Skip to content

Instantly share code, notes, and snippets.

@cswiercz
Last active August 3, 2017 17:30
Show Gist options
  • Select an option

  • Save cswiercz/37ec48f44a713fb3676480a8706644ed to your computer and use it in GitHub Desktop.

Select an option

Save cswiercz/37ec48f44a713fb3676480a8706644ed to your computer and use it in GitHub Desktop.

Revisions

  1. cswiercz revised this gist Aug 3, 2017. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion understanding_elemwise_sum.md
    Original file line number Diff line number Diff line change
    @@ -2,7 +2,7 @@

    1. [Introduction](#introduction)
    1. [Registering and Operator](#registering-an-operator)
    1. [Operator Registration Attributes](#operator-registration-atributes)
    1. [Operator Registration Attributes](#operator-registration-atributes)
    1. [Inferring Shapes](#inferring-shapes)
    1. [Inferring Types](#inferring-types)

  2. cswiercz revised this gist Aug 3, 2017. 1 changed file with 3 additions and 3 deletions.
    6 changes: 3 additions & 3 deletions understanding_elemwise_sum.md
    Original file line number Diff line number Diff line change
    @@ -1,10 +1,10 @@
    # Table of Contents

    1. [Introduction](#introduction)
    2. [Registering and Operator](#registering-an-operator)
    1. [Registering and Operator](#registering-an-operator)
    1. [Operator Registration Attributes](#operator-registration-atributes)
    3. [Inferring Shapes](#inferring-shapes)
    4. [Inferring Types](#inferring-types)
    1. [Inferring Shapes](#inferring-shapes)
    1. [Inferring Types](#inferring-types)

    # Introduction

  3. cswiercz revised this gist Aug 3, 2017. 1 changed file with 10 additions and 2 deletions.
    12 changes: 10 additions & 2 deletions understanding_elemwise_sum.md
    Original file line number Diff line number Diff line change
    @@ -1,3 +1,11 @@
    # Table of Contents

    1. [Introduction](#introduction)
    2. [Registering and Operator](#registering-an-operator)
    1. [Operator Registration Attributes](#operator-registration-atributes)
    3. [Inferring Shapes](#inferring-shapes)
    4. [Inferring Types](#inferring-types)

    # Introduction

    This document is a sort of guide to creating an MXNet operator. However, here we
    @@ -263,11 +271,11 @@ and exand upon them.

    See [Inferring Types](#inferringtypes).

    # <a name="inferringshapes"></a>Inferring Shapes
    # Inferring Shapes

    foo

    # <a name="inferringtypes"></a>Inferring Types
    # Inferring Types


    bar
  4. cswiercz revised this gist Aug 3, 2017. 1 changed file with 159 additions and 17 deletions.
    176 changes: 159 additions & 17 deletions understanding_elemwise_sum.md
    Original file line number Diff line number Diff line change
    @@ -1,8 +1,15 @@
    # Introduction

    This document is a sort of guide to creating an MXNet operator. However, here we will learn by exploring an example: `elemwise_sum`. ([`.h`](https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/elemwise_sum.h), [`.c`](https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/elemwise_sum.cc), [`.cu`](https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/elemwise_sum.cu))
    This document is a sort of guide to creating an MXNet operator. However, here we
    will learn by exploring an example: `elemwise_sum`.
    ([`.h`](https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/elemwise_sum.h),
    [`.c`](https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/elemwise_sum.cc),
    [`.cu`](https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/elemwise_sum.cu))

    The code in `elemwise_sum.{h,cc,cu}` defines and registers the [`add_n()`](http://mxnet.io/api/python/ndarray.html#mxnet.ndarray.add_n) operator, which computes the elementwise sum of an arbitrary number of input arguments of the same shape.
    The code in `elemwise_sum.{h,cc,cu}` defines and registers
    the [`add_n()`](http://mxnet.io/api/python/ndarray.html#mxnet.ndarray.add_n)
    operator, which computes the elementwise sum of an arbitrary number of input
    arguments of the same shape.

    ```python
    >>> import mxnet as mx
    @@ -16,9 +23,17 @@ array([ 12., 15., 18.], dtype=float32)

    # Registering an Operator

    Before jumping into the workhorse part of the code let's see how an operator is registered with MXNet. Once we understand this we'll see how all of the consituents of the implementation are linked together.
    Before jumping into the workhorse part of the code let's see how an operator is
    registered with MXNet. Once we understand this we'll see how all of the
    consituents of the implementation are linked together.

    Registration is done using the macro `NNVM_REGISTER_OP`. The code for registering `add_n` can be found in [`elemwise_sum.c`](https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/elemwise_sum.cc) and [`elemwise_sum.cu`](https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/elemwise_sum.cu). Since the operator registration on the GPU-side inherits most of its parameters from the CPU-side parameter we will focus on the contents of `elemwise_sum.cc`:
    Registration is done using the macro `NNVM_REGISTER_OP`. The code for
    registering `add_n` can be found
    in
    [`elemwise_sum.c`](https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/elemwise_sum.cc) and
    [`elemwise_sum.cu`](https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/elemwise_sum.cu).
    Since the operator registration on the GPU-side inherits most of its parameters
    from the CPU-side parameter we will focus on the contents of `elemwise_sum.cc`:

    ```c++
    NNVM_REGISTER_OP(add_n)
    @@ -58,11 +73,25 @@ There is a lot to parse here, but don't worry, we'll take each component one at
    ## Operator Registration Attributes
    Let's look at each of the attributes in the above operator definition one by one. For reference, the source code for the macro and basic operator structure/attributes can be found in NNVM's [`op.h`](https://github.com/dmlc/nnvm/blob/master/include/nnvm/op.h). Also see [`op_attr_types.h`](https://github.com/dmlc/nnvm/blob/master/include/nnvm/op_attr_types.h) for a list of the different types of attributes. If you're already really comfortable with working with MXNet then go ahead and read these two source codes. My goal here is to take some of the documentation listed in these files and exand upon them.
    Let's look at each of the attributes in the above operator definition one by
    one. For reference, the source code for the macro and basic operator
    structure/attributes can be found in
    NNVM's [`op.h`](https://github.com/dmlc/nnvm/blob/master/include/nnvm/op.h).
    Also
    see
    [`op_attr_types.h`](https://github.com/dmlc/nnvm/blob/master/include/nnvm/op_attr_types.h) for
    a list of the different types of attributes. If you're already really
    comfortable with working with MXNet then go ahead and read these two source
    codes. My goal here is to take some of the documentation listed in these files
    and exand upon them.
    * `NNVM_REGISTER_OP(add_n)`
    This registers the name of the operator. When built, you can invoke the operator from the Python interface, for example, from `mx.nd.add_n()` or `mx.sym.add_n()`. At the C++ level, the operator would be invoked using `Op::Get(add_n)`. For example, one could write the following to make the operator easier to use:
    This registers the name of the operator. When built, you can invoke the
    operator from the Python interface, for example, from `mx.nd.add_n()` or
    `mx.sym.add_n()`. At the C++ level, the operator would be invoked using
    `Op::Get(add_n)`. For example, one could write the following to make the
    operator easier to use:
    ```c++
    using namespace mxnet:op; // TODO: is this the right one?
    @@ -72,15 +101,20 @@ Let's look at each of the attributes in the above operator definition one by one

    * `.add_alias("ElementWiseSum")`

    Register an alias for the operator, allowing you to invoke it at the C++ level by writing `Op::Get(ElementWiseSum)`.
    Register an alias for the operator, allowing you to invoke it at the C++ level
    by writing `Op::Get(ElementWiseSum)`.

    * `.describe(R"...")`

    Easy enough: include a docstring. Parsed as reStructured text syntax for the purposes of generating online documentation. See [here](http://www.sphinx-doc.org/en/stable/rest.html) for a reST primer.
    Easy enough: include a docstring. Parsed as reStructured text syntax for the
    purposes of generating online documentation.
    See [here](http://www.sphinx-doc.org/en/stable/rest.html) for a reST primer.

    * `.set_attr_parser(ParamParser<dmlc::Parameter>)`

    Allows you to customize the way the attributes are parsed in the definition and invocation of the operator. In this case, the function `add_n()` is a bit tricky since we want to allow an arbitrary number of input arguments.
    Allows you to customize the way the attributes are parsed in the definition
    and invocation of the operator. In this case, the function `add_n()` is a bit
    tricky since we want to allow an arbitrary number of input arguments.

    Let's look at `ElementWiseSumParam` more closely to see what's going on,

    @@ -96,19 +130,37 @@ Let's look at each of the attributes in the above operator definition one by one
    DMLC_REGISTER_PARAMETER(ElementWiseSumParam);
    ```
    Right off the bat we see that the struct makes use of the [Curiously Recurring Template Pattern](https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern) which should make you feel like you're a rockstar C++ programmer. (Basically, CRTP is a compile-time polymorphism technique.) Long story short, the parameter inherits from [`dmlc::Parameter`](https://github.com/dmlc/dmlc-core/blob/master/include/dmlc/parameter.h#L114) which is a lightweight parameter management system for operators and objects.
    Right off the bat we see that the struct makes use of
    the
    [Curiously Recurring Template Pattern](https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern) which
    should make you feel like you're a rockstar C++ programmer. (Basically, CRTP
    is a compile-time polymorphism technique.) Long story short, the parameter
    inherits
    from
    [`dmlc::Parameter`](https://github.com/dmlc/dmlc-core/blob/master/include/dmlc/parameter.h#L114) which
    is a lightweight parameter management system for operators and objects.
    Inside this struct we declare the parameter(s) we want to manage in the construction of the operator, in this case `num_args`. (Again, the whole point of this is so that we can provide the function a variable number of arguments.) `DMLC_DECLARE_PARAMETER()` is a macro for augmenting a particular parameter. In this case, we want to provide a description and set a lower bound.
    Inside this struct we declare the parameter(s) we want to manage in the
    construction of the operator, in this case `num_args`. (Again, the whole point
    of this is so that we can provide the function a variable number of
    arguments.) `DMLC_DECLARE_PARAMETER()` is a macro for augmenting a particular
    parameter. In this case, we want to provide a description and set a lower
    bound.
    Note that parameters need to be registered separately from the operator.
    This is all done so we can define the next two attributes...
    * `.set_num_inputs(int OR std::function<uint32_t (const NodeAttrs& attr)>)`
    Set the number of inputs to the operator. That's it! For some operators the number of inputs is fixed and known. In this case, all you need to do is provide a hard coded integer, here. (e.g. `.set_num_inputs(2)` for an operator with two arguments/inputs)
    Set the number of inputs to the operator. That's it! For some operators the
    number of inputs is fixed and known. In this case, all you need to do is
    provide a hard coded integer, here. (e.g. `.set_num_inputs(2)` for an operator
    with two arguments/inputs)
    But that's not really it because here we are already in the deep end and want to define the way get the number of inputs using our custom parameter parser. The code used in `add_n()` is repeated here for convenience,
    But that's not really it because here we are already in the deep end and want
    to define the way get the number of inputs using our custom parameter parser.
    The code used in `add_n()` is repeated here for convenience,
    ```c++
    .set_num_inputs([](const nnvm::NodeAttrs& attrs) {
    @@ -117,15 +169,105 @@ Let's look at each of the attributes in the above operator definition one by one
    })
    ```

    As the prototype requests we provide a function that accepts a `NodeAttrs` struct, which in this case are of type `ElementWiseSumParam`, and returns an unsigned integer. The lambda function above gets the `ElementWiseSumParam` and returns its `num_args` which is set to the number of arguments passed to the operator. *(TODO: at what point is `num_args` set?)
    As the prototype requests we provide a function that accepts a `NodeAttrs`
    struct, which in this case are of type `ElementWiseSumParam`, and returns an
    unsigned integer. The lambda function above gets the `ElementWiseSumParam` and
    returns its `num_args` which is set to the number of arguments passed to the
    operator. *(TODO: at what point is `num_args` set?)

    * `set_attr<nnvm::FListInputNames>`

    There is a class of templatized operator attributes. See [here](https://github.com/dmlc/nnvm/blob/master/include/nnvm/op_attr_types.h) for a pre-defined list. The function [`.set_attr<T>`](https://github.com/dmlc/nnvm/blob/master/include/nnvm/op.h#L419) (here be dragons at this link) accepts three arguments:
    There is a class of templatized operator attributes.
    See
    [here](https://github.com/dmlc/nnvm/blob/master/include/nnvm/op_attr_types.h)
    for a pre-defined list. The
    function
    [`.set_attr<T>`](https://github.com/dmlc/nnvm/blob/master/include/nnvm/op.h#L419) (here
    be dragons at this link) accepts three arguments:
    * `const string& attr_name` - the name of the attribute
    * `const T& value` - the value to set this attribute to
    * `int plevel` - the priority level of this attribute. If the operator inherits from another operator this tells the compiler which definition of the attribute to use. The priority level is set to 10 by default.

    So that being said, `add_n()` defines an attribute called `"FListInputNames"` where the value is the same type as [`nnvm::FListInputNames`](https://github.com/dmlc/nnvm/blob/master/include/nnvm/op_attr_types.h#L31): a function which accepts a `const NodeAttrs&`, which in this case is interpreted as an `ElementWiseSumParam`, and returns a vector of strings. We can see in the attribute definition below that the function simply extracts `num_args` from the `ElementWiseSumParam` and returns the vector `["arg0", "arg1", ..., "arg(num_args-1)"]`.
    So that being said, `add_n()` defines an attribute called `"FListInputNames"`
    where the value is the same type
    as
    [`nnvm::FListInputNames`](https://github.com/dmlc/nnvm/blob/master/include/nnvm/op_attr_types.h#L31):
    a function which accepts a `const NodeAttrs&`, which in this case is
    interpreted as an `ElementWiseSumParam`, and returns a vector of strings. We
    can see in the attribute definition below that the function simply extracts
    `num_args` from the `ElementWiseSumParam` and returns the vector `["arg0",
    "arg1", ..., "arg(num_args-1)"]`.

    The point of this is to enables automatic variable creation for missing arguments.
    The point of this is to enables automatic variable creation for missing
    arguments.

    * `.set_attr<std::string>("key_var_num_args", "num_args")`

    This not well documented but examining the source code it seems to be a hint
    to the docstring generator that this function accepts a variable number of
    arguments. Seems to only be needed in this kind of situation.

    * `.set_attr<FCompute>("FCompute<cpu>", ElementWiseSumCompute<cpu>)`

    ***This is a key attribute to a new operator!*** Assigning `FCompute` tells
    MXNet which function to call when the operator is called. That is, `add_n()`
    is, more or less, the function `ElementWiseSumCompute<>()` but with some
    layers of pre-processing. Later in the document we'll talk about this function
    in more detail, but I'll show the `FCompute` function prototype anyway,

    ```c++
    using FCompute = std::function<void (const nnvm::NodeAttrs& attrs,
    const OpContext& ctx,
    const std::vector<TBlob>& inputs,
    const std::vector<OpReqType>& req,
    const std::vector<TBlob>& outputs)>;
    ```

    This doesn't look like what we would use an input to `add_n()`. Again, we'll
    get back to this later.

    * `.set_attr<nnvm::FInplaceOption>`

    Operators have the option of performing computations in-place. That is, you
    can optionally store the output of operator in the memory already occupied by
    one of the inputs.
    The
    [prototype for `FInplaceOption`](https://github.com/dmlc/nnvm/blob/master/include/nnvm/op_attr_types.h#L115) is,

    ```c++
    using FInplaceOption = std::function<
    std::vector<std::pair<int, int> > (const NodeAttrs& attrs)>;
    ```

    Basically, the value of `FInplaceOption` is a function mapping the attributes
    of this compute node, which in this case is of type `ElementWiseSumParam`, to
    a list of two-tuples. Each tuple `{i,j}` defines a map from input `i` to
    output `j`. That is, the memory location of input `i` is the same as the
    memory location of output `j`.

    For `add_n()` the `FInplaceOption` function always returns `{{0,0}}`, meaning
    that no matter how many arguments are passed we only map the first input to
    the first (and only) output. This makes sense since `add_n()` is a variable
    argument function and, by the definition of `ElementWiseSumParam`, will always
    have at least one input.

    The nice thing about this design is that we can store the result of our
    computation can be stored in the appropriate output pointer and MXNet will
    take care of the in-place'edness.

    * `.set_attr<nnvm::FInferShape>`

    See [Inferring Shapes](#inferringshapes).

    * `.set_attr<nnvm::FInferType>`

    See [Inferring Types](#inferringtypes).

    # <a name="inferringshapes"></a>Inferring Shapes

    foo

    # <a name="inferringtypes"></a>Inferring Types


    bar
  5. cswiercz revised this gist Aug 2, 2017. 1 changed file with 19 additions and 1 deletion.
    20 changes: 19 additions & 1 deletion understanding_elemwise_sum.md
    Original file line number Diff line number Diff line change
    @@ -92,10 +92,17 @@ Let's look at each of the attributes in the above operator definition one by one
    .describe("Number of inputs to be summed.");
    }
    };

    DMLC_REGISTER_PARAMETER(ElementWiseSumParam);
    ```
    Right off the bat we see that the struct makes use of the [Curiously Recurring Template Pattern](https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern) which should make you feel like you're a rockstar C++ programmer. (Basically, CRTP is a compile-time polymorphism technique.) Long story short, the parameter inherits from [`dmlc::Parameter`](https://github.com/dmlc/dmlc-core/blob/master/include/dmlc/parameter.h#L114) which is a lightweight parameter management system for operators and objects.
    Inside this struct we declare the parameter(s) we want to manage in the construction of the operator, in this case `num_args`. (Again, the whole point of this is so that we can provide the function a variable number of arguments.) `DMLC_DECLARE_PARAMETER()` is a macro for augmenting a particular parameter. In this case, we want to provide a description and set a lower bound.
    Note that parameters need to be registered separately from the operator.
    This is all done so we can define the next two attributes...
    * `.set_num_inputs(int OR std::function<uint32_t (const NodeAttrs& attr)>)`
    @@ -110,4 +117,15 @@ Let's look at each of the attributes in the above operator definition one by one
    })
    ```

    As the prototype requests we provide a function that accepts a list of attributes, which in this case are of type `ElementWiseSumParam`, and returns an unsigned integer.
    As the prototype requests we provide a function that accepts a `NodeAttrs` struct, which in this case are of type `ElementWiseSumParam`, and returns an unsigned integer. The lambda function above gets the `ElementWiseSumParam` and returns its `num_args` which is set to the number of arguments passed to the operator. *(TODO: at what point is `num_args` set?)

    * `set_attr<nnvm::FListInputNames>`

    There is a class of templatized operator attributes. See [here](https://github.com/dmlc/nnvm/blob/master/include/nnvm/op_attr_types.h) for a pre-defined list. The function [`.set_attr<T>`](https://github.com/dmlc/nnvm/blob/master/include/nnvm/op.h#L419) (here be dragons at this link) accepts three arguments:
    * `const string& attr_name` - the name of the attribute
    * `const T& value` - the value to set this attribute to
    * `int plevel` - the priority level of this attribute. If the operator inherits from another operator this tells the compiler which definition of the attribute to use. The priority level is set to 10 by default.

    So that being said, `add_n()` defines an attribute called `"FListInputNames"` where the value is the same type as [`nnvm::FListInputNames`](https://github.com/dmlc/nnvm/blob/master/include/nnvm/op_attr_types.h#L31): a function which accepts a `const NodeAttrs&`, which in this case is interpreted as an `ElementWiseSumParam`, and returns a vector of strings. We can see in the attribute definition below that the function simply extracts `num_args` from the `ElementWiseSumParam` and returns the vector `["arg0", "arg1", ..., "arg(num_args-1)"]`.

    The point of this is to enables automatic variable creation for missing arguments.
  6. cswiercz revised this gist Aug 2, 2017. 1 changed file with 31 additions and 1 deletion.
    32 changes: 31 additions & 1 deletion understanding_elemwise_sum.md
    Original file line number Diff line number Diff line change
    @@ -80,4 +80,34 @@ Let's look at each of the attributes in the above operator definition one by one

    * `.set_attr_parser(ParamParser<dmlc::Parameter>)`

    Allows you to customize the way the attributes are parsed in the definition and invocation of the operator. In this case, the function `add_n()` is a bit tricky since we want to allow an arbitrary number of input arguments.
    Allows you to customize the way the attributes are parsed in the definition and invocation of the operator. In this case, the function `add_n()` is a bit tricky since we want to allow an arbitrary number of input arguments.

    Let's look at `ElementWiseSumParam` more closely to see what's going on,

    ```c++
    struct ElementWiseSumParam : public dmlc::Parameter<ElementWiseSumParam> {
    int num_args;
    DMLC_DECLARE_PARAMETER(ElementWiseSumParam) {
    DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
    .describe("Number of inputs to be summed.");
    }
    };
    ```
    Right off the bat we see that the struct makes use of the [Curiously Recurring Template Pattern](https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern) which should make you feel like you're a rockstar C++ programmer. (Basically, CRTP is a compile-time polymorphism technique.) Long story short, the parameter inherits from [`dmlc::Parameter`](https://github.com/dmlc/dmlc-core/blob/master/include/dmlc/parameter.h#L114) which is a lightweight parameter management system for operators and objects.
    * `.set_num_inputs(int OR std::function<uint32_t (const NodeAttrs& attr)>)`
    Set the number of inputs to the operator. That's it! For some operators the number of inputs is fixed and known. In this case, all you need to do is provide a hard coded integer, here. (e.g. `.set_num_inputs(2)` for an operator with two arguments/inputs)
    But that's not really it because here we are already in the deep end and want to define the way get the number of inputs using our custom parameter parser. The code used in `add_n()` is repeated here for convenience,
    ```c++
    .set_num_inputs([](const nnvm::NodeAttrs& attrs) {
    uint32_t ret = dmlc::get<ElementWiseSumParam>(attrs.parsed).num_args;
    return ret;
    })
    ```

    As the prototype requests we provide a function that accepts a list of attributes, which in this case are of type `ElementWiseSumParam`, and returns an unsigned integer.
  7. cswiercz revised this gist Aug 2, 2017. 1 changed file with 15 additions and 2 deletions.
    17 changes: 15 additions & 2 deletions understanding_elemwise_sum.md
    Original file line number Diff line number Diff line change
    @@ -58,13 +58,26 @@ There is a lot to parse here, but don't worry, we'll take each component one at
    ## Operator Registration Attributes
    Let's look at each of the attributes in the above operator definition one by one. For reference, the source code for the macro and basic operator structure/attributes can be found in NNVM's [`op.h`](https://github.com/dmlc/nnvm/blob/master/include/nnvm/op.h).
    Let's look at each of the attributes in the above operator definition one by one. For reference, the source code for the macro and basic operator structure/attributes can be found in NNVM's [`op.h`](https://github.com/dmlc/nnvm/blob/master/include/nnvm/op.h). Also see [`op_attr_types.h`](https://github.com/dmlc/nnvm/blob/master/include/nnvm/op_attr_types.h) for a list of the different types of attributes. If you're already really comfortable with working with MXNet then go ahead and read these two source codes. My goal here is to take some of the documentation listed in these files and exand upon them.
    * `NNVM_REGISTER_OP(add_n)`
    This registers the name of the operator. When built, you can invoke the operator from the Python interface, for example, from `mx.nd.add_n()` or `mx.sym.add_n()`. At the C++ level, the operator would be invoked using `Op::Get(add_n)`.
    This registers the name of the operator. When built, you can invoke the operator from the Python interface, for example, from `mx.nd.add_n()` or `mx.sym.add_n()`. At the C++ level, the operator would be invoked using `Op::Get(add_n)`. For example, one could write the following to make the operator easier to use:
    ```c++
    using namespace mxnet:op; // TODO: is this the right one?
    const Op* add_n = Op::Get("add_n");
    // use add_n below using an OpKernel (to be discussed below)
    ```

    * `.add_alias("ElementWiseSum")`

    Register an alias for the operator, allowing you to invoke it at the C++ level by writing `Op::Get(ElementWiseSum)`.

    * `.describe(R"...")`

    Easy enough: include a docstring. Parsed as reStructured text syntax for the purposes of generating online documentation. See [here](http://www.sphinx-doc.org/en/stable/rest.html) for a reST primer.

    * `.set_attr_parser(ParamParser<dmlc::Parameter>)`

    Allows you to customize the way the attributes are parsed in the definition and invocation of the operator. In this case, the function `add_n()` is a bit tricky since we want to allow an arbitrary number of input arguments.
  8. cswiercz created this gist Aug 2, 2017.
    70 changes: 70 additions & 0 deletions understanding_elemwise_sum.md
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,70 @@
    # Introduction

    This document is a sort of guide to creating an MXNet operator. However, here we will learn by exploring an example: `elemwise_sum`. ([`.h`](https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/elemwise_sum.h), [`.c`](https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/elemwise_sum.cc), [`.cu`](https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/elemwise_sum.cu))

    The code in `elemwise_sum.{h,cc,cu}` defines and registers the [`add_n()`](http://mxnet.io/api/python/ndarray.html#mxnet.ndarray.add_n) operator, which computes the elementwise sum of an arbitrary number of input arguments of the same shape.

    ```python
    >>> import mxnet as mx
    >>> a = mx.nd.array([1,2,3])
    >>> b = mx.nd.array([4,5,6])
    >>> c = mx.nd.array([7,8,9])
    >>> x = mx.nd.add_n(a,b,c)
    >>> x.asnumpy()
    array([ 12., 15., 18.], dtype=float32)
    ```

    # Registering an Operator

    Before jumping into the workhorse part of the code let's see how an operator is registered with MXNet. Once we understand this we'll see how all of the consituents of the implementation are linked together.

    Registration is done using the macro `NNVM_REGISTER_OP`. The code for registering `add_n` can be found in [`elemwise_sum.c`](https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/elemwise_sum.cc) and [`elemwise_sum.cu`](https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/elemwise_sum.cu). Since the operator registration on the GPU-side inherits most of its parameters from the CPU-side parameter we will focus on the contents of `elemwise_sum.cc`:

    ```c++
    NNVM_REGISTER_OP(add_n)
    .add_alias("ElementWiseSum")
    .describe(R"doc(Adds all input arguments element-wise.
    .. math::
    add\_n(a_1, a_2, ..., a_n) = a_1 + a_2 + ... + a_n
    ``add_n`` is potentially more efficient than calling ``add`` by `n` times.
    )doc" ADD_FILELINE)
    .set_attr_parser(ParamParser<ElementWiseSumParam>)
    .set_num_inputs([](const nnvm::NodeAttrs& attrs) {
    uint32_t ret = dmlc::get<ElementWiseSumParam>(attrs.parsed).num_args;
    return ret;
    })
    .set_attr<nnvm::FListInputNames>("FListInputNames",
    [](const NodeAttrs& attrs) {
    uint32_t num_args = dmlc::get<ElementWiseSumParam>(attrs.parsed).num_args;
    std::vector<std::string> ret;
    for (uint32_t i = 0; i < num_args; ++i) {
    ret.push_back(std::string("arg") + std::to_string(i));
    }
    return ret;
    })
    .set_attr<std::string>("key_var_num_args", "num_args")
    .set_attr<FCompute>("FCompute<cpu>", ElementWiseSumCompute<cpu>)
    .set_attr<nnvm::FInplaceOption>(
    "FInplaceOption", [](const NodeAttrs& attrs) {
    return std::vector<std::pair<int, int> >{{0, 0}};
    })
    .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<-1, 1>)
    .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, 1>)
    .set_attr<nnvm::FGradient>("FGradient", CloneGradient{"_backward_add_n"})
    .add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments");
    ```
    There is a lot to parse here, but don't worry, we'll take each component one at a time.
    ## Operator Registration Attributes
    Let's look at each of the attributes in the above operator definition one by one. For reference, the source code for the macro and basic operator structure/attributes can be found in NNVM's [`op.h`](https://github.com/dmlc/nnvm/blob/master/include/nnvm/op.h).
    * `NNVM_REGISTER_OP(add_n)`
    This registers the name of the operator. When built, you can invoke the operator from the Python interface, for example, from `mx.nd.add_n()` or `mx.sym.add_n()`. At the C++ level, the operator would be invoked using `Op::Get(add_n)`.
    * `.add_alias("ElementWiseSum")`
    Register an alias for the operator, allowing you to invoke it at the C++ level by writing `Op::Get(ElementWiseSum)`.