Skip to content

Instantly share code, notes, and snippets.

@mfrashad
Created May 4, 2022 18:40
Show Gist options
  • Select an option

  • Save mfrashad/12d456740c78fb72b58bb213b6365401 to your computer and use it in GitHub Desktop.

Select an option

Save mfrashad/12d456740c78fb72b58bb213b6365401 to your computer and use it in GitHub Desktop.
ClothingGAN - Load model and components
#@title Load Model
# The model name, change this only
selected_model = 'lookbook'
# Load model
import torch
import numpy as np
from PIL import Image
from models import get_instrumented_model
from decomposition import get_or_compute
from config import Config
# Speed up computation
torch.autograd.set_grad_enabled(False)
torch.backends.cudnn.benchmark = True
# Specify model to use
config = Config(
model='StyleGAN2',
layer='style',
output_class=selected_model,
components=20,
use_w=True,
batch_size=5_000, # style layer quite small
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
inst = get_instrumented_model(config.model, config.output_class,
config.layer, device, use_w=config.use_w)
path_to_components = get_or_compute(config, inst)
model = inst.model
comps = np.load(path_to_components)
lst = comps.files
latent_dirs = []
latent_stdevs = []
load_activations = False
for item in lst:
if load_activations:
if item == 'act_comp':
for i in range(comps[item].shape[0]):
latent_dirs.append(comps[item][i])
if item == 'act_stdev':
for i in range(comps[item].shape[0]):
latent_stdevs.append(comps[item][i])
else:
if item == 'lat_comp':
for i in range(comps[item].shape[0]):
latent_dirs.append(comps[item][i])
if item == 'lat_stdev':
for i in range(comps[item].shape[0]):
latent_stdevs.append(comps[item][i])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment