Skip to content

Commit 84813da

Browse files
committed
Sending the message as JSON and encode the name using base64 to keep the sent length below 512
1 parent 9ac85d8 commit 84813da

1 file changed

Lines changed: 52 additions & 20 deletions

File tree

Lib/multiprocessing/resource_tracker.py

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import warnings
2424
from collections import deque
2525

26+
import json
27+
2628
from . import spawn
2729
from . import util
2830

@@ -192,6 +194,17 @@ def _launch(self):
192194
finally:
193195
os.close(r)
194196

197+
def _make_probe_message(self):
198+
"""Return a JSON-encoded probe message."""
199+
return (
200+
json.dumps(
201+
{"cmd": "PROBE", "rtype": "noop", "base64_name": ""},
202+
ensure_ascii=True,
203+
separators=(",", ":"),
204+
)
205+
+ "\n"
206+
).encode("ascii")
207+
195208
def _ensure_running_and_write(self, msg=None):
196209
with self._lock:
197210
if self._lock._recursion_count() > 1:
@@ -203,7 +216,7 @@ def _ensure_running_and_write(self, msg=None):
203216
if self._fd is not None:
204217
# resource tracker was launched before, is it still running?
205218
if msg is None:
206-
to_send = b'PROBE:0:noop\n'
219+
to_send = self._make_probe_message()
207220
else:
208221
to_send = msg
209222
try:
@@ -230,7 +243,7 @@ def _check_alive(self):
230243
try:
231244
# We cannot use send here as it calls ensure_running, creating
232245
# a cycle.
233-
os.write(self._fd, b'PROBE:0:noop\n')
246+
os.write(self._fd, self._make_probe_message())
234247
except OSError:
235248
return False
236249
else:
@@ -249,17 +262,23 @@ def _write(self, msg):
249262
assert nbytes == len(msg), f"{nbytes=} != {len(msg)=}"
250263

251264
def _send(self, cmd, name, rtype):
252-
# Encode shared_memory names as they are created by the user and may contain
253-
# colons or newlines.
254-
if rtype == "shared_memory":
255-
b = name.encode('utf-8', 'surrogateescape')
256-
name = base64.urlsafe_b64encode(b).decode('ascii')
257-
258-
msg = f"{cmd}:{name}:{rtype}\n".encode("ascii")
259-
if len(msg) > 512:
260-
# posix guarantees that writes to a pipe of less than PIPE_BUF
261-
# bytes are atomic, and that PIPE_BUF >= 512
262-
raise ValueError('msg too long')
265+
# POSIX shm_open() and sem_open() require the name, including its leading slash,
266+
# to be at most NAME_MAX bytes (255 on Linux)
267+
# With json.dump(..., ensure_ascii=True) every non-ASCII byte becomes a 6-char
268+
# escape like \uDC80.
269+
# As we want the overall message to be kept atomic and therefore smaller than 512,
270+
# we encode encode the raw name bytes with URL-safe Base64 - so a 255 long name
271+
# will not exceed 340 bytes.
272+
b = name.encode('utf-8', 'surrogateescape')
273+
if len(b) > 255:
274+
raise ValueError('shared memory name too long (max 255 bytes)')
275+
b64 = base64.urlsafe_b64encode(b).decode('ascii')
276+
277+
payload = {"cmd": cmd, "rtype": rtype, "base64_name": b64}
278+
msg = (json.dumps(payload, ensure_ascii=True, separators=(",", ":")) + "\n").encode("ascii")
279+
280+
# The entire JSON message is guaranteed < PIPE_BUF (512 bytes) by construction.
281+
assert len(msg) <= 512, f"internal error: message too long ({len(msg)} bytes)"
263282

264283
self._ensure_running_and_write(msg)
265284

@@ -290,14 +309,27 @@ def main(fd):
290309
try:
291310
# keep track of registered/unregistered resources
292311
with open(fd, 'rb') as f:
293-
for line in f:
312+
for raw in f:
294313
try:
295-
cmd, enc_name, rtype = line.rstrip(b'\n').decode('ascii').split(':', 2)
296-
if rtype == "shared_memory":
297-
name = base64.urlsafe_b64decode(enc_name.encode('ascii')).decode('utf-8', 'surrogateescape')
298-
else:
299-
# Semaphore names are generated internally, so no encoding is needed.
300-
name = enc_name
314+
line = raw.rstrip(b'\n')
315+
try:
316+
obj = json.loads(line.decode('ascii'))
317+
except Exception as e:
318+
raise ValueError("malformed resource_tracker message: %r" % (line,)) from e
319+
320+
cmd = obj.get("cmd")
321+
rtype = obj.get("rtype")
322+
b64 = obj.get("base64_name")
323+
324+
if not isinstance(cmd, str) or not isinstance(rtype, str) or not isinstance(b64, str):
325+
raise ValueError("malformed resource_tracker fields: %r" % (obj,))
326+
327+
enc = b64.encode('ascii')
328+
enc += b'=' * (-len(enc) % 4) # normalize padding
329+
try:
330+
name = base64.urlsafe_b64decode(enc).decode('utf-8', 'surrogateescape')
331+
except ValueError as e:
332+
raise ValueError("malformed resource_tracker base64_name: %r" % (b64,)) from e
301333

302334
cleanup_func = _CLEANUP_FUNCS.get(rtype, None)
303335
if cleanup_func is None:

0 commit comments

Comments
 (0)