Last active
December 4, 2025 20:05
-
-
Save timothelaborie/071bb3ae8e13e1036e8cb8d0b0694330 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch.nn import Module | |
| from collections import OrderedDict | |
| from typing import Mapping, Any, List, NamedTuple | |
| from unsloth import tokenizer_utils | |
| def do_nothing(*args, **kwargs): | |
| pass | |
| tokenizer_utils.fix_untrained_tokens = do_nothing | |
| from datasets import load_dataset | |
| import datasets | |
| from trl import SFTTrainer | |
| import pandas as pd | |
| import numpy as np | |
| import os | |
| import pandas as pd | |
| import numpy as np | |
| from unsloth import FastLanguageModel | |
| from trl import SFTTrainer | |
| from transformers import TrainingArguments, Trainer | |
| from typing import Tuple | |
| import warnings | |
| from typing import Any, Dict, List, Union | |
| from transformers import DataCollatorForLanguageModeling | |
| from sklearn.model_selection import train_test_split | |
| import matplotlib.pyplot as plt | |
| from transformers import Qwen2ForCausalLM, Qwen2Tokenizer | |
| def _find_mismatched_keys( | |
| model: torch.nn.Module, peft_model_state_dict: dict[str, torch.Tensor], ignore_mismatched_sizes: bool = True | |
| ) -> tuple[dict[str, torch.Tensor], list[tuple[str, tuple[int, ...], tuple[int, ...]]]]: | |
| return peft_model_state_dict, [] | |
| # Monkey patch the original function | |
| import peft.utils.save_and_load | |
| peft.utils.save_and_load._find_mismatched_keys = _find_mismatched_keys | |
| class _IncompatibleKeys(NamedTuple): | |
| missing_keys: List[str] | |
| unexpected_keys: List[str] | |
| def patched_load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False): | |
| if not isinstance(state_dict, Mapping): | |
| raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.") | |
| missing_keys: List[str] = [] | |
| unexpected_keys: List[str] = [] | |
| error_msgs: List[str] = [] | |
| # copy state_dict so _load_from_state_dict can modify it | |
| metadata = getattr(state_dict, "_metadata", None) | |
| state_dict = OrderedDict(state_dict) | |
| if metadata is not None: | |
| state_dict._metadata = metadata # type: ignore[attr-defined] | |
| def load(module, local_state_dict, prefix=""): | |
| local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | |
| if assign: | |
| local_metadata["assign_to_params_buffers"] = assign | |
| module._load_from_state_dict( | |
| local_state_dict, | |
| prefix, | |
| local_metadata, | |
| True, | |
| missing_keys, | |
| unexpected_keys, | |
| error_msgs, | |
| ) | |
| for name, child in module._modules.items(): | |
| if child is not None: | |
| child_prefix = prefix + name + "." | |
| child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} | |
| load(child, child_state_dict, child_prefix) | |
| incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) | |
| for hook in module._load_state_dict_post_hooks.values(): | |
| out = hook(module, incompatible_keys) | |
| assert out is None, ( | |
| "Hooks registered with ``register_load_state_dict_post_hook`` are not" | |
| "expected to return new values, if incompatible_keys need to be modified," | |
| "it should be done inplace." | |
| ) | |
| def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): | |
| for name, param in self._parameters.items(): | |
| key = prefix + name | |
| if key in state_dict: | |
| input_param = state_dict[key] | |
| if param.shape != input_param.shape: | |
| print(f"Shape mismatch for {key}, creating new tensor. Old shape: {param.shape}, New shape: {input_param.shape}") | |
| # Create a new parameter with the shape from state_dict | |
| new_param = torch.nn.Parameter(torch.empty_like(input_param), requires_grad=param.requires_grad) | |
| new_param.data.copy_(input_param) | |
| setattr(self, name, new_param) | |
| else: | |
| param.data.copy_(input_param) | |
| elif strict: | |
| missing_keys.append(key) | |
| for name, buf in self._buffers.items(): | |
| key = prefix + name | |
| if key in state_dict: | |
| input_buf = state_dict[key] | |
| if buf.shape != input_buf.shape: | |
| print(f"Shape mismatch for buffer {key}, creating new tensor. Old shape: {buf.shape}, New shape: {input_buf.shape}") | |
| # Create a new buffer with the shape from state_dict | |
| new_buf = torch.empty_like(input_buf) | |
| new_buf.copy_(input_buf) | |
| setattr(self, name, new_buf) | |
| else: | |
| buf.copy_(input_buf) | |
| elif strict: | |
| missing_keys.append(key) | |
| # Monkey patch the _load_from_state_dict method | |
| Module._load_from_state_dict = _load_from_state_dict | |
| load(self, state_dict) | |
| del load | |
| if strict: | |
| if len(unexpected_keys) > 0: | |
| error_msgs.insert(0, "Unexpected key(s) in state_dict: {}. ".format(", ".join(f'"{k}"' for k in unexpected_keys))) | |
| if len(missing_keys) > 0: | |
| error_msgs.insert(0, "Missing key(s) in state_dict: {}. ".format(", ".join(f'"{k}"' for k in missing_keys))) | |
| if len(error_msgs) > 0: | |
| raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(self.__class__.__name__, "\n\t".join(error_msgs))) | |
| return _IncompatibleKeys(missing_keys, unexpected_keys) | |
| # Apply the monkey patch | |
| Module.load_state_dict = patched_load_state_dict | |
| # Load model | |
| model, tokenizer = FastLanguageModel.from_pretrained(saved_name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment