Skip to content

Commit 0ce2698

Browse files
committed
fix VideoFromFile stream source to _ReentrantBytesIO for parallel async use
1 parent 807538f commit 0ce2698

1 file changed

Lines changed: 170 additions & 34 deletions

File tree

comfy_api/latest/_input_impl/video_types.py

Lines changed: 170 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from av.container import InputContainer
33
from av.subtitles.stream import SubtitleStream
44
from fractions import Fraction
5-
from typing import Optional
5+
from typing import Optional, IO, Iterator
6+
from contextlib import contextmanager, suppress
67
from .._input import AudioInput, VideoInput
78
import av
89
import io
@@ -13,6 +14,129 @@
1314
from .._util import VideoContainer, VideoCodec, VideoComponents
1415

1516

17+
class _ReentrantBytesIO(io.BytesIO):
18+
"""Read-only, seekable BytesIO-compatible view over shared immutable bytes."""
19+
20+
def __init__(self, data: bytes):
21+
super().__init__(b"") # Initialize base BytesIO with an empty buffer; we do not use its internal storage.
22+
if data is None:
23+
raise TypeError("data must be bytes, not None")
24+
self._data = data
25+
self._view = memoryview(data)
26+
self._pos = 0
27+
28+
def getvalue(self) -> bytes:
29+
if self.closed:
30+
raise ValueError("I/O operation on closed file.")
31+
return self._data
32+
33+
def getbuffer(self) -> memoryview:
34+
if self.closed:
35+
raise ValueError("I/O operation on closed file.")
36+
return memoryview(self._data) # return a NEW view; external .release() won't break our internal _view.
37+
38+
def readable(self) -> bool:
39+
return True
40+
41+
def writable(self) -> bool:
42+
return False
43+
44+
def seekable(self) -> bool:
45+
return True
46+
47+
def tell(self) -> int:
48+
return self._pos
49+
50+
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
51+
if self.closed:
52+
raise ValueError("I/O operation on closed file.")
53+
if whence == io.SEEK_SET:
54+
new_pos = offset
55+
elif whence == io.SEEK_CUR:
56+
new_pos = self._pos + offset
57+
elif whence == io.SEEK_END:
58+
new_pos = len(self._view) + offset
59+
else:
60+
raise ValueError(f"Invalid whence: {whence}")
61+
if new_pos < 0:
62+
raise ValueError("Negative seek position")
63+
self._pos = new_pos
64+
return self._pos
65+
66+
def readinto(self, b) -> int:
67+
if self.closed:
68+
raise ValueError("I/O operation on closed file.")
69+
mv = memoryview(b)
70+
if mv.readonly:
71+
raise TypeError("readinto() argument must be writable")
72+
mv = mv.cast("B")
73+
if self._pos >= len(self._view):
74+
return 0
75+
n = min(len(mv), len(self._view) - self._pos)
76+
mv[:n] = self._view[self._pos:self._pos + n]
77+
self._pos += n
78+
return n
79+
80+
def readinto1(self, b) -> int:
81+
return self.readinto(b)
82+
83+
def read(self, size: int = -1) -> bytes:
84+
if self.closed:
85+
raise ValueError("I/O operation on closed file.")
86+
if size is None or size < 0:
87+
size = len(self._view) - self._pos
88+
if self._pos >= len(self._view):
89+
return b""
90+
end = min(self._pos + size, len(self._view))
91+
out = self._data[self._pos:end]
92+
self._pos = end
93+
return out
94+
95+
def read1(self, size: int = -1) -> bytes:
96+
return self.read(size)
97+
98+
def readline(self, size: int = -1) -> bytes:
99+
if self.closed:
100+
raise ValueError("I/O operation on closed file.")
101+
if self._pos >= len(self._view):
102+
return b""
103+
end_limit = len(self._view) if size is None or size < 0 else min(len(self._view), self._pos + size)
104+
nl = self._data.find(b"\n", self._pos, end_limit)
105+
end = (nl + 1) if nl != -1 else end_limit
106+
out = self._data[self._pos:end]
107+
self._pos = end
108+
return out
109+
110+
def readlines(self, hint: int = -1) -> list[bytes]:
111+
if self.closed:
112+
raise ValueError("I/O operation on closed file.")
113+
lines: list[bytes] = []
114+
total = 0
115+
while True:
116+
line = self.readline()
117+
if not line:
118+
break
119+
lines.append(line)
120+
total += len(line)
121+
if hint is not None and 0 <= hint <= total:
122+
break
123+
return lines
124+
125+
def write(self, b) -> int:
126+
raise io.UnsupportedOperation("not writable")
127+
128+
def writelines(self, lines) -> None:
129+
raise io.UnsupportedOperation("not writable")
130+
131+
def truncate(self, size: int | None = None) -> int:
132+
raise io.UnsupportedOperation("not writable")
133+
134+
def close(self) -> None:
135+
with suppress(Exception):
136+
self._view.release()
137+
super().close()
138+
139+
16140
def container_to_output_format(container_format: str | None) -> str | None:
17141
"""
18142
A container's `format` may be a comma-separated list of formats.
@@ -57,21 +181,34 @@ class VideoFromFile(VideoInput):
57181
Class representing video input from a file.
58182
"""
59183

60-
def __init__(self, file: str | io.BytesIO):
184+
def __init__(self, file: str | io.BytesIO | bytes | bytearray | memoryview):
61185
"""
62186
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
63187
containing the file contents.
64188
"""
65-
self.__file = file
189+
self.__path: Optional[str] = None
190+
self.__data: Optional[bytes] = None
191+
if isinstance(file, str):
192+
self.__path = file
193+
elif isinstance(file, io.BytesIO):
194+
# Snapshot to immutable bytes once to ensure re-entrant, parallel-safe readers.
195+
self.__data = file.getbuffer().tobytes()
196+
elif isinstance(file, (bytes, bytearray, memoryview)):
197+
self.__data = bytes(file)
198+
else:
199+
raise TypeError(f"Unsupported video source type: {type(file)!r}")
66200

67201
def get_stream_source(self) -> str | io.BytesIO:
68202
"""
69203
Return the underlying file source for efficient streaming.
70204
This avoids unnecessary memory copies when the source is already a file path.
71205
"""
72-
if isinstance(self.__file, io.BytesIO):
73-
self.__file.seek(0)
74-
return self.__file
206+
if self.__path is not None:
207+
return self.__path
208+
data = self.__data
209+
if data is None:
210+
raise RuntimeError("VideoFromFile: missing in-memory bytes (__data is None)")
211+
return _ReentrantBytesIO(data)
75212

76213
def get_dimensions(self) -> tuple[int, int]:
77214
"""
@@ -80,14 +217,12 @@ def get_dimensions(self) -> tuple[int, int]:
80217
Returns:
81218
Tuple of (width, height)
82219
"""
83-
if isinstance(self.__file, io.BytesIO):
84-
self.__file.seek(0) # Reset the BytesIO object to the beginning
85-
with av.open(self.__file, mode='r') as container:
220+
with self._open_source() as src, av.open(src, mode="r") as container:
86221
for stream in container.streams:
87222
if stream.type == 'video':
88223
assert isinstance(stream, av.VideoStream)
89224
return stream.width, stream.height
90-
raise ValueError(f"No video stream found in file '{self.__file}'")
225+
raise ValueError(f"No video stream found in {self._source_label()}")
91226

92227
def get_duration(self) -> float:
93228
"""
@@ -96,9 +231,7 @@ def get_duration(self) -> float:
96231
Returns:
97232
Duration in seconds
98233
"""
99-
if isinstance(self.__file, io.BytesIO):
100-
self.__file.seek(0)
101-
with av.open(self.__file, mode="r") as container:
234+
with self._open_source() as src, av.open(src, mode="r") as container:
102235
if container.duration is not None:
103236
return float(container.duration / av.time_base)
104237

@@ -119,17 +252,14 @@ def get_duration(self) -> float:
119252
if frame_count > 0:
120253
return float(frame_count / video_stream.average_rate)
121254

122-
raise ValueError(f"Could not determine duration for file '{self.__file}'")
255+
raise ValueError(f"Could not determine duration for file '{self._source_label()}'")
123256

124257
def get_frame_count(self) -> int:
125258
"""
126259
Returns the number of frames in the video without materializing them as
127260
torch tensors.
128261
"""
129-
if isinstance(self.__file, io.BytesIO):
130-
self.__file.seek(0)
131-
132-
with av.open(self.__file, mode="r") as container:
262+
with self._open_source() as src, av.open(src, mode="r") as container:
133263
video_stream = self._get_first_video_stream(container)
134264
# 1. Prefer the frames field if available
135265
if video_stream.frames and video_stream.frames > 0:
@@ -160,18 +290,15 @@ def get_frame_count(self) -> int:
160290
frame_count += 1
161291

162292
if frame_count == 0:
163-
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
293+
raise ValueError(f"Could not determine frame count for file '{self._source_label()}'")
164294
return frame_count
165295

166296
def get_frame_rate(self) -> Fraction:
167297
"""
168298
Returns the average frame rate of the video using container metadata
169299
without decoding all frames.
170300
"""
171-
if isinstance(self.__file, io.BytesIO):
172-
self.__file.seek(0)
173-
174-
with av.open(self.__file, mode="r") as container:
301+
with self._open_source() as src, av.open(src, mode="r") as container:
175302
video_stream = self._get_first_video_stream(container)
176303
# Preferred: use PyAV's average_rate (usually already a Fraction-like)
177304
if video_stream.average_rate:
@@ -193,9 +320,7 @@ def get_container_format(self) -> str:
193320
Returns:
194321
Container format as string
195322
"""
196-
if isinstance(self.__file, io.BytesIO):
197-
self.__file.seek(0)
198-
with av.open(self.__file, mode='r') as container:
323+
with self._open_source() as src, av.open(src, mode='r') as container:
199324
return container.format.name
200325

201326
def get_components_internal(self, container: InputContainer) -> VideoComponents:
@@ -239,11 +364,8 @@ def get_components_internal(self, container: InputContainer) -> VideoComponents:
239364
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
240365

241366
def get_components(self) -> VideoComponents:
242-
if isinstance(self.__file, io.BytesIO):
243-
self.__file.seek(0) # Reset the BytesIO object to the beginning
244-
with av.open(self.__file, mode='r') as container:
367+
with self._open_source() as src, av.open(src, mode='r') as container:
245368
return self.get_components_internal(container)
246-
raise ValueError(f"No video stream found in file '{self.__file}'")
247369

248370
def save_to(
249371
self,
@@ -252,9 +374,7 @@ def save_to(
252374
codec: VideoCodec = VideoCodec.AUTO,
253375
metadata: Optional[dict] = None
254376
):
255-
if isinstance(self.__file, io.BytesIO):
256-
self.__file.seek(0) # Reset the BytesIO object to the beginning
257-
with av.open(self.__file, mode='r') as container:
377+
with self._open_source() as src, av.open(src, mode='r') as container:
258378
container_format = container.format.name
259379
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
260380
reuse_streams = True
@@ -306,9 +426,25 @@ def save_to(
306426
def _get_first_video_stream(self, container: InputContainer):
307427
video_stream = next((s for s in container.streams if s.type == "video"), None)
308428
if video_stream is None:
309-
raise ValueError(f"No video stream found in file '{self.__file}'")
429+
raise ValueError(f"No video stream found in file '{self._source_label()}'")
310430
return video_stream
311431

432+
def _source_label(self) -> str:
433+
if self.__path is not None:
434+
return self.__path
435+
return f"<in-memory video: {len(self.__data)} bytes>"
436+
437+
@contextmanager
438+
def _open_source(self) -> Iterator[str | IO[bytes]]:
439+
"""Internal helper to ensure file-like sources are closed after use."""
440+
src = self.get_stream_source()
441+
try:
442+
yield src
443+
finally:
444+
if not isinstance(src, str):
445+
with suppress(Exception):
446+
src.close()
447+
312448

313449
class VideoFromComponents(VideoInput):
314450
"""

0 commit comments

Comments
 (0)