Skip to content

Instantly share code, notes, and snippets.

@ProGamerGov
Last active December 26, 2023 07:15
Show Gist options
  • Select an option

  • Save ProGamerGov/70061a08e3a2da6e9ed83e145ea24a70 to your computer and use it in GitHub Desktop.

Select an option

Save ProGamerGov/70061a08e3a2da6e9ed83e145ea24a70 to your computer and use it in GitHub Desktop.

Revisions

  1. ProGamerGov revised this gist Oct 23, 2022. No changes.
  2. ProGamerGov revised this gist Oct 23, 2022. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion replace_vae.py
    Original file line number Diff line number Diff line change
    @@ -8,7 +8,7 @@
    model_file_path = "v1-5-pruned-emaonly.ckpt"

    # Name to use for new model file
    new_model_name = "v1-5-pruned-emaonly_mse_vae.ckpt"
    new_model_name = "v1-5-pruned-emaonly_ema_vae.ckpt"

    # Load files
    vae_model = torch.load(vae_file_path, map_location="cpu")
  3. ProGamerGov created this gist Oct 23, 2022.
    25 changes: 25 additions & 0 deletions replace_vae.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,25 @@
    # Script by https://github.com/ProGamerGov
    import copy
    import torch


    # Path to model and VAE files that you want to merge
    vae_file_path = "vae-ft-mse-840000-ema-pruned.ckpt"
    model_file_path = "v1-5-pruned-emaonly.ckpt"

    # Name to use for new model file
    new_model_name = "v1-5-pruned-emaonly_mse_vae.ckpt"

    # Load files
    vae_model = torch.load(vae_file_path, map_location="cpu")
    full_model = torch.load(model_file_path, map_location="cpu")

    # Replace VAE in model file with new VAE
    vae_dict = {k: v for k, v in vae_model["state_dict"].items() if k[0:4] not in ["loss", "mode"]}
    for k, _ in vae_dict.items():
    key_name = "first_stage_model." + k
    full_model['state_dict'][key_name] = copy.deepcopy(vae_model["state_dict"][k])


    # Save model with new VAE
    torch.save(full_model, new_model_name)