Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Created March 16, 2026 23:49
Show Gist options
  • Select an option

  • Save ricardoV94/809a2c25ec61d149209336ab0add6f30 to your computer and use it in GitHub Desktop.

Select an option

Save ricardoV94/809a2c25ec61d149209336ab0add6f30 to your computer and use it in GitHub Desktop.

How Scan.L_op Works

Summary

Scan.L_op computes the gradient of a Scan operation by constructing a new backward Scan that runs in reverse, propagating gradients through time. The method:

  1. Differentiates the inner function symbolically to get per-step gradients
  2. Adds accumulation terms for recurrent states
  3. Constructs a new backward Scan with reversed sequences and mit-mot states (initialized with output gradients, accumulate total gradients after evaluation)
  4. Re-orders the backward Scan's outputs to match the expected gradient layout

The backward Scan always converts all recurrent states (sit-sot, mit-sot, mit-mot) into mit-mot form, which is the most general recurrence type.

Background: Scan State Types

A Scan operates on a state buffer. Each state type describes how the inner function reads/writes:

Type Input taps Output Example
sit-sot single (e.g. -1) single (current step) x[t] = f(x[t-1])
mit-sot multiple (e.g. -2, -1) single (current step) x[t] = f(x[t-2], x[t-1])
mit-mot multiple multiple general read/write at multiple offsets
nit-sot none single pure output, no recurrence

Taps are offsets relative to the current step s into the state buffer. Tap -1 means buffer[s-1], tap 0 means buffer[s], tap 1 means buffer[s+1].

In the user-facing API, taps are always negative (read from the past). Positive/zero taps arise internally in gradient scans.

Phase 1: Determine Gradient Step Count

info = self.info
if info.n_nit_sot > 0:
    grad_steps = self.outer_nitsot_outs(outs)[0].shape[0]
elif info.n_sit_sot > 0:
    grad_steps = self.outer_sitsot_outs(outs)[0].shape[0] - 1
elif info.n_mit_sot > 0:
    grad_steps = (
        self.outer_mitsot_outs(outs)[0].shape[0] + self.mintaps[info.n_mit_mot]
    )
else:
    grad_steps = inputs[0]

The backward Scan runs for the same number of steps as the forward Scan. The step count is inferred from the output shapes rather than from inputs[0] directly. This is necessary because for while-loop scans, the actual number of steps executed may be less than the allocated buffer size — the output shape reflects the true step count. The - 1 for sit-sot and the + self.mintaps[...] for mit-sot account for the initial state entries that pad the output buffer beyond the number of computed steps.

Phase 2: Classify Inputs and Outputs

diff_inputs = (
    self.inner_seqs(self_inputs)
    + self.inner_mitmot(self_inputs)
    + self.inner_mitsot(self_inputs)
    + self.inner_sitsot(self_inputs)
    + self.inner_non_seqs(self_inputs)
)
diff_outputs = (
    self.inner_mitmot_outs(self_outputs)
    + self.inner_mitsot_outs(self_outputs)
    + self.inner_sitsot_outs(self_outputs)
    + self.inner_nitsot_outs(self_outputs)
)

Collects all differentiable inner inputs and outputs. The ordering matters: sequences first, then states (mit-mot, mit-sot, sit-sot), then non-sequences.

Phase 3: compute_all_gradients — Differentiate the Inner Function

def compute_all_gradients(known_grads):
    # ...
    known_grads = {k.copy(): v for (k, v) in known_grads.items()}
    grads = grad(
        cost=None,
        known_grads=known_grads,
        wrt=wrt,
        consider_constant=wrt,
        disconnected_inputs="ignore",
        return_disconnected="None",
        null_gradients="return",
    )
    # ...
    rval = [gmp.get(p, None) for p in diff_inputs]
    return rval

This differentiates the inner function symbolically using PyTensor's grad() with known_grads. The known_grads dict maps each inner output to a placeholder variable (dC_dXt) representing the output gradient signal. The result is the gradient of each inner output w.r.t. each inner input, weighted by dC_dXt.

The consider_constant=wrt prevents the gradient from flowing further back through the inputs (they'll be handled by the backward Scan's recurrence).

The .copy() on keys is required because grad() with known_grads works by injecting gradient values into the graph at the specified variables. If the same inner output variable y is used for two different scan outputs (e.g., a scan that returns the same expression twice), known_grads would have {y: dC_dXt_0 + dC_dXt_1} after the summing loop. But grad() matches known_grads keys by identity, not by value — it looks for the exact variable object in the computation graph. If we passed the original y, grad() would find it in the graph and attach the gradient. But y.copy() creates a distinct variable that grad() treats as a new node whose output "is" y. This forces grad() to propagate the combined gradient dC_dXt_0 + dC_dXt_1 backward through y's computation, rather than potentially short-circuiting when it encounters y as a wrt target.

Phase 4: Create Gradient Placeholders

dC_dXts = []
Xts = []
for idx, Xt in enumerate(diff_outputs):
    if idx >= n_mit_mot_outs:
        Xt_placeholder = safe_new(Xt)
        Xts.append(Xt_placeholder)
    # ... dtype handling ...
    dC_dXt = safe_new(Xt, dtype=new_dtype)
    dC_dXts.append(dC_dXt)

Creates two sets of placeholder variables:

  • dC_dXts: Gradient of the cost w.r.t. each inner output. These become inputs to the backward Scan (carrying the gradient signal at each time step).
  • Xts: Placeholder copies of the inner outputs (for non-mit-mot outputs). Used later by forced_replace to substitute inner outputs with sequence inputs in the backward Scan.

Phase 5: Build known_grads and Compute Inner Gradients

known_grads = {}
for i in range(len(diff_outputs)):
    if diff_outputs[i] in known_grads:
        known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx]
    else:
        known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx]
    dc_dxts_idx += 1

dC_dinps_t = compute_all_gradients(known_grads)

Calls compute_all_gradients to differentiate the inner function. The result dC_dinps_t is a list with one entry per diff_input, containing the symbolic gradient expression (or None if disconnected).

After computing, forced_replace substitutes inner output variables with placeholder variables (Xts):

for Xt, Xt_placeholder in zip(
    diff_outputs[info.n_mit_mot_outs:], Xts, strict=True
):
    tmp = forced_replace(dC_dinps_t[dx], Xt, Xt_placeholder)
    dC_dinps_t[dx] = tmp

This substitution is an optimization to avoid recomputing forward outputs inside the backward Scan. When an Op's L_op reuses subexpressions from the forward pass (e.g., exp(x).L_op returns output_gradient * exp(x)), If the L_op is clever enough to reuse the provided output variable, a simple graph_replace would correctly replace that by the inner input that corresponds to the output computed on the first forward pass. But if it happens to recreate a new (equivalent) exp(x), it would miss it, and would recompute it internally. forced_replace uses equal_computations to detect when a subexpression in the gradient graph matches an inner output on structure, not just identity.

This approach is not comprehensive: if the gradient expression computes the same value in a slightly different way (e.g., exp(x + 0) instead of exp(x)), equal_computations won't match and the forward value will be recomputed instead of reused. More generally, it can't reuse results returned by a distinct Scan output (say if the same scan has and 2x as outputs, the grad of could in theory make use of the 2x computed in the forward), or intermediate results that are not returned.

This is only done for non-mit-mot outputs (mit-sot, sit-sot, nit-sot). Mit-mot outputs are handled differently because they have multiple output taps and their values are already accessible through the mit-mot state buffer.

Phase 6: Create Accumulation Placeholders and Handle Overwritten Taps

6a: Create dC_dXtm1s — accumulated gradient placeholders

dC_dXtm1s = []
n_internal_recurrent_states = sum(
    len(t)
    for t in chain(
        info.mit_mot_in_slices, info.mit_sot_in_slices, info.sit_sot_in_slices,
    )
)
for pos, x in enumerate(dC_dinps_t[info.n_seqs:]):
    idxs = var_mappings["inner_out_from_inner_inp"][info.n_seqs + pos]
    x_is_state = pos < n_internal_recurrent_states
    if x_is_state and len(idxs) > 0:
        opos = idxs[0]
        dC_dXtm1s.append(safe_new(dC_dXts[opos]))
    else:
        dC_dXtm1s.append(safe_new(x))

For each state input, creates a placeholder for the "accumulated gradient from future steps." This placeholder will become an inner input of the backward Scan, carrying the gradient that flows backward through the recurrence.

6b: Skip accumulation for overwritten buffer positions

In a mit-mot, each buffer position that the inner function reads from may or may not also be written to at the same step. A tap whose index appears in both in_slices and out_slices corresponds to a buffer position that is overwritten — the output replaces the input value at that position. A tap that appears only in in_slices corresponds to a buffer position that is preserved — the value persists and may be read by future steps.

