hat.drivers.tcp

Asyncio TCP wrapper

  1"""Asyncio TCP wrapper"""
  2
  3import asyncio
  4import collections
  5import functools
  6import logging
  7import sys
  8import typing
  9
 10from hat import aio
 11from hat import util
 12
 13from hat.drivers import ssl
 14
 15
 16mlog: logging.Logger = logging.getLogger(__name__)
 17"""Module logger"""
 18
 19
 20class Address(typing.NamedTuple):
 21    host: str
 22    port: int
 23
 24
 25class ConnectionInfo(typing.NamedTuple):
 26    name: str | None
 27    local_addr: Address
 28    remote_addr: Address
 29
 30
 31class ServerInfo(typing.NamedTuple):
 32    name: str | None
 33    addresses: list[Address]
 34
 35
 36ConnectionCb: typing.TypeAlias = aio.AsyncCallable[['Connection'], None]
 37"""Connection callback"""
 38
 39
 40async def connect(addr: Address,
 41                  *,
 42                  name: str | None = None,
 43                  input_buffer_limit: int = 64 * 1024,
 44                  **kwargs
 45                  ) -> 'Connection':
 46    """Create TCP connection
 47
 48    Argument `addr` specifies remote server listening address.
 49
 50    Argument `name` defines connection name available in property `info`.
 51
 52    Argument `input_buffer_limit` defines number of bytes in input buffer
 53    that whill temporary pause data receiving. Once number of bytes
 54    drops bellow `input_buffer_limit`, data receiving is resumed. If this
 55    argument is ``0``, data receive pausing is disabled.
 56
 57    Additional arguments are passed directly to `asyncio.create_connection`.
 58
 59    """
 60    loop = asyncio.get_running_loop()
 61    create_protocol = functools.partial(Protocol, None, name,
 62                                        input_buffer_limit)
 63    _, protocol = await loop.create_connection(create_protocol,
 64                                               addr.host, addr.port,
 65                                               **kwargs)
 66    return Connection(protocol)
 67
 68
 69async def listen(connection_cb: ConnectionCb,
 70                 addr: Address,
 71                 *,
 72                 name: str | None = None,
 73                 bind_connections: bool = False,
 74                 input_buffer_limit: int = 64 * 1024,
 75                 **kwargs
 76                 ) -> 'Server':
 77    """Create listening server
 78
 79    Argument `name` defines server name available in property `info`. This
 80    name is used for all incomming connections.
 81
 82    If `bind_connections` is ``True``, closing server will close all open
 83    incoming connections.
 84
 85    Argument `input_buffer_limit` is associated with newly created connections
 86    (see `connect`).
 87
 88    Additional arguments are passed directly to `asyncio.create_server`.
 89
 90    """
 91    server = Server()
 92    server._connection_cb = connection_cb
 93    server._bind_connections = bind_connections
 94    server._async_group = aio.Group()
 95    server._log = mlog
 96
 97    on_connection = functools.partial(server.async_group.spawn,
 98                                      server._on_connection)
 99    create_protocol = functools.partial(Protocol, on_connection, name,
100                                        input_buffer_limit)
101
102    loop = asyncio.get_running_loop()
103    server._srv = await loop.create_server(create_protocol, addr.host,
104                                           addr.port, **kwargs)
105
106    server.async_group.spawn(aio.call_on_cancel, server._on_close)
107
108    try:
109        socknames = (socket.getsockname() for socket in server._srv.sockets)
110        addresses = [Address(*sockname[:2]) for sockname in socknames]
111        server._info = ServerInfo(name=name,
112                                  addresses=addresses)
113        server._log = _create_server_logger_adapter(server._info)
114
115    except Exception:
116        await aio.uncancellable(server.async_close())
117        raise
118
119    server._log.debug('listening for incomming connections')
120
121    return server
122
123
124class Server(aio.Resource):
125    """TCP listening server
126
127    Closing server will cancel all running `connection_cb` coroutines.
128
129    """
130
131    @property
132    def async_group(self) -> aio.Group:
133        """Async group"""
134        return self._async_group
135
136    @property
137    def info(self) -> ServerInfo:
138        """Server info"""
139        return self._info
140
141    async def _on_close(self):
142        self._srv.close()
143
144        if self._bind_connections or sys.version_info[:2] < (3, 12):
145            await self._srv.wait_closed()
146
147    async def _on_connection(self, protocol):
148        self._log.debug('new incomming connection')
149
150        conn = Connection(protocol)
151
152        try:
153            await aio.call(self._connection_cb, conn)
154
155            if self._bind_connections:
156                await conn.wait_closing()
157
158            else:
159                conn = None
160
161        except Exception as e:
162            self._log.warning('connection callback error: %s', e, exc_info=e)
163
164        finally:
165            if conn:
166                await aio.uncancellable(conn.async_close())
167
168
169class Connection(aio.Resource):
170    """TCP connection"""
171
172    def __init__(self, protocol: 'Protocol'):
173        self._protocol = protocol
174        self._async_group = aio.Group()
175        self._log = _create_connection_logger_adapter(protocol.info)
176
177        self._log.debug('connection established')
178
179        self.async_group.spawn(aio.call_on_cancel, protocol.async_close)
180        self.async_group.spawn(aio.call_on_done, protocol.wait_closed(),
181                               self.close)
182
183    @property
184    def async_group(self) -> aio.Group:
185        """Async group"""
186        return self._async_group
187
188    @property
189    def info(self) -> ConnectionInfo:
190        """Connection info"""
191        return self._protocol.info
192
193    @property
194    def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None:
195        """SSL Object"""
196        return self._protocol.ssl_object
197
198    async def write(self, data: util.Bytes):
199        """Write data
200
201        This coroutine will wait until `data` can be added to output buffer.
202
203        """
204        if not self.is_open:
205            raise ConnectionError()
206
207        await self._protocol.write(data)
208
209    async def drain(self):
210        """Drain output buffer"""
211        await self._protocol.drain()
212
213    async def read(self, n: int = -1) -> util.Bytes:
214        """Read up to `n` bytes
215
216        If EOF is detected and no new bytes are available, `ConnectionError`
217        is raised.
218
219        """
220        return await self._protocol.read(n)
221
222    async def readexactly(self, n: int) -> util.Bytes:
223        """Read exactly `n` bytes
224
225        If exact number of bytes could not be read, `ConnectionError` is
226        raised.
227
228        """
229        return await self._protocol.readexactly(n)
230
231    def clear_input_buffer(self) -> int:
232        """Clear input buffer
233
234        Returns number of bytes cleared from buffer.
235
236        """
237        return self._protocol.clear_input_buffer()
238
239
240class Protocol(asyncio.Protocol):
241    """Asyncio protocol implementation"""
242
243    def __init__(self,
244                 on_connected: typing.Callable[['Protocol'], None] | None,
245                 name: str | None,
246                 input_buffer_limit: int):
247        self._on_connected = on_connected
248        self._name = name
249        self._input_buffer_limit = input_buffer_limit
250        self._loop = asyncio.get_running_loop()
251        self._input_buffer = util.BytesBuffer()
252        self._transport = None
253        self._read_queue = None
254        self._write_queue = None
255        self._drain_futures = None
256        self._closed_futures = None
257        self._info = None
258        self._ssl_object = None
259        self._log = mlog
260
261    @property
262    def info(self) -> ConnectionInfo:
263        return self._info
264
265    @property
266    def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None:
267        return self._ssl_object
268
269    def connection_made(self, transport: asyncio.Transport):
270        self._transport = transport
271        self._read_queue = collections.deque()
272        self._closed_futures = collections.deque()
273
274        try:
275            sockname = transport.get_extra_info('sockname')
276            peername = transport.get_extra_info('peername')
277            self._info = ConnectionInfo(
278                name=self._name,
279                local_addr=Address(sockname[0], sockname[1]),
280                remote_addr=Address(peername[0], peername[1]))
281
282            self._log = _create_connection_logger_adapter(self._info)
283
284            self._ssl_object = transport.get_extra_info('ssl_object')
285
286            if self._on_connected:
287                self._on_connected(self)
288
289        except Exception:
290            transport.abort()
291            return
292
293    def connection_lost(self, exc: Exception | None):
294        self._transport = None
295        write_queue, self._write_queue = self._write_queue, None
296        drain_futures, self._drain_futures = self._drain_futures, None
297        closed_futures, self._closed_futures = self._closed_futures, None
298
299        self.eof_received()
300
301        while write_queue:
302            _, future = write_queue.popleft()
303            if not future.done():
304                future.set_exception(ConnectionError())
305
306        while drain_futures:
307            future = drain_futures.popleft()
308            if not future.done():
309                future.set_result(None)
310
311        while closed_futures:
312            future = closed_futures.popleft()
313            if not future.done():
314                future.set_result(None)
315
316        self._log.debug('connection closed')
317
318    def pause_writing(self):
319        self._log.debug('pause writing')
320
321        self._write_queue = collections.deque()
322        self._drain_futures = collections.deque()
323
324    def resume_writing(self):
325        self._log.debug('resume writing')
326
327        write_queue, self._write_queue = self._write_queue, None
328        drain_futures, self._drain_futures = self._drain_futures, None
329
330        while self._write_queue is None and write_queue:
331            data, future = write_queue.popleft()
332            if future.done():
333                continue
334
335            self._log.debug('writing %s bytes', len(data))
336            self._transport.write(data)
337            future.set_result(None)
338
339        if write_queue:
340            write_queue.extend(self._write_queue)
341            self._write_queue = write_queue
342
343            drain_futures.extend(self._drain_futures)
344            self._drain_futures = drain_futures
345
346            return
347
348        while drain_futures:
349            future = drain_futures.popleft()
350            if not future.done():
351                future.set_result(None)
352
353    def data_received(self, data: util.Bytes):
354        self._log.debug('received %s bytes', len(data))
355
356        self._input_buffer.add(data)
357        self._process_input_buffer()
358
359    def eof_received(self):
360        self._log.debug('eof received')
361
362        while self._read_queue:
363            exact, n, future = self._read_queue.popleft()
364            if future.done():
365                continue
366
367            if exact and n <= len(self._input_buffer):
368                future.set_result(self._input_buffer.read(n))
369
370            elif not exact and self._input_buffer:
371                future.set_result(self._input_buffer.read(n))
372
373            else:
374                future.set_exception(ConnectionError())
375
376        self._read_queue = None
377
378    async def write(self, data: util.Bytes):
379        if self._transport is None:
380            raise ConnectionError()
381
382        if self._write_queue is None:
383            self._log.debug('writing %s bytes', len(data))
384            self._transport.write(data)
385            return
386
387        future = self._loop.create_future()
388        self._write_queue.append((data, future))
389        await future
390
391    async def drain(self):
392        if self._drain_futures is None:
393            return
394
395        future = self._loop.create_future()
396        self._drain_futures.append(future)
397        await future
398
399    async def read(self, n: int) -> util.Bytes:
400        if n == 0:
401            return b''
402
403        if self._input_buffer and not self._read_queue:
404            data = self._input_buffer.read(n)
405            self._process_input_buffer()
406            return data
407
408        if self._read_queue is None:
409            raise ConnectionError()
410
411        future = self._loop.create_future()
412        future.add_done_callback(self._on_read_future_done)
413        self._read_queue.append((False, n, future))
414        return await future
415
416    async def readexactly(self, n: int) -> util.Bytes:
417        if n == 0:
418            return b''
419
420        if n <= len(self._input_buffer) and not self._read_queue:
421            data = self._input_buffer.read(n)
422            self._process_input_buffer()
423            return data
424
425        if self._read_queue is None:
426            raise ConnectionError()
427
428        future = self._loop.create_future()
429        future.add_done_callback(self._on_read_future_done)
430        self._read_queue.append((True, n, future))
431        self._process_input_buffer()
432        return await future
433
434    def clear_input_buffer(self) -> int:
435        count = self._input_buffer.clear()
436        self._transport.resume_reading()
437        return count
438
439    async def async_close(self):
440        if self._transport is not None:
441            self._transport.close()
442
443        await self.wait_closed()
444
445    async def wait_closed(self):
446        if self._closed_futures is None:
447            return
448
449        future = self._loop.create_future()
450        self._closed_futures.append(future)
451        await future
452
453    def _on_read_future_done(self, future):
454        if not self._read_queue:
455            return
456
457        if not future.cancelled():
458            return
459
460        for _ in range(len(self._read_queue)):
461            i = self._read_queue.popleft()
462            if not i[2].done():
463                self._read_queue.append(i)
464
465        self._process_input_buffer()
466
467    def _process_input_buffer(self):
468        while self._input_buffer and self._read_queue:
469            exact, n, future = self._read_queue.popleft()
470            if future.done():
471                continue
472
473            if not exact:
474                future.set_result(self._input_buffer.read(n))
475
476            elif n <= len(self._input_buffer):
477                future.set_result(self._input_buffer.read(n))
478
479            else:
480                self._read_queue.appendleft((exact, n, future))
481                break
482
483        if not self._transport:
484            return
485
486        pause = (self._input_buffer_limit > 0 and
487                 len(self._input_buffer) > self._input_buffer_limit and
488                 not self._read_queue)
489
490        if pause:
491            self._transport.pause_reading()
492
493        else:
494            self._transport.resume_reading()
495
496
497def _create_server_logger_adapter(info):
498    extra = {'meta': {'type': 'TcpServer',
499                      'name': info.name,
500                      'addresses': [{'host': addr.host,
501                                     'port': addr.port}
502                                    for addr in info.addresses]}}
503
504    return logging.LoggerAdapter(mlog, extra)
505
506
507def _create_connection_logger_adapter(info):
508    extra = {'meta': {'type': 'TcpConnection',
509                      'name': info.name,
510                      'local_addr': {'host': info.local_addr.host,
511                                     'port': info.local_addr.port},
512                      'remote_addr': {'host': info.remote_addr.host,
513                                      'port': info.remote_addr.port}}}
514
515    return logging.LoggerAdapter(mlog, extra)
mlog: logging.Logger = <Logger hat.drivers.tcp (WARNING)>

Module logger

class Address(typing.NamedTuple):
21class Address(typing.NamedTuple):
22    host: str
23    port: int

Address(host, port)

Address(host: str, port: int)

Create new instance of Address(host, port)

host: str

Alias for field number 0

port: int

Alias for field number 1

class ConnectionInfo(typing.NamedTuple):
26class ConnectionInfo(typing.NamedTuple):
27    name: str | None
28    local_addr: Address
29    remote_addr: Address

ConnectionInfo(name, local_addr, remote_addr)

ConnectionInfo( name: str | None, local_addr: Address, remote_addr: Address)

Create new instance of ConnectionInfo(name, local_addr, remote_addr)

name: str | None

Alias for field number 0

local_addr: Address

Alias for field number 1

remote_addr: Address

Alias for field number 2

class ServerInfo(typing.NamedTuple):
32class ServerInfo(typing.NamedTuple):
33    name: str | None
34    addresses: list[Address]

ServerInfo(name, addresses)

ServerInfo(name: str | None, addresses: list[Address])

Create new instance of ServerInfo(name, addresses)

name: str | None

Alias for field number 0

addresses: list[Address]

Alias for field number 1

ConnectionCb: TypeAlias = Callable[[ForwardRef('Connection')], None | Awaitable[None]]

Connection callback

async def connect( addr: Address, *, name: str | None = None, input_buffer_limit: int = 65536, **kwargs) -> Connection:
41async def connect(addr: Address,
42                  *,
43                  name: str | None = None,
44                  input_buffer_limit: int = 64 * 1024,
45                  **kwargs
46                  ) -> 'Connection':
47    """Create TCP connection
48
49    Argument `addr` specifies remote server listening address.
50
51    Argument `name` defines connection name available in property `info`.
52
53    Argument `input_buffer_limit` defines number of bytes in input buffer
54    that whill temporary pause data receiving. Once number of bytes
55    drops bellow `input_buffer_limit`, data receiving is resumed. If this
56    argument is ``0``, data receive pausing is disabled.
57
58    Additional arguments are passed directly to `asyncio.create_connection`.
59
60    """
61    loop = asyncio.get_running_loop()
62    create_protocol = functools.partial(Protocol, None, name,
63                                        input_buffer_limit)
64    _, protocol = await loop.create_connection(create_protocol,
65                                               addr.host, addr.port,
66                                               **kwargs)
67    return Connection(protocol)

Create TCP connection

Argument addr specifies remote server listening address.

Argument name defines connection name available in property info.

Argument input_buffer_limit defines number of bytes in input buffer that whill temporary pause data receiving. Once number of bytes drops bellow input_buffer_limit, data receiving is resumed. If this argument is 0, data receive pausing is disabled.

Additional arguments are passed directly to asyncio.create_connection.

async def listen( connection_cb: Callable[[Connection], None | Awaitable[None]], addr: Address, *, name: str | None = None, bind_connections: bool = False, input_buffer_limit: int = 65536, **kwargs) -> Server:
 70async def listen(connection_cb: ConnectionCb,
 71                 addr: Address,
 72                 *,
 73                 name: str | None = None,
 74                 bind_connections: bool = False,
 75                 input_buffer_limit: int = 64 * 1024,
 76                 **kwargs
 77                 ) -> 'Server':
 78    """Create listening server
 79
 80    Argument `name` defines server name available in property `info`. This
 81    name is used for all incomming connections.
 82
 83    If `bind_connections` is ``True``, closing server will close all open
 84    incoming connections.
 85
 86    Argument `input_buffer_limit` is associated with newly created connections
 87    (see `connect`).
 88
 89    Additional arguments are passed directly to `asyncio.create_server`.
 90
 91    """
 92    server = Server()
 93    server._connection_cb = connection_cb
 94    server._bind_connections = bind_connections
 95    server._async_group = aio.Group()
 96    server._log = mlog
 97
 98    on_connection = functools.partial(server.async_group.spawn,
 99                                      server._on_connection)
100    create_protocol = functools.partial(Protocol, on_connection, name,
101                                        input_buffer_limit)
102
103    loop = asyncio.get_running_loop()
104    server._srv = await loop.create_server(create_protocol, addr.host,
105                                           addr.port, **kwargs)
106
107    server.async_group.spawn(aio.call_on_cancel, server._on_close)
108
109    try:
110        socknames = (socket.getsockname() for socket in server._srv.sockets)
111        addresses = [Address(*sockname[:2]) for sockname in socknames]
112        server._info = ServerInfo(name=name,
113                                  addresses=addresses)
114        server._log = _create_server_logger_adapter(server._info)
115
116    except Exception:
117        await aio.uncancellable(server.async_close())
118        raise
119
120    server._log.debug('listening for incomming connections')
121
122    return server

Create listening server

Argument name defines server name available in property info. This name is used for all incomming connections.

If bind_connections is True, closing server will close all open incoming connections.

Argument input_buffer_limit is associated with newly created connections (see connect).

Additional arguments are passed directly to asyncio.create_server.

