You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Two silent-corruption bugs flagged in the PR #28200 review:
1. Int32 overflow in GatherAndExpandPagedKVCache / LaunchGatherAndExpandPagedKVCache.
`total_kv_tokens * num_heads * head_size` was computed in int; for realistic large-
context GQA configs (e.g. 2M tokens * 64 heads * 128 head_size = 16.4B) this
overflows INT32_MAX, producing a wrong element count, wrong block count, and wrong
`tid` bound — silent corruption or OOB reads.
Kernel now takes total_elems as int64_t and uses a grid-stride loop instead of
a per-thread (tid >= total_elems) early-exit. Launcher computes total_elems in
int64_t and caps the grid at kMaxBlocks = 65535 (grid-stride loop covers the rest).
paged_idx, page_stride, and the outer stride are all int64_t so no intermediate
multiplication overflows.
2. max_query_len heuristic (token_count - batch_size + 1) silently drops query
tokens in the MEA path.
CUTLASS MEA uses `p.sequence_length` directly as grid_x (ceil_div(sequence_length,
kQueriesPerBlock)); missing blocks are never launched, so if any batch has 0 new
query tokens the heuristic underestimates the actual max and query tokens from
larger batches are silently unprocessed. Same issue affects the rotary grid.
The FA path is unaffected — max_query_len is a hint there.
Added an int max_query_len field to PagedAttentionData. paged_attention.cc now
D->H syncs the full cumulative_seqlens_q (and cumulative_seqlens_kv) — both are
batch_size+1 ints so the extra copy is cheap and avoids a second sync. The host
computes max per-batch new-query length and propagates via data.max_query_len;
EfficientAttention uses data.max_query_len instead of the heuristic.
0 commit comments