TorchTPU:在Google规模下原生运行PyTorch on TPUs
TorchTPU: Running PyTorch Natively on TPUs at Google Scale

原始链接: https://developers.googleblog.com/torchtpu-running-pytorch-natively-on-tpus-at-google-scale/

## TorchTPU:将PyTorch引入Google的TPU超级计算基础设施 现代人工智能需要跨大规模分布式系统扩展模型——可能高达数万个加速器。Google的张量处理单元(TPU)是其人工智能基础设施的核心,为Gemini等模型提供动力,并服务于云客户。为了扩大访问范围,Google开发了**TorchTPU**,这是一个旨在无缝集成PyTorch与TPU的堆栈,优先考虑易用性、可移植性和性能。 TorchTPU利用独特的TPU架构——互连芯片,具有用于密集数学运算的TensorCore和用于不规则运算的SparseCore——同时旨在提供“感觉像PyTorch”的体验。它通过“Eager First”方法实现这一点,允许开发者在进行最少更改的情况下运行现有的PyTorch代码。三种eager模式(Debug、Strict和Fused)提供越来越高的性能,其中Fused Eager可以自动提供50-100%+的性能提升。 为了达到峰值性能,TorchTPU与`torch.compile`集成,利用XLA和StableHLO进行优化编译。它还通过Pallas和JAX支持自定义内核。重要的是,TorchTPU通过支持发散执行(MPMD)并提供面向硬件的最佳模型设计指南,解决了分布式训练中的挑战。 未来的开发重点是减少编译开销,扩展预编译内核,以及与PyTorch生态系统的更深入集成,包括Helion和vLLM等工具。TorchTPU旨在消除PyTorch和TPU硬件之间的摩擦,从而推动下一代人工智能创新。

黑客新闻 新的 | 过去的 | 评论 | 提问 | 展示 | 工作 | 提交 登录 TorchTPU:在谷歌规模下原生运行 PyTorch on TPUs (googleblog.com) 8 分,来自 mji 2 小时前 | 隐藏 | 过去的 | 收藏 | 讨论 帮助 考虑申请 YC 的 2026 年夏季批次!申请截止至 5 月 4 日 指南 | 常见问题 | 列表 | API | 安全 | 法律 | 申请 YC | 联系 搜索:
相关文章

原文

The challenges of building for modern AI infrastructure have fundamentally shifted. The modern frontier of machine learning now requires leveraging distributed systems, spanning thousands of accelerators. As models scale to run on clusters of O(100,000) chips, the software that powers these models must meet new demands for performance, hardware portability, and reliability.

At Google, our Tensor Processing Units (TPUs) are foundational to our supercomputing infrastructure. These custom ASICs power training and serving for both Google’s own AI platforms, like Gemini and Veo, and the massive workloads of our Cloud customers. The entire AI community should be able to easily access the full capabilities of TPUs, and because many of these potential users build models in PyTorch, an integration that allows PyTorch to work natively and efficiently on the TPU is crucial.

Enter TorchTPU. As an engineering team, our mandate was to build a stack that leads with usability, portability, and excellent performance. We wanted to enable developers to migrate existing PyTorch workloads with minimal code changes while giving them the APIs and the tools to extract every ounce of compute from our hardware. Here is a look under the hood at the engineering principles driving TorchTPU, the technical architecture we’ve built, and our roadmap for 2026.

Architecting for Usability, Portability, and Performance

To understand TorchTPU, you first have to understand the hardware it targets.

A TPU system is not just a chip; it is an integrated network. A host is attached to multiple chips, and each chip connects to the host and to other chips via our Inter-Chip Interconnect (ICI). This ICI links the chips into a highly efficient 2D or 3D Torus topology, allowing for massive scale-up without traditional networking bottlenecks. Within each chip, execution is divided between TensorCores and SparseCores. TensorCores are single-threaded units dedicated to dense matrix math, while SparseCores handle irregular memory access patterns like embeddings, gather/scatter operations, and offloading collectives.

These features mean TPUs are a powerful tool for machine learning; and our goal is to provide the specialized support needed to fully leverage these unique capabilities. This is where PyTorch comes in: the PyTorch toolchain already creates a consistent, widely-used interface over other device types.

Our core principle for usability is simple: it should feel like PyTorch. A developer should be able to take an existing PyTorch script, change their initialization to “tpu”, and run their training loop without modifying a single line of core logic.

Achieving this required an entirely new approach to how PyTorch interacts with the TPU compiler and runtime stack.

Engineering the TorchTPU Stack: The Technical Reality

Eager First: Flexibility Without Compromise

Moving from concept to a native PyTorch experience on TPU meant rethinking the execution stack. We established an "Eager First" philosophy. Instead of requiring developers into static graph compilation immediately, we implemented TorchTPU using PyTorch’s “PrivateUse1” interface. No subclasses, no wrappers; just ordinary, familiar PyTorch Tensors on a TPU. By integrating at this deep level, we are able to fully prioritize the eager execution experience developers expect from PyTorch.

We engineered three distinct eager modes to support the development lifecycle.

The first eager mode is Debug Eager, which dispatches one operation at a time and synchronizes with the CPU after every execution. It is inherently slow, but invaluable for tracking down shape mismatches, NaN values, and out-of-memory crashes.

