Skip to content

Instantly share code, notes, and snippets.

@Christopher-Hayes
Last active June 2, 2024 14:56
Show Gist options
  • Select an option

  • Save Christopher-Hayes/636ba25e0ae2e7020722d5386ac2571b to your computer and use it in GitHub Desktop.

Select an option

Save Christopher-Hayes/636ba25e0ae2e7020722d5386ac2571b to your computer and use it in GitHub Desktop.

Revisions

  1. Christopher-Hayes revised this gist Oct 23, 2022. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion convertToCheckpoint.md
    Original file line number Diff line number Diff line change
    @@ -1,4 +1,4 @@
    # Converting DreamBooth `.bin` files to a `.ckpt` model file.
    ## Converting DreamBooth `.bin` files to a `.ckpt` model file.

    These instructions are based on DreamBooth usage with the https://github.com/ShivamShrirao/diffusers repo.

  2. Christopher-Hayes revised this gist Oct 19, 2022. 1 changed file with 7 additions and 6 deletions.
    13 changes: 7 additions & 6 deletions convertToCheckpoint.md
    Original file line number Diff line number Diff line change
    @@ -28,17 +28,18 @@ cd examples/dreambooth

    ### 2b. Either run the Python script directly, or run the convenience CLI script

    To run the original convert script run this in the CLI (from inside the examples/dreambooth directory).
    **The convenience CLI script:**
    ```bash
    python convertToCkpt.py --model_path ./name_of_model_folder --checkpoint_path ./model.ckpt
    ./toCkpt.sh ./name_of_model_folder
    ```
    Remember to change ./name_of_model_folder to the correct folder name. `model.ckpt` can be anything.
    Remember to change ./name_of_model_folder to the correct folder name.

    **If you want to run the original Python script, that can still be done:**

    A convenience CLI script is also available:
    ```bash
    ./toCkpt.sh ./name_of_model_folder
    python convertToCkpt.py --model_path ./name_of_model_folder --checkpoint_path ./model.ckpt
    ```
    Remember to change ./name_of_model_folder to the correct folder name.
    Remember to change ./name_of_model_folder to the correct folder name. `model.ckpt` can be anything.

    ### 3. Try out your new checkpoint model.

  3. Christopher-Hayes revised this gist Oct 19, 2022. 1 changed file with 3 additions and 1 deletion.
    4 changes: 3 additions & 1 deletion convertToCheckpoint.md
    Original file line number Diff line number Diff line change
    @@ -28,15 +28,17 @@ cd examples/dreambooth

    ### 2b. Either run the Python script directly, or run the convenience CLI script

    To run the original convert script run this in the CLI (from inside the examples/dreambooth directory):
    To run the original convert script run this in the CLI (from inside the examples/dreambooth directory).
    ```bash
    python convertToCkpt.py --model_path ./name_of_model_folder --checkpoint_path ./model.ckpt
    ```
    Remember to change ./name_of_model_folder to the correct folder name. `model.ckpt` can be anything.

    A convenience CLI script is also available:
    ```bash
    ./toCkpt.sh ./name_of_model_folder
    ```
    Remember to change ./name_of_model_folder to the correct folder name.

    ### 3. Try out your new checkpoint model.

  4. Christopher-Hayes revised this gist Oct 19, 2022. 1 changed file with 6 additions and 6 deletions.
    12 changes: 6 additions & 6 deletions convertToCheckpoint.md
    Original file line number Diff line number Diff line change
    @@ -4,21 +4,21 @@ These instructions are based on DreamBooth usage with the https://github.com/Shi

    ## 1. Add the script files

    Below is 2 files. "toCkpt.sh" and "convertToCkpt.py". Create those files inside the `examples/dreambooth` folder with the code provided.
    Below are 2 files. "convertToCkpt.py" and "toCkpt.sh". Create those files inside the `examples/dreambooth` folder with the code provided.

    ### 1a. Python convert script (required)
    Put the `convertToCkpt.py` file in the "examples/dreambooth" folder.
    Credit to @jachiam this file is originally from https://gist.github.com/jachiam/8a5c0b607e38fcc585168b90c686eb05

    Put the `convertToCkpt.py` file in the `examples/dreambooth` folder.
    Credit to @jachiam for this Python script https://gist.github.com/jachiam/8a5c0b607e38fcc585168b90c686eb05

    ### 1b. Convenience CLI command (optional)

    This runs the Python script. It accepts the model folder as the only argument. It's a little easier to type and automatically uses the folder name as the .ckpt filename.

    Put the `toCkpt.sh` file in the "examples/dreambooth" folder as well.
    Put the `toCkpt.sh` file in the `examples/dreambooth` folder as well.

    ## 2. Running the script

    Before running these, make sure to **first** create the Python script and (optionally) the .sh script under "Code files"

    ### 2a. Make sure you're in the `examples/dreambooth` folder

    Run if you're still in the project root directory.
  5. Christopher-Hayes revised this gist Oct 19, 2022. 1 changed file with 4 additions and 0 deletions.
    4 changes: 4 additions & 0 deletions convertToCheckpoint.md
    Original file line number Diff line number Diff line change
    @@ -1,3 +1,7 @@
    # Converting DreamBooth `.bin` files to a `.ckpt` model file.

    These instructions are based on DreamBooth usage with the https://github.com/ShivamShrirao/diffusers repo.

    ## 1. Add the script files

    Below is 2 files. "toCkpt.sh" and "convertToCkpt.py". Create those files inside the `examples/dreambooth` folder with the code provided.
  6. Christopher-Hayes revised this gist Oct 19, 2022. 3 changed files with 259 additions and 267 deletions.
    284 changes: 17 additions & 267 deletions convertToCheckpoint.md
    Original file line number Diff line number Diff line change
    @@ -1,14 +1,28 @@
    ## To run
    ## 1. Add the script files

    Below is 2 files. "toCkpt.sh" and "convertToCkpt.py". Create those files inside the `examples/dreambooth` folder with the code provided.

    ### 1a. Python convert script (required)
    Put the `convertToCkpt.py` file in the "examples/dreambooth" folder.
    Credit to @jachiam this file is originally from https://gist.github.com/jachiam/8a5c0b607e38fcc585168b90c686eb05

    ### 1b. Convenience CLI command (optional)
    This runs the Python script. It accepts the model folder as the only argument. It's a little easier to type and automatically uses the folder name as the .ckpt filename.

    Put the `toCkpt.sh` file in the "examples/dreambooth" folder as well.

    ## 2. Running the script

    Before running these, make sure to **first** create the Python script and (optionally) the .sh script under "Code files"

    ### 1. Make sure you're in the `examples/dreambooth` folder
    ### 2a. Make sure you're in the `examples/dreambooth` folder

    Run if you're still in the project root directory.
    ```bash
    cd examples/dreambooth
    ```

    ### 2. Either run the Python script directly, or run the convenience CLI script
    ### 2b. Either run the Python script directly, or run the convenience CLI script

    To run the original convert script run this in the CLI (from inside the examples/dreambooth directory):
    ```bash
    @@ -23,267 +37,3 @@ A convenience CLI script is also available:
    ### 3. Try out your new checkpoint model.

    If you're using Automatic1111, copy-paste that `.ckpt` model file into the `models/Stable-diffusion` folder.

    ---

    ## Code files

    Below is 2 files. "toCkpt.sh" and "convertToCkpt.py". Create those files inside the `examples/dreambooth` folder with the code provided.

    ### 1. Python convert script (required)
    Create the file below as "convertToCkpt.py"
    Credit to @jachiam this file is originally from https://gist.github.com/jachiam/8a5c0b607e38fcc585168b90c686eb05

    ```python
    # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
    # *Only* converts the UNet, VAE, and Text Encoder.
    # Does not convert optimizer state or any other thing.
    # Written by jachiam

    import argparse
    import os.path as osp

    import torch


    # =================#
    # UNet Conversion #
    # =================#

    unet_conversion_map = [
    # (stable-diffusion, HF Diffusers)
    ("time_embed.0.weight", "time_embedding.linear_1.weight"),
    ("time_embed.0.bias", "time_embedding.linear_1.bias"),
    ("time_embed.2.weight", "time_embedding.linear_2.weight"),
    ("time_embed.2.bias", "time_embedding.linear_2.bias"),
    ("input_blocks.0.0.weight", "conv_in.weight"),
    ("input_blocks.0.0.bias", "conv_in.bias"),
    ("out.0.weight", "conv_norm_out.weight"),
    ("out.0.bias", "conv_norm_out.bias"),
    ("out.2.weight", "conv_out.weight"),
    ("out.2.bias", "conv_out.bias"),
    ]

    unet_conversion_map_resnet = [
    # (stable-diffusion, HF Diffusers)
    ("in_layers.0", "norm1"),
    ("in_layers.2", "conv1"),
    ("out_layers.0", "norm2"),
    ("out_layers.3", "conv2"),
    ("emb_layers.1", "time_emb_proj"),
    ("skip_connection", "conv_shortcut"),
    ]

    unet_conversion_map_layer = []
    # hardcoded number of downblocks and resnets/attentions...
    # would need smarter logic for other networks.
    for i in range(4):
    # loop over downblocks/upblocks

    for j in range(2):
    # loop over resnets/attentions for downblocks
    hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
    sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
    unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))

    if i < 3:
    # no attention layers in down_blocks.3
    hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
    sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
    unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))

    for j in range(3):
    # loop over resnets/attentions for upblocks
    hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
    sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
    unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))

    if i > 0:
    # no attention layers in up_blocks.0
    hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
    sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
    unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))

    if i < 3:
    # no downsample in down_blocks.3
    hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
    sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
    unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))

    # no upsample in up_blocks.3
    hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
    sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
    unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))

    hf_mid_atn_prefix = "mid_block.attentions.0."
    sd_mid_atn_prefix = "middle_block.1."
    unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))

    for j in range(2):
    hf_mid_res_prefix = f"mid_block.resnets.{j}."
    sd_mid_res_prefix = f"middle_block.{2*j}."
    unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))


    def convert_unet_state_dict(unet_state_dict):
    # buyer beware: this is a *brittle* function,
    # and correct output requires that all of these pieces interact in
    # the exact order in which I have arranged them.
    mapping = {k: k for k in unet_state_dict.keys()}
    for sd_name, hf_name in unet_conversion_map:
    mapping[hf_name] = sd_name
    for k, v in mapping.items():
    if "resnets" in k:
    for sd_part, hf_part in unet_conversion_map_resnet:
    v = v.replace(hf_part, sd_part)
    mapping[k] = v
    for k, v in mapping.items():
    for sd_part, hf_part in unet_conversion_map_layer:
    v = v.replace(hf_part, sd_part)
    mapping[k] = v
    new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
    return new_state_dict


    # ================#
    # VAE Conversion #
    # ================#

    vae_conversion_map = [
    # (stable-diffusion, HF Diffusers)
    ("nin_shortcut", "conv_shortcut"),
    ("norm_out", "conv_norm_out"),
    ("mid.attn_1.", "mid_block.attentions.0."),
    ]

    for i in range(4):
    # down_blocks have two resnets
    for j in range(2):
    hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
    sd_down_prefix = f"encoder.down.{i}.block.{j}."
    vae_conversion_map.append((sd_down_prefix, hf_down_prefix))

    if i < 3:
    hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
    sd_downsample_prefix = f"down.{i}.downsample."
    vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))

    hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
    sd_upsample_prefix = f"up.{3-i}.upsample."
    vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))

    # up_blocks have three resnets
    # also, up blocks in hf are numbered in reverse from sd
    for j in range(3):
    hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
    sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
    vae_conversion_map.append((sd_up_prefix, hf_up_prefix))

    # this part accounts for mid blocks in both the encoder and the decoder
    for i in range(2):
    hf_mid_res_prefix = f"mid_block.resnets.{i}."
    sd_mid_res_prefix = f"mid.block_{i+1}."
    vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))


    vae_conversion_map_attn = [
    # (stable-diffusion, HF Diffusers)
    ("norm.", "group_norm."),
    ("q.", "query."),
    ("k.", "key."),
    ("v.", "value."),
    ("proj_out.", "proj_attn."),
    ]


    def reshape_weight_for_sd(w):
    # convert HF linear weights to SD conv2d weights
    return w.reshape(*w.shape, 1, 1)


    def convert_vae_state_dict(vae_state_dict):
    mapping = {k: k for k in vae_state_dict.keys()}
    for k, v in mapping.items():
    for sd_part, hf_part in vae_conversion_map:
    v = v.replace(hf_part, sd_part)
    mapping[k] = v
    for k, v in mapping.items():
    if "attentions" in k:
    for sd_part, hf_part in vae_conversion_map_attn:
    v = v.replace(hf_part, sd_part)
    mapping[k] = v
    new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
    weights_to_convert = ["q", "k", "v", "proj_out"]
    for k, v in new_state_dict.items():
    for weight_name in weights_to_convert:
    if f"mid.attn_1.{weight_name}.weight" in k:
    print(f"Reshaping {k} for SD format")
    new_state_dict[k] = reshape_weight_for_sd(v)
    return new_state_dict


    # =========================#
    # Text Encoder Conversion #
    # =========================#
    # pretty much a no-op


    def convert_text_enc_state_dict(text_enc_dict):
    return text_enc_dict


    if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
    parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
    parser.add_argument("--half", action="store_true", help="Save weights in half precision.")

    args = parser.parse_args()

    assert args.model_path is not None, "Must provide a model path!"

    assert args.checkpoint_path is not None, "Must provide a checkpoint path!"

    unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
    vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
    text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")

    # Convert the UNet model
    unet_state_dict = torch.load(unet_path, map_location='cpu')
    unet_state_dict = convert_unet_state_dict(unet_state_dict)
    unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}

    # Convert the VAE model
    vae_state_dict = torch.load(vae_path, map_location='cpu')
    vae_state_dict = convert_vae_state_dict(vae_state_dict)
    vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}

    # Convert the text encoder model
    text_enc_dict = torch.load(text_enc_path, map_location='cpu')
    text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
    text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}

    # Put together new checkpoint
    state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
    if args.half:
    state_dict = {k:v.half() for k,v in state_dict.items()}
    state_dict = {"state_dict": state_dict}
    torch.save(state_dict, args.checkpoint_path)
    ```
    ---

    ### 2. Convenience CLI command (optional)
    This runs the Python script. It accepts the model folder as the single argument.
    The ckpt will show up with the same name as the model folder.
    Create the file below as "toCkpt.sh"

    ```bash
    #!/bin/bash

    model_path=$1
    ckpt_name=$(basename $model_path)
    ckpt_path="${ckpt_name}.ckpt"

    python convertToCkpt.py --model_path=$model_path --checkpoint_path=$ckpt_path
    ```
    235 changes: 235 additions & 0 deletions convertToCkpt.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,235 @@
    # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
    # *Only* converts the UNet, VAE, and Text Encoder.
    # Does not convert optimizer state or any other thing.
    # Written by jachiam

    import argparse
    import os.path as osp

    import torch


    # =================#
    # UNet Conversion #
    # =================#

    unet_conversion_map = [
    # (stable-diffusion, HF Diffusers)
    ("time_embed.0.weight", "time_embedding.linear_1.weight"),
    ("time_embed.0.bias", "time_embedding.linear_1.bias"),
    ("time_embed.2.weight", "time_embedding.linear_2.weight"),
    ("time_embed.2.bias", "time_embedding.linear_2.bias"),
    ("input_blocks.0.0.weight", "conv_in.weight"),
    ("input_blocks.0.0.bias", "conv_in.bias"),
    ("out.0.weight", "conv_norm_out.weight"),
    ("out.0.bias", "conv_norm_out.bias"),
    ("out.2.weight", "conv_out.weight"),
    ("out.2.bias", "conv_out.bias"),
    ]

    unet_conversion_map_resnet = [
    # (stable-diffusion, HF Diffusers)
    ("in_layers.0", "norm1"),
    ("in_layers.2", "conv1"),
    ("out_layers.0", "norm2"),
    ("out_layers.3", "conv2"),
    ("emb_layers.1", "time_emb_proj"),
    ("skip_connection", "conv_shortcut"),
    ]

    unet_conversion_map_layer = []
    # hardcoded number of downblocks and resnets/attentions...
    # would need smarter logic for other networks.
    for i in range(4):
    # loop over downblocks/upblocks

    for j in range(2):
    # loop over resnets/attentions for downblocks
    hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
    sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
    unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))

    if i < 3:
    # no attention layers in down_blocks.3
    hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
    sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
    unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))

    for j in range(3):
    # loop over resnets/attentions for upblocks
    hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
    sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
    unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))

    if i > 0:
    # no attention layers in up_blocks.0
    hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
    sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
    unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))

    if i < 3:
    # no downsample in down_blocks.3
    hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
    sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
    unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))

    # no upsample in up_blocks.3
    hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
    sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
    unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))

    hf_mid_atn_prefix = "mid_block.attentions.0."
    sd_mid_atn_prefix = "middle_block.1."
    unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))

    for j in range(2):
    hf_mid_res_prefix = f"mid_block.resnets.{j}."
    sd_mid_res_prefix = f"middle_block.{2*j}."
    unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))


    def convert_unet_state_dict(unet_state_dict):
    # buyer beware: this is a *brittle* function,
    # and correct output requires that all of these pieces interact in
    # the exact order in which I have arranged them.
    mapping = {k: k for k in unet_state_dict.keys()}
    for sd_name, hf_name in unet_conversion_map:
    mapping[hf_name] = sd_name
    for k, v in mapping.items():
    if "resnets" in k:
    for sd_part, hf_part in unet_conversion_map_resnet:
    v = v.replace(hf_part, sd_part)
    mapping[k] = v
    for k, v in mapping.items():
    for sd_part, hf_part in unet_conversion_map_layer:
    v = v.replace(hf_part, sd_part)
    mapping[k] = v
    new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
    return new_state_dict


    # ================#
    # VAE Conversion #
    # ================#

    vae_conversion_map = [
    # (stable-diffusion, HF Diffusers)
    ("nin_shortcut", "conv_shortcut"),
    ("norm_out", "conv_norm_out"),
    ("mid.attn_1.", "mid_block.attentions.0."),
    ]

    for i in range(4):
    # down_blocks have two resnets
    for j in range(2):
    hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
    sd_down_prefix = f"encoder.down.{i}.block.{j}."
    vae_conversion_map.append((sd_down_prefix, hf_down_prefix))

    if i < 3:
    hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
    sd_downsample_prefix = f"down.{i}.downsample."
    vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))

    hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
    sd_upsample_prefix = f"up.{3-i}.upsample."
    vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))

    # up_blocks have three resnets
    # also, up blocks in hf are numbered in reverse from sd
    for j in range(3):
    hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
    sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
    vae_conversion_map.append((sd_up_prefix, hf_up_prefix))

    # this part accounts for mid blocks in both the encoder and the decoder
    for i in range(2):
    hf_mid_res_prefix = f"mid_block.resnets.{i}."
    sd_mid_res_prefix = f"mid.block_{i+1}."
    vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))


    vae_conversion_map_attn = [
    # (stable-diffusion, HF Diffusers)
    ("norm.", "group_norm."),
    ("q.", "query."),
    ("k.", "key."),
    ("v.", "value."),
    ("proj_out.", "proj_attn."),
    ]


    def reshape_weight_for_sd(w):
    # convert HF linear weights to SD conv2d weights
    return w.reshape(*w.shape, 1, 1)


    def convert_vae_state_dict(vae_state_dict):
    mapping = {k: k for k in vae_state_dict.keys()}
    for k, v in mapping.items():
    for sd_part, hf_part in vae_conversion_map:
    v = v.replace(hf_part, sd_part)
    mapping[k] = v
    for k, v in mapping.items():
    if "attentions" in k:
    for sd_part, hf_part in vae_conversion_map_attn:
    v = v.replace(hf_part, sd_part)
    mapping[k] = v
    new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
    weights_to_convert = ["q", "k", "v", "proj_out"]
    for k, v in new_state_dict.items():
    for weight_name in weights_to_convert:
    if f"mid.attn_1.{weight_name}.weight" in k:
    print(f"Reshaping {k} for SD format")
    new_state_dict[k] = reshape_weight_for_sd(v)
    return new_state_dict


    # =========================#
    # Text Encoder Conversion #
    # =========================#
    # pretty much a no-op


    def convert_text_enc_state_dict(text_enc_dict):
    return text_enc_dict


    if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
    parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
    parser.add_argument("--half", action="store_true", help="Save weights in half precision.")

    args = parser.parse_args()

    assert args.model_path is not None, "Must provide a model path!"

    assert args.checkpoint_path is not None, "Must provide a checkpoint path!"

    unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
    vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
    text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")

    # Convert the UNet model
    unet_state_dict = torch.load(unet_path, map_location='cpu')
    unet_state_dict = convert_unet_state_dict(unet_state_dict)
    unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}

    # Convert the VAE model
    vae_state_dict = torch.load(vae_path, map_location='cpu')
    vae_state_dict = convert_vae_state_dict(vae_state_dict)
    vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}

    # Convert the text encoder model
    text_enc_dict = torch.load(text_enc_path, map_location='cpu')
    text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
    text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}

    # Put together new checkpoint
    state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
    if args.half:
    state_dict = {k:v.half() for k,v in state_dict.items()}
    state_dict = {"state_dict": state_dict}
    torch.save(state_dict, args.checkpoint_path)
    7 changes: 7 additions & 0 deletions toCkpt.sh
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,7 @@
    #!/bin/bash

    model_path=$1
    ckpt_name=$(basename $model_path)
    ckpt_path="${ckpt_name}.ckpt"

    python convertToCkpt.py --model_path=$model_path --checkpoint_path=$ckpt_path
  7. Christopher-Hayes revised this gist Oct 19, 2022. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion convertToCheckpoint.md
    Original file line number Diff line number Diff line change
    @@ -1,6 +1,6 @@
    ## To run

    Before running these, make sure to create the Python script and (optionally) the .sh script under "Code files"
    Before running these, make sure to **first** create the Python script and (optionally) the .sh script under "Code files"

    ### 1. Make sure you're in the `examples/dreambooth` folder

  8. Christopher-Hayes revised this gist Oct 19, 2022. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion convertToCheckpoint.md
    Original file line number Diff line number Diff line change
    @@ -1,6 +1,6 @@
    ## To run

    Before running these, make sure to create the 2 files under "code files"
    Before running these, make sure to create the Python script and (optionally) the .sh script under "Code files"

    ### 1. Make sure you're in the `examples/dreambooth` folder

  9. Christopher-Hayes revised this gist Oct 19, 2022. 1 changed file with 3 additions and 1 deletion.
    4 changes: 3 additions & 1 deletion convertToCheckpoint.md
    Original file line number Diff line number Diff line change
    @@ -20,7 +20,9 @@ A convenience CLI script is also available:
    ./toCkpt.sh ./name_of_model_folder
    ```

    ### 3. If you're using Automatic1111, copy-paste that `.ckpt` model file into the `models/Stable-Diffusion` folder.
    ### 3. Try out your new checkpoint model.

    If you're using Automatic1111, copy-paste that `.ckpt` model file into the `models/Stable-diffusion` folder.

    ---

  10. Christopher-Hayes revised this gist Oct 19, 2022. 1 changed file with 6 additions and 3 deletions.
    9 changes: 6 additions & 3 deletions convertToCheckpoint.md
    Original file line number Diff line number Diff line change
    @@ -12,13 +12,16 @@ cd examples/dreambooth

    To run the original convert script run this in the CLI (from inside the examples/dreambooth directory):
    ```bash
    python convertToCkpt.py --model_path ./name_of_model_folder --checkpoint_path ./name_of_model_folder/model.ckpt
    python convertToCkpt.py --model_path ./name_of_model_folder --checkpoint_path ./model.ckpt
    ```

    A convenience CLI script is also available:
    ```bash
    ./toCkpt.sh ./name_of_model_folder
    ```

    ### 3. If you're using Automatic1111, copy-paste that `.ckpt` model file into the `models/Stable-Diffusion` folder.

    ---

    ## Code files
    @@ -270,15 +273,15 @@ if __name__ == "__main__":

    ### 2. Convenience CLI command (optional)
    This runs the Python script. It accepts the model folder as the single argument.
    The ckpt is put inside the model folder, and uses the same name as the folder.
    The ckpt will show up with the same name as the model folder.
    Create the file below as "toCkpt.sh"

    ```bash
    #!/bin/bash

    model_path=$1
    ckpt_name=$(basename $model_path)
    ckpt_path="${model_path}/${ckpt_name}.ckpt"
    ckpt_path="${ckpt_name}.ckpt"

    python convertToCkpt.py --model_path=$model_path --checkpoint_path=$ckpt_path
    ```
  11. Christopher-Hayes created this gist Oct 19, 2022.
    284 changes: 284 additions & 0 deletions convertToCheckpoint.md
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,284 @@
    ## To run

    Before running these, make sure to create the 2 files under "code files"

    ### 1. Make sure you're in the `examples/dreambooth` folder

    ```bash
    cd examples/dreambooth
    ```

    ### 2. Either run the Python script directly, or run the convenience CLI script

    To run the original convert script run this in the CLI (from inside the examples/dreambooth directory):
    ```bash
    python convertToCkpt.py --model_path ./name_of_model_folder --checkpoint_path ./name_of_model_folder/model.ckpt
    ```

    A convenience CLI script is also available:
    ```bash
    ./toCkpt.sh ./name_of_model_folder
    ```
    ---

    ## Code files

    Below is 2 files. "toCkpt.sh" and "convertToCkpt.py". Create those files inside the `examples/dreambooth` folder with the code provided.

    ### 1. Python convert script (required)
    Create the file below as "convertToCkpt.py"
    Credit to @jachiam this file is originally from https://gist.github.com/jachiam/8a5c0b607e38fcc585168b90c686eb05

    ```python
    # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
    # *Only* converts the UNet, VAE, and Text Encoder.
    # Does not convert optimizer state or any other thing.
    # Written by jachiam

    import argparse
    import os.path as osp

    import torch


    # =================#
    # UNet Conversion #
    # =================#

    unet_conversion_map = [
    # (stable-diffusion, HF Diffusers)
    ("time_embed.0.weight", "time_embedding.linear_1.weight"),
    ("time_embed.0.bias", "time_embedding.linear_1.bias"),
    ("time_embed.2.weight", "time_embedding.linear_2.weight"),
    ("time_embed.2.bias", "time_embedding.linear_2.bias"),
    ("input_blocks.0.0.weight", "conv_in.weight"),
    ("input_blocks.0.0.bias", "conv_in.bias"),
    ("out.0.weight", "conv_norm_out.weight"),
    ("out.0.bias", "conv_norm_out.bias"),
    ("out.2.weight", "conv_out.weight"),
    ("out.2.bias", "conv_out.bias"),
    ]

    unet_conversion_map_resnet = [
    # (stable-diffusion, HF Diffusers)
    ("in_layers.0", "norm1"),
    ("in_layers.2", "conv1"),
    ("out_layers.0", "norm2"),
    ("out_layers.3", "conv2"),
    ("emb_layers.1", "time_emb_proj"),
    ("skip_connection", "conv_shortcut"),
    ]

    unet_conversion_map_layer = []
    # hardcoded number of downblocks and resnets/attentions...
    # would need smarter logic for other networks.
    for i in range(4):
    # loop over downblocks/upblocks

    for j in range(2):
    # loop over resnets/attentions for downblocks
    hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
    sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
    unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))

    if i < 3:
    # no attention layers in down_blocks.3
    hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
    sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
    unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))

    for j in range(3):
    # loop over resnets/attentions for upblocks
    hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
    sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
    unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))

    if i > 0:
    # no attention layers in up_blocks.0
    hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
    sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
    unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))

    if i < 3:
    # no downsample in down_blocks.3
    hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
    sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
    unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))

    # no upsample in up_blocks.3
    hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
    sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
    unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))

    hf_mid_atn_prefix = "mid_block.attentions.0."
    sd_mid_atn_prefix = "middle_block.1."
    unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))

    for j in range(2):
    hf_mid_res_prefix = f"mid_block.resnets.{j}."
    sd_mid_res_prefix = f"middle_block.{2*j}."
    unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))


    def convert_unet_state_dict(unet_state_dict):
    # buyer beware: this is a *brittle* function,
    # and correct output requires that all of these pieces interact in
    # the exact order in which I have arranged them.
    mapping = {k: k for k in unet_state_dict.keys()}
    for sd_name, hf_name in unet_conversion_map:
    mapping[hf_name] = sd_name
    for k, v in mapping.items():
    if "resnets" in k:
    for sd_part, hf_part in unet_conversion_map_resnet:
    v = v.replace(hf_part, sd_part)
    mapping[k] = v
    for k, v in mapping.items():
    for sd_part, hf_part in unet_conversion_map_layer:
    v = v.replace(hf_part, sd_part)
    mapping[k] = v
    new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
    return new_state_dict


    # ================#
    # VAE Conversion #
    # ================#

    vae_conversion_map = [
    # (stable-diffusion, HF Diffusers)
    ("nin_shortcut", "conv_shortcut"),
    ("norm_out", "conv_norm_out"),
    ("mid.attn_1.", "mid_block.attentions.0."),
    ]

    for i in range(4):
    # down_blocks have two resnets
    for j in range(2):
    hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
    sd_down_prefix = f"encoder.down.{i}.block.{j}."
    vae_conversion_map.append((sd_down_prefix, hf_down_prefix))

    if i < 3:
    hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
    sd_downsample_prefix = f"down.{i}.downsample."
    vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))

    hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
    sd_upsample_prefix = f"up.{3-i}.upsample."
    vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))

    # up_blocks have three resnets
    # also, up blocks in hf are numbered in reverse from sd
    for j in range(3):
    hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
    sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
    vae_conversion_map.append((sd_up_prefix, hf_up_prefix))

    # this part accounts for mid blocks in both the encoder and the decoder
    for i in range(2):
    hf_mid_res_prefix = f"mid_block.resnets.{i}."
    sd_mid_res_prefix = f"mid.block_{i+1}."
    vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))


    vae_conversion_map_attn = [
    # (stable-diffusion, HF Diffusers)
    ("norm.", "group_norm."),
    ("q.", "query."),
    ("k.", "key."),
    ("v.", "value."),
    ("proj_out.", "proj_attn."),
    ]


    def reshape_weight_for_sd(w):
    # convert HF linear weights to SD conv2d weights
    return w.reshape(*w.shape, 1, 1)


    def convert_vae_state_dict(vae_state_dict):
    mapping = {k: k for k in vae_state_dict.keys()}
    for k, v in mapping.items():
    for sd_part, hf_part in vae_conversion_map:
    v = v.replace(hf_part, sd_part)
    mapping[k] = v
    for k, v in mapping.items():
    if "attentions" in k:
    for sd_part, hf_part in vae_conversion_map_attn:
    v = v.replace(hf_part, sd_part)
    mapping[k] = v
    new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
    weights_to_convert = ["q", "k", "v", "proj_out"]
    for k, v in new_state_dict.items():
    for weight_name in weights_to_convert:
    if f"mid.attn_1.{weight_name}.weight" in k:
    print(f"Reshaping {k} for SD format")
    new_state_dict[k] = reshape_weight_for_sd(v)
    return new_state_dict


    # =========================#
    # Text Encoder Conversion #
    # =========================#
    # pretty much a no-op


    def convert_text_enc_state_dict(text_enc_dict):
    return text_enc_dict


    if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
    parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
    parser.add_argument("--half", action="store_true", help="Save weights in half precision.")

    args = parser.parse_args()

    assert args.model_path is not None, "Must provide a model path!"

    assert args.checkpoint_path is not None, "Must provide a checkpoint path!"

    unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
    vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
    text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")

    # Convert the UNet model
    unet_state_dict = torch.load(unet_path, map_location='cpu')
    unet_state_dict = convert_unet_state_dict(unet_state_dict)
    unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}

    # Convert the VAE model
    vae_state_dict = torch.load(vae_path, map_location='cpu')
    vae_state_dict = convert_vae_state_dict(vae_state_dict)
    vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}

    # Convert the text encoder model
    text_enc_dict = torch.load(text_enc_path, map_location='cpu')
    text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
    text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}

    # Put together new checkpoint
    state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
    if args.half:
    state_dict = {k:v.half() for k,v in state_dict.items()}
    state_dict = {"state_dict": state_dict}
    torch.save(state_dict, args.checkpoint_path)
    ```
    ---

    ### 2. Convenience CLI command (optional)
    This runs the Python script. It accepts the model folder as the single argument.
    The ckpt is put inside the model folder, and uses the same name as the folder.
    Create the file below as "toCkpt.sh"

    ```bash
    #!/bin/bash

    model_path=$1
    ckpt_name=$(basename $model_path)
    ckpt_path="${model_path}/${ckpt_name}.ckpt"

    python convertToCkpt.py --model_path=$model_path --checkpoint_path=$ckpt_path
    ```