The code calls these "overlapping taps" (named overlapping_taps in the source):

overlapping_taps = set()
dx_offset = 0
for idx in range(info.n_mit_mot):
    in_taps = info.mit_mot_in_slices[idx]
    out_taps = info.mit_mot_out_slices[idx]
    for k, tap in enumerate(in_taps):
        if tap in out_taps:
            overlapping_taps.add(dx_offset + k)
    dx_offset += len(in_taps)

The detection is purely mechanical: does the same integer appear in both in_slices and out_slices? It doesn't matter where the scan came from.

Why do overwritten taps exist? They arise naturally when differentiating sit-sot or mit-sot scans. The L_op of a sit-sot converts it to a mit-mot where one position in the state buffer is both read and written. Concretely, for a sit-sot backward scan with in_taps=(0, 1), out_taps=(1,):

  • Tap 0 reads the accumulated gradient from the previous backward step (initially the cost gradient seed). This position is preserved — it acts as a read-only input.
  • Tap 1 reads the external cost gradient at the next buffer position. This position is overwritten by the output — the new accumulated value replaces the old external gradient.

This is a memory optimization: external gradients and accumulated results share a single buffer instead of using a separate sequence for the externals.

Why skip accumulation for overwritten positions? An overwritten buffer position is analogous to result = set_subtensor(x, y, i). The gradient of set_subtensor w.r.t. x zeroes out position i on the incoming output gradient, and routes the output gradient at position i entirely through y. In our case, y is the inner function's output expression (e.g., accumulated * f'(x) + external).

6c: Add accumulation for preserved taps

for dx, dC_dXtm1 in enumerate(dC_dXtm1s):
    if dx in overlapping_taps:
        continue
    # ... NullType handling ...
    else:
        dC_dinps_t[dx + info.n_seqs] += dC_dXtm1

For preserved taps, the value at that buffer position persists to future steps — it is read but never overwritten. The total gradient has two components: the gradient through the current step's output (from compute_all_gradients) plus the gradient from future steps that also read this position (carried by dC_dXtm1). The += dC_dXtm1 adds the latter.

Phase 7: Build Outer Sequence Inputs

outer_inp_seqs = [x[::-1] for x in inputs[1 : 1 + info.n_seqs]]
# ... add reversed forward state outputs for each tap ...
outer_inp_seqs += [x[:-1][::-1] for x in self.outer_sitsot_outs(outs)]
# ... add reversed nit-sot gradient signals ...
outer_inp_seqs = [s_[:grad_steps] for s_ in outer_inp_seqs]

The backward Scan needs the forward pass values as sequence inputs (to evaluate f'(x[t])). All sequences are reversed because the backward Scan processes time in reverse, and truncated to grad_steps length.

The sequences are built in groups, matching the order expected by the backward Scan's inner function (assembled from inner_inp_seqs):

Group Source Slice expression Purpose
Forward sequences inputs[1:1+n_seqs] x[::-1] External inputs u[t] the forward Scan received
Mit-mot/mit-sot state values per tap outer_mitmot_outs / outer_mitsot_outs x[offset:-offset][::-1] (tap-dependent) Forward state values x[t+tap] read by the inner function
Sit-sot state values outer_sitsot_outs x[:-1][::-1] Forward sit-sot output values, last element dropped
Nit-sot cost gradients dC_douts for nit-sot outputs x[::-1] External gradient signal dC/dout[t] at each time step
Forward mit-sot/sit-sot outputs (Xts) outer_mitsot_outs / outer_sitsot_outs x[::-1] Forward output values for forced_replace placeholders (Phase 5)

Phase 8: Build Backward Scan's Mit-Mot

This is the most complex part. Every recurrent state in the forward scan becomes a mit-mot in the backward scan. The conversion rule is:

  • Each output tap of the forward state becomes an input tap of the backward mit-mot (negated). This input carries the gradient signal.
  • Each input tap of the forward state becomes an output tap of the backward mit-mot (negated). This output is the gradient to propagate.
  • Each backward output tap is a recurrent state: the backward scan writes the accumulated gradient at that position, and at the next step it needs to read the previously accumulated value. So each backward output position also requires a backward input tap. If that position already has an input tap from the first rule, they share the buffer slot (the overwritten case). Otherwise a new input tap is created.

A tap that appears in both the backward in and out is overwritten (see Phase 6b). All others are preserved.

Forward type Forward taps Backward in Backward out Overwritten Preserved
sit-sot in=(-1,) out=(0,) (0, 1) (1,) tap 1 tap 0
mit-sot in=(-2,-1) in=(-2,-1) out=(0,) (0, 2, 1) (2, 1) taps 1, 2 tap 0
mit-mot in=(0,1) out=(1,) (-1, 0) (0, -1) taps -1, 0
mit-mot in=(-1,0) out=(0,-1) (0, 1) (1, 0) taps 0, 1

8a: Forward mit-mot → backward mit-mot

for idx, taps in enumerate(info.mit_mot_in_slices):
    outer_inp_mitmot.append(dC_douts[idx][::-1])
    # For each OUTPUT tap of the forward mit-mot:
    #   → create an INPUT to the backward scan (gradient signal)
    for mit_mot_out_slice in info.mit_mot_out_slices[idx]:
        inner_inp_mitmot.append(dC_dXts[out_pos])
        mitmot_inp_taps[idx].append(-mit_mot_out_slice)

    # For each INPUT tap of the forward mit-mot:
    #   → create an OUTPUT of the backward scan (gradient to propagate)
    for tap in taps:
        tap = -tap
        if tap not in mitmot_inp_taps[idx]:
            inner_inp_mitmot.append(dC_dXtm1s[ins_pos - info.n_seqs])
        # ... handle overwritten taps with clone_replace ...
        inner_out_mitmot.append(new_inner_out_mitmot)
        mitmot_out_taps[idx].append(tap)

When a negated input tap matches an existing backward input tap (the overwritten case), no new inner input is created — one already exists from the output-tap loop. Instead, the clone_replace block substitutes the dC_dXtm1s placeholder in the gradient expression with the existing inner input variable. This is a consequence of the memory optimization: since both the gradient signal and the accumulated state share the same buffer position, they must share the same inner input variable in the backward Scan's inner function.

8b: Forward mit-sot → backward mit-mot

for idx, taps in enumerate(info.mit_sot_in_slices):
    outer_inp_mitmot.append(dC_douts[idx + offset][::-1])
    mitmot_inp_taps[idx + offset].append(0)  # gradient signal at tap 0
    for tap in taps:
        tap = -tap
        inner_inp_mitmot.append(dC_dXtm1s[ins_pos - info.n_seqs])
        inner_out_mitmot.append(dC_dinps_t[ins_pos])
        mitmot_inp_taps[idx + offset].append(tap)
        mitmot_out_taps[idx].append(tap)

Similar to mit-mot but simpler — the forward output is always at the current step (implicit tap 0), so the backward scan gets a single input tap at 0 for the gradient signal, plus negated input taps for the outputs.

8c: Forward sit-sot → backward mit-mot

for idx in range(info.n_sit_sot):
    mitmot_inp_taps.append([0, 1])
    mitmot_out_taps.append([1])
    outer_inp_mitmot.append(dC_douts[idx + offset][::-1])
    inner_out_mitmot.append(dC_dinps_t[ins_pos])
    inner_inp_mitmot += [
        dC_dXts[out_pos],       # tap 0: gradient signal
        dC_dXtm1s[ins_pos - info.n_seqs],  # tap 1: accumulated gradient
    ]

A sit-sot (x[t] = f(x[t-1])) becomes a mit-mot with in_taps=(0, 1), out_taps=(1,):

  • Tap 0 (preserved): reads the accumulated gradient from the previous backward step (or initial seed)
  • Tap 1 (overwritten): reads the external cost gradient at the next position, then gets replaced by the new accumulated value
  • Output at tap 1: accumulated * f'(x) + external

The buffer stores both external gradients (initial values) and accumulated results (overwritten) — a memory optimization that avoids allocating a separate sequence.

Phase 9: Handle Remaining Outputs and Disconnected Gradients

n_nit_sot = info.n_seqs
inner_out_nitsot = dC_dinps_t[:info.n_seqs]       # gradients w.r.t. sequences
inner_out_sitsot = dC_dinps_t[ins_pos:]            # gradients w.r.t. non-sequences
  • nit-sot outputs of the backward Scan: gradients w.r.t. the forward Scan's sequence inputs. These are "no input tap" outputs — they don't feed back into the recurrence.
  • sit-sot outputs: gradients w.r.t. non-sequence inputs, accumulated over time steps.

Disconnected and null gradients

Throughout the gradient construction, some inputs may have None or NullType gradients:

  • None (disconnected): the output does not depend on that input at all. compute_all_gradients returns None for these. They are excluded from the backward Scan's inner function and produce a disconnected gradient in the final output.
  • NullType: the output depends on the input, but the gradient is not defined (e.g., integer-typed inputs). These are propagated as-is — the backward Scan does not attempt to differentiate through them. The accumulation step (Phase 6c) skips NullType entries.

When building the backward Scan's inner outputs and the final gradient list, None and NullType entries are filtered out so only truly differentiable paths are wired into the Scan. The Phase 11 re-ordering restores them to the correct positions as disconnected or null markers.

Untraced sit-sot (n_sit_sot_untraced)

The forward Scan may have untraced sit-sot states — recurrent states that are deliberately excluded from gradient computation. These are used for outputs that cannot be or are chosen not to be differentiated through, most commonly shared random number generator (RNG) states.

Untraced sit-sot states are not included in diff_inputs or diff_outputs, so compute_all_gradients never sees them. In the final output (Phase 11), their gradient slots are filled with a special through_untraced marker that signals "this input was intentionally not traced."

Phase 10: Construct and Run the Backward Scan

out_info = ScanInfo(
    n_seqs=len(outer_inp_seqs),
    mit_mot_in_slices=tuple(tuple(v) for v in mitmot_inp_taps),
    mit_mot_out_slices=tuple(tuple(v) for v in mitmot_out_taps),
    mit_sot_in_slices=(),        # always empty — everything is mit-mot
    sit_sot_in_slices=tuple((-1,) for k in range(n_sitsot_outs)),
    n_nit_sot=n_nit_sot,
    # ...
)

local_op = Scan(
    inner_gfn_ins, inner_gfn_outs, out_info,
    name=f"grad_of_{self.name}" if self.name else None,
    # ...
)
outputs = local_op(*outer_inputs, return_list=True)

Creates the backward Scan op with:

  • All forward state types converted to mit-mot
  • Reversed sequences as inputs
  • The inner gradient function as the step function

Phase 11: Re-order Outputs

The backward Scan's outputs are ordered as: mit-mot states, sit-sot states, nit-sot outputs. The L_op must re-order them to match the forward Scan's input ordering: [n_steps, sequences, mit-mot initial, mit-sot initial, sit-sot initial, nit-sot, untraced, non-sequences].

Each output is also reversed (the backward Scan produces values in reverse time order) and classified as connected, disconnected, through_untraced, or null.

Interaction with scan_save_mem

The scan_save_mem optimization shortens output buffers so that only the final k entries are stored (where k is however many the rest of the graph actually uses). For example, if a scan runs 10 steps but only xs[-1] is used, scan_save_mem can reduce the buffer from size 11 to size 2 (a circular buffer holding just the current and previous values).

This optimization would break gradients if applied before grad(). The backward scan needs all intermediate states as sequence inputs (to evaluate f'(x[t]) at each time step). With a truncated buffer, those intermediate values are lost and the gradient is silently wrong — the forward pass still produces the correct final value, but the backward pass cannot reconstruct the full history.

In practice this is safe because scan_save_mem is a compilation-time rewrite. The typical workflow is:

  1. Build the forward graph (full buffers)
  2. Call grad() on the un-optimized graph → L_op sees full buffers and builds the backward scan with all intermediate states wired as sequences
  3. Compile the combined forward+backward graph with FAST_RUNscan_save_mem runs and can safely truncate the forward scan's buffer within the gradient graph, since the backward scan already has its own copy of the forward outputs as inputs

If someone manually applies scan_save_mem to a graph and then calls grad(), the gradient will be incorrect. This is not a supported workflow.

Concrete Example: Sit-Sot Gradient

For a forward scan x[t] = x[t-1]² with n_steps=2:

Forward:

buf = [x, x², x⁴]

Backward (gradient scan, a mit-mot with in=(0,1), out=(1,)):

buf = [dC/dx₂, dC/dx₁, dC/dx₀]    (initialized with reversed cost gradient)
step 0: buf[1] = buf[0] * 2·x₁ + buf[1]    (overwrite)
step 1: buf[2] = buf[1] * 2·x₀ + buf[2]    (overwrite)

The buffer serves dual purpose:

  • Tap 0 (preserved): carries the accumulated gradient forward through the backward pass
  • Tap 1 (overwritten): initially holds the external cost gradient, gets replaced with the accumulated result

This is equivalent to the SSA form:

b' = a · f'(x₁) + b
c' = b' · f'(x₀) + c

where a = dC/dx₂ (the seed), b = dC/dx₁, c = dC/dx₀ (external cost gradients at each position). The final result c' is the total gradient dC/dx₀.

The overwrite and its gradient: In the buffer form, buf[1] is overwritten at step 0. The original value b is consumed and replaced by b' = a · f'(x₁) + b. When differentiating this scan (for higher-order derivatives), the gradient w.r.t. the initial value b must reflect that b only affects the result through the expression a · f'(x₁) + b. In SSA terms, b appears once (in the first assignment) and the chain rule handles everything naturally. In the buffer form, the overwrite makes this implicit: buf[1] before the overwrite and buf[1] after the overwrite are different values sharing the same slot. The L_op must not add an extra gradient contribution for the old buf[1] as if the original value also passes through unchanged to future steps — it doesn't, because it was overwritten. The only way b affects the result is through b', and the chain rule d(result)/d(b) = d(result)/d(b') · d(b')/d(b) already captures this.

Concrete Example: Gradient of the Gradient Scan (Mit-Mot L_op)

Consider the gradient scan from the previous example: a mit-mot with in=(0, 1), out=(1,) and inner function output = buf[s] * 2·x_seq[s] + buf[s+1], running for 2 steps.

Forward (gradient scan):

buf = [a, b, c]                                  (initialized with reversed cost gradient)
step 0: buf[1] = buf[0] * 2·x₁ + buf[1]         (reads a, b; overwrites b)
step 1: buf[2] = buf[1] * 2·x₀ + buf[2]         (reads b', c; overwrites c)
  • Tap 0 (preserved): buf[s] is read but never overwritten — carries the accumulated gradient forward
  • Tap 1 (overwritten): buf[s+1] is read and then replaced by the output

Constructing its L_op:

The inner function is output = buf[s] * 2·x_seq[s] + buf[s+1]. The L_op needs the gradient of this output w.r.t. each mit-mot input:

  • d(output)/d(buf[s]) = dC * 2·x_seq[s] — gradient w.r.t. the input at tap 0
  • d(output)/d(buf[s+1]) = dC — gradient w.r.t. the input at tap 1

where dC is the gradient signal placeholder (from known_grads).

Now the preserved/overwritten distinction determines whether to add an accumulation term:

  • Tap 0 is preserved in the forward scan (not in out_taps). The value at this buffer position persists — future steps also read it. So the total gradient is the chain-rule term plus an accumulation placeholder: dC * 2·x_seq[s] + dC_acc.
  • Tap 1 is overwritten in the forward scan (in both in_taps and out_taps). The value is replaced by the output — future steps never see the original. So the gradient is just the chain-rule term: dC. No accumulation.

Note that the chain-rule term for tap 0 has the same structure as the original sit-sot gradient (2·x times a gradient signal). This is expected: the gradient scan's inner function is itself a linear combination involving f'(x), so differentiating it naturally produces f'(x) again.

The resulting backward scan has in=(-1, 0), out=(0, -1) and inner function:

backward_out[tap 0]  = dC * 2·x_seq[s] + dC_acc     (gradient for the preserved input)
backward_out[tap -1] = dC                             (gradient for the overwritten input)

Backward scan inputs and buffers:

The backward scan takes as inputs:

  • Sequences (reversed forward values): x_seq provides the forward states needed to evaluate 2·x_seq[s]. These come from the forward gradient scan's output buffer, reversed.
  • Mit-mot buffer [p, q, r]: initialized with the reversed gradient of the 2nd-derivative cost w.r.t. the gradient scan's output. In a typical scalar 2nd derivative, this is one-hot (e.g., [1, 0, 0] — only the extracted position is nonzero).

The scan reads from taps -1 and 0, and writes to taps 0 and -1. The buffer positions are indexed absolutely (not Python-style wrapping), so a mit-mot with taps reaching -1 has buffer positions starting at -1:

         position:  -1   0   1
buf =             [ p,   q,  r ]      (p, q, r = reversed d(cost)/d(grad_scan_output))
step 0: pos[ 0] = pos[-1] * 2·x₁ + pos[0]
        pos[-1] = pos[-1]
step 1: pos[ 1] = pos[ 0] * 2·x₀ + pos[1]
        pos[ 0] = pos[ 0]

In this backward scan, both taps happen to be overwritten (both -1 and 0 appear in in and out). This means that if we were to differentiate this scan again (for the 3rd derivative), no accumulation would be added for either tap.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment