Skip to main content
Batch Tiling on Attention (BTA) tiles attention on the batch dimension to reduce memory pressure from attention, which is quadratic on the sequence dimension. This enables larger batch sizes that improve MoE layer performance.

Prerequisites

Make sure to have read through Trainer Overview and Trainer Configuration Overview which provide the basic overview of how to run Model Zoo models. In this document, you will be using the tools and configurations outlined in those pages.

How It Works

MoE layers need larger batches to increase arithmetic intensity, but the batch size increase is limited by the memory bottleneck primarily coming from the attention layer. BTA helps by tiling the attention, reducing the effective working batch of attention. Generally, for smaller sequence lengths, MoE cycles dominate end-to-end runs and increasing batch size helps increase MoE performance and overall end-to-end performance.

Performance Impact

Without BTA, MoE models experience significant throughput degradation as sparsity increases:
ConfigurationThroughput Degradation (without BTA)
128 expertsUp to 53% slower (2x slowdown)
Low top_k (high sparsity)Up to 86% slower (7x slowdown)
With BTA enabled, throughput remains stable and comparable to dense models across all configurations.

Parameters

  • ws_opt_enable_bta: Set to true to enable Batch Tiling on Attention.
  • ws_opt_bta_max_tile: Optional. Caps the maximum tile size. By default, BTA automatically selects a tile size based on model dimensions. If the automatic tiling is too aggressive or not aggressive enough, set this to a positive integer to cap the tile size. A smaller value reduces per-tile memory usage in attention, allowing for a larger overall batch size to improve MoE performance. The value must be a factor of your batch size.

Enabling Batch Tiling on Attention

To enable BTA, set the following in your Trainer configuration:
YAML
trainer:
  init:
    callbacks:
      - GlobalFlags:
          csx.debug.ini:
            ws_opt_enable_bta: true
            ws_opt_bta_max_tile: 4  # Optional. Must be a factor of your batch size

Key Considerations

  • Memory Trade-off: BTA reduces peak memory usage but may introduce additional compute overhead due to tiling
  • Model Compatibility: BTA is supported for transformer-based models with standard attention mechanisms

Further Reading

To learn more about optimizing your training workflows, see: