Learn by doing: TorchInductor (DeviceAssert)

I recently contributed a PR to PyTorch that fixed how assert behaves inside torch.compile. While working on it, I realized this issue touched many components of TorchInductor such as decomposition, lowering, IR, op handlers, schedulers, codegen. That gave me the idea to write a ‘learn-by-doing’ series, where we solve real issues and use them to understand PyTorch compiler? That’s where it begins.

My goal with this blog post is to show how we actually understand a problem by creating mental model, try different methods, debug, and finally solve it. As we go through the problem, I’ll explain each component (to best of my knowledge). Before getting started with the issue, I’ll first give a bird’s-eye view of torch.compile.

High-level view of torch.compile:

  
    
  

The main difference between eager mode and compile mode is, in eager mode, PyTorch just runs the program line by line. That’s simple and good for debugging, but it doesn’t give much room for optimizations like advance kernel fusion.

With compile mode (torch.compile), the idea is different. TorchDynamo first takes your python program and converts it into a Torch FX graph for the forward pass. Then AOTAutograd builds the backward pass ahead of time. After that, the graph is lowered into ATen IR. Finally, TorchInductor takes that FX/ATen IR and generates optimized kernels (C++ for CPU or Triton for GPU).

Note: this is just a very high-level view of torch.compile. If you want more in-depth details, I recommend to start with Jason’s blog post, PT2 paper, and the official PyTorch docs.

The Problem

Here’s a simple python repro with an assert statement. When you run the below code you can see that should not run get printed, but the program should raise should throw, because a > 0 is false for one element. In eager mode it throws. That means correctness checks silently disappeared under torch.compile. We need to fix that.

import torch

def f():
    a = torch.tensor([1.0, -2.0], device="cuda")
    result = torch.all(a > 0)
    assert result, "should throw"
    print("should not run")

f_c = torch.compile(func)
f_c()

Where do we start? TorchInductor consumes FX/ATen graphs, so first we check what got captured and what got generated. In general use TORCH_TRACE="./logs" run the above code it will generate logs then we can use tlparse to inspect it. I highly recommend to use this. We can also use TORCH_COMPILE_DEBUG=1 or TORCH_LOGS

This is the output graph from TorchDynamo

class GraphModule(torch.nn.Module):
    def forward(self):
        # code: a = torch.tensor([1.0, -2.0], device="cuda")
        a: "f32[2][1]cuda:0" = torch.tensor([1.0, -2.0], device = "cuda")

        # code: result = torch.all(a > 0)
        gt: "b8[2][1]cuda:0" = a > 0; a_0 = None
        result: "b8[1]cuda:0" = torch.all(gt); gt = None

        # code: assert result, "should throw"
        __assert_async = torch._assert_async(result, 'should throw'); result = None
        return ()

In the FX graph, the assert gets translated to aten._assert_async.msg, but the generated kernel still prints should not run. Why? After capture, During decompositions replace higher-level ops with simpler/core ops see more examples in this file. For _assert_async, the registered rule returns None (no value, no side effect), so the assert becomes a (no op). By the time the graph reaches Inductor, there’s nothing to lower or codegen for the assert, and it is removed by the scheduler’s Dead Code Elimination (DCE) because (no users + no effects). That’s why the final kernel has no assert.

@register_decomposition([aten._assert_async.msg])
def assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None:
    return


# Following `assert_async_msg_decomp` and implement as non-op.
@register_decomposition([aten._functional_assert_async.msg])
def functional_assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None:
    return

To fix, remove those decompositions lines for this _assert_async. We don’t want to decompose the assert into nothing, we want to lower it into a real op that has side effects.

How we will lower it? The graph lowering lives in lowering.py, and we use the @register_lowering decorator. we already disabled the decompositions so _assert_async actually reaches lowering. Then we need to map aten._assert_async.msg and aten._functional_assert_async.msg to internal Inductor op `device_assert_async (we need to implement). In general graph lowering converts FX/ATen to Inductor IR (I’ll explain what Inductor IR looks like later). In lowering.py I add two registrations like this:

@register_lowering(aten._assert_async.msg)
def lower_assert_async(cond, msg):
    # TODO: need to implement
    pass


@register_lowering(aten._functional_assert_async.msg)
def lower_assert_functional_async(cond, msg):
    # TODO: need to implement
    pass

if we add breakpoint() in lower_assert_async method we can see the input cond and msg

(Pdb) cond
TensorBox(StorageBox(
  Pointwise(
    'cuda',
    torch.bool,
    def inner_fn(index):
        tmp0 = ops.index_expr(0, torch.int64)
        tmp1 = ops.constant(1.0, torch.int64)
        tmp2 = tmp0 < tmp1
        tmp3 = ops.constant(1.0, torch.float32)
        tmp4 = ops.constant(-2.0, torch.float32)
        tmp5 = ops.where(tmp2, tmp3, tmp4)
        tmp6 = ops.constant(1.0, torch.float32)
        tmp7 = tmp5 > tmp6
        tmp8 = tmp5 < tmp6
        tmp9 = tmp7 | tmp8
        tmp10 = ops.index_expr(1, torch.int64)
        tmp11 = ops.constant(1.0, torch.float32)
        tmp12 = ops.constant(-2.0, torch.float32)
        tmp13 = ops.where(tmp11, tmp12, tmp3)
        tmp14 = ops.constant(0.0, torch.float32)
        tmp15 = tmp13 != tmp14
        tmp16 = ops.logical_not(tmp15)
        tmp17 = tmp14 != tmp15
        tmp18 = ops.logical_or(tmp16, tmp17)
        return tmp19
    ,
    ranges=[]
    origin_node=logical_not_1,
    origins=OrderedSet([logical_not_1, any_1, logical_not, gt_1, li…,
    stack_traces = {
    File "test_device_assert.py", line 7, in func,
        result = torch.all(a > 0),
    },
    stack_traces = {
    File "test_device_assert.py", line 6, in func,
        a = torch.tensor([1.0, -2.0], device="cuda"),
    }
  )
))

(Pdb) msg
'should throw'

cond arrives as TensorBox and msg is the python string. In Inductor terms, TensorBox is the value wrapper, StorageBox means the value is buffer-backed (i.e. TensorBox currently has storage.) Pointwise is the elementwise compute node whose inner_fn(index), Inductor will tile/vectorize when it builds the kernel. Because this cond comes from torch.all(a > 0), it’s a 0-D boolean on ‘cuda’ with dtype=torch.bool, so the iteration ranges are empty.

For more details on Inductor IR, Please refer thi section in IR.py

How to create a device assert op ?

There are a couple of ways to add an op in Inductor. One such way is ExternelKernel method which is implemented in torch/_inductor/ir.py

class DeviceAssert(ExternKernel):

    def __init__(self, cond: TensorBox, msg: str, device: torch.device) -> None:
        cond.realize()
        super().__init__(None, NoneLayout(device=device), [cond])
        self.msg = msg

    def has_side_effects(self) -> bool:
        return True

    def is_fusible(self) -> bool:
        return False

    def codegen(self, wrapper: PythonWrapperCodegen) -> None:
        # generate code triton for gpu and c++ for cpu

Let me first explain what is does, this class extends ExternKernel (its create an external implementation of ops). it override has_side_effects() so the scheduler’s DCE won’t drop it (we’ll talk more about that later). It takes the condition and the message in codegen() method it generate the backend code C++ for cpu (throw on failure) and triton for gpu (device-side assert).

If you thought we solved this problem, that’s what I thought initially. I implemented DeviceAssert as an ExternKernel see here and it worked and assert get triggered. But extern kernels don’t fuse with other ops, so the assert sits in its own kernel. we want this to participate in fusion, so the extern kernel route isn’t the right fit. Moving to the other method (ops-handler).

Another way is the OpsHandler. OpsHandler is Inductor’s internal op interface the protocol for the ops we can call as torch._inductor.virtualized.ops during lowering. For more details, see torch/_inductor/ops_handler.py. Here, I add device_assert_async(cond, msg) to the OpsHandler protocol, implement it in KernelFormatterHandler (so it prints correctly in generated code), and implement dtype/shape propagation so the scheduler knows it’s a output type and shape. Backend overrides then map it to real code (C++ throw on CPU via CppOverrides , tl.device_assert on GPU via TritonOverrides).

# torch/_inductor/ops_handler.py
class OpsHandler(Generic[T]):
    def device_assert_async(self, cond: T, msg: str) -> T:
        raise NotImplementedError

# torch/_inductor/ops_handler.py
class KernelFormatterHandler(DefaultHandler):
    def device_assert_async(self, cond, msg: str):
        return f"ops.device_assert_async({cond}, {msg})"

# torch/_inductor/dtype_propagation.py
class DtypePropagationOpsHandler:
    @staticmethod
    def device_assert_async(cond, msg: str) -> torch.dtype:
        return torch.bool

# torch/_inductor/shape_propagation.py
class ShapePropagationOpsHandler:
    @staticmethod
    def device_assert_async(cond: ShapeArg, msg: str) -> None:
        return None

# torch/_inductor/codegen/triton.py
class TritonOverrides(OpOverrides):
    @staticmethod
    def device_assert_async(cond, msg):
        return f"tl.device_assert({cond}, {repr(msg)})"

# torch/_inductor/codegen/cpp.py
class CppOverrides(OpOverrides):
    @staticmethod
    def device_assert_async(cond, msg):
        return f'({cond} ? 0 : (throw std::runtime_error("{msg}"), 0))'

the next step is adding the device_assert_async op in lowering i.e., convert the aten_assert_async node into Inductor IR. First realize cond (TensorBox) by calling cond.realize(), this materializes storage if it’s still lazy (no-op if it’s already realized), ensuring the cond is concretely computed and available in memory so the device assert can correctly evaluate it. Then wrap it with Pointwise inside inner_fn(index) load the predicate with cond.make_loader()(index) and call ops.device_assert_async(pred, msg). Build this node with Pointwise.create() over cond’s iteration space (ranges = list(cond.get_size())). Finally realize the assertion node to pin its placement in the scheduling (we need to realize explicitly, because output here doesn’t have no consumer, so it won’t get implicitly realized).


def _assert_async(cond, msg):
    cond.realize()
    cond = to_dtype(cond, torch.bool)

    def inner_fn(index):
        return ops.device_assert_async(cond.make_loader()(index), msg)

    assertion_op = Pointwise.create(
        device=cond.get_device(),
        dtype=cond.get_dtype(),
        inner_fn=inner_fn,
        ranges=list(cond.get_size()),
    )
    assertion_op.realize()

@register_lowering(aten._assert_async.msg)
def lower_assert_async(cond, msg):
     return _assert_async(cond, msg)


@register_lowering(aten._functional_assert_async.msg)
def lower_assert_functional_async(cond, msg):
     return _assert_async(cond, msg)

Okay lets run again, still no device_assert in the kernel. why? with ExternKernel we set has_side_effects() to True, so DCE couldn’t drop it. in the ops-handler path we didn’t, so the scheduler prunes it (no users so its gone). How do we mitigate this?, the idea is to mark device_assert_async as side-effect. During scheduling, if self._body.has_op("device_assert_async") then set has_side_effects = True so DCE won’t delete it.

class SchedulerNode(BaseSchedulerNode):

    @cache_on_self
    def has_side_effects(self) -> bool:
        if self._body is not None and self._body.has_op("device_assert_async"):
            return True
        return super().has_side_effects()

Now we run the code, we will get tl.device_assert in generated triton kernel.

@triton.jit
def triton_poi_fused_assert_async_all_gt_lift_fresh_0(xnumel, XBLOCK : tl.constexpr):
    xnumel = 1
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)

    tmp0 = tl.full([1], 0, tl.int64)
    tmp1 = tl.full([1], 1, tl.int64)
    tmp2 = tmp0 < tmp1
    tmp3 = 1.0
    tmp4 = -2.0
    tmp5 = tl.where(tmp2, tmp3, tmp4)
    tmp6 = 0.0
    tmp7 = tmp5 > tmp6
    tmp8 = tmp7 == 0
    tmp9 = tmp1 < tmp1
    tmp10 = tl.where(tmp9, tmp3, tmp4)
    tmp11 = tmp10 > tmp6
    tmp12 = tmp11 == 0
    tmp13 = tmp8 | tmp12
    tmp14 = tmp13 == 0
    tmp15 = tl.device_assert(tmp14, 'should throw')

Summary

  
    
  

We fixed the torch.compile case where assert just vanished. In eager it throws but in torch.compile it printed should not run. FX graph still showed aten._assert_async.msg, but AOTAutograd’s decomposition turned it into None, so by the time it hit Inductor there was nothing to lower and DCE dropped the stub. We removed that decomp for this path, lowered it to a real in-kernel op device_assert_async via the ops-handler, implemented dtype/shape and informed the scheduler that any _body containing device_assert_async is side-effecting so it can’t be pruned. Codegen then does the obvious: c++ throw on cpu, tl.device_assert on gpu. Now the assert survives torch.compile and fails at runtime like it should.

Credits

Please find the more details in this PR:160677. The solution is based on Elias’s suggestion in this comment. I want to thank Michael for his guidance, Finally I want to thank all my reviewers!

Hope this was helpful! If you’ve any suggestions, please drop a comment. Thanks for reading!

References

  1. https://github.com/pytorch/pytorch

  2. https://docs.pytorch.org/docs/stable/index.html

  3. https://pytorch.org/blog/pytorch-pytorch-2-paper-tutorial/

  4. https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747

  5. https://www.youtube.com/watch?v=egZB5Uxki0I