Skip to content

Instantly share code, notes, and snippets.

@cswiercz
Last active August 3, 2017 17:30
Show Gist options
  • Select an option

  • Save cswiercz/37ec48f44a713fb3676480a8706644ed to your computer and use it in GitHub Desktop.

Select an option

Save cswiercz/37ec48f44a713fb3676480a8706644ed to your computer and use it in GitHub Desktop.
Understanding MXNet elemwise_sum

Introduction

This document is a sort of guide to creating an MXNet operator. However, here we will learn by exploring an example: elemwise_sum. (.h, .c, .cu)

The code in elemwise_sum.{h,cc,cu} defines and registers the add_n() operator, which computes the elementwise sum of an arbitrary number of input arguments of the same shape.

>>> import mxnet as mx
>>> a = mx.nd.array([1,2,3])
>>> b = mx.nd.array([4,5,6])
>>> c = mx.nd.array([7,8,9])
>>> x = mx.nd.add_n(a,b,c)
>>> x.asnumpy()
array([ 12.,  15.,  18.], dtype=float32)

Registering an Operator

Before jumping into the workhorse part of the code let's see how an operator is registered with MXNet. Once we understand this we'll see how all of the consituents of the implementation are linked together.

Registration is done using the macro NNVM_REGISTER_OP. The code for registering add_n can be found in elemwise_sum.c and elemwise_sum.cu. Since the operator registration on the GPU-side inherits most of its parameters from the CPU-side parameter we will focus on the contents of elemwise_sum.cc:

NNVM_REGISTER_OP(add_n)
.add_alias("ElementWiseSum")
.describe(R"doc(Adds all input arguments element-wise.
.. math::
   add\_n(a_1, a_2, ..., a_n) = a_1 + a_2 + ... + a_n
``add_n`` is potentially more efficient than calling ``add`` by `n` times.
)doc" ADD_FILELINE)
.set_attr_parser(ParamParser<ElementWiseSumParam>)
.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
    uint32_t ret = dmlc::get<ElementWiseSumParam>(attrs.parsed).num_args;
    return ret;
  })
.set_attr<nnvm::FListInputNames>("FListInputNames",
  [](const NodeAttrs& attrs) {
    uint32_t num_args = dmlc::get<ElementWiseSumParam>(attrs.parsed).num_args;
    std::vector<std::string> ret;
    for (uint32_t i = 0; i < num_args; ++i) {
      ret.push_back(std::string("arg") + std::to_string(i));
    }
    return ret;
  })
.set_attr<std::string>("key_var_num_args", "num_args")
.set_attr<FCompute>("FCompute<cpu>", ElementWiseSumCompute<cpu>)
.set_attr<nnvm::FInplaceOption>(
    "FInplaceOption", [](const NodeAttrs& attrs) {
      return std::vector<std::pair<int, int> >{{0, 0}};
    })
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<-1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<nnvm::FGradient>("FGradient", CloneGradient{"_backward_add_n"})
.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments");

There is a lot to parse here, but don't worry, we'll take each component one at a time.

Operator Registration Attributes

Let's look at each of the attributes in the above operator definition one by one. For reference, the source code for the macro and basic operator structure/attributes can be found in NNVM's op.h. Also see op_attr_types.h for a list of the different types of attributes. If you're already really comfortable with working with MXNet then go ahead and read these two source codes. My goal here is to take some of the documentation listed in these files and exand upon them.

  • NNVM_REGISTER_OP(add_n)

    This registers the name of the operator. When built, you can invoke the operator from the Python interface, for example, from mx.nd.add_n() or mx.sym.add_n(). At the C++ level, the operator would be invoked using Op::Get(add_n). For example, one could write the following to make the operator easier to use:

    using namespace mxnet:op; // TODO: is this the right one?
    const Op* add_n = Op::Get("add_n");
    // use add_n below using an OpKernel (to be discussed below)
  • .add_alias("ElementWiseSum")

    Register an alias for the operator, allowing you to invoke it at the C++ level by writing Op::Get(ElementWiseSum).

  • .describe(R"...")

    Easy enough: include a docstring. Parsed as reStructured text syntax for the purposes of generating online documentation. See here for a reST primer.

  • .set_attr_parser(ParamParser<dmlc::Parameter>)

    Allows you to customize the way the attributes are parsed in the definition and invocation of the operator. In this case, the function add_n() is a bit tricky since we want to allow an arbitrary number of input arguments.

    Let's look at ElementWiseSumParam more closely to see what's going on,

    struct ElementWiseSumParam : public dmlc::Parameter<ElementWiseSumParam> {
    int num_args;
    DMLC_DECLARE_PARAMETER(ElementWiseSumParam) {
      DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
          .describe("Number of inputs to be summed.");
      }
    };

    Right off the bat we see that the struct makes use of the Curiously Recurring Template Pattern which should make you feel like you're a rockstar C++ programmer. (Basically, CRTP is a compile-time polymorphism technique.) Long story short, the parameter inherits from dmlc::Parameter which is a lightweight parameter management system for operators and objects.

  • .set_num_inputs(int OR std::function<uint32_t (const NodeAttrs& attr)>)

    Set the number of inputs to the operator. That's it! For some operators the number of inputs is fixed and known. In this case, all you need to do is provide a hard coded integer, here. (e.g. .set_num_inputs(2) for an operator with two arguments/inputs)

    But that's not really it because here we are already in the deep end and want to define the way get the number of inputs using our custom parameter parser. The code used in add_n() is repeated here for convenience,

    .set_num_inputs([](const nnvm::NodeAttrs& attrs) {
      uint32_t ret = dmlc::get<ElementWiseSumParam>(attrs.parsed).num_args;
      return ret;
    })

    As the prototype requests we provide a function that accepts a list of attributes, which in this case are of type ElementWiseSumParam, and returns an unsigned integer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment