# Script by https://github.com/ProGamerGov import copy import torch # Path to model and VAE files that you want to merge vae_file_path = "vae-ft-mse-840000-ema-pruned.ckpt" model_file_path = "v1-5-pruned-emaonly.ckpt" # Name to use for new model file new_model_name = "v1-5-pruned-emaonly_ema_vae.ckpt" # Load files vae_model = torch.load(vae_file_path, map_location="cpu") full_model = torch.load(model_file_path, map_location="cpu") # Replace VAE in model file with new VAE vae_dict = {k: v for k, v in vae_model["state_dict"].items() if k[0:4] not in ["loss", "mode"]} for k, _ in vae_dict.items(): key_name = "first_stage_model." + k full_model['state_dict'][key_name] = copy.deepcopy(vae_model["state_dict"][k]) # Save model with new VAE torch.save(full_model, new_model_name)