class AE(nn.Module): def __init__(self, n_latent): super(AE, self).__init__() self.encoder = nn.Sequential( nn.Linear(28 * 28, 128), nn.ReLU(True), nn.Linear(128, 64), nn.ReLU(True), nn.Linear(64, n_latent)) self.decoder = nn.Sequential( nn.Linear(n_latent, 64), nn.ReLU(True), nn.Linear(64, 128), nn.ReLU(True), nn.Linear(128, 28 * 28), nn.Sigmoid()) def forward(self, x): x = self.encoder(x) x = self.decoder(x) return x