Last active
September 18, 2024 18:17
-
-
Save okaris/d90f1d134e7788fd6037f32fae5e2a9a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class LiDBLayer(nn.Module): | |
| def __init__(self, in_features, out_features, hidden_dim, r=128, a=8, b=4): | |
| super().__init__() | |
| self.r, self.a, self.b = r, a, b | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.hidden_dim = hidden_dim | |
| # Initialize Aaux and Baux with orthogonal rows (non-trainable) | |
| self.register_buffer('aaux', torch.empty(r, a)) | |
| self.register_buffer('baux', torch.empty(b, r)) | |
| nn.init.orthogonal_(self.aaux) | |
| nn.init.orthogonal_(self.baux) | |
| self.aaux.requires_grad = False | |
| self.baux.requires_grad = False | |
| # Initialize trainable parameters | |
| self.atrain = nn.Parameter(torch.empty(a, out_features)) | |
| self.btrain = nn.Parameter(torch.empty(in_features, b)) | |
| nn.init.kaiming_uniform_(self.atrain, a=math.sqrt(5)) | |
| nn.init.kaiming_uniform_(self.btrain, a=math.sqrt(5)) | |
| self.atrain.requires_grad = False | |
| self.btrain.requires_grad = False | |
| def forward(self, x): | |
| # Use the input to generate scaling factors | |
| atrain_scale = x[:self.a].view(self.a, 1) | |
| btrain_scale = x[self.a:self.a+self.b].view(1, self.b) | |
| # Apply scaling factors | |
| scaled_atrain = self.atrain * atrain_scale | |
| scaled_btrain = self.btrain * btrain_scale | |
| # Compute A and B matrices | |
| A = torch.mm(self.aaux, scaled_atrain) | |
| B = torch.mm(scaled_btrain, self.baux) | |
| return A, B | |
| class FluxCapacitor(nn.Module): | |
| def __init__(self, input_dim=768, hidden_dim=768, num_layers=2, num_blocks=20, transformer_loops=10): | |
| super(FluxCapacitor, self).__init__() | |
| self.transformer_loops = transformer_loops | |
| self.embedding_linear = nn.Linear(input_dim, hidden_dim) | |
| self.transformer = nn.TransformerDecoder( | |
| nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=2), | |
| num_layers=num_layers | |
| ) | |
| self.num_blocks = num_blocks | |
| # LiDB layers | |
| self.lidb_layers = nn.ModuleList([ | |
| LiDBLayer(3072, 15360, hidden_dim) for _ in range(self.num_blocks) | |
| ]) | |
| # # log all trainable parameters with their names | |
| # for name, param in self.named_parameters(): | |
| # if param.requires_grad: | |
| # print(name, param.shape, 'total', param.numel()) | |
| def forward(self, face_embedding, _): | |
| device = face_embedding.device | |
| face_embedding = self.embedding_linear(face_embedding) | |
| face_embedding = face_embedding.unsqueeze(0) # Add sequence dimension | |
| # Initialize zero tensors for each weight | |
| zero_weights = [torch.zeros(1, face_embedding.size(1), device=device) for _ in range(self.num_blocks)] | |
| # Concatenate embedding and zero weights | |
| x = torch.cat([face_embedding] + zero_weights, dim=0) | |
| for _ in range(self.transformer_loops): | |
| transformer_output = self.transformer(x, x) | |
| # remove the first element of the transformer output | |
| transformer_output = transformer_output[1:] | |
| # remove the first element of x | |
| x = x[1:] | |
| x = x + transformer_output # Add transformer result to original x | |
| x = torch.cat([face_embedding, x], dim=0) | |
| # Compute weights directly from transformer output | |
| predicted_weights = {} | |
| for i in range(self.num_blocks): | |
| # Use transformer output directly in LiDB layers | |
| key_A = f'transformer.single_transformer_blocks.{i}.proj_out.lora_A.weight' | |
| key_B = f'transformer.single_transformer_blocks.{i}.proj_out.lora_B.weight' | |
| A, B = self.lidb_layers[i](x[i+1]) | |
| predicted_weights[key_A] = A # Shape: [128, 15360] | |
| predicted_weights[key_B] = B # Shape: [3072, 128], transposed to match expected shape | |
| return predicted_weights |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment