import tvm from tvm import tir from tvm.script import ty @tvm.script.tir def csr_spmm(indptr_: ty.handle, indices_: ty.handle, a_data: ty.handle, b: ty.handle, c: ty.handle) -> None: m = tir.var('int32') n = tir.var('int32') k = tir.var('int32') nnz = tir.var('int32') indptr = tir.match_buffer(indptr_, [m + 1], 'int32') indices = tir.match_buffer(indices_, [nnz], 'int32') A = tir.match_buffer(a_data, [nnz], 'float32') B = tir.match_buffer(b, [k, n], 'float32') C = tir.match_buffer(c, [m, n], 'float32') with tir.block([m, n], 'spmm_outer') as [vi, vj]: with tir.init(): C[vi, vj] = 0. with tir.block([tir.reduce_axis(indptr[vi], indptr[vi + 1])], 'spmm_inner') as [vk]: C[vi, vj] = C[vi, vj] + A[vk] * B[indices[vk], vj] @tvm.script.tir def csr_sddmm(row_: ty.handle, col_: ty.handle, a: ty.handle, b: ty.handle, c: ty.handle) -> None: m = tir.var('int32') n = tir.var('int32') k = tir.var('int32') nnz = tir.var('int32') row = tir.match_buffer(row_, [nnz,], 'int32') col = tir.match_buffer(col_, [nnz,], 'int32') A = tir.match_buffer(a, [m, k], 'float32') B = tir.match_buffer(b, [k, n], 'float32') C = tir.match_buffer(c, [nnz,], 'float32') with tir.block([nnz, tir.reduce_axis(0, k)], 'sddmm') as [eid, vk]: with tir.init(): C[eid] = 0. C[eid] = C[eid] + A[row[eid], vk] * B[vk, col[eid]] @tvm.script.tir def bsr_spmm(indptr_: ty.handle, indices_: ty.handle, a_data: ty.handle, b: ty.handle, c: ty.handle) -> None: mb = tir.var('int32') n = tir.var('int32') kb = tir.var('int32') nnzb = tir.var('int32') block_size = tir.var('int32') indptr = tir.match_buffer(indptr_, [mb + 1], 'int32') indices = tir.match_buffer(indices_, [nnzb], 'int32') A = tir.match_buffer(a_data, [nnzb, block_size, block_size], 'float32') B = tir.match_buffer(b, [kb, block_size, n], 'float32') C = tir.match_buffer(c, [mb, block_size, n], 'float32') with tir.block([mb, tir.reduce_axis(0, block_size), block_size, n], 'spmm_outer') as [io, ki, ii, j]: with tir.init(): C[io, ii, j] = 0. with tir.block([tir.reduce_axis(indptr[io], indptr[io + 1])], 'spmm_inner') as [ko]: C[io, ii, j] = C[io, ii, j] + A[ko, ii, ki] * B[indices[ko], ki, j] @tvm.script.tir def bsr_sddmm(row_: ty.handle, col_: ty.handle, a: ty.handle, b: ty.handle, c: ty.handle) -> None: mb = tir.var('int32') nb = tir.var('int32') k = tir.var('int32') nnzb = tir.var('int32') block_size = tir.var('int32') row = tir.match_buffer(row_, [nnzb,], 'int32') col = tir.match_buffer(col_, [nnzb,], 'int32') A = tir.match_buffer(a, [mb, block_size, k], 'float32') B = tir.match_buffer(b, [k, nb, block_size], 'float32') C = tir.match_buffer(c, [nnzb, block_size, block_size], 'float32') with tir.block([nnzb, block_size, block_size, tir.reduce_axis(0, k)], 'sddmm') as [bid, vi, vj, vk]: with tir.init(): C[bid, vi, vj] = 0. C[bid, vi, vj] = C[bid, vi, vj] + A[row[bid], vi, vk] * B[vk, col[bid], vj] @tvm.script.tir def ell_spmm(indices_: ty.handle, a_data: ty.handle, b: ty.handle, c: ty.handle) -> None: mb = tir.var('int32') n = tir.var('int32') kb = tir.var('int32') block_size = tir.var('int32') ell_cols = tir.var('int32') indices = tir.match_buffer(indices_, [mb, ell_cols], 'int32') A = tir.match_buffer(a_data, [mb, ell_cols, block_size, block_size], 'float32') B = tir.match_buffer(b, [kb, block_size, n], 'float32') C = tir.match_buffer(c, [mb, block_size, n], 'float32') with tir.block([mb, tir.reduce_axis(0, ell_cols), tir.reduce_axis(0, block_size), block_size, n], 'spmm') as [io, ko, ki, ii, j]: with tir.init(): C[io, ii, j] = 0. C[io, ii, j] = C[io, ii, j] + A[io, ko, ii, ki] * B[indices[io, ko], ki, j] if __name__ == '__main__': print(csr_spmm) print(csr_sddmm) print(bsr_spmm) print(bsr_sddmm) print(ell_spmm)