Skip to content

Enable torch.export of the full SAM3 grounding pipeline#9

Open
rbavery wants to merge 10 commits intomainfrom
export-pipeline-minimal
Open

Enable torch.export of the full SAM3 grounding pipeline#9
rbavery wants to merge 10 commits intomainfrom
export-pipeline-minimal

Conversation

@rbavery
Copy link
Copy Markdown
Member

@rbavery rbavery commented Apr 29, 2026

Summary

Bundle the SAM3 image encoder, text encoder, encoder fusion, and decoder
into a single ExportedProgram so consumers can ship one .pt2. The
graph accepts dynamic batch (B>=1) and prompt (P>=1) dims; image H/W and
CLIP context length are fixed at 1008 and 32.

This is a minimal redo of #3: only the patches strictly needed for
torch.export of the full pipeline + a single test, on top of latest main.
No artifact/sub-module export scripts, no benchmarks, no uv.lock.

The exported contract preserves what wherobots-rasterflow relies on
(/home/rave/wherobots-rasterflow/.../adapters/torch/geometry_actors.py
calls self.model(patch_batch, token_ids) and unpacks 4 outputs).

Verified on RTX 3090 (torch 2.10.0+cu128)

range_constraints: {s34: VR[1, int_oo], s69: VR[1, int_oo]}
$ uv run pytest tests/export/test_full_pipeline_export.py -m slow
test_full_pipeline_export_matches_eager                          PASSED  (atol=0)
test_full_pipeline_export_save_load_roundtrip                    PASSED
test_full_pipeline_export_supports_dynamic_shapes[1-1]           PASSED
test_full_pipeline_export_supports_dynamic_shapes[1-2]           PASSED
test_full_pipeline_export_supports_dynamic_shapes[3-2]           PASSED
test_full_pipeline_export_supports_dynamic_shapes[2-4]           PASSED
======== 6 passed in ~3 min ========

Eager-vs-exported max abs diff is 0 on all 4 outputs (pred_logits,
pred_boxes, pred_masks, presence_logit_dec).

Required model patches

File Change
sam3/model/encoder.py Fix x.dimx.dim() (called the method object instead of invoking it)
sam3/model/geometry_encoders.py Skip pin_memory + non_blocking copy under torch.compiler.is_compiling() and on CPU. Convert ROIAlign input from a list-of-tensors to the [N, 5] batched form (lists of tensors aren't traceable).
sam3/model/position_encoding.py Skip the _encode_xy ndim/length assert when tracing — symbolic dims trip GuardOnDataDependentSymNode. Skip the shape-keyed forward cache when shape items are SymInts.
sam3/model/decoder.py (1) Allow 4D (B, H, Q, K) cross_attn_mask through TransformerDecoderLayer and add _cross_attn_with_rpb, a manual SDPA path that preserves a per-head additive bias (nn.MultiheadAttention rejects 4D masks). Casts proj weights to activation dtype to coexist with bf16 autocast. (2) Forward the box-RPB matrix in 4D (drop the flatten(0, 1) that produced the MHA-friendly 3D shape). (3) _make_box_rpb_relative: use torch.compiler.is_compiling() instead of is_dynamo_compiling() and drop the SymInt tuple-equality cache check + eager-only shape asserts on deltas/B.
sam3/model_builder.py Thread num_feature_levels through build_sam3_image_model_create_sam3_transformer / _create_sam3_model so callers can pin it explicitly.

New files

  • scripts/export_sam3_full_pipeline.pyFullSam3PipelineWrapper that unifies the four sub-modules into one forward. Wraps the model call in torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) because upstream's ViT MLP uses sam3.perflib.fused.addmm_act (added in the SAM 3.1 release) which forces bf16 internally and Sam3TrackingPredictor enters the same autocast in __init__. Outputs are cast back to fp32 so consumers see the same input/output dtype contract production already expects. CLI saves a .pt2 with Dim("batch", min=1) and Dim("num_prompts", min=1).

  • tests/export/test_full_pipeline_export.py — slow pytest marked with the slow marker (registered in pyproject.toml). Three tests:

    1. test_full_pipeline_export_matches_eager: exact-equality (atol=0) check between the eager wrapper and the exported program on the same input.
    2. test_full_pipeline_export_save_load_roundtrip: torch.export.savetorch.export.load → run on fresh inputs.
    3. test_full_pipeline_export_supports_dynamic_shapes: parametric run across (1,1), (1,2), (3,2), (2,4) (batch, num_prompts) pairs.

Notes / caveats

  • The 2.11 release notes review (from Bump torch to 2.11 and torchvision to 0.26 #8) holds: none of these workarounds have a clean 2.11-specific revert. After merging Bump torch to 2.11 and torchvision to 0.26 #8 we can re-test whether the manual SDPA RPB path can fall back to plain nn.MultiheadAttention; for now it's load-bearing.
  • AOTI (issue pytorch/pytorch#174608) is the next step; this PR sets up a clean base for that experiment.
  • Loading the saved .pt2 requires import torchvision.ops first to register torch.ops.torchvision.roi_align.default — same as the production deploy.

rbavery added 10 commits April 29, 2026 16:22
Bundle the image encoder, text encoder, encoder fusion, and decoder into a
single ExportedProgram so consumers can ship one .pt2. The graph accepts
dynamic batch and prompt dims; image H/W and CLIP context length are fixed.

Required model patches:

- sam3/model/encoder.py: fix x.dim → x.dim() (called the method object).
- sam3/model/geometry_encoders.py:
  * Skip pin_memory + non_blocking copy under torch._dynamo.is_compiling()
    and on CPU; both fail under torch.export.
  * Convert ROIAlign input from a list-of-tensors (one per batch) to the
    [N, 5] batched format with batch_idx prepended. Lists of tensors are
    not traceable through torch.export.
- sam3/model/position_encoding.py:
  * Skip the ndim/length assert in _encode_xy when tracing — symbolic dims
    from dynamic prompts trip GuardOnDataDependentSymNode.
  * Skip the shape-keyed cache when shapes are SymInts (looking up a SymInt
    in a dict of int keys raises the same guard).
- sam3/model/decoder.py:
  * Allow 4D (bs, nheads, nq, hw) cross_attn_mask through TransformerDecoderLayer.
    nn.MultiheadAttention rejects 4D additive bias, so add a manual SDPA
    path (_cross_attn_with_rpb) that unpacks in/out projections and calls
    F.scaled_dot_product_attention with the per-head bias.
  * Forward the box-RPB matrix in 4D (drop the flatten(0, 1) that produced
    the (bs*nheads, nq, hw) form for MHA).
  * In _make_box_rpb_relative, branch is_dynamo_compiling() before the
    tuple-equality cache check (SymInt tuple equality trips guards) and
    drop the eager-only shape asserts on deltas/B that fired with
    dynamic dims even with the is_dynamo_compiling guard.
- sam3/model_builder.py: thread num_feature_levels through
  build_sam3_image_model → _create_sam3_transformer/_model so callers can
  pin it explicitly (the production export passes 1).

New:
- scripts/export_sam3_full_pipeline.py — FullSam3PipelineWrapper that
  unifies the four sub-modules into one forward, plus a CLI to save a
  .pt2 with dynamic batch/prompt dims and fixed 1008x1008 / 32-token spec.
- tests/export/test_full_pipeline_export.py — slow pytest that builds
  the real model, exports it, and runs the exported module on inputs
  with batch=2 and num_prompts=4 to exercise both dynamic dims.
… export

is_dynamo_compiling() returns False during non-strict export, so the
SymInt tuple-equality on compilable_stored_size still triggered
GuardOnDataDependentSymNode. is_compiling() is the broader API that
covers torch.compile, strict export, and non-strict export.
cross_attn is sam3's custom MultiheadAttention (model_misc.MultiheadAttention,
aliased as MultiheadAttentionWrapper), not torch.nn.MultiheadAttention.
The two classes share the relevant attributes (in_proj_weight/bias,
out_proj, num_heads, head_dim, dropout), so duck-type instead.
Dim.AUTO with batch=2 example specializes the batch dim to min=2, which
rejects batch=1 at runtime. A named Dim with min=1 keeps the dim dynamic
across the full [1, +inf) range.
The image encoder runs under autocast and emits bf16, but the cross-attn
weights stay fp32. nn.MultiheadAttention's forward casts internally; our
manual SDPA path didn't, producing 'mat1 and mat2 must have the same
dtype' at runtime. Cast in_proj/out_proj weights and the attn bias to
match query dtype.
Upstream vitdet MLP uses sam3.perflib.fused.addmm_act which forces bf16
internally (added in the SAM 3.1 release commit 9f22cb9). Production
runs the model inside Sam3TrackingPredictor whose __init__ enters a
persistent bf16 autocast — so eager production code works. Our export
wrapper calls the model directly, so we have to re-enter autocast at
trace time, otherwise fc2 downstream of addmm_act sees a bf16 input
against fp32 weights and raises 'mat1 and mat2 must have the same dtype'.
Tracing under outer autocast baked the autocast effects into op dtypes
but missed the fp32→bf16 input boundary, causing the deserialized graph
to demand bf16 input at runtime. Putting autocast inside the wrapper
keeps the input/output fp32 contract while letting addmm_act's bf16
path coexist with the surrounding fp32 weights.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Development

Successfully merging this pull request may close these issues.

1 participant