""" Fuse conv-bn pattern in torch.Module, an example for torch.fx see: https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html """ import copy from typing import Tuple, Dict, Any import torch import torch.fx as fx import torch.nn as nn from ipdb import set_trace # helper functions to fuse the conv and bn # nothing special, just math operations def fuse_conv_bn_eval(conv, bn): """ Given a conv Module `A` and an batch_norm module `B`, returns a conv module `C` such that C(x) == B(A(x)) in inference mode. """ assert(not (conv.training or bn.training)), "Fusion only for eval!" fused_conv = copy.deepcopy(conv) fused_conv.weight, fused_conv.bias = \ fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias, bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias) return fused_conv def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b): if conv_b is None: conv_b = torch.zeros_like(bn_rm) if bn_w is None: bn_w = torch.ones_like(bn_rm) if bn_b is None: bn_b = torch.zeros_like(bn_rm) bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1)) conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b) # Module part: create a nn.Module with conv-bn pattern inside # notice that the conv-bn could be very flexible: # - used in nn.Module directly # - wrapped bn # - nested style with a Sequential container class WrappedBatchNorm(nn.Module): def __init__(self): super().__init__() self.mod = nn.BatchNorm2d(1) def forward(self, x): return self.mod(x) class M(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 1, 1) self.bn1 = nn.BatchNorm2d(1) self.conv2 = nn.Conv2d(1, 1, 1) self.nested = nn.Sequential( nn.BatchNorm2d(1), nn.Conv2d(1, 1, 1), ) self.wrapped = WrappedBatchNorm() def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.conv2(x) x = self.nested(x) x = self.wrapped(x) return x # create an instance of the module model = M().eval() # Let's start! def _parent_name(target : str) -> Tuple[str, str]: """ Splits a qualname into parent path and last atom. For example, `foo.bar.baz` -> (`foo.bar`, `baz`) """ *parent, name = target.rsplit('.', 1) return parent[0] if parent else '', name def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module): assert(isinstance(node.target, str)) # this is for nested modules parent_name, name = _parent_name(node.target) print(f'Modules[{parent_name}].{name} <- {new_module.__class__.__name__}') setattr(modules[parent_name], name, new_module) def fuse(model: nn.Module) -> nn.Module: """ Fuse the conv and bn, where magic happens """ # get the graph representation of the model model = copy.deepcopy(model) fx_model: fx.GraphModule = fx.symbolic_trace(model) modules = dict(fx_model.named_modules()) # Each `GraphModule` has a `Graph` associated with it # The `Graph` itself is represented as a list of `Node` objects. # To iterate the `Graph`, we need iterate the `Node`s for node in fx_model.graph.nodes: # only consider the nodes of `call_module` type if node.op != 'call_module': continue # For call sites, `Node.target` represents the module/function/method # that's being called. # Here, we check `Node.target` to see if it's a batch norm module, # and then check `Node.args[0].target` to see if the input `Node` is # a convolution. cur_module = modules[node.target] if isinstance(cur_module, nn.BatchNorm2d): prev_module = modules[node.args[0].target] if isinstance(prev_module, nn.Conv2d): # find conv-bn pattern if len(node.args[0].users) > 1: # Output of conv is used by other nodes continue fused_conv = fuse_conv_bn_eval(prev_module, cur_module) replace_node_module(node.args[0], modules, fused_conv) # As we've folded the batch norm into the conv, we need to # replace all uses of the batch norm with the conv. node.replace_all_uses_with(node.args[0]) # Now that all uses of the batch norm have been replaced, we can # safely remove the batch norm. fx_model.graph.erase_node(node) fx_model.graph.lint() # After we've modified our graph, we need to recompile our graph in order # to keep the generated code in sync. fx_model.recompile() return fx_model fused_model = fuse(model) print(f'The `forward` code after fusion:{fused_model.code}') # check the output inp = torch.randn(5, 1, 1, 1) if torch.allclose(fused_model(inp), model(inp)): print('Fuse successfully') else: print('Fail to fuse, the diff is too large')