Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save timothelaborie/fe284d6ccb08b50bb0e9c34e28c622b6 to your computer and use it in GitHub Desktop.

Select an option

Save timothelaborie/fe284d6ccb08b50bb0e9c34e28c622b6 to your computer and use it in GitHub Desktop.
adam_mini attempt for unsloth/Qwen2-7B-bnb-4bit
import torch
from torch.optim.optimizer import Optimizer
import math
import torch.distributed as dist
from torch.optim.optimizer import _dispatch_sqrt
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class Adam_mini(Optimizer):
def __init__(
self, *args, **kwargs
):
print("init adam mini")
weight_decay=0.01
lr=2e-4
beta1=0.9
beta2=0.999
epsilon=1e-8
zero_3=False
n_embd = 3584
n_head = 28
n_query_groups = 7
'''
model: the model you are training.
zero_3: set to True if you are using zero_3 in Deepspeed, or if you are using model parallelism with more than 1 GPU. Set to False if otherwise.
n_embd: number of embedding dimensions. Could be unspecified if you are training non-transformer models.
n_head: number of attention heads. Could be unspecified if you are training non-transformer models.
n_query_groups: number of query groups in Group query Attention. If not specified, it will be equal to n_head. Could be unspecified if you are training non-transformer models.
'''
self.n_embd = n_embd
self.n_head = n_head
if n_query_groups is not None:
self.n_query_groups = n_query_groups
assert self.n_head % self.n_query_groups == 0
else:
self.n_query_groups = self.n_head
self.model = model
self.world_size = torch.cuda.device_count()
self.zero_optimization_stage = 0
if zero_3:
self.zero_optimization_stage = 3
print("Adam-mini is using zero_3")
optim_groups = []
for name, param in self.model.named_parameters():
print("init adam mini name",name)
if param.requires_grad:
print("init adam mini requires_grad")
dic = {}
dic["name"] = name
dic["params"] = param
if ("norm" in name or "ln_f" in name):
dic["weight_decay"] = 0
else:
dic["weight_decay"] = weight_decay
if ("self_attn.k_proj.weight" in name or "self_attn.q_proj.weight" in name):
dic["parameter_per_head"] = self.n_embd * self.n_embd // self.n_head
if ("attn.attn.weight" in name or "attn.qkv.weight" in name):
dic["n_head"] = self.n_head
dic["q_per_kv"] = self.n_head // self.n_query_groups
optim_groups.append(dic)
defaults = dict(lr=lr, beta1=beta1, beta2=beta2, epsilon=epsilon)
super(Adam_mini, self).__init__(optim_groups, defaults)
def step(self,*args, **kwargs):
print("adam mini step")
# print("args",args)
# print("kwargs",kwargs)
with torch.no_grad():
for group in self.param_groups:
beta1 = group["beta1"]
beta2 = group["beta2"]
lr = group["lr"]
name = group["name"]
epsilon = group["epsilon"]
for p in group["params"]:
# print("p",p)
state = self.state[p]
if ("embed_tokens" in name or "wte" in name or "lm_head" in name):
print("if1")
if p.grad is None:
continue
print("if1.1")
if len(state) == 0:
state["m"] = torch.zeros_like(p.data).to(torch.float32)
state["iteration"] = 0
state["v"] = torch.zeros_like(p.data).to(torch.float32)
grad = p.grad.data.to(torch.float32)
state["v"].mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
state["iteration"] += 1
if group["weight_decay"] != 0:
p.data.mul_(1 - lr * group["weight_decay"])
state["m"].lerp_(grad, 1 - beta1)
bias_correction_1 = 1 - beta1 ** state["iteration"]
bias_correction_2 = 1 - beta2 ** state["iteration"]
bias_correction_2_sqrt = math.sqrt(bias_correction_2)
h = (state["v"].sqrt() / bias_correction_2_sqrt).add_(epsilon)
stepsize = lr/ bias_correction_1
p.addcdiv_(state["m"], h, value=-stepsize)
elif ("self_attn" in name):
# elif ("self_attn.k_proj.weight" in name or "self_attn.q_proj.weight" in name or "attn.wq.weight" in name or "attn.wk.weight" in name):
print("if2")
if p.grad is None:
continue
print("if2.1")
# dim = group["parameter_per_head"]
dim = 128
if (len(state)==0):
state["m"] = torch.zeros_like(p.data).to(torch.float32)
state["m"] = state["m"].view(-1, dim)
state['head'] = state['m'].shape[0]
state["iteration"] = 0
state["vmean"] = torch.zeros(state['head']).cuda()
grad = p.grad.data.to(torch.float32)
head = state['head']
grad = grad.view(head, dim)
tmp_lr = torch.mean(grad*grad, dim = 1).cuda()
state["vmean"].mul_(beta2).add_(tmp_lr, alpha=1 - beta2)
v = state["vmean"]
state["iteration"] += 1
if group["weight_decay"] != 0:
p.data.mul_(1 - lr * group["weight_decay"])
state["m"].lerp_(grad, 1 - beta1)
bias_correction_1 = 1 - beta1 ** state["iteration"]
bias_correction_2 = 1 - beta2 ** state["iteration"]
bias_correction_2_sqrt = math.sqrt(bias_correction_2)
h = (v.sqrt() / bias_correction_2_sqrt).add_(epsilon)
stepsize = ((1/bias_correction_1) / h).view(head,1)
update = state["m"] * (stepsize.to(state['m'].device))
if p.dim() > 1:
d0, d1 = p.size()
update = update.view(d0, d1)
else:
update = update.view(-1)
update.mul_(lr)
p.add_(-update)
elif ("attn.attn.weight" in name or "attn.qkv.weight" in name):
if p.grad is None:
continue
if (len(state)==0):
state["m"] = torch.zeros_like(p.data).to(torch.float32)
state["m"] = state["m"].view(group["n_head"], group["q_per_kv"] + 2, -1)
state["iteration"] = 0
state["vmean"] = torch.zeros(group["n_head"], group["q_per_kv"]+2).cuda()
grad = p.grad.data.to(torch.float32)
grad = grad.view(group["n_head"], group["q_per_kv"] + 2, -1)
tmp_lr = torch.mean(grad*grad, dim = 2).cuda()
state["vmean"].mul_(beta2).add_(tmp_lr, alpha=1 - beta2)
v = state["vmean"]
state["iteration"] += 1
if group["weight_decay"] != 0:
p.data.mul_(1 - lr * group["weight_decay"])
state["m"].lerp_(grad, 1 - beta1)
bias_correction_1 = 1 - beta1 ** state["iteration"]
bias_correction_2 = 1 - beta2 ** state["iteration"]
bias_correction_2_sqrt = math.sqrt(bias_correction_2)
h = (v.sqrt() / bias_correction_2_sqrt).add_(epsilon)
stepsize = ((1/bias_correction_1) / h).view(group["n_head"],group["q_per_kv"]+2,1)
update = state["m"] * (stepsize.to(state['m'].device))
if p.dim() > 1:
d0, d1 = p.size()
update = update.view(d0, d1)
else:
update = update.view(-1)
update.mul_(lr)
p.add_(-update)
else:
print("name",name, "lr",lr, "beta1",beta1, "beta2",beta2)
if (len(state)==0):
dimension = torch.tensor(p.data.numel()).cuda().to(torch.float32)
reduced = False
if (self.world_size > 1) and (self.zero_optimization_stage == 3):
tensor_list = [torch.zeros_like(dimension) for _ in range(self.world_size)]
dist.all_gather(tensor_list, dimension)
s = 0
dimension = 0
for d in tensor_list:
if (d>0):
s = s + 1
dimension = dimension + d
if (s>=2):
reduced = True
state["m"] = torch.zeros_like(p.data).to(torch.float32)
state["iteration"] = 0
state["reduced"] = reduced
state["vmean"] = torch.tensor(0.0).cuda()
state["dimension"] = dimension.item()
if p.grad is None:
tmp_lr = torch.tensor(0.0).cuda()
else:
grad = p.grad.data.to(torch.float32)
tmp_lr = torch.sum(grad*grad)
if (state["reduced"]):
dist.all_reduce(tmp_lr, op=dist.ReduceOp.SUM)
tmp_lr = tmp_lr / (state["dimension"])
tmp_lr = tmp_lr.to(grad.device)
if (p.grad is None):
continue
if group["weight_decay"] != 0:
p.data.mul_(1 - lr * group["weight_decay"])
state["iteration"] += 1
state["m"].lerp_(grad, 1 - beta1)
bias_correction_1 = 1 - beta1 ** state["iteration"]
bias_correction_2 = 1 - beta2 ** state["iteration"]
bias_correction_2_sqrt = math.sqrt(bias_correction_2)
state["vmean"] = (1 - beta2) * tmp_lr + beta2 * state["vmean"]
h = (state["vmean"].sqrt() / bias_correction_2_sqrt).add_(epsilon)
stepsize = (1 / bias_correction_1) / h
update = state["m"] * (stepsize.to(state['m'].device))
update.mul_(lr)
p.add_(-update)
# optim monkey patch
@staticmethod
def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
"""
Returns the optimizer class and optimizer parameters based on the training arguments.
Args:
args (`transformers.training_args.TrainingArguments`):
The training arguments for the training session.
"""
# parse args.optim_args
optim_args = {}
if args.optim_args:
for mapping in args.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optim_args[key] = value
optimizer_kwargs = {"lr": args.learning_rate}
optimizer_cls = Adam_mini
return optimizer_cls, optimizer_kwargs
Trainer.get_optimizer_cls_and_kwargs = get_optimizer_cls_and_kwargs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment