Last active
September 1, 2025 18:57
-
-
Save umtksa/912050d7c76c4aff182f4e922432bf94 to your computer and use it in GitHub Desktop.
Qwen3-0.6B finetune
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from datasets import load_dataset | |
| from transformers import ( | |
| AutoTokenizer, AutoModelForCausalLM, | |
| TrainingArguments, Trainer, DataCollatorForSeq2Seq, | |
| set_seed | |
| ) | |
| from peft import LoraConfig, get_peft_model, TaskType | |
| import torch | |
| set_seed(42) | |
| # load dataset | |
| dataset = load_dataset("umtksa/tools", split="train") | |
| # Model & tokenizer | |
| model_id = "Qwen/Qwen3-0.6B" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, torch_dtype=torch.float32, | |
| device_map="auto", trust_remote_code=True | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # | |
| # move to mps | |
| device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") | |
| model.to(device) | |
| # preprocess | |
| def preprocess(example): | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": example["input"]} | |
| ] | |
| text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| text = text + example["output"] + tokenizer.eos_token | |
| tokenized = tokenizer(text, truncation=True, | |
| padding="max_length", max_length=128) | |
| tokenized["labels"] = tokenized["input_ids"].copy() | |
| return tokenized | |
| tokenized_dataset = dataset.map(preprocess, remove_columns=dataset.column_names) | |
| # LoRA prefs | |
| lora_config = LoraConfig( | |
| r=4, lora_alpha=16, lora_dropout=0.1, bias="none", | |
| target_modules=[ | |
| "q_proj","k_proj","v_proj","o_proj", | |
| "gate_proj","up_proj","down_proj" | |
| ], | |
| task_type=TaskType.CAUSAL_LM | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| #training args | |
| training_args = TrainingArguments( | |
| output_dir="/qwen3-tool", | |
| per_device_train_batch_size=1, | |
| gradient_accumulation_steps=8, | |
| num_train_epochs=5, | |
| learning_rate=2e-4, | |
| logging_steps=10, | |
| save_strategy="epoch", | |
| save_total_limit=1, | |
| fp16=False, bf16=False, | |
| report_to="none", | |
| dataloader_pin_memory=False | |
| ) | |
| data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) | |
| trainer = Trainer( | |
| model=model, args=training_args, | |
| train_dataset=tokenized_dataset, | |
| data_collator=data_collator | |
| ) | |
| # start training | |
| trainer.train() | |
| trainer.save_model("/qwen3-tool") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment