Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,11 @@
"**/cmake/external/**": true,
"**/node_modules/**": true,
"**/.git/**": true
},

// CMake generator for ORT builds. Change to "Visual Studio 18 2026" if using VS 2026.
"ort.cmakeGenerator": "Visual Studio 17 2022",
"chat.tools.terminal.autoApprove": {
".\\build.bat": true
}
}
7 changes: 4 additions & 3 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ if (NOT CMAKE_C_STANDARD)
set(CMAKE_C_STANDARD 99)
endif()

if (NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 20)
endif()
# C++20 is required. Set unconditionally to override any cached values from dependencies
# (e.g., the 'date' library caches CMAKE_CXX_STANDARD=17, which would prevent the
# conditional set from taking effect on subsequent configures).
set(CMAKE_CXX_STANDARD 20)

# We don't use C++20 modules yet.
# There are some known issues to address first:
Expand Down
646 changes: 646 additions & 0 deletions docs/turbo_quant_plan.md

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ struct WebgpuAttentionParameters {
int* zero_ptr_ = nullptr;
// Computed values
int n_reps = 1;
bool turbo_quant_ = false;
int compressed_head_size_ = 0; // head_size/4 + 4 when turbo_quant (vec4-aligned compressed dim)
AttentionMaskType mask_type_ = MASK_NONE;
AttentionQkvFormat qkv_format_ = UNKNOWN;
};
Expand Down
147 changes: 134 additions & 13 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Large diffs are not rendered by default.

17 changes: 11 additions & 6 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
bool is_nvidia,
bool q_BNSH,
bool use_seqlen_k = false,
bool has_head_sink = false)
bool has_head_sink = false,
bool turbo_quant = false)
: Program{kernel_name},
has_attention_bias_(has_attention_bias),
is_qualcomm_(is_qualcomm),
Expand All @@ -89,7 +90,8 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
is_nvidia_(is_nvidia),
q_BNSH_(q_BNSH),
use_seqlen_k_(use_seqlen_k),
has_head_sink_(has_head_sink) {
has_head_sink_(has_head_sink),
turbo_quant_(turbo_quant) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -115,13 +117,14 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
bool q_BNSH_;
bool use_seqlen_k_;
bool has_head_sink_;
bool turbo_quant_;
};

class FlashAttentionDecodeQKTProgram final : public Program<FlashAttentionDecodeQKTProgram> {
public:
FlashAttentionDecodeQKTProgram(const std::string& kernel_name,
bool has_attention_bias, uint32_t tile_size, bool use_indirect_dispatch)
: Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size), use_indirect_dispatch_(use_indirect_dispatch) {
bool has_attention_bias, uint32_t tile_size, bool use_indirect_dispatch, bool turbo_quant = false)
: Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size), use_indirect_dispatch_(use_indirect_dispatch), turbo_quant_(turbo_quant) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -141,12 +144,13 @@ class FlashAttentionDecodeQKTProgram final : public Program<FlashAttentionDecode
bool has_attention_bias_;
uint32_t tile_size_;
bool use_indirect_dispatch_;
bool turbo_quant_;
};

class FlashAttentionDecodeSplitVxProgram final : public Program<FlashAttentionDecodeSplitVxProgram> {
public:
FlashAttentionDecodeSplitVxProgram(const std::string& kernel_name, uint32_t tile_size, int head_size_vec, bool use_indirect_dispatch, bool has_head_sink = false)
: Program{kernel_name}, tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch), has_head_sink_(has_head_sink) {
FlashAttentionDecodeSplitVxProgram(const std::string& kernel_name, uint32_t tile_size, int head_size_vec, bool use_indirect_dispatch, bool has_head_sink = false, bool turbo_quant = false)
: Program{kernel_name}, tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch), has_head_sink_(has_head_sink), turbo_quant_(turbo_quant) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -164,6 +168,7 @@ class FlashAttentionDecodeSplitVxProgram final : public Program<FlashAttentionDe
int head_size_vec_;
bool use_indirect_dispatch_;
bool has_head_sink_;
bool turbo_quant_;
};

class FlashAttentionDecodeVxReduceProgram final : public Program<FlashAttentionDecodeVxReduceProgram> {
Expand Down
83 changes: 83 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#param q_BNSH
#param qkv_head_size
#param qkv_num_heads
#param turbo_quant
#param use_seqlen_k

const head_size : u32 = qkv_head_size;
Expand Down Expand Up @@ -42,6 +43,25 @@ const head_size_vec : u32 = head_size / vec_factor;
var<workgroup> k_tile : array<array<q_value_t, head_size_vec>, max_k_step>;
var<workgroup> v_tile : array<array<q_value_t, head_size_vec>, max_k_step>;

#if turbo_quant
#include "bert/turbo_quant_common.wgsl.template"

// Compressed KV cache: int4 indices packed as u32 pairs in fp16 carrier.
// compressed_head_size_vec = number of vec4s per token in compressed buffer.
const compressed_head_size_vec : u32 = head_size / 16u + 1u;

// Extract f32 norm from the x,y components of the first vec4 in a compressed KV token.
fn tq_extract_norm(first_vec: q_value_t) -> f32 {
#if is_fp16
return bitcast<f32>(vec2<f16>(q_element_t(first_vec.x), q_element_t(first_vec.y)));
#else
let lo = bitcast<u32>(f32(first_vec.x));
let hi = bitcast<u32>(f32(first_vec.y));
return bitcast<f32>(lo | (hi << 16u));
#endif
}
#endif

// Private memory per lane.
var<private> q_tile : array<q_value_t, head_size_vec>;
fn loadq(batch_idx : u32, q_idx_global : u32, head_idx : u32, alpha : q_element_t) {
Expand All @@ -63,24 +83,87 @@ fn loadq(batch_idx : u32, q_idx_global : u32, head_idx : u32, alpha : q_element_

fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, k_step : u32) {
// Stored as float16[batch_size,num_heads,present_sequence_length,96]
#if turbo_quant
let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * compressed_head_size_vec +
k_start * compressed_head_size_vec;
for (var slot : u32 = 0; slot < k_step; slot++) {
if (k_start + slot < get_total_sequence_length()) {
let token_offset = offset + slot * compressed_head_size_vec;
let norm = tq_extract_norm(present_key[token_offset]);
// Cooperative unpack: each thread handles a subset of output vec4s
for (var idx : u32 = local_idx; idx < head_size_vec; idx += workgroup_size_x) {
let element_start = idx * 4u;
let u32_group = element_start / 8u;
let half_sel = (element_start / 4u) % 2u;
let compressed_fp16_idx = 2u + u32_group * 2u;
let cv_idx = compressed_fp16_idx / 4u;
let cv_component_pair = (compressed_fp16_idx % 4u) / 2u;
let cv = present_key[token_offset + cv_idx];
var packed_u32: u32;
if (cv_component_pair == 0u) {
packed_u32 = bitcast<u32>(vec2<f16>(q_element_t(cv.x), q_element_t(cv.y)));
} else {
packed_u32 = bitcast<u32>(vec2<f16>(q_element_t(cv.z), q_element_t(cv.w)));
}
k_tile[slot][idx] = q_value_t(tq_dequant_packed_half(packed_u32, half_sel, norm));
}
} else {
for (var idx : u32 = local_idx; idx < head_size_vec; idx += workgroup_size_x) {
k_tile[slot][idx] = q_value_t(0);
}
}
}
#else
let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec +
k_start * head_size_vec;
for (var idx : u32 = local_idx; idx < head_size_vec * k_step; idx += workgroup_size_x) {
let slot = u32(idx / head_size_vec);
let val = select(q_value_t(0), present_key[offset + idx], k_start + slot < get_total_sequence_length());
k_tile[slot][idx % head_size_vec] = val;
}
#endif
}

fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32, v_step : u32) {
// Stored as float16[batch_size,num_heads,present_sequence_length,96]
#if turbo_quant
let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * compressed_head_size_vec +
v_start * compressed_head_size_vec;
for (var slot : u32 = 0; slot < v_step; slot++) {
if (v_start + slot < get_total_sequence_length()) {
let token_offset = offset + slot * compressed_head_size_vec;
let norm = tq_extract_norm(present_value[token_offset]);
for (var idx : u32 = local_idx; idx < head_size_vec; idx += workgroup_size_x) {
let element_start = idx * 4u;
let u32_group = element_start / 8u;
let half_sel = (element_start / 4u) % 2u;
let compressed_fp16_idx = 2u + u32_group * 2u;
let cv_idx = compressed_fp16_idx / 4u;
let cv_component_pair = (compressed_fp16_idx % 4u) / 2u;
let cv = present_value[token_offset + cv_idx];
var packed_u32: u32;
if (cv_component_pair == 0u) {
packed_u32 = bitcast<u32>(vec2<f16>(q_element_t(cv.x), q_element_t(cv.y)));
} else {
packed_u32 = bitcast<u32>(vec2<f16>(q_element_t(cv.z), q_element_t(cv.w)));
}
v_tile[slot][idx] = q_value_t(tq_dequant_packed_half(packed_u32, half_sel, norm));
}
} else {
for (var idx : u32 = local_idx; idx < head_size_vec; idx += workgroup_size_x) {
v_tile[slot][idx] = q_value_t(0);
}
}
}
#else
let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec +
v_start * head_size_vec;
for (var idx : u32 = local_idx; idx < head_size_vec * v_step; idx += workgroup_size_x) {
let slot = u32(idx / head_size_vec);
let val = select(q_value_t(0), present_value[offset + idx], v_start + slot < get_total_sequence_length());
v_tile[slot][idx % head_size_vec] = val;
}
#endif
}

#if is_qualcomm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#param tile_size
#param tile_size_k_vec
#param sub_tile_count
#param turbo_quant
#param use_indirect_dispatch

// Note that this shader adopts similar algorithm with dp4a generation shader.
Expand Down Expand Up @@ -34,6 +35,12 @@ var<workgroup> tile_q: array<q_value_t, tile_size_k_vec>;
var<workgroup> inner_qk_values: array<array<q_element_t, tile_size_k_vec>, tile_size>;
var<workgroup> tile_qk: array<q_element_t, tile_size>;

#if turbo_quant
#include "bert/turbo_quant_common.wgsl.template"

var<workgroup> tq_k_norms: array<f32, tile_size>;
#endif

#if has_attention_bias
fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t
{
Expand Down Expand Up @@ -73,7 +80,22 @@ $MAIN {
return;
}
let q_offset = batch_idx * uniforms.num_heads * uniforms.head_size_vec + head_idx * uniforms.head_size_vec;
#if turbo_quant
let compressed_hsv = uniforms.head_size_vec / 4u + 1u;
let present_offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * compressed_hsv;
#else
let present_offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * uniforms.head_size_vec;
#endif

#if turbo_quant
// Preload f32 norms from first vec4 of each compressed token in this tile.
if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) {
let first_k = present_key[present_offset + (total_seq_offset + local_idx) * compressed_hsv];
tq_k_norms[local_idx] = bitcast<f32>(vec2<f16>(first_k.x, first_k.y));
}
workgroupBarrier();
#endif

for (var k: u32 = 0u; k < uniforms.head_size_vec; k += tile_size_k_vec) {
if (local_idx < tile_size_k_vec && k + local_idx < uniforms.head_size_vec) {
tile_q[local_idx] = q[q_offset + k + local_idx];
Expand All @@ -83,7 +105,27 @@ $MAIN {
if (k + local_col < uniforms.head_size_vec) {
for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) {
if (total_seq_offset + row_offset + local_row < total_sequence_length) {
#if turbo_quant
let k_vec_idx = k + local_col;
let k_norm = tq_k_norms[row_offset + local_row];
let element_start = k_vec_idx * 4u;
let u32_group = element_start / 8u;
let half_sel = (element_start / 4u) % 2u;
let compressed_fp16_idx = 2u + u32_group * 2u;
let cv_idx = compressed_fp16_idx / 4u;
let cv_component_pair = (compressed_fp16_idx % 4u) / 2u;
let cv = present_key[present_offset + (total_seq_offset + row_offset + local_row) * compressed_hsv + cv_idx];
var packed_u32: u32;
if (cv_component_pair == 0u) {
packed_u32 = bitcast<u32>(vec2<f16>(cv.x, cv.y));
} else {
packed_u32 = bitcast<u32>(vec2<f16>(cv.z, cv.w));
}
let dequant_k = q_value_t(tq_dequant_packed_half(packed_u32, half_sel, k_norm));
inner_qk_values[row_offset + local_row][local_col] += dot(dequant_k, q_data);
#else
inner_qk_values[row_offset + local_row][local_col] += dot(present_key[present_offset + (total_seq_offset + row_offset + local_row) * uniforms.head_size_vec + k + local_col], q_data);
#endif
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#param head_size_vec
#param tile_size_k_vec
#param sub_tile_count
#param turbo_quant
#param use_indirect_dispatch

// Note that this shader adopts similar algorithm with dp4a generation shader.
Expand Down Expand Up @@ -39,6 +40,13 @@ var<workgroup> tile_qk: array<present_value_element_t, tile_size>;
var<workgroup> tile_output: array<present_value_value_t, head_size_vec>;
var<workgroup> qkv_values: array<array<present_value_value_t, tile_size_k_vec>, sub_tile_count>;

#if turbo_quant
#include "bert/turbo_quant_common.wgsl.template"

const compressed_head_size_vec : u32 = head_size_vec / 4u + 1u;
var<workgroup> tq_v_norms: array<f32, tile_size>;
#endif

$MAIN {
let local_row = u32(local_idx / tile_size_k_vec);
let local_col = local_idx % tile_size_k_vec;
Expand All @@ -53,7 +61,11 @@ $MAIN {
if (batch_head_idx >= uniforms.batch_heads) {
return;
}
#if turbo_quant
let present_offset = u32(batch_head_idx / uniforms.n_reps) * compressed_head_size_vec * uniforms.present_sequence_length;
#else
let present_offset = u32(batch_head_idx / uniforms.n_reps) * head_size_vec * uniforms.present_sequence_length;
#endif

// Calculate the global max and sum in qk.
var g_max = f32(-3.4028234663852886e+38f);
Expand Down Expand Up @@ -82,6 +94,15 @@ $MAIN {
tile_qk[local_idx] = present_value_element_t(exp(f32(qk[batch_head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx]) - g_max) / g_sum);
}

#if turbo_quant
// Preload f32 norms from first vec4 of each compressed V token in this tile.
if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) {
let first_v = present_value[present_offset + (total_seq_offset + local_idx) * compressed_head_size_vec];
tq_v_norms[local_idx] = bitcast<f32>(vec2<f16>(present_value_element_t(first_v.x), present_value_element_t(first_v.y)));
}
workgroupBarrier();
#endif

for (var k: u32 = 0u; k < head_size_vec; k += tile_size_k_vec) {
var value = present_value_value_t(0);
qkv_values[local_row][local_col] = present_value_value_t(0);
Expand All @@ -90,7 +111,27 @@ $MAIN {
if (k + local_col < head_size_vec) {
for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) {
if (total_seq_offset + row_offset + local_row < total_sequence_length) {
#if turbo_quant
let v_vec_idx = k + local_col;
let v_norm = tq_v_norms[row_offset + local_row];
let element_start = v_vec_idx * 4u;
let u32_group = element_start / 8u;
let half_sel = (element_start / 4u) % 2u;
let compressed_fp16_idx = 2u + u32_group * 2u;
let cv_idx = compressed_fp16_idx / 4u;
let cv_component_pair = (compressed_fp16_idx % 4u) / 2u;
let cv = present_value[present_offset + (total_seq_offset + row_offset + local_row) * compressed_head_size_vec + cv_idx];
var packed_u32: u32;
if (cv_component_pair == 0u) {
packed_u32 = bitcast<u32>(vec2<f16>(present_value_element_t(cv.x), present_value_element_t(cv.y)));
} else {
packed_u32 = bitcast<u32>(vec2<f16>(present_value_element_t(cv.z), present_value_element_t(cv.w)));
}
let dequant_v = present_value_value_t(tq_dequant_packed_half(packed_u32, half_sel, v_norm));
value += dequant_v * tile_qk[row_offset + local_row];
#else
value += present_value[present_offset + (total_seq_offset + row_offset + local_row) * head_size_vec + k + local_col] * tile_qk[row_offset + local_row];
#endif
}
}
}
Expand Down
Loading
Loading