class Server(hat.aio.group.Resource):
125class Server(aio.Resource):
126    """TCP listening server
127
128    Closing server will cancel all running `connection_cb` coroutines.
129
130    """
131
132    @property
133    def async_group(self) -> aio.Group:
134        """Async group"""
135        return self._async_group
136
137    @property
138    def info(self) -> ServerInfo:
139        """Server info"""
140        return self._info
141
142    async def _on_close(self):
143        self._srv.close()
144
145        if self._bind_connections or sys.version_info[:2] < (3, 12):
146            await self._srv.wait_closed()
147
148    async def _on_connection(self, protocol):
149        self._log.debug('new incomming connection')
150
151        conn = Connection(protocol)
152
153        try:
154            await aio.call(self._connection_cb, conn)
155
156            if self._bind_connections:
157                await conn.wait_closing()
158
159            else:
160                conn = None
161
162        except Exception as e:
163            self._log.warning('connection callback error: %s', e, exc_info=e)
164
165        finally:
166            if conn:
167                await aio.uncancellable(conn.async_close())

TCP listening server

Closing server will cancel all running connection_cb coroutines.

async_group: hat.aio.group.Group
132    @property
133    def async_group(self) -> aio.Group:
134        """Async group"""
135        return self._async_group

Async group

info: ServerInfo
137    @property
138    def info(self) -> ServerInfo:
139        """Server info"""
140        return self._info

Server info

class Connection(hat.aio.group.Resource):
170class Connection(aio.Resource):
171    """TCP connection"""
172
173    def __init__(self, protocol: 'Protocol'):
174        self._protocol = protocol
175        self._async_group = aio.Group()
176        self._log = _create_connection_logger_adapter(protocol.info)
177
178        self._log.debug('connection established')
179
180        self.async_group.spawn(aio.call_on_cancel, protocol.async_close)
181        self.async_group.spawn(aio.call_on_done, protocol.wait_closed(),
182                               self.close)
183
184    @property
185    def async_group(self) -> aio.Group:
186        """Async group"""
187        return self._async_group
188
189    @property
190    def info(self) -> ConnectionInfo:
191        """Connection info"""
192        return self._protocol.info
193
194    @property
195    def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None:
196        """SSL Object"""
197        return self._protocol.ssl_object
198
199    async def write(self, data: util.Bytes):
200        """Write data
201
202        This coroutine will wait until `data` can be added to output buffer.
203
204        """
205        if not self.is_open:
206            raise ConnectionError()
207
208        await self._protocol.write(data)
209
210    async def drain(self):
211        """Drain output buffer"""
212        await self._protocol.drain()
213
214    async def read(self, n: int = -1) -> util.Bytes:
215        """Read up to `n` bytes
216
217        If EOF is detected and no new bytes are available, `ConnectionError`
218        is raised.
219
220        """
221        return await self._protocol.read(n)
222
223    async def readexactly(self, n: int) -> util.Bytes:
224        """Read exactly `n` bytes
225
226        If exact number of bytes could not be read, `ConnectionError` is
227        raised.
228
229        """
230        return await self._protocol.readexactly(n)
231
232    def clear_input_buffer(self) -> int:
233        """Clear input buffer
234
235        Returns number of bytes cleared from buffer.
236
237        """
238        return self._protocol.clear_input_buffer()

TCP connection

Connection(protocol: Protocol)
173    def __init__(self, protocol: 'Protocol'):
174        self._protocol = protocol
175        self._async_group = aio.Group()
176        self._log = _create_connection_logger_adapter(protocol.info)
177
178        self._log.debug('connection established')
179
180        self.async_group.spawn(aio.call_on_cancel, protocol.async_close)
181        self.async_group.spawn(aio.call_on_done, protocol.wait_closed(),
182                               self.close)
async_group: hat.aio.group.Group
184    @property
185    def async_group(self) -> aio.Group:
186        """Async group"""
187        return self._async_group

Async group

info: ConnectionInfo
189    @property
190    def info(self) -> ConnectionInfo:
191        """Connection info"""
192        return self._protocol.info

Connection info

ssl_object: ssl.SSLObject | ssl.SSLSocket | None
194    @property
195    def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None:
196        """SSL Object"""
197        return self._protocol.ssl_object

SSL Object

async def write(self, data: bytes | bytearray | memoryview):
199    async def write(self, data: util.Bytes):
200        """Write data
201
202        This coroutine will wait until `data` can be added to output buffer.
203
204        """
205        if not self.is_open:
206            raise ConnectionError()
207
208        await self._protocol.write(data)

Write data

This coroutine will wait until data can be added to output buffer.

async def drain(self):
210    async def drain(self):
211        """Drain output buffer"""
212        await self._protocol.drain()

Drain output buffer

async def read(self, n: int = -1) -> bytes | bytearray | memoryview:
214    async def read(self, n: int = -1) -> util.Bytes:
215        """Read up to `n` bytes
216
217        If EOF is detected and no new bytes are available, `ConnectionError`
218        is raised.
219
220        """
221        return await self._protocol.read(n)

Read up to n bytes

If EOF is detected and no new bytes are available, ConnectionError is raised.

async def readexactly(self, n: int) -> bytes | bytearray | memoryview:
223    async def readexactly(self, n: int) -> util.Bytes:
224        """Read exactly `n` bytes
225
226        If exact number of bytes could not be read, `ConnectionError` is
227        raised.
228
229        """
230        return await self._protocol.readexactly(n)

Read exactly n bytes

If exact number of bytes could not be read, ConnectionError is raised.

def clear_input_buffer(self) -> int:
232    def clear_input_buffer(self) -> int:
233        """Clear input buffer
234
235        Returns number of bytes cleared from buffer.
236
237        """
238        return self._protocol.clear_input_buffer()

Clear input buffer

Returns number of bytes cleared from buffer.

class Protocol(asyncio.protocols.Protocol):
241class Protocol(asyncio.Protocol):
242    """Asyncio protocol implementation"""
243
244    def __init__(self,
245                 on_connected: typing.Callable[['Protocol'], None] | None,
246                 name: str | None,
247                 input_buffer_limit: int):
248        self._on_connected = on_connected
249        self._name = name
250        self._input_buffer_limit = input_buffer_limit
251        self._loop = asyncio.get_running_loop()
252        self._input_buffer = util.BytesBuffer()
253        self._transport = None
254        self._read_queue = None
255        self._write_queue = None
256        self._drain_futures = None
257        self._closed_futures = None
258        self._info = None
259        self._ssl_object = None
260        self._log = mlog
261
262    @property
263    def info(self) -> ConnectionInfo:
264        return self._info
265
266    @property
267    def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None:
268        return self._ssl_object
269
270    def connection_made(self, transport: asyncio.Transport):
271        self._transport = transport
272        self._read_queue = collections.deque()
273        self._closed_futures = collections.deque()
274
275        try:
276            sockname = transport.get_extra_info('sockname')
277            peername = transport.get_extra_info('peername')
278            self._info = ConnectionInfo(
279                name=self._name,
280                local_addr=Address(sockname[0], sockname[1]),
281                remote_addr=Address(peername[0], peername[1]))
282
283            self._log = _create_connection_logger_adapter(self._info)
284
285            self._ssl_object = transport.get_extra_info('ssl_object')
286
287            if self._on_connected:
288                self._on_connected(self)
289
290        except Exception:
291            transport.abort()
292            return
293
294    def connection_lost(self, exc: Exception | None):
295        self._transport = None
296        write_queue, self._write_queue = self._write_queue, None
297        drain_futures, self._drain_futures = self._drain_futures, None
298        closed_futures, self._closed_futures = self._closed_futures, None
299
300        self.eof_received()
301
302        while write_queue:
303            _, future = write_queue.popleft()
304            if not future.done():
305                future.set_exception(ConnectionError())
306
307        while drain_futures:
308            future = drain_futures.popleft()
309            if not future.done():
310                future.set_result(None)
311
312        while closed_futures:
313            future = closed_futures.popleft()
314            if not future.done():
315                future.set_result(None)
316
317        self._log.debug('connection closed')
318
319    def pause_writing(self):
320        self._log.debug('pause writing')
321
322        self._write_queue = collections.deque()
323        self._drain_futures = collections.deque()
324
325    def resume_writing(self):
326        self._log.debug('resume writing')
327
328        write_queue, self._write_queue = self._write_queue, None
329        drain_futures, self._drain_futures = self._drain_futures, None
330
331        while self._write_queue is None and write_queue:
332            data, future = write_queue.popleft()
333            if future.done():
334                continue
335
336            self._log.debug('writing %s bytes', len(data))
337            self._transport.write(data)
338            future.set_result(None)
339
340        if write_queue:
341            write_queue.extend(self._write_queue)
342            self._write_queue = write_queue
343
344            drain_futures.extend(self._drain_futures)
345            self._drain_futures = drain_futures
346
347            return
348
349        while drain_futures:
350            future = drain_futures.popleft()
351            if not future.done():
352                future.set_result(None)
353
354    def data_received(self, data: util.Bytes):
355        self._log.debug('received %s bytes', len(data))
356
357        self._input_buffer.add(data)
358        self._process_input_buffer()
359
360    def eof_received(self):
361        self._log.debug('eof received')
362
363        while self._read_queue:
364            exact, n, future = self._read_queue.popleft()
365            if future.done():
366                continue
367
368            if exact and n <= len(self._input_buffer):
369                future.set_result(self._input_buffer.read(n))
370
371            elif not exact and self._input_buffer:
372                future.set_result(self._input_buffer.read(n))
373
374            else:
375                future.set_exception(ConnectionError())
376
377        self._read_queue = None
378
379    async def write(self, data: util.Bytes):
380        if self._transport is None:
381            raise ConnectionError()
382
383        if self._write_queue is None:
384            self._log.debug('writing %s bytes', len(data))
385            self._transport.write(data)
386            return
387
388        future = self._loop.create_future()
389        self._write_queue.append((data, future))
390        await future
391
392    async def drain(self):
393        if self._drain_futures is None:
394            return
395
396        future = self._loop.create_future()
397        self._drain_futures.append(future)
398        await future
399
400    async def read(self, n: int) -> util.Bytes:
401        if n == 0:
402            return b''
403
404        if self._input_buffer and not self._read_queue:
405            data = self._input_buffer.read(n)
406            self._process_input_buffer()
407            return data
408
409        if self._read_queue is None:
410            raise ConnectionError()
411
412        future = self._loop.create_future()
413        future.add_done_callback(self._on_read_future_done)
414        self._read_queue.append((False, n, future))
415        return await future
416
417    async def readexactly(self, n: int) -> util.Bytes:
418        if n == 0:
419            return b''
420
421        if n <= len(self._input_buffer) and not self._read_queue:
422            data = self._input_buffer.read(n)
423            self._process_input_buffer()
424            return data
425
426        if self._read_queue is None:
427            raise ConnectionError()
428
429        future = self._loop.create_future()
430        future.add_done_callback(self._on_read_future_done)
431        self._read_queue.append((True, n, future))
432        self._process_input_buffer()
433        return await future
434
435    def clear_input_buffer(self) -> int:
436        count = self._input_buffer.clear()
437        self._transport.resume_reading()
438        return count
439
440    async def async_close(self):
441        if self._transport is not None:
442            self._transport.close()
443
444        await self.wait_closed()
445
446    async def wait_closed(self):
447        if self._closed_futures is None:
448            return
449
450        future = self._loop.create_future()
451        self._closed_futures.append(future)
452        await future
453
454    def _on_read_future_done(self, future):
455        if not self._read_queue:
456            return
457
458        if not future.cancelled():
459            return
460
461        for _ in range(len(self._read_queue)):
462            i = self._read_queue.popleft()
463            if not i[2].done():
464                self._read_queue.append(i)
465
466        self._process_input_buffer()
467
468    def _process_input_buffer(self):
469        while self._input_buffer and self._read_queue:
470            exact, n, future = self._read_queue.popleft()
471            if future.done():
472                continue
473
474            if not exact:
475                future.set_result(self._input_buffer.read(n))
476
477            elif n <= len(self._input_buffer):
478                future.set_result(self._input_buffer.read(n))
479
480            else:
481                self._read_queue.appendleft((exact, n, future))
482                break
483
484        if not self._transport:
485            return
486
487        pause = (self._input_buffer_limit > 0 and
488                 len(self._input_buffer) > self._input_buffer_limit and
489                 not self._read_queue)
490
491        if pause:
492            self._transport.pause_reading()
493
494        else:
495            self._transport.resume_reading()

Asyncio protocol implementation

Protocol( on_connected: Optional[Callable[[Protocol], NoneType]], name: str | None, input_buffer_limit: int)
244    def __init__(self,
245                 on_connected: typing.Callable[['Protocol'], None] | None,
246                 name: str | None,
247                 input_buffer_limit: int):
248        self._on_connected = on_connected
249        self._name = name
250        self._input_buffer_limit = input_buffer_limit
251        self._loop = asyncio.get_running_loop()
252        self._input_buffer = util.BytesBuffer()
253        self._transport = None
254        self._read_queue = None
255        self._write_queue = None
256        self._drain_futures = None
257        self._closed_futures = None
258        self._info = None
259        self._ssl_object = None
260        self._log = mlog
info: ConnectionInfo
262    @property
263    def info(self) -> ConnectionInfo:
264        return self._info
ssl_object: ssl.SSLObject | ssl.SSLSocket | None
266    @property
267    def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None:
268        return self._ssl_object
def connection_made(self, transport: asyncio.transports.Transport):
270    def connection_made(self, transport: asyncio.Transport):
271        self._transport = transport
272        self._read_queue = collections.deque()
273        self._closed_futures = collections.deque()
274
275        try:
276            sockname = transport.get_extra_info('sockname')
277            peername = transport.get_extra_info('peername')
278            self._info = ConnectionInfo(
279                name=self._name,
280                local_addr=Address(sockname[0], sockname[1]),
281                remote_addr=Address(peername[0], peername[1]))
282
283            self._log = _create_connection_logger_adapter(self._info)
284
285            self._ssl_object = transport.get_extra_info('ssl_object')
286
287            if self._on_connected:
288                self._on_connected(self)
289
290        except Exception:
291            transport.abort()
292            return

Called when a connection is made.

The argument is the transport representing the pipe connection. To receive data, wait for data_received() calls. When the connection is closed, connection_lost() is called.

def connection_lost(self, exc: Exception | None):
294    def connection_lost(self, exc: Exception | None):
295        self._transport = None
296        write_queue, self._write_queue = self._write_queue, None
297        drain_futures, self._drain_futures = self._drain_futures, None
298        closed_futures, self._closed_futures = self._closed_futures, None
299
300        self.eof_received()
301
302        while write_queue:
303            _, future = write_queue.popleft()
304            if not future.done():
305                future.set_exception(ConnectionError())
306
307        while drain_futures:
308            future = drain_futures.popleft()
309            if not future.done():
310                future.set_result(None)
311
312        while closed_futures:
313            future = closed_futures.popleft()
314            if not future.done():
315                future.set_result(None)
316
317        self._log.debug('connection closed')

Called when the connection is lost or closed.

The argument is an exception object or None (the latter meaning a regular EOF is received or the connection was aborted or closed).

def pause_writing(self):
319    def pause_writing(self):
320        self._log.debug('pause writing')
321
322        self._write_queue = collections.deque()
323        self._drain_futures = collections.deque()

Called when the transport's buffer goes over the high-water mark.

Pause and resume calls are paired -- pause_writing() is called once when the buffer goes strictly over the high-water mark (even if subsequent writes increases the buffer size even more), and eventually resume_writing() is called once when the buffer size reaches the low-water mark.

Note that if the buffer size equals the high-water mark, pause_writing() is not called -- it must go strictly over. Conversely, resume_writing() is called when the buffer size is equal or lower than the low-water mark. These end conditions are important to ensure that things go as expected when either mark is zero.

NOTE: This is the only Protocol callback that is not called through EventLoop.call_soon() -- if it were, it would have no effect when it's most needed (when the app keeps writing without yielding until pause_writing() is called).

def resume_writing(self):
325    def resume_writing(self):
326        self._log.debug('resume writing')
327
328        write_queue, self._write_queue = self._write_queue, None
329        drain_futures, self._drain_futures = self._drain_futures, None
330
331        while self._write_queue is None and write_queue:
332            data, future = write_queue.popleft()
333            if future.done():
334                continue
335
336            self._log.debug('writing %s bytes', len(data))
337            self._transport.write(data)
338            future.set_result(None)
339
340        if write_queue:
341            write_queue.extend(self._write_queue)
342            self._write_queue = write_queue
343
344            drain_futures.extend(self._drain_futures)
345            self._drain_futures = drain_futures
346
347            return
348
349        while drain_futures:
350            future = drain_futures.popleft()
351            if not future.done():
352                future.set_result(None)

Called when the transport's buffer drains below the low-water mark.

See pause_writing() for details.

def data_received(self, data: bytes | bytearray | memoryview):
354    def data_received(self, data: util.Bytes):
355        self._log.debug('received %s bytes', len(data))
356
357        self._input_buffer.add(data)
358        self._process_input_buffer()

Called when some data is received.

The argument is a bytes object.

def eof_received(self):
360    def eof_received(self):
361        self._log.debug('eof received')
362
363        while self._read_queue:
364            exact, n, future = self._read_queue.popleft()
365            if future.done():
366                continue
367
368            if exact and n <= len(self._input_buffer):
369                future.set_result(self._input_buffer.read(n))
370
371            elif not exact and self._input_buffer:
372                future.set_result(self._input_buffer.read(n))
373
374            else:
375                future.set_exception(ConnectionError())
376
377        self._read_queue = None

Called when the other end calls write_eof() or equivalent.

If this returns a false value (including None), the transport will close itself. If it returns a true value, closing the transport is up to the protocol.

async def write(self, data: bytes | bytearray | memoryview):
379    async def write(self, data: util.Bytes):
380        if self._transport is None:
381            raise ConnectionError()
382
383        if self._write_queue is None:
384            self._log.debug('writing %s bytes', len(data))
385            self._transport.write(data)
386            return
387
388        future = self._loop.create_future()
389        self._write_queue.append((data, future))
390        await future
async def drain(self):
392    async def drain(self):
393        if self._drain_futures is None:
394            return
395
396        future = self._loop.create_future()
397        self._drain_futures.append(future)
398        await future
async def read(self, n: int) -> bytes | bytearray | memoryview:
400    async def read(self, n: int) -> util.Bytes:
401        if n == 0:
402            return b''
403
404        if self._input_buffer and not self._read_queue:
405            data = self._input_buffer.read(n)
406            self._process_input_buffer()
407            return data
408
409        if self._read_queue is None:
410            raise ConnectionError()
411
412        future = self._loop.create_future()
413        future.add_done_callback(self._on_read_future_done)
414        self._read_queue.append((False, n, future))
415        return await future
async def readexactly(self, n: int) -> bytes | bytearray | memoryview:
417    async def readexactly(self, n: int) -> util.Bytes:
418        if n == 0:
419            return b''
420
421        if n <= len(self._input_buffer) and not self._read_queue:
422            data = self._input_buffer.read(n)
423            self._process_input_buffer()
424            return data
425
426        if self._read_queue is None:
427            raise ConnectionError()
428
429        future = self._loop.create_future()
430        future.add_done_callback(self._on_read_future_done)
431        self._read_queue.append((True, n, future))
432        self._process_input_buffer()
433        return await future
def clear_input_buffer(self) -> int:
435    def clear_input_buffer(self) -> int:
436        count = self._input_buffer.clear()
437        self._transport.resume_reading()
438        return count
async def async_close(self):
440    async def async_close(self):
441        if self._transport is not None:
442            self._transport.close()
443
444        await self.wait_closed()
async def wait_closed(self):
446    async def wait_closed(self):
447        if self._closed_futures is None:
448            return
449
450        future = self._loop.create_future()
451        self._closed_futures.append(future)
452        await future