Skip to content

Commit a7bf339

Browse files
properly propagate contextvars for server protocols
1 parent 248ce9f commit a7bf339

2 files changed

Lines changed: 31 additions & 28 deletions

File tree

Lib/asyncio/base_events.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""
1515

1616
import collections
17+
import contextvars
1718
import collections.abc
1819
import concurrent.futures
1920
import errno
@@ -289,6 +290,7 @@ def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
289290
self._ssl_shutdown_timeout = ssl_shutdown_timeout
290291
self._serving = False
291292
self._serving_forever_fut = None
293+
self._context = contextvars.copy_context()
292294

293295
def __repr__(self):
294296
return f'<{self.__class__.__name__} sockets={self.sockets!r}>'
@@ -318,7 +320,7 @@ def _start_serving(self):
318320
self._loop._start_serving(
319321
self._protocol_factory, sock, self._ssl_context,
320322
self, self._backlog, self._ssl_handshake_timeout,
321-
self._ssl_shutdown_timeout)
323+
self._ssl_shutdown_timeout, context=self._context)
322324

323325
def get_loop(self):
324326
return self._loop
@@ -1211,9 +1213,10 @@ async def _create_connection_transport(
12111213
self, sock, protocol_factory, ssl,
12121214
server_hostname, server_side=False,
12131215
ssl_handshake_timeout=None,
1214-
ssl_shutdown_timeout=None):
1216+
ssl_shutdown_timeout=None, context=None):
12151217

12161218
sock.setblocking(False)
1219+
context = context if context is not None else contextvars.copy_context()
12171220

12181221
protocol = protocol_factory()
12191222
waiter = self.create_future()
@@ -1225,7 +1228,7 @@ async def _create_connection_transport(
12251228
ssl_handshake_timeout=ssl_handshake_timeout,
12261229
ssl_shutdown_timeout=ssl_shutdown_timeout)
12271230
else:
1228-
transport = self._make_socket_transport(sock, protocol, waiter)
1231+
transport = self._make_socket_transport(sock, protocol, waiter, context=context)
12291232

12301233
try:
12311234
await waiter

Lib/asyncio/selector_events.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ def __init__(self, selector=None):
6767
self._transports = weakref.WeakValueDictionary()
6868

6969
def _make_socket_transport(self, sock, protocol, waiter=None, *,
70-
extra=None, server=None):
70+
extra=None, server=None, context=None):
7171
self._ensure_fd_no_transport(sock)
7272
return _SelectorSocketTransport(self, sock, protocol, waiter,
73-
extra, server)
73+
extra, server, context=context)
7474

7575
def _make_ssl_transport(
7676
self, rawsock, protocol, sslcontext, waiter=None,
@@ -159,16 +159,16 @@ def _write_to_self(self):
159159
def _start_serving(self, protocol_factory, sock,
160160
sslcontext=None, server=None, backlog=100,
161161
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
162-
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
162+
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, context=None):
163163
self._add_reader(sock.fileno(), self._accept_connection,
164164
protocol_factory, sock, sslcontext, server, backlog,
165-
ssl_handshake_timeout, ssl_shutdown_timeout)
165+
ssl_handshake_timeout, ssl_shutdown_timeout, context)
166166

167167
def _accept_connection(
168168
self, protocol_factory, sock,
169169
sslcontext=None, server=None, backlog=100,
170170
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
171-
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
171+
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, context=None):
172172
# This method is only called once for each event loop tick where the
173173
# listening socket has triggered an EVENT_READ. There may be multiple
174174
# connections waiting for an .accept() so it is called in a loop.
@@ -204,21 +204,21 @@ def _accept_connection(
204204
self._start_serving,
205205
protocol_factory, sock, sslcontext, server,
206206
backlog, ssl_handshake_timeout,
207-
ssl_shutdown_timeout)
207+
ssl_shutdown_timeout, context)
208208
else:
209209
raise # The event loop will catch, log and ignore it.
210210
else:
211211
extra = {'peername': addr}
212212
accept = self._accept_connection2(
213213
protocol_factory, conn, extra, sslcontext, server,
214-
ssl_handshake_timeout, ssl_shutdown_timeout)
215-
self.create_task(accept)
214+
ssl_handshake_timeout, ssl_shutdown_timeout, context=context)
215+
self.create_task(accept, context=context)
216216

217217
async def _accept_connection2(
218218
self, protocol_factory, conn, extra,
219219
sslcontext=None, server=None,
220220
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
221-
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
221+
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, context=None):
222222
protocol = None
223223
transport = None
224224
try:
@@ -233,7 +233,7 @@ async def _accept_connection2(
233233
else:
234234
transport = self._make_socket_transport(
235235
conn, protocol, waiter=waiter, extra=extra,
236-
server=server)
236+
server=server, context=context)
237237

238238
try:
239239
await waiter
@@ -275,9 +275,9 @@ def _ensure_fd_no_transport(self, fd):
275275
f'File descriptor {fd!r} is used by transport '
276276
f'{transport!r}')
277277

278-
def _add_reader(self, fd, callback, *args):
278+
def _add_reader(self, fd, callback, *args, context=None):
279279
self._check_closed()
280-
handle = events.Handle(callback, args, self, None)
280+
handle = events.Handle(callback, args, self, context=context)
281281
key = self._selector.get_map().get(fd)
282282
if key is None:
283283
self._selector.register(fd, selectors.EVENT_READ,
@@ -770,7 +770,7 @@ class _SelectorTransport(transports._FlowControlMixin,
770770
# exception)
771771
_sock = None
772772

773-
def __init__(self, loop, sock, protocol, extra=None, server=None):
773+
def __init__(self, loop, sock, protocol, extra=None, server=None, context=None):
774774
super().__init__(extra, loop)
775775
self._extra['socket'] = trsock.TransportSocket(sock)
776776
try:
@@ -784,7 +784,7 @@ def __init__(self, loop, sock, protocol, extra=None, server=None):
784784
self._extra['peername'] = None
785785
self._sock = sock
786786
self._sock_fd = sock.fileno()
787-
787+
self._context = context
788788
self._protocol_connected = False
789789
self.set_protocol(protocol)
790790

@@ -866,7 +866,7 @@ def close(self):
866866
if not self._buffer:
867867
self._conn_lost += 1
868868
self._loop._remove_writer(self._sock_fd)
869-
self._loop.call_soon(self._call_connection_lost, None)
869+
self._loop.call_soon(self._call_connection_lost, None, context=self._context)
870870

871871
def __del__(self, _warn=warnings.warn):
872872
if self._sock is not None:
@@ -899,7 +899,7 @@ def _force_close(self, exc):
899899
self._closing = True
900900
self._loop._remove_reader(self._sock_fd)
901901
self._conn_lost += 1
902-
self._loop.call_soon(self._call_connection_lost, exc)
902+
self._loop.call_soon(self._call_connection_lost, exc, context=self._context)
903903

904904
def _call_connection_lost(self, exc):
905905
try:
@@ -921,7 +921,7 @@ def get_write_buffer_size(self):
921921
def _add_reader(self, fd, callback, *args):
922922
if not self.is_reading():
923923
return
924-
self._loop._add_reader(fd, callback, *args)
924+
self._loop._add_reader(fd, callback, *args, context=self._context)
925925

926926

927927
class _SelectorSocketTransport(_SelectorTransport):
@@ -930,10 +930,10 @@ class _SelectorSocketTransport(_SelectorTransport):
930930
_sendfile_compatible = constants._SendfileMode.TRY_NATIVE
931931

932932
def __init__(self, loop, sock, protocol, waiter=None,
933-
extra=None, server=None):
934-
933+
extra=None, server=None, context=None):
934+
assert context is not None
935935
self._read_ready_cb = None
936-
super().__init__(loop, sock, protocol, extra, server)
936+
super().__init__(loop, sock, protocol, extra, server, context)
937937
self._eof = False
938938
self._empty_waiter = None
939939
if _HAS_SENDMSG:
@@ -945,14 +945,14 @@ def __init__(self, loop, sock, protocol, waiter=None,
945945
# decreases the latency (in some cases significantly.)
946946
base_events._set_nodelay(self._sock)
947947

948-
self._loop.call_soon(self._protocol.connection_made, self)
948+
self._loop.call_soon(self._protocol.connection_made, self, context=context)
949949
# only start reading when connection_made() has been called
950950
self._loop.call_soon(self._add_reader,
951-
self._sock_fd, self._read_ready)
951+
self._sock_fd, self._read_ready, context=context)
952952
if waiter is not None:
953953
# only wake up the waiter when connection_made() has been called
954954
self._loop.call_soon(futures._set_result_unless_cancelled,
955-
waiter, None)
955+
waiter, None, context=context)
956956

957957
def set_protocol(self, protocol):
958958
if isinstance(protocol, protocols.BufferedProtocol):
@@ -1081,7 +1081,7 @@ def write(self, data):
10811081
if not data:
10821082
return
10831083
# Not all was written; register write handler.
1084-
self._loop._add_writer(self._sock_fd, self._write_ready)
1084+
self._loop._add_writer(self._sock_fd, self._write_ready, context=self._context)
10851085

10861086
# Add it to the buffer.
10871087
self._buffer.append(data)
@@ -1185,7 +1185,7 @@ def writelines(self, list_of_data):
11851185
self._write_ready()
11861186
# If the entire buffer couldn't be written, register a write handler
11871187
if self._buffer:
1188-
self._loop._add_writer(self._sock_fd, self._write_ready)
1188+
self._loop._add_writer(self._sock_fd, self._write_ready, context=self._context)
11891189
self._maybe_pause_protocol()
11901190

11911191
def can_write_eof(self):

0 commit comments

Comments
 (0)