for epoch in range(num_epochs): for n, (real_samples, mnist_labels) in enumerate(train_loader): # Данные для тренировки дискриминатора real_samples = real_samples.to(device=device) real_samples_labels = torch.ones((batch_size, 1)).to( device=device) latent_space_samples = torch.randn((batch_size, 100)).to( device=device) generated_samples = generator(latent_space_samples) generated_samples_labels = torch.zeros((batch_size, 1)).to( device=device) all_samples = torch.cat((real_samples, generated_samples)) all_samples_labels = torch.cat( (real_samples_labels, generated_samples_labels)) # Обучение дискриминатора discriminator.zero_grad() output_discriminator = discriminator(all_samples) loss_discriminator = loss_function( output_discriminator, all_samples_labels) loss_discriminator.backward() optimizer_discriminator.step() # Данные для обучения генератора latent_space_samples = torch.randn((batch_size, 100)).to( device=device) # Обучение генератора generator.zero_grad() generated_samples = generator(latent_space_samples) output_discriminator_generated = discriminator(generated_samples) loss_generator = loss_function( output_discriminator_generated, real_samples_labels) loss_generator.backward() optimizer_generator.step() # Показываем loss if n == batch_size - 1: print(f"Epoch: {epoch} Loss D.: {loss_discriminator} | Loss G.:{loss_generator}")