Skip to content

Instantly share code, notes, and snippets.

@wangz10
Last active January 14, 2019 02:51
Show Gist options
  • Select an option

  • Save wangz10/b867fc57d146ea17805e050ed46bc64c to your computer and use it in GitHub Desktop.

Select an option

Save wangz10/b867fc57d146ea17805e050ed46bc64c to your computer and use it in GitHub Desktop.
def interpolate_from_a_to_b_for_c(model, X, labels, a=None, b=None, x_c=None, alpha=0.):
'''Perform interpolation between two classes a and b for any sample x_c.
model: a trained generative model
X: data in the original space with shape: (n_samples, n_features)
labels: array of class labels (n_samples, )
a, b: class labels a and b
x_c: input sample to manipulate (1, n_features)
alpha: scalar for the magnitude and direction of the interpolation
'''
# Encode samples to the latent space
Z_a, Z_b = model.encode(X[labels == a]), model.encode(X[labels == b])
# Find the centroids of the classes a, b in the latent space
z_a_centoid = Z_a.mean(axis=0)
z_b_centoid = Z_b.mean(axis=0)
# The interpolation vector pointing from b -> a
z_b2a = z_a_centoid - z_b_centoid
# Manipulate x_c
z_c = model.encode(x_c)
z_c_interp = z_c + alpha * z_b2a
return model.decode(z_c_interp)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment