Skip to content

Instantly share code, notes, and snippets.

@okaris
Last active September 18, 2024 18:17
Show Gist options
  • Select an option

  • Save okaris/d90f1d134e7788fd6037f32fae5e2a9a to your computer and use it in GitHub Desktop.

Select an option

Save okaris/d90f1d134e7788fd6037f32fae5e2a9a to your computer and use it in GitHub Desktop.
class LiDBLayer(nn.Module):
def __init__(self, in_features, out_features, hidden_dim, r=128, a=8, b=4):
super().__init__()
self.r, self.a, self.b = r, a, b
self.in_features = in_features
self.out_features = out_features
self.hidden_dim = hidden_dim
# Initialize Aaux and Baux with orthogonal rows (non-trainable)
self.register_buffer('aaux', torch.empty(r, a))
self.register_buffer('baux', torch.empty(b, r))
nn.init.orthogonal_(self.aaux)
nn.init.orthogonal_(self.baux)
self.aaux.requires_grad = False
self.baux.requires_grad = False
# Initialize trainable parameters
self.atrain = nn.Parameter(torch.empty(a, out_features))
self.btrain = nn.Parameter(torch.empty(in_features, b))
nn.init.kaiming_uniform_(self.atrain, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.btrain, a=math.sqrt(5))
self.atrain.requires_grad = False
self.btrain.requires_grad = False
def forward(self, x):
# Use the input to generate scaling factors
atrain_scale = x[:self.a].view(self.a, 1)
btrain_scale = x[self.a:self.a+self.b].view(1, self.b)
# Apply scaling factors
scaled_atrain = self.atrain * atrain_scale
scaled_btrain = self.btrain * btrain_scale
# Compute A and B matrices
A = torch.mm(self.aaux, scaled_atrain)
B = torch.mm(scaled_btrain, self.baux)
return A, B
class FluxCapacitor(nn.Module):
def __init__(self, input_dim=768, hidden_dim=768, num_layers=2, num_blocks=20, transformer_loops=10):
super(FluxCapacitor, self).__init__()
self.transformer_loops = transformer_loops
self.embedding_linear = nn.Linear(input_dim, hidden_dim)
self.transformer = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=2),
num_layers=num_layers
)
self.num_blocks = num_blocks
# LiDB layers
self.lidb_layers = nn.ModuleList([
LiDBLayer(3072, 15360, hidden_dim) for _ in range(self.num_blocks)
])
# # log all trainable parameters with their names
# for name, param in self.named_parameters():
# if param.requires_grad:
# print(name, param.shape, 'total', param.numel())
def forward(self, face_embedding, _):
device = face_embedding.device
face_embedding = self.embedding_linear(face_embedding)
face_embedding = face_embedding.unsqueeze(0) # Add sequence dimension
# Initialize zero tensors for each weight
zero_weights = [torch.zeros(1, face_embedding.size(1), device=device) for _ in range(self.num_blocks)]
# Concatenate embedding and zero weights
x = torch.cat([face_embedding] + zero_weights, dim=0)
for _ in range(self.transformer_loops):
transformer_output = self.transformer(x, x)
# remove the first element of the transformer output
transformer_output = transformer_output[1:]
# remove the first element of x
x = x[1:]
x = x + transformer_output # Add transformer result to original x
x = torch.cat([face_embedding, x], dim=0)
# Compute weights directly from transformer output
predicted_weights = {}
for i in range(self.num_blocks):
# Use transformer output directly in LiDB layers
key_A = f'transformer.single_transformer_blocks.{i}.proj_out.lora_A.weight'
key_B = f'transformer.single_transformer_blocks.{i}.proj_out.lora_B.weight'
A, B = self.lidb_layers[i](x[i+1])
predicted_weights[key_A] = A # Shape: [128, 15360]
predicted_weights[key_B] = B # Shape: [3072, 128], transposed to match expected shape
return predicted_weights
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment