锯齿状闪存注意力优化
Jagged Flash Attention Optimization

原始链接: https://www.shaped.ai/blog/jagged-flash-attention-optimization

Meta的论文“使用Jagged Flash Attention增强大型推荐系统的性能和可扩展性”解决了推荐系统中填充变长类别特征的低效问题。填充会引入不必要的内存和计算开销。他们的解决方案,Jagged Feature Interaction Kernels(锯齿特征交互核),利用动态大小的张量来连续处理变长特征,无需填充。这些内核使用定制的Triton代码构建,优化了数据局部性和并行性,用于乘法和softmax等运算。性能提升显著,与密集注意力相比,速度提升高达9倍,内存减少高达22倍。他们还利用Flash Attention的平铺优化来进一步改进注意力计算。此优化使生产环境中的每秒查询数(QPS)提高了10%,内存使用量减少了18%。这项工作结合开源的TorchRec库,能够训练更大的模型并处理更长的特征序列,标志着推荐系统的一项重大进步。

Hacker News 最新 | 往期 | 评论 | 提问 | 展示 | 招聘 | 提交 登录 Jagged Flash Attention 优化 (shaped.ai) 7 分,由 tullie 发帖 1 小时前 | 隐藏 | 往期 | 收藏 | 1 评论 platers 14 分钟前 [–] Flash Attention 原生支持将多个可变长度序列打包到单个调用中,Jagged Flash Attention 的优势是什么? 回复 加入我们,参加 6 月 16-17 日在旧金山举办的 AI 初创公司学校! 指南 | 常见问题 | 列表 | API | 安全 | 法律 | 申请 YC | 联系我们 搜索:

原文

A write-up on the RecSys '24: Proceedings of the 18th ACM Conference on Recommender Systems paper, “Enhancing Performance and Scalability of Large-Scale Recommendation Systems with Jagged Flash Attention”, by Meta Platforms, CA, USA.

The Problem: Why Traditional Methods Fall Short

Traditional recommendation systems face challenges with variable-length categorical features, such as user interaction history. Unlike fixed-size numerical features, these require special handling. The conventional approach of padding to standardize lengths introduces significant overhead, especially in GPU-intensive operations.

Consider this scenario: If you're tracking a user's last 100 interactions, but they only have 20, you'd need to pad the remaining 80 slots with zeros. This padding creates:

  • Unnecessary memory usage
  • Increased computational load
  • Higher communication overhead between system components

TorchRec: Scalable Recommender Systems 

TorchRec is a powerful PyTorch domain library designed to address the unique challenges of building and deploying large-scale recommendation systems. It offers several key features and optimizations:

Embedding Operations

  • Fused embedding tables and bulk lookups for improved performance
  • Efficient single kernel lookups across multiple embedding tables

Sparse Data Handling

  • Specialized containers and operations for sparse data
  • Optimized permutation and all-to-all communication

Advanced Sharding Capabilities

  • Supports various techniques: data parallel, table-wise, row-wise, column-wise
  • Hierarchical sharding for scaling to many GPUs
  • Automated sharding planner for optimal strategies

Performance Optimizations

  • Quantization support for embeddings (int8/int4)
  • High-performance GPU inference with TorchDeploy integration
  • Caching between GPU and system memory

Production Impact at Meta

  • Enables training of 3+ trillion parameter models
  • Up to 10x performance improvements
  • Facilitates transition to accelerator-based full-sync training

TorchRec excels at handling models combining deep neural networks with wide embedding tables, addressing PyTorch's previous limitations with sparse data and wide models. This enables researchers and engineers to build and efficiently deploy state-of-the-art personalization models in production environments.

The Game-Changer: Jagged Feature Interaction Kernels

Jagged Feature Interaction Kernels represent a significant advancement in handling variable-length categorical features in recommendation systems. This innovative approach efficiently extracts fine-grained insights from long categorical features by utilizing dynamically sized tensors. The kernel operates on jagged tensors , which store variable-length features from multiple samples contiguously in memory without padding.

Image Source: Research paper

The key components of Jagged Feature Interaction Kernels include:

  • Values tensor: A contiguous array storing all feature values collectively
  • Offset tensor: Determines sample boundaries for each feature segment
  • Triton kernels: Custom-built for both forward and backward computations, optimizing data locality and parallelism

These kernels enable efficient operations such as jagged tensor multiplication, softmax computations, and element-wise operations specifically tailored for sparse data structures. By prioritizing the most relevant feature values and assigning them higher weights, Jagged Feature Interaction Kernels significantly improve the performance and memory efficiency of large-scale recommendation models.

Performance Gains

Image Source: Research paper

Speedup

  • Jagged attention: Up to 2× faster than dense attention
  • Jagged Flash Attention: 9× speedup compared to dense attention
  • 3× speedup over dense flash attention

Memory Efficiency

  • Jagged attention: Up to 3.5× reduction vs. dense attention
  • Jagged Flash Attention: Impressive 22× memory reduction

Real-World Impact (Production)

  • 10% improvement in Queries Per Second (QPS)
  • 18% reduction in memory usage
  • Enhanced ability to handle longer feature sequences
  • Support for more complex model architectures

These optimizations significantly enhance the efficiency and scalability of large-scale recommendation systems, enabling more complex model architectures and longer feature sequences.

Flash Attention Tiling Optimization

Flash Attention's  tiling optimization is a key innovation that significantly improves the efficiency of attention computations in large language models. By leveraging the GPU memory hierarchy, FlashAttention reduces the number of memory accesses to high-bandwidth memory (HBM) and maximizes the use of fast on-chip SRAM. The tiling strategy involves dividing the input matrices into smaller blocks that fit into SRAM, allowing for efficient processing without excessive data movement.

The core algorithm employs two main techniques:

  • Tiling: Input matrices Q, K, and V are divided into blocks of size B×d, where B is the block size and d is the embedding dimension.
  • Incremental softmax: A modified online softmax algorithm is used to process attention scores block-wise, maintaining running statistics to ensure numerical stability.

This approach reduces the complexity of attention from O(N2) to approximately O(N) in terms of memory accesses, where N is the sequence length. The practical benefits include up to 3x speedup over dense attention implementations and significant memory savings, enabling the processing of longer sequences with limited GPU resources

A New Era for Recommendation Systems

Jagged Flash Attention and the open source TorchRec implementation are a fantastic contribution to the recommendation system community. It addresses key challenges in handling variable-length categorical features and attention mechanisms, significantly improving performance in production systems and making further advancements in the field.

Key implementation considerations for leveraging Jagged Flash Attention include:

  • Memory efficiency: Prioritize jagged tensor implementations over dense padded approaches to reduce memory overhead.
  • Computational optimization: Utilize custom Triton kernels for jagged tensor operations, achieving up to 2.52× speedup for matrix multiplications and 3.06× for softmax operations.
  • Scalability: Implement block-wise processing for large-scale operations, allowing for efficient handling of longer sequences and more complex model architectures.
  • GPU utilization: Leverage shared memory effectively and implement fused operations to maximize computational efficiency.

The practical impact of these optimizations is substantial, with production models demonstrating a 10% improvement in Queries Per Second (QPS) and an 18% reduction in memory usage. Experiments were performed for recommendation system use-cases but we could see this being useful for any use-case that requires sparse variable length batch sizes and attention models.

At Shaped we use Jagged Tensors and the TorchRec library to power many of our PyTorch models. We're excited to start integrating the Flash Attention model and see what improvements we get across our customer base! 

联系我们 contact @ memedata.com