Learn by doing: TorchInductor Pattern Matcher
This is my third post in the “Learn by Doing: TorchInductor” series. The goal is to understand torch.compile from first principles, by working through examples.
This post is about the pattern matcher. In Inductor, this is the pass that rewrites FX graphs: it looks for a specific subgraph (the pattern) and swaps it with a replacement subgraph. vLLM uses it for fusions such as RMSNorm + quant, QK norm + RoPE and more. If you want to understand vLLM and how torch.compile delivers fusion optimizations, this post should help.
There are a few ways to register patterns in Inductor. The lower-level APIs (register_graph_pattern, register_lowering_pattern) build patterns explicitly with CallFunction. The higher-level API (register_replacement) lets you write the pattern and replacement as plain PyTorch code, and Inductor traces them into FX graphs. vLLM uses this style extensively because it usually keeps patterns readable and close to the original code.
I’ll focus on register_replacement: walk through an example, then explain how the pattern is registered, how matching works internally, and why auto_functionalized is required for mutating custom ops.
torch.compile captures your Python function into FX, and Inductor runs a series of graph passes before codegen. Pattern matching happens in multiple phases (pre-grad/joint and post-grad), but most Inductor fusion patterns you’ll encounter are applied in post-grad. You can see it in torch/_inductor/fx_passes/post_grad.py:
for i, patterns in enumerate(pass_patterns):
GraphTransformObserver(gm, f"pass_pattern_{i}").apply_graph_pass(
patterns.apply
)
In my last post on reductions kernels, the softmax rewrite happened in this exact stage.
To hook in your own pass, you set a config key inside the custom backend and pass it into compile_fx:
current_config["post_grad_custom_post_pass"] = custom_pass
return compile_fx(graph, example_inputs, config_patches=current_config)
Inside post_grad_passes, Inductor checks that key and runs the pass:
if post_grad_custom_post_pass := config.post_grad_custom_post_pass:
GraphTransformObserver(gm, "post_grad_custom_post_pass").apply_graph_pass(
post_grad_custom_post_pass
)
Inductor calls your custom pass on the same FX graph and exposes other hooks in torch/_inductor/config.py (post_grad_custom_pre_pass, joint_custom_pre_pass, joint_custom_post_pass, pre_grad_custom_pass, and custom_partitioner_fn). I use post_grad_custom_post_pass here because it runs immediately after the built-in post-grad patterns.
Before we jump into the full example, it helps to know the pieces we’ll use:
-
PatternMatcherPass: a container that holds many patterns and can apply them to an FX graph. -
register_replacement: traces a pattern function, traces a replacement function, and wires them into a matcher pass. -
fwd_only: the tracing mode we use for forward-only patterns (good enough here). -
auto_functionalized: a wrapper that makes mutating custom ops behave functionally in FX (so the graph has real data dependencies).
With those in mind, the example below is just: build a pattern, build a replacement, register them, and run the pass.
Example: RMSNorm + int8 quant
Below is a vLLM-style example. If you understand this one, you can read most of vLLM’s pattern passes and write your own fusions. The ops are placeholders here; in real code they would be backed by CUDA or Triton kernels.
The key call is register_replacement:
register_replacement(
rms_pattern_static, rms_replacement_static, inputs, fwd_only, my_patterns
)
It takes a pattern function, a replacement function, and example inputs for tracing.
# https://github.com/pytorch/pytorch/blob/main/test/inductor/test_pattern_matcher.py
import torch
from typing import Optional
from collections.abc import Callable
from torch._inductor.utils import run_and_get_code
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import fwd_only, PatternMatcherPass, register_replacement
# fused rms_norm + quant custom op (mutates result + scale)
@torch.library.custom_op(
"vllm::fused_rms_norm_quant_static", mutates_args=["result", "scale"]
)
def fused_rms_norm_quant_static(
result: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
azp: torch.Tensor,
epsilon: float,
) -> None:
result_rms = torch.mul(input, weight) + epsilon
_result = torch.mul(result_rms, scale).to(torch.int8)
scale.fill_(0.5)
# rms norm custom op (mutates result)
@torch.library.custom_op("vllm::rms_norm", mutates_args=["result"])
def rms_norm(
result: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
) -> None:
_result = torch.mul(input, weight) + epsilon
# static int8 quant custom op (mutates result + scale)
@torch.library.custom_op(
"vllm::static_scaled_int8_quant", mutates_args=["result", "scale"]
)
def static_scaled_int8_quant(
result: torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor,
azp: Optional[torch.Tensor] = None,
) -> None:
_result = torch.mul(input, scale).to(torch.int8)
scale.fill_(0.5)
# pattern: rms_norm then quant
def rms_pattern_static(
result: torch.Tensor,
result_rms: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at1 = auto_functionalized(
torch.ops.vllm.rms_norm.default,
result=result_rms,
input=input,
weight=weight,
epsilon=1e-6,
)
at2 = auto_functionalized(
torch.ops.vllm.static_scaled_int8_quant.default,
result=result,
input=at1[1],
scale=scale,
azp=None,
)
return at2[1], at2[2]
# replacement: fused op
def rms_replacement_static(
result: torch.Tensor,
result_rms: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at = auto_functionalized(
torch.ops.vllm.fused_rms_norm_quant_static.default,
result=result,
input=input,
weight=weight,
epsilon=1e-6,
scale=scale,
azp=None,
)
return at[1], at[2]
my_patterns = PatternMatcherPass()
inputs = [
torch.empty(5, 4, dtype=torch.int8, device="cuda"),
torch.empty(5, 4, dtype=torch.bfloat16, device="cuda"),
torch.empty(5, 4, dtype=torch.bfloat16, device="cuda"),
torch.empty(5, 1, dtype=torch.bfloat16, device="cuda"),
torch.empty(1, 1, device="cuda"),
]
register_replacement(
rms_pattern_static, rms_replacement_static, inputs, fwd_only, my_patterns
)
def custom_pass(graph: torch.fx.Graph) -> None:
print("=== before ===")
print(graph)
my_patterns.apply(graph)
graph.eliminate_dead_code()
print("=== after ===")
print(graph)
def custom_backend(
graph: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
) -> Callable:
from torch._inductor import config
from torch._inductor.compile_fx import compile_fx
current_config = config.shallow_copy_dict()
current_config["post_grad_custom_post_pass"] = custom_pass
return compile_fx(graph, example_inputs, config_patches=current_config)
@torch.compile(backend=custom_backend)
def fn(x, w, epsilon):
quant_result = torch.empty_like(x, dtype=torch.int8, device="cuda")
result_rms = torch.empty_like(x, dtype=torch.bfloat16, device="cuda")
scale = torch.ones((1, 1), device="cuda")
x = x.to(torch.bfloat16)
w = w.to(torch.bfloat16)
quant_result, scale = rms_pattern_static(
result=quant_result,
result_rms=result_rms,
input=x,
weight=w,
scale=scale,
)
return quant_result, scale
inputs = [torch.empty((5, 4), device="cuda"), torch.empty((5, 1), device="cuda"), 1e-6]
_, (code,) = run_and_get_code(fn, *inputs)
The example defines three custom ops, builds a two-op pattern (rms_norm then static_scaled_int8_quant), defines a fused replacement op, registers the replacement with example inputs, and then compiles a toy function with a custom backend that runs the pattern pass and prints the FX graph before and after the rewrite.
What register_replacement actually does
It traces the pattern with fwd_only and builds a small FX pattern graph, converts that graph into a pattern, stores the replacement function for later tracing, and then, during apply, walks the main FX graph to match each node as the output of that pattern before rewriting it.
A few constraints to keep in mind:
The pattern and replacement must be traceable. In its initial trace, register_replacement builds the pattern with ignore_types=(int, float, list, device, dtype), so scalar constants (like epsilon) are not matched strictly. The example inputs are only used to build the initial pattern graph; a later recheck uses concrete shape metadata from the main graph.
How the pattern is constructed
The FX graph produced by tracing is converted into a PatternExpr tree by fx_to_pattern inside torch/_inductor/pattern_matcher.py.
Useful mental model:
- Placeholders become
KeywordArg. This is how the matcher captures inputs by name. Even if the original function uses positional args, the placeholder name still becomes the keyword key in the match object. -
Arg()andKeywordArg()match anything. They’re the leaves of the pattern and they bind the matched node into the resultingMatch. - Call nodes match by
(node.op, node.target). If the target doesn’t match exactly, the match fails. -
Ignored()exists so you can ignore certain sub-args (and Inductor also uses it when it decides to ignore constants during pattern construction). There’s special-casing aroundgetitemso indices don’t get “ignored away”.
That’s why structure is strict: the tree has to line up in op type + target + argument structure.
Walkthrough of the example pattern
-
rms_pattern_staticcontains twoauto_functionalizedcalls. That’s the exact subgraph the matcher will look for. -
rms_replacement_staticcontains oneauto_functionalizedcall to the fused op. That’s what gets inserted. - The
inputslist matches the pattern signature:(result, result_rms, input, weight, scale). These are only for tracing the pattern; they aren’t runtime inputs to your model.
The key requirement is that the pattern graph and the real graph align structurally: same targets and argument structure. If those differ (different overload, different kwarg names, missing a getitem, etc.), the pattern won’t fire.
How PatternMatcherPass.apply finds and applies matches
The matcher is a recursive PatternExpr matcher over FX nodes. PatternMatcherPass.apply iterates candidate nodes grouped by (node.op, node.target) and calls pattern.match(node). The match checks op and target, then recurses through the flattened argument structure. Each child either matches another PatternExpr or (for non-ignored literals) must match an exact constant. The matcher also enforces reused-node and multi-user constraints; multi-output patterns can start from anchor nodes already matched in the context.
PatternMatcherPass.apply also supports an extra_check provided via register_replacement; the check can reject candidate matches. Once a match succeeds, the replacement function is traced with the matched inputs, the new subgraph is inserted, and the old nodes are erased.
Why auto_functionalized is required here
Consider this mutable-op sequence:
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
There’s no explicit edge from foo_inplace to bar_out even though bar_out depends on the mutation. Without explicit edges, graph transforms (and pattern matching!) can’t reliably reason about ordering and dependencies.
In our example the custom ops mutate outputs (result, scale). FX graphs are functional and need explicit data dependencies. If you call a mutating custom op directly, the graph doesn’t encode the mutation as a real value flow, and optimizations can reorder things auto_functionalized fixes this by turning mutation into a functional return. It clones the mutating arguments, calls the op on the clones, and returns a tuple:
(original_return, mutated_arg1, mutated_arg2, ...)
That’s why the pattern uses auto_functionalized(...) and then indexes at[1], at[2] to recover the mutated outputs.
One more Inductor detail: later in post-grad, the compiler decomposes auto_functionalized / auto_functionalized_v2 back into explicit clones + the underlying mutation op. This happens in decompose_auto_functionalized inside torch/_inductor/fx_passes/post_grad.py, and it’s implemented using the same pattern matcher mechanism (it registers patterns on torch.ops.higher_order.auto_functionalized and _v2, then replaces them with auto_functionalized_dense / auto_functionalized_v2_dense). This pass runs after reinplace_inplaceable_ops and asserts that no auto_functionalized nodes remain afterwards.
What the FX graphs show
Here is the before/after from the example. The key change is that two auto_functionalized calls (RMSNorm + quant) become a single fused auto_functionalized call.
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%empty_1 : [num_users=1] = call_function[target=torch.ops.aten.empty.memory_format](args = ([5, 4],), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda, pin_memory: False})
%convert_element_type : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg0_1, torch.bfloat16), kwargs = {})
%convert_element_type_1 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg1_1, torch.bfloat16), kwargs = {})
%auto_functionalized : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized](args = (vllm.rms_norm.default,), kwargs = {result: %empty_1, input: %convert_element_type, weight: %convert_element_type_1, epsilon: 1e-06})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized, 1), kwargs = {})
%empty : [num_users=1] = call_function[target=torch.ops.aten.empty.memory_format](args = ([5, 4],), kwargs = {dtype: torch.int8, layout: torch.strided, device: cuda, pin_memory: False})
%full_default : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([1, 1], 1), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
%auto_functionalized_1 : [num_users=2] = call_function[target=torch.ops.higher_order.auto_functionalized](args = (vllm.static_scaled_int8_quant.default,), kwargs = {result: %empty, input: %getitem_1, scale: %full_default, azp: None})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_1, 1), kwargs = {})
%getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_1, 2), kwargs = {})
return (getitem_3, getitem_4)
After the pattern replacement:
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%empty_1 : [num_users=0] = call_function[target=torch.ops.aten.empty.memory_format](args = ([5, 4],), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda, pin_memory: False})
%convert_element_type : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg0_1, torch.bfloat16), kwargs = {})
%convert_element_type_1 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg1_1, torch.bfloat16), kwargs = {})
%empty : [num_users=1] = call_function[target=torch.ops.aten.empty.memory_format](args = ([5, 4],), kwargs = {dtype: torch.int8, layout: torch.strided, device: cuda, pin_memory: False})
%full_default : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([1, 1], 1), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
%auto_functionalized_2 : [num_users=2] = call_function[target=torch.ops.higher_order.auto_functionalized](args = (vllm.fused_rms_norm_quant_static.default,), kwargs = {result: %empty, input: %convert_element_type, weight: %convert_element_type_1, epsilon: 1e-06, scale: %full_default, azp: None})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_2, 1), kwargs = {})
%getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_2, 2), kwargs = {})
return (getitem_5, getitem_6)
That’s the exact goal: match two consecutive ops and replace them with a single fused op, while keeping mutation semantics correct via auto_functionalized. Later in post-grad, auto_functionalized gets decomposed back the mutation op.
Wrap-up
If you want to add custom optimizations, the recipe is simple:
- Express the pattern in Python.
- Ensure mutable custom ops go through
auto_functionalized. - Register the pattern with
register_replacementand apply it in post-grad.
That’s the same mechanism vLLM uses to land its biggest fusions. Once you understand this, the rest of those fusion passes read like normal Python code. If you want to see the real thing: vLLM’s fusion rules are here: Link
References
- https://github.com/pytorch/pytorch
- https://github.com/vllm-project/vllm