Skip to content

Commit f345865

Browse files
committed
[CUDA] PagedAttention: fix int32 overflow + heuristic max_query_len (tianleiwu review)
Two silent-corruption bugs flagged in the PR #28200 review: 1. Int32 overflow in GatherAndExpandPagedKVCache / LaunchGatherAndExpandPagedKVCache. `total_kv_tokens * num_heads * head_size` was computed in int; for realistic large- context GQA configs (e.g. 2M tokens * 64 heads * 128 head_size = 16.4B) this overflows INT32_MAX, producing a wrong element count, wrong block count, and wrong `tid` bound — silent corruption or OOB reads. Kernel now takes total_elems as int64_t and uses a grid-stride loop instead of a per-thread (tid >= total_elems) early-exit. Launcher computes total_elems in int64_t and caps the grid at kMaxBlocks = 65535 (grid-stride loop covers the rest). paged_idx, page_stride, and the outer stride are all int64_t so no intermediate multiplication overflows. 2. max_query_len heuristic (token_count - batch_size + 1) silently drops query tokens in the MEA path. CUTLASS MEA uses `p.sequence_length` directly as grid_x (ceil_div(sequence_length, kQueriesPerBlock)); missing blocks are never launched, so if any batch has 0 new query tokens the heuristic underestimates the actual max and query tokens from larger batches are silently unprocessed. Same issue affects the rotary grid. The FA path is unaffected — max_query_len is a hint there. Added an int max_query_len field to PagedAttentionData. paged_attention.cc now D->H syncs the full cumulative_seqlens_q (and cumulative_seqlens_kv) — both are batch_size+1 ints so the extra copy is cheap and avoids a second sync. The host computes max per-batch new-query length and propagates via data.max_query_len; EfficientAttention uses data.max_query_len instead of the heuristic.
1 parent 93aff52 commit f345865

3 files changed

Lines changed: 93 additions & 51 deletions

File tree

onnxruntime/contrib_ops/cuda/bert/attention_data.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,13 @@ struct PagedAttentionData {
233233
// Populated by the caller after a D->H sync on cumulative_seqlens_kv[batch_size].
234234
int total_kv_tokens = 0;
235235

236+
// Actual max of per-batch new-query lengths (cumulative_seqlens_q[i+1] - cumulative_seqlens_q[i]).
237+
// Populated by the caller via the same D->H sync so the MEA path's rotary grid and MEA's
238+
// grid_x (ceil_div(sequence_length, kQueriesPerBlock)) cover every query token. The previous
239+
// heuristic `token_count - batch_size + 1` underestimates when any batch has 0 new tokens,
240+
// producing silent per-token dropout in MEA and rotary.
241+
int max_query_len = 0;
242+
236243
// Output Tensors
237244
T* output = nullptr;
238245

onnxruntime/contrib_ops/cuda/bert/paged_attention.cc

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -221,18 +221,38 @@ Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
221221
parameters.batch_size, cuda_stream));
222222

223223
int total_kv_tokens = 0;
224+
int max_query_len = 0;
224225
IAllocatorUniquePtr<void> gathered_key_buffer;
225226
IAllocatorUniquePtr<void> gathered_value_buffer;
226227
IAllocatorUniquePtr<void> fmha_buffer;
227228

228229
#if USE_MEMORY_EFFICIENT_ATTENTION
229230
if (use_memory_efficient_attention) {
230-
auto total_kv_pinned = this->AllocateBufferOnCPUPinned<int>(1);
231-
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(total_kv_pinned.get(),
232-
cumulative_seqlens_kv_ptr + parameters.batch_size,
233-
sizeof(int), cudaMemcpyDeviceToHost, cuda_stream));
231+
// MEA needs two host-side quantities:
232+
// - total_kv_tokens (= cumulative_seqlens_kv[batch_size]) to size tight gather buffers.
233+
// - max_query_len (= max per-batch new-query length) to size the rotary and MEA grids
234+
// correctly. The heuristic `token_count - batch_size + 1` underestimates when any
235+
// batch has 0 new tokens (valid input), silently dropping query-tokens from those
236+
// larger-than-average batches.
237+
// Both come from cumulative_seqlens_q / cumulative_seqlens_kv, which are tiny (batch+1
238+
// ints each), so one D->H copy of the full arrays is cheaper than issuing an extra
239+
// reduction kernel and avoids a second sync.
240+
const int kCumulativeCount = parameters.batch_size + 1;
241+
auto cum_q_pinned = this->AllocateBufferOnCPUPinned<int>(kCumulativeCount);
242+
auto cum_kv_pinned = this->AllocateBufferOnCPUPinned<int>(kCumulativeCount);
243+
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cum_q_pinned.get(),
244+
reinterpret_cast<const int*>(cumulative_seqlens_q->Data<int>()),
245+
sizeof(int) * kCumulativeCount, cudaMemcpyDeviceToHost, cuda_stream));
246+
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cum_kv_pinned.get(), cumulative_seqlens_kv_ptr,
247+
sizeof(int) * kCumulativeCount, cudaMemcpyDeviceToHost, cuda_stream));
234248
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream));
235-
total_kv_tokens = total_kv_pinned.get()[0];
249+
total_kv_tokens = cum_kv_pinned.get()[parameters.batch_size];
250+
for (int i = 0; i < parameters.batch_size; ++i) {
251+
const int q_len_i = cum_q_pinned.get()[i + 1] - cum_q_pinned.get()[i];
252+
if (q_len_i > max_query_len) {
253+
max_query_len = q_len_i;
254+
}
255+
}
236256
if (total_kv_tokens == 0) {
237257
// Legal empty-input case: token_count == 0 and all past_seqlens == 0 — nothing to do.
238258
// The paged key/value caches are alias-outputs already bound to the input caches
@@ -305,6 +325,7 @@ Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
305325
data.fmha_buffer = reinterpret_cast<CudaT*>(fmha_buffer.get());
306326
}
307327
data.total_kv_tokens = total_kv_tokens;
328+
data.max_query_len = max_query_len;
308329
}
309330

310331
cublasHandle_t cublas = GetCublasHandle(context);

onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu

Lines changed: 60 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,9 @@ Status LaunchReshapeAndCache(const T* key, const T* value, T* key_cache, T* valu
239239
}
240240

241241
// Gather paged KV into packed-varlen [total_kv_tokens, num_heads, head_size], expanding GQA heads.
242+
// total_elems = total_kv_tokens * num_heads * head_size can exceed INT32_MAX for realistic
243+
// large-context GQA configs (e.g., 2M tokens * 64 * 128 = 16.4B), so the linear index is int64_t
244+
// and the kernel uses a grid-stride loop instead of a single (tid >= total_elems) early-exit.
242245
template <typename T>
243246
__global__ void GatherAndExpandPagedKVCache(const T* __restrict__ key_cache,
244247
const T* __restrict__ value_cache,
@@ -252,52 +255,54 @@ __global__ void GatherAndExpandPagedKVCache(const T* __restrict__ key_cache,
252255
const int head_size,
253256
const int block_size,
254257
const int max_num_blocks_per_seq,
255-
const int total_kv_tokens) {
256-
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
257-
const int total_elems = total_kv_tokens * num_heads * head_size;
258-
if (tid >= total_elems) {
259-
return;
260-
}
261-
262-
const int h = tid % head_size;
263-
const int head_id = (tid / head_size) % num_heads;
264-
const int token_id = tid / (num_heads * head_size);
265-
266-
// cumulative_seqlens_kv is a prefix sum of non-negative per-batch KV lengths
267-
// (past_seqlens[i] + new_tokens[i]), so it is monotonically non-decreasing for
268-
// any valid op input — the same assumption the previous linear scan made.
269-
// Binary-search for the batch this token belongs to: log2(batch_size) is strictly
270-
// better than the linear scan, which ran once per (token, head, h) element and
271-
// multiplied its cost by num_heads * head_size.
272-
int left = 0;
273-
int right = batch_size;
274-
while (left < right) {
275-
const int mid = left + (right - left) / 2;
276-
if (token_id < cumulative_seqlens_kv[mid + 1]) {
277-
right = mid;
278-
} else {
279-
left = mid + 1;
258+
const int64_t total_elems) {
259+
const int64_t stride = static_cast<int64_t>(gridDim.x) * blockDim.x;
260+
const int64_t num_heads_times_head = static_cast<int64_t>(num_heads) * head_size;
261+
const int q_kv_head_ratio = num_heads / kv_num_heads;
262+
const int64_t page_stride = static_cast<int64_t>(block_size) * kv_num_heads * head_size;
263+
264+
for (int64_t tid = threadIdx.x + static_cast<int64_t>(blockIdx.x) * blockDim.x;
265+
tid < total_elems;
266+
tid += stride) {
267+
const int h = static_cast<int>(tid % head_size);
268+
const int head_id = static_cast<int>((tid / head_size) % num_heads);
269+
const int token_id = static_cast<int>(tid / num_heads_times_head);
270+
271+
// cumulative_seqlens_kv is a prefix sum of non-negative per-batch KV lengths
272+
// (past_seqlens[i] + new_tokens[i]), so it is monotonically non-decreasing for
273+
// any valid op input — the same assumption the previous linear scan made.
274+
// Binary-search for the batch this token belongs to: log2(batch_size) is strictly
275+
// better than the linear scan, which ran once per (token, head, h) element and
276+
// multiplied its cost by num_heads * head_size.
277+
int left = 0;
278+
int right = batch_size;
279+
while (left < right) {
280+
const int mid = left + (right - left) / 2;
281+
if (token_id < cumulative_seqlens_kv[mid + 1]) {
282+
right = mid;
283+
} else {
284+
left = mid + 1;
285+
}
280286
}
281-
}
282-
const int batch_id = left;
287+
const int batch_id = left;
283288

284-
const int pos = token_id - cumulative_seqlens_kv[batch_id];
285-
const int block_idx_in_seq = pos / block_size;
286-
const int block_offset = pos % block_size;
287-
const int block_id = block_table[batch_id * max_num_blocks_per_seq + block_idx_in_seq];
289+
const int pos = token_id - cumulative_seqlens_kv[batch_id];
290+
const int block_idx_in_seq = pos / block_size;
291+
const int block_offset = pos % block_size;
292+
const int block_id = block_table[batch_id * max_num_blocks_per_seq + block_idx_in_seq];
288293

289-
// GQA expansion: each output head maps to kv_head_id = head_id / (num_heads / kv_num_heads).
290-
// For MHA (num_heads == kv_num_heads) this is the identity.
291-
const int q_kv_head_ratio = num_heads / kv_num_heads;
292-
const int kv_head_id = head_id / q_kv_head_ratio;
294+
// GQA expansion: each output head maps to kv_head_id = head_id / (num_heads / kv_num_heads).
295+
// For MHA (num_heads == kv_num_heads) this is the identity.
296+
const int kv_head_id = head_id / q_kv_head_ratio;
293297

294-
const int paged_idx = block_id * block_size * kv_num_heads * head_size +
295-
block_offset * kv_num_heads * head_size +
296-
kv_head_id * head_size +
297-
h;
298+
const int64_t paged_idx = static_cast<int64_t>(block_id) * page_stride +
299+
static_cast<int64_t>(block_offset) * kv_num_heads * head_size +
300+
kv_head_id * head_size +
301+
h;
298302

299-
gathered_key[tid] = key_cache[paged_idx];
300-
gathered_value[tid] = value_cache[paged_idx];
303+
gathered_key[tid] = key_cache[paged_idx];
304+
gathered_value[tid] = value_cache[paged_idx];
305+
}
301306
}
302307

303308
template <typename T>
@@ -309,17 +314,22 @@ Status LaunchGatherAndExpandPagedKVCache(const T* key_cache, const T* value_cach
309314
const int block_size, const int max_num_blocks_per_seq,
310315
const int total_kv_tokens, cudaStream_t stream,
311316
const int max_threads_per_block) {
312-
const int total_elems = total_kv_tokens * num_heads * head_size;
317+
const int64_t total_elems = static_cast<int64_t>(total_kv_tokens) * num_heads * head_size;
313318
if (total_elems == 0) {
314319
return Status::OK();
315320
}
316-
const int threads = std::min(total_elems, max_threads_per_block);
317-
const int blocks = (total_elems + threads - 1) / threads;
321+
// With the op's batch_size <= 256 precondition (paged_attention.cc) and MEA's
322+
// head_size <= 1024 cap, blocks_needed = ceil(total_elems / threads) stays comfortably
323+
// within int range for any realistic input, so no explicit clamp is needed. The kernel
324+
// uses a grid-stride loop so launching fewer blocks than total_elems / threads would
325+
// also be correct — we don't need an artificial "keep SMs busy" cap.
326+
const int threads = static_cast<int>(std::min<int64_t>(max_threads_per_block, total_elems));
327+
const int blocks = static_cast<int>((total_elems + threads - 1) / threads);
318328
GatherAndExpandPagedKVCache<T><<<blocks, threads, 0, stream>>>(
319329
key_cache, value_cache, gathered_key, gathered_value,
320330
block_table, cumulative_seqlens_kv,
321331
batch_size, num_heads, kv_num_heads, head_size,
322-
block_size, max_num_blocks_per_seq, total_kv_tokens);
332+
block_size, max_num_blocks_per_seq, total_elems);
323333
return CUDA_CALL(cudaGetLastError());
324334
}
325335

@@ -445,7 +455,11 @@ Status EfficientAttention(
445455
const int max_num_blocks_per_seq = parameters.max_num_blocks_per_seq;
446456
const int local_window_size = parameters.local_window_size;
447457
const int total_kv_tokens = data.total_kv_tokens;
448-
const int max_query_len = token_count - batch_size + 1;
458+
// Use the caller-computed actual max of per-batch new-query lengths, not the
459+
// `token_count - batch_size + 1` heuristic: the heuristic assumes >=1 new token per batch
460+
// and underestimates otherwise, which would silently drop query tokens from the
461+
// rotary grid and from MEA's `grid_x = ceil_div(sequence_length, kQueriesPerBlock)`.
462+
const int max_query_len = data.max_query_len;
449463

450464
T* query = const_cast<T*>(data.query);
451465
T* key;

0 commit comments

Comments
 (0)