hat.drivers.tcp

Asyncio TCP wrapper

  1"""Asyncio TCP wrapper"""
  2
  3import asyncio
  4import collections
  5import functools
  6import logging
  7import sys
  8import typing
  9
 10from hat import aio
 11from hat import util
 12
 13from hat.drivers import ssl
 14
 15
 16mlog: logging.Logger = logging.getLogger(__name__)
 17"""Module logger"""
 18
 19
 20class Address(typing.NamedTuple):
 21    host: str
 22    port: int
 23
 24
 25class ConnectionInfo(typing.NamedTuple):
 26    local_addr: Address
 27    remote_addr: Address
 28
 29
 30ConnectionCb: typing.TypeAlias = aio.AsyncCallable[['Connection'], None]
 31"""Connection callback"""
 32
 33
 34async def connect(addr: Address,
 35                  *,
 36                  input_buffer_limit: int = 64 * 1024,
 37                  **kwargs
 38                  ) -> 'Connection':
 39    """Create TCP connection
 40
 41    Argument `addr` specifies remote server listening address.
 42
 43    Argument `input_buffer_limit` defines number of bytes in input buffer
 44    that whill temporary pause data receiving. Once number of bytes
 45    drops bellow `input_buffer_limit`, data receiving is resumed. If this
 46    argument is ``0``, data receive pausing is disabled.
 47
 48    Additional arguments are passed directly to `asyncio.create_connection`.
 49
 50    """
 51    loop = asyncio.get_running_loop()
 52    create_transport = functools.partial(Protocol, None, input_buffer_limit)
 53    _, protocol = await loop.create_connection(create_transport,
 54                                               addr.host, addr.port,
 55                                               **kwargs)
 56    return Connection(protocol)
 57
 58
 59async def listen(connection_cb: ConnectionCb,
 60                 addr: Address,
 61                 *,
 62                 bind_connections: bool = False,
 63                 input_buffer_limit: int = 64 * 1024,
 64                 **kwargs
 65                 ) -> 'Server':
 66    """Create listening server
 67
 68    If `bind_connections` is ``True``, closing server will close all open
 69    incoming connections.
 70
 71    Argument `input_buffer_limit` is associated with newly created connections
 72    (see `connect`).
 73
 74    Additional arguments are passed directly to `asyncio.create_server`.
 75
 76    """
 77    server = Server()
 78    server._connection_cb = connection_cb
 79    server._bind_connections = bind_connections
 80    server._async_group = aio.Group()
 81
 82    on_connection = functools.partial(server.async_group.spawn,
 83                                      server._on_connection)
 84    create_transport = functools.partial(Protocol, on_connection,
 85                                         input_buffer_limit)
 86
 87    loop = asyncio.get_running_loop()
 88    server._srv = await loop.create_server(create_transport, addr.host,
 89                                           addr.port, **kwargs)
 90
 91    server.async_group.spawn(aio.call_on_cancel, server._on_close)
 92
 93    try:
 94        socknames = (socket.getsockname() for socket in server._srv.sockets)
 95        server._addresses = [Address(*sockname[:2]) for sockname in socknames]
 96
 97    except Exception:
 98        await aio.uncancellable(server.async_close())
 99        raise
100
101    return server
102
103
104class Server(aio.Resource):
105    """TCP listening server
106
107    Closing server will cancel all running `connection_cb` coroutines.
108
109    """
110
111    @property
112    def async_group(self) -> aio.Group:
113        """Async group"""
114        return self._async_group
115
116    @property
117    def addresses(self) -> list[Address]:
118        """Listening addresses"""
119        return self._addresses
120
121    async def _on_close(self):
122        self._srv.close()
123
124        if self._bind_connections or sys.version_info[:2] < (3, 12):
125            await self._srv.wait_closed()
126
127    async def _on_connection(self, protocol):
128        conn = Connection(protocol)
129
130        try:
131            await aio.call(self._connection_cb, conn)
132
133            if self._bind_connections:
134                await conn.wait_closing()
135
136            else:
137                conn = None
138
139        except Exception as e:
140            mlog.warning('connection callback error: %s', e, exc_info=e)
141
142        finally:
143            if conn:
144                await aio.uncancellable(conn.async_close())
145
146
147class Connection(aio.Resource):
148    """TCP connection"""
149
150    def __init__(self, protocol: 'Protocol'):
151        self._protocol = protocol
152        self._async_group = aio.Group()
153
154        self.async_group.spawn(aio.call_on_cancel, protocol.async_close)
155        self.async_group.spawn(aio.call_on_done, protocol.wait_closed(),
156                               self.close)
157
158    @property
159    def async_group(self) -> aio.Group:
160        """Async group"""
161        return self._async_group
162
163    @property
164    def info(self) -> ConnectionInfo:
165        """Connection info"""
166        return self._protocol.info
167
168    @property
169    def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None:
170        """SSL Object"""
171        return self._protocol.ssl_object
172
173    async def write(self, data: util.Bytes):
174        """Write data
175
176        This coroutine will wait until `data` can be added to output buffer.
177
178        """
179        if not self.is_open:
180            raise ConnectionError()
181
182        await self._protocol.write(data)
183
184    async def drain(self):
185        """Drain output buffer"""
186        await self._protocol.drain()
187
188    async def read(self, n: int = -1) -> util.Bytes:
189        """Read up to `n` bytes
190
191        If EOF is detected and no new bytes are available, `ConnectionError`
192        is raised.
193
194        """
195        return await self._protocol.read(n)
196
197    async def readexactly(self, n: int) -> util.Bytes:
198        """Read exactly `n` bytes
199
200        If exact number of bytes could not be read, `ConnectionError` is
201        raised.
202
203        """
204        return await self._protocol.readexactly(n)
205
206    def reset_input_buffer(self) -> int:
207        """Reset input buffer
208
209        Returns number of bytes cleared from buffer.
210
211        """
212        return self._protocol.reset_input_buffer()
213
214
215class Protocol(asyncio.Protocol):
216    """Asyncio protocol implementation"""
217
218    def __init__(self,
219                 on_connected: typing.Callable[['Protocol'], None] | None,
220                 input_buffer_limit: int):
221        self._on_connected = on_connected
222        self._input_buffer_limit = input_buffer_limit
223        self._loop = asyncio.get_running_loop()
224        self._input_buffer = util.BytesBuffer()
225        self._transport = None
226        self._read_queue = None
227        self._write_queue = None
228        self._drain_futures = None
229        self._closed_futures = None
230        self._info = None
231        self._ssl_object = None
232
233    @property
234    def info(self) -> ConnectionInfo:
235        return self._info
236
237    @property
238    def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None:
239        return self._ssl_object
240
241    def connection_made(self, transport: asyncio.Transport):
242        self._transport = transport
243        self._read_queue = collections.deque()
244        self._closed_futures = collections.deque()
245
246        try:
247            sockname = transport.get_extra_info('sockname')
248            peername = transport.get_extra_info('peername')
249            self._info = ConnectionInfo(
250                local_addr=Address(sockname[0], sockname[1]),
251                remote_addr=Address(peername[0], peername[1]))
252            self._ssl_object = transport.get_extra_info('ssl_object')
253
254            if self._on_connected:
255                self._on_connected(self)
256
257        except Exception:
258            transport.abort()
259            return
260
261    def connection_lost(self, exc: Exception | None):
262        self._transport = None
263        write_queue, self._write_queue = self._write_queue, None
264        drain_futures, self._drain_futures = self._drain_futures, None
265        closed_futures, self._closed_futures = self._closed_futures, None
266
267        self.eof_received()
268
269        while write_queue:
270            _, future = write_queue.popleft()
271            if not future.done():
272                future.set_exception(ConnectionError())
273
274        while drain_futures:
275            future = drain_futures.popleft()
276            if not future.done():
277                future.set_result(None)
278
279        while closed_futures:
280            future = closed_futures.popleft()
281            if not future.done():
282                future.set_result(None)
283
284    def pause_writing(self):
285        self._write_queue = collections.deque()
286        self._drain_futures = collections.deque()
287
288    def resume_writing(self):
289        write_queue, self._write_queue = self._write_queue, None
290        drain_futures, self._drain_futures = self._drain_futures, None
291
292        while self._write_queue is None and write_queue:
293            data, future = write_queue.popleft()
294            if future.done():
295                continue
296
297            self._transport.write(data)
298            future.set_result(None)
299
300        if write_queue:
301            write_queue.extend(self._write_queue)
302            self._write_queue = write_queue
303
304            drain_futures.extend(self._drain_futures)
305            self._drain_futures = drain_futures
306
307            return
308
309        while drain_futures:
310            future = drain_futures.popleft()
311            if not future.done():
312                future.set_result(None)
313
314    def data_received(self, data: util.Bytes):
315        self._input_buffer.add(data)
316        self._process_input_buffer()
317
318    def eof_received(self):
319        while self._read_queue:
320            exact, n, future = self._read_queue.popleft()
321            if future.done():
322                continue
323
324            if exact and n <= len(self._input_buffer):
325                future.set_result(self._input_buffer.read(n))
326
327            elif not exact and self._input_buffer:
328                future.set_result(self._input_buffer.read(n))
329
330            else:
331                future.set_exception(ConnectionError())
332
333        self._read_queue = None
334
335    async def write(self, data: util.Bytes):
336        if self._transport is None:
337            raise ConnectionError()
338
339        if self._write_queue is None:
340            self._transport.write(data)
341            return
342
343        future = self._loop.create_future()
344        self._write_queue.append((data, future))
345        await future
346
347    async def drain(self):
348        if self._drain_futures is None:
349            return
350
351        future = self._loop.create_future()
352        self._drain_futures.append(future)
353        await future
354
355    async def read(self, n: int) -> util.Bytes:
356        if n == 0:
357            return b''
358
359        if self._input_buffer and not self._read_queue:
360            data = self._input_buffer.read(n)
361            self._process_input_buffer()
362            return data
363
364        if self._read_queue is None:
365            raise ConnectionError()
366
367        future = self._loop.create_future()
368        future.add_done_callback(self._on_read_future_done)
369        self._read_queue.append((False, n, future))
370        return await future
371
372    async def readexactly(self, n: int) -> util.Bytes:
373        if n == 0:
374            return b''
375
376        if n <= len(self._input_buffer) and not self._read_queue:
377            data = self._input_buffer.read(n)
378            self._process_input_buffer()
379            return data
380
381        if self._read_queue is None:
382            raise ConnectionError()
383
384        future = self._loop.create_future()
385        future.add_done_callback(self._on_read_future_done)
386        self._read_queue.append((True, n, future))
387        self._process_input_buffer()
388        return await future
389
390    def reset_input_buffer(self) -> int:
391        count = self._input_buffer.clear()
392        self._transport.resume_reading()
393        return count
394
395    async def async_close(self):
396        if self._transport is not None:
397            self._transport.close()
398
399        await self.wait_closed()
400
401    async def wait_closed(self):
402        if self._closed_futures is None:
403            return
404
405        future = self._loop.create_future()
406        self._closed_futures.append(future)
407        await future
408
409    def _on_read_future_done(self, future):
410        if not self._read_queue:
411            return
412
413        if not future.cancelled():
414            return
415
416        for _ in range(len(self._read_queue)):
417            i = self._read_queue.popleft()
418            if not i[2].done():
419                self._read_queue.append(i)
420
421        self._process_input_buffer()
422
423    def _process_input_buffer(self):
424        while self._input_buffer and self._read_queue:
425            exact, n, future = self._read_queue.popleft()
426            if future.done():
427                continue
428
429            if not exact:
430                future.set_result(self._input_buffer.read(n))
431
432            elif n <= len(self._input_buffer):
433                future.set_result(self._input_buffer.read(n))
434
435            else:
436                self._read_queue.appendleft((exact, n, future))
437                break
438
439        if not self._transport:
440            return
441
442        pause = (self._input_buffer_limit > 0 and
443                 len(self._input_buffer) > self._input_buffer_limit and
444                 not self._read_queue)
445
446        if pause:
447            self._transport.pause_reading()
448
449        else:
450            self._transport.resume_reading()
mlog: logging.Logger = <Logger hat.drivers.tcp (WARNING)>

Module logger

class Address(typing.NamedTuple):
21class Address(typing.NamedTuple):
22    host: str
23    port: int

Address(host, port)

Address(host: str, port: int)

Create new instance of Address(host, port)

host: str

Alias for field number 0

port: int

Alias for field number 1

class ConnectionInfo(typing.NamedTuple):
26class ConnectionInfo(typing.NamedTuple):
27    local_addr: Address
28    remote_addr: Address

ConnectionInfo(local_addr, remote_addr)

ConnectionInfo( local_addr: Address, remote_addr: Address)

Create new instance of ConnectionInfo(local_addr, remote_addr)

local_addr: Address

Alias for field number 0

remote_addr: Address

Alias for field number 1

ConnectionCb: TypeAlias = Callable[[ForwardRef('Connection')], None | Awaitable[None]]

Connection callback

async def connect( addr: Address, *, input_buffer_limit: int = 65536, **kwargs) -> Connection:
35async def connect(addr: Address,
36                  *,
37                  input_buffer_limit: int = 64 * 1024,
38                  **kwargs
39                  ) -> 'Connection':
40    """Create TCP connection
41
42    Argument `addr` specifies remote server listening address.
43
44    Argument `input_buffer_limit` defines number of bytes in input buffer
45    that whill temporary pause data receiving. Once number of bytes
46    drops bellow `input_buffer_limit`, data receiving is resumed. If this
47    argument is ``0``, data receive pausing is disabled.
48
49    Additional arguments are passed directly to `asyncio.create_connection`.
50
51    """
52    loop = asyncio.get_running_loop()
53    create_transport = functools.partial(Protocol, None, input_buffer_limit)
54    _, protocol = await loop.create_connection(create_transport,
55                                               addr.host, addr.port,
56                                               **kwargs)
57    return Connection(protocol)

Create TCP connection

Argument addr specifies remote server listening address.

Argument input_buffer_limit defines number of bytes in input buffer that whill temporary pause data receiving. Once number of bytes drops bellow input_buffer_limit, data receiving is resumed. If this argument is 0, data receive pausing is disabled.

Additional arguments are passed directly to asyncio.create_connection.

async def listen( connection_cb: Callable[[hat.drivers.acse.Connection], None | Awaitable[None]], addr: Address, *, bind_connections: bool = False, input_buffer_limit: int = 65536, **kwargs) -> Server:
 60async def listen(connection_cb: ConnectionCb,
 61                 addr: Address,
 62                 *,
 63                 bind_connections: bool = False,
 64                 input_buffer_limit: int = 64 * 1024,
 65                 **kwargs
 66                 ) -> 'Server':
 67    """Create listening server
 68
 69    If `bind_connections` is ``True``, closing server will close all open
 70    incoming connections.
 71
 72    Argument `input_buffer_limit` is associated with newly created connections
 73    (see `connect`).
 74
 75    Additional arguments are passed directly to `asyncio.create_server`.
 76
 77    """
 78    server = Server()
 79    server._connection_cb = connection_cb
 80    server._bind_connections = bind_connections
 81    server._async_group = aio.Group()
 82
 83    on_connection = functools.partial(server.async_group.spawn,
 84                                      server._on_connection)
 85    create_transport = functools.partial(Protocol, on_connection,
 86                                         input_buffer_limit)
 87
 88    loop = asyncio.get_running_loop()
 89    server._srv = await loop.create_server(create_transport, addr.host,
 90                                           addr.port, **kwargs)
 91
 92    server.async_group.spawn(aio.call_on_cancel, server._on_close)
 93
 94    try:
 95        socknames = (socket.getsockname() for socket in server._srv.sockets)
 96        server._addresses = [Address(*sockname[:2]) for sockname in socknames]
 97
 98    except Exception:
 99        await aio.uncancellable(server.async_close())
100        raise
101
102    return server

Create listening server

If bind_connections is True, closing server will close all open incoming connections.

Argument input_buffer_limit is associated with newly created connections (see connect).

Additional arguments are passed directly to asyncio.create_server.

class Server(hat.aio.group.Resource):
105class Server(aio.Resource):
106    """TCP listening server
107
108    Closing server will cancel all running `connection_cb` coroutines.
109
110    """
111
112    @property
113    def async_group(self) -> aio.Group:
114        """Async group"""
115        return self._async_group
116
117    @property
118    def addresses(self) -> list[Address]:
119        """Listening addresses"""
120        return self._addresses
121
122    async def _on_close(self):
123        self._srv.close()
124
125        if self._bind_connections or sys.version_info[:2] < (3, 12):
126            await self._srv.wait_closed()
127
128    async def _on_connection(self, protocol):
129        conn = Connection(protocol)
130
131        try:
132            await aio.call(self._connection_cb, conn)
133
134            if self._bind_connections:
135                await conn.wait_closing()
136
137            else:
138                conn = None
139
140        except Exception as e:
141            mlog.warning('connection callback error: %s', e, exc_info=e)
142
143        finally:
144            if conn:
145                await aio.uncancellable(conn.async_close())

TCP listening server

Closing server will cancel all running connection_cb coroutines.

async_group: hat.aio.group.Group
112    @property
113    def async_group(self) -> aio.Group:
114        """Async group"""
115        return self._async_group

Async group

addresses: list[Address]
117    @property
118    def addresses(self) -> list[Address]:
119        """Listening addresses"""
120        return self._addresses

Listening addresses

class Connection(hat.aio.group.Resource):
148class Connection(aio.Resource):
149    """TCP connection"""
150
151    def __init__(self, protocol: 'Protocol'):
152        self._protocol = protocol
153        self._async_group = aio.Group()
154
155        self.async_group.spawn(aio.call_on_cancel, protocol.async_close)
156        self.async_group.spawn(aio.call_on_done, protocol.wait_closed(),
157                               self.close)
158
159    @property
160    def async_group(self) -> aio.Group:
161        """Async group"""
162        return self._async_group
163
164    @property
165    def info(self) -> ConnectionInfo:
166        """Connection info"""
167        return self._protocol.info
168
169    @property
170    def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None:
171        """SSL Object"""
172        return self._protocol.ssl_object
173
174    async def write(self, data: util.Bytes):
175        """Write data
176
177        This coroutine will wait until `data` can be added to output buffer.
178
179        """
180        if not self.is_open:
181            raise ConnectionError()
182
183        await self._protocol.write(data)
184
185    async def drain(self):
186        """Drain output buffer"""
187        await self._protocol.drain()
188
189    async def read(self, n: int = -1) -> util.Bytes:
190        """Read up to `n` bytes
191
192        If EOF is detected and no new bytes are available, `ConnectionError`
193        is raised.
194
195        """
196        return await self._protocol.read(n)
197
198    async def readexactly(self, n: int) -> util.Bytes:
199        """Read exactly `n` bytes
200
201        If exact number of bytes could not be read, `ConnectionError` is
202        raised.
203
204        """
205        return await self._protocol.readexactly(n)
206
207    def reset_input_buffer(self) -> int:
208        """Reset input buffer
209
210        Returns number of bytes cleared from buffer.
211
212        """
213        return self._protocol.reset_input_buffer()

TCP connection

Connection(protocol: Protocol)
151    def __init__(self, protocol: 'Protocol'):
152        self._protocol = protocol
153        self._async_group = aio.Group()
154
155        self.async_group.spawn(aio.call_on_cancel, protocol.async_close)
156        self.async_group.spawn(aio.call_on_done, protocol.wait_closed(),
157                               self.close)
async_group: hat.aio.group.Group
159    @property
160    def async_group(self) -> aio.Group:
161        """Async group"""
162        return self._async_group

Async group

info: ConnectionInfo
164    @property
165    def info(self) -> ConnectionInfo:
166        """Connection info"""
167        return self._protocol.info

Connection info

ssl_object: ssl.SSLObject | ssl.SSLSocket | None
169    @property
170    def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None:
171        """SSL Object"""
172        return self._protocol.ssl_object

SSL Object

async def write(self, data: bytes | bytearray | memoryview):
174    async def write(self, data: util.Bytes):
175        """Write data
176
177        This coroutine will wait until `data` can be added to output buffer.
178
179        """
180        if not self.is_open:
181            raise ConnectionError()
182
183        await self._protocol.write(data)

Write data

This coroutine will wait until data can be added to output buffer.

async def drain(self):
185    async def drain(self):
186        """Drain output buffer"""
187        await self._protocol.drain()

Drain output buffer

async def read(self, n: int = -1) -> bytes | bytearray | memoryview:
189    async def read(self, n: int = -1) -> util.Bytes:
190        """Read up to `n` bytes
191
192        If EOF is detected and no new bytes are available, `ConnectionError`
193        is raised.
194
195        """
196        return await self._protocol.read(n)

Read up to n bytes

If EOF is detected and no new bytes are available, ConnectionError is raised.

async def readexactly(self, n: int) -> bytes | bytearray | memoryview:
198    async def readexactly(self, n: int) -> util.Bytes:
199        """Read exactly `n` bytes
200
201        If exact number of bytes could not be read, `ConnectionError` is
202        raised.
203
204        """
205        return await self._protocol.readexactly(n)

Read exactly n bytes

If exact number of bytes could not be read, ConnectionError is raised.

def reset_input_buffer(self) -> int:
207    def reset_input_buffer(self) -> int:
208        """Reset input buffer
209
210        Returns number of bytes cleared from buffer.
211
212        """
213        return self._protocol.reset_input_buffer()

Reset input buffer

Returns number of bytes cleared from buffer.

class Protocol(asyncio.protocols.Protocol):
216class Protocol(asyncio.Protocol):
217    """Asyncio protocol implementation"""
218
219    def __init__(self,
220                 on_connected: typing.Callable[['Protocol'], None] | None,
221                 input_buffer_limit: int):
222        self._on_connected = on_connected
223        self._input_buffer_limit = input_buffer_limit
224        self._loop = asyncio.get_running_loop()
225        self._input_buffer = util.BytesBuffer()
226        self._transport = None
227        self._read_queue = None
228        self._write_queue = None
229        self._drain_futures = None
230        self._closed_futures = None
231        self._info = None
232        self._ssl_object = None
233
234    @property
235    def info(self) -> ConnectionInfo:
236        return self._info
237
238    @property
239    def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None:
240        return self._ssl_object
241
242    def connection_made(self, transport: asyncio.Transport):
243        self._transport = transport
244        self._read_queue = collections.deque()
245        self._closed_futures = collections.deque()
246
247        try:
248            sockname = transport.get_extra_info('sockname')
249            peername = transport.get_extra_info('peername')
250            self._info = ConnectionInfo(
251                local_addr=Address(sockname[0], sockname[1]),
252                remote_addr=Address(peername[0], peername[1]))
253            self._ssl_object = transport.get_extra_info('ssl_object')
254
255            if self._on_connected:
256                self._on_connected(self)
257
258        except Exception:
259            transport.abort()
260            return
261
262    def connection_lost(self, exc: Exception | None):
263        self._transport = None
264        write_queue, self._write_queue = self._write_queue, None
265        drain_futures, self._drain_futures = self._drain_futures, None
266        closed_futures, self._closed_futures = self._closed_futures, None
267
268        self.eof_received()
269
270        while write_queue:
271            _, future = write_queue.popleft()
272            if not future.done():
273                future.set_exception(ConnectionError())
274
275        while drain_futures:
276            future = drain_futures.popleft()
277            if not future.done():
278                future.set_result(None)
279
280        while closed_futures:
281            future = closed_futures.popleft()
282            if not future.done():
283                future.set_result(None)
284
285    def pause_writing(self):
286        self._write_queue = collections.deque()
287        self._drain_futures = collections.deque()
288
289    def resume_writing(self):
290        write_queue, self._write_queue = self._write_queue, None
291        drain_futures, self._drain_futures = self._drain_futures, None
292
293        while self._write_queue is None and write_queue:
294            data, future = write_queue.popleft()
295            if future.done():
296                continue
297
298            self._transport.write(data)
299            future.set_result(None)
300
301        if write_queue:
302            write_queue.extend(self._write_queue)
303            self._write_queue = write_queue
304
305            drain_futures.extend(self._drain_futures)
306            self._drain_futures = drain_futures
307
308            return
309
310        while drain_futures:
311            future = drain_futures.popleft()
312            if not future.done():
313                future.set_result(None)
314
315    def data_received(self, data: util.Bytes):
316        self._input_buffer.add(data)
317        self._process_input_buffer()
318
319    def eof_received(self):
320        while self._read_queue:
321            exact, n, future = self._read_queue.popleft()
322            if future.done():
323                continue
324
325            if exact and n <= len(self._input_buffer):
326                future.set_result(self._input_buffer.read(n))
327
328            elif not exact and self._input_buffer:
329                future.set_result(self._input_buffer.read(n))
330
331            else:
332                future.set_exception(ConnectionError())
333
334        self._read_queue = None
335
336    async def write(self, data: util.Bytes):
337        if self._transport is None:
338            raise ConnectionError()
339
340        if self._write_queue is None:
341            self._transport.write(data)
342            return
343
344        future = self._loop.create_future()
345        self._write_queue.append((data, future))
346        await future
347
348    async def drain(self):
349        if self._drain_futures is None:
350            return
351
352        future = self._loop.create_future()
353        self._drain_futures.append(future)
354        await future
355
356    async def read(self, n: int) -> util.Bytes:
357        if n == 0:
358            return b''
359
360        if self._input_buffer and not self._read_queue:
361            data = self._input_buffer.read(n)
362            self._process_input_buffer()
363            return data
364
365        if self._read_queue is None:
366            raise ConnectionError()
367
368        future = self._loop.create_future()
369        future.add_done_callback(self._on_read_future_done)
370        self._read_queue.append((False, n, future))
371        return await future
372
373    async def readexactly(self, n: int) -> util.Bytes:
374        if n == 0:
375            return b''
376
377        if n <= len(self._input_buffer) and not self._read_queue:
378            data = self._input_buffer.read(n)
379            self._process_input_buffer()
380            return data
381
382        if self._read_queue is None:
383            raise ConnectionError()
384
385        future = self._loop.create_future()
386        future.add_done_callback(self._on_read_future_done)
387        self._read_queue.append((True, n, future))
388        self._process_input_buffer()
389        return await future
390
391    def reset_input_buffer(self) -> int:
392        count = self._input_buffer.clear()
393        self._transport.resume_reading()
394        return count
395
396    async def async_close(self):
397        if self._transport is not None:
398            self._transport.close()
399
400        await self.wait_closed()
401
402    async def wait_closed(self):
403        if self._closed_futures is None:
404            return
405
406        future = self._loop.create_future()
407        self._closed_futures.append(future)
408        await future
409
410    def _on_read_future_done(self, future):
411        if not self._read_queue:
412            return
413
414        if not future.cancelled():
415            return
416
417        for _ in range(len(self._read_queue)):
418            i = self._read_queue.popleft()
419            if not i[2].done():
420                self._read_queue.append(i)
421
422        self._process_input_buffer()
423
424    def _process_input_buffer(self):
425        while self._input_buffer and self._read_queue:
426            exact, n, future = self._read_queue.popleft()
427            if future.done():
428                continue
429
430            if not exact:
431                future.set_result(self._input_buffer.read(n))
432
433            elif n <= len(self._input_buffer):
434                future.set_result(self._input_buffer.read(n))
435
436            else:
437                self._read_queue.appendleft((exact, n, future))
438                break
439
440        if not self._transport:
441            return
442
443        pause = (self._input_buffer_limit > 0 and
444                 len(self._input_buffer) > self._input_buffer_limit and
445                 not self._read_queue)
446
447        if pause:
448            self._transport.pause_reading()
449
450        else:
451            self._transport.resume_reading()

Asyncio protocol implementation

Protocol( on_connected: Optional[Callable[[Protocol], NoneType]], input_buffer_limit: int)
219    def __init__(self,
220                 on_connected: typing.Callable[['Protocol'], None] | None,
221                 input_buffer_limit: int):
222        self._on_connected = on_connected
223        self._input_buffer_limit = input_buffer_limit
224        self._loop = asyncio.get_running_loop()
225        self._input_buffer = util.BytesBuffer()
226        self._transport = None
227        self._read_queue = None
228        self._write_queue = None
229        self._drain_futures = None
230        self._closed_futures = None
231        self._info = None
232        self._ssl_object = None
info: ConnectionInfo
234    @property
235    def info(self) -> ConnectionInfo:
236        return self._info
ssl_object: ssl.SSLObject | ssl.SSLSocket | None
238    @property
239    def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None:
240        return self._ssl_object
def connection_made(self, transport: asyncio.transports.Transport):
242    def connection_made(self, transport: asyncio.Transport):
243        self._transport = transport
244        self._read_queue = collections.deque()
245        self._closed_futures = collections.deque()
246
247        try:
248            sockname = transport.get_extra_info('sockname')
249            peername = transport.get_extra_info('peername')
250            self._info = ConnectionInfo(
251                local_addr=Address(sockname[0], sockname[1]),
252                remote_addr=Address(peername[0], peername[1]))
253            self._ssl_object = transport.get_extra_info('ssl_object')
254
255            if self._on_connected:
256                self._on_connected(self)
257
258        except Exception:
259            transport.abort()
260            return

Called when a connection is made.

The argument is the transport representing the pipe connection. To receive data, wait for data_received() calls. When the connection is closed, connection_lost() is called.

def connection_lost(self, exc: Exception | None):
262    def connection_lost(self, exc: Exception | None):
263        self._transport = None
264        write_queue, self._write_queue = self._write_queue, None
265        drain_futures, self._drain_futures = self._drain_futures, None
266        closed_futures, self._closed_futures = self._closed_futures, None
267
268        self.eof_received()
269
270        while write_queue:
271            _, future = write_queue.popleft()
272            if not future.done():
273                future.set_exception(ConnectionError())
274
275        while drain_futures:
276            future = drain_futures.popleft()
277            if not future.done():
278                future.set_result(None)
279
280        while closed_futures:
281            future = closed_futures.popleft()
282            if not future.done():
283                future.set_result(None)

Called when the connection is lost or closed.

The argument is an exception object or None (the latter meaning a regular EOF is received or the connection was aborted or closed).

def pause_writing(self):
285    def pause_writing(self):
286        self._write_queue = collections.deque()
287        self._drain_futures = collections.deque()

Called when the transport's buffer goes over the high-water mark.

Pause and resume calls are paired -- pause_writing() is called once when the buffer goes strictly over the high-water mark (even if subsequent writes increases the buffer size even more), and eventually resume_writing() is called once when the buffer size reaches the low-water mark.

Note that if the buffer size equals the high-water mark, pause_writing() is not called -- it must go strictly over. Conversely, resume_writing() is called when the buffer size is equal or lower than the low-water mark. These end conditions are important to ensure that things go as expected when either mark is zero.

NOTE: This is the only Protocol callback that is not called through EventLoop.call_soon() -- if it were, it would have no effect when it's most needed (when the app keeps writing without yielding until pause_writing() is called).

def resume_writing(self):
289    def resume_writing(self):
290        write_queue, self._write_queue = self._write_queue, None
291        drain_futures, self._drain_futures = self._drain_futures, None
292
293        while self._write_queue is None and write_queue:
294            data, future = write_queue.popleft()
295            if future.done():
296                continue
297
298            self._transport.write(data)
299            future.set_result(None)
300
301        if write_queue:
302            write_queue.extend(self._write_queue)
303            self._write_queue = write_queue
304
305            drain_futures.extend(self._drain_futures)
306            self._drain_futures = drain_futures
307
308            return
309
310        while drain_futures:
311            future = drain_futures.popleft()
312            if not future.done():
313                future.set_result(None)

Called when the transport's buffer drains below the low-water mark.

See pause_writing() for details.

def data_received(self, data: bytes | bytearray | memoryview):
315    def data_received(self, data: util.Bytes):
316        self._input_buffer.add(data)
317        self._process_input_buffer()

Called when some data is received.

The argument is a bytes object.

def eof_received(self):
319    def eof_received(self):
320        while self._read_queue:
321            exact, n, future = self._read_queue.popleft()
322            if future.done():
323                continue
324
325            if exact and n <= len(self._input_buffer):
326                future.set_result(self._input_buffer.read(n))
327
328            elif not exact and self._input_buffer:
329                future.set_result(self._input_buffer.read(n))
330
331            else:
332                future.set_exception(ConnectionError())
333
334        self._read_queue = None

Called when the other end calls write_eof() or equivalent.

If this returns a false value (including None), the transport will close itself. If it returns a true value, closing the transport is up to the protocol.

async def write(self, data: bytes | bytearray | memoryview):
336    async def write(self, data: util.Bytes):
337        if self._transport is None:
338            raise ConnectionError()
339
340        if self._write_queue is None:
341            self._transport.write(data)
342            return
343
344        future = self._loop.create_future()
345        self._write_queue.append((data, future))
346        await future
async def drain(self):
348    async def drain(self):
349        if self._drain_futures is None:
350            return
351
352        future = self._loop.create_future()
353        self._drain_futures.append(future)
354        await future
async def read(self, n: int) -> bytes | bytearray | memoryview:
356    async def read(self, n: int) -> util.Bytes:
357        if n == 0:
358            return b''
359
360        if self._input_buffer and not self._read_queue:
361            data = self._input_buffer.read(n)
362            self._process_input_buffer()
363            return data
364
365        if self._read_queue is None:
366            raise ConnectionError()
367
368        future = self._loop.create_future()
369        future.add_done_callback(self._on_read_future_done)
370        self._read_queue.append((False, n, future))
371        return await future
async def readexactly(self, n: int) -> bytes | bytearray | memoryview:
373    async def readexactly(self, n: int) -> util.Bytes:
374        if n == 0:
375            return b''
376
377        if n <= len(self._input_buffer) and not self._read_queue:
378            data = self._input_buffer.read(n)
379            self._process_input_buffer()
380            return data
381
382        if self._read_queue is None:
383            raise ConnectionError()
384
385        future = self._loop.create_future()
386        future.add_done_callback(self._on_read_future_done)
387        self._read_queue.append((True, n, future))
388        self._process_input_buffer()
389        return await future
def reset_input_buffer(self) -> int:
391    def reset_input_buffer(self) -> int:
392        count = self._input_buffer.clear()
393        self._transport.resume_reading()
394        return count
async def async_close(self):
396    async def async_close(self):
397        if self._transport is not None:
398            self._transport.close()
399
400        await self.wait_closed()
async def wait_closed(self):
402    async def wait_closed(self):
403        if self._closed_futures is None:
404            return
405
406        future = self._loop.create_future()
407        self._closed_futures.append(future)
408        await future