hat.drivers.tcp

Asyncio TCP wrapper

  1"""Asyncio TCP wrapper"""
  2
  3import asyncio
  4import collections
  5import functools
  6import logging
  7import typing
  8
  9from hat import aio
 10from hat import util
 11
 12from hat.drivers import ssl
 13
 14
 15mlog: logging.Logger = logging.getLogger(__name__)
 16"""Module logger"""
 17
 18
 19class Address(typing.NamedTuple):
 20    host: str
 21    port: int
 22
 23
 24class ConnectionInfo(typing.NamedTuple):
 25    local_addr: Address
 26    remote_addr: Address
 27
 28
 29ConnectionCb: typing.TypeAlias = aio.AsyncCallable[['Connection'], None]
 30"""Connection callback"""
 31
 32
 33async def connect(addr: Address,
 34                  *,
 35                  input_buffer_limit: int = 64 * 1024,
 36                  **kwargs
 37                  ) -> 'Connection':
 38    """Create TCP connection
 39
 40    Argument `addr` specifies remote server listening address.
 41
 42    Argument `input_buffer_limit` defines number of bytes in input buffer
 43    that whill temporary pause data receiving. Once number of bytes
 44    drops bellow `input_buffer_limit`, data receiving is resumed. If this
 45    argument is ``0``, data receive pausing is disabled.
 46
 47    Additional arguments are passed directly to `asyncio.create_connection`.
 48
 49    """
 50    loop = asyncio.get_running_loop()
 51    create_transport = functools.partial(Protocol, None, input_buffer_limit)
 52    _, protocol = await loop.create_connection(create_transport,
 53                                               addr.host, addr.port,
 54                                               **kwargs)
 55    return Connection(protocol)
 56
 57
 58async def listen(connection_cb: ConnectionCb,
 59                 addr: Address,
 60                 *,
 61                 bind_connections: bool = False,
 62                 input_buffer_limit: int = 64 * 1024,
 63                 **kwargs
 64                 ) -> 'Server':
 65    """Create listening server
 66
 67    If `bind_connections` is ``True``, closing server will close all open
 68    incoming connections.
 69
 70    Argument `input_buffer_limit` is associated with newly created connections
 71    (see `connect`).
 72
 73    Additional arguments are passed directly to `asyncio.create_server`.
 74
 75    """
 76    server = Server()
 77    server._connection_cb = connection_cb
 78    server._bind_connections = bind_connections
 79    server._async_group = aio.Group()
 80
 81    on_connection = functools.partial(server.async_group.spawn,
 82                                      server._on_connection)
 83    create_transport = functools.partial(Protocol, on_connection,
 84                                         input_buffer_limit)
 85
 86    loop = asyncio.get_running_loop()
 87    server._srv = await loop.create_server(create_transport, addr.host,
 88                                           addr.port, **kwargs)
 89
 90    server.async_group.spawn(aio.call_on_cancel, server._on_close)
 91
 92    try:
 93        socknames = (socket.getsockname() for socket in server._srv.sockets)
 94        server._addresses = [Address(*sockname[:2]) for sockname in socknames]
 95
 96    except Exception:
 97        await aio.uncancellable(server.async_close())
 98        raise
 99
100    return server
101
102
103class Server(aio.Resource):
104    """TCP listening server
105
106    Closing server will cancel all running `connection_cb` coroutines.
107
108    """
109
110    @property
111    def async_group(self) -> aio.Group:
112        """Async group"""
113        return self._async_group
114
115    @property
116    def addresses(self) -> list[Address]:
117        """Listening addresses"""
118        return self._addresses
119
120    async def _on_close(self):
121        self._srv.close()
122        await self._srv.wait_closed()
123
124    async def _on_connection(self, protocol):
125        conn = Connection(protocol)
126
127        try:
128            await aio.call(self._connection_cb, conn)
129
130            if self._bind_connections:
131                await conn.wait_closing()
132
133            else:
134                conn = None
135
136        except Exception as e:
137            mlog.warning('connection callback error: %s', e, exc_info=e)
138
139        finally:
140            if conn:
141                await aio.uncancellable(conn.async_close())
142
143
144class Connection(aio.Resource):
145    """TCP connection"""
146
147    def __init__(self, protocol: 'Protocol'):
148        self._protocol = protocol
149        self._async_group = aio.Group()
150
151        self.async_group.spawn(aio.call_on_cancel, protocol.async_close)
152        self.async_group.spawn(aio.call_on_done, protocol.wait_closed(),
153                               self.close)
154
155    @property
156    def async_group(self) -> aio.Group:
157        """Async group"""
158        return self._async_group
159
160    @property
161    def info(self) -> ConnectionInfo:
162        """Connection info"""
163        return self._protocol.info
164
165    @property
166    def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None:
167        """SSL Object"""
168        return self._protocol.ssl_object
169
170    async def write(self, data: util.Bytes):
171        """Write data
172
173        This coroutine will wait until `data` can be added to output buffer.
174
175        """
176        if not self.is_open:
177            raise ConnectionError()
178
179        await self._protocol.write(data)
180
181    async def drain(self):
182        """Drain output buffer"""
183        await self._protocol.drain()
184
185    async def read(self, n: int = -1) -> util.Bytes:
186        """Read up to `n` bytes
187
188        If EOF is detected and no new bytes are available, `ConnectionError`
189        is raised.
190
191        """
192        return await self._protocol.read(n)
193
194    async def readexactly(self, n: int) -> util.Bytes:
195        """Read exactly `n` bytes
196
197        If exact number of bytes could not be read, `ConnectionError` is
198        raised.
199
200        """
201        return await self._protocol.readexactly(n)
202
203    def reset_input_buffer(self) -> int:
204        """Reset input buffer
205
206        Returns number of bytes cleared from buffer.
207
208        """
209        return self._protocol.reset_input_buffer()
210
211
212class Protocol(asyncio.Protocol):
213    """Asyncio protocol implementation"""
214
215    def __init__(self,
216                 on_connected: typing.Callable[['Protocol'], None] | None,
217                 input_buffer_limit: int):
218        self._on_connected = on_connected
219        self._input_buffer_limit = input_buffer_limit
220        self._loop = asyncio.get_running_loop()
221        self._input_buffer = util.BytesBuffer()
222        self._transport = None
223        self._read_queue = None
224        self._write_queue = None
225        self._drain_futures = None
226        self._closed_futures = None
227        self._info = None
228        self._ssl_object = None
229
230    @property
231    def info(self) -> ConnectionInfo:
232        return self._info
233
234    @property
235    def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None:
236        return self._ssl_object
237
238    def connection_made(self, transport: asyncio.Transport):
239        self._transport = transport
240        self._read_queue = collections.deque()
241        self._closed_futures = collections.deque()
242
243        try:
244            sockname = transport.get_extra_info('sockname')
245            peername = transport.get_extra_info('peername')
246            self._info = ConnectionInfo(
247                local_addr=Address(sockname[0], sockname[1]),
248                remote_addr=Address(peername[0], peername[1]))
249            self._ssl_object = transport.get_extra_info('ssl_object')
250
251            if self._on_connected:
252                self._on_connected(self)
253
254        except Exception:
255            transport.abort()
256            return
257
258    def connection_lost(self, exc: Exception | None):
259        self._transport = None
260        write_queue, self._write_queue = self._write_queue, None
261        drain_futures, self._drain_futures = self._drain_futures, None
262        closed_futures, self._closed_futures = self._closed_futures, None
263
264        self.eof_received()
265
266        while write_queue:
267            _, future = write_queue.popleft()
268            if not future.done():
269                future.set_exception(ConnectionError())
270
271        while drain_futures:
272            future = drain_futures.popleft()
273            if not future.done():
274                future.set_result(None)
275
276        while closed_futures:
277            future = closed_futures.popleft()
278            if not future.done():
279                future.set_result(None)
280
281    def pause_writing(self):
282        self._write_queue = collections.deque()
283        self._drain_futures = collections.deque()
284
285    def resume_writing(self):
286        write_queue, self._write_queue = self._write_queue, None
287        drain_futures, self._drain_futures = self._drain_futures, None
288
289        while self._write_queue is None and write_queue:
290            data, future = write_queue.popleft()
291            if future.done():
292                continue
293
294            self._transport.write(data)
295            future.set_result(None)
296
297        if write_queue:
298            write_queue.extend(self._write_queue)
299            self._write_queue = write_queue
300
301            drain_futures.extend(self._drain_futures)
302            self._drain_futures = drain_futures
303
304            return
305
306        while drain_futures:
307            future = drain_futures.popleft()
308            if not future.done():
309                future.set_result(None)
310
311    def data_received(self, data: util.Bytes):
312        self._input_buffer.add(data)
313        self._process_input_buffer()
314
315    def eof_received(self):
316        while self._read_queue:
317            exact, n, future = self._read_queue.popleft()
318            if future.done():
319                continue
320
321            if exact and n <= len(self._input_buffer):
322                future.set_result(self._input_buffer.read(n))
323
324            elif not exact and self._input_buffer:
325                future.set_result(self._input_buffer.read(n))
326
327            else:
328                future.set_exception(ConnectionError())
329
330        self._read_queue = None
331
332    async def write(self, data: util.Bytes):
333        if self._transport is None:
334            raise ConnectionError()
335
336        if self._write_queue is None:
337            self._transport.write(data)
338            return
339
340        future = self._loop.create_future()
341        self._write_queue.append((data, future))
342        await future
343
344    async def drain(self):
345        if self._drain_futures is None:
346            return
347
348        future = self._loop.create_future()
349        self._drain_futures.append(future)
350        await future
351
352    async def read(self, n: int) -> util.Bytes:
353        if n == 0:
354            return b''
355
356        if self._input_buffer and not self._read_queue:
357            data = self._input_buffer.read(n)
358            self._process_input_buffer()
359            return data
360
361        if self._read_queue is None:
362            raise ConnectionError()
363
364        future = self._loop.create_future()
365        future.add_done_callback(self._on_read_future_done)
366        self._read_queue.append((False, n, future))
367        return await future
368
369    async def readexactly(self, n: int) -> util.Bytes:
370        if n == 0:
371            return b''
372
373        if n <= len(self._input_buffer) and not self._read_queue:
374            data = self._input_buffer.read(n)
375            self._process_input_buffer()
376            return data
377
378        if self._read_queue is None:
379            raise ConnectionError()
380
381        future = self._loop.create_future()
382        future.add_done_callback(self._on_read_future_done)
383        self._read_queue.append((True, n, future))
384        self._process_input_buffer()
385        return await future
386
387    def reset_input_buffer(self) -> int:
388        count = self._input_buffer.clear()
389        self._transport.resume_reading()
390        return count
391
392    async def async_close(self):
393        if self._transport is not None:
394            self._transport.close()
395
396        await self.wait_closed()
397
398    async def wait_closed(self):
399        if self._closed_futures is None:
400            return
401
402        future = self._loop.create_future()
403        self._closed_futures.append(future)
404        await future
405
406    def _on_read_future_done(self, future):
407        if not self._read_queue:
408            return
409
410        if not future.cancelled():
411            return
412
413        for _ in range(len(self._read_queue)):
414            i = self._read_queue.popleft()
415            if not i[2].done():
416                self._read_queue.append(i)
417
418        self._process_input_buffer()
419
420    def _process_input_buffer(self):
421        while self._input_buffer and self._read_queue:
422            exact, n, future = self._read_queue.popleft()
423            if future.done():
424                continue
425
426            if not exact:
427                future.set_result(self._input_buffer.read(n))
428
429            elif n <= len(self._input_buffer):
430                future.set_result(self._input_buffer.read(n))
431
432            else:
433                self._read_queue.appendleft((exact, n, future))
434                break
435
436        if not self._transport:
437            return
438
439        pause = (self._input_buffer_limit > 0 and
440                 len(self._input_buffer) > self._input_buffer_limit and
441                 not self._read_queue)
442
443        if pause:
444            self._transport.pause_reading()
445
446        else:
447            self._transport.resume_reading()
mlog: logging.Logger = <Logger hat.drivers.tcp (WARNING)>

Module logger

class Address(typing.NamedTuple):
20class Address(typing.NamedTuple):
21    host: str
22    port: int

Address(host, port)

Address(host: str, port: int)

Create new instance of Address(host, port)

host: str

Alias for field number 0

port: int

Alias for field number 1

Inherited Members
builtins.tuple
index
count
class ConnectionInfo(typing.NamedTuple):
25class ConnectionInfo(typing.NamedTuple):
26    local_addr: Address
27    remote_addr: Address

ConnectionInfo(local_addr, remote_addr)

ConnectionInfo( local_addr: Address, remote_addr: Address)

Create new instance of ConnectionInfo(local_addr, remote_addr)

local_addr: Address

Alias for field number 0

remote_addr: Address

Alias for field number 1

Inherited Members
builtins.tuple
index
count
ConnectionCb: TypeAlias = Callable[[ForwardRef('Connection')], Optional[Awaitable[NoneType]]]

Connection callback

async def connect( addr: Address, *, input_buffer_limit: int = 65536, **kwargs) -> Connection:
34async def connect(addr: Address,
35                  *,
36                  input_buffer_limit: int = 64 * 1024,
37                  **kwargs
38                  ) -> 'Connection':
39    """Create TCP connection
40
41    Argument `addr` specifies remote server listening address.
42
43    Argument `input_buffer_limit` defines number of bytes in input buffer
44    that whill temporary pause data receiving. Once number of bytes
45    drops bellow `input_buffer_limit`, data receiving is resumed. If this
46    argument is ``0``, data receive pausing is disabled.
47
48    Additional arguments are passed directly to `asyncio.create_connection`.
49
50    """
51    loop = asyncio.get_running_loop()
52    create_transport = functools.partial(Protocol, None, input_buffer_limit)
53    _, protocol = await loop.create_connection(create_transport,
54                                               addr.host, addr.port,
55                                               **kwargs)
56    return Connection(protocol)

Create TCP connection

Argument addr specifies remote server listening address.

Argument input_buffer_limit defines number of bytes in input buffer that whill temporary pause data receiving. Once number of bytes drops bellow input_buffer_limit, data receiving is resumed. If this argument is 0, data receive pausing is disabled.

Additional arguments are passed directly to asyncio.create_connection.

async def listen( connection_cb: Callable[[hat.drivers.acse.Connection], Optional[Awaitable[NoneType]]], addr: Address, *, bind_connections: bool = False, input_buffer_limit: int = 65536, **kwargs) -> Server:
 59async def listen(connection_cb: ConnectionCb,
 60                 addr: Address,
 61                 *,
 62                 bind_connections: bool = False,
 63                 input_buffer_limit: int = 64 * 1024,
 64                 **kwargs
 65                 ) -> 'Server':
 66    """Create listening server
 67
 68    If `bind_connections` is ``True``, closing server will close all open
 69    incoming connections.
 70
 71    Argument `input_buffer_limit` is associated with newly created connections
 72    (see `connect`).
 73
 74    Additional arguments are passed directly to `asyncio.create_server`.
 75
 76    """
 77    server = Server()
 78    server._connection_cb = connection_cb
 79    server._bind_connections = bind_connections
 80    server._async_group = aio.Group()
 81
 82    on_connection = functools.partial(server.async_group.spawn,
 83                                      server._on_connection)
 84    create_transport = functools.partial(Protocol, on_connection,
 85                                         input_buffer_limit)
 86
 87    loop = asyncio.get_running_loop()
 88    server._srv = await loop.create_server(create_transport, addr.host,
 89                                           addr.port, **kwargs)
 90
 91    server.async_group.spawn(aio.call_on_cancel, server._on_close)
 92
 93    try:
 94        socknames = (socket.getsockname() for socket in server._srv.sockets)
 95        server._addresses = [Address(*sockname[:2]) for sockname in socknames]
 96
 97    except Exception:
 98        await aio.uncancellable(server.async_close())
 99        raise
100
101    return server

Create listening server

If bind_connections is True, closing server will close all open incoming connections.

Argument input_buffer_limit is associated with newly created connections (see connect).

Additional arguments are passed directly to asyncio.create_server.

class Server(hat.aio.group.Resource):
104class Server(aio.Resource):
105    """TCP listening server
106
107    Closing server will cancel all running `connection_cb` coroutines.
108
109    """
110
111    @property
112    def async_group(self) -> aio.Group:
113        """Async group"""
114        return self._async_group
115
116    @property
117    def addresses(self) -> list[Address]:
118        """Listening addresses"""
119        return self._addresses
120
121    async def _on_close(self):
122        self._srv.close()
123        await self._srv.wait_closed()
124
125    async def _on_connection(self, protocol):
126        conn = Connection(protocol)
127
128        try:
129            await aio.call(self._connection_cb, conn)
130
131            if self._bind_connections:
132                await conn.wait_closing()
133
134            else:
135                conn = None
136
137        except Exception as e:
138            mlog.warning('connection callback error: %s', e, exc_info=e)
139
140        finally:
141            if conn:
142                await aio.uncancellable(conn.async_close())

TCP listening server

Closing server will cancel all running connection_cb coroutines.

async_group: hat.aio.group.Group
111    @property
112    def async_group(self) -> aio.Group:
113        """Async group"""
114        return self._async_group

Async group

addresses: list[Address]
116    @property
117    def addresses(self) -> list[Address]:
118        """Listening addresses"""
119        return self._addresses

Listening addresses

Inherited Members
hat.aio.group.Resource
is_open
is_closing
is_closed
wait_closing
wait_closed
close
async_close
class Connection(hat.aio.group.Resource):
145class Connection(aio.Resource):
146    """TCP connection"""
147
148    def __init__(self, protocol: 'Protocol'):
149        self._protocol = protocol
150        self._async_group = aio.Group()
151
152        self.async_group.spawn(aio.call_on_cancel, protocol.async_close)
153        self.async_group.spawn(aio.call_on_done, protocol.wait_closed(),
154                               self.close)
155
156    @property
157    def async_group(self) -> aio.Group:
158        """Async group"""
159        return self._async_group
160
161    @property
162    def info(self) -> ConnectionInfo:
163        """Connection info"""
164        return self._protocol.info
165
166    @property
167    def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None:
168        """SSL Object"""
169        return self._protocol.ssl_object
170
171    async def write(self, data: util.Bytes):
172        """Write data
173
174        This coroutine will wait until `data` can be added to output buffer.
175
176        """
177        if not self.is_open:
178            raise ConnectionError()
179
180        await self._protocol.write(data)
181
182    async def drain(self):
183        """Drain output buffer"""
184        await self._protocol.drain()
185
186    async def read(self, n: int = -1) -> util.Bytes:
187        """Read up to `n` bytes
188
189        If EOF is detected and no new bytes are available, `ConnectionError`
190        is raised.
191
192        """
193        return await self._protocol.read(n)
194
195    async def readexactly(self, n: int) -> util.Bytes:
196        """Read exactly `n` bytes
197
198        If exact number of bytes could not be read, `ConnectionError` is
199        raised.
200
201        """
202        return await self._protocol.readexactly(n)
203
204    def reset_input_buffer(self) -> int:
205        """Reset input buffer
206
207        Returns number of bytes cleared from buffer.
208
209        """
210        return self._protocol.reset_input_buffer()

TCP connection

Connection(protocol: Protocol)
148    def __init__(self, protocol: 'Protocol'):
149        self._protocol = protocol
150        self._async_group = aio.Group()
151
152        self.async_group.spawn(aio.call_on_cancel, protocol.async_close)
153        self.async_group.spawn(aio.call_on_done, protocol.wait_closed(),
154                               self.close)
async_group: hat.aio.group.Group
156    @property
157    def async_group(self) -> aio.Group:
158        """Async group"""
159        return self._async_group

Async group

info: ConnectionInfo
161    @property
162    def info(self) -> ConnectionInfo:
163        """Connection info"""
164        return self._protocol.info

Connection info

ssl_object: ssl.SSLObject | ssl.SSLSocket | None
166    @property
167    def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None:
168        """SSL Object"""
169        return self._protocol.ssl_object

SSL Object

async def write(self, data: bytes | bytearray | memoryview):
171    async def write(self, data: util.Bytes):
172        """Write data
173
174        This coroutine will wait until `data` can be added to output buffer.
175
176        """
177        if not self.is_open:
178            raise ConnectionError()
179
180        await self._protocol.write(data)

Write data

This coroutine will wait until data can be added to output buffer.

async def drain(self):
182    async def drain(self):
183        """Drain output buffer"""
184        await self._protocol.drain()

Drain output buffer

async def read(self, n: int = -1) -> bytes | bytearray | memoryview:
186    async def read(self, n: int = -1) -> util.Bytes:
187        """Read up to `n` bytes
188
189        If EOF is detected and no new bytes are available, `ConnectionError`
190        is raised.
191
192        """
193        return await self._protocol.read(n)

Read up to n bytes

If EOF is detected and no new bytes are available, ConnectionError is raised.

async def readexactly(self, n: int) -> bytes | bytearray | memoryview:
195    async def readexactly(self, n: int) -> util.Bytes:
196        """Read exactly `n` bytes
197
198        If exact number of bytes could not be read, `ConnectionError` is
199        raised.
200
201        """
202        return await self._protocol.readexactly(n)

Read exactly n bytes

If exact number of bytes could not be read, ConnectionError is raised.

def reset_input_buffer(self) -> int:
204    def reset_input_buffer(self) -> int:
205        """Reset input buffer
206
207        Returns number of bytes cleared from buffer.
208
209        """
210        return self._protocol.reset_input_buffer()

Reset input buffer

Returns number of bytes cleared from buffer.

Inherited Members
hat.aio.group.Resource
is_open
is_closing
is_closed
wait_closing
wait_closed
close
async_close
class Protocol(asyncio.protocols.Protocol):
213class Protocol(asyncio.Protocol):
214    """Asyncio protocol implementation"""
215
216    def __init__(self,
217                 on_connected: typing.Callable[['Protocol'], None] | None,
218                 input_buffer_limit: int):
219        self._on_connected = on_connected
220        self._input_buffer_limit = input_buffer_limit
221        self._loop = asyncio.get_running_loop()
222        self._input_buffer = util.BytesBuffer()
223        self._transport = None
224        self._read_queue = None
225        self._write_queue = None
226        self._drain_futures = None
227        self._closed_futures = None
228        self._info = None
229        self._ssl_object = None
230
231    @property
232    def info(self) -> ConnectionInfo:
233        return self._info
234
235    @property
236    def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None:
237        return self._ssl_object
238
239    def connection_made(self, transport: asyncio.Transport):
240        self._transport = transport
241        self._read_queue = collections.deque()
242        self._closed_futures = collections.deque()
243
244        try:
245            sockname = transport.get_extra_info('sockname')
246            peername = transport.get_extra_info('peername')
247            self._info = ConnectionInfo(
248                local_addr=Address(sockname[0], sockname[1]),
249                remote_addr=Address(peername[0], peername[1]))
250            self._ssl_object = transport.get_extra_info('ssl_object')
251
252            if self._on_connected:
253                self._on_connected(self)
254
255        except Exception:
256            transport.abort()
257            return
258
259    def connection_lost(self, exc: Exception | None):
260        self._transport = None
261        write_queue, self._write_queue = self._write_queue, None
262        drain_futures, self._drain_futures = self._drain_futures, None
263        closed_futures, self._closed_futures = self._closed_futures, None
264
265        self.eof_received()
266
267        while write_queue:
268            _, future = write_queue.popleft()
269            if not future.done():
270                future.set_exception(ConnectionError())
271
272        while drain_futures:
273            future = drain_futures.popleft()
274            if not future.done():
275                future.set_result(None)
276
277        while closed_futures:
278            future = closed_futures.popleft()
279            if not future.done():
280                future.set_result(None)
281
282    def pause_writing(self):
283        self._write_queue = collections.deque()
284        self._drain_futures = collections.deque()
285
286    def resume_writing(self):
287        write_queue, self._write_queue = self._write_queue, None
288        drain_futures, self._drain_futures = self._drain_futures, None
289
290        while self._write_queue is None and write_queue:
291            data, future = write_queue.popleft()
292            if future.done():
293                continue
294
295            self._transport.write(data)
296            future.set_result(None)
297
298        if write_queue:
299            write_queue.extend(self._write_queue)
300            self._write_queue = write_queue
301
302            drain_futures.extend(self._drain_futures)
303            self._drain_futures = drain_futures
304
305            return
306
307        while drain_futures:
308            future = drain_futures.popleft()
309            if not future.done():
310                future.set_result(None)
311
312    def data_received(self, data: util.Bytes):
313        self._input_buffer.add(data)
314        self._process_input_buffer()
315
316    def eof_received(self):
317        while self._read_queue:
318            exact, n, future = self._read_queue.popleft()
319            if future.done():
320                continue
321
322            if exact and n <= len(self._input_buffer):
323                future.set_result(self._input_buffer.read(n))
324
325            elif not exact and self._input_buffer:
326                future.set_result(self._input_buffer.read(n))
327
328            else:
329                future.set_exception(ConnectionError())
330
331        self._read_queue = None
332
333    async def write(self, data: util.Bytes):
334        if self._transport is None:
335            raise ConnectionError()
336
337        if self._write_queue is None:
338            self._transport.write(data)
339            return
340
341        future = self._loop.create_future()
342        self._write_queue.append((data, future))
343        await future
344
345    async def drain(self):
346        if self._drain_futures is None:
347            return
348
349        future = self._loop.create_future()
350        self._drain_futures.append(future)
351        await future
352
353    async def read(self, n: int) -> util.Bytes:
354        if n == 0:
355            return b''
356
357        if self._input_buffer and not self._read_queue:
358            data = self._input_buffer.read(n)
359            self._process_input_buffer()
360            return data
361
362        if self._read_queue is None:
363            raise ConnectionError()
364
365        future = self._loop.create_future()
366        future.add_done_callback(self._on_read_future_done)
367        self._read_queue.append((False, n, future))
368        return await future
369
370    async def readexactly(self, n: int) -> util.Bytes:
371        if n == 0:
372            return b''
373
374        if n <= len(self._input_buffer) and not self._read_queue:
375            data = self._input_buffer.read(n)
376            self._process_input_buffer()
377            return data
378
379        if self._read_queue is None:
380            raise ConnectionError()
381
382        future = self._loop.create_future()
383        future.add_done_callback(self._on_read_future_done)
384        self._read_queue.append((True, n, future))
385        self._process_input_buffer()
386        return await future
387
388    def reset_input_buffer(self) -> int:
389        count = self._input_buffer.clear()
390        self._transport.resume_reading()
391        return count
392
393    async def async_close(self):
394        if self._transport is not None:
395            self._transport.close()
396
397        await self.wait_closed()
398
399    async def wait_closed(self):
400        if self._closed_futures is None:
401            return
402
403        future = self._loop.create_future()
404        self._closed_futures.append(future)
405        await future
406
407    def _on_read_future_done(self, future):
408        if not self._read_queue:
409            return
410
411        if not future.cancelled():
412            return
413
414        for _ in range(len(self._read_queue)):
415            i = self._read_queue.popleft()
416            if not i[2].done():
417                self._read_queue.append(i)
418
419        self._process_input_buffer()
420
421    def _process_input_buffer(self):
422        while self._input_buffer and self._read_queue:
423            exact, n, future = self._read_queue.popleft()
424            if future.done():
425                continue
426
427            if not exact:
428                future.set_result(self._input_buffer.read(n))
429
430            elif n <= len(self._input_buffer):
431                future.set_result(self._input_buffer.read(n))
432
433            else:
434                self._read_queue.appendleft((exact, n, future))
435                break
436
437        if not self._transport:
438            return
439
440        pause = (self._input_buffer_limit > 0 and
441                 len(self._input_buffer) > self._input_buffer_limit and
442                 not self._read_queue)
443
444        if pause:
445            self._transport.pause_reading()
446
447        else:
448            self._transport.resume_reading()

Asyncio protocol implementation

Protocol( on_connected: Optional[Callable[[Protocol], NoneType]], input_buffer_limit: int)
216    def __init__(self,
217                 on_connected: typing.Callable[['Protocol'], None] | None,
218                 input_buffer_limit: int):
219        self._on_connected = on_connected
220        self._input_buffer_limit = input_buffer_limit
221        self._loop = asyncio.get_running_loop()
222        self._input_buffer = util.BytesBuffer()
223        self._transport = None
224        self._read_queue = None
225        self._write_queue = None
226        self._drain_futures = None
227        self._closed_futures = None
228        self._info = None
229        self._ssl_object = None
info: ConnectionInfo
231    @property
232    def info(self) -> ConnectionInfo:
233        return self._info
ssl_object: ssl.SSLObject | ssl.SSLSocket | None
235    @property
236    def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None:
237        return self._ssl_object
def connection_made(self, transport: asyncio.transports.Transport):
239    def connection_made(self, transport: asyncio.Transport):
240        self._transport = transport
241        self._read_queue = collections.deque()
242        self._closed_futures = collections.deque()
243
244        try:
245            sockname = transport.get_extra_info('sockname')
246            peername = transport.get_extra_info('peername')
247            self._info = ConnectionInfo(
248                local_addr=Address(sockname[0], sockname[1]),
249                remote_addr=Address(peername[0], peername[1]))
250            self._ssl_object = transport.get_extra_info('ssl_object')
251
252            if self._on_connected:
253                self._on_connected(self)
254
255        except Exception:
256            transport.abort()
257            return

Called when a connection is made.

The argument is the transport representing the pipe connection. To receive data, wait for data_received() calls. When the connection is closed, connection_lost() is called.

def connection_lost(self, exc: Exception | None):
259    def connection_lost(self, exc: Exception | None):
260        self._transport = None
261        write_queue, self._write_queue = self._write_queue, None
262        drain_futures, self._drain_futures = self._drain_futures, None
263        closed_futures, self._closed_futures = self._closed_futures, None
264
265        self.eof_received()
266
267        while write_queue:
268            _, future = write_queue.popleft()
269            if not future.done():
270                future.set_exception(ConnectionError())
271
272        while drain_futures:
273            future = drain_futures.popleft()
274            if not future.done():
275                future.set_result(None)
276
277        while closed_futures:
278            future = closed_futures.popleft()
279            if not future.done():
280                future.set_result(None)

Called when the connection is lost or closed.

The argument is an exception object or None (the latter meaning a regular EOF is received or the connection was aborted or closed).

def pause_writing(self):
282    def pause_writing(self):
283        self._write_queue = collections.deque()
284        self._drain_futures = collections.deque()

Called when the transport's buffer goes over the high-water mark.

Pause and resume calls are paired -- pause_writing() is called once when the buffer goes strictly over the high-water mark (even if subsequent writes increases the buffer size even more), and eventually resume_writing() is called once when the buffer size reaches the low-water mark.

Note that if the buffer size equals the high-water mark, pause_writing() is not called -- it must go strictly over. Conversely, resume_writing() is called when the buffer size is equal or lower than the low-water mark. These end conditions are important to ensure that things go as expected when either mark is zero.

NOTE: This is the only Protocol callback that is not called through EventLoop.call_soon() -- if it were, it would have no effect when it's most needed (when the app keeps writing without yielding until pause_writing() is called).

def resume_writing(self):
286    def resume_writing(self):
287        write_queue, self._write_queue = self._write_queue, None
288        drain_futures, self._drain_futures = self._drain_futures, None
289
290        while self._write_queue is None and write_queue:
291            data, future = write_queue.popleft()
292            if future.done():
293                continue
294
295            self._transport.write(data)
296            future.set_result(None)
297
298        if write_queue:
299            write_queue.extend(self._write_queue)
300            self._write_queue = write_queue
301
302            drain_futures.extend(self._drain_futures)
303            self._drain_futures = drain_futures
304
305            return
306
307        while drain_futures:
308            future = drain_futures.popleft()
309            if not future.done():
310                future.set_result(None)

Called when the transport's buffer drains below the low-water mark.

See pause_writing() for details.

def data_received(self, data: bytes | bytearray | memoryview):
312    def data_received(self, data: util.Bytes):
313        self._input_buffer.add(data)
314        self._process_input_buffer()

Called when some data is received.

The argument is a bytes object.

def eof_received(self):
316    def eof_received(self):
317        while self._read_queue:
318            exact, n, future = self._read_queue.popleft()
319            if future.done():
320                continue
321
322            if exact and n <= len(self._input_buffer):
323                future.set_result(self._input_buffer.read(n))
324
325            elif not exact and self._input_buffer:
326                future.set_result(self._input_buffer.read(n))
327
328            else:
329                future.set_exception(ConnectionError())
330
331        self._read_queue = None

Called when the other end calls write_eof() or equivalent.

If this returns a false value (including None), the transport will close itself. If it returns a true value, closing the transport is up to the protocol.

async def write(self, data: bytes | bytearray | memoryview):
333    async def write(self, data: util.Bytes):
334        if self._transport is None:
335            raise ConnectionError()
336
337        if self._write_queue is None:
338            self._transport.write(data)
339            return
340
341        future = self._loop.create_future()
342        self._write_queue.append((data, future))
343        await future
async def drain(self):
345    async def drain(self):
346        if self._drain_futures is None:
347            return
348
349        future = self._loop.create_future()
350        self._drain_futures.append(future)
351        await future
async def read(self, n: int) -> bytes | bytearray | memoryview:
353    async def read(self, n: int) -> util.Bytes:
354        if n == 0:
355            return b''
356
357        if self._input_buffer and not self._read_queue:
358            data = self._input_buffer.read(n)
359            self._process_input_buffer()
360            return data
361
362        if self._read_queue is None:
363            raise ConnectionError()
364
365        future = self._loop.create_future()
366        future.add_done_callback(self._on_read_future_done)
367        self._read_queue.append((False, n, future))
368        return await future
async def readexactly(self, n: int) -> bytes | bytearray | memoryview:
370    async def readexactly(self, n: int) -> util.Bytes:
371        if n == 0:
372            return b''
373
374        if n <= len(self._input_buffer) and not self._read_queue:
375            data = self._input_buffer.read(n)
376            self._process_input_buffer()
377            return data
378
379        if self._read_queue is None:
380            raise ConnectionError()
381
382        future = self._loop.create_future()
383        future.add_done_callback(self._on_read_future_done)
384        self._read_queue.append((True, n, future))
385        self._process_input_buffer()
386        return await future
def reset_input_buffer(self) -> int:
388    def reset_input_buffer(self) -> int:
389        count = self._input_buffer.clear()
390        self._transport.resume_reading()
391        return count
async def async_close(self):
393    async def async_close(self):
394        if self._transport is not None:
395            self._transport.close()
396
397        await self.wait_closed()
async def wait_closed(self):
399    async def wait_closed(self):
400        if self._closed_futures is None:
401            return
402
403        future = self._loop.create_future()
404        self._closed_futures.append(future)
405        await future