import sys from collections import OrderedDict PY2 = sys.version_info[0] == 2 _internal_attrs = {'_backend', '_parameters', '_buffers', '_backward_hooks', '_forward_hooks', '_forward_pre_hooks', '_modules'} class Scope(object): def __init__(self): self._modules = OrderedDict() def _make_functional(module, params_box, params_offset): self = Scope() num_params = len(module._parameters) param_names = list(module._parameters.keys()) forward = type(module).forward.__func__ if PY2 else type(module).forward for name, attr in module.__dict__.items(): if name in _internal_attrs: continue setattr(self, name, attr) child_params_offset = params_offset + num_params for name, child in module.named_children(): child_params_offset, fchild = _make_functional(child, params_box, child_params_offset) self._modules[name] = fchild setattr(self, name, fchild) def fmodule(*args, **kwargs): for name, param in zip(param_names, params_box[0][params_offset:params_offset + num_params]): setattr(self, name, param) return forward(self, *args, **kwargs) return child_params_offset, fmodule def make_functional(module): params_box = [None] _, fmodule_internal = _make_functional(module, params_box, 0) def fmodule(*args, **kwargs): params_box[0] = kwargs.pop('params') return fmodule_internal(*args, **kwargs) return fmodule ################################################################################ import torch from torch import nn from torch.nn import functional as F class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.layers = nn.Sequential( nn.Conv2d(1, 10, kernel_size=5), nn.MaxPool2d(2), nn.ReLU(), nn.Conv2d(10, 20, kernel_size=5), nn.MaxPool2d(2), nn.ReLU(), nn.Dropout2d()) self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10) def forward(self, x): x = self.layers(x) x = x.view(-1, 320) x = F.relu(self.fc1(x)) x = F.dropout(x, training=self.training) x = self.fc2(x) return F.log_softmax(x, dim=1) model = Net() model.eval() eval_fmodel = make_functional(model) model.train() train_fmodel = make_functional(model) # Verify correctness in eval mode (because we have dropout) model.eval() params = list(model.parameters()) x = torch.randn(10, 1, 28, 28) print(model(x).sum()) print(fmodel(x, params=params).sum())