Skip to content

[CUDA] PagedAttention: add SM<80 fp16 fallback via memory-efficient attention#28200

Open
elwhyjay wants to merge 10 commits intomicrosoft:mainfrom
elwhyjay:feature/paged-attention-mea-fallback
Open

[CUDA] PagedAttention: add SM<80 fp16 fallback via memory-efficient attention#28200
elwhyjay wants to merge 10 commits intomicrosoft:mainfrom
elwhyjay:feature/paged-attention-mea-fallback

Conversation

@elwhyjay
Copy link
Copy Markdown
Contributor

@elwhyjay elwhyjay commented Apr 23, 2026

Description

Adds a CUTLASS memory-efficient attention (MEA) fallback to the CUDA PagedAttention op, enabling the operator on sm<80 (Turing / Volta / Pascal) with fp16 for the first time. On sm>=80 the default FlashAttention path is unchanged; MEA is reachable via ORT_DISABLE_FLASH_ATTENTION=1 or the sdpa_kernel CUDA provider option for debugging and perf comparison.

Environment Before After
sm<80 + fp16 ❌ error ✅ MEA
sm<80 + bf16 ❌ error ❌ error (MEA requires sm>=80 for bf16)
sm>=80 + fp16/bf16 (default) ✅ FA ✅ FA (unchanged)
sm>=80 + ORT_DISABLE_FLASH_ATTENTION=1 / sdpa_kernel=EFFICIENT_ATTENTION ❌ error ✅ MEA

Motivation and Context

The original PagedAttention PR (#24595) landed with the title "CUDA SM80 support" — the op errors out immediately whenever FlashAttention isn't available (sm<80 or USE_FLASH_ATTENTION=0 builds). During that review, @tianleiwu flagged that the interface was too FlashAttention-specific ("not good for other EP like WebGPU, CPU etc.") and @aciddelgado agreed the FA-specific dependencies could be lifted at the kernel level.

This PR closes that gap for sm<80 fp16 by mirroring the exact pattern established in #20012 ("Packed QKV and Rotary Embedding Support for sm<80 GQA"). The same CUTLASS memory-efficient attention backend that covers GQA's sm<80 path now covers PagedAttention.

Related work:

Implementation

Dispatch cascade in paged_attention.cc: FlashAttention preferred; fall back to MemoryEfficientAttention via has_memory_efficient_attention(sm, is_half, is_bf16, head_size, head_size). No custom head-size or dtype bounds hardcoded — MEA's own helper gates fp16 sm>=53 / bf16 sm>=80 / head_size <= 1024 and % 8 == 0. This keeps us forward-compatible with any future expansion of MEA's supported range.

MEA path (UnfusedAttention<T>):

  1. Reuses existing preprocessing: LaunchGetCumulativeSeqlensKV (hoisted to paged_attention.cc so both FA and MEA paths consume a pre-populated buffer — single-producer refactor), rotary, packed-QKV unpack, ReshapeAndCache.
  2. New GatherAndExpandPagedKVCache CUDA kernel walks block_table to gather paged K/V into a packed-varlen [total_kv_tokens, num_heads, head_size] buffer, folding in GQA head expansion (so downstream MEA sees num_heads uniformly).
  3. Dispatches to run_memory_efficient_attention in varlen mode via seqstart_q_ptr = cumulative_seqlens_q + seqstart_k_ptr = cumulative_seqlens_kv (and has_custom_right_padding = false). No padding required; layout matches the kernel's expected [total_tokens, num_heads, head_size] with BSNH strides.

Scratch allocation: the MEA path D->H syncs cumulative_seqlens_kv[batch_size] via a pinned buffer to obtain total_kv_tokens on the host for tight gathered_key / gathered_value / fmha_buffer allocation. This adds a forward-per-call cudaStreamSynchronize — acceptable for a compatibility fallback (FA remains the hot path on supported hardware). Over-allocation (the no-sync alternative) would consume B × max_num_blocks_per_seq × block_size × num_heads × head_size × 2 × sizeof(T), which reaches GB-scale for realistic GQA models and was rejected.

fmha_buffer is sized with sizeof(float) (matching the GQA EfficientAttention pattern at group_query_attention.cc:482) because MEA's output accumulator is fp32 regardless of input dtype.

Testing

New TestPagedAttentionMEA class in test_paged_attention_cuda.py runs the existing parity matrix (rotary on/off, rotary_interleaved on/off, packed-QKV on/off, local window on/off, softcap 0/50, varied head sizes/shapes) against the MEA path via the sdpa_kernel CUDA provider option set to EFFICIENT_ATTENTION (=2, from AttentionBackend enum). Using a per-session provider option instead of an env var means both FA and MEA test classes coexist in the same pytest process — each InferenceSession creates its own CUDA EP with its own attention_kernel_options_.

The existing TestPagedAttention class is skipped wholesale on sm<80 by its has_flash_attention() gate, so without the new MEA class the fallback path would have no CI coverage.

Local verification (NVIDIA A100 80GB, CUDA 12.8, GCC 13.3):

TestPagedAttention:       24/24 passed (~60s)   # FA baseline — no regression
TestPagedAttentionMEA:    24/24 passed (~59s)   # new MEA path

Tolerance: rtol = atol = 5e-3 against the same torch reference used by the FA parity test. All combinations match.

sm<80 hardware coverage: I don't have local Turing / Volta / Pascal hardware, so real-SM coverage relies on MS CI. The code path exercised on A100 via sdpa_kernel=EFFICIENT_ATTENTION is the same one taken on sm<80; only the underlying CUTLASS kernel (run_memory_efficient_attention_sm50/70/75/80) differs per SM, and those are upstream and unmodified by this change.

Build note: built with --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 CMAKE_CXX_STANDARD=20. The explicit C++20 define was needed because the initial configure resolved CMAKE_CXX_STANDARD=17, under which ort_version_check.h's consteval usage fails to compile. Unrelated to this change.

… fallback

Prepares the SM<80 / FlashAttention-disabled fallback path for PagedAttention.
Adds a gather kernel that reads paged KV cache through block_table into a
packed-varlen layout [total_kv_tokens, num_heads, head_size], with GQA head
expansion folded in. The layout matches CUTLASS memory-efficient attention's
varlen contract (seqstart_q_ptr / seqstart_k_ptr).

Also adds gathered_key / gathered_value / fmha_buffer / use_memory_efficient_attention
fields to PagedAttentionData. No dispatch changes yet; follow-up commit wires
this into QkvToContext and paged_attention.cc.
Adds UnfusedAttention<T> that dispatches through CUTLASS memory-efficient
attention when FlashAttention is unavailable (SM<80 or ORT_DISABLE_FLASH_ATTENTION=1).
Mirrors the FlashAttention preprocessing (rotary, unpack, ReshapeAndCache),
then gathers the paged KV cache into a packed-varlen buffer via the new
GatherAndExpandPagedKVCache kernel and calls run_memory_efficient_attention
using the seqstart_q / seqstart_k varlen ABI.

QkvToContext now dispatches to UnfusedAttention when data.use_memory_efficient_attention
is set. total_kv_tokens added to PagedAttentionData, populated by the caller
after a D->H sync on cumulative_seqlens_kv[batch_size].

Follow-up commit (paged_attention.cc) wires the dispatch, sync, and scratch
buffer allocation. After that commit the op will actually exercise this path.

Mirrors the sm<80 GQA pattern established in PR microsoft#20012.
…tion path

In the prior Phase 2 commit, UnfusedAttention consumed data.cumulative_seqlens_kv
(for the gather kernel's seqstart_k_ptr and for the MemoryEfficientAttentionParams)
but never launched LaunchGetCumulativeSeqlensKV. That launcher lived only inside
FlashAttention(), so on the MEA dispatch path the buffer was uninitialized
memory -> silent corruption.

Mirror FlashAttention by calling LaunchGetCumulativeSeqlensKV at the top of
UnfusedAttention. Phase 3 will lift this out of both FA and MEA paths into
paged_attention.cc (the host also needs total_kv_tokens, read from the last
element, to size the tight gather scratch buffers before calling QkvToContext).
At that point the per-path calls become redundant and are removed in a single
symmetric edit.

Keeps each commit on the branch individually buildable and correct: bisect-safe.
Completes the SM<80 fp16 fallback path by wiring the dispatch in
paged_attention.cc and allocating the MEA-specific scratch buffers.

- Add disable_memory_efficient_attention_ member, initialized from
  kernel_options_->UseEfficientAttention() (honors ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION).
- Select backend: FA preferred; fall back to MEA via has_memory_efficient_attention()
  which correctly encodes fp16 sm>=53 / bf16 sm>=80 / head_size<=1024 & %8==0.
  No custom bounds hardcoded; MEA's own helper decides.
- Lift LaunchGetCumulativeSeqlensKV out of both FlashAttention() and
  UnfusedAttention() into paged_attention.cc. Both consumers now see a
  pre-populated buffer; paged_attention.cc is the single producer.
- For the MEA path only, D->H sync cumulative_seqlens_kv[batch_size] via
  pinned buffer to obtain total_kv_tokens on host, then tight-allocate
  gathered_key/value ([total_kv_tokens, num_heads, head_size]) and the
  conditional fmha_buffer. The fmha_buffer uses sizeof(float) not sizeof(T)
  because MEA's output accumulator is fp32 regardless of input dtype
  (matches GQA pattern at group_query_attention.cc:482).
- Surface a descriptive error message when both backends reject the
  configuration (e.g. SM<80 + bf16), naming the relevant gates and env vars
  so users can self-diagnose.

Scope:
- Opens: SM<80 fp16 PagedAttention (primary target); SM>=80 fp16/bf16 with
  ORT_DISABLE_FLASH_ATTENTION=1 (debug).
- Out of scope: SM<80 + bf16 (MEA structural limitation), custom paged kernel,
  quantized KV cache, FA3.

Mirrors the sm<80 GQA fallback pattern established in PR microsoft#20012.
Adds a parallel TestPagedAttentionMEA class that runs the full parity matrix
(rotary, packed_qkv, local_window, softcap, varied head sizes and shapes)
against the CUTLASS memory-efficient attention path. Without this, SM<80 CI
runs skip the existing TestPagedAttention class entirely (has_flash_attention
gate) and the fallback would never be exercised.

Mechanism: pass the CUDA provider option sdpa_kernel=2 (AttentionBackend::
EFFICIENT_ATTENTION bit) to force the kernel to select MEA even on SM>=80,
where FlashAttention would otherwise take the traffic. Per-session provider
options means FA and MEA test classes can coexist in the same pytest process
(each InferenceSession creates its own CUDA EP instance with its own
attention_kernel_options_ member).

- paged_attention_func / parity_check_paged_attention gain sdpa_kernel param,
  defaulted to 0 (no override) so the existing TestPagedAttention class is
  byte-for-byte equivalent.
- has_memory_efficient_attention() skip gate covers sm>=53 (fp16-only tests).
  bf16 tests would require sm>=80 but are out of scope for this file.

Follows the benchmark_mha.py precedent (SdpaKernel.EFFICIENT_ATTENTION = 2).
…eader

Phase 3 hoisted LaunchGetCumulativeSeqlensKV out of FlashAttention() /
UnfusedAttention() into paged_attention.cc, which now calls the launcher
from a different TU than its definition. The declaration was missing from
paged_attention_impl.h, causing the Release build to fail with

  error: 'LaunchGetCumulativeSeqlensKV' was not declared in this scope

at paged_attention.cc:203 during template instantiation.

Signature matches the definition in paged_attention_impl.cu (non-templated,
int32_t pointers, cudaStream_t). No behavioral change — fixes linkage only.
In paged_attention.cc the local `ort_stream` returned by GetOrtStream(context)
is an OrtStreamAdapter (value type), not a pointer. The previous
`ort_stream->GetHandle()` was a copy-paste from impl.cu where the equivalent
variable is a `Stream*` parameter. The value-type access fails to compile:

  error: base operand of '->' has non-pointer type
  'onnxruntime::cuda::OrtStreamAdapter'

Use `ort_stream.get()->GetHandle()` — OrtStreamAdapter::get() returns a
Stream*, and Stream::GetHandle() returns the underlying cudaStream_t.

No behavioral change. Fix-up for f1338c7 (Phase 3 dispatch).
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the CUDA contrib PagedAttention operator to work on SM<80 with fp16 by adding a CUTLASS MemoryEfficientAttention (MEA) fallback when FlashAttention is unavailable or explicitly disabled, and adds Python test coverage that forces the MEA path via the CUDA EP sdpa_kernel provider option.

Changes:

  • Add a Flash → MEA dispatch cascade in PagedAttention (CUDA) and plumb new MEA-specific scratch buffers into PagedAttentionData.
  • Implement a CUDA kernel to gather paged KV-cache into a packed-varlen buffer suitable for CUTLASS fMHA, then call run_memory_efficient_attention.
  • Add a new Python test class that forces MEA (sdpa_kernel=EFFICIENT_ATTENTION) to cover the fallback path.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
onnxruntime/contrib_ops/cuda/bert/paged_attention.cc Select FlashAttention when supported, otherwise fall back to MEA; allocate/prepare shared and MEA-only scratch buffers.
onnxruntime/contrib_ops/cuda/bert/paged_attention.h Add disable_memory_efficient_attention_ member for backend gating.
onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu Add KV gather/expand kernel and MEA runner path; remove FA-internal cumulative KV seqlens producer.
onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.h Expose LaunchGetCumulativeSeqlensKV so the caller can populate KV cumulative seqlens for both backends.
onnxruntime/contrib_ops/cuda/bert/attention_data.h Extend PagedAttentionData with gathered KV buffers, fMHA workspace, total_kv_tokens, and a MEA enable flag.
onnxruntime/test/python/transformers/test_paged_attention_cuda.py Add sdpa_kernel plumbing and a new MEA-forced parity test class.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu
Comment thread onnxruntime/contrib_ops/cuda/bert/paged_attention.cc
Comment thread onnxruntime/contrib_ops/cuda/bert/paged_attention.cc Outdated
@elwhyjay
Copy link
Copy Markdown
Contributor Author

Hi @tianleiwu. Checking the copliot's review — all four comments are fair. Addressing them in a single follow-up commit.

1. Gather kernel: linear scan → binary search. Good catch on multiplying the scan cost by num_heads * head_size. The binary search preserves the same monotonicity assumption the original linear scan already relied on: cumulative_seqlens_kv is a prefix sum of non-negative per-batch KV lengths (past_seqlens[i] + new_tokens[i]), so it is monotonically non-decreasing for any valid op input. The previous if (token_id < cumulative_seqlens_kv[i + 1]) break loop would also return wrong results on non-monotonic input, so no new precondition is introduced. Making this explicit in a comment above the search.

2. Rename UnfusedAttentionEfficientAttention. Agreed — the old name clashed with the math-based "unfused" kernel concept used elsewhere in the attention code. Renaming the function, the QkvToContext dispatch site, and the fallthrough error message.

3. batch_size <= 256 precondition. Adding an explicit guard in paged_attention.cc immediately before LaunchGetCumulativeSeqlensKV with a clear error. Keeping this as a precondition rather than reworking to cub::DeviceScan to keep this PR scoped; the scan can be upgraded in a follow-up if a higher batch_size is needed in practice.

4. total_kv_tokens == 0 early return. Adding a Status::OK() path for the legal empty input (token_count == 0 && total_kv_tokens == 0). Keeping a distinct negative-value error branch as a defensive check.

Pushing shortly once the local rebuild and parity tests are green.

- GatherAndExpandPagedKVCache: replace per-element linear scan over batch_size
  with a binary search on cumulative_seqlens_kv. The scan previously ran once
  per (token, head, h) element, multiplying its cost by num_heads * head_size;
  log2(batch_size) is strictly better. The binary search preserves the same
  monotonicity assumption the original scan relied on (prefix sum of
  non-negative per-batch KV lengths) — no new precondition is introduced.
  This is also documented in a comment above the search.
- Rename UnfusedAttention -> EfficientAttention (and update the QkvToContext
  dispatch site and the fallthrough error message). The function dispatches
  to CUTLASS MemoryEfficientAttention; the previous name collided with the
  math-based "unfused" kernel concept used elsewhere in the attention code.
- paged_attention.cc: add an explicit batch_size <= 256 precondition before
  LaunchGetCumulativeSeqlensKV with a descriptive error. The launcher uses
  per-block cub::BlockScan with a 256-thread block, so multi-block launches
  scan independently and produce wrong cumulative sums. A grid-wide scan
  (e.g. cub::DeviceScan) could lift this in a follow-up; for now the guard
  prevents silent corruption.
- EfficientAttention path: return Status::OK() for the legal empty input
  (token_count == 0 && total_kv_tokens == 0) rather than erroring out.
  Kept a distinct negative-value error branch as a defensive check.

Tested on A100 with CMAKE_CUDA_ARCHITECTURES=80 CMAKE_CXX_STANDARD=20:
TestPagedAttention 24/24 + TestPagedAttentionMEA 24/24 green.
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

REQUEST CHANGES — Two correctness concerns, otherwise well-structured fallback implementation.

Clean extension of PagedAttention to SM<80 hardware. The dispatch cascade (FA preferred → MEA fallback) mirrors the established GQA pattern from #20012, and the test coverage via sdpa_kernel is good. The binary search for batch_id in the gather kernel and the explicit batch_size ≤ 256 guard are nice touches.

Two issues to fix before merge:

  1. Integer overflow in gather kernelint total_elems overflows for realistic configurations within the batch_size=256 limit.
  2. max_query_len grid-sizing mismatch — the formula used is safe for FlashAttention (which doesn't use it for grid sizing) but CUTLASS MEA does use p.sequence_length for grid X dimension, so underestimation silently drops query tokens.

Comment thread onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu Outdated
…tianleiwu review)

Two silent-corruption bugs flagged in the PR microsoft#28200 review:

1. Int32 overflow in GatherAndExpandPagedKVCache / LaunchGatherAndExpandPagedKVCache.
   `total_kv_tokens * num_heads * head_size` was computed in int; for realistic large-
   context GQA configs (e.g. 2M tokens * 64 heads * 128 head_size = 16.4B) this
   overflows INT32_MAX, producing a wrong element count, wrong block count, and wrong
   `tid` bound — silent corruption or OOB reads.

   Kernel now takes total_elems as int64_t and uses a grid-stride loop instead of
   a per-thread (tid >= total_elems) early-exit. Launcher computes total_elems in
   int64_t and caps the grid at kMaxBlocks = 65535 (grid-stride loop covers the rest).
   paged_idx, page_stride, and the outer stride are all int64_t so no intermediate
   multiplication overflows.

2. max_query_len heuristic (token_count - batch_size + 1) silently drops query
   tokens in the MEA path.
   CUTLASS MEA uses `p.sequence_length` directly as grid_x (ceil_div(sequence_length,
   kQueriesPerBlock)); missing blocks are never launched, so if any batch has 0 new
   query tokens the heuristic underestimates the actual max and query tokens from
   larger batches are silently unprocessed. Same issue affects the rotary grid.
   The FA path is unaffected — max_query_len is a hint there.

   Added an int max_query_len field to PagedAttentionData. paged_attention.cc now
   D->H syncs the full cumulative_seqlens_q (and cumulative_seqlens_kv) — both are
   batch_size+1 ints so the extra copy is cheap and avoids a second sync. The host
   computes max per-batch new-query length and propagates via data.max_query_len;
   EfficientAttention uses data.max_query_len instead of the heuristic.
@elwhyjay
Copy link
Copy Markdown
Contributor Author

Thanks @tianleiwu — both of these were my mistakes, good catches.
Pushing the fix now.

tianleiwu
tianleiwu previously approved these changes Apr 24, 2026
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

APPROVE — Both correctness issues from round 1 are cleanly fixed.

Integer overflow (concern #1): total_elems, paged_idx, and the grid-stride arithmetic are now int64_t. The kernel uses a grid-stride loop that naturally handles element counts beyond INT32_MAX. Pre-computed loop-invariant values (num_heads_times_head, page_stride, q_kv_head_ratio) are a nice optimization touch.

max_query_len grid sizing (concern #2): paged_attention.cc now D→H copies the full cumulative_seqlens_q array and computes the actual per-batch max on the host, passing it via data.max_query_len. EfficientAttention uses this exact value for both p.sequence_length (MEA grid X dimension) and the rotary grid, eliminating the underestimation risk.

Also resolved 4 stale threads from the earlier automated review (all addressed in current head).

@elwhyjay
Copy link
Copy Markdown
Contributor Author

Thanks for the quick approval! @tianleiwu

I noticed one tiny adjacent empty-query edge case during a final pass: token_count == 0 with non-zero past_seqlens can still enter backend preprocessing. I have a small early-return fix locally. Should I push it here, or keep the approved head unchanged and handle it separately?

@tianleiwu
Copy link
Copy Markdown
Contributor

tianleiwu commented Apr 24, 2026

@elwhyjay, it is fine to push it here.

One Suggestion (pre-existing, not introduced here): the FA path at paged_attention_impl.cu:361 still uses the token_count - batch_size + 1 heuristic for max_query_len, which is then passed to LaunchRotaryEmbeddingKernel as grid.x. Same silent-drop failure mode as the MEA bug this PR fixed — could reuse data.max_query_len in a follow-up. Not a blocker.

… == 0)

Found while verifying the MEA-path edge cases from the round-2 review:
token_count == 0 with non-zero past_seqlens would still enter backend
preprocessing — FA path's LaunchReshapeAndCache hits total_size = 0,
threads = min(0, max_threads) = 0, then blocks = (0 + 0 - 1) / 0
(division by zero). MEA path would also mis-report "total_kv_tokens
is zero for non-empty input" even though token_count == 0 is the
non-empty coordinate.

Move the empty-query check right after the cache-aliasing verification
(output is already [0, hidden_size] and the cache outputs alias the
inputs, so no backend work is needed). This protects both backends
with a single guard and removes the now-redundant nested check inside
the MEA block.
@elwhyjay
Copy link
Copy Markdown
Contributor Author

Thanks @tianleiwu! Pushing the empty-query early-return fix now.

Good call on the FA rotary grid — it's the same silent-drop pattern as the MEA case. I'll handle it as a follow-up PR so we don't drag additional scope here; lifting max_query_len into common code would add a forward-per-call D→H sync to the FA path, which feels worth its own perf discussion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants