[CUDA] PagedAttention: add SM<80 fp16 fallback via memory-efficient attention#28200
[CUDA] PagedAttention: add SM<80 fp16 fallback via memory-efficient attention#28200elwhyjay wants to merge 10 commits intomicrosoft:mainfrom
Conversation
… fallback Prepares the SM<80 / FlashAttention-disabled fallback path for PagedAttention. Adds a gather kernel that reads paged KV cache through block_table into a packed-varlen layout [total_kv_tokens, num_heads, head_size], with GQA head expansion folded in. The layout matches CUTLASS memory-efficient attention's varlen contract (seqstart_q_ptr / seqstart_k_ptr). Also adds gathered_key / gathered_value / fmha_buffer / use_memory_efficient_attention fields to PagedAttentionData. No dispatch changes yet; follow-up commit wires this into QkvToContext and paged_attention.cc.
Adds UnfusedAttention<T> that dispatches through CUTLASS memory-efficient attention when FlashAttention is unavailable (SM<80 or ORT_DISABLE_FLASH_ATTENTION=1). Mirrors the FlashAttention preprocessing (rotary, unpack, ReshapeAndCache), then gathers the paged KV cache into a packed-varlen buffer via the new GatherAndExpandPagedKVCache kernel and calls run_memory_efficient_attention using the seqstart_q / seqstart_k varlen ABI. QkvToContext now dispatches to UnfusedAttention when data.use_memory_efficient_attention is set. total_kv_tokens added to PagedAttentionData, populated by the caller after a D->H sync on cumulative_seqlens_kv[batch_size]. Follow-up commit (paged_attention.cc) wires the dispatch, sync, and scratch buffer allocation. After that commit the op will actually exercise this path. Mirrors the sm<80 GQA pattern established in PR microsoft#20012.
…tion path In the prior Phase 2 commit, UnfusedAttention consumed data.cumulative_seqlens_kv (for the gather kernel's seqstart_k_ptr and for the MemoryEfficientAttentionParams) but never launched LaunchGetCumulativeSeqlensKV. That launcher lived only inside FlashAttention(), so on the MEA dispatch path the buffer was uninitialized memory -> silent corruption. Mirror FlashAttention by calling LaunchGetCumulativeSeqlensKV at the top of UnfusedAttention. Phase 3 will lift this out of both FA and MEA paths into paged_attention.cc (the host also needs total_kv_tokens, read from the last element, to size the tight gather scratch buffers before calling QkvToContext). At that point the per-path calls become redundant and are removed in a single symmetric edit. Keeps each commit on the branch individually buildable and correct: bisect-safe.
Completes the SM<80 fp16 fallback path by wiring the dispatch in paged_attention.cc and allocating the MEA-specific scratch buffers. - Add disable_memory_efficient_attention_ member, initialized from kernel_options_->UseEfficientAttention() (honors ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION). - Select backend: FA preferred; fall back to MEA via has_memory_efficient_attention() which correctly encodes fp16 sm>=53 / bf16 sm>=80 / head_size<=1024 & %8==0. No custom bounds hardcoded; MEA's own helper decides. - Lift LaunchGetCumulativeSeqlensKV out of both FlashAttention() and UnfusedAttention() into paged_attention.cc. Both consumers now see a pre-populated buffer; paged_attention.cc is the single producer. - For the MEA path only, D->H sync cumulative_seqlens_kv[batch_size] via pinned buffer to obtain total_kv_tokens on host, then tight-allocate gathered_key/value ([total_kv_tokens, num_heads, head_size]) and the conditional fmha_buffer. The fmha_buffer uses sizeof(float) not sizeof(T) because MEA's output accumulator is fp32 regardless of input dtype (matches GQA pattern at group_query_attention.cc:482). - Surface a descriptive error message when both backends reject the configuration (e.g. SM<80 + bf16), naming the relevant gates and env vars so users can self-diagnose. Scope: - Opens: SM<80 fp16 PagedAttention (primary target); SM>=80 fp16/bf16 with ORT_DISABLE_FLASH_ATTENTION=1 (debug). - Out of scope: SM<80 + bf16 (MEA structural limitation), custom paged kernel, quantized KV cache, FA3. Mirrors the sm<80 GQA fallback pattern established in PR microsoft#20012.
Adds a parallel TestPagedAttentionMEA class that runs the full parity matrix (rotary, packed_qkv, local_window, softcap, varied head sizes and shapes) against the CUTLASS memory-efficient attention path. Without this, SM<80 CI runs skip the existing TestPagedAttention class entirely (has_flash_attention gate) and the fallback would never be exercised. Mechanism: pass the CUDA provider option sdpa_kernel=2 (AttentionBackend:: EFFICIENT_ATTENTION bit) to force the kernel to select MEA even on SM>=80, where FlashAttention would otherwise take the traffic. Per-session provider options means FA and MEA test classes can coexist in the same pytest process (each InferenceSession creates its own CUDA EP instance with its own attention_kernel_options_ member). - paged_attention_func / parity_check_paged_attention gain sdpa_kernel param, defaulted to 0 (no override) so the existing TestPagedAttention class is byte-for-byte equivalent. - has_memory_efficient_attention() skip gate covers sm>=53 (fp16-only tests). bf16 tests would require sm>=80 but are out of scope for this file. Follows the benchmark_mha.py precedent (SdpaKernel.EFFICIENT_ATTENTION = 2).
…eader Phase 3 hoisted LaunchGetCumulativeSeqlensKV out of FlashAttention() / UnfusedAttention() into paged_attention.cc, which now calls the launcher from a different TU than its definition. The declaration was missing from paged_attention_impl.h, causing the Release build to fail with error: 'LaunchGetCumulativeSeqlensKV' was not declared in this scope at paged_attention.cc:203 during template instantiation. Signature matches the definition in paged_attention_impl.cu (non-templated, int32_t pointers, cudaStream_t). No behavioral change — fixes linkage only.
In paged_attention.cc the local `ort_stream` returned by GetOrtStream(context) is an OrtStreamAdapter (value type), not a pointer. The previous `ort_stream->GetHandle()` was a copy-paste from impl.cu where the equivalent variable is a `Stream*` parameter. The value-type access fails to compile: error: base operand of '->' has non-pointer type 'onnxruntime::cuda::OrtStreamAdapter' Use `ort_stream.get()->GetHandle()` — OrtStreamAdapter::get() returns a Stream*, and Stream::GetHandle() returns the underlying cudaStream_t. No behavioral change. Fix-up for f1338c7 (Phase 3 dispatch).
There was a problem hiding this comment.
Pull request overview
This PR extends the CUDA contrib PagedAttention operator to work on SM<80 with fp16 by adding a CUTLASS MemoryEfficientAttention (MEA) fallback when FlashAttention is unavailable or explicitly disabled, and adds Python test coverage that forces the MEA path via the CUDA EP sdpa_kernel provider option.
Changes:
- Add a Flash → MEA dispatch cascade in
PagedAttention(CUDA) and plumb new MEA-specific scratch buffers intoPagedAttentionData. - Implement a CUDA kernel to gather paged KV-cache into a packed-varlen buffer suitable for CUTLASS fMHA, then call
run_memory_efficient_attention. - Add a new Python test class that forces MEA (
sdpa_kernel=EFFICIENT_ATTENTION) to cover the fallback path.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/contrib_ops/cuda/bert/paged_attention.cc | Select FlashAttention when supported, otherwise fall back to MEA; allocate/prepare shared and MEA-only scratch buffers. |
| onnxruntime/contrib_ops/cuda/bert/paged_attention.h | Add disable_memory_efficient_attention_ member for backend gating. |
| onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu | Add KV gather/expand kernel and MEA runner path; remove FA-internal cumulative KV seqlens producer. |
| onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.h | Expose LaunchGetCumulativeSeqlensKV so the caller can populate KV cumulative seqlens for both backends. |
| onnxruntime/contrib_ops/cuda/bert/attention_data.h | Extend PagedAttentionData with gathered KV buffers, fMHA workspace, total_kv_tokens, and a MEA enable flag. |
| onnxruntime/test/python/transformers/test_paged_attention_cuda.py | Add sdpa_kernel plumbing and a new MEA-forced parity test class. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Hi @tianleiwu. Checking the copliot's review — all four comments are fair. Addressing them in a single follow-up commit. 1. Gather kernel: linear scan → binary search. Good catch on multiplying the scan cost by 2. Rename 3. 4. Pushing shortly once the local rebuild and parity tests are green. |
- GatherAndExpandPagedKVCache: replace per-element linear scan over batch_size with a binary search on cumulative_seqlens_kv. The scan previously ran once per (token, head, h) element, multiplying its cost by num_heads * head_size; log2(batch_size) is strictly better. The binary search preserves the same monotonicity assumption the original scan relied on (prefix sum of non-negative per-batch KV lengths) — no new precondition is introduced. This is also documented in a comment above the search. - Rename UnfusedAttention -> EfficientAttention (and update the QkvToContext dispatch site and the fallthrough error message). The function dispatches to CUTLASS MemoryEfficientAttention; the previous name collided with the math-based "unfused" kernel concept used elsewhere in the attention code. - paged_attention.cc: add an explicit batch_size <= 256 precondition before LaunchGetCumulativeSeqlensKV with a descriptive error. The launcher uses per-block cub::BlockScan with a 256-thread block, so multi-block launches scan independently and produce wrong cumulative sums. A grid-wide scan (e.g. cub::DeviceScan) could lift this in a follow-up; for now the guard prevents silent corruption. - EfficientAttention path: return Status::OK() for the legal empty input (token_count == 0 && total_kv_tokens == 0) rather than erroring out. Kept a distinct negative-value error branch as a defensive check. Tested on A100 with CMAKE_CUDA_ARCHITECTURES=80 CMAKE_CXX_STANDARD=20: TestPagedAttention 24/24 + TestPagedAttentionMEA 24/24 green.
tianleiwu
left a comment
There was a problem hiding this comment.
REQUEST CHANGES — Two correctness concerns, otherwise well-structured fallback implementation.
Clean extension of PagedAttention to SM<80 hardware. The dispatch cascade (FA preferred → MEA fallback) mirrors the established GQA pattern from #20012, and the test coverage via sdpa_kernel is good. The binary search for batch_id in the gather kernel and the explicit batch_size ≤ 256 guard are nice touches.
Two issues to fix before merge:
- Integer overflow in gather kernel —
int total_elemsoverflows for realistic configurations within the batch_size=256 limit. max_query_lengrid-sizing mismatch — the formula used is safe for FlashAttention (which doesn't use it for grid sizing) but CUTLASS MEA does usep.sequence_lengthfor grid X dimension, so underestimation silently drops query tokens.
…tianleiwu review) Two silent-corruption bugs flagged in the PR microsoft#28200 review: 1. Int32 overflow in GatherAndExpandPagedKVCache / LaunchGatherAndExpandPagedKVCache. `total_kv_tokens * num_heads * head_size` was computed in int; for realistic large- context GQA configs (e.g. 2M tokens * 64 heads * 128 head_size = 16.4B) this overflows INT32_MAX, producing a wrong element count, wrong block count, and wrong `tid` bound — silent corruption or OOB reads. Kernel now takes total_elems as int64_t and uses a grid-stride loop instead of a per-thread (tid >= total_elems) early-exit. Launcher computes total_elems in int64_t and caps the grid at kMaxBlocks = 65535 (grid-stride loop covers the rest). paged_idx, page_stride, and the outer stride are all int64_t so no intermediate multiplication overflows. 2. max_query_len heuristic (token_count - batch_size + 1) silently drops query tokens in the MEA path. CUTLASS MEA uses `p.sequence_length` directly as grid_x (ceil_div(sequence_length, kQueriesPerBlock)); missing blocks are never launched, so if any batch has 0 new query tokens the heuristic underestimates the actual max and query tokens from larger batches are silently unprocessed. Same issue affects the rotary grid. The FA path is unaffected — max_query_len is a hint there. Added an int max_query_len field to PagedAttentionData. paged_attention.cc now D->H syncs the full cumulative_seqlens_q (and cumulative_seqlens_kv) — both are batch_size+1 ints so the extra copy is cheap and avoids a second sync. The host computes max per-batch new-query length and propagates via data.max_query_len; EfficientAttention uses data.max_query_len instead of the heuristic.
|
Thanks @tianleiwu — both of these were my mistakes, good catches. |
tianleiwu
left a comment
There was a problem hiding this comment.
APPROVE — Both correctness issues from round 1 are cleanly fixed.
Integer overflow (concern #1): total_elems, paged_idx, and the grid-stride arithmetic are now int64_t. The kernel uses a grid-stride loop that naturally handles element counts beyond INT32_MAX. Pre-computed loop-invariant values (num_heads_times_head, page_stride, q_kv_head_ratio) are a nice optimization touch.
max_query_len grid sizing (concern #2): paged_attention.cc now D→H copies the full cumulative_seqlens_q array and computes the actual per-batch max on the host, passing it via data.max_query_len. EfficientAttention uses this exact value for both p.sequence_length (MEA grid X dimension) and the rotary grid, eliminating the underestimation risk.
Also resolved 4 stale threads from the earlier automated review (all addressed in current head).
|
Thanks for the quick approval! @tianleiwu I noticed one tiny adjacent empty-query edge case during a final pass: |
|
@elwhyjay, it is fine to push it here. One Suggestion (pre-existing, not introduced here): the FA path at paged_attention_impl.cu:361 still uses the token_count - batch_size + 1 heuristic for max_query_len, which is then passed to LaunchRotaryEmbeddingKernel as grid.x. Same silent-drop failure mode as the MEA bug this PR fixed — could reuse data.max_query_len in a follow-up. Not a blocker. |
… == 0) Found while verifying the MEA-path edge cases from the round-2 review: token_count == 0 with non-zero past_seqlens would still enter backend preprocessing — FA path's LaunchReshapeAndCache hits total_size = 0, threads = min(0, max_threads) = 0, then blocks = (0 + 0 - 1) / 0 (division by zero). MEA path would also mis-report "total_kv_tokens is zero for non-empty input" even though token_count == 0 is the non-empty coordinate. Move the empty-query check right after the cache-aliasing verification (output is already [0, hidden_size] and the cache outputs alias the inputs, so no backend work is needed). This protects both backends with a single guard and removes the now-redundant nested check inside the MEA block.
|
Thanks @tianleiwu! Pushing the empty-query early-return fix now. Good call on the FA rotary grid — it's the same silent-drop pattern as the MEA case. I'll handle it as a follow-up PR so we don't drag additional scope here; lifting |
Description
Adds a CUTLASS memory-efficient attention (MEA) fallback to the CUDA PagedAttention op, enabling the operator on sm<80 (Turing / Volta / Pascal) with fp16 for the first time. On sm>=80 the default FlashAttention path is unchanged; MEA is reachable via
ORT_DISABLE_FLASH_ATTENTION=1or thesdpa_kernelCUDA provider option for debugging and perf comparison.ORT_DISABLE_FLASH_ATTENTION=1/sdpa_kernel=EFFICIENT_ATTENTIONMotivation and Context
The original PagedAttention PR (#24595) landed with the title "CUDA SM80 support" — the op errors out immediately whenever FlashAttention isn't available (sm<80 or
USE_FLASH_ATTENTION=0builds). During that review, @tianleiwu flagged that the interface was too FlashAttention-specific ("not good for other EP like WebGPU, CPU etc.") and @aciddelgado agreed the FA-specific dependencies could be lifted at the kernel level.This PR closes that gap for sm<80 fp16 by mirroring the exact pattern established in #20012 ("Packed QKV and Rotary Embedding Support for sm<80 GQA"). The same CUTLASS memory-efficient attention backend that covers GQA's sm<80 path now covers PagedAttention.
Related work:
Implementation
Dispatch cascade in
paged_attention.cc: FlashAttention preferred; fall back to MemoryEfficientAttention viahas_memory_efficient_attention(sm, is_half, is_bf16, head_size, head_size). No custom head-size or dtype bounds hardcoded — MEA's own helper gates fp16 sm>=53 / bf16 sm>=80 / head_size <= 1024 and% 8 == 0. This keeps us forward-compatible with any future expansion of MEA's supported range.MEA path (
UnfusedAttention<T>):LaunchGetCumulativeSeqlensKV(hoisted topaged_attention.ccso both FA and MEA paths consume a pre-populated buffer — single-producer refactor), rotary, packed-QKV unpack,ReshapeAndCache.GatherAndExpandPagedKVCacheCUDA kernel walksblock_tableto gather paged K/V into a packed-varlen[total_kv_tokens, num_heads, head_size]buffer, folding in GQA head expansion (so downstream MEA seesnum_headsuniformly).run_memory_efficient_attentionin varlen mode viaseqstart_q_ptr = cumulative_seqlens_q+seqstart_k_ptr = cumulative_seqlens_kv(andhas_custom_right_padding = false). No padding required; layout matches the kernel's expected[total_tokens, num_heads, head_size]with BSNH strides.Scratch allocation: the MEA path D->H syncs
cumulative_seqlens_kv[batch_size]via a pinned buffer to obtaintotal_kv_tokenson the host for tightgathered_key/gathered_value/fmha_bufferallocation. This adds a forward-per-callcudaStreamSynchronize— acceptable for a compatibility fallback (FA remains the hot path on supported hardware). Over-allocation (the no-sync alternative) would consumeB × max_num_blocks_per_seq × block_size × num_heads × head_size × 2 × sizeof(T), which reaches GB-scale for realistic GQA models and was rejected.fmha_bufferis sized withsizeof(float)(matching the GQA EfficientAttention pattern atgroup_query_attention.cc:482) because MEA's output accumulator is fp32 regardless of input dtype.Testing
New
TestPagedAttentionMEAclass intest_paged_attention_cuda.pyruns the existing parity matrix (rotary on/off, rotary_interleaved on/off, packed-QKV on/off, local window on/off, softcap 0/50, varied head sizes/shapes) against the MEA path via thesdpa_kernelCUDA provider option set toEFFICIENT_ATTENTION(=2, fromAttentionBackendenum). Using a per-session provider option instead of an env var means both FA and MEA test classes coexist in the same pytest process — each InferenceSession creates its own CUDA EP with its ownattention_kernel_options_.The existing
TestPagedAttentionclass is skipped wholesale on sm<80 by itshas_flash_attention()gate, so without the new MEA class the fallback path would have no CI coverage.Local verification (NVIDIA A100 80GB, CUDA 12.8, GCC 13.3):
Tolerance:
rtol = atol = 5e-3against the same torch reference used by the FA parity test. All combinations match.sm<80 hardware coverage: I don't have local Turing / Volta / Pascal hardware, so real-SM coverage relies on MS CI. The code path exercised on A100 via
sdpa_kernel=EFFICIENT_ATTENTIONis the same one taken on sm<80; only the underlying CUTLASS kernel (run_memory_efficient_attention_sm50/70/75/80) differs per SM, and those are upstream and unmodified by this change.Build note: built with
--cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 CMAKE_CXX_STANDARD=20. The explicit C++20 define was needed because the initial configure resolvedCMAKE_CXX_STANDARD=17, under whichort_version_check.h'sconstevalusage fails to compile. Unrelated to this change.