feat(gpu): optimize BLS12-446 field arithmetic for MSM performance#3448
feat(gpu): optimize BLS12-446 field arithmetic for MSM performance#3448bbarbakadze wants to merge 1 commit intomainfrom
Conversation
|
If you're new to commit signing, there are different ways to set it up: Sign commits with
|
There was a problem hiding this comment.
This PR should change this line to 32 by default.
|
@bbarbakadze Something is wrong with benchmarks: https://github.com/zama-ai/tfhe-rs/actions/runs/24038290191 |
8cfdd1b to
a5fd85a
Compare
pdroalves
left a comment
There was a problem hiding this comment.
I added a batch of high-level comments. Let me know when you are done with this PR so I can do a careful review line by line.
| // largest field alignment (4 bytes in 32-bit limb mode, 8 bytes in 64-bit). | ||
| // Forcing alignas(8) ensures sizeof(G1Affine)==120 in both modes, matching | ||
| // the Rust FFI bindings which are always generated from the 64-bit layout. | ||
| struct alignas(8) G1Affine { |
There was a problem hiding this comment.
Can you replace this magic number by a function based on LIMB_BITS_CONFIG?
There was a problem hiding this comment.
So this is actually not dependent on a LIMB_BITS_CONFIG, it depends on the layout Rust is using. So once 64-bit limbs are used, we need the same alignment. still the magic number is replaced with sizeof
| // across all 14 limbs. | ||
| // Operand map: %0..%13 = c[0..13], %14 = carry_out, | ||
| // %15..%28 = a[0..13], %29..%42 = b[0..13]. | ||
| uint32_t carry_out; |
There was a problem hiding this comment.
@guillermo-oyarzun do you want to double check this PTX? It seems ok to me.
There was a problem hiding this comment.
yup the ptx looks good!
| // Operand map: %0..%13 = c[0..13], %14 = borrow_out, | ||
| // %15..%28 = a[0..13], %29..%42 = b[0..13]. | ||
| uint32_t borrow_out; | ||
| asm("sub.cc.u32 %0, %15, %29;\n\t" // c[0] = a[0] - b[0], set BF |
There was a problem hiding this comment.
same here, it looks good too!
| #endif // LIMB_BITS_CONFIG == 64 | ||
| #endif // __CUDA_ARCH__ | ||
|
|
||
| // 32-bit dual MAD-chain Montgomery multiplication (device path) |
There was a problem hiding this comment.
Do you have a reference for this MAD-chain multiplication? If so, a link as comment would help.
| fp_qad_row_32(&wtemp[2 * i], &wide[2 * i + 2], &a32[i + 1], a32[i], n - i); | ||
| } | ||
|
|
||
| asm("mul.lo.u32 %0, %2, %3; mul.hi.u32 %1, %2, %3;" |
There was a problem hiding this comment.
I don't like PTX in the middle of a function like this one. Maybe you could move it to a macro and add comments explaining what it is.
| p4 = DEVICE_MODULUS.limb[4], p5 = DEVICE_MODULUS.limb[5], | ||
| p6 = DEVICE_MODULUS.limb[6]; | ||
| uint64_t r0, r1, r2, r3, r4, r5, r6, mask64; | ||
| asm("sub.cc.u64 %0, %8, %15;\n\t" |
There was a problem hiding this comment.
This diff is full of PTX. We need to careful read them and if possible remove them from within functions.
a5fd85a to
38e4101
Compare
| // across all 14 limbs. | ||
| // Operand map: %0..%13 = c[0..13], %14 = carry_out, | ||
| // %15..%28 = a[0..13], %29..%42 = b[0..13]. | ||
| uint32_t carry_out; |
There was a problem hiding this comment.
yup the ptx looks good!
| // Operand map: %0..%13 = c[0..13], %14 = borrow_out, | ||
| // %15..%28 = a[0..13], %29..%42 = b[0..13]. | ||
| uint32_t borrow_out; | ||
| asm("sub.cc.u32 %0, %15, %29;\n\t" // c[0] = a[0] - b[0], set BF |
There was a problem hiding this comment.
same here, it looks good too!
| #if defined(__CUDA_ARCH__) && LIMB_BITS_CONFIG == 64 | ||
| // Device path: fully unrolled PTX with hardware carry flags | ||
| fp_mont_mul_cios_ptx(c, a, b); | ||
| #ifdef __CUDA_ARCH__ |
There was a problem hiding this comment.
i understand that now we have 2 versions for 32 and 64-bit limbs, can we add a panic in the correct place in case someone attempts to use it with 128-bit?
There was a problem hiding this comment.
@guillermo-oyarzun you mean if someone tries to set value other than 32 and 64 to LIMB_BITS_CONFIG
There was a problem hiding this comment.
in this case maybe use enum? with two values 32BIT and 64BIT
There was a problem hiding this comment.
yup enum should work, just trying be extra safe because the code shouldn't work with 128-bit, right? we would need to emulate them somehow
There was a problem hiding this comment.
for now limbs can only be 32 or 64 I will rewrite it with enum, should be better than panic.
There was a problem hiding this comment.
btw there is already a protection implemented for this inside fp.h line:55
static_assert(LIMB_BITS == 32 || LIMB_BITS == 64, "LIMB_BITS_CONFIG must be 32 or 64");
So I guess it is fine to leave it as it is.
- Replace 64-bit CIOS Montgomery multiplication with 32-bit MAD chains
(mad.lo.cc/madc.hi.cc), exploiting native 2x throughput of 32-bit ops
on NVIDIA GPUs via even/odd accumulator separation
- Add fp_mont_sqr using a triangular MAD chain (upper triangle computed
once and doubled, diagonal added separately), saving of the
multiplications versus treating squaring as a general multiplication
- Add fp_add_lazy/fp_sub_lazy (and Fp2 variants): skip the final
conditional subtraction when the result feeds fp_mont_mul, which
accepts inputs in [0, 2p). Wired into fp2_mont_mul, fp2_mont_square,
and G1/G2 projective_point_double
- Replace all fp_mont_mul(c, a, a) squaring patterns with fp_mont_sqr
across curve.cu and fp2.cu (point addition, doubling, inversion)
38e4101 to
e716051
Compare
| #define LIMB_BITS_CONFIG 32 | ||
| #endif | ||
|
|
||
| #if LIMB_BITS_CONFIG == 64 |
There was a problem hiding this comment.
We cannot forget to deprecate this. Once we merge this PR we should completely remove the 64-bit mode. Do you agree @bbarbakadze ?
pdroalves
left a comment
There was a problem hiding this comment.
Just a few minor comments and style changes.
The PR is quite good in my opinion. Is there anything else you need to do here? Otherwise it's a good moment to rebase. We can merge after these changes.
| // Uses the complex-squaring identity: c0 = (a0+a1)(a0-a1), c1 = 2*a0*a1 | ||
| // Only 2 Fp multiplications vs 3 for fp2_mont_mul(c, a, a). | ||
| // NOTE: All inputs and outputs are in Montgomery form (no conversions) | ||
| // NOTE: All inputs should be in Montgomery form |
| } | ||
|
|
||
| // Montgomery squaring using CIOS with triangular 32-bit MAD chains. | ||
| // See fp_mont_mul_mad32 for the algorithm reference (Koç et al., 1996). |
There was a problem hiding this comment.
Where can I find this reference? Could be good to add the paper name and venue with full author names.
| for (int j = 0; j < FP_LIMBS; j++) { | ||
| uint64_t acc = | ||
| (uint64_t)t[i + j] + (uint64_t)u * (uint64_t)p.limb[j] + carry; | ||
| t[i + j] = (UNSIGNED_LIMB)acc; |
There was a problem hiding this comment.
We've been slowly trying to avoid this type of cast in new C++ code. In the ZK backend that's a convention that we should be following.
Instead of
a = (type_t) b
you should be using
a = static_cast<type_t>(b)
```.
There was a problem hiding this comment.
I asked code to point other lines changed in this PR that needs to be fixed.
● 14 C-style casts in new PR lines, across two functions:
fp_mont_reduce (32-bit path), lines 596-611:
┌──────┬──────────────────────────────────────────────────────┐
│ Line │ Cast │
├──────┼──────────────────────────────────────────────────────┤
│ 600 │ (uint64_t)t[i + j], (uint64_t)u, (uint64_t)p.limb[j] │
├──────┼──────────────────────────────────────────────────────┤
│ 601 │ (UNSIGNED_LIMB)acc │
├──────┼──────────────────────────────────────────────────────┤
│ 607 │ (uint64_t)t[idx] │
├──────┼──────────────────────────────────────────────────────┤
│ 608 │ (UNSIGNED_LIMB)acc │
└──────┴──────────────────────────────────────────────────────┘
fp_mont_mul_cios (32-bit Step 1), lines 1263-1273:
┌──────┬──────────────────────────────────────────────────────────┐
│ Line │ Cast │
├──────┼──────────────────────────────────────────────────────────┤
│ 1267 │ (uint64_t)t[j], (uint64_t)a.limb[j], (uint64_t)b.limb[i] │
├──────┼──────────────────────────────────────────────────────────┤
│ 1268 │ (UNSIGNED_LIMB)acc │
├──────┼──────────────────────────────────────────────────────────┤
│ 1271 │ (uint64_t)t[FP_LIMBS] │
├──────┼──────────────────────────────────────────────────────────┤
│ 1272 │ (UNSIGNED_LIMB)(sum64 >> LIMB_BITS) │
├──────┼──────────────────────────────────────────────────────────┤
│ 1273 │ (UNSIGNED_LIMB)sum64 │
└──────┴──────────────────────────────────────────────────────────┘
fp_mont_mul_cios (32-bit Step 2), lines 1298-1310:
┌──────┬──────────────────────────────────────────────────┐
│ Line │ Cast │
├──────┼──────────────────────────────────────────────────┤
│ 1302 │ (uint64_t)t[j], (uint64_t)m, (uint64_t)p.limb[j] │
├──────┼──────────────────────────────────────────────────┤
│ 1303 │ (UNSIGNED_LIMB)acc │
├──────┼──────────────────────────────────────────────────┤
│ 1308 │ (uint64_t)t[FP_LIMBS], (uint64_t)overflow │
├──────┼──────────────────────────────────────────────────┤
│ 1309 │ (UNSIGNED_LIMB)s64 │
├──────┼──────────────────────────────────────────────────┤
│ 1310 │ (UNSIGNED_LIMB)(s64 >> LIMB_BITS) │
└──────┴──────────────────────────────────────────────────┘
All 14 are widening (uint32_t → uint64_t) or truncating (uint64_t → uint32_t) integer casts that should use static_cast<>.
| int idx = i + FP_LIMBS; | ||
| while (carry != 0 && idx <= 2 * FP_LIMBS) { | ||
| uint64_t acc = (uint64_t)t[idx] + carry; | ||
| t[idx] = (UNSIGNED_LIMB)acc; |
| // Add reduced lower half into upper half wide[n..2n-1]; the result lives | ||
| // in wide[n..2n-1] and is in [0, 2p). | ||
| fp_cadd_n_32(&wide[n], &wide[0], n); | ||
| FP_CARRY_32(wide[0]); // discard overflow (always 0 for p<2^446) |
There was a problem hiding this comment.
Do we need this line? The comment says the overflow is always 0, is that right?
There was a problem hiding this comment.
Ahh I think I get it, you are just consuming the carry right? Maybe the comment should be replaced by consume the carry flag so CC is clean
PR content/description
Optimize BLS12-446 field arithmetic for MSM performance
Replace 64-bit CIOS Montgomery multiplication with 32-bit MAD chains
(mad.lo.cc/madc.hi.cc), exploiting native 2x throughput of 32-bit ops
on NVIDIA GPUs via even/odd accumulator separation
Add fp_mont_sqr using a triangular MAD chain (upper triangle computed
once and doubled, diagonal added separately), saving ~40% of the
multiplications versus treating squaring as a general multiplication
Add fp_add_lazy/fp_sub_lazy (and Fp2 variants): skip the final
conditional subtraction when the result feeds fp_mont_mul, which
accepts inputs in [0, 2p). Wired into fp2_mont_mul, fp2_mont_square,
and G1/G2 projective_point_double
Replace all fp_mont_mul(c, a, a) squaring patterns with fp_mont_sqr
across curve.cu and fp2.cu (point addition, doubling, inversion)
Check-list: