Skip to content

Instantly share code, notes, and snippets.

@timothelaborie
Last active December 4, 2025 20:05
Show Gist options
  • Select an option

  • Save timothelaborie/071bb3ae8e13e1036e8cb8d0b0694330 to your computer and use it in GitHub Desktop.

Select an option

Save timothelaborie/071bb3ae8e13e1036e8cb8d0b0694330 to your computer and use it in GitHub Desktop.
import torch
from torch.nn import Module
from collections import OrderedDict
from typing import Mapping, Any, List, NamedTuple
from unsloth import tokenizer_utils
def do_nothing(*args, **kwargs):
pass
tokenizer_utils.fix_untrained_tokens = do_nothing
from datasets import load_dataset
import datasets
from trl import SFTTrainer
import pandas as pd
import numpy as np
import os
import pandas as pd
import numpy as np
from unsloth import FastLanguageModel
from trl import SFTTrainer
from transformers import TrainingArguments, Trainer
from typing import Tuple
import warnings
from typing import Any, Dict, List, Union
from transformers import DataCollatorForLanguageModeling
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from transformers import Qwen2ForCausalLM, Qwen2Tokenizer
def _find_mismatched_keys(
model: torch.nn.Module, peft_model_state_dict: dict[str, torch.Tensor], ignore_mismatched_sizes: bool = True
) -> tuple[dict[str, torch.Tensor], list[tuple[str, tuple[int, ...], tuple[int, ...]]]]:
return peft_model_state_dict, []
# Monkey patch the original function
import peft.utils.save_and_load
peft.utils.save_and_load._find_mismatched_keys = _find_mismatched_keys
class _IncompatibleKeys(NamedTuple):
missing_keys: List[str]
unexpected_keys: List[str]
def patched_load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False):
if not isinstance(state_dict, Mapping):
raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.")
missing_keys: List[str] = []
unexpected_keys: List[str] = []
error_msgs: List[str] = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = OrderedDict(state_dict)
if metadata is not None:
state_dict._metadata = metadata # type: ignore[attr-defined]
def load(module, local_state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
if assign:
local_metadata["assign_to_params_buffers"] = assign
module._load_from_state_dict(
local_state_dict,
prefix,
local_metadata,
True,
missing_keys,
unexpected_keys,
error_msgs,
)
for name, child in module._modules.items():
if child is not None:
child_prefix = prefix + name + "."
child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
load(child, child_state_dict, child_prefix)
incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
for hook in module._load_state_dict_post_hooks.values():
out = hook(module, incompatible_keys)
assert out is None, (
"Hooks registered with ``register_load_state_dict_post_hook`` are not"
"expected to return new values, if incompatible_keys need to be modified,"
"it should be done inplace."
)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
for name, param in self._parameters.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
if param.shape != input_param.shape:
print(f"Shape mismatch for {key}, creating new tensor. Old shape: {param.shape}, New shape: {input_param.shape}")
# Create a new parameter with the shape from state_dict
new_param = torch.nn.Parameter(torch.empty_like(input_param), requires_grad=param.requires_grad)
new_param.data.copy_(input_param)
setattr(self, name, new_param)
else:
param.data.copy_(input_param)
elif strict:
missing_keys.append(key)
for name, buf in self._buffers.items():
key = prefix + name
if key in state_dict:
input_buf = state_dict[key]
if buf.shape != input_buf.shape:
print(f"Shape mismatch for buffer {key}, creating new tensor. Old shape: {buf.shape}, New shape: {input_buf.shape}")
# Create a new buffer with the shape from state_dict
new_buf = torch.empty_like(input_buf)
new_buf.copy_(input_buf)
setattr(self, name, new_buf)
else:
buf.copy_(input_buf)
elif strict:
missing_keys.append(key)
# Monkey patch the _load_from_state_dict method
Module._load_from_state_dict = _load_from_state_dict
load(self, state_dict)
del load
if strict:
if len(unexpected_keys) > 0:
error_msgs.insert(0, "Unexpected key(s) in state_dict: {}. ".format(", ".join(f'"{k}"' for k in unexpected_keys)))
if len(missing_keys) > 0:
error_msgs.insert(0, "Missing key(s) in state_dict: {}. ".format(", ".join(f'"{k}"' for k in missing_keys)))
if len(error_msgs) > 0:
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(self.__class__.__name__, "\n\t".join(error_msgs)))
return _IncompatibleKeys(missing_keys, unexpected_keys)
# Apply the monkey patch
Module.load_state_dict = patched_load_state_dict
# Load model
model, tokenizer = FastLanguageModel.from_pretrained(saved_name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment