Skip to content

Instantly share code, notes, and snippets.

@zackzhou-work
Forked from ProGamerGov/replace_vae.py
Created September 12, 2023 12:10
Show Gist options
  • Select an option

  • Save zackzhou-work/e5ccf942c9b7e01abc6cfd64fd3705da to your computer and use it in GitHub Desktop.

Select an option

Save zackzhou-work/e5ccf942c9b7e01abc6cfd64fd3705da to your computer and use it in GitHub Desktop.
Replace the VAE in a Stable Diffusion model with a new VAE. Tested on v1.4 & v1.5 SD models
# Script by https://github.com/ProGamerGov
import copy
import torch
# Path to model and VAE files that you want to merge
vae_file_path = "vae-ft-mse-840000-ema-pruned.ckpt"
model_file_path = "v1-5-pruned-emaonly.ckpt"
# Name to use for new model file
new_model_name = "v1-5-pruned-emaonly_ema_vae.ckpt"
# Load files
vae_model = torch.load(vae_file_path, map_location="cpu")
full_model = torch.load(model_file_path, map_location="cpu")
# Replace VAE in model file with new VAE
vae_dict = {k: v for k, v in vae_model["state_dict"].items() if k[0:4] not in ["loss", "mode"]}
for k, _ in vae_dict.items():
key_name = "first_stage_model." + k
full_model['state_dict'][key_name] = copy.deepcopy(vae_model["state_dict"][k])
# Save model with new VAE
torch.save(full_model, new_model_name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment