This document is a sort of guide to creating an MXNet operator. However, here we will learn by exploring an example: elemwise_sum. (.h, .c, .cu)
The code in elemwise_sum.{h,cc,cu} defines and registers the add_n() operator, which computes the elementwise sum of an arbitrary number of input arguments of the same shape.
>>> import mxnet as mx
>>> a = mx.nd.array([1,2,3])
>>> b = mx.nd.array([4,5,6])
>>> c = mx.nd.array([7,8,9])
>>> x = mx.nd.add_n(a,b,c)
>>> x.asnumpy()
array([ 12., 15., 18.], dtype=float32)Before jumping into the workhorse part of the code let's see how an operator is registered with MXNet. Once we understand this we'll see how all of the consituents of the implementation are linked together.
Registration is done using the macro NNVM_REGISTER_OP. The code for registering add_n can be found in elemwise_sum.c and elemwise_sum.cu. Since the operator registration on the GPU-side inherits most of its parameters from the CPU-side parameter we will focus on the contents of elemwise_sum.cc:
NNVM_REGISTER_OP(add_n)
.add_alias("ElementWiseSum")
.describe(R"doc(Adds all input arguments element-wise.
.. math::
add\_n(a_1, a_2, ..., a_n) = a_1 + a_2 + ... + a_n
``add_n`` is potentially more efficient than calling ``add`` by `n` times.
)doc" ADD_FILELINE)
.set_attr_parser(ParamParser<ElementWiseSumParam>)
.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
uint32_t ret = dmlc::get<ElementWiseSumParam>(attrs.parsed).num_args;
return ret;
})
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
uint32_t num_args = dmlc::get<ElementWiseSumParam>(attrs.parsed).num_args;
std::vector<std::string> ret;
for (uint32_t i = 0; i < num_args; ++i) {
ret.push_back(std::string("arg") + std::to_string(i));
}
return ret;
})
.set_attr<std::string>("key_var_num_args", "num_args")
.set_attr<FCompute>("FCompute<cpu>", ElementWiseSumCompute<cpu>)
.set_attr<nnvm::FInplaceOption>(
"FInplaceOption", [](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{0, 0}};
})
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<-1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<nnvm::FGradient>("FGradient", CloneGradient{"_backward_add_n"})
.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments");There is a lot to parse here, but don't worry, we'll take each component one at a time.
Let's look at each of the attributes in the above operator definition one by one. For reference, the source code for the macro and basic operator structure/attributes can be found in NNVM's op.h. Also see op_attr_types.h for a list of the different types of attributes. If you're already really comfortable with working with MXNet then go ahead and read these two source codes. My goal here is to take some of the documentation listed in these files and exand upon them.
-
NNVM_REGISTER_OP(add_n)This registers the name of the operator. When built, you can invoke the operator from the Python interface, for example, from
mx.nd.add_n()ormx.sym.add_n(). At the C++ level, the operator would be invoked usingOp::Get(add_n). For example, one could write the following to make the operator easier to use:using namespace mxnet:op; // TODO: is this the right one? const Op* add_n = Op::Get("add_n"); // use add_n below using an OpKernel (to be discussed below)
-
.add_alias("ElementWiseSum")Register an alias for the operator, allowing you to invoke it at the C++ level by writing
Op::Get(ElementWiseSum). -
.describe(R"...")Easy enough: include a docstring. Parsed as reStructured text syntax for the purposes of generating online documentation. See here for a reST primer.
-
.set_attr_parser(ParamParser<dmlc::Parameter>)Allows you to customize the way the attributes are parsed in the definition and invocation of the operator. In this case, the function
add_n()is a bit tricky since we want to allow an arbitrary number of input arguments.Let's look at
ElementWiseSumParammore closely to see what's going on,struct ElementWiseSumParam : public dmlc::Parameter<ElementWiseSumParam> { int num_args; DMLC_DECLARE_PARAMETER(ElementWiseSumParam) { DMLC_DECLARE_FIELD(num_args).set_lower_bound(1) .describe("Number of inputs to be summed."); } }; DMLC_REGISTER_PARAMETER(ElementWiseSumParam);
Right off the bat we see that the struct makes use of the Curiously Recurring Template Pattern which should make you feel like you're a rockstar C++ programmer. (Basically, CRTP is a compile-time polymorphism technique.) Long story short, the parameter inherits from
dmlc::Parameterwhich is a lightweight parameter management system for operators and objects.Inside this struct we declare the parameter(s) we want to manage in the construction of the operator, in this case
num_args. (Again, the whole point of this is so that we can provide the function a variable number of arguments.)DMLC_DECLARE_PARAMETER()is a macro for augmenting a particular parameter. In this case, we want to provide a description and set a lower bound.Note that parameters need to be registered separately from the operator.
This is all done so we can define the next two attributes...
-
.set_num_inputs(int OR std::function<uint32_t (const NodeAttrs& attr)>)Set the number of inputs to the operator. That's it! For some operators the number of inputs is fixed and known. In this case, all you need to do is provide a hard coded integer, here. (e.g.
.set_num_inputs(2)for an operator with two arguments/inputs)But that's not really it because here we are already in the deep end and want to define the way get the number of inputs using our custom parameter parser. The code used in
add_n()is repeated here for convenience,.set_num_inputs([](const nnvm::NodeAttrs& attrs) { uint32_t ret = dmlc::get<ElementWiseSumParam>(attrs.parsed).num_args; return ret; })
As the prototype requests we provide a function that accepts a
NodeAttrsstruct, which in this case are of typeElementWiseSumParam, and returns an unsigned integer. The lambda function above gets theElementWiseSumParamand returns itsnum_argswhich is set to the number of arguments passed to the operator. *(TODO: at what point isnum_argsset?) -
set_attr<nnvm::FListInputNames>There is a class of templatized operator attributes. See here for a pre-defined list. The function
.set_attr<T>(here be dragons at this link) accepts three arguments:const string& attr_name- the name of the attributeconst T& value- the value to set this attribute toint plevel- the priority level of this attribute. If the operator inherits from another operator this tells the compiler which definition of the attribute to use. The priority level is set to 10 by default.
So that being said,
add_n()defines an attribute called"FListInputNames"where the value is the same type asnnvm::FListInputNames: a function which accepts aconst NodeAttrs&, which in this case is interpreted as anElementWiseSumParam, and returns a vector of strings. We can see in the attribute definition below that the function simply extractsnum_argsfrom theElementWiseSumParamand returns the vector["arg0", "arg1", ..., "arg(num_args-1)"].The point of this is to enables automatic variable creation for missing arguments.