The second is Strict Eager, which maintains single-op dispatch, but executes asynchronously, with the intent of mirroring the default PyTorch experience. This allows both the CPU and TPU to execute simultaneously, until a synchronization point is reached in the user’s script.

The breakthrough, however, is our Fused Eager mode. Using automated reflection on the stream of operations, TorchTPU fuses steps on the fly into larger, computationally dense chunks before handing them to the TPU. By maximizing TensorCore utilization and minimizing memory bandwidth overhead, Fused Eager consistently delivers a 50% to 100+% performance increase over Strict Eager, with no setup required by the user.

All three modes are backed by a shared Compilation Cache that can operate on a single host, or be configured as persistent across multi-host setups. This means that as TorchTPU learns your workload, you spend less time compiling, and more time running.

Static Compilation: Dynamo, XLA, and StableHLO

For users who want to unlock peak performance on the TPU, TorchTPU integrates natively with the torch.compile interface for full-graph compilation. We start by capturing the FX graph using Torch Dynamo. However, rather than routing through Torch Inductor, we utilize XLA as our primary backend compiler.

This was a highly deliberate architectural decision. XLA is rigorously battle-tested for TPU topologies. More importantly, it natively understands how to optimize the critical overlap between dense computation and collective communications across the ICI. Our translation layer maps PyTorch's operators directly into StableHLO, XLA’s primary Intermediate Representation (IR) for tensor math. This creates a direct connection from PyTorch into XLA’s core lowering path, allowing us to generate highly optimized TPU binaries while reusing the execution paths established by our eager modes.

For developers writing custom operators, we ensure extensibility doesn't break performance. TorchTPU natively supports custom kernels written in Pallas and JAX. By decorating a JAX function with @torch_tpu.pallas.custom_jax_kernel, engineers can write low-level hardware instructions that interface directly with our lowering path. Work is ongoing to also support Helion kernels.

Distributed Training and the MPMD Challenge

To preserve the flexibility and usability of eager and compiled modes at scale, we focused heavily on PyTorch's distributed APIs. Today, TorchTPU supports Distributed Data Parallel (DDP), Fully Sharded Data Parallel v2 (FSDPv2), and PyTorch’s DTensor out of the box. We've validated that many third-party libraries that build on PyTorch's distributed APIs work unchanged on TorchTPU.

One major limitation of PyTorch/XLA (a predecessor to TorchTPU) was that it only supported pure SPMD code. The reality of PyTorch inputs is that there is frequently slight divergence in the code running on different ranks: for instance, it is common for the “rank 0” process to do a little extra work for logging or analytics. This kind of input represents a challenge for the TPU stack, which is heavily optimized for SPMD optimization. XLA works best with a global view of code running on the system, but working around it adds overhead to the developer who has to carefully remove impure behavior.

TorchTPU is architected to carefully support divergent executions (MPMD), and will isolate communication primitives where necessary to preserve correctness, at minimal cost. This approach helps ensure that the experience of using PyTorch on the TPU is as natural as possible to existing PyTorch developers, while preserving XLA’s ability to overlap communication and computation with a global view of a distributed TPU deployment wherever possible.

TPU Hardware Awareness

The TPU can achieve very high performance and efficiency, but optimal model design may differ slightly from other hardware. For example, we frequently see models hardcoding attention head dimensions to 64, while current-generation TPUs achieve peak matrix multiplication efficiency at dimensions of 128 or 256. Modifying the model to target 128 or 256 dimensions better utilizes the large, dense and efficient tensor cores on the TPU chip.

Portability doesn't eliminate hardware realities, so TorchTPU facilitates a tiered workflow: establish correct execution first, then use our upcoming deep-dive guidelines to identify and refactor suboptimal architectures, or to inject custom kernels, for optimal hardware utilization.

The Road Ahead: 2026 and Beyond

We have laid a rock-solid foundation across training and serving support today, and we are actively tackling several open challenges to make TorchTPU a frictionless backend in the PyTorch ecosystem.

A primary focus for our compiler team is reducing recompilations triggered by dynamic sequence lengths and batch sizes. By implementing advanced bounded dynamism within XLA, we aim to handle shape changes without incurring compilation overhead. This can be an important feature for certain workloads, such as iterative next-token prediction.

We are also building out a comprehensive library of precompiled TPU kernels for standard operations to drastically reduce the latency of the first execution iteration.

Looking through the rest of 2026, we are working on:

  • The launch of our public GitHub repository, complete with extensive documentation and reproducible architectural tutorials.
  • Integration with PyTorch’s Helion DSL to further expand our custom kernel capabilities.
  • First-class support for dynamic shapes directly through torch.compile.
  • Native multi-queue support to ease migration of heavily asynchronous codebases with decoupled memory and compute streams.
  • Deep integrations with ecosystem pillars like vLLM and TorchTitan, alongside validated linear scaling up to full Pod-size infrastructure.

TorchTPU represents our dedicated engineering effort to provide a seamless, high-performance PyTorch experience on TPU hardware. We are breaking down obstacles and removing friction between the framework you love and the TPU supercomputing hardware required for the next generation of AI.

To stay informed on the latest TorchTPU updates, please visit the TPU Developer Hub.

联系我们 contact @ memedata.com