Skip to content

Commit 6df3aee

Browse files
authored
Merge branch 'master' into fix/color-curves-shader-nested-sampler
2 parents b4156b8 + b353a7c commit 6df3aee

File tree

9 files changed

+61
-43
lines changed

9 files changed

+61
-43
lines changed

comfy/cli_args.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,13 @@ def from_string(cls, value: str):
110110

111111
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
112112

113+
CACHE_RAM_AUTO_GB = -1.0
114+
113115
cache_group = parser.add_mutually_exclusive_group()
114116
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
115117
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
116118
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
117-
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
119+
cache_group.add_argument("--cache-ram", nargs='?', const=CACHE_RAM_AUTO_GB, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threshold the cache removes large items to free RAM. Default (when no value is provided): 25%% of system RAM (min 4GB, max 32GB).")
118120

119121
attn_group = parser.add_mutually_exclusive_group()
120122
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")

comfy/memory_management.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,17 @@ def interpret_gathered_like(tensors, gathered):
141141
return dest_views
142142

143143
aimdo_enabled = False
144+
145+
extra_ram_release_callback = None
146+
RAM_CACHE_HEADROOM = 0
147+
148+
def set_ram_cache_release_state(callback, headroom):
149+
global extra_ram_release_callback
150+
global RAM_CACHE_HEADROOM
151+
extra_ram_release_callback = callback
152+
RAM_CACHE_HEADROOM = max(0, int(headroom))
153+
154+
def extra_ram_release(target):
155+
if extra_ram_release_callback is None:
156+
return 0
157+
return extra_ram_release_callback(target)

comfy/model_management.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
669669

670670
for i in range(len(current_loaded_models) -1, -1, -1):
671671
shift_model = current_loaded_models[i]
672-
if shift_model.device == device:
672+
if device is None or shift_model.device == device:
673673
if shift_model not in keep_loaded and not shift_model.is_dead():
674674
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
675675
shift_model.currently_used = False
@@ -679,8 +679,8 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
679679
i = x[-1]
680680
memory_to_free = 1e32
681681
pins_to_free = 1e32
682-
if not DISABLE_SMART_MEMORY:
683-
memory_to_free = memory_required - get_free_memory(device)
682+
if not DISABLE_SMART_MEMORY or device is None:
683+
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
684684
pins_to_free = pins_required - get_free_ram()
685685
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
686686
#don't actually unload dynamic models for the sake of other dynamic models
@@ -708,7 +708,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
708708

709709
if len(unloaded_model) > 0:
710710
soft_empty_cache()
711-
else:
711+
elif device is not None:
712712
if vram_state != VRAMState.HIGH_VRAM:
713713
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
714714
if mem_free_torch > mem_free_total * 0.25:

comfy/model_patcher.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,9 +300,6 @@ def model_size(self):
300300
def model_mmap_residency(self, free=False):
301301
return comfy.model_management.module_mmap_residency(self.model, free=free)
302302

303-
def get_ram_usage(self):
304-
return self.model_size()
305-
306303
def loaded_size(self):
307304
return self.model.model_loaded_weight_memory
308305

comfy/pinned_memory.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import comfy.memory_management
33
import comfy_aimdo.host_buffer
44
import comfy_aimdo.torch
5+
import psutil
56

67
from comfy.cli_args import args
78

@@ -12,6 +13,11 @@ def pin_memory(module):
1213
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
1314
return
1415
#FIXME: This is a RAM cache trigger event
16+
ram_headroom = comfy.memory_management.RAM_CACHE_HEADROOM
17+
#we split the difference and assume half the RAM cache headroom is for us
18+
if ram_headroom > 0 and psutil.virtual_memory().available < (ram_headroom * 0.5):
19+
comfy.memory_management.extra_ram_release(ram_headroom)
20+
1521
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
1622

1723
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:

comfy/sd.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,6 @@ def clone(self, disable_dynamic=False):
280280
n.apply_hooks_to_conds = self.apply_hooks_to_conds
281281
return n
282282

283-
def get_ram_usage(self):
284-
return self.patcher.get_ram_usage()
285-
286283
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
287284
return self.patcher.add_patches(patches, strength_patch, strength_model)
288285

@@ -840,9 +837,6 @@ def model_size(self):
840837
self.size = comfy.model_management.module_size(self.first_stage_model)
841838
return self.size
842839

843-
def get_ram_usage(self):
844-
return self.model_size()
845-
846840
def throw_exception_if_invalid(self):
847841
if self.first_stage_model is None:
848842
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")

comfy_execution/caching.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
22
import bisect
3-
import gc
43
import itertools
54
import psutil
65
import time
@@ -475,6 +474,10 @@ async def set(self, node_id, value):
475474
self._mark_used(node_id)
476475
return await self._set_immediate(node_id, value)
477476

477+
def set_local(self, node_id, value):
478+
self._mark_used(node_id)
479+
BasicCache.set_local(self, node_id, value)
480+
478481
async def ensure_subcache_for(self, node_id, children_ids):
479482
# Just uses subcaches for tracking 'live' nodes
480483
await super()._ensure_subcache(node_id, children_ids)
@@ -489,15 +492,10 @@ async def ensure_subcache_for(self, node_id, children_ids):
489492
return self
490493

491494

492-
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
493-
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
494-
495-
RAM_CACHE_HYSTERESIS = 1.1
495+
#Small baseline weight used when a cache entry has no measurable CPU tensors.
496+
#Keeps unknown-sized entries in eviction scoring without dominating tensor-backed entries.
496497

497-
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
498-
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
499-
500-
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
498+
RAM_CACHE_DEFAULT_RAM_USAGE = 0.05
501499

502500
#Exponential bias towards evicting older workflows so garbage will be taken out
503501
#in constantly changing setups.
@@ -521,19 +519,17 @@ async def get(self, node_id):
521519
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
522520
return await super().get(node_id)
523521

524-
def poll(self, ram_headroom):
525-
def _ram_gb():
526-
return psutil.virtual_memory().available / (1024**3)
522+
def set_local(self, node_id, value):
523+
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
524+
super().set_local(node_id, value)
527525

528-
if _ram_gb() > ram_headroom:
529-
return
530-
gc.collect()
531-
if _ram_gb() > ram_headroom:
526+
def ram_release(self, target):
527+
if psutil.virtual_memory().available >= target:
532528
return
533529

534530
clean_list = []
535531

536-
for key, (outputs, _), in self.cache.items():
532+
for key, cache_entry in self.cache.items():
537533
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
538534

539535
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
@@ -542,22 +538,20 @@ def scan_list_for_ram_usage(outputs):
542538
if outputs is None:
543539
return
544540
for output in outputs:
545-
if isinstance(output, list):
541+
if isinstance(output, (list, tuple)):
546542
scan_list_for_ram_usage(output)
547543
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
548-
#score Tensors at a 50% discount for RAM usage as they are likely to
549-
#be high value intermediates
550-
ram_usage += (output.numel() * output.element_size()) * 0.5
551-
elif hasattr(output, "get_ram_usage"):
552-
ram_usage += output.get_ram_usage()
553-
scan_list_for_ram_usage(outputs)
544+
ram_usage += output.numel() * output.element_size()
545+
scan_list_for_ram_usage(cache_entry.outputs)
554546

555547
oom_score *= ram_usage
556548
#In the case where we have no information on the node ram usage at all,
557549
#break OOM score ties on the last touch timestamp (pure LRU)
558550
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
559551

560-
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
552+
while psutil.virtual_memory().available < target and clean_list:
561553
_, _, key = clean_list.pop()
562554
del self.cache[key]
563-
gc.collect()
555+
self.used_generation.pop(key, None)
556+
self.timestamps.pop(key, None)
557+
self.children.pop(key, None)

execution.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,9 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
724724
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
725725

726726
self._notify_prompt_lifecycle("start", prompt_id)
727+
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
728+
ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None
729+
comfy.memory_management.set_ram_cache_release_state(ram_release_callback, ram_headroom)
727730

728731
try:
729732
with torch.inference_mode():
@@ -773,7 +776,10 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
773776
execution_list.unstage_node_execution()
774777
else: # result == ExecutionResult.SUCCESS:
775778
execution_list.complete_node_execution()
776-
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
779+
780+
if self.cache_type == CacheType.RAM_PRESSURE:
781+
comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom)
782+
comfy.memory_management.extra_ram_release(ram_headroom)
777783
else:
778784
# Only execute when the while-loop ends without break
779785
# Send cached UI for intermediate output nodes that weren't executed
@@ -801,6 +807,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
801807
if comfy.model_management.DISABLE_SMART_MEMORY:
802808
comfy.model_management.unload_all_models()
803809
finally:
810+
comfy.memory_management.set_ram_cache_release_state(None, 0)
804811
self._notify_prompt_lifecycle("end", prompt_id)
805812

806813

main.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,15 +275,19 @@ def _collect_output_absolute_paths(history_result: dict) -> list[str]:
275275

276276
def prompt_worker(q, server_instance):
277277
current_time: float = 0.0
278+
cache_ram = args.cache_ram
279+
if cache_ram < 0:
280+
cache_ram = min(32.0, max(4.0, comfy.model_management.total_ram * 0.25 / 1024.0))
281+
278282
cache_type = execution.CacheType.CLASSIC
279283
if args.cache_lru > 0:
280284
cache_type = execution.CacheType.LRU
281-
elif args.cache_ram > 0:
285+
elif cache_ram > 0:
282286
cache_type = execution.CacheType.RAM_PRESSURE
283287
elif args.cache_none:
284288
cache_type = execution.CacheType.NONE
285289

286-
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } )
290+
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram } )
287291
last_gc_collect = 0
288292
need_gc = False
289293
gc_collect_interval = 10.0

0 commit comments

Comments
 (0)