Skip to content

Commit 48deb15

Browse files
authored
Simplify multigpu dispatch: run all devices on pool threads (#13340)
Benchmarked hybrid (main thread + pool) vs all-pool on 2x RTX 4090 with SD1.5 and NetaYume models. No meaningful performance difference (within noise). All-pool is simpler: eliminates the main_device special case, main_batch_tuple deferred execution, and the 3-way branch in the dispatch loop.
1 parent 4b93c43 commit 48deb15

File tree

1 file changed

+5
-12
lines changed

1 file changed

+5
-12
lines changed

comfy/samplers.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -516,25 +516,17 @@ def _handle_batch_pooled(device, batch_tuple):
516516

517517
results: list[thread_result] = []
518518
thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool")
519-
main_device = output_device
520-
main_batch_tuple = None
521519

522-
# Submit extra GPU work to pool first, then run main device on this thread
520+
# Submit all GPU work to pool threads
523521
pool_devices = []
524522
for device, batch_tuple in device_batched_hooked_to_run.items():
525-
if device == main_device and thread_pool is not None:
526-
main_batch_tuple = batch_tuple
527-
elif thread_pool is not None:
523+
if thread_pool is not None:
528524
thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple)
529525
pool_devices.append(device)
530526
else:
531527
# Fallback: no pool, run everything on main thread
532528
_handle_batch(device, batch_tuple, results)
533529

534-
# Run main device batch on this thread (parallel with pool workers)
535-
if main_batch_tuple is not None:
536-
_handle_batch(main_device, main_batch_tuple, results)
537-
538530
# Collect results from pool workers
539531
for device in pool_devices:
540532
worker_results, error = thread_pool.get_result(device)
@@ -1210,10 +1202,11 @@ def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None,
12101202

12111203
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
12121204

1213-
# Create persistent thread pool for extra GPU devices
1205+
# Create persistent thread pool for all GPU devices (main + extras)
12141206
if multigpu_patchers:
12151207
extra_devices = [p.load_device for p in multigpu_patchers]
1216-
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(extra_devices)
1208+
all_devices = [device] + extra_devices
1209+
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
12171210

12181211
try:
12191212
noise = noise.to(device=device, dtype=torch.float32)

0 commit comments

Comments
 (0)