Enable torch.export of the full SAM3 grounding pipeline#9
Open
Enable torch.export of the full SAM3 grounding pipeline#9
Conversation
Bundle the image encoder, text encoder, encoder fusion, and decoder into a
single ExportedProgram so consumers can ship one .pt2. The graph accepts
dynamic batch and prompt dims; image H/W and CLIP context length are fixed.
Required model patches:
- sam3/model/encoder.py: fix x.dim → x.dim() (called the method object).
- sam3/model/geometry_encoders.py:
* Skip pin_memory + non_blocking copy under torch._dynamo.is_compiling()
and on CPU; both fail under torch.export.
* Convert ROIAlign input from a list-of-tensors (one per batch) to the
[N, 5] batched format with batch_idx prepended. Lists of tensors are
not traceable through torch.export.
- sam3/model/position_encoding.py:
* Skip the ndim/length assert in _encode_xy when tracing — symbolic dims
from dynamic prompts trip GuardOnDataDependentSymNode.
* Skip the shape-keyed cache when shapes are SymInts (looking up a SymInt
in a dict of int keys raises the same guard).
- sam3/model/decoder.py:
* Allow 4D (bs, nheads, nq, hw) cross_attn_mask through TransformerDecoderLayer.
nn.MultiheadAttention rejects 4D additive bias, so add a manual SDPA
path (_cross_attn_with_rpb) that unpacks in/out projections and calls
F.scaled_dot_product_attention with the per-head bias.
* Forward the box-RPB matrix in 4D (drop the flatten(0, 1) that produced
the (bs*nheads, nq, hw) form for MHA).
* In _make_box_rpb_relative, branch is_dynamo_compiling() before the
tuple-equality cache check (SymInt tuple equality trips guards) and
drop the eager-only shape asserts on deltas/B that fired with
dynamic dims even with the is_dynamo_compiling guard.
- sam3/model_builder.py: thread num_feature_levels through
build_sam3_image_model → _create_sam3_transformer/_model so callers can
pin it explicitly (the production export passes 1).
New:
- scripts/export_sam3_full_pipeline.py — FullSam3PipelineWrapper that
unifies the four sub-modules into one forward, plus a CLI to save a
.pt2 with dynamic batch/prompt dims and fixed 1008x1008 / 32-token spec.
- tests/export/test_full_pipeline_export.py — slow pytest that builds
the real model, exports it, and runs the exported module on inputs
with batch=2 and num_prompts=4 to exercise both dynamic dims.
… export is_dynamo_compiling() returns False during non-strict export, so the SymInt tuple-equality on compilable_stored_size still triggered GuardOnDataDependentSymNode. is_compiling() is the broader API that covers torch.compile, strict export, and non-strict export.
cross_attn is sam3's custom MultiheadAttention (model_misc.MultiheadAttention, aliased as MultiheadAttentionWrapper), not torch.nn.MultiheadAttention. The two classes share the relevant attributes (in_proj_weight/bias, out_proj, num_heads, head_dim, dropout), so duck-type instead.
Dim.AUTO with batch=2 example specializes the batch dim to min=2, which rejects batch=1 at runtime. A named Dim with min=1 keeps the dim dynamic across the full [1, +inf) range.
The image encoder runs under autocast and emits bf16, but the cross-attn weights stay fp32. nn.MultiheadAttention's forward casts internally; our manual SDPA path didn't, producing 'mat1 and mat2 must have the same dtype' at runtime. Cast in_proj/out_proj weights and the attn bias to match query dtype.
Upstream vitdet MLP uses sam3.perflib.fused.addmm_act which forces bf16 internally (added in the SAM 3.1 release commit 9f22cb9). Production runs the model inside Sam3TrackingPredictor whose __init__ enters a persistent bf16 autocast — so eager production code works. Our export wrapper calls the model directly, so we have to re-enter autocast at trace time, otherwise fc2 downstream of addmm_act sees a bf16 input against fp32 weights and raises 'mat1 and mat2 must have the same dtype'.
Tracing under outer autocast baked the autocast effects into op dtypes but missed the fp32→bf16 input boundary, causing the deserialized graph to demand bf16 input at runtime. Putting autocast inside the wrapper keeps the input/output fp32 contract while letting addmm_act's bf16 path coexist with the surrounding fp32 weights.
3 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Bundle the SAM3 image encoder, text encoder, encoder fusion, and decoder
into a single
ExportedProgramso consumers can ship one.pt2. Thegraph accepts dynamic batch (
B>=1) and prompt (P>=1) dims; image H/W andCLIP context length are fixed at 1008 and 32.
This is a minimal redo of #3: only the patches strictly needed for
torch.exportof the full pipeline + a single test, on top of latest main.No artifact/sub-module export scripts, no benchmarks, no uv.lock.
The exported contract preserves what
wherobots-rasterflowrelies on(
/home/rave/wherobots-rasterflow/.../adapters/torch/geometry_actors.pycalls
self.model(patch_batch, token_ids)and unpacks 4 outputs).Verified on RTX 3090 (torch 2.10.0+cu128)
Eager-vs-exported max abs diff is 0 on all 4 outputs (
pred_logits,pred_boxes,pred_masks,presence_logit_dec).Required model patches
sam3/model/encoder.pyx.dim→x.dim()(called the method object instead of invoking it)sam3/model/geometry_encoders.pypin_memory+non_blockingcopy undertorch.compiler.is_compiling()and on CPU. Convert ROIAlign input from a list-of-tensors to the[N, 5]batched form (lists of tensors aren't traceable).sam3/model/position_encoding.py_encode_xyndim/length assert when tracing — symbolic dims tripGuardOnDataDependentSymNode. Skip the shape-keyedforwardcache when shape items are SymInts.sam3/model/decoder.py(B, H, Q, K)cross_attn_maskthroughTransformerDecoderLayerand add_cross_attn_with_rpb, a manual SDPA path that preserves a per-head additive bias (nn.MultiheadAttentionrejects 4D masks). Casts proj weights to activation dtype to coexist with bf16 autocast. (2) Forward the box-RPB matrix in 4D (drop theflatten(0, 1)that produced the MHA-friendly 3D shape). (3)_make_box_rpb_relative: usetorch.compiler.is_compiling()instead ofis_dynamo_compiling()and drop the SymInt tuple-equality cache check + eager-only shape asserts ondeltas/B.sam3/model_builder.pynum_feature_levelsthroughbuild_sam3_image_model→_create_sam3_transformer/_create_sam3_modelso callers can pin it explicitly.New files
scripts/export_sam3_full_pipeline.py—FullSam3PipelineWrapperthat unifies the four sub-modules into one forward. Wraps the model call intorch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)because upstream's ViT MLP usessam3.perflib.fused.addmm_act(added in the SAM 3.1 release) which forces bf16 internally andSam3TrackingPredictorenters the same autocast in__init__. Outputs are cast back to fp32 so consumers see the same input/output dtype contract production already expects. CLI saves a.pt2withDim("batch", min=1)andDim("num_prompts", min=1).tests/export/test_full_pipeline_export.py— slow pytest marked with theslowmarker (registered inpyproject.toml). Three tests:test_full_pipeline_export_matches_eager: exact-equality (atol=0) check between the eager wrapper and the exported program on the same input.test_full_pipeline_export_save_load_roundtrip:torch.export.save→torch.export.load→ run on fresh inputs.test_full_pipeline_export_supports_dynamic_shapes: parametric run across(1,1), (1,2), (3,2), (2,4)(batch, num_prompts)pairs.Notes / caveats
nn.MultiheadAttention; for now it's load-bearing..pt2requiresimport torchvision.opsfirst to registertorch.ops.torchvision.roi_align.default— same as the production deploy.