-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_push.py
More file actions
602 lines (519 loc) · 27.8 KB
/
run_push.py
File metadata and controls
602 lines (519 loc) · 27.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
"""
run_push.py — SAM3 + image-space visual servoing cube push (no cube depth).
Single-camera pipeline. The 20×20 attention window is cropped from the
thirdview frame at the cube-table interface; gripper image is locked to
the cube pixel by Phase-2 ray-following so the gripper stays above the
window.
Pipeline per episode:
1. Reset env, capture first thirdview frame
2. SAM3 on thirdview frame → cube mask + centroid + bottom-edge anchor.
Phases B1/B2 servo to the mask BOTTOM-edge midpoint (front-face-top
of the cube) so the gripper tip lands on the front face of the cube;
the centroid is shown only for diagnostics.
3. Phase A : move EEF to APPROACH_HEIGHT above the table at current tip XY
4. Phase B1: image-space XY alignment — at each step solve for the world XY
(at the gripper tip's CURRENT Z) that projects to the cube's
front-face-top pixel; step toward it until the tip pixel is
within STOP_PIXEL_THRESHOLD
5. Phase B2: descend along the thirdview camera ray (keeps front-face-top
pixel locked) until the cube's yellow pixels in the 20×20 window
crop drop below YELLOW_DROP_FRACTION of their initial count
— i.e. the cube has translated up and out of the window.
On the trigger, re-run SAM3 on the current thirdview frame
and report the cube's 3D position from sim ground truth.
6. Phase C : grasp demo (lift → open → query env.get_cube_pos() → move
over cube → descend → close → final lift → hold). The push
phases (A/B1/B2) never use cube 3D coords; the grasp uses
the simulator's actual cube pose.
The cube's 3D height is never used for control — only the EEF's own
(kinematically known) depth and the camera intrinsics + extrinsics. The
final 3D coords printed on contact are simulator ground truth, for
verification only.
Outputs (per invocation, written to results/run_<N>/):
thirdview.mp4 — third-view recording with SAM3/EEF overlays
attention.mp4 — 20×20 thirdview crop of the attention window,
upscaled 12×
sideview.mp4 — perpendicular side-view camera (visual sanity
check that the gripper tip contacts the cube
front face at the correct height)
debug_first_frame.png — raw thirdview at start of episode
debug_sam3.png — SAM3 mask + 20×20 attention window on thirdview
Usage:
uv run python run_push.py [--checkpoint /path/to/sam3.pt]
"""
import argparse
from pathlib import Path
import numpy as np
import imageio
import cv2
from cube_push_env import FrankaCubePushEnv, TABLE_SURFACE_Z, CUBE_HALF
from segmentation import load_sam3_model, segment_cube_sam3
from servoing import pixel_to_world_at_z
RESULTS_ROOT = Path(__file__).resolve().parent / "results"
# ---------------------------------------------------------------------------
# Recording / window constants
# ---------------------------------------------------------------------------
FPS = 20
ATTENTION_HALF = 6 # 20×20 thirdview overlay rectangle (visualization)
DETECT_HALF = 6 # 20×20 crop used for contact detection
# SAM3's mask on thirdview underestimates the cube (only catches the top
# face), so its bottom_y sits on the cube body rather than the actual
# cube-table line. Push the anchor down by this many pixels so the
# 20×20 window straddles the real cube-bottom edge in image space — the
# upper half of the window contains cube body, the lower half contains
# the table strip just below the cube. When the cube is pushed in +Y
# (away from camera), the cube image translates UP and exits the window
# from the top, leaving pure table content behind.
WINDOW_DROP_PX = 31 # tuned for CUBE_HALF=0.035 (7 cm cube)
# ---------------------------------------------------------------------------
# Servoing parameters — purely image-space alignment + camera-ray descent.
# Cube depth is NEVER assumed.
# ---------------------------------------------------------------------------
APPROACH_HEIGHT = 0.20 # m above table surface (Phase A target Z)
STOP_PIXEL_THRESHOLD = 20 # px: Phase B1 done when |EEF px - cube px| < this
SERVO_DZ = 0.008 # camera-frame depth step per Phase-2 iteration (m).
# World-Z descent is ~SERVO_DZ * cos(camera_tilt),
# and the OSC controller tracks at ~15% of command,
# so 0.008 → ~0.5 mm actual world-Z descent per step.
PHASE1_MAX_STEPS = 120
PHASE2_MAX_STEPS = 600
PHASE2_WARMUP_STEPS = 20 # skip checks only for early alignment transients;
# do not block stopping for most of the descent
ATTN_VIDEO_UPSCALE = 12 # upscale 20×20 attn crop → 240×240 in the saved video
# Detection: count near-white (table) pixels in the attention-window crop.
# The cube fills the upper portion of the window initially (cube body =
# coloured, not white), so initial white count is low. When the gripper
# pushes the cube +Y (away from camera, up in image), the cube exits the
# window from above and only table remains — white count rises sharply.
# Trigger fires once the window is dominantly white.
WHITE_RGB_THRESH = 200 # per-channel: R, G, AND B must each exceed
# this to be "white" (rejects cube body and
# gripper, both of which fail the B channel)
WHITE_TRIGGER_FRACTION = 0.60 # trigger when white fraction of window
# pixels exceeds this (cube is gone, table
# fills the window)
WHITE_MIN_RISE = 20 # absolute minimum rise in white count from
# initial — guards against false fires when
# the initial frame already happened to have
# mostly table (cube barely in window)
# --- Phase C: hardcoded grasp after servo stops ---
# After Phase B2 contact, lift up + open gripper, THEN navigate to the
# cube's actual ground-truth pose (env.get_cube_pos()) and grasp. The
# servo phases (A/B1/B2) never used any cube 3D coords; the grasp is a
# separate demo step that uses the simulator's true cube pose.
GRASP_LIFT_DZ = 0.10 # m: lift relative to post-contact tip
GRASP_HOVER_DZ = 0.10 # m: hover above cube top before descending
GRIPPER_HOLD_STEPS = 15 # steps to settle open/close
GRASP_FINAL_DZ = 0.15 # m: final demonstration lift above cube
POST_GRASP_HOLD = 20 # ~1 s @ 20 fps
# ---------------------------------------------------------------------------
# Robot motion helpers (unchanged)
# ---------------------------------------------------------------------------
def step_toward(env, obs, target_pos, gripper=-1.0, max_steps=80, tol=0.005):
"""Move gripper tip toward target_pos; yields obs after each step."""
for _ in range(max_steps):
tip_pos = env.get_gripper_tip_pos()
err = target_pos - tip_pos
if np.linalg.norm(err) < tol:
break
action = np.zeros(7)
action[:3] = np.clip(err / 0.05, -1.0, 1.0)
action[6] = gripper
obs, _, _, _ = env.step(action)
yield obs
def hold(env, obs, n_steps, gripper=-1.0):
"""Hold current position for n_steps."""
for _ in range(n_steps):
action = np.zeros(7)
action[6] = gripper
obs, _, _, _ = env.step(action)
yield obs
# ---------------------------------------------------------------------------
# Debug overlay
# ---------------------------------------------------------------------------
def draw_debug_overlay(frame, mask, contact_px, attn_anchor_px, eef_px=None):
"""Overlay contact-target dot, attention rectangle, and tip dot.
contact_px is the front-face-top pixel (sam3_bottom_px) — where the
gripper tip should make contact with the cube. Drawn red.
"""
vis = frame.copy()
# Contact-target dot (red): where the eef tip should touch the cube
if contact_px is not None:
cv2.circle(vis, contact_px, 5, (255, 50, 50), -1)
# Attention window rectangle (green)
if attn_anchor_px is not None:
ax, ay = attn_anchor_px
cv2.rectangle(
vis,
(ax - ATTENTION_HALF, ay - ATTENTION_HALF),
(ax + ATTENTION_HALF, ay + ATTENTION_HALF),
(0, 255, 0), 1,
)
# EEF projection dot (yellow)
if eef_px is not None:
cv2.circle(vis, eef_px, 5, (255, 255, 0), -1)
return vis
def save_sam3_debug(frame, mask, centroid_px, attn_anchor_px, path="debug_sam3.png"):
"""Save a static debug image showing the SAM3 detection (mask fill + keypoints)."""
vis = frame.copy()
if mask is not None:
vis[mask] = (
0.5 * vis[mask].astype(float) + 0.5 * np.array([80, 80, 255])
).astype(np.uint8)
print(f" SAM3 mask coverage: {mask.mean()*100:.1f}% of image")
if centroid_px is not None:
cv2.circle(vis, centroid_px, 6, (255, 50, 50), -1)
if attn_anchor_px is not None:
ax, ay = attn_anchor_px
cv2.rectangle(vis,
(ax - ATTENTION_HALF, ay - ATTENTION_HALF),
(ax + ATTENTION_HALF, ay + ATTENTION_HALF),
(0, 255, 0), 2)
# save as BGR for OpenCV
cv2.imwrite(path, cv2.cvtColor(vis, cv2.COLOR_RGB2BGR))
print(f" Saved SAM3 debug image: {path}")
# ---------------------------------------------------------------------------
# Main episode
# ---------------------------------------------------------------------------
def run_episode(env, processor, out_dir):
"""
Execute the full SAM3-guided push and return (thirdview_frames, attention_frames).
No 3D cube coordinates are used for navigation — all motion is derived from
SAM3 detection + 2D pixel reprojection.
Args:
env: FrankaCubePushEnv instance.
processor: Sam3Processor returned by segmentation.load_sam3_model().
out_dir: pathlib.Path where debug PNGs are written.
"""
obs = env.reset()
# Let the simulation settle and the robot reach its home pose before we
# capture the SAM3 frame — the arm is still moving immediately after reset.
for obs in hold(env, obs, n_steps=20, gripper=1.0):
pass
first_frame = env.get_frame(obs)
# Save raw frame so you can inspect what SAM3 sees
cv2.imwrite(str(out_dir / "debug_first_frame.png"),
cv2.cvtColor(first_frame, cv2.COLOR_RGB2BGR))
print(f"Saved debug_first_frame.png in {out_dir}")
# ------------------------------------------------------------------
# SAM3 detection on the (only) thirdview frame. Returns the cube
# mask centroid (used for Phase-1 image-space servoing) and the
# mask's bottom-edge midpoint, which we shift down by WINDOW_DROP_PX
# to anchor the 20×20 attention window at the real cube-table line.
# ------------------------------------------------------------------
print("\n--- SAM3 detection (thirdview camera) ---")
# The cube is small (~25 px) in the full thirdview frame — too small for
# SAM3 to reliably detect. Crop a generous image-space window over the
# table region and upscale before running SAM3. Center is a fixed
# image-space prior (lower-center of the frame, where the table sits in
# this fixed-camera setup); no cube 3D pose is used.
H_img, W_img = first_frame.shape[:2]
crop_center_px = (W_img // 2, int(H_img * 0.6))
print(f" SAM3 crop_center={crop_center_px} half=120 upscale=4")
sam3_mask, centroid_px, sam3_bottom_px = segment_cube_sam3(
first_frame, processor,
crop_center=crop_center_px, crop_half=120, upscale=4,
)
# Anchor the window at the cube-table line: SAM3 mask bottom (top-face
# bottom edge in image) + WINDOW_DROP_PX to reach the actual cube body
# bottom + a strip of table. Cube is in the upper half of the window
# initially; when the gripper pushes the cube +Y (away from camera, up
# in image), the cube exits the window from above and only table
# remains.
attn_anchor_px = (sam3_bottom_px[0], sam3_bottom_px[1] + WINDOW_DROP_PX)
print(f" centroid_px={centroid_px} sam3_bottom_px={sam3_bottom_px} "
f"attn_window_px={attn_anchor_px}")
eef_ref = obs["robot0_eef_pos"]
tip = env.get_gripper_tip_pos()
print(f" [diag] eef_ref={eef_ref.round(4)} tip={tip.round(4)} "
f"delta={(tip - eef_ref).round(4)}")
# === DIAGNOSTIC: compare SAM3 detection to ground-truth pixel projection ===
actual_cube = env.get_cube_pos()
actual_third_px = env.world_to_pixel(actual_cube)
print(f" [diag] actual cube 3D : {actual_cube.round(3)}")
print(f" [diag] actual cube px : {actual_third_px} (SAM3 said {centroid_px})")
print(f" [diag] pixel error : "
f"({centroid_px[0]-actual_third_px[0]:+d}, {centroid_px[1]-actual_third_px[1]:+d})")
save_sam3_debug(first_frame, sam3_mask, centroid_px, attn_anchor_px,
str(out_dir / "debug_sam3.png"))
print("Saved debug_sam3.png — verify the 20×20 attention window sits on "
"the cube-bottom + table strip in the thirdview frame")
# Thirdview camera params (used by Phases B1 + B2)
cam_pos, cam_R, f, W, H = env.get_camera_params()
# 20×20 detection-crop bounds in the THIRDVIEW frame.
awx, awy = attn_anchor_px
adx1 = max(awx - DETECT_HALF, 0)
ady1 = max(awy - DETECT_HALF, 0)
adx2 = min(awx + DETECT_HALF, W)
ady2 = min(awy + DETECT_HALF, H)
# ------------------------------------------------------------------
# Recording helpers — attention.mp4 shows ONLY the 20×20 thirdview
# crop at the attention window (upscaled), so the user sees exactly
# what the contact algorithm sees. Thirdview overlay still draws
# the full scene.
# ------------------------------------------------------------------
thirdview_frames = []
attention_frames = [] # each frame is the 20×20 crop, upscaled at save
sideview_frames = [] # perpendicular camera, for visual verification
def record(obs, eef_px=None):
frame = env.get_frame(obs)
vis = draw_debug_overlay(frame, sam3_mask, sam3_bottom_px, attn_anchor_px, eef_px)
thirdview_frames.append(vis)
crop = frame[ady1:ady2, adx1:adx2]
# Pad to consistent 20×20 if window clipped at frame edge
ph = (2 * DETECT_HALF) - crop.shape[0]
pw = (2 * DETECT_HALF) - crop.shape[1]
if ph > 0 or pw > 0:
crop = np.pad(crop, ((0, ph), (0, pw), (0, 0)))
attention_frames.append(crop)
sideview_frames.append(env.get_sideview_frame(obs))
return frame
record(obs)
# ------------------------------------------------------------------
# Phase A: move EEF straight up to APPROACH_HEIGHT above the table.
# No cube coords used — we just lift to a known-safe height before servoing.
# ------------------------------------------------------------------
print("\n--- Phase A: lift to approach height ---")
tip_now = env.get_gripper_tip_pos()
above_3d = np.array([tip_now[0], tip_now[1], TABLE_SURFACE_Z + APPROACH_HEIGHT])
print(f" current tip={tip_now.round(3)} above_target={above_3d.round(3)}")
for obs in step_toward(env, obs, above_3d, gripper=1.0, max_steps=100):
eef_px = env.world_to_pixel(env.get_gripper_tip_pos())
record(obs, eef_px)
# ------------------------------------------------------------------
# Phase B1: image-space XY alignment (no cube depth).
#
# At each step, solve for the world XY at the EEF's CURRENT Z that
# projects to the cube's pixel — that's the world point through the
# camera ray at the gripper's known depth — and step toward it.
# ------------------------------------------------------------------
print("\n--- Phase B1: image-space XY alignment ---")
for step_i in range(PHASE1_MAX_STEPS):
tip_pos = env.get_gripper_tip_pos()
tip_z = float(tip_pos[2])
target_world = pixel_to_world_at_z(
sam3_bottom_px[0], sam3_bottom_px[1], tip_z,
cam_pos, cam_R, f, W, H,
)
target_world[2] = tip_z # alignment is pure XY — preserve Z
err = target_world - tip_pos
action = np.zeros(7)
action[:3] = np.clip(err / 0.05, -1.0, 1.0)
action[6] = 1.0 # keep gripper closed so fingers stay at the tip site
obs, _, _, _ = env.step(action)
eef_px = env.world_to_pixel(env.get_gripper_tip_pos())
record(obs, eef_px)
pix_err = float(np.hypot(eef_px[0] - sam3_bottom_px[0],
eef_px[1] - sam3_bottom_px[1]))
if step_i % 5 == 0:
print(f" step={step_i:3d} tip px={eef_px} front-face px={sam3_bottom_px} "
f"pix_err={pix_err:.1f}")
if pix_err < STOP_PIXEL_THRESHOLD:
print(f"====== ALIGNED at step {step_i} pix_err={pix_err:.1f} ======")
break
else:
print(f"[Phase B1] Max iterations reached. pix_err={pix_err:.1f}")
# ------------------------------------------------------------------
# Phase B2: descend along the thirdview camera ray through the cube
# pixel. The cube pixel stays locked in the image while the EEF moves
# toward it; stop when the 20×20 attention window detects motion via
# image differencing + Farneback optical flow.
# ------------------------------------------------------------------
print("\n--- Phase B2: camera-ray descent ---")
# Ray direction in camera frame for the front-face-top pixel. Same
# convention as pixel_to_world_at_z: p_cam[0] = ax_ray*z,
# p_cam[1] = ay_ray*z, p_cam[2] = z.
u_cube, v_cube = sam3_bottom_px
v_gl = H - 1 - v_cube
ax_ray = -(u_cube - W / 2.0) / f
ay_ray = -(v_gl - H / 2.0) / f
contact_confirmed = False
start_z = float(env.get_gripper_tip_pos()[2])
print(f" start tip z={start_z:.4f}, ray=({ax_ray:+.3f}, {ay_ray:+.3f})")
window_total_px = (ady2 - ady1) * (adx2 - adx1)
def white_count(rgb_crop):
"""Count near-white (table) pixels: R, G, AND B all above
WHITE_RGB_THRESH. Cube body fails the B channel (cube b is much
lower than r,g); gripper is gray (~128) and fails the per-channel
threshold; cube shadow on table darkens R/G/B uniformly so it also
fails."""
return int(np.sum(np.all(rgb_crop > WHITE_RGB_THRESH, axis=-1)))
# Capture initial white_count from the CURRENT frame at the start of
# Phase B2 (after Phase A/B1). This keeps the baseline aligned with the
# active attention-window content instead of using reset-time frame 0.
phase2_start_frame = env.get_frame(obs)
initial_crop = phase2_start_frame[ady1:ady2, adx1:adx2]
initial_white_count = white_count(initial_crop)
print(f" initial white_count (Phase B2 start) = {initial_white_count} / "
f"{window_total_px} px")
for step_i in range(PHASE2_MAX_STEPS):
tip_pos = env.get_gripper_tip_pos()
# Current tip camera-frame depth (negative; camera looks along -cam_Z).
p_cam_eef = cam_R.T @ (tip_pos - cam_pos)
# Step "deeper" along the ray: p_cam[2] becomes more negative.
z_cam_target = float(p_cam_eef[2]) - SERVO_DZ
target_cam = np.array([ax_ray * z_cam_target,
ay_ray * z_cam_target,
z_cam_target])
target_world = cam_pos + cam_R @ target_cam
err = target_world - tip_pos
action = np.zeros(7)
action[:3] = np.clip(err / 0.05, -1.0, 1.0)
action[6] = 1.0 # keep gripper closed so fingers stay at the tip site
obs, _, _, _ = env.step(action)
frame = record(obs, env.world_to_pixel(env.get_gripper_tip_pos()))
# 20×20 thirdview crop at the attention window
crop_rgb = frame[ady1:ady2, adx1:adx2]
if step_i % 25 == 0:
tp = env.get_gripper_tip_pos()
tip_px_now = env.world_to_pixel(tp)
print(f" step={step_i:3d} tip=({tp[0]:+.3f},{tp[1]:+.3f},{tp[2]:.3f}) "
f"tip px={tip_px_now} front-face px={sam3_bottom_px}")
if step_i >= PHASE2_WARMUP_STEPS:
wc = white_count(crop_rgb)
frac = wc / max(window_total_px, 1)
rise = wc - initial_white_count
if step_i % 25 == 0:
print(f" white_count={wc:3d}/{window_total_px} "
f"frac={frac:.2f} rise={rise:+d} "
f"(trigger when frac>{WHITE_TRIGGER_FRACTION} and "
f"rise>{WHITE_MIN_RISE})")
if frac > WHITE_TRIGGER_FRACTION and rise > WHITE_MIN_RISE:
print(f"\n--- Cube exited attention window at step {step_i}, "
f"tip z={env.get_gripper_tip_pos()[2]:.4f} ---")
print(f" white_count: {initial_white_count} → {wc} "
f"(frac {frac:.2f}, rise {rise:+d})")
# SAM3 re-segmentation on the current thirdview frame, so
# we can show the user that the cube's mask centroid has
# shifted from its original position (the "compare SAM3
# segmentations from different frames" check).
print("Running SAM3 on current thirdview frame for comparison…")
_, verify_centroid_px, _ = segment_cube_sam3(frame, processor)
shift = float(np.hypot(
verify_centroid_px[0] - centroid_px[0],
verify_centroid_px[1] - centroid_px[1],
))
print(f" original SAM3 centroid: {centroid_px}")
print(f" current SAM3 centroid: {verify_centroid_px}")
print(f" pixel shift : {shift:.1f} px")
print("====== CUBE MOVEMENT CONFIRMED ======")
cube_now = env.get_cube_pos()
print(f"\nCube 3D position (simulator ground truth):")
print(f" x = {cube_now[0]:+.4f} m")
print(f" y = {cube_now[1]:+.4f} m")
print(f" z = {cube_now[2]:+.4f} m")
print(f" delta from start: {(cube_now - actual_cube).round(4)}")
contact_confirmed = True
break
if not contact_confirmed:
print(f"[Phase B2] Max iterations reached. tip z={env.get_gripper_tip_pos()[2]:.4f}")
cube_now = env.get_cube_pos()
print(f" cube position at end : {cube_now.round(4)}")
print(f" cube delta from start: {(cube_now - actual_cube).round(4)}")
# ------------------------------------------------------------------
# Phase C: grasp demo using the cube's actual ground-truth pose.
# C1 lift (closed) → C2 open → C3 query cube pose, move over cube
# → C4 descend onto cube → C5 close → C6 final lift → C7 hold
# ------------------------------------------------------------------
print("\n--- Phase C: grasp using actual cube pose ---")
tip0 = env.get_gripper_tip_pos()
# C1: lift straight up (gripper still closed from Phase B2)
lift_target = tip0 + np.array([0.0, 0.0, GRASP_LIFT_DZ])
print(f" C1 lift → {lift_target.round(3)}")
for obs in step_toward(env, obs, lift_target, gripper=1.0, max_steps=80):
record(obs, env.world_to_pixel(env.get_gripper_tip_pos()))
# C2: open gripper in place
print(" C2 open gripper")
for obs in hold(env, obs, n_steps=GRIPPER_HOLD_STEPS, gripper=-1.0):
record(obs, env.world_to_pixel(env.get_gripper_tip_pos()))
# C3: query the cube's actual pose AFTER the push and move horizontally
# to hover above it (preserving the lifted Z so we don't graze the cube).
cube_pose = env.get_cube_pos()
hover_z = max(lift_target[2], cube_pose[2] + GRASP_HOVER_DZ)
over_cube = np.array([cube_pose[0], cube_pose[1], hover_z])
print(f" cube actual pose = {cube_pose.round(4)}")
print(f" C3 over cube → {over_cube.round(3)}")
for obs in step_toward(env, obs, over_cube, gripper=-1.0, max_steps=120):
record(obs, env.world_to_pixel(env.get_gripper_tip_pos()))
# C4: descend to actual cube center Z so the open fingers straddle the cube
descend_target = np.array([cube_pose[0], cube_pose[1], cube_pose[2]])
print(f" C4 descend → {descend_target.round(3)}")
for obs in step_toward(env, obs, descend_target, gripper=-1.0, max_steps=120):
record(obs, env.world_to_pixel(env.get_gripper_tip_pos()))
# C5: close gripper to clamp cube
print(" C5 close gripper")
for obs in hold(env, obs, n_steps=GRIPPER_HOLD_STEPS, gripper=1.0):
record(obs, env.world_to_pixel(env.get_gripper_tip_pos()))
# C6: lift to demonstrate grasp
final_target = descend_target + np.array([0.0, 0.0, GRASP_FINAL_DZ])
print(f" C6 final lift → {final_target.round(3)}")
for obs in step_toward(env, obs, final_target, gripper=1.0, max_steps=120):
record(obs, env.world_to_pixel(env.get_gripper_tip_pos()))
# C7: hold so the held cube is visible at end of video
print(" C7 hold")
for obs in hold(env, obs, n_steps=POST_GRASP_HOLD, gripper=1.0):
record(obs, env.world_to_pixel(env.get_gripper_tip_pos()))
cube_final = env.get_cube_pos()
print(f" pre-grasp cube z = {cube_pose[2]:.4f}, "
f"post-grasp cube z = {cube_final[2]:.4f}, "
f"lift Δz = {cube_final[2] - cube_pose[2]:+.4f} m")
print(f"\nTotal frames: {len(thirdview_frames)}")
return thirdview_frames, attention_frames, sideview_frames
# ---------------------------------------------------------------------------
# Video helpers
# ---------------------------------------------------------------------------
def save_video(frames, path, fps, upscale=1):
writer = imageio.get_writer(path, fps=fps, codec="libx264", quality=8, macro_block_size=1)
for frame in frames:
if upscale > 1:
frame = np.repeat(np.repeat(frame, upscale, axis=0), upscale, axis=1)
writer.append_data(frame)
writer.close()
h, w = frames[0].shape[:2]
print(f"Saved {path} ({len(frames)} frames @ {fps} fps, {w * upscale}×{h * upscale})")
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
def _next_run_dir() -> Path:
"""Return results/run_<N>/ with N = max existing index + 1 (1-based)."""
RESULTS_ROOT.mkdir(exist_ok=True)
existing = [int(p.name.split("_")[1]) for p in RESULTS_ROOT.glob("run_*")
if p.is_dir() and p.name.split("_", 1)[1].isdigit()]
n = (max(existing) + 1) if existing else 1
out = RESULTS_ROOT / f"run_{n}"
out.mkdir()
return out
def main():
parser = argparse.ArgumentParser(description="SAM3 + visual servoing cube push")
parser.add_argument(
"--checkpoint", type=str, default=None,
help="Local SAM3 checkpoint (.pt). Omit to load from HuggingFace.",
)
args = parser.parse_args()
out_dir = _next_run_dir()
print(f"Run output directory: {out_dir}\n")
print("Loading SAM3 model...")
_, processor = load_sam3_model(args.checkpoint)
print("SAM3 loaded.\n")
print("Creating FrankaCubePushEnv...")
env = FrankaCubePushEnv(
camera_height=480,
camera_width=480,
has_offscreen_renderer=True,
use_camera_obs=True,
has_renderer=False,
horizon=2000,
ignore_done=True,
)
thirdview_frames, attention_frames, sideview_frames = run_episode(
env, processor, out_dir
)
env.close()
print("\nSaving videos...")
save_video(thirdview_frames, str(out_dir / "thirdview.mp4"), FPS)
save_video(attention_frames, str(out_dir / "attention.mp4"), FPS,
upscale=ATTN_VIDEO_UPSCALE)
save_video(sideview_frames, str(out_dir / "sideview.mp4"), FPS)
print(f"\nDone. All outputs in {out_dir}/")
if __name__ == "__main__":
main()