Last active
January 14, 2019 02:51
-
-
Save wangz10/b867fc57d146ea17805e050ed46bc64c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def interpolate_from_a_to_b_for_c(model, X, labels, a=None, b=None, x_c=None, alpha=0.): | |
| '''Perform interpolation between two classes a and b for any sample x_c. | |
| model: a trained generative model | |
| X: data in the original space with shape: (n_samples, n_features) | |
| labels: array of class labels (n_samples, ) | |
| a, b: class labels a and b | |
| x_c: input sample to manipulate (1, n_features) | |
| alpha: scalar for the magnitude and direction of the interpolation | |
| ''' | |
| # Encode samples to the latent space | |
| Z_a, Z_b = model.encode(X[labels == a]), model.encode(X[labels == b]) | |
| # Find the centroids of the classes a, b in the latent space | |
| z_a_centoid = Z_a.mean(axis=0) | |
| z_b_centoid = Z_b.mean(axis=0) | |
| # The interpolation vector pointing from b -> a | |
| z_b2a = z_a_centoid - z_b_centoid | |
| # Manipulate x_c | |
| z_c = model.encode(x_c) | |
| z_c_interp = z_c + alpha * z_b2a | |
| return model.decode(z_c_interp) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment