Chapter 2: Tensor Program Abstraction

This chapter is divided into three major parts from the MLC.ai book:

  1. Primitive Tensor Functions & Tensor Program Abstraction
  2. Case Study: Building and Transforming mm_relu
  3. Exercises: Writing and Transforming TensorIR

1. Primitive Tensor Functions & Abstraction

What are Primitive Tensor Functions?

  • Primitive tensor functions are the atomic building blocks of machine learning computation (e.g., add, matmul, relu).
  • Each function maps input tensors to output tensors and is defined by loop-based semantics.
  • They can be written in different ways:
    • High-level Python/Numpy implementation
    • PyTorch/JAX operator
    • Low-level C/CUDA kernel
  • To unify across these forms, we adopt a common abstraction in the compiler.

Tensor Program Abstraction

A tensor program specifies:

  • Buffers: multidimensional arrays holding input/output.
  • Iteration axes: loops describing computation space.
  • Axis types:
    • Spatial: indexes over independent data points.
    • Reduce: indexes where accumulation happens (e.g., sum over k in matmul).
  • Blocks: logical computation regions with explicit reads/writes.

For example, in TensorIR we annotate axes:

for i, j, k in T.grid(M, N, K):
    with T.block("matmul"):
        vi, vj, vk = T.axis.remap("SSR", [i, j, k])  # S=spatial, R=reduce
        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

This meta-information is essential for safe program transformations such as reordering loops, tiling, or parallelization.

Why Do We Need This Abstraction?

  • Writing kernels directly per hardware backend is unscalable.
  • Tensor programs give us a hardware-agnostic but optimizable representation.
  • By separating semantics from schedule, we can:
    • Guarantee correctness.
    • Explore optimization choices systematically.
    • Retarget the same operator to CPU, GPU, WebGPU, or NPUs.

2. Case Study: mm_relu

To ground the abstraction, the chapter walks through a case study: building a fused matmul + relu operator in TensorIR.

High-Level Idea

  1. Compute intermediate result Y = A × B.
  2. Apply ReLU: C = max(Y, 0).
  3. Express everything as a single TensorIR module with blocks and axis annotations.

Example TensorIR for mm_relu

@I.ir_module
class Module:
    @T.prim_func
    def mm_relu(A: T.Buffer((128, 128), "float32"),
                B: T.Buffer((128, 128), "float32"),
                C: T.Buffer((128, 128), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        Y = T.alloc_buffer((128, 128))
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[vi, vj] = T.float32(0.0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))

Key features:

  • Block Y: defines matmul with reduction axis k.
  • Block C: applies elementwise ReLU.
  • Axis remap ensures compiler knows which axes are spatial and which are reductions.

Why Blocks Matter

  • Each block explicitly declares read/write regions.
  • This allows dependence analysis, automatic verification, and safe transformation.
  • For example: reordering loops won’t break correctness if dependence rules are satisfied.

Transformations on mm_relu

Once expressed in TensorIR, we can apply transformations:

  • Loop tiling, reordering, fusion.
  • Caching reads into shared/local memory.
  • Mapping loops onto GPU threads and blocks.
  • Tensorization (replace loop patterns with tensor core instructions).

This illustrates the separation of algorithm and schedule: same semantics, many optimized implementations.


3. Exercises: Writing and Transforming TensorIR

The book provides exercises to practice TensorIR. These highlight how to write primitive operators and how to transform them for performance.

Part 1: Writing TensorIR

  • Element-wise Add:
    Write C[i] = A[i] + B[i] with explicit block annotations.

  • Broadcast Add:
    One tensor has smaller shape, automatically broadcast along dimensions. Exercise demonstrates handling axis mapping.

  • 2D Convolution:
    Express sliding-window operations with spatial and reduction axes. Convolution is more complex due to stride, padding, and channels.

Part 2: Transforming TensorIR

  • Batch MatMul + ReLU (bmm_relu):
    Extend matmul example to batched input.
    Add elementwise ReLU afterwards.

  • Apply transformations:
    • Tile and reorder loops.
    • Cache reads/writes into shared memory.
    • Parallelize and vectorize.
    • Compare runtime with naive baseline.
  • Measure performance:
    Build and run transformed modules, evaluate speedup. This ties theory back to practical outcomes.

4. Cross-Cutting Themes

  • Expressiveness vs Optimizability: TensorIR balances clarity of computation with flexibility for optimization.
  • Algorithm vs Schedule: first write correct semantics, then tune schedule for performance.
  • Blocks and Dependencies: explicit read/write sets enable safe transformations.
  • Iterative Cycle: write → transform → measure → refine.

5. Key Takeaways

  • Tensor program abstraction provides a unified foundation for compiling ML operators.
  • Primitive functions (matmul, relu, conv) can be represented in a general loop/block form.
  • TensorIR makes transformations explicit and safe.
  • Case studies (e.g., mm_relu) show how high-level computations lower to structured IR.
  • Exercises demonstrate how to write and optimize real operators.
  • Ultimately, this abstraction enables systematic optimization across heterogeneous backends.

Summary

Chapter 2 introduces tensor programs as the bridge between ML semantics and hardware execution. By writing operators once in TensorIR, we can explore schedules, optimizations, and transformations without changing semantics. This foundation underpins the rest of the MLC.ai book and sets the stage for advanced compilation topics such as scheduling search, tensorization, and heterogeneous deployment.

Updated: