This document is a sort of guide to creating an MXNet operator. However, here we will learn by exploring an example: elemwise_sum. (.h, .c, .cu)
The code in elemwise_sum.{h,cc,cu} defines and registers the add_n() operator, which computes the elementwise sum of an arbitrary number of input arguments of the same shape.
>>> import mxnet as mx
>>> a = mx.nd.array([1,2,3])
>>> b = mx.nd.array([4,5,6])
>>> c = mx.nd.array([7,8,9])
>>> x = mx.nd.add_n(a,b,c)
>>> x.asnumpy()
array([ 12., 15., 18.], dtype=float32)Before jumping into the workhorse part of the code let's see how an operator is registered with MXNet. Once we understand this we'll see how all of the consituents of the implementation are linked together.
Registration is done using the macro NNVM_REGISTER_OP. The code for registering add_n can be found in elemwise_sum.c and elemwise_sum.cu. Since the operator registration on the GPU-side inherits most of its parameters from the CPU-side parameter we will focus on the contents of elemwise_sum.cc:
NNVM_REGISTER_OP(add_n)
.add_alias("ElementWiseSum")
.describe(R"doc(Adds all input arguments element-wise.
.. math::
add\_n(a_1, a_2, ..., a_n) = a_1 + a_2 + ... + a_n
``add_n`` is potentially more efficient than calling ``add`` by `n` times.
)doc" ADD_FILELINE)
.set_attr_parser(ParamParser<ElementWiseSumParam>)
.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
uint32_t ret = dmlc::get<ElementWiseSumParam>(attrs.parsed).num_args;
return ret;
})
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
uint32_t num_args = dmlc::get<ElementWiseSumParam>(attrs.parsed).num_args;
std::vector<std::string> ret;
for (uint32_t i = 0; i < num_args; ++i) {
ret.push_back(std::string("arg") + std::to_string(i));
}
return ret;
})
.set_attr<std::string>("key_var_num_args", "num_args")
.set_attr<FCompute>("FCompute<cpu>", ElementWiseSumCompute<cpu>)
.set_attr<nnvm::FInplaceOption>(
"FInplaceOption", [](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{0, 0}};
})
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<-1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<nnvm::FGradient>("FGradient", CloneGradient{"_backward_add_n"})
.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments");There is a lot to parse here, but don't worry, we'll take each component one at a time.
Let's look at each of the attributes in the above operator definition one by one. For reference, the source code for the macro and basic operator structure/attributes can be found in NNVM's op.h.
-
NNVM_REGISTER_OP(add_n)This registers the name of the operator. When built, you can invoke the operator from the Python interface, for example, from
mx.nd.add_n()ormx.sym.add_n(). At the C++ level, the operator would be invoked usingOp::Get(add_n). -
.add_alias("ElementWiseSum")Register an alias for the operator, allowing you to invoke it at the C++ level by writing
Op::Get(ElementWiseSum).