diff --git a/CHANGES/12452.feature.rst b/CHANGES/12452.feature.rst new file mode 100644 index 00000000000..f3f619f3dbb --- /dev/null +++ b/CHANGES/12452.feature.rst @@ -0,0 +1 @@ +Added :attr:`~aiohttp.ClientResponse.output_size` -- by :user:`Dreamsorcerer`. diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 25160f198bc..e7cceb9c11e 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -212,6 +212,7 @@ class ClientResponse(HeadersMixin): _resolve_charset: Callable[["ClientResponse", bytes], str] = lambda *_: "utf-8" __writer: asyncio.Task[None] | None = None + _body_writer: AbstractStreamWriter | None = None def __init__( self, @@ -226,6 +227,7 @@ def __init__( session: "ClientSession | None", request_headers: CIMultiDict[str], original_url: URL, + body_writer: AbstractStreamWriter | None = None, **kwargs: object, ) -> None: # kwargs exists so authors of subclasses should expect to pass through unknown @@ -242,6 +244,8 @@ def __init__( self._url = url.with_fragment(None) if url.raw_fragment else url if writer is not None: self._writer = writer + if body_writer is not None: + self._body_writer = body_writer if continue100 is not None: self._continue = continue100 self._request_headers = request_headers @@ -285,6 +289,13 @@ def _writer(self, writer: asyncio.Task[None] | None) -> None: else: writer.add_done_callback(self.__reset_writer) + @property + def output_size(self) -> int: + """Number of bytes sent for this request.""" + if self._body_writer is None: + return 0 + return self._body_writer.output_size + @property def cookies(self) -> SimpleCookie: if self._cookies is None: @@ -827,7 +838,11 @@ def _update_headers(self, headers: CIMultiDict[str]) -> None: self.headers[hdrs.HOST] = headers.pop(hdrs.HOST, host) self.headers.extend(headers) - def _create_response(self, task: asyncio.Task[None] | None) -> ClientResponse: + def _create_response( + self, + task: asyncio.Task[None] | None, + body_writer: AbstractStreamWriter | None = None, + ) -> ClientResponse: return self.response_class( self.method, self.original_url, @@ -839,6 +854,7 @@ def _create_response(self, task: asyncio.Task[None] | None) -> ClientResponse: session=None, request_headers=self.headers, original_url=self.original_url, + body_writer=body_writer, ) def _create_writer(self, protocol: BaseProtocol) -> StreamWriter: @@ -912,7 +928,7 @@ async def _send(self, conn: "Connection") -> ClientResponse: protocol.start_timeout() writer.set_eof() task = None - self._response = self._create_response(task) + self._response = self._create_response(task, body_writer=writer) return self._response async def _write_bytes( @@ -1291,7 +1307,11 @@ def _update_proxy( self.proxy_auth = proxy_auth self.proxy_headers = proxy_headers - def _create_response(self, task: asyncio.Task[None] | None) -> ClientResponse: + def _create_response( + self, + task: asyncio.Task[None] | None, + body_writer: AbstractStreamWriter | None = None, + ) -> ClientResponse: return self.response_class( self.method, self.original_url, @@ -1303,6 +1323,7 @@ def _create_response(self, task: asyncio.Task[None] | None) -> ClientResponse: session=self._session, request_headers=self.headers, original_url=self.original_url, + body_writer=body_writer, ) def _create_writer(self, protocol: BaseProtocol) -> StreamWriter: diff --git a/docs/client_reference.rst b/docs/client_reference.rst index ccbac8b6885..c8410143a10 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -1573,6 +1573,21 @@ Response object .. versionadded:: 3.2 + .. attribute:: output_size + + Number of bytes sent for this request. + + Returns ``0`` if no body writer present (e.g. for some empty body requests). + + Useful to display upload progress:: + + async with session.post(url, data=mpwriter) as resp: + while not resp._writer.done(): + print(f"uploaded {resp.output_size} bytes") + await asyncio.sleep(0.5) + + .. versionadded:: 3.14 + .. attribute:: content_type Read-only property with *content* part of *Content-Type* header. diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 51dfc6d44c9..97d23eb30b0 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -5882,3 +5882,108 @@ async def handler(request: web.Request) -> web.Response: data = await resp.content.read() assert resp.content.total_raw_bytes == len(data) assert resp.content.total_raw_bytes == int(resp.headers["Content-Length"]) + + +async def test_output_size_bytes(aiohttp_client: AiohttpClient) -> None: + async def handler(request: web.Request) -> web.Response: + await request.read() + return web.Response() + + app = web.Application() + app.router.add_post("/", handler) + client = await aiohttp_client(app) + + body = b"x" * 1024 + async with client.post("/", data=body) as resp: + assert resp.output_size >= len(body) + + +async def test_output_size_multipart(aiohttp_client: AiohttpClient) -> None: + async def handler(request: web.Request) -> web.Response: + await request.read() + return web.Response() + + app = web.Application() + app.router.add_post("/", handler) + client = await aiohttp_client(app) + + mpwriter = aiohttp.MultipartWriter("form-data") + mpwriter.append(b"x" * 4096) + mpwriter.append(b"y" * 2048) + expected_body_size = mpwriter.size + assert expected_body_size is not None + + async with client.post("/", data=mpwriter) as resp: + assert resp.output_size >= expected_body_size + + +async def test_output_size_keepalive_isolated( + aiohttp_client: AiohttpClient, +) -> None: + """Each request on a keep-alive connection has its own counter.""" + transports: set[object] = set() + + async def handler(request: web.Request) -> web.Response: + transports.add(request.transport) + await request.read() + return web.Response() + + app = web.Application() + app.router.add_post("/", handler) + connector = aiohttp.TCPConnector(limit=1, force_close=False) + client = await aiohttp_client(app, connector=connector) + body = b"x" * 65536 + + async with client.post("/", data=body) as resp1: + size1 = resp1.output_size + + async with client.post("/", data=body) as resp2: + size2 = resp2.output_size + + assert len(transports) == 1 # Check keep-alive worked. + assert size1 >= len(body) + assert size1 == size2 + + +async def test_output_size_progress(aiohttp_client: AiohttpClient) -> None: + """output_size advances by exactly one chunk per yield.""" + + async def handler(request: web.Request) -> web.StreamResponse: + response = web.StreamResponse() + await response.prepare(request) + # Flush headers + a chunk so resp.start() returns on the client + # side before we read the body. + await response.write(b"x") + await request.read() + return response + + app = web.Application() + app.router.add_post("/", handler) + client = await aiohttp_client(app) + + chunk_size = 4096 + chunk = b"z" * chunk_size + num_chunks = 8 + sample_taken = asyncio.Event() + next_chunk = asyncio.Event() + + async def gated_body() -> AsyncIterator[bytes]: + for _ in range(num_chunks): + yield chunk + sample_taken.clear() + next_chunk.set() + await sample_taken.wait() + + async with client.post("/", data=gated_body()) as resp: + samples: list[int] = [] + for _ in range(num_chunks): + await next_chunk.wait() + next_chunk.clear() + samples.append(resp.output_size) + sample_taken.set() + await resp.read() + + # Each sample after the first reflects exactly one more chunk on the wire. + chunked_framing = len(f"{chunk_size:x}".encode()) + 4 + deltas = [samples[i] - samples[i - 1] for i in range(1, len(samples))] + assert deltas == [chunk_size + chunked_framing] * (num_chunks - 1) diff --git a/tests/test_client_response.py b/tests/test_client_response.py index 96a14ca56eb..10a70684865 100644 --- a/tests/test_client_response.py +++ b/tests/test_client_response.py @@ -1634,3 +1634,20 @@ def test_response_cookies_setter_updates_raw_headers( response.cookies = empty_cookies # Should not set _raw_cookie_headers for empty cookies assert response._raw_cookie_headers is None + + +def test_output_size_default_zero() -> None: + url = URL("http://def-cl-resp.org") + response = ClientResponse( + "get", + url, + writer=WriterMock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=mock.Mock(), + session=None, + request_headers=CIMultiDict[str](), + original_url=url, + ) + assert response.output_size == 0