|
|
@@ -0,0 +1,346 @@ |
|
|
{ |
|
|
"cells": [ |
|
|
{ |
|
|
"cell_type": "markdown", |
|
|
"metadata": {}, |
|
|
"source": [ |
|
|
"# Deep Convolutional Generative Adversarial Network (DCGAN) Tutorial" |
|
|
] |
|
|
}, |
|
|
{ |
|
|
"cell_type": "markdown", |
|
|
"metadata": {}, |
|
|
"source": [ |
|
|
"This tutorials walks through an implementation of DCGAN as described in [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/abs/1511.06434).\n", |
|
|
"\n", |
|
|
"To learn more about generative adversarial networks, see my [Medium post](https://medium.com/p/54deab2fce39) on them." |
|
|
] |
|
|
}, |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"execution_count": null, |
|
|
"metadata": { |
|
|
"collapsed": true |
|
|
}, |
|
|
"outputs": [], |
|
|
"source": [ |
|
|
"#Import the libraries we will need.\n", |
|
|
"import tensorflow as tf\n", |
|
|
"import numpy as np\n", |
|
|
"import input_data\n", |
|
|
"import matplotlib.pyplot as plt\n", |
|
|
"import tensorflow.contrib.slim as slim\n", |
|
|
"import os\n", |
|
|
"import scipy.misc\n", |
|
|
"import scipy" |
|
|
] |
|
|
}, |
|
|
{ |
|
|
"cell_type": "markdown", |
|
|
"metadata": {}, |
|
|
"source": [ |
|
|
"We will be using the MNIST dataset. input_data is a library that downloads the dataset and uzips it automatically. It can be acquired Github here: https://gist.github.com/awjuliani/1d21151bc17362bf6738c3dc02f37906" |
|
|
] |
|
|
}, |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"execution_count": null, |
|
|
"metadata": { |
|
|
"collapsed": false, |
|
|
"scrolled": true |
|
|
}, |
|
|
"outputs": [], |
|
|
"source": [ |
|
|
"mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot=False)" |
|
|
] |
|
|
}, |
|
|
{ |
|
|
"cell_type": "markdown", |
|
|
"metadata": {}, |
|
|
"source": [ |
|
|
"### Helper Functions" |
|
|
] |
|
|
}, |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"execution_count": null, |
|
|
"metadata": { |
|
|
"collapsed": true |
|
|
}, |
|
|
"outputs": [], |
|
|
"source": [ |
|
|
"#This function performns a leaky relu activation, which is needed for the discriminator network.\n", |
|
|
"def lrelu(x, leak=0.2, name=\"lrelu\"):\n", |
|
|
" with tf.variable_scope(name):\n", |
|
|
" f1 = 0.5 * (1 + leak)\n", |
|
|
" f2 = 0.5 * (1 - leak)\n", |
|
|
" return f1 * x + f2 * abs(x)\n", |
|
|
" \n", |
|
|
"#The below functions are taken from carpdem20's implementation https://github.com/carpedm20/DCGAN-tensorflow\n", |
|
|
"#They allow for saving sample images from the generator to follow progress\n", |
|
|
"def save_images(images, size, image_path):\n", |
|
|
" return imsave(inverse_transform(images), size, image_path)\n", |
|
|
"\n", |
|
|
"def imsave(images, size, path):\n", |
|
|
" return scipy.misc.imsave(path, merge(images, size))\n", |
|
|
"\n", |
|
|
"def inverse_transform(images):\n", |
|
|
" return (images+1.)/2.\n", |
|
|
"\n", |
|
|
"def merge(images, size):\n", |
|
|
" h, w = images.shape[1], images.shape[2]\n", |
|
|
" img = np.zeros((h * size[0], w * size[1]))\n", |
|
|
"\n", |
|
|
" for idx, image in enumerate(images):\n", |
|
|
" i = idx % size[1]\n", |
|
|
" j = idx / size[1]\n", |
|
|
" img[j*h:j*h+h, i*w:i*w+w] = image\n", |
|
|
"\n", |
|
|
" return img" |
|
|
] |
|
|
}, |
|
|
{ |
|
|
"cell_type": "markdown", |
|
|
"metadata": {}, |
|
|
"source": [ |
|
|
"## Defining the Adversarial Networks" |
|
|
] |
|
|
}, |
|
|
{ |
|
|
"cell_type": "markdown", |
|
|
"metadata": {}, |
|
|
"source": [ |
|
|
"### Generator Network\n", |
|
|
"\n", |
|
|
"The generator takes a vector of random numbers and transforms it into a 32x32 image. Each layer in the network involves a strided transpose convolution, batch normalization, and rectified nonlinearity. Tensorflow's slim library allows us to easily define each of these layers." |
|
|
] |
|
|
}, |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"execution_count": null, |
|
|
"metadata": { |
|
|
"collapsed": true |
|
|
}, |
|
|
"outputs": [], |
|
|
"source": [ |
|
|
"def generator(z):\n", |
|
|
" \n", |
|
|
" zP = slim.fully_connected(z,4*4*256,normalizer_fn=slim.batch_norm,\\\n", |
|
|
" activation_fn=tf.nn.relu,scope='g_project',weights_initializer=initializer)\n", |
|
|
" zCon = tf.reshape(zP,[-1,4,4,256])\n", |
|
|
" \n", |
|
|
" gen1 = slim.convolution2d_transpose(\\\n", |
|
|
" zCon,num_outputs=64,kernel_size=[5,5],stride=[2,2],\\\n", |
|
|
" padding=\"SAME\",normalizer_fn=slim.batch_norm,\\\n", |
|
|
" activation_fn=tf.nn.relu,scope='g_conv1', weights_initializer=initializer)\n", |
|
|
" \n", |
|
|
" gen2 = slim.convolution2d_transpose(\\\n", |
|
|
" gen1,num_outputs=32,kernel_size=[5,5],stride=[2,2],\\\n", |
|
|
" padding=\"SAME\",normalizer_fn=slim.batch_norm,\\\n", |
|
|
" activation_fn=tf.nn.relu,scope='g_conv2', weights_initializer=initializer)\n", |
|
|
" \n", |
|
|
" gen3 = slim.convolution2d_transpose(\\\n", |
|
|
" gen2,num_outputs=16,kernel_size=[5,5],stride=[2,2],\\\n", |
|
|
" padding=\"SAME\",normalizer_fn=slim.batch_norm,\\\n", |
|
|
" activation_fn=tf.nn.relu,scope='g_conv3', weights_initializer=initializer)\n", |
|
|
" \n", |
|
|
" g_out = slim.convolution2d_transpose(\\\n", |
|
|
" gen3,num_outputs=1,kernel_size=[32,32],padding=\"SAME\",\\\n", |
|
|
" biases_initializer=None,activation_fn=tf.nn.tanh,\\\n", |
|
|
" scope='g_out', weights_initializer=initializer)\n", |
|
|
" \n", |
|
|
" return g_out" |
|
|
] |
|
|
}, |
|
|
{ |
|
|
"cell_type": "markdown", |
|
|
"metadata": {}, |
|
|
"source": [ |
|
|
"### Discriminator Network\n", |
|
|
"The discriminator network takes as input a 32x32 image and transforms it into a single valued probability of being generated from real-world data. Again we use tf.slim to define the convolutional layers, batch normalization, and weight initialization." |
|
|
] |
|
|
}, |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"execution_count": null, |
|
|
"metadata": { |
|
|
"collapsed": true |
|
|
}, |
|
|
"outputs": [], |
|
|
"source": [ |
|
|
"def discriminator(bottom, reuse=False):\n", |
|
|
" \n", |
|
|
" dis1 = slim.convolution2d(bottom,16,[4,4],stride=[2,2],padding=\"SAME\",\\\n", |
|
|
" biases_initializer=None,activation_fn=lrelu,\\\n", |
|
|
" reuse=reuse,scope='d_conv1',weights_initializer=initializer)\n", |
|
|
" \n", |
|
|
" dis2 = slim.convolution2d(dis1,32,[4,4],stride=[2,2],padding=\"SAME\",\\\n", |
|
|
" normalizer_fn=slim.batch_norm,activation_fn=lrelu,\\\n", |
|
|
" reuse=reuse,scope='d_conv2', weights_initializer=initializer)\n", |
|
|
" \n", |
|
|
" dis3 = slim.convolution2d(dis2,64,[4,4],stride=[2,2],padding=\"SAME\",\\\n", |
|
|
" normalizer_fn=slim.batch_norm,activation_fn=lrelu,\\\n", |
|
|
" reuse=reuse,scope='d_conv3',weights_initializer=initializer)\n", |
|
|
" \n", |
|
|
" d_out = slim.fully_connected(slim.flatten(dis3),1,activation_fn=tf.nn.sigmoid,\\\n", |
|
|
" reuse=reuse,scope='d_out', weights_initializer=initializer)\n", |
|
|
" \n", |
|
|
" return d_out" |
|
|
] |
|
|
}, |
|
|
{ |
|
|
"cell_type": "markdown", |
|
|
"metadata": {}, |
|
|
"source": [ |
|
|
"### Connecting them together" |
|
|
] |
|
|
}, |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"execution_count": null, |
|
|
"metadata": { |
|
|
"collapsed": false |
|
|
}, |
|
|
"outputs": [], |
|
|
"source": [ |
|
|
"tf.reset_default_graph()\n", |
|
|
"\n", |
|
|
"z_size = 100 #Size of z vector used for generator.\n", |
|
|
"\n", |
|
|
"#This initializaer is used to initialize all the weights of the network.\n", |
|
|
"initializer = tf.truncated_normal_initializer(stddev=0.02)\n", |
|
|
"\n", |
|
|
"#These two placeholders are used for input into the generator and discriminator, respectively.\n", |
|
|
"z_in = tf.placeholder(shape=[None,z_size],dtype=tf.float32) #Random vector\n", |
|
|
"real_in = tf.placeholder(shape=[None,32,32,1],dtype=tf.float32) #Real images\n", |
|
|
"\n", |
|
|
"Gz = generator(z_in) #Generates images from random z vectors\n", |
|
|
"Dx = discriminator(real_in) #Produces probabilities for real images\n", |
|
|
"Dg = discriminator(Gz,reuse=True) #Produces probabilities for generator images\n", |
|
|
"\n", |
|
|
"#These functions together define the optimization objective of the GAN.\n", |
|
|
"d_loss = -tf.reduce_mean(tf.log(Dx) + tf.log(1.-Dg)) #This optimizes the discriminator.\n", |
|
|
"g_loss = -tf.reduce_mean(tf.log(Dg)) #This optimizes the generator.\n", |
|
|
"\n", |
|
|
"tvars = tf.trainable_variables()\n", |
|
|
"\n", |
|
|
"#The below code is responsible for applying gradient descent to update the GAN.\n", |
|
|
"trainerD = tf.train.AdamOptimizer(learning_rate=0.0002,beta1=0.5)\n", |
|
|
"trainerG = tf.train.AdamOptimizer(learning_rate=0.0002,beta1=0.5)\n", |
|
|
"d_grads = trainerD.compute_gradients(d_loss,tvars[9:]) #Only update the weights for the discriminator network.\n", |
|
|
"g_grads = trainerG.compute_gradients(g_loss,tvars[0:9]) #Only update the weights for the generator network.\n", |
|
|
"\n", |
|
|
"update_D = trainerD.apply_gradients(d_grads)\n", |
|
|
"update_G = trainerG.apply_gradients(g_grads)" |
|
|
] |
|
|
}, |
|
|
{ |
|
|
"cell_type": "markdown", |
|
|
"metadata": { |
|
|
"collapsed": true |
|
|
}, |
|
|
"source": [ |
|
|
"## Training the network\n", |
|
|
"Now that we have fully defined our network, it is time to train it!" |
|
|
] |
|
|
}, |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"execution_count": null, |
|
|
"metadata": { |
|
|
"collapsed": false, |
|
|
"scrolled": true |
|
|
}, |
|
|
"outputs": [], |
|
|
"source": [ |
|
|
"batch_size = 128 #Size of image batch to apply at each iteration.\n", |
|
|
"iterations = 500000 #Total number of iterations to use.\n", |
|
|
"sample_directory = './figs' #Directory to save sample images from generator in.\n", |
|
|
"model_directory = './models' #Directory to save trained model to.\n", |
|
|
"\n", |
|
|
"init = tf.initialize_all_variables()\n", |
|
|
"saver = tf.train.Saver()\n", |
|
|
"with tf.Session() as sess: \n", |
|
|
" sess.run(init)\n", |
|
|
" for i in range(iterations):\n", |
|
|
" zs = np.random.uniform(-1.0,1.0,size=[batch_size,z_size]).astype(np.float32) #Generate a random z batch\n", |
|
|
" xs,_ = mnist.train.next_batch(batch_size) #Draw a sample batch from MNIST dataset.\n", |
|
|
" xs = (np.reshape(xs,[batch_size,28,28,1]) - 0.5) * 2.0 #Transform it to be between -1 and 1\n", |
|
|
" xs = np.lib.pad(xs, ((0,0),(2,2),(2,2),(0,0)),'constant', constant_values=(-1, -1)) #Pad the images so the are 32x32\n", |
|
|
" _,dLoss = sess.run([update_D,d_loss],feed_dict={z_in:zs,real_in:xs}) #Update the discriminator\n", |
|
|
" _,gLoss = sess.run([update_G,g_loss],feed_dict={z_in:zs}) #Update the generator, twice for good measure.\n", |
|
|
" _,gLoss = sess.run([update_G,g_loss],feed_dict={z_in:zs})\n", |
|
|
" if i % 10 == 0:\n", |
|
|
" print \"Gen Loss: \" + str(gLoss) + \" Disc Loss: \" + str(dLoss)\n", |
|
|
" z2 = np.random.uniform(-1.0,1.0,size=[batch_size,z_size]).astype(np.float32) #Generate another z batch\n", |
|
|
" newZ = sess.run(Gz,feed_dict={z_in:z2}) #Use new z to get sample images from generator.\n", |
|
|
" if not os.path.exists(sample_directory):\n", |
|
|
" os.makedirs(sample_directory)\n", |
|
|
" #Save sample generator images for viewing training progress.\n", |
|
|
" save_images(np.reshape(newZ[0:36],[36,32,32]),[6,6],sample_directory+'/fig'+str(i)+'.png')\n", |
|
|
" if i % 1000 == 0 && i != 0:\n", |
|
|
" if not os.path.exists(model_directory):\n", |
|
|
" os.makedirs(model_directory)\n", |
|
|
" saver.save(sess,model_directory+'/model-'+str(i)+'.cptk')\n", |
|
|
" print \"Saved Model\"" |
|
|
] |
|
|
}, |
|
|
{ |
|
|
"cell_type": "markdown", |
|
|
"metadata": {}, |
|
|
"source": [ |
|
|
"## Using a trained network\n", |
|
|
"Once we have a trained model saved, we may want to use it to generate new images, and explore the representation it has learned." |
|
|
] |
|
|
}, |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"execution_count": null, |
|
|
"metadata": { |
|
|
"collapsed": true |
|
|
}, |
|
|
"outputs": [], |
|
|
"source": [ |
|
|
"sample_directory = './figs' #Directory to save sample images from generator in.\n", |
|
|
"model_directory = './models' #Directory to load trained model from.\n", |
|
|
"batch_size_sample = 36\n", |
|
|
"\n", |
|
|
"init = tf.initialize_all_variables()\n", |
|
|
"saver = tf.train.Saver()\n", |
|
|
"with tf.Session() as sess: \n", |
|
|
" sess.run(init)\n", |
|
|
" #Reload the model.\n", |
|
|
" print 'Loading Model...'\n", |
|
|
" ckpt = tf.train.get_checkpoint_state(path)\n", |
|
|
" saver.restore(sess,ckpt.model_checkpoint_path)\n", |
|
|
" \n", |
|
|
" zs = np.random.uniform(-1.0,1.0,size=[batch_size_sample,z_size]).astype(np.float32) #Generate a random z batch\n", |
|
|
" newZ = sess.run(Gz,feed_dict={z_in:z2}) #Use new z to get sample images from generator.\n", |
|
|
" if not os.path.exists(sample_directory):\n", |
|
|
" os.makedirs(sample_directory)\n", |
|
|
" save_images(np.reshape(newZ[0:batch_size_sample],[36,32,32]),[6,6],sample_directory+'/fig'+str(i)+'.png')" |
|
|
] |
|
|
} |
|
|
], |
|
|
"metadata": { |
|
|
"kernelspec": { |
|
|
"display_name": "Python 2", |
|
|
"language": "python", |
|
|
"name": "python2" |
|
|
}, |
|
|
"language_info": { |
|
|
"codemirror_mode": { |
|
|
"name": "ipython", |
|
|
"version": 2 |
|
|
}, |
|
|
"file_extension": ".py", |
|
|
"mimetype": "text/x-python", |
|
|
"name": "python", |
|
|
"nbconvert_exporter": "python", |
|
|
"pygments_lexer": "ipython2", |
|
|
"version": "2.7.11" |
|
|
} |
|
|
}, |
|
|
"nbformat": 4, |
|
|
"nbformat_minor": 0 |
|
|
} |