class NgramBlock(nn.Module): requires_input_ids = True def __init__(self, config, ngram): """ parameter size 4d^2 """ super().__init__() self.ln_1 = RMSNorm(config.d_model, eps=1e-5) self.attn = Ngram(config, ngram) self.ln_2 = RMSNorm(config.d_model, eps=1e-5) mlp_hidden = config.d_model self.mlp = nn.Sequential( nn.Linear(config.d_model, mlp_hidden), nn.SiLU(), nn.Linear(mlp_hidden, config.d_model), ) self.resid_dropout = nn.Dropout(config.resid_pdrop) def forward(self, x, input_ids): # attention/rnn x_att = self.attn(self.ln_1(x), input_ids) x = x + self.resid_dropout(x_att) # ffn x_mlp = self.mlp(self.ln_2(x)) x = x + self.resid_dropout(x_mlp) return x