Using PyTorch custom operations is common in PyTorch models. PyTorch custom operations can be custom classes and custom functions implemented in C++ and CUDA and used in both Python and C++ inference programs.
In this blog post, I would like to share how to implement PyTorch custom operations in C++ and CUDA, and how to use them in PyTorch models and AOTInductor compiled inference programs, using a simple identity convolution example.
PyTorch Custom Function
PyTorch custom functions can be implemented in C++ and CUDA and registered using the TORCH_LIBRARY_IMPL macro. Both the CPU and CUDA implementations can be provided, and PyTorch will dispatch to the correct implementation based on the device of the input tensors.
torch::Tensor identity_conv_cpu_impl(const torch::Tensor& input) { TORCH_CHECK(!input.is_cuda(), "identity_conv_cpu_impl: input must be a CPU tensor"); return input.clone(); }
torch::Tensor identity_conv_cuda_impl(const torch::Tensor& input) { TORCH_CHECK(input.is_cuda(), "identity_conv_cuda_impl: input must be a CUDA tensor");
auto output = torch::empty_like(input); constint64_t numel = input.numel();
TORCH_LIBRARY_IMPL(my_ops, CUDA, m) { m.impl("identity_conv_op", identity_conv_cuda_impl); }
TORCH_LIBRARY_IMPL(my_ops, CPU, m) { m.impl("identity_conv_op", identity_conv_cpu_impl); }
PyTorch Custom Class
PyTorch custom functions are stateless and cannot hold any parameters. If we would like to implement a custom class that holds some parameters and has a forward() method that can be called from Python, we can use torch::CustomClassHolder to define a custom class in C++ and register it with TORCH_LIBRARY macro.
The PyTorch custom classes, functions, and their registrations in C++ are built into a shared library (libidentity_conv_ops.so) that can be loaded and registered in PyTorch using torch.ops.load_library. For torch.compile and torch.export compatibility, we also need to register “fake” (abstract) versions of the custom classes and functions in PyTorch using @register_fake_class and @torch.library.register_fake so that the FakeTensor-based symbolic tracing can work correctly without having to execute the actual C++/CUDA code during tracing.
""" custom_ops.py ============= Loads the C++ / CUDA shared library and sets up all custom PyTorch operations used by the IdentityModel: 1. torch.classes.my_ops.IdentityConvClass (registered by the shared library) - A fake/abstract version is registered here so that torch.export can trace through module attributes that hold an instance of this class. 2. my_ops::identity_conv_op (schema + CPU + CUDA registered by the shared library) - register_fake: abstract implementation for torch.export / FakeTensor. """
from torch._library.fake_class_registry import register_fake_class
@register_fake_class("my_ops::IdentityConvClass") classFakeIdentityConvClass: """Abstract counterpart of IdentityConvClass used during torch.export."""
PyTorch custom classes can be loaded using torch.classes and PyTorch custom functions can be loaded using torch.ops after the shared library is loaded.
""" model.py ======== Defines the four-layer IdentityModel used in the AOTInductor demo. Layer layout ------------ layer1 : IdentityConv - native PyTorch operators layer2 : IdentityConvCustomClass - torch.classes C++/CUDA custom class layer3 : IdentityConvCustomOp - torch.library.custom_op C++/CUDA op layer4 : IdentityConv - native PyTorch operators Every layer is an identity transformation, so model(x) == x for any input x. """
import torch import torch.nn as nn
from custom_ops import identity_conv_op
classIdentityConv(nn.Module): """Identity convolution implemented with native PyTorch operators. Uses a depthwise Conv2d with kernel_size=1 and weight=1.0, which is equivalent to a no-op (output == input). This layer is compatible with torch.export and AOTInductor out of the box. """
classIdentityConvCustomClass(nn.Module): """Identity convolution backed by a torch.classes C++/CUDA custom class. At runtime the forward call is dispatched to the CUDA kernel registered inside IdentityConvClass (csrc/identity_conv.cpp + .cu). For torch.export compatibility a FakeIdentityConvClass is registered in custom_ops.py via @register_fake_class so that symbolic tracing works. """
classIdentityConvCustomOp(nn.Module): """Identity convolution backed by a torch.library.custom_op C++/CUDA op. The op (my_ops::identity_conv_op) is defined in custom_ops.py with: • a register_fake implementation for torch.export tracing • a register_kernel("cuda") implementation that calls the CUDA kernel """
defforward(self, x: torch.Tensor) -> torch.Tensor: x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) return x
defcreate_model(channels: int = 3) -> IdentityModel: """Return an IdentityModel in eval mode on the default CUDA device.""" return IdentityModel(channels=channels).cuda().eval()
PyTorch Model Export and Lowering
The PyTorch model using custom classes and custom functions can be exported with torch.export if fake (abstract) versions of all custom classes and functions are registered for torch.export symbolic tracing.
""" export_compile.py ================= Exports the IdentityModel with torch.export and compiles it with torch._inductor.aoti_compile_and_package. The resulting package (model.pt2) is written to the artifacts/ directory and can be loaded by both run_inference.py (Python) and the C++ inference binary. Usage (run from the python/ directory): python export_compile.py """
From the exported graph we can see that the custom class IdentityConvClass.forward is represented as a call to torch.ops.higher_order.call_torchbind. The custom op identity_conv_op is represented as a call to torch.ops.my_ops.identity_conv_op.
The exported program can be compiled and packaged with torch._inductor.aoti_compile_and_package to produce a model.pt2 package that can be loaded by both Python and C++ inference programs. The custom class and custom op implementations will be loaded from the shared library and correctly dispatched at runtime when the compiled model is executed.
""" run_inference.py ================ Loads the AOTInductor-compiled IdentityModel package (model.pt2) and runs inference to verify correctness. The output of the identity model must equal the input within a tight floating-point tolerance. Usage (run from the python/ directory after export_compile.py): python run_inference.py [MODEL_PATH [OP_LIB_PATH]] Arguments: MODEL_PATH Path to the compiled model package (.pt2). Defaults to ../artifacts/model.pt2 relative to this script. OP_LIB_PATH Path to the custom-op shared library (.so). When provided the library path is forwarded to custom_ops.py via the IDENTITY_CONV_OPS_LIB environment variable so that torch.ops.load_library uses that file instead of the default ../ext/libidentity_conv_ops.so. """
The custom class and custom function shared library loading and registration can be performed using dlopen in a pure C++ inference program without any pybind11 or libpython dependency.