Created
March 21, 2019 15:42
-
-
Save laol777/0ac74698317f890b9a087e354d358239 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from MODEL import * | |
| from EVALUATION import * | |
| timer = Timer() | |
| timer.start() | |
| netG_act_o_1 = dict(size=2, index=0) | |
| netG_act_o_2 = dict(size=2, index=1) | |
| netD_act_o = dict(size=1, index=0) | |
| gp_weight_1 = tf.placeholder(tf.float32) | |
| gp_weight_2 = tf.placeholder(tf.float32) | |
| def wgan_gp(fake_data, real_data): | |
| fake_data = tf.reshape(fake_data, [FLAGS['data_train_batch_size'], -1]) | |
| real_data = tf.reshape(real_data, [FLAGS['data_train_batch_size'], -1]) | |
| alpha = tf.random_uniform(shape=[FLAGS['data_train_batch_size'], 1], | |
| minval=0., maxval=1., seed=FLAGS['process_random_seed']) | |
| differences = fake_data - real_data | |
| interpolates = real_data + (alpha * differences) | |
| interpolates_D = tf.reshape(interpolates, [FLAGS['data_train_batch_size'], FLAGS['data_image_size'], | |
| FLAGS['data_image_size'], FLAGS['data_image_channel']]) | |
| gradients = tf.gradients(model(netD, interpolates_D, True, netD_act_o), [interpolates])[0] | |
| slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) | |
| if FLAGS['loss_wgan_use_g_to_one']: | |
| gradient_penalty = -tf.reduce_mean((slopes - 1.) ** 2) | |
| else: | |
| gradient_penalty = -tf.reduce_mean(tf.maximum(0., slopes - 1.)) | |
| return gradient_penalty | |
| with tf.name_scope(netG.name): | |
| with tf.variable_scope(netG.variable_scope_name) as scope_full: | |
| with tf.variable_scope(netG.variable_scope_name + 'B') as scopeB: | |
| netG_train_output2 = model(netG, train_df.input2, True, netG_act_o_1, is_first=True) | |
| scopeB.reuse_variables() | |
| netG_test_output2 = model(netG, test_df.input2, False, netG_act_o_1) | |
| netG_train_output2_for_netD = model(netG, train_df.input2, False, netG_act_o_1) | |
| with tf.variable_scope(netG.variable_scope_name + 'A') as scopeA: | |
| netG_train_output1 = model(netG, train_df.input1, True, netG_act_o_1, is_first=True) | |
| scopeA.reuse_variables() | |
| netG_test_output1 = model(netG, test_df.input1, False, netG_act_o_1) | |
| netG_train_output1_for_netD = model(netG, train_df.input1, False, netG_act_o_1) | |
| netG_train_output2_inv = model(netG, tf.clip_by_value(netG_train_output2, 0, 1), True, netG_act_o_2) | |
| netG_test_output2_inv = model(netG, tf.clip_by_value(netG_test_output2, 0, 1), False, netG_act_o_2) | |
| with tf.variable_scope(netG.variable_scope_name + 'B') as scopeB: | |
| scopeB.reuse_variables() | |
| netG_train_output1_inv = model(netG, tf.clip_by_value(netG_train_output1, 0, 1), True, netG_act_o_2) | |
| netG_test_output1_inv = model(netG, tf.clip_by_value(netG_test_output1, 0, 1), False, netG_act_o_2) | |
| with tf.name_scope(netD.name): | |
| with tf.variable_scope(netD.variable_scope_name) as scope_full: | |
| with tf.variable_scope(netD.variable_scope_name + 'A') as scopeA: | |
| netD_train_output1_1 = model(netD, netG_train_output1_for_netD, True, netD_act_o, is_first=True) | |
| scopeA.reuse_variables() | |
| netD_train_output2_1 = model(netD, train_df.input2, True, netD_act_o) | |
| netD_netG_train_output1_1 = model(netD, netG_train_output1, True, netD_act_o) | |
| netD_netG_train_output2_1 = netD_train_output2_1 | |
| netD_test_output1_1 = model(netD, netG_test_output1, False, netD_act_o) | |
| netD_test_output2_1 = model(netD, test_df.input2, False, netD_act_o) | |
| # wgan-gp | |
| if FLAGS['loss_wgan_gp_use_all']: | |
| assert False, 'not yet' | |
| else: | |
| w_list = [] | |
| for _ in range(FLAGS['loss_wgan_gp_times']): | |
| w_list.append(wgan_gp(netG_train_output1_for_netD, train_df.input2)) | |
| gradient_penalty_1 = tf.reduce_mean(tf.pack(w_list)) * gp_weight_1 | |
| with tf.variable_scope(netD.variable_scope_name + 'B') as scopeB: | |
| netD_train_output1_2 = model(netD, train_df.input1, True, netD_act_o, is_first=True) | |
| scopeB.reuse_variables() | |
| netD_train_output2_2 = model(netD, netG_train_output2_for_netD, True, netD_act_o) | |
| netD_netG_train_output1_2 = netD_train_output1_2 | |
| netD_netG_train_output2_2 = model(netD, netG_train_output2, True, netD_act_o) | |
| netD_test_output1_2 = model(netD, test_df.input1, False, netD_act_o) | |
| netD_test_output2_2 = model(netD, netG_test_output2, False, netD_act_o) | |
| # wgan-gp | |
| if FLAGS['loss_wgan_gp_use_all']: | |
| assert False, 'not yet' | |
| else: | |
| w_list = [] | |
| for _ in range(FLAGS['loss_wgan_gp_times']): | |
| w_list.append(wgan_gp(netG_train_output2_for_netD, train_df.input1)) | |
| gradient_penalty_2 = tf.reduce_mean(tf.pack(w_list)) * gp_weight_2 | |
| def update_cache_dict(csr_ind, csr_val, csr_ind_r, csr_val_r, csr_ind_g, csr_val_g, csr_ind_b, csr_val_b, csr_names): | |
| for i, names in enumerate(csr_names): | |
| for name in names: | |
| if name[-2:] == '_r': | |
| csr_ind_r[i] = csr_dict[name][:, :2] - 1 | |
| csr_val_r[i] = csr_dict[name][:, -1] | |
| elif name[-2:] == '_g': | |
| csr_ind_g[i] = csr_dict[name][:, :2] - 1 | |
| csr_val_g[i] = csr_dict[name][:, -1] | |
| elif name[-2:] == '_b': | |
| csr_ind_b[i] = csr_dict[name][:, :2] - 1 | |
| csr_val_b[i] = csr_dict[name][:, -1] | |
| else: | |
| csr_ind[i] = csr_dict[name][:, :2] - 1 | |
| csr_val[i] = csr_dict[name][:, -1] | |
| saver = tf.train.Saver(var_list=netD.weights + netG.weights, max_to_keep=None) | |
| sess_config = tf.ConfigProto(log_device_placement=False) | |
| sess_config.gpu_options.allow_growth = FLAGS['sys_use_all_gpu_memory'] is False | |
| with tf.Session(config=sess_config) as sess: | |
| print(FLAGS['load_model_path']) | |
| saver.restore(sess, FLAGS['load_model_path']) | |
| print('after restore ', str(timer.end())) | |
| test_count = 42 | |
| data_loader = DataLoader() | |
| netG_test_output1_crop = tf_crop_rect(netG_test_output1, test_df.mat1, 0) | |
| for i in range(test_count): | |
| label_img = data_loader.get_next_test_label() | |
| input_img, data = data_loader.get_next_test_input_batch() | |
| update_cache_dict(data['csr_ind1'], data['csr_val1'], data['csr_ind_r1'], data['csr_val_r1'], | |
| data['csr_ind_g1'], data['csr_val_g1'], data['csr_ind_b1'], data['csr_val_b1'], | |
| data['csr_names1']) | |
| update_cache_dict(data['csr_ind2'], data['csr_val2'], data['csr_ind_r2'], data['csr_val_r2'], | |
| data['csr_ind_g2'], data['csr_val_g2'], data['csr_ind_b2'], data['csr_val_b2'], | |
| data['csr_names2']) | |
| dict_d = \ | |
| [input_img, label_img] + \ | |
| data['rect1'] + data['rot1'] + \ | |
| data['rect2'] + data['rot2'] + \ | |
| data['csr_ind1'] + data['csr_val1'] + \ | |
| data['csr_ind_r1'] + data['csr_val_r1'] + \ | |
| data['csr_ind_g1'] + data['csr_val_g1'] + \ | |
| data['csr_ind_b1'] + data['csr_val_b1'] + data['csr_sha1'] + \ | |
| data['csr_ind2'] + data['csr_val2'] + \ | |
| data['csr_ind_r2'] + data['csr_val_r2'] + \ | |
| data['csr_ind_g2'] + data['csr_val_g2'] + \ | |
| data['csr_ind_b2'] + data['csr_val_b2'] + data['csr_sha2'] | |
| dict_t = \ | |
| [test_df.input1_src, test_df.input2_src] + \ | |
| test_df.mat1.rect + test_df.mat1.rot + \ | |
| test_df.mat2.rect + test_df.mat2.rot + \ | |
| test_df.mat1.csr_ind + test_df.mat1.csr_val + \ | |
| test_df.mat1.csr_ind_r + test_df.mat1.csr_val_r + \ | |
| test_df.mat1.csr_ind_g + test_df.mat1.csr_val_g + \ | |
| test_df.mat1.csr_ind_b + test_df.mat1.csr_val_b + test_df.mat1.csr_sha + \ | |
| test_df.mat2.csr_ind + test_df.mat2.csr_val + \ | |
| test_df.mat2.csr_ind_r + test_df.mat2.csr_val_r + \ | |
| test_df.mat2.csr_ind_g + test_df.mat2.csr_val_g + \ | |
| test_df.mat2.csr_ind_b + test_df.mat2.csr_val_b + test_df.mat2.csr_sha | |
| output_node_names = ['netG/netG_var_scope/netG_var_scopeA/netG_3_1/Add'] | |
| frozen_graph_def = tf.graph_util.convert_variables_to_constants( | |
| sess, | |
| tf.get_default_graph().as_graph_def(), | |
| output_node_names) | |
| # # Save the frozen graph | |
| with open('output_graph.pb', 'wb') as f: | |
| f.write(frozen_graph_def.SerializeToString()) | |
| print(1) | |
| enhance_test_img = sess.run([netG_test_output1_crop], feed_dict={t: d for t, d in zip(dict_t, dict_d)}) | |
| print(2) | |
| enhance_test_img = safe_casting(enhance_test_img[0] * tf.as_dtype(FLAGS['data_input_dtype']).max, | |
| FLAGS['data_input_dtype']) | |
| cv2.imwrite(FLAGS['folder_test_img'] + test_image_name_list[i] + FLAGS['data_input_ext'], | |
| enhance_test_img) | |
| print('end ', str(timer.end())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment