Skip to content

Instantly share code, notes, and snippets.

@laol777
Created March 21, 2019 15:42
Show Gist options
  • Select an option

  • Save laol777/0ac74698317f890b9a087e354d358239 to your computer and use it in GitHub Desktop.

Select an option

Save laol777/0ac74698317f890b9a087e354d358239 to your computer and use it in GitHub Desktop.
from MODEL import *
from EVALUATION import *
timer = Timer()
timer.start()
netG_act_o_1 = dict(size=2, index=0)
netG_act_o_2 = dict(size=2, index=1)
netD_act_o = dict(size=1, index=0)
gp_weight_1 = tf.placeholder(tf.float32)
gp_weight_2 = tf.placeholder(tf.float32)
def wgan_gp(fake_data, real_data):
fake_data = tf.reshape(fake_data, [FLAGS['data_train_batch_size'], -1])
real_data = tf.reshape(real_data, [FLAGS['data_train_batch_size'], -1])
alpha = tf.random_uniform(shape=[FLAGS['data_train_batch_size'], 1],
minval=0., maxval=1., seed=FLAGS['process_random_seed'])
differences = fake_data - real_data
interpolates = real_data + (alpha * differences)
interpolates_D = tf.reshape(interpolates, [FLAGS['data_train_batch_size'], FLAGS['data_image_size'],
FLAGS['data_image_size'], FLAGS['data_image_channel']])
gradients = tf.gradients(model(netD, interpolates_D, True, netD_act_o), [interpolates])[0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
if FLAGS['loss_wgan_use_g_to_one']:
gradient_penalty = -tf.reduce_mean((slopes - 1.) ** 2)
else:
gradient_penalty = -tf.reduce_mean(tf.maximum(0., slopes - 1.))
return gradient_penalty
with tf.name_scope(netG.name):
with tf.variable_scope(netG.variable_scope_name) as scope_full:
with tf.variable_scope(netG.variable_scope_name + 'B') as scopeB:
netG_train_output2 = model(netG, train_df.input2, True, netG_act_o_1, is_first=True)
scopeB.reuse_variables()
netG_test_output2 = model(netG, test_df.input2, False, netG_act_o_1)
netG_train_output2_for_netD = model(netG, train_df.input2, False, netG_act_o_1)
with tf.variable_scope(netG.variable_scope_name + 'A') as scopeA:
netG_train_output1 = model(netG, train_df.input1, True, netG_act_o_1, is_first=True)
scopeA.reuse_variables()
netG_test_output1 = model(netG, test_df.input1, False, netG_act_o_1)
netG_train_output1_for_netD = model(netG, train_df.input1, False, netG_act_o_1)
netG_train_output2_inv = model(netG, tf.clip_by_value(netG_train_output2, 0, 1), True, netG_act_o_2)
netG_test_output2_inv = model(netG, tf.clip_by_value(netG_test_output2, 0, 1), False, netG_act_o_2)
with tf.variable_scope(netG.variable_scope_name + 'B') as scopeB:
scopeB.reuse_variables()
netG_train_output1_inv = model(netG, tf.clip_by_value(netG_train_output1, 0, 1), True, netG_act_o_2)
netG_test_output1_inv = model(netG, tf.clip_by_value(netG_test_output1, 0, 1), False, netG_act_o_2)
with tf.name_scope(netD.name):
with tf.variable_scope(netD.variable_scope_name) as scope_full:
with tf.variable_scope(netD.variable_scope_name + 'A') as scopeA:
netD_train_output1_1 = model(netD, netG_train_output1_for_netD, True, netD_act_o, is_first=True)
scopeA.reuse_variables()
netD_train_output2_1 = model(netD, train_df.input2, True, netD_act_o)
netD_netG_train_output1_1 = model(netD, netG_train_output1, True, netD_act_o)
netD_netG_train_output2_1 = netD_train_output2_1
netD_test_output1_1 = model(netD, netG_test_output1, False, netD_act_o)
netD_test_output2_1 = model(netD, test_df.input2, False, netD_act_o)
# wgan-gp
if FLAGS['loss_wgan_gp_use_all']:
assert False, 'not yet'
else:
w_list = []
for _ in range(FLAGS['loss_wgan_gp_times']):
w_list.append(wgan_gp(netG_train_output1_for_netD, train_df.input2))
gradient_penalty_1 = tf.reduce_mean(tf.pack(w_list)) * gp_weight_1
with tf.variable_scope(netD.variable_scope_name + 'B') as scopeB:
netD_train_output1_2 = model(netD, train_df.input1, True, netD_act_o, is_first=True)
scopeB.reuse_variables()
netD_train_output2_2 = model(netD, netG_train_output2_for_netD, True, netD_act_o)
netD_netG_train_output1_2 = netD_train_output1_2
netD_netG_train_output2_2 = model(netD, netG_train_output2, True, netD_act_o)
netD_test_output1_2 = model(netD, test_df.input1, False, netD_act_o)
netD_test_output2_2 = model(netD, netG_test_output2, False, netD_act_o)
# wgan-gp
if FLAGS['loss_wgan_gp_use_all']:
assert False, 'not yet'
else:
w_list = []
for _ in range(FLAGS['loss_wgan_gp_times']):
w_list.append(wgan_gp(netG_train_output2_for_netD, train_df.input1))
gradient_penalty_2 = tf.reduce_mean(tf.pack(w_list)) * gp_weight_2
def update_cache_dict(csr_ind, csr_val, csr_ind_r, csr_val_r, csr_ind_g, csr_val_g, csr_ind_b, csr_val_b, csr_names):
for i, names in enumerate(csr_names):
for name in names:
if name[-2:] == '_r':
csr_ind_r[i] = csr_dict[name][:, :2] - 1
csr_val_r[i] = csr_dict[name][:, -1]
elif name[-2:] == '_g':
csr_ind_g[i] = csr_dict[name][:, :2] - 1
csr_val_g[i] = csr_dict[name][:, -1]
elif name[-2:] == '_b':
csr_ind_b[i] = csr_dict[name][:, :2] - 1
csr_val_b[i] = csr_dict[name][:, -1]
else:
csr_ind[i] = csr_dict[name][:, :2] - 1
csr_val[i] = csr_dict[name][:, -1]
saver = tf.train.Saver(var_list=netD.weights + netG.weights, max_to_keep=None)
sess_config = tf.ConfigProto(log_device_placement=False)
sess_config.gpu_options.allow_growth = FLAGS['sys_use_all_gpu_memory'] is False
with tf.Session(config=sess_config) as sess:
print(FLAGS['load_model_path'])
saver.restore(sess, FLAGS['load_model_path'])
print('after restore ', str(timer.end()))
test_count = 42
data_loader = DataLoader()
netG_test_output1_crop = tf_crop_rect(netG_test_output1, test_df.mat1, 0)
for i in range(test_count):
label_img = data_loader.get_next_test_label()
input_img, data = data_loader.get_next_test_input_batch()
update_cache_dict(data['csr_ind1'], data['csr_val1'], data['csr_ind_r1'], data['csr_val_r1'],
data['csr_ind_g1'], data['csr_val_g1'], data['csr_ind_b1'], data['csr_val_b1'],
data['csr_names1'])
update_cache_dict(data['csr_ind2'], data['csr_val2'], data['csr_ind_r2'], data['csr_val_r2'],
data['csr_ind_g2'], data['csr_val_g2'], data['csr_ind_b2'], data['csr_val_b2'],
data['csr_names2'])
dict_d = \
[input_img, label_img] + \
data['rect1'] + data['rot1'] + \
data['rect2'] + data['rot2'] + \
data['csr_ind1'] + data['csr_val1'] + \
data['csr_ind_r1'] + data['csr_val_r1'] + \
data['csr_ind_g1'] + data['csr_val_g1'] + \
data['csr_ind_b1'] + data['csr_val_b1'] + data['csr_sha1'] + \
data['csr_ind2'] + data['csr_val2'] + \
data['csr_ind_r2'] + data['csr_val_r2'] + \
data['csr_ind_g2'] + data['csr_val_g2'] + \
data['csr_ind_b2'] + data['csr_val_b2'] + data['csr_sha2']
dict_t = \
[test_df.input1_src, test_df.input2_src] + \
test_df.mat1.rect + test_df.mat1.rot + \
test_df.mat2.rect + test_df.mat2.rot + \
test_df.mat1.csr_ind + test_df.mat1.csr_val + \
test_df.mat1.csr_ind_r + test_df.mat1.csr_val_r + \
test_df.mat1.csr_ind_g + test_df.mat1.csr_val_g + \
test_df.mat1.csr_ind_b + test_df.mat1.csr_val_b + test_df.mat1.csr_sha + \
test_df.mat2.csr_ind + test_df.mat2.csr_val + \
test_df.mat2.csr_ind_r + test_df.mat2.csr_val_r + \
test_df.mat2.csr_ind_g + test_df.mat2.csr_val_g + \
test_df.mat2.csr_ind_b + test_df.mat2.csr_val_b + test_df.mat2.csr_sha
output_node_names = ['netG/netG_var_scope/netG_var_scopeA/netG_3_1/Add']
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
tf.get_default_graph().as_graph_def(),
output_node_names)
# # Save the frozen graph
with open('output_graph.pb', 'wb') as f:
f.write(frozen_graph_def.SerializeToString())
print(1)
enhance_test_img = sess.run([netG_test_output1_crop], feed_dict={t: d for t, d in zip(dict_t, dict_d)})
print(2)
enhance_test_img = safe_casting(enhance_test_img[0] * tf.as_dtype(FLAGS['data_input_dtype']).max,
FLAGS['data_input_dtype'])
cv2.imwrite(FLAGS['folder_test_img'] + test_image_name_list[i] + FLAGS['data_input_ext'],
enhance_test_img)
print('end ', str(timer.end()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment