Skip to content

Instantly share code, notes, and snippets.

@isspek
Last active January 3, 2021 21:59
Show Gist options
  • Select an option

  • Save isspek/0fbeda0de8c0cc6b1448e2f216ff862f to your computer and use it in GitHub Desktop.

Select an option

Save isspek/0fbeda0de8c0cc6b1448e2f216ff862f to your computer and use it in GitHub Desktop.
Can not get_weights() in custom train step
def train_step(self, input_data):
labeled_data, labels = input_data
# Train the both
with tf.GradientTape() as tape:
augmented_data, augmented_labels = augment_data(labeled_data, labels, self.unlabeled_train)
logits_student = self.student(augmented_data)
clf_loss = self.classification_costs(logits_student, augmented_labels)
# we augment the data again
augmented_data, augmented_labels = augment_data(labeled_data, labels, self.unlabeled_train)
logits_student = self.student(augmented_data)
logits_teacher = self.teacher(augmented_data)
consistency_cost = self.compute_consistency_cost(logits_teacher, logits_student)
loss = self.compute_overall_cost(clf_loss, consistency_cost)
grads = tape.gradient(loss, self.student.trainable_weights)
self.optimizer.apply_gradients(
zip(grads, self.student.trainable_weights)
)
self.ema()
logits_teacher = self.teacher(input_data)
self.compiled_metrics.update_state(labels, logits_teacher)
metrics = {m.name: m.result() for m in self.metrics}
metrics["loss"] = loss
return metrics
def ema(self):
print(self.teacher.get_weights())
teacher_weights = self.teacher.get_weights()
student_weights = self.student.get_weights()
# length must be equal otherwise it will not work
assert len(student_weights) == len(
teacher_weights), f'length of student and teachers weights are not equal Please check. \n Student: {len(student_weights)}, \n Teacher:{len(teacher_weights)}'
new_layers = []
for i, layers in enumerate(student_weights):
new_layer = self.alpha * (teacher_weights[i]) + (1 - self.alpha) * layers
new_layers.append(new_layer)
self.teacher.set_weights(new_layers)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment