import ast import collections import inspect import numbers import textwrap import numpy PUSH = ast.Attribute(value=ast.Name(id='_stack', ctx=ast.Load()), attr='push', ctx=ast.Load()) POP = ast.Attribute(value=ast.Name(id='_stack', ctx=ast.Load()), attr='pop', ctx=ast.Load()) def parse_function(fn): return ast.parse(textwrap.dedent(inspect.getsource(fn))) class NodeReverse(object): """Generate a primal and adjoint for a given AST tree. Notes ----- In principle, this class simply walks the AST recursively and for each node returns a new primal and an adjoint. A limited amount of communication happens through the state of the class. Assign statements set `current_target` so that the adjoint of the right hand side knows what gradient to read. On the other hand, right hand side expressions set `current_partials` to tell assignment statements what variables the partials were written to. """ def visit(self, node): method = 'visit_' + node.__class__.__name__ visitor = getattr(self, method, self.generic_visit) return visitor(node) @staticmethod def create_grad(node): """Given a variable, create variable name for the gradient. WARNING: This returns an invalid node, with the `ctx` attribute missing. It is assumed that this attribute is filled in later (e.g. by the `replace` function). """ if not isinstance(node, ast.Name): raise TypeError return ast.Name(id='d' + node.id) @staticmethod def create_var(id_): """Method to create a named variable. Used for temporaries.""" return ast.Name(id=id_, ctx=ast.Load()) def visit_FunctionDef(self, node): # TODO Change function signatures to receive stack # TODO Change adjoint signature to take output of primal and its initial gradient pass def visit_statements(self, nodes): """Generate the adjoint of a series of statements.""" primals, adjoints = [], collections.deque() for node in nodes: primal, adjoint = self.visit(node) primals.extend(primal) adjoints.extendleft(adjoint[::-1]) return primals, list(adjoints) def visit_For(self, node): assert not node.orelse primal_body, adjoint_body = self.visit_statements(node.body) def primal_template(body, iter_, target, push): i = 0 for target in iter_: i += 1 body push(i) primal = replace(primal_template, body=primal_body, push=PUSH, target=node.target, iter_=node.iter) def adjoint_template(body, pop): i = pop() for _ in range(i): body adjoint = replace(adjoint_template, body=adjoint_body, pop=POP) return primal, adjoint def visit_BinOp(self, node): adjoint_templates = {} def adjoint_Mult_template(x, y, dx, dy, dz): dx = dz * y dy = dz * x adjoint_templates[ast.Mult] = adjoint_Mult_template def adjoint_Add_template(x, y, dx, dy, dz): dx = dz dy = dz adjoint_templates[ast.Add] = adjoint_Add_template def adjoint_Div_template(x, y, dx, dy, dz): dx = dz / y dy = -dz * x / y ** 2 adjoint_templates[ast.Div] = adjoint_Div_template op = type(node.op) if op not in adjoint_templates: raise ValueError("unknown binary operator") self.current_partials = { node.left: self.create_var('__dx'), node.right: self.create_var('__dy') } return node, replace( adjoint_templates[op], x=node.left, y=node.right, dx=self.current_partials[node.left], dy=self.current_partials[node.right], dz=self.create_grad(self.current_target)) def visit_Assign(self, node): if len(node.targets) != 1: raise ValueError if isinstance(node.targets[0], ast.Tuple): if not isinstance(node.value, ast.Name): raise ValueError("can only unpack variables") # TODO Pack the gradients into a tuple raise ValueError("no support for tuple assignments") if not isinstance(node.targets[0], ast.Name): raise ValueError("can only assign to names") # Extract the target and store it in the state so that the # right hand side templates can use it target = node.targets[0] self.current_target = target primal_rhs, adjoint_rhs = self.visit(node.value) # NOTE We simplify things here by EAFP. Ideally each variable that is # pushed at any point should be set to None at the beginning def primal_template(target, primal_rhs, push): try: push(target) except NameError: push(None) target = primal_rhs primal = replace(primal_template, target=target, primal_rhs=primal_rhs, push=PUSH) # NOTE For each partial gradient from the rhs we want to accumulate # it into the existing gradient; this is the template for that # NOTE EAFP approach again; gradients should be initialized beforehand def accumulate_template(in_grad, partial_grad): try: in_grad = add_grad(in_grad, partial_grad) except NameError: in_grad = partial_grad gradient_accumulation = [] for partial in self.current_partials: gradient_accumulation.extend(replace( accumulate_template, in_grad=self.create_grad(partial), partial_grad=self.current_partials[partial])) # The final adjoint restores the input (pop), stores the partials # in temporary variables, resets the gradient w.r.t. output, # and finally updates the gradients def adjoint_template(target, adjoint_rhs, target_grad, gradient_accumulation, pop): target = pop() adjoint_rhs target_grad = 0 gradient_accumulation adjoint = replace(adjoint_template, target=target, adjoint_rhs=adjoint_rhs, gradient_accumulation=gradient_accumulation, target_grad=self.create_grad(target), pop=POP) # Reset the state self.current_target = None self.current_partials = None return primal, adjoint def generic_visit(self, node): raise ValueError("unknown node type") class ReplaceTransformer(ast.NodeTransformer): """Replace variables with AST nodes""" def __init__(self, replacements): self.replacements = replacements def visit_Name(self, node): replacement_node = self.replacements.get(node.id, node) # Use the replacement node in the same context as the placeholder if isinstance(replacement_node, ast.AST) and \ 'ctx' in replacement_node._fields: replacement_node.ctx = node.ctx return replacement_node def replace(fn, **replacements): """Replace placeholders in a Python template (quote). One special thing happens: If a replacement node has a ctx attribute, it is made to match the ctx attribute of the variable it is replacing. Parameters ---------- fn : function A function used as a metaprogramming template. replacements : dict A mapping from the variable names of the function's arguments to (lists of) AST nodes that these variables will be replaced with wherever they appear in the function body. A replacement can be a list, in which case it will be merged into the list of statements containing the node. Returns ------- body : list A list of statements in the form of AST nodes. """ tree = parse_function(fn).body[0] if replacements.keys() != set(arg.arg for arg in tree.args.args): raise ValueError("too many or few replacements") tree = ReplaceTransformer(replacements).visit(tree) return tree.body def add_grad(left, right): """Recursively add the gradient of e.g. tuples.""" # If the gradient is undefined, then we simply return the rhs # NOTE This is more efficient than initializing empty gradients and # adding to them, since we could be adding to large matrix of zeros then if left is None: return right assert right is not None if type(left) != type(right): raise TypeError("incompatible gradients") if isinstance(left, (numpy.ndarray, numbers.Number)): return left + right if isinstance(left, tuple): return tuple(lelem + relem for lelem, relem in zip(left, right)) raise TypeError("unknown gradient type") def f(x): y = x * x def g(x): for i in range(10): y = x * x if __name__ == "__main__": body = parse_function(g).body[0].body primal, adjoint = NodeReverse().visit_statements(body) import astor print("PRIMAL") print(astor.to_source(ast.Module(body=primal))) print("ADJOINT") print(astor.to_source(ast.Module(body=adjoint))) # PRIMAL # i = 0 # for i in range(10): # i += 1 # try: # _stack.push(y) # except NameError: # _stack.push(None) # y = x * x # _stack.push(i) # ADJOINT # i = _stack.pop() # for _ in range(i): # y = _stack.pop() # __dx = dy * x # __dy = dy * x # dy = 0 # try: # dx = add_grad(dx, __dx) # except NameError: # dx = __dx # try: # dx = add_grad(dx, __dy) # except NameError: # dx = __dy