from tinygrad import Tensor, nn, Device, TinyJit from tinygrad.nn.datasets import mnist from tinygrad.nn.state import safe_save, get_state_dict import math print(f"Using device: {Device.DEFAULT}") # normalization for Pattention def nonlinear_normalization(inputs, normalization_type, dim=-1): if normalization_type == 'softmax': outputs = inputs.softmax(axis=dim) # * math.sqrt(inputs.shape[dim]) elif normalization_type == 'scaled_softmax': # note: this one works better without growth scale = 1.0 / math.sqrt(inputs.shape[dim]) outputs = (inputs * scale).softmax(axis=dim) elif normalization_type == 'scaled_softmax2': scale = 1.0 / math.sqrt(inputs.shape[dim]) inputs = inputs * scale max_val = inputs.max(axis=dim, keepdim=True) exp_inputs = (inputs - max_val).exp() outputs = exp_inputs / exp_inputs.sum(axis=dim, keepdim=True) elif normalization_type == 'l1_norm': norm = inputs.abs().sum(axis=dim, keepdim=True) outputs = inputs / norm * math.sqrt(inputs.shape[dim]) elif normalization_type == 'l2_norm': norm = (inputs ** 2).sum(axis=dim, keepdim=True).sqrt() outputs = inputs / norm * math.sqrt(inputs.shape[dim]) elif normalization_type == 'gelu_l2_norm': nonlinear_outputs = inputs.gelu() norm = (nonlinear_outputs ** 2).sum(axis=dim, keepdim=True).sqrt() outputs = nonlinear_outputs / norm * math.sqrt(inputs.shape[dim]) elif normalization_type == 'l2_norm_gelu': norm = (inputs ** 2).sum(axis=dim, keepdim=True).sqrt() norm_outputs = inputs / norm * math.sqrt(inputs.shape[dim]) outputs = norm_outputs.gelu() else: raise NotImplementedError return outputs class Pattention: def __init__(self, input_channels, output_channels, token_num, normalization_type): self.input_channels = input_channels self.output_channels = output_channels self.normalization_type = normalization_type # Initialize with small random values and enable gradients self.key_param_tokens = Tensor.randn(token_num, input_channels) * 0.02 self.key_param_tokens.requires_grad = True self.value_param_tokens = Tensor.randn(token_num, output_channels) * 0.00001 self.value_param_tokens.requires_grad = True def grow_parameters(self, num_to_add): # Create new parameters new_keys = Tensor.randn(num_to_add, self.input_channels, requires_grad=True) * 0.02 new_values = Tensor.randn(num_to_add, self.output_channels, requires_grad=True) * 0.00001 # Concatenate while preserving gradients self.key_param_tokens = Tensor.cat(self.key_param_tokens, new_keys, dim=0) self.value_param_tokens = Tensor.cat(self.value_param_tokens, new_values, dim=0) def __call__(self, inputs): attn_weights = inputs @ self.key_param_tokens.transpose() attn_weights = nonlinear_normalization(attn_weights, self.normalization_type) output = attn_weights @ self.value_param_tokens return output # Define the model class Model: def __init__(self): self.l1 = nn.Conv2d(1, 32, kernel_size=(3,3)) self.l2 = nn.Conv2d(32, 64, kernel_size=(3,3)) # Replace Linear with Pattention self.pattention = Pattention(input_channels=1600, output_channels=10, token_num=4, normalization_type='l2_norm') def __call__(self, x:Tensor) -> Tensor: x = self.l1(x).relu().max_pool2d((2,2)) x = self.l2(x).relu().max_pool2d((2,2)) # Remove dropout and use Pattention instead of Linear return self.pattention(x.flatten(1)) # Move all training code inside a main block if __name__ == "__main__": # Load dataset X_train, Y_train, X_test, Y_test = mnist() # Initialize model and optimizer model = Model() optim = nn.optim.Adam(nn.state.get_parameters(model)) batch_size = 256 # Define training step def step(): Tensor.training = True samples = Tensor.randint(batch_size, high=X_train.shape[0]) X, Y = X_train[samples], Y_train[samples] optim.zero_grad() loss = model(X).sparse_categorical_crossentropy(Y).backward() optim.step() return loss # JIT compile the step function jit_step = TinyJit(step) # Configuration for growing parameters GROWTH_START = 0 # Start growing after N steps GROWTH_INTERVAL = 50 # Grow parameters every N steps GROWTH_RATE = 2 MAX_TOKENS = 64 # Training loop for step_num in range(7000): loss = jit_step() # Grow parameters periodically if step_num > GROWTH_START and (step_num - GROWTH_START) % GROWTH_INTERVAL == 0 and model.pattention.key_param_tokens.shape[0] < MAX_TOKENS: model.pattention.grow_parameters(GROWTH_RATE) # Reinitialize the optimizer with the new parameters optim = nn.optim.Adam(nn.state.get_parameters(model)) # Need to recompile after changing parameter shapes jit_step = TinyJit(step) print(f"Growing parameters at step {step_num}. New token count: {model.pattention.key_param_tokens.shape[0]}") if step_num % 100 == 0: Tensor.training = False acc = (model(X_test).argmax(axis=1) == Y_test).mean().item() print(f"step {step_num:4d}, loss {loss.item():.3f}, acc {acc*100.:.2f}%") # Save the trained model state_dict = get_state_dict(model) safe_save(state_dict, "mnist_model_pattention2.safetensors") print("Model saved successfully")