hat.drivers.tcp

Asyncio TCP wrapper

  1"""Asyncio TCP wrapper"""
  2
  3import asyncio
  4import collections
  5import functools
  6import logging
  7import sys
  8import typing
  9
 10from hat import aio
 11from hat import util
 12
 13from hat.drivers import common
 14from hat.drivers import ssl
 15
 16
 17mlog: logging.Logger = logging.getLogger(__name__)
 18"""Module logger"""
 19
 20
 21class Address(typing.NamedTuple):
 22    host: str
 23    port: int
 24
 25
 26class ConnectionInfo(typing.NamedTuple):
 27    name: str | None
 28    local_addr: Address
 29    remote_addr: Address
 30
 31
 32class ServerInfo(typing.NamedTuple):
 33    name: str | None
 34    addresses: list[Address]
 35
 36
 37ConnectionCb: typing.TypeAlias = aio.AsyncCallable[['Connection'], None]
 38"""Connection callback"""
 39
 40
 41async def connect(addr: Address,
 42                  *,
 43                  name: str | None = None,
 44                  input_buffer_limit: int = 64 * 1024,
 45                  **kwargs
 46                  ) -> 'Connection':
 47    """Create TCP connection
 48
 49    Argument `addr` specifies remote server listening address.
 50
 51    Argument `name` defines connection name available in property `info`.
 52
 53    Argument `input_buffer_limit` defines number of bytes in input buffer
 54    that whill temporary pause data receiving. Once number of bytes
 55    drops bellow `input_buffer_limit`, data receiving is resumed. If this
 56    argument is ``0``, data receive pausing is disabled.
 57
 58    Additional arguments are passed directly to `asyncio.create_connection`.
 59
 60    """
 61    loop = asyncio.get_running_loop()
 62    create_protocol = functools.partial(Protocol, None, name,
 63                                        input_buffer_limit)
 64    _, protocol = await loop.create_connection(create_protocol,
 65                                               addr.host, addr.port,
 66                                               **kwargs)
 67    return Connection(protocol)
 68
 69
 70async def listen(connection_cb: ConnectionCb,
 71                 addr: Address,
 72                 *,
 73                 name: str | None = None,
 74                 bind_connections: bool = False,
 75                 input_buffer_limit: int = 64 * 1024,
 76                 **kwargs
 77                 ) -> 'Server':
 78    """Create listening server
 79
 80    Argument `name` defines server name available in property `info`. This
 81    name is used for all incomming connections.
 82
 83    If `bind_connections` is ``True``, closing server will close all open
 84    incoming connections.
 85
 86    Argument `input_buffer_limit` is associated with newly created connections
 87    (see `connect`).
 88
 89    Additional arguments are passed directly to `asyncio.create_server`.
 90
 91    """
 92    server = Server()
 93    server._connection_cb = connection_cb
 94    server._bind_connections = bind_connections
 95    server._async_group = aio.Group()
 96    server._log = _create_server_logger(name, None)
 97
 98    on_connection = functools.partial(server.async_group.spawn,
 99                                      server._on_connection)
100    create_protocol = functools.partial(Protocol, on_connection, name,
101                                        input_buffer_limit)
102
103    loop = asyncio.get_running_loop()
104    server._srv = await loop.create_server(create_protocol, addr.host,
105                                           addr.port, **kwargs)
106
107    server.async_group.spawn(aio.call_on_cancel, server._on_close)
108
109    try:
110        socknames = (socket.getsockname() for socket in server._srv.sockets)
111        addresses = [Address(*sockname[:2]) for sockname in socknames]
112        server._info = ServerInfo(name=name,
113                                  addresses=addresses)
114        server._log = _create_server_logger(name, server._info)
115
116    except Exception:
117        await aio.uncancellable(server.async_close())
118        raise
119
120    server._log.debug('listening for incomming connections')
121
122    return server
123
124
125class Server(aio.Resource):
126    """TCP listening server
127
128    Closing server will cancel all running `connection_cb` coroutines.
129
130    """
131
132    @property
133    def async_group(self) -> aio.Group:
134        """Async group"""
135        return self._async_group
136
137    @property
138    def info(self) -> ServerInfo:
139        """Server info"""
140        return self._info
141
142    async def _on_close(self):
143        self._srv.close()
144
145        if self._bind_connections or sys.version_info[:2] < (3, 12):
146            await self._srv.wait_closed()
147
148    async def _on_connection(self, protocol):
149        self._log.debug('new incomming connection')
150
151        conn = Connection(protocol)
152
153        try:
154            await aio.call(self._connection_cb, conn)
155
156            if self._bind_connections:
157                await conn.wait_closing()
158
159            else:
160                conn = None
161
162        except Exception as e:
163            self._log.warning('connection callback error: %s', e, exc_info=e)
164
165        finally:
166            if conn:
167                await aio.uncancellable(conn.async_close())
168
169
170class Connection(aio.Resource):
171    """TCP connection"""
172
173    def __init__(self, protocol: 'Protocol'):
174        self._protocol = protocol
175        self._async_group = aio.Group()
176        self._log = _create_connection_logger(protocol.info.name,
177                                              protocol.info)
178        self._comm_log = _CommunicationLogger(protocol.info.name,
179                                              protocol.info)
180
181        self.async_group.spawn(aio.call_on_cancel, protocol.async_close)
182        self.async_group.spawn(aio.call_on_done, protocol.wait_closed(),
183                               self.close)
184
185        self._comm_log.log(common.CommLogAction.OPEN)
186
187    @property
188    def async_group(self) -> aio.Group:
189        """Async group"""
190        return self._async_group
191
192    @property
193    def info(self) -> ConnectionInfo:
194        """Connection info"""
195        return self._protocol.info
196
197    @property
198    def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None:
199        """SSL Object"""
200        return self._protocol.ssl_object
201
202    async def write(self, data: util.Bytes):
203        """Write data
204
205        This coroutine will wait until `data` can be added to output buffer.
206
207        """
208        if not self.is_open:
209            raise ConnectionError()
210
211        await self._protocol.write(data)
212
213    async def drain(self):
214        """Drain output buffer"""
215        await self._protocol.drain()
216
217    async def read(self, n: int = -1) -> util.Bytes:
218        """Read up to `n` bytes
219
220        If EOF is detected and no new bytes are available, `ConnectionError`
221        is raised.
222
223        """
224        return await self._protocol.read(n)
225
226    async def readexactly(self, n: int) -> util.Bytes:
227        """Read exactly `n` bytes
228
229        If exact number of bytes could not be read, `ConnectionError` is
230        raised.
231
232        """
233        return await self._protocol.readexactly(n)
234
235    def clear_input_buffer(self) -> int:
236        """Clear input buffer
237
238        Returns number of bytes cleared from buffer.
239
240        """
241        return self._protocol.clear_input_buffer()
242
243
244class Protocol(asyncio.Protocol):
245    """Asyncio protocol implementation"""
246
247    def __init__(self,
248                 on_connected: typing.Callable[['Protocol'], None] | None,
249                 name: str | None,
250                 input_buffer_limit: int):
251        self._on_connected = on_connected
252        self._name = name
253        self._input_buffer_limit = input_buffer_limit
254        self._loop = asyncio.get_running_loop()
255        self._input_buffer = util.BytesBuffer()
256        self._transport = None
257        self._read_queue = None
258        self._write_queue = None
259        self._drain_futures = None
260        self._closed_futures = None
261        self._info = None
262        self._ssl_object = None
263        self._log = _create_connection_logger(name, None)
264        self._comm_log = _CommunicationLogger(name, None)
265
266    @property
267    def info(self) -> ConnectionInfo:
268        return self._info
269
270    @property
271    def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None:
272        return self._ssl_object
273
274    def connection_made(self, transport: asyncio.Transport):
275        self._transport = transport
276        self._read_queue = collections.deque()
277        self._closed_futures = collections.deque()
278
279        try:
280            sockname = transport.get_extra_info('sockname')
281            peername = transport.get_extra_info('peername')
282            self._info = ConnectionInfo(
283                name=self._name,
284                local_addr=Address(sockname[0], sockname[1]),
285                remote_addr=Address(peername[0], peername[1]))
286
287            self._log = _create_connection_logger(self._name, self._info)
288            self._comm_log = _CommunicationLogger(self._name, self._info)
289
290            self._ssl_object = transport.get_extra_info('ssl_object')
291
292            if self._on_connected:
293                self._on_connected(self)
294
295        except Exception:
296            transport.abort()
297            return
298
299    def connection_lost(self, exc: Exception | None):
300        self._transport = None
301        write_queue, self._write_queue = self._write_queue, None
302        drain_futures, self._drain_futures = self._drain_futures, None
303        closed_futures, self._closed_futures = self._closed_futures, None
304
305        self.eof_received()
306
307        while write_queue:
308            _, future = write_queue.popleft()
309            if not future.done():
310                future.set_exception(ConnectionError())
311
312        while drain_futures:
313            future = drain_futures.popleft()
314            if not future.done():
315                future.set_result(None)
316
317        while closed_futures:
318            future = closed_futures.popleft()
319            if not future.done():
320                future.set_result(None)
321
322        self._comm_log.log(common.CommLogAction.CLOSE)
323
324    def pause_writing(self):
325        self._log.debug('pause writing')
326
327        self._write_queue = collections.deque()
328        self._drain_futures = collections.deque()
329
330    def resume_writing(self):
331        self._log.debug('resume writing')
332
333        write_queue, self._write_queue = self._write_queue, None
334        drain_futures, self._drain_futures = self._drain_futures, None
335
336        while self._write_queue is None and write_queue:
337            data, future = write_queue.popleft()
338            if future.done():
339                continue
340
341            self._comm_log.log(common.CommLogAction.SEND, data)
342
343            self._transport.write(data)
344            future.set_result(None)
345
346        if write_queue:
347            write_queue.extend(self._write_queue)
348            self._write_queue = write_queue
349
350            drain_futures.extend(self._drain_futures)
351            self._drain_futures = drain_futures
352
353            return
354
355        while drain_futures:
356            future = drain_futures.popleft()
357            if not future.done():
358                future.set_result(None)
359
360    def data_received(self, data: util.Bytes):
361        self._comm_log.log(common.CommLogAction.RECEIVE, data)
362
363        self._input_buffer.add(data)
364        self._process_input_buffer()
365
366    def eof_received(self):
367        self._log.debug('eof received')
368
369        while self._read_queue:
370            exact, n, future = self._read_queue.popleft()
371            if future.done():
372                continue
373
374            if exact and n <= len(self._input_buffer):
375                future.set_result(self._input_buffer.read(n))
376
377            elif not exact and self._input_buffer:
378                future.set_result(self._input_buffer.read(n))
379
380            else:
381                future.set_exception(ConnectionError())
382
383        self._read_queue = None
384
385    async def write(self, data: util.Bytes):
386        if self._transport is None:
387            raise ConnectionError()
388
389        if self._write_queue is None:
390            self._comm_log.log(common.CommLogAction.SEND, data)
391
392            self._transport.write(data)
393            return
394
395        future = self._loop.create_future()
396        self._write_queue.append((data, future))
397        await future
398
399    async def drain(self):
400        if self._drain_futures is None:
401            return
402
403        future = self._loop.create_future()
404        self._drain_futures.append(future)
405        await future
406
407    async def read(self, n: int) -> util.Bytes:
408        if n == 0:
409            return b''
410
411        if self._input_buffer and not self._read_queue:
412            data = self._input_buffer.read(n)
413            self._process_input_buffer()
414            return data
415
416        if self._read_queue is None:
417            raise ConnectionError()
418
419        future = self._loop.create_future()
420        future.add_done_callback(self._on_read_future_done)
421        self._read_queue.append((False, n, future))
422        return await future
423
424    async def readexactly(self, n: int) -> util.Bytes:
425        if n == 0:
426            return b''
427
428        if n <= len(self._input_buffer) and not self._read_queue:
429            data = self._input_buffer.read(n)
430            self._process_input_buffer()
431            return data
432
433        if self._read_queue is None:
434            raise ConnectionError()
435
436        future = self._loop.create_future()
437        future.add_done_callback(self._on_read_future_done)
438        self._read_queue.append((True, n, future))
439        self._process_input_buffer()
440        return await future
441
442    def clear_input_buffer(self) -> int:
443        count = self._input_buffer.clear()
444        self._transport.resume_reading()
445        return count
446
447    async def async_close(self):
448        if self._transport is not None:
449            self._transport.close()
450
451        await self.wait_closed()
452
453    async def wait_closed(self):
454        if self._closed_futures is None:
455            return
456
457        future = self._loop.create_future()
458        self._closed_futures.append(future)
459        await future
460
461    def _on_read_future_done(self, future):
462        if not self._read_queue:
463            return
464
465        if not future.cancelled():
466            return
467
468        for _ in range(len(self._read_queue)):
469            i = self._read_queue.popleft()
470            if not i[2].done():
471                self._read_queue.append(i)
472
473        self._process_input_buffer()
474
475    def _process_input_buffer(self):
476        while self._input_buffer and self._read_queue:
477            exact, n, future = self._read_queue.popleft()
478            if future.done():
479                continue
480
481            if not exact:
482                future.set_result(self._input_buffer.read(n))
483
484            elif n <= len(self._input_buffer):
485                future.set_result(self._input_buffer.read(n))
486
487            else:
488                self._read_queue.appendleft((exact, n, future))
489                break
490
491        if not self._transport:
492            return
493
494        pause = (self._input_buffer_limit > 0 and
495                 len(self._input_buffer) > self._input_buffer_limit and
496                 not self._read_queue)
497
498        if pause:
499            self._transport.pause_reading()
500
501        else:
502            self._transport.resume_reading()
503
504
505def _create_server_logger(name, info):
506    extra = {'meta': {'type': 'TcpServer',
507                      'name': name}}
508
509    if info is not None:
510        extra['meta']['addresses'] = [{'host': addr.host,
511                                       'port': addr.port}
512                                      for addr in info.addresses]
513
514    return logging.LoggerAdapter(mlog, extra)
515
516
517def _create_connection_logger(name, info):
518    extra = {'meta': {'type': 'TcpConnection',
519                      'name': name}}
520
521    if info is not None:
522        extra['meta']['local_addr'] = {'host': info.local_addr.host,
523                                       'port': info.local_addr.port}
524        extra['meta']['remote_addr'] = {'host': info.remote_addr.host,
525                                        'port': info.remote_addr.port}
526
527    return logging.LoggerAdapter(mlog, extra)
528
529
530class _CommunicationLogger:
531
532    def __init__(self,
533                 name: str | None,
534                 info: ConnectionInfo | None):
535        extra = {'meta': {'type': 'TcpConnection',
536                          'communication': True,
537                          'name': name}}
538
539        if info is not None:
540            extra['meta']['local_addr'] = {'host': info.local_addr.host,
541                                           'port': info.local_addr.port}
542            extra['meta']['remote_addr'] = {'host': info.remote_addr.host,
543                                            'port': info.remote_addr.port}
544
545        self._log = logging.LoggerAdapter(mlog, extra)
546
547    def log(self,
548            action: common.CommLogAction,
549            data: util.Bytes | None = None):
550        if not self._log.isEnabledFor(logging.DEBUG):
551            return
552
553        if data is None:
554            self._log.debug(action.value, stacklevel=2)
555
556        else:
557            self._log.debug('%s (%s)', action.value, data.hex(' '),
558                            stacklevel=2)
mlog: logging.Logger = <Logger hat.drivers.tcp (WARNING)>

Module logger

class Address(typing.NamedTuple):
22class Address(typing.NamedTuple):
23    host: str
24    port: int

Address(host, port)

Address(host: str, port: int)

Create new instance of Address(host, port)

host: str

Alias for field number 0

port: int

Alias for field number 1

class ConnectionInfo(typing.NamedTuple):
27class ConnectionInfo(typing.NamedTuple):
28    name: str | None
29    local_addr: Address
30    remote_addr: Address

ConnectionInfo(name, local_addr, remote_addr)

ConnectionInfo( name: str | None, local_addr: Address, remote_addr: Address)

Create new instance of ConnectionInfo(name, local_addr, remote_addr)

name: str | None

Alias for field number 0

local_addr: Address

Alias for field number 1

remote_addr: Address

Alias for field number 2

class ServerInfo(typing.NamedTuple):
33class ServerInfo(typing.NamedTuple):
34    name: str | None
35    addresses: list[Address]

ServerInfo(name, addresses)

ServerInfo(name: str | None, addresses: list[Address])

Create new instance of ServerInfo(name, addresses)

name: str | None

Alias for field number 0

addresses: list[Address]

Alias for field number 1

ConnectionCb: TypeAlias = Callable[[ForwardRef('Connection')], None | Awaitable[None]]

Connection callback

async def connect( addr: Address, *, name: str | None = None, input_buffer_limit: int = 65536, **kwargs) -> Connection:
42async def connect(addr: Address,
43                  *,
44                  name: str | None = None,
45                  input_buffer_limit: int = 64 * 1024,
46                  **kwargs
47                  ) -> 'Connection':
48    """Create TCP connection
49
50    Argument `addr` specifies remote server listening address.
51
52    Argument `name` defines connection name available in property `info`.
53
54    Argument `input_buffer_limit` defines number of bytes in input buffer
55    that whill temporary pause data receiving. Once number of bytes
56    drops bellow `input_buffer_limit`, data receiving is resumed. If this
57    argument is ``0``, data receive pausing is disabled.
58
59    Additional arguments are passed directly to `asyncio.create_connection`.
60
61    """
62    loop = asyncio.get_running_loop()
63    create_protocol = functools.partial(Protocol, None, name,
64                                        input_buffer_limit)
65    _, protocol = await loop.create_connection(create_protocol,
66                                               addr.host, addr.port,
67                                               **kwargs)
68    return Connection(protocol)

Create TCP connection

Argument addr specifies remote server listening address.

Argument name defines connection name available in property info.

Argument input_buffer_limit defines number of bytes in input buffer that whill temporary pause data receiving. Once number of bytes drops bellow input_buffer_limit, data receiving is resumed. If this argument is 0, data receive pausing is disabled.

Additional arguments are passed directly to asyncio.create_connection.

async def listen( connection_cb: Callable[[Connection], None | Awaitable[None]], addr: Address, *, name: str | None = None, bind_connections: bool = False, input_buffer_limit: int = 65536, **kwargs) -> Server:
 71async def listen(connection_cb: ConnectionCb,
 72                 addr: Address,
 73                 *,
 74                 name: str | None = None,
 75                 bind_connections: bool = False,
 76                 input_buffer_limit: int = 64 * 1024,
 77                 **kwargs
 78                 ) -> 'Server':
 79    """Create listening server
 80
 81    Argument `name` defines server name available in property `info`. This
 82    name is used for all incomming connections.
 83
 84    If `bind_connections` is ``True``, closing server will close all open
 85    incoming connections.
 86
 87    Argument `input_buffer_limit` is associated with newly created connections
 88    (see `connect`).
 89
 90    Additional arguments are passed directly to `asyncio.create_server`.
 91
 92    """
 93    server = Server()
 94    server._connection_cb = connection_cb
 95    server._bind_connections = bind_connections
 96    server._async_group = aio.Group()
 97    server._log = _create_server_logger(name, None)
 98
 99    on_connection = functools.partial(server.async_group.spawn,
100                                      server._on_connection)
101    create_protocol = functools.partial(Protocol, on_connection, name,
102                                        input_buffer_limit)
103
104    loop = asyncio.get_running_loop()
105    server._srv = await loop.create_server(create_protocol, addr.host,
106                                           addr.port, **kwargs)
107
108    server.async_group.spawn(aio.call_on_cancel, server._on_close)
109
110    try:
111        socknames = (socket.getsockname() for socket in server._srv.sockets)
112        addresses = [Address(*sockname[:2]) for sockname in socknames]
113        server._info = ServerInfo(name=name,
114                                  addresses=addresses)
115        server._log = _create_server_logger(name, server._info)
116
117    except Exception:
118        await aio.uncancellable(server.async_close())
119        raise
120
121    server._log.debug('listening for incomming connections')
122
123    return server

Create listening server

Argument name defines server name available in property info. This name is used for all incomming connections.

If bind_connections is True, closing server will close all open incoming connections.

Argument input_buffer_limit is associated with newly created connections (see connect).

Additional arguments are passed directly to asyncio.create_server.

class Server(hat.aio.group.Resource):
126class Server(aio.Resource):
127    """TCP listening server
128
129    Closing server will cancel all running `connection_cb` coroutines.
130
131    """
132
133    @property
134    def async_group(self) -> aio.Group:
135        """Async group"""
136        return self._async_group
137
138    @property
139    def info(self) -> ServerInfo:
140        """Server info"""
141        return self._info
142
143    async def _on_close(self):
144        self._srv.close()
145
146        if self._bind_connections or sys.version_info[:2] < (3, 12):
147            await self._srv.wait_closed()
148
149    async def _on_connection(self, protocol):
150        self._log.debug('new incomming connection')
151
152        conn = Connection(protocol)
153
154        try:
155            await aio.call(self._connection_cb, conn)
156
157            if self._bind_connections:
158                await conn.wait_closing()
159
160            else:
161                conn = None
162
163        except Exception as e:
164            self._log.warning('connection callback error: %s', e, exc_info=e)
165
166        finally:
167            if conn:
168                await aio.uncancellable(conn.async_close())

TCP listening server

Closing server will cancel all running connection_cb coroutines.

async_group: hat.aio.group.Group
133    @property
134    def async_group(self) -> aio.Group:
135        """Async group"""
136        return self._async_group

Async group

info: ServerInfo
138    @property
139    def info(self) -> ServerInfo:
140        """Server info"""
141        return self._info

Server info

class Connection(hat.aio.group.Resource):
171class Connection(aio.Resource):
172    """TCP connection"""
173
174    def __init__(self, protocol: 'Protocol'):
175        self._protocol = protocol
176        self._async_group = aio.Group()
177        self._log = _create_connection_logger(protocol.info.name,
178                                              protocol.info)
179        self._comm_log = _CommunicationLogger(protocol.info.name,
180                                              protocol.info)
181
182        self.async_group.spawn(aio.call_on_cancel, protocol.async_close)
183        self.async_group.spawn(aio.call_on_done, protocol.wait_closed(),
184                               self.close)
185
186        self._comm_log.log(common.CommLogAction.OPEN)
187
188    @property
189    def async_group(self) -> aio.Group:
190        """Async group"""
191        return self._async_group
192
193    @property
194    def info(self) -> ConnectionInfo:
195        """Connection info"""
196        return self._protocol.info
197
198    @property
199    def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None:
200        """SSL Object"""
201        return self._protocol.ssl_object
202
203    async def write(self, data: util.Bytes):
204        """Write data
205
206        This coroutine will wait until `data` can be added to output buffer.
207
208        """
209        if not self.is_open:
210            raise ConnectionError()
211
212        await self._protocol.write(data)
213
214    async def drain(self):
215        """Drain output buffer"""
216        await self._protocol.drain()
217
218    async def read(self, n: int = -1) -> util.Bytes:
219        """Read up to `n` bytes
220
221        If EOF is detected and no new bytes are available, `ConnectionError`
222        is raised.
223
224        """
225        return await self._protocol.read(n)
226
227    async def readexactly(self, n: int) -> util.Bytes:
228        """Read exactly `n` bytes
229
230        If exact number of bytes could not be read, `ConnectionError` is
231        raised.
232
233        """
234        return await self._protocol.readexactly(n)
235
236    def clear_input_buffer(self) -> int:
237        """Clear input buffer
238
239        Returns number of bytes cleared from buffer.
240
241        """
242        return self._protocol.clear_input_buffer()

TCP connection

Connection(protocol: Protocol)
174    def __init__(self, protocol: 'Protocol'):
175        self._protocol = protocol
176        self._async_group = aio.Group()
177        self._log = _create_connection_logger(protocol.info.name,
178                                              protocol.info)
179        self._comm_log = _CommunicationLogger(protocol.info.name,
180                                              protocol.info)
181
182        self.async_group.spawn(aio.call_on_cancel, protocol.async_close)
183        self.async_group.spawn(aio.call_on_done, protocol.wait_closed(),
184                               self.close)
185
186        self._comm_log.log(common.CommLogAction.OPEN)
async_group: hat.aio.group.Group
188    @property
189    def async_group(self) -> aio.Group:
190        """Async group"""
191        return self._async_group

Async group

info: ConnectionInfo
193    @property
194    def info(self) -> ConnectionInfo:
195        """Connection info"""
196        return self._protocol.info

Connection info

ssl_object: ssl.SSLObject | ssl.SSLSocket | None
198    @property
199    def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None:
200        """SSL Object"""
201        return self._protocol.ssl_object

SSL Object

async def write(self, data: bytes | bytearray | memoryview):
203    async def write(self, data: util.Bytes):
204        """Write data
205
206        This coroutine will wait until `data` can be added to output buffer.
207
208        """
209        if not self.is_open:
210            raise ConnectionError()
211
212        await self._protocol.write(data)

Write data

This coroutine will wait until data can be added to output buffer.

async def drain(self):
214    async def drain(self):
215        """Drain output buffer"""
216        await self._protocol.drain()

Drain output buffer

async def read(self, n: int = -1) -> bytes | bytearray | memoryview:
218    async def read(self, n: int = -1) -> util.Bytes:
219        """Read up to `n` bytes
220
221        If EOF is detected and no new bytes are available, `ConnectionError`
222        is raised.
223
224        """
225        return await self._protocol.read(n)

Read up to n bytes

If EOF is detected and no new bytes are available, ConnectionError is raised.

async def readexactly(self, n: int) -> bytes | bytearray | memoryview:
227    async def readexactly(self, n: int) -> util.Bytes:
228        """Read exactly `n` bytes
229
230        If exact number of bytes could not be read, `ConnectionError` is
231        raised.
232
233        """
234        return await self._protocol.readexactly(n)

Read exactly n bytes

If exact number of bytes could not be read, ConnectionError is raised.

def clear_input_buffer(self) -> int:
236    def clear_input_buffer(self) -> int:
237        """Clear input buffer
238
239        Returns number of bytes cleared from buffer.
240
241        """
242        return self._protocol.clear_input_buffer()

Clear input buffer

Returns number of bytes cleared from buffer.

class Protocol(asyncio.protocols.Protocol):
245class Protocol(asyncio.Protocol):
246    """Asyncio protocol implementation"""
247
248    def __init__(self,
249                 on_connected: typing.Callable[['Protocol'], None] | None,
250                 name: str | None,
251                 input_buffer_limit: int):
252        self._on_connected = on_connected
253        self._name = name
254        self._input_buffer_limit = input_buffer_limit
255        self._loop = asyncio.get_running_loop()
256        self._input_buffer = util.BytesBuffer()
257        self._transport = None
258        self._read_queue = None
259        self._write_queue = None
260        self._drain_futures = None
261        self._closed_futures = None
262        self._info = None
263        self._ssl_object = None
264        self._log = _create_connection_logger(name, None)
265        self._comm_log = _CommunicationLogger(name, None)
266
267    @property
268    def info(self) -> ConnectionInfo:
269        return self._info
270
271    @property
272    def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None:
273        return self._ssl_object
274
275    def connection_made(self, transport: asyncio.Transport):
276        self._transport = transport
277        self._read_queue = collections.deque()
278        self._closed_futures = collections.deque()
279
280        try:
281            sockname = transport.get_extra_info('sockname')
282            peername = transport.get_extra_info('peername')
283            self._info = ConnectionInfo(
284                name=self._name,
285                local_addr=Address(sockname[0], sockname[1]),
286                remote_addr=Address(peername[0], peername[1]))
287
288            self._log = _create_connection_logger(self._name, self._info)
289            self._comm_log = _CommunicationLogger(self._name, self._info)
290
291            self._ssl_object = transport.get_extra_info('ssl_object')
292
293            if self._on_connected:
294                self._on_connected(self)
295
296        except Exception:
297            transport.abort()
298            return
299
300    def connection_lost(self, exc: Exception | None):
301        self._transport = None
302        write_queue, self._write_queue = self._write_queue, None
303        drain_futures, self._drain_futures = self._drain_futures, None
304        closed_futures, self._closed_futures = self._closed_futures, None
305
306        self.eof_received()
307
308        while write_queue:
309            _, future = write_queue.popleft()
310            if not future.done():
311                future.set_exception(ConnectionError())
312
313        while drain_futures:
314            future = drain_futures.popleft()
315            if not future.done():
316                future.set_result(None)
317
318        while closed_futures:
319            future = closed_futures.popleft()
320            if not future.done():
321                future.set_result(None)
322
323        self._comm_log.log(common.CommLogAction.CLOSE)
324
325    def pause_writing(self):
326        self._log.debug('pause writing')
327
328        self._write_queue = collections.deque()
329        self._drain_futures = collections.deque()
330
331    def resume_writing(self):
332        self._log.debug('resume writing')
333
334        write_queue, self._write_queue = self._write_queue, None
335        drain_futures, self._drain_futures = self._drain_futures, None
336
337        while self._write_queue is None and write_queue:
338            data, future = write_queue.popleft()
339            if future.done():
340                continue
341
342            self._comm_log.log(common.CommLogAction.SEND, data)
343
344            self._transport.write(data)
345            future.set_result(None)
346
347        if write_queue:
348            write_queue.extend(self._write_queue)
349            self._write_queue = write_queue
350
351            drain_futures.extend(self._drain_futures)
352            self._drain_futures = drain_futures
353
354            return
355
356        while drain_futures:
357            future = drain_futures.popleft()
358            if not future.done():
359                future.set_result(None)
360
361    def data_received(self, data: util.Bytes):
362        self._comm_log.log(common.CommLogAction.RECEIVE, data)
363
364        self._input_buffer.add(data)
365        self._process_input_buffer()
366
367    def eof_received(self):
368        self._log.debug('eof received')
369
370        while self._read_queue:
371            exact, n, future = self._read_queue.popleft()
372            if future.done():
373                continue
374
375            if exact and n <= len(self._input_buffer):
376                future.set_result(self._input_buffer.read(n))
377
378            elif not exact and self._input_buffer:
379                future.set_result(self._input_buffer.read(n))
380
381            else:
382                future.set_exception(ConnectionError())
383
384        self._read_queue = None
385
386    async def write(self, data: util.Bytes):
387        if self._transport is None:
388            raise ConnectionError()
389
390        if self._write_queue is None:
391            self._comm_log.log(common.CommLogAction.SEND, data)
392
393            self._transport.write(data)
394            return
395
396        future = self._loop.create_future()
397        self._write_queue.append((data, future))
398        await future
399
400    async def drain(self):
401        if self._drain_futures is None:
402            return
403
404        future = self._loop.create_future()
405        self._drain_futures.append(future)
406        await future
407
408    async def read(self, n: int) -> util.Bytes:
409        if n == 0:
410            return b''
411
412        if self._input_buffer and not self._read_queue:
413            data = self._input_buffer.read(n)
414            self._process_input_buffer()
415            return data
416
417        if self._read_queue is None:
418            raise ConnectionError()
419
420        future = self._loop.create_future()
421        future.add_done_callback(self._on_read_future_done)
422        self._read_queue.append((False, n, future))
423        return await future
424
425    async def readexactly(self, n: int) -> util.Bytes:
426        if n == 0:
427            return b''
428
429        if n <= len(self._input_buffer) and not self._read_queue:
430            data = self._input_buffer.read(n)
431            self._process_input_buffer()
432            return data
433
434        if self._read_queue is None:
435            raise ConnectionError()
436
437        future = self._loop.create_future()
438        future.add_done_callback(self._on_read_future_done)
439        self._read_queue.append((True, n, future))
440        self._process_input_buffer()
441        return await future
442
443    def clear_input_buffer(self) -> int:
444        count = self._input_buffer.clear()
445        self._transport.resume_reading()
446        return count
447
448    async def async_close(self):
449        if self._transport is not None:
450            self._transport.close()
451
452        await self.wait_closed()
453
454    async def wait_closed(self):
455        if self._closed_futures is None:
456            return
457
458        future = self._loop.create_future()
459        self._closed_futures.append(future)
460        await future
461
462    def _on_read_future_done(self, future):
463        if not self._read_queue:
464            return
465
466        if not future.cancelled():
467            return
468
469        for _ in range(len(self._read_queue)):
470            i = self._read_queue.popleft()
471            if not i[2].done():
472                self._read_queue.append(i)
473
474        self._process_input_buffer()
475
476    def _process_input_buffer(self):
477        while self._input_buffer and self._read_queue:
478            exact, n, future = self._read_queue.popleft()
479            if future.done():
480                continue
481
482            if not exact:
483                future.set_result(self._input_buffer.read(n))
484
485            elif n <= len(self._input_buffer):
486                future.set_result(self._input_buffer.read(n))
487
488            else:
489                self._read_queue.appendleft((exact, n, future))
490                break
491
492        if not self._transport:
493            return
494
495        pause = (self._input_buffer_limit > 0 and
496                 len(self._input_buffer) > self._input_buffer_limit and
497                 not self._read_queue)
498
499        if pause:
500            self._transport.pause_reading()
501
502        else:
503            self._transport.resume_reading()

Asyncio protocol implementation

Protocol( on_connected: Optional[Callable[[Protocol], NoneType]], name: str | None, input_buffer_limit: int)
248    def __init__(self,
249                 on_connected: typing.Callable[['Protocol'], None] | None,
250                 name: str | None,
251                 input_buffer_limit: int):
252        self._on_connected = on_connected
253        self._name = name
254        self._input_buffer_limit = input_buffer_limit
255        self._loop = asyncio.get_running_loop()
256        self._input_buffer = util.BytesBuffer()
257        self._transport = None
258        self._read_queue = None
259        self._write_queue = None
260        self._drain_futures = None
261        self._closed_futures = None
262        self._info = None
263        self._ssl_object = None
264        self._log = _create_connection_logger(name, None)
265        self._comm_log = _CommunicationLogger(name, None)
info: ConnectionInfo
267    @property
268    def info(self) -> ConnectionInfo:
269        return self._info
ssl_object: ssl.SSLObject | ssl.SSLSocket | None
271    @property
272    def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None:
273        return self._ssl_object
def connection_made(self, transport: asyncio.transports.Transport):
275    def connection_made(self, transport: asyncio.Transport):
276        self._transport = transport
277        self._read_queue = collections.deque()
278        self._closed_futures = collections.deque()
279
280        try:
281            sockname = transport.get_extra_info('sockname')
282            peername = transport.get_extra_info('peername')
283            self._info = ConnectionInfo(
284                name=self._name,
285                local_addr=Address(sockname[0], sockname[1]),
286                remote_addr=Address(peername[0], peername[1]))
287
288            self._log = _create_connection_logger(self._name, self._info)
289            self._comm_log = _CommunicationLogger(self._name, self._info)
290
291            self._ssl_object = transport.get_extra_info('ssl_object')
292
293            if self._on_connected:
294                self._on_connected(self)
295
296        except Exception:
297            transport.abort()
298            return

Called when a connection is made.

The argument is the transport representing the pipe connection. To receive data, wait for data_received() calls. When the connection is closed, connection_lost() is called.

def connection_lost(self, exc: Exception | None):
300    def connection_lost(self, exc: Exception | None):
301        self._transport = None
302        write_queue, self._write_queue = self._write_queue, None
303        drain_futures, self._drain_futures = self._drain_futures, None
304        closed_futures, self._closed_futures = self._closed_futures, None
305
306        self.eof_received()
307
308        while write_queue:
309            _, future = write_queue.popleft()
310            if not future.done():
311                future.set_exception(ConnectionError())
312
313        while drain_futures:
314            future = drain_futures.popleft()
315            if not future.done():
316                future.set_result(None)
317
318        while closed_futures:
319            future = closed_futures.popleft()
320            if not future.done():
321                future.set_result(None)
322
323        self._comm_log.log(common.CommLogAction.CLOSE)

Called when the connection is lost or closed.

The argument is an exception object or None (the latter meaning a regular EOF is received or the connection was aborted or closed).

def pause_writing(self):
325    def pause_writing(self):
326        self._log.debug('pause writing')
327
328        self._write_queue = collections.deque()
329        self._drain_futures = collections.deque()

Called when the transport's buffer goes over the high-water mark.

Pause and resume calls are paired -- pause_writing() is called once when the buffer goes strictly over the high-water mark (even if subsequent writes increases the buffer size even more), and eventually resume_writing() is called once when the buffer size reaches the low-water mark.

Note that if the buffer size equals the high-water mark, pause_writing() is not called -- it must go strictly over. Conversely, resume_writing() is called when the buffer size is equal or lower than the low-water mark. These end conditions are important to ensure that things go as expected when either mark is zero.

NOTE: This is the only Protocol callback that is not called through EventLoop.call_soon() -- if it were, it would have no effect when it's most needed (when the app keeps writing without yielding until pause_writing() is called).

def resume_writing(self):
331    def resume_writing(self):
332        self._log.debug('resume writing')
333
334        write_queue, self._write_queue = self._write_queue, None
335        drain_futures, self._drain_futures = self._drain_futures, None
336
337        while self._write_queue is None and write_queue:
338            data, future = write_queue.popleft()
339            if future.done():
340                continue
341
342            self._comm_log.log(common.CommLogAction.SEND, data)
343
344            self._transport.write(data)
345            future.set_result(None)
346
347        if write_queue:
348            write_queue.extend(self._write_queue)
349            self._write_queue = write_queue
350
351            drain_futures.extend(self._drain_futures)
352            self._drain_futures = drain_futures
353
354            return
355
356        while drain_futures:
357            future = drain_futures.popleft()
358            if not future.done():
359                future.set_result(None)

Called when the transport's buffer drains below the low-water mark.

See pause_writing() for details.

def data_received(self, data: bytes | bytearray | memoryview):
361    def data_received(self, data: util.Bytes):
362        self._comm_log.log(common.CommLogAction.RECEIVE, data)
363
364        self._input_buffer.add(data)
365        self._process_input_buffer()

Called when some data is received.

The argument is a bytes object.

def eof_received(self):
367    def eof_received(self):
368        self._log.debug('eof received')
369
370        while self._read_queue:
371            exact, n, future = self._read_queue.popleft()
372            if future.done():
373                continue
374
375            if exact and n <= len(self._input_buffer):
376                future.set_result(self._input_buffer.read(n))
377
378            elif not exact and self._input_buffer:
379                future.set_result(self._input_buffer.read(n))
380
381            else:
382                future.set_exception(ConnectionError())
383
384        self._read_queue = None

Called when the other end calls write_eof() or equivalent.

If this returns a false value (including None), the transport will close itself. If it returns a true value, closing the transport is up to the protocol.

async def write(self, data: bytes | bytearray | memoryview):
386    async def write(self, data: util.Bytes):
387        if self._transport is None:
388            raise ConnectionError()
389
390        if self._write_queue is None:
391            self._comm_log.log(common.CommLogAction.SEND, data)
392
393            self._transport.write(data)
394            return
395
396        future = self._loop.create_future()
397        self._write_queue.append((data, future))
398        await future
async def drain(self):
400    async def drain(self):
401        if self._drain_futures is None:
402            return
403
404        future = self._loop.create_future()
405        self._drain_futures.append(future)
406        await future
async def read(self, n: int) -> bytes | bytearray | memoryview:
408    async def read(self, n: int) -> util.Bytes:
409        if n == 0:
410            return b''
411
412        if self._input_buffer and not self._read_queue:
413            data = self._input_buffer.read(n)
414            self._process_input_buffer()
415            return data
416
417        if self._read_queue is None:
418            raise ConnectionError()
419
420        future = self._loop.create_future()
421        future.add_done_callback(self._on_read_future_done)
422        self._read_queue.append((False, n, future))
423        return await future
async def readexactly(self, n: int) -> bytes | bytearray | memoryview:
425    async def readexactly(self, n: int) -> util.Bytes:
426        if n == 0:
427            return b''
428
429        if n <= len(self._input_buffer) and not self._read_queue:
430            data = self._input_buffer.read(n)
431            self._process_input_buffer()
432            return data
433
434        if self._read_queue is None:
435            raise ConnectionError()
436
437        future = self._loop.create_future()
438        future.add_done_callback(self._on_read_future_done)
439        self._read_queue.append((True, n, future))
440        self._process_input_buffer()
441        return await future
def clear_input_buffer(self) -> int:
443    def clear_input_buffer(self) -> int:
444        count = self._input_buffer.clear()
445        self._transport.resume_reading()
446        return count
async def async_close(self):
448    async def async_close(self):
449        if self._transport is not None:
450            self._transport.close()
451
452        await self.wait_closed()
async def wait_closed(self):
454    async def wait_closed(self):
455        if self._closed_futures is None:
456            return
457
458        future = self._loop.create_future()
459        self._closed_futures.append(future)
460        await future