Skip to content

Instantly share code, notes, and snippets.

@cswiercz
Last active August 3, 2017 17:30
Show Gist options
  • Select an option

  • Save cswiercz/37ec48f44a713fb3676480a8706644ed to your computer and use it in GitHub Desktop.

Select an option

Save cswiercz/37ec48f44a713fb3676480a8706644ed to your computer and use it in GitHub Desktop.
Understanding MXNet elemwise_sum

Introduction

This document is a sort of guide to creating an MXNet operator. However, here we will learn by exploring an example: elemwise_sum. (.h, .c, .cu)

The code in elemwise_sum.{h,cc,cu} defines and registers the add_n() operator, which computes the elementwise sum of an arbitrary number of input arguments of the same shape.

>>> import mxnet as mx
>>> a = mx.nd.array([1,2,3])
>>> b = mx.nd.array([4,5,6])
>>> c = mx.nd.array([7,8,9])
>>> x = mx.nd.add_n(a,b,c)
>>> x.asnumpy()
array([ 12.,  15.,  18.], dtype=float32)

Registering an Operator

Before jumping into the workhorse part of the code let's see how an operator is registered with MXNet. Once we understand this we'll see how all of the consituents of the implementation are linked together.

Registration is done using the macro NNVM_REGISTER_OP. The code for registering add_n can be found in elemwise_sum.c and elemwise_sum.cu. Since the operator registration on the GPU-side inherits most of its parameters from the CPU-side parameter we will focus on the contents of elemwise_sum.cc:

NNVM_REGISTER_OP(add_n)
.add_alias("ElementWiseSum")
.describe(R"doc(Adds all input arguments element-wise.
.. math::
   add\_n(a_1, a_2, ..., a_n) = a_1 + a_2 + ... + a_n
``add_n`` is potentially more efficient than calling ``add`` by `n` times.
)doc" ADD_FILELINE)
.set_attr_parser(ParamParser<ElementWiseSumParam>)
.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
    uint32_t ret = dmlc::get<ElementWiseSumParam>(attrs.parsed).num_args;
    return ret;
  })
.set_attr<nnvm::FListInputNames>("FListInputNames",
  [](const NodeAttrs& attrs) {
    uint32_t num_args = dmlc::get<ElementWiseSumParam>(attrs.parsed).num_args;
    std::vector<std::string> ret;
    for (uint32_t i = 0; i < num_args; ++i) {
      ret.push_back(std::string("arg") + std::to_string(i));
    }
    return ret;
  })
.set_attr<std::string>("key_var_num_args", "num_args")
.set_attr<FCompute>("FCompute<cpu>", ElementWiseSumCompute<cpu>)
.set_attr<nnvm::FInplaceOption>(
    "FInplaceOption", [](const NodeAttrs& attrs) {
      return std::vector<std::pair<int, int> >{{0, 0}};
    })
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<-1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<nnvm::FGradient>("FGradient", CloneGradient{"_backward_add_n"})
.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments");

There is a lot to parse here, but don't worry, we'll take each component one at a time.

Operator Registration Attributes

Let's look at each of the attributes in the above operator definition one by one. For reference, the source code for the macro and basic operator structure/attributes can be found in NNVM's op.h.

  • NNVM_REGISTER_OP(add_n)

    This registers the name of the operator. When built, you can invoke the operator from the Python interface, for example, from mx.nd.add_n() or mx.sym.add_n(). At the C++ level, the operator would be invoked using Op::Get(add_n).

  • .add_alias("ElementWiseSum")

    Register an alias for the operator, allowing you to invoke it at the C++ level by writing Op::Get(ElementWiseSum).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment