Skip to content

Commit 7375578

Browse files
committed
[CUDA] PagedAttention: early-return on empty query input (token_count == 0)
Found while verifying the MEA-path edge cases from the round-2 review: token_count == 0 with non-zero past_seqlens would still enter backend preprocessing — FA path's LaunchReshapeAndCache hits total_size = 0, threads = min(0, max_threads) = 0, then blocks = (0 + 0 - 1) / 0 (division by zero). MEA path would also mis-report "total_kv_tokens is zero for non-empty input" even though token_count == 0 is the non-empty coordinate. Move the empty-query check right after the cache-aliasing verification (output is already [0, hidden_size] and the cache outputs alias the inputs, so no backend work is needed). This protects both backends with a single guard and removes the now-redundant nested check inside the MEA block.
1 parent f345865 commit 7375578

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

onnxruntime/contrib_ops/cuda/bert/paged_attention.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,12 @@ Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
143143
"value_cache and value_cache_out must be the same buffer");
144144
}
145145

146+
// Empty query input: output is already shaped [0, hidden_size], and the cache outputs
147+
// alias the input caches (verified above), so no backend kernel or cache update is needed.
148+
if (parameters.token_count == 0) {
149+
return Status::OK();
150+
}
151+
146152
// Kernel backend selection — FlashAttention preferred, fall back to MemoryEfficientAttention.
147153
#if USE_FLASH_ATTENTION
148154
bool use_flash_attention = !disable_flash_attention_ &&
@@ -254,12 +260,6 @@ Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
254260
}
255261
}
256262
if (total_kv_tokens == 0) {
257-
// Legal empty-input case: token_count == 0 and all past_seqlens == 0 — nothing to do.
258-
// The paged key/value caches are alias-outputs already bound to the input caches
259-
// (verified above), and the op's output is [0, hidden_size]; no kernel launches needed.
260-
if (parameters.token_count == 0) {
261-
return Status::OK();
262-
}
263263
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
264264
"PagedAttention MEA fallback: total_kv_tokens is zero for non-empty input.");
265265
}

0 commit comments

Comments
 (0)