Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save jireh-father/b3cbeb29a3193ec130825dd9f58f9a37 to your computer and use it in GitHub Desktop.

Select an option

Save jireh-father/b3cbeb29a3193ec130825dd9f58f9a37 to your computer and use it in GitHub Desktop.
cnn_attention_model.py
net = tf.layers.conv2d(inputs, 64, 11, 4, padding="VALID", activation=tf.nn.relu, name="conv1")
net = tf.layers.max_pooling2d(net, 3, 2, name="max_pool1")
attention_mask = tf.layers.conv2d(net, 1, 1, padding="SAME")
shape = attention_mask.get_shape()
features = tf.reshape(tf.transpose(attention_mask, [0, 3, 1, 2]),
[config.batch_size * int(shape[3]), int(shape[1]) * int(
shape[2])])
spatial_softmax = tf.nn.softmax(features)
spatial_softmax = tf.transpose(tf.reshape(spatial_softmax, [config.batch_size, int(shape[3]), int(shape[1]),
int(shape[2])]), [0, 2, 3, 1])
# element-wise multiply
attention_head = tf.multiply(net, spatial_softmax)
ouput_head = tf.layers.average_pooling2d(attention_head, int(attention_head.get_shape()[1]), strides=1)
ouput_head = tf.squeeze(ouput_head)
attention_prediction = tf.layers.dense(ouput_head, config.num_class)
confidence = tf.nn.tanh(tf.layers.dense(attention_prediction, config.num_class))
gate_weights = tf.nn.softmax(confidence)
attention_output = tf.multiply(attention_prediction, gate_weights)
net = tf.layers.conv2d(net, 128, 5, activation=tf.nn.relu, name="conv2")
net = tf.layers.max_pooling2d(net, 3, 2, name="max_pool2")
net = tf.layers.conv2d(net, 256, 3, activation=tf.nn.relu, name="conv3")
net = tf.layers.conv2d(net, 256, 3, activation=tf.nn.relu, name="conv4")
net = tf.layers.conv2d(net, 128, 3, activation=tf.nn.relu, name="conv5")
net = tf.layers.max_pooling2d(net, 3, 2, name="max_pool3")
net = tf.reshape(net, [-1, int(net.get_shape()[1]) * int(net.get_shape()[2]) * int(net.get_shape()[3])])
net = tf.layers.dense(net, 512, activation=tf.nn.relu, name="fc1")
net = tf.layers.dense(net, 256, activation=tf.nn.relu, name="fc2")
net = tf.layers.dense(net, config.num_class, name="fc3")
net_gate_weights = tf.nn.softmax(tf.nn.tanh(tf.layers.dense(net, config.num_class)))
net_output = tf.multiply(net, net_gate_weights)
net = net_output + attention_output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment