Last active
June 2, 2024 14:56
-
-
Save Christopher-Hayes/636ba25e0ae2e7020722d5386ac2571b to your computer and use it in GitHub Desktop.
Revisions
-
Christopher-Hayes revised this gist
Oct 23, 2022 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,4 +1,4 @@ ## Converting DreamBooth `.bin` files to a `.ckpt` model file. These instructions are based on DreamBooth usage with the https://github.com/ShivamShrirao/diffusers repo. -
Christopher-Hayes revised this gist
Oct 19, 2022 . 1 changed file with 7 additions and 6 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -28,17 +28,18 @@ cd examples/dreambooth ### 2b. Either run the Python script directly, or run the convenience CLI script **The convenience CLI script:** ```bash ./toCkpt.sh ./name_of_model_folder ``` Remember to change ./name_of_model_folder to the correct folder name. **If you want to run the original Python script, that can still be done:** ```bash python convertToCkpt.py --model_path ./name_of_model_folder --checkpoint_path ./model.ckpt ``` Remember to change ./name_of_model_folder to the correct folder name. `model.ckpt` can be anything. ### 3. Try out your new checkpoint model. -
Christopher-Hayes revised this gist
Oct 19, 2022 . 1 changed file with 3 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -28,15 +28,17 @@ cd examples/dreambooth ### 2b. Either run the Python script directly, or run the convenience CLI script To run the original convert script run this in the CLI (from inside the examples/dreambooth directory). ```bash python convertToCkpt.py --model_path ./name_of_model_folder --checkpoint_path ./model.ckpt ``` Remember to change ./name_of_model_folder to the correct folder name. `model.ckpt` can be anything. A convenience CLI script is also available: ```bash ./toCkpt.sh ./name_of_model_folder ``` Remember to change ./name_of_model_folder to the correct folder name. ### 3. Try out your new checkpoint model. -
Christopher-Hayes revised this gist
Oct 19, 2022 . 1 changed file with 6 additions and 6 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -4,21 +4,21 @@ These instructions are based on DreamBooth usage with the https://github.com/Shi ## 1. Add the script files Below are 2 files. "convertToCkpt.py" and "toCkpt.sh". Create those files inside the `examples/dreambooth` folder with the code provided. ### 1a. Python convert script (required) Put the `convertToCkpt.py` file in the `examples/dreambooth` folder. Credit to @jachiam for this Python script https://gist.github.com/jachiam/8a5c0b607e38fcc585168b90c686eb05 ### 1b. Convenience CLI command (optional) This runs the Python script. It accepts the model folder as the only argument. It's a little easier to type and automatically uses the folder name as the .ckpt filename. Put the `toCkpt.sh` file in the `examples/dreambooth` folder as well. ## 2. Running the script ### 2a. Make sure you're in the `examples/dreambooth` folder Run if you're still in the project root directory. -
Christopher-Hayes revised this gist
Oct 19, 2022 . 1 changed file with 4 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,3 +1,7 @@ # Converting DreamBooth `.bin` files to a `.ckpt` model file. These instructions are based on DreamBooth usage with the https://github.com/ShivamShrirao/diffusers repo. ## 1. Add the script files Below is 2 files. "toCkpt.sh" and "convertToCkpt.py". Create those files inside the `examples/dreambooth` folder with the code provided. -
Christopher-Hayes revised this gist
Oct 19, 2022 . 3 changed files with 259 additions and 267 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,14 +1,28 @@ ## 1. Add the script files Below is 2 files. "toCkpt.sh" and "convertToCkpt.py". Create those files inside the `examples/dreambooth` folder with the code provided. ### 1a. Python convert script (required) Put the `convertToCkpt.py` file in the "examples/dreambooth" folder. Credit to @jachiam this file is originally from https://gist.github.com/jachiam/8a5c0b607e38fcc585168b90c686eb05 ### 1b. Convenience CLI command (optional) This runs the Python script. It accepts the model folder as the only argument. It's a little easier to type and automatically uses the folder name as the .ckpt filename. Put the `toCkpt.sh` file in the "examples/dreambooth" folder as well. ## 2. Running the script Before running these, make sure to **first** create the Python script and (optionally) the .sh script under "Code files" ### 2a. Make sure you're in the `examples/dreambooth` folder Run if you're still in the project root directory. ```bash cd examples/dreambooth ``` ### 2b. Either run the Python script directly, or run the convenience CLI script To run the original convert script run this in the CLI (from inside the examples/dreambooth directory): ```bash @@ -23,267 +37,3 @@ A convenience CLI script is also available: ### 3. Try out your new checkpoint model. If you're using Automatic1111, copy-paste that `.ckpt` model file into the `models/Stable-diffusion` folder. This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,235 @@ # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint. # *Only* converts the UNet, VAE, and Text Encoder. # Does not convert optimizer state or any other thing. # Written by jachiam import argparse import os.path as osp import torch # =================# # UNet Conversion # # =================# unet_conversion_map = [ # (stable-diffusion, HF Diffusers) ("time_embed.0.weight", "time_embedding.linear_1.weight"), ("time_embed.0.bias", "time_embedding.linear_1.bias"), ("time_embed.2.weight", "time_embedding.linear_2.weight"), ("time_embed.2.bias", "time_embedding.linear_2.bias"), ("input_blocks.0.0.weight", "conv_in.weight"), ("input_blocks.0.0.bias", "conv_in.bias"), ("out.0.weight", "conv_norm_out.weight"), ("out.0.bias", "conv_norm_out.bias"), ("out.2.weight", "conv_out.weight"), ("out.2.bias", "conv_out.bias"), ] unet_conversion_map_resnet = [ # (stable-diffusion, HF Diffusers) ("in_layers.0", "norm1"), ("in_layers.2", "conv1"), ("out_layers.0", "norm2"), ("out_layers.3", "conv2"), ("emb_layers.1", "time_emb_proj"), ("skip_connection", "conv_shortcut"), ] unet_conversion_map_layer = [] # hardcoded number of downblocks and resnets/attentions... # would need smarter logic for other networks. for i in range(4): # loop over downblocks/upblocks for j in range(2): # loop over resnets/attentions for downblocks hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) if i < 3: # no attention layers in down_blocks.3 hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) for j in range(3): # loop over resnets/attentions for upblocks hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." sd_up_res_prefix = f"output_blocks.{3*i + j}.0." unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) if i > 0: # no attention layers in up_blocks.0 hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) if i < 3: # no downsample in down_blocks.3 hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) # no upsample in up_blocks.3 hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) hf_mid_atn_prefix = "mid_block.attentions.0." sd_mid_atn_prefix = "middle_block.1." unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) for j in range(2): hf_mid_res_prefix = f"mid_block.resnets.{j}." sd_mid_res_prefix = f"middle_block.{2*j}." unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) def convert_unet_state_dict(unet_state_dict): # buyer beware: this is a *brittle* function, # and correct output requires that all of these pieces interact in # the exact order in which I have arranged them. mapping = {k: k for k in unet_state_dict.keys()} for sd_name, hf_name in unet_conversion_map: mapping[hf_name] = sd_name for k, v in mapping.items(): if "resnets" in k: for sd_part, hf_part in unet_conversion_map_resnet: v = v.replace(hf_part, sd_part) mapping[k] = v for k, v in mapping.items(): for sd_part, hf_part in unet_conversion_map_layer: v = v.replace(hf_part, sd_part) mapping[k] = v new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} return new_state_dict # ================# # VAE Conversion # # ================# vae_conversion_map = [ # (stable-diffusion, HF Diffusers) ("nin_shortcut", "conv_shortcut"), ("norm_out", "conv_norm_out"), ("mid.attn_1.", "mid_block.attentions.0."), ] for i in range(4): # down_blocks have two resnets for j in range(2): hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." sd_down_prefix = f"encoder.down.{i}.block.{j}." vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) if i < 3: hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." sd_downsample_prefix = f"down.{i}.downsample." vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." sd_upsample_prefix = f"up.{3-i}.upsample." vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) # up_blocks have three resnets # also, up blocks in hf are numbered in reverse from sd for j in range(3): hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." sd_up_prefix = f"decoder.up.{3-i}.block.{j}." vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) # this part accounts for mid blocks in both the encoder and the decoder for i in range(2): hf_mid_res_prefix = f"mid_block.resnets.{i}." sd_mid_res_prefix = f"mid.block_{i+1}." vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) vae_conversion_map_attn = [ # (stable-diffusion, HF Diffusers) ("norm.", "group_norm."), ("q.", "query."), ("k.", "key."), ("v.", "value."), ("proj_out.", "proj_attn."), ] def reshape_weight_for_sd(w): # convert HF linear weights to SD conv2d weights return w.reshape(*w.shape, 1, 1) def convert_vae_state_dict(vae_state_dict): mapping = {k: k for k in vae_state_dict.keys()} for k, v in mapping.items(): for sd_part, hf_part in vae_conversion_map: v = v.replace(hf_part, sd_part) mapping[k] = v for k, v in mapping.items(): if "attentions" in k: for sd_part, hf_part in vae_conversion_map_attn: v = v.replace(hf_part, sd_part) mapping[k] = v new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} weights_to_convert = ["q", "k", "v", "proj_out"] for k, v in new_state_dict.items(): for weight_name in weights_to_convert: if f"mid.attn_1.{weight_name}.weight" in k: print(f"Reshaping {k} for SD format") new_state_dict[k] = reshape_weight_for_sd(v) return new_state_dict # =========================# # Text Encoder Conversion # # =========================# # pretty much a no-op def convert_text_enc_state_dict(text_enc_dict): return text_enc_dict if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") parser.add_argument("--half", action="store_true", help="Save weights in half precision.") args = parser.parse_args() assert args.model_path is not None, "Must provide a model path!" assert args.checkpoint_path is not None, "Must provide a checkpoint path!" unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin") vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin") text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") # Convert the UNet model unet_state_dict = torch.load(unet_path, map_location='cpu') unet_state_dict = convert_unet_state_dict(unet_state_dict) unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} # Convert the VAE model vae_state_dict = torch.load(vae_path, map_location='cpu') vae_state_dict = convert_vae_state_dict(vae_state_dict) vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} # Convert the text encoder model text_enc_dict = torch.load(text_enc_path, map_location='cpu') text_enc_dict = convert_text_enc_state_dict(text_enc_dict) text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} # Put together new checkpoint state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} if args.half: state_dict = {k:v.half() for k,v in state_dict.items()} state_dict = {"state_dict": state_dict} torch.save(state_dict, args.checkpoint_path) This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,7 @@ #!/bin/bash model_path=$1 ckpt_name=$(basename $model_path) ckpt_path="${ckpt_name}.ckpt" python convertToCkpt.py --model_path=$model_path --checkpoint_path=$ckpt_path -
Christopher-Hayes revised this gist
Oct 19, 2022 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,6 +1,6 @@ ## To run Before running these, make sure to **first** create the Python script and (optionally) the .sh script under "Code files" ### 1. Make sure you're in the `examples/dreambooth` folder -
Christopher-Hayes revised this gist
Oct 19, 2022 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,6 +1,6 @@ ## To run Before running these, make sure to create the Python script and (optionally) the .sh script under "Code files" ### 1. Make sure you're in the `examples/dreambooth` folder -
Christopher-Hayes revised this gist
Oct 19, 2022 . 1 changed file with 3 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -20,7 +20,9 @@ A convenience CLI script is also available: ./toCkpt.sh ./name_of_model_folder ``` ### 3. Try out your new checkpoint model. If you're using Automatic1111, copy-paste that `.ckpt` model file into the `models/Stable-diffusion` folder. --- -
Christopher-Hayes revised this gist
Oct 19, 2022 . 1 changed file with 6 additions and 3 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -12,13 +12,16 @@ cd examples/dreambooth To run the original convert script run this in the CLI (from inside the examples/dreambooth directory): ```bash python convertToCkpt.py --model_path ./name_of_model_folder --checkpoint_path ./model.ckpt ``` A convenience CLI script is also available: ```bash ./toCkpt.sh ./name_of_model_folder ``` ### 3. If you're using Automatic1111, copy-paste that `.ckpt` model file into the `models/Stable-Diffusion` folder. --- ## Code files @@ -270,15 +273,15 @@ if __name__ == "__main__": ### 2. Convenience CLI command (optional) This runs the Python script. It accepts the model folder as the single argument. The ckpt will show up with the same name as the model folder. Create the file below as "toCkpt.sh" ```bash #!/bin/bash model_path=$1 ckpt_name=$(basename $model_path) ckpt_path="${ckpt_name}.ckpt" python convertToCkpt.py --model_path=$model_path --checkpoint_path=$ckpt_path ``` -
Christopher-Hayes created this gist
Oct 19, 2022 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,284 @@ ## To run Before running these, make sure to create the 2 files under "code files" ### 1. Make sure you're in the `examples/dreambooth` folder ```bash cd examples/dreambooth ``` ### 2. Either run the Python script directly, or run the convenience CLI script To run the original convert script run this in the CLI (from inside the examples/dreambooth directory): ```bash python convertToCkpt.py --model_path ./name_of_model_folder --checkpoint_path ./name_of_model_folder/model.ckpt ``` A convenience CLI script is also available: ```bash ./toCkpt.sh ./name_of_model_folder ``` --- ## Code files Below is 2 files. "toCkpt.sh" and "convertToCkpt.py". Create those files inside the `examples/dreambooth` folder with the code provided. ### 1. Python convert script (required) Create the file below as "convertToCkpt.py" Credit to @jachiam this file is originally from https://gist.github.com/jachiam/8a5c0b607e38fcc585168b90c686eb05 ```python # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint. # *Only* converts the UNet, VAE, and Text Encoder. # Does not convert optimizer state or any other thing. # Written by jachiam import argparse import os.path as osp import torch # =================# # UNet Conversion # # =================# unet_conversion_map = [ # (stable-diffusion, HF Diffusers) ("time_embed.0.weight", "time_embedding.linear_1.weight"), ("time_embed.0.bias", "time_embedding.linear_1.bias"), ("time_embed.2.weight", "time_embedding.linear_2.weight"), ("time_embed.2.bias", "time_embedding.linear_2.bias"), ("input_blocks.0.0.weight", "conv_in.weight"), ("input_blocks.0.0.bias", "conv_in.bias"), ("out.0.weight", "conv_norm_out.weight"), ("out.0.bias", "conv_norm_out.bias"), ("out.2.weight", "conv_out.weight"), ("out.2.bias", "conv_out.bias"), ] unet_conversion_map_resnet = [ # (stable-diffusion, HF Diffusers) ("in_layers.0", "norm1"), ("in_layers.2", "conv1"), ("out_layers.0", "norm2"), ("out_layers.3", "conv2"), ("emb_layers.1", "time_emb_proj"), ("skip_connection", "conv_shortcut"), ] unet_conversion_map_layer = [] # hardcoded number of downblocks and resnets/attentions... # would need smarter logic for other networks. for i in range(4): # loop over downblocks/upblocks for j in range(2): # loop over resnets/attentions for downblocks hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) if i < 3: # no attention layers in down_blocks.3 hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) for j in range(3): # loop over resnets/attentions for upblocks hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." sd_up_res_prefix = f"output_blocks.{3*i + j}.0." unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) if i > 0: # no attention layers in up_blocks.0 hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) if i < 3: # no downsample in down_blocks.3 hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) # no upsample in up_blocks.3 hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) hf_mid_atn_prefix = "mid_block.attentions.0." sd_mid_atn_prefix = "middle_block.1." unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) for j in range(2): hf_mid_res_prefix = f"mid_block.resnets.{j}." sd_mid_res_prefix = f"middle_block.{2*j}." unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) def convert_unet_state_dict(unet_state_dict): # buyer beware: this is a *brittle* function, # and correct output requires that all of these pieces interact in # the exact order in which I have arranged them. mapping = {k: k for k in unet_state_dict.keys()} for sd_name, hf_name in unet_conversion_map: mapping[hf_name] = sd_name for k, v in mapping.items(): if "resnets" in k: for sd_part, hf_part in unet_conversion_map_resnet: v = v.replace(hf_part, sd_part) mapping[k] = v for k, v in mapping.items(): for sd_part, hf_part in unet_conversion_map_layer: v = v.replace(hf_part, sd_part) mapping[k] = v new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} return new_state_dict # ================# # VAE Conversion # # ================# vae_conversion_map = [ # (stable-diffusion, HF Diffusers) ("nin_shortcut", "conv_shortcut"), ("norm_out", "conv_norm_out"), ("mid.attn_1.", "mid_block.attentions.0."), ] for i in range(4): # down_blocks have two resnets for j in range(2): hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." sd_down_prefix = f"encoder.down.{i}.block.{j}." vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) if i < 3: hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." sd_downsample_prefix = f"down.{i}.downsample." vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." sd_upsample_prefix = f"up.{3-i}.upsample." vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) # up_blocks have three resnets # also, up blocks in hf are numbered in reverse from sd for j in range(3): hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." sd_up_prefix = f"decoder.up.{3-i}.block.{j}." vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) # this part accounts for mid blocks in both the encoder and the decoder for i in range(2): hf_mid_res_prefix = f"mid_block.resnets.{i}." sd_mid_res_prefix = f"mid.block_{i+1}." vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) vae_conversion_map_attn = [ # (stable-diffusion, HF Diffusers) ("norm.", "group_norm."), ("q.", "query."), ("k.", "key."), ("v.", "value."), ("proj_out.", "proj_attn."), ] def reshape_weight_for_sd(w): # convert HF linear weights to SD conv2d weights return w.reshape(*w.shape, 1, 1) def convert_vae_state_dict(vae_state_dict): mapping = {k: k for k in vae_state_dict.keys()} for k, v in mapping.items(): for sd_part, hf_part in vae_conversion_map: v = v.replace(hf_part, sd_part) mapping[k] = v for k, v in mapping.items(): if "attentions" in k: for sd_part, hf_part in vae_conversion_map_attn: v = v.replace(hf_part, sd_part) mapping[k] = v new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} weights_to_convert = ["q", "k", "v", "proj_out"] for k, v in new_state_dict.items(): for weight_name in weights_to_convert: if f"mid.attn_1.{weight_name}.weight" in k: print(f"Reshaping {k} for SD format") new_state_dict[k] = reshape_weight_for_sd(v) return new_state_dict # =========================# # Text Encoder Conversion # # =========================# # pretty much a no-op def convert_text_enc_state_dict(text_enc_dict): return text_enc_dict if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") parser.add_argument("--half", action="store_true", help="Save weights in half precision.") args = parser.parse_args() assert args.model_path is not None, "Must provide a model path!" assert args.checkpoint_path is not None, "Must provide a checkpoint path!" unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin") vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin") text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") # Convert the UNet model unet_state_dict = torch.load(unet_path, map_location='cpu') unet_state_dict = convert_unet_state_dict(unet_state_dict) unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} # Convert the VAE model vae_state_dict = torch.load(vae_path, map_location='cpu') vae_state_dict = convert_vae_state_dict(vae_state_dict) vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} # Convert the text encoder model text_enc_dict = torch.load(text_enc_path, map_location='cpu') text_enc_dict = convert_text_enc_state_dict(text_enc_dict) text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} # Put together new checkpoint state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} if args.half: state_dict = {k:v.half() for k,v in state_dict.items()} state_dict = {"state_dict": state_dict} torch.save(state_dict, args.checkpoint_path) ``` --- ### 2. Convenience CLI command (optional) This runs the Python script. It accepts the model folder as the single argument. The ckpt is put inside the model folder, and uses the same name as the folder. Create the file below as "toCkpt.sh" ```bash #!/bin/bash model_path=$1 ckpt_name=$(basename $model_path) ckpt_path="${model_path}/${ckpt_name}.ckpt" python convertToCkpt.py --model_path=$model_path --checkpoint_path=$ckpt_path ```