hat.drivers.tcp
Asyncio TCP wrapper
1"""Asyncio TCP wrapper""" 2 3import asyncio 4import collections 5import functools 6import logging 7import sys 8import typing 9 10from hat import aio 11from hat import util 12 13from hat.drivers import common 14from hat.drivers import ssl 15 16 17mlog: logging.Logger = logging.getLogger(__name__) 18"""Module logger""" 19 20 21class Address(typing.NamedTuple): 22 host: str 23 port: int 24 25 26class ConnectionInfo(typing.NamedTuple): 27 name: str | None 28 local_addr: Address 29 remote_addr: Address 30 31 32class ServerInfo(typing.NamedTuple): 33 name: str | None 34 addresses: list[Address] 35 36 37ConnectionCb: typing.TypeAlias = aio.AsyncCallable[['Connection'], None] 38"""Connection callback""" 39 40 41async def connect(addr: Address, 42 *, 43 name: str | None = None, 44 input_buffer_limit: int = 64 * 1024, 45 **kwargs 46 ) -> 'Connection': 47 """Create TCP connection 48 49 Argument `addr` specifies remote server listening address. 50 51 Argument `name` defines connection name available in property `info`. 52 53 Argument `input_buffer_limit` defines number of bytes in input buffer 54 that whill temporary pause data receiving. Once number of bytes 55 drops bellow `input_buffer_limit`, data receiving is resumed. If this 56 argument is ``0``, data receive pausing is disabled. 57 58 Additional arguments are passed directly to `asyncio.create_connection`. 59 60 """ 61 loop = asyncio.get_running_loop() 62 create_protocol = functools.partial(Protocol, None, name, 63 input_buffer_limit) 64 _, protocol = await loop.create_connection(create_protocol, 65 addr.host, addr.port, 66 **kwargs) 67 return Connection(protocol) 68 69 70async def listen(connection_cb: ConnectionCb, 71 addr: Address, 72 *, 73 name: str | None = None, 74 bind_connections: bool = False, 75 input_buffer_limit: int = 64 * 1024, 76 **kwargs 77 ) -> 'Server': 78 """Create listening server 79 80 Argument `name` defines server name available in property `info`. This 81 name is used for all incomming connections. 82 83 If `bind_connections` is ``True``, closing server will close all open 84 incoming connections. 85 86 Argument `input_buffer_limit` is associated with newly created connections 87 (see `connect`). 88 89 Additional arguments are passed directly to `asyncio.create_server`. 90 91 """ 92 server = Server() 93 server._connection_cb = connection_cb 94 server._bind_connections = bind_connections 95 server._async_group = aio.Group() 96 server._log = _create_server_logger(name, None) 97 98 on_connection = functools.partial(server.async_group.spawn, 99 server._on_connection) 100 create_protocol = functools.partial(Protocol, on_connection, name, 101 input_buffer_limit) 102 103 loop = asyncio.get_running_loop() 104 server._srv = await loop.create_server(create_protocol, addr.host, 105 addr.port, **kwargs) 106 107 server.async_group.spawn(aio.call_on_cancel, server._on_close) 108 109 try: 110 socknames = (socket.getsockname() for socket in server._srv.sockets) 111 addresses = [Address(*sockname[:2]) for sockname in socknames] 112 server._info = ServerInfo(name=name, 113 addresses=addresses) 114 server._log = _create_server_logger(name, server._info) 115 116 except Exception: 117 await aio.uncancellable(server.async_close()) 118 raise 119 120 server._log.debug('listening for incomming connections') 121 122 return server 123 124 125class Server(aio.Resource): 126 """TCP listening server 127 128 Closing server will cancel all running `connection_cb` coroutines. 129 130 """ 131 132 @property 133 def async_group(self) -> aio.Group: 134 """Async group""" 135 return self._async_group 136 137 @property 138 def info(self) -> ServerInfo: 139 """Server info""" 140 return self._info 141 142 async def _on_close(self): 143 self._srv.close() 144 145 if self._bind_connections or sys.version_info[:2] < (3, 12): 146 await self._srv.wait_closed() 147 148 async def _on_connection(self, protocol): 149 self._log.debug('new incomming connection') 150 151 conn = Connection(protocol) 152 153 try: 154 await aio.call(self._connection_cb, conn) 155 156 if self._bind_connections: 157 await conn.wait_closing() 158 159 else: 160 conn = None 161 162 except Exception as e: 163 self._log.warning('connection callback error: %s', e, exc_info=e) 164 165 finally: 166 if conn: 167 await aio.uncancellable(conn.async_close()) 168 169 170class Connection(aio.Resource): 171 """TCP connection""" 172 173 def __init__(self, protocol: 'Protocol'): 174 self._protocol = protocol 175 self._async_group = aio.Group() 176 self._log = _create_connection_logger(protocol.info.name, 177 protocol.info) 178 self._comm_log = _CommunicationLogger(protocol.info.name, 179 protocol.info) 180 181 self.async_group.spawn(aio.call_on_cancel, protocol.async_close) 182 self.async_group.spawn(aio.call_on_done, protocol.wait_closed(), 183 self.close) 184 185 self._comm_log.log(common.CommLogAction.OPEN) 186 187 @property 188 def async_group(self) -> aio.Group: 189 """Async group""" 190 return self._async_group 191 192 @property 193 def info(self) -> ConnectionInfo: 194 """Connection info""" 195 return self._protocol.info 196 197 @property 198 def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None: 199 """SSL Object""" 200 return self._protocol.ssl_object 201 202 async def write(self, data: util.Bytes): 203 """Write data 204 205 This coroutine will wait until `data` can be added to output buffer. 206 207 """ 208 if not self.is_open: 209 raise ConnectionError() 210 211 await self._protocol.write(data) 212 213 async def drain(self): 214 """Drain output buffer""" 215 await self._protocol.drain() 216 217 async def read(self, n: int = -1) -> util.Bytes: 218 """Read up to `n` bytes 219 220 If EOF is detected and no new bytes are available, `ConnectionError` 221 is raised. 222 223 """ 224 return await self._protocol.read(n) 225 226 async def readexactly(self, n: int) -> util.Bytes: 227 """Read exactly `n` bytes 228 229 If exact number of bytes could not be read, `ConnectionError` is 230 raised. 231 232 """ 233 return await self._protocol.readexactly(n) 234 235 def clear_input_buffer(self) -> int: 236 """Clear input buffer 237 238 Returns number of bytes cleared from buffer. 239 240 """ 241 return self._protocol.clear_input_buffer() 242 243 244class Protocol(asyncio.Protocol): 245 """Asyncio protocol implementation""" 246 247 def __init__(self, 248 on_connected: typing.Callable[['Protocol'], None] | None, 249 name: str | None, 250 input_buffer_limit: int): 251 self._on_connected = on_connected 252 self._name = name 253 self._input_buffer_limit = input_buffer_limit 254 self._loop = asyncio.get_running_loop() 255 self._input_buffer = util.BytesBuffer() 256 self._transport = None 257 self._read_queue = None 258 self._write_queue = None 259 self._drain_futures = None 260 self._closed_futures = None 261 self._info = None 262 self._ssl_object = None 263 self._log = _create_connection_logger(name, None) 264 self._comm_log = _CommunicationLogger(name, None) 265 266 @property 267 def info(self) -> ConnectionInfo: 268 return self._info 269 270 @property 271 def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None: 272 return self._ssl_object 273 274 def connection_made(self, transport: asyncio.Transport): 275 self._transport = transport 276 self._read_queue = collections.deque() 277 self._closed_futures = collections.deque() 278 279 try: 280 sockname = transport.get_extra_info('sockname') 281 peername = transport.get_extra_info('peername') 282 self._info = ConnectionInfo( 283 name=self._name, 284 local_addr=Address(sockname[0], sockname[1]), 285 remote_addr=Address(peername[0], peername[1])) 286 287 self._log = _create_connection_logger(self._name, self._info) 288 self._comm_log = _CommunicationLogger(self._name, self._info) 289 290 self._ssl_object = transport.get_extra_info('ssl_object') 291 292 if self._on_connected: 293 self._on_connected(self) 294 295 except Exception: 296 transport.abort() 297 return 298 299 def connection_lost(self, exc: Exception | None): 300 self._transport = None 301 write_queue, self._write_queue = self._write_queue, None 302 drain_futures, self._drain_futures = self._drain_futures, None 303 closed_futures, self._closed_futures = self._closed_futures, None 304 305 self.eof_received() 306 307 while write_queue: 308 _, future = write_queue.popleft() 309 if not future.done(): 310 future.set_exception(ConnectionError()) 311 312 while drain_futures: 313 future = drain_futures.popleft() 314 if not future.done(): 315 future.set_result(None) 316 317 while closed_futures: 318 future = closed_futures.popleft() 319 if not future.done(): 320 future.set_result(None) 321 322 self._comm_log.log(common.CommLogAction.CLOSE) 323 324 def pause_writing(self): 325 self._log.debug('pause writing') 326 327 self._write_queue = collections.deque() 328 self._drain_futures = collections.deque() 329 330 def resume_writing(self): 331 self._log.debug('resume writing') 332 333 write_queue, self._write_queue = self._write_queue, None 334 drain_futures, self._drain_futures = self._drain_futures, None 335 336 while self._write_queue is None and write_queue: 337 data, future = write_queue.popleft() 338 if future.done(): 339 continue 340 341 self._comm_log.log(common.CommLogAction.SEND, data) 342 343 self._transport.write(data) 344 future.set_result(None) 345 346 if write_queue: 347 write_queue.extend(self._write_queue) 348 self._write_queue = write_queue 349 350 drain_futures.extend(self._drain_futures) 351 self._drain_futures = drain_futures 352 353 return 354 355 while drain_futures: 356 future = drain_futures.popleft() 357 if not future.done(): 358 future.set_result(None) 359 360 def data_received(self, data: util.Bytes): 361 self._comm_log.log(common.CommLogAction.RECEIVE, data) 362 363 self._input_buffer.add(data) 364 self._process_input_buffer() 365 366 def eof_received(self): 367 self._log.debug('eof received') 368 369 while self._read_queue: 370 exact, n, future = self._read_queue.popleft() 371 if future.done(): 372 continue 373 374 if exact and n <= len(self._input_buffer): 375 future.set_result(self._input_buffer.read(n)) 376 377 elif not exact and self._input_buffer: 378 future.set_result(self._input_buffer.read(n)) 379 380 else: 381 future.set_exception(ConnectionError()) 382 383 self._read_queue = None 384 385 async def write(self, data: util.Bytes): 386 if self._transport is None: 387 raise ConnectionError() 388 389 if self._write_queue is None: 390 self._comm_log.log(common.CommLogAction.SEND, data) 391 392 self._transport.write(data) 393 return 394 395 future = self._loop.create_future() 396 self._write_queue.append((data, future)) 397 await future 398 399 async def drain(self): 400 if self._drain_futures is None: 401 return 402 403 future = self._loop.create_future() 404 self._drain_futures.append(future) 405 await future 406 407 async def read(self, n: int) -> util.Bytes: 408 if n == 0: 409 return b'' 410 411 if self._input_buffer and not self._read_queue: 412 data = self._input_buffer.read(n) 413 self._process_input_buffer() 414 return data 415 416 if self._read_queue is None: 417 raise ConnectionError() 418 419 future = self._loop.create_future() 420 future.add_done_callback(self._on_read_future_done) 421 self._read_queue.append((False, n, future)) 422 return await future 423 424 async def readexactly(self, n: int) -> util.Bytes: 425 if n == 0: 426 return b'' 427 428 if n <= len(self._input_buffer) and not self._read_queue: 429 data = self._input_buffer.read(n) 430 self._process_input_buffer() 431 return data 432 433 if self._read_queue is None: 434 raise ConnectionError() 435 436 future = self._loop.create_future() 437 future.add_done_callback(self._on_read_future_done) 438 self._read_queue.append((True, n, future)) 439 self._process_input_buffer() 440 return await future 441 442 def clear_input_buffer(self) -> int: 443 count = self._input_buffer.clear() 444 self._transport.resume_reading() 445 return count 446 447 async def async_close(self): 448 if self._transport is not None: 449 self._transport.close() 450 451 await self.wait_closed() 452 453 async def wait_closed(self): 454 if self._closed_futures is None: 455 return 456 457 future = self._loop.create_future() 458 self._closed_futures.append(future) 459 await future 460 461 def _on_read_future_done(self, future): 462 if not self._read_queue: 463 return 464 465 if not future.cancelled(): 466 return 467 468 for _ in range(len(self._read_queue)): 469 i = self._read_queue.popleft() 470 if not i[2].done(): 471 self._read_queue.append(i) 472 473 self._process_input_buffer() 474 475 def _process_input_buffer(self): 476 while self._input_buffer and self._read_queue: 477 exact, n, future = self._read_queue.popleft() 478 if future.done(): 479 continue 480 481 if not exact: 482 future.set_result(self._input_buffer.read(n)) 483 484 elif n <= len(self._input_buffer): 485 future.set_result(self._input_buffer.read(n)) 486 487 else: 488 self._read_queue.appendleft((exact, n, future)) 489 break 490 491 if not self._transport: 492 return 493 494 pause = (self._input_buffer_limit > 0 and 495 len(self._input_buffer) > self._input_buffer_limit and 496 not self._read_queue) 497 498 if pause: 499 self._transport.pause_reading() 500 501 else: 502 self._transport.resume_reading() 503 504 505def _create_server_logger(name, info): 506 extra = {'meta': {'type': 'TcpServer', 507 'name': name}} 508 509 if info is not None: 510 extra['meta']['addresses'] = [{'host': addr.host, 511 'port': addr.port} 512 for addr in info.addresses] 513 514 return logging.LoggerAdapter(mlog, extra) 515 516 517def _create_connection_logger(name, info): 518 extra = {'meta': {'type': 'TcpConnection', 519 'name': name}} 520 521 if info is not None: 522 extra['meta']['local_addr'] = {'host': info.local_addr.host, 523 'port': info.local_addr.port} 524 extra['meta']['remote_addr'] = {'host': info.remote_addr.host, 525 'port': info.remote_addr.port} 526 527 return logging.LoggerAdapter(mlog, extra) 528 529 530class _CommunicationLogger: 531 532 def __init__(self, 533 name: str | None, 534 info: ConnectionInfo | None): 535 extra = {'meta': {'type': 'TcpConnection', 536 'communication': True, 537 'name': name}} 538 539 if info is not None: 540 extra['meta']['local_addr'] = {'host': info.local_addr.host, 541 'port': info.local_addr.port} 542 extra['meta']['remote_addr'] = {'host': info.remote_addr.host, 543 'port': info.remote_addr.port} 544 545 self._log = logging.LoggerAdapter(mlog, extra) 546 547 def log(self, 548 action: common.CommLogAction, 549 data: util.Bytes | None = None): 550 if not self._log.isEnabledFor(logging.DEBUG): 551 return 552 553 if data is None: 554 self._log.debug(action.value, stacklevel=2) 555 556 else: 557 self._log.debug('%s (%s)', action.value, data.hex(' '), 558 stacklevel=2)
Module logger
Address(host, port)
27class ConnectionInfo(typing.NamedTuple): 28 name: str | None 29 local_addr: Address 30 remote_addr: Address
ConnectionInfo(name, local_addr, remote_addr)
ServerInfo(name, addresses)
Create new instance of ServerInfo(name, addresses)
Connection callback
42async def connect(addr: Address, 43 *, 44 name: str | None = None, 45 input_buffer_limit: int = 64 * 1024, 46 **kwargs 47 ) -> 'Connection': 48 """Create TCP connection 49 50 Argument `addr` specifies remote server listening address. 51 52 Argument `name` defines connection name available in property `info`. 53 54 Argument `input_buffer_limit` defines number of bytes in input buffer 55 that whill temporary pause data receiving. Once number of bytes 56 drops bellow `input_buffer_limit`, data receiving is resumed. If this 57 argument is ``0``, data receive pausing is disabled. 58 59 Additional arguments are passed directly to `asyncio.create_connection`. 60 61 """ 62 loop = asyncio.get_running_loop() 63 create_protocol = functools.partial(Protocol, None, name, 64 input_buffer_limit) 65 _, protocol = await loop.create_connection(create_protocol, 66 addr.host, addr.port, 67 **kwargs) 68 return Connection(protocol)
Create TCP connection
Argument addr specifies remote server listening address.
Argument name defines connection name available in property info.
Argument input_buffer_limit defines number of bytes in input buffer
that whill temporary pause data receiving. Once number of bytes
drops bellow input_buffer_limit, data receiving is resumed. If this
argument is 0, data receive pausing is disabled.
Additional arguments are passed directly to asyncio.create_connection.
71async def listen(connection_cb: ConnectionCb, 72 addr: Address, 73 *, 74 name: str | None = None, 75 bind_connections: bool = False, 76 input_buffer_limit: int = 64 * 1024, 77 **kwargs 78 ) -> 'Server': 79 """Create listening server 80 81 Argument `name` defines server name available in property `info`. This 82 name is used for all incomming connections. 83 84 If `bind_connections` is ``True``, closing server will close all open 85 incoming connections. 86 87 Argument `input_buffer_limit` is associated with newly created connections 88 (see `connect`). 89 90 Additional arguments are passed directly to `asyncio.create_server`. 91 92 """ 93 server = Server() 94 server._connection_cb = connection_cb 95 server._bind_connections = bind_connections 96 server._async_group = aio.Group() 97 server._log = _create_server_logger(name, None) 98 99 on_connection = functools.partial(server.async_group.spawn, 100 server._on_connection) 101 create_protocol = functools.partial(Protocol, on_connection, name, 102 input_buffer_limit) 103 104 loop = asyncio.get_running_loop() 105 server._srv = await loop.create_server(create_protocol, addr.host, 106 addr.port, **kwargs) 107 108 server.async_group.spawn(aio.call_on_cancel, server._on_close) 109 110 try: 111 socknames = (socket.getsockname() for socket in server._srv.sockets) 112 addresses = [Address(*sockname[:2]) for sockname in socknames] 113 server._info = ServerInfo(name=name, 114 addresses=addresses) 115 server._log = _create_server_logger(name, server._info) 116 117 except Exception: 118 await aio.uncancellable(server.async_close()) 119 raise 120 121 server._log.debug('listening for incomming connections') 122 123 return server
Create listening server
Argument name defines server name available in property info. This
name is used for all incomming connections.
If bind_connections is True, closing server will close all open
incoming connections.
Argument input_buffer_limit is associated with newly created connections
(see connect).
Additional arguments are passed directly to asyncio.create_server.
126class Server(aio.Resource): 127 """TCP listening server 128 129 Closing server will cancel all running `connection_cb` coroutines. 130 131 """ 132 133 @property 134 def async_group(self) -> aio.Group: 135 """Async group""" 136 return self._async_group 137 138 @property 139 def info(self) -> ServerInfo: 140 """Server info""" 141 return self._info 142 143 async def _on_close(self): 144 self._srv.close() 145 146 if self._bind_connections or sys.version_info[:2] < (3, 12): 147 await self._srv.wait_closed() 148 149 async def _on_connection(self, protocol): 150 self._log.debug('new incomming connection') 151 152 conn = Connection(protocol) 153 154 try: 155 await aio.call(self._connection_cb, conn) 156 157 if self._bind_connections: 158 await conn.wait_closing() 159 160 else: 161 conn = None 162 163 except Exception as e: 164 self._log.warning('connection callback error: %s', e, exc_info=e) 165 166 finally: 167 if conn: 168 await aio.uncancellable(conn.async_close())
TCP listening server
Closing server will cancel all running connection_cb coroutines.
133 @property 134 def async_group(self) -> aio.Group: 135 """Async group""" 136 return self._async_group
Async group
171class Connection(aio.Resource): 172 """TCP connection""" 173 174 def __init__(self, protocol: 'Protocol'): 175 self._protocol = protocol 176 self._async_group = aio.Group() 177 self._log = _create_connection_logger(protocol.info.name, 178 protocol.info) 179 self._comm_log = _CommunicationLogger(protocol.info.name, 180 protocol.info) 181 182 self.async_group.spawn(aio.call_on_cancel, protocol.async_close) 183 self.async_group.spawn(aio.call_on_done, protocol.wait_closed(), 184 self.close) 185 186 self._comm_log.log(common.CommLogAction.OPEN) 187 188 @property 189 def async_group(self) -> aio.Group: 190 """Async group""" 191 return self._async_group 192 193 @property 194 def info(self) -> ConnectionInfo: 195 """Connection info""" 196 return self._protocol.info 197 198 @property 199 def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None: 200 """SSL Object""" 201 return self._protocol.ssl_object 202 203 async def write(self, data: util.Bytes): 204 """Write data 205 206 This coroutine will wait until `data` can be added to output buffer. 207 208 """ 209 if not self.is_open: 210 raise ConnectionError() 211 212 await self._protocol.write(data) 213 214 async def drain(self): 215 """Drain output buffer""" 216 await self._protocol.drain() 217 218 async def read(self, n: int = -1) -> util.Bytes: 219 """Read up to `n` bytes 220 221 If EOF is detected and no new bytes are available, `ConnectionError` 222 is raised. 223 224 """ 225 return await self._protocol.read(n) 226 227 async def readexactly(self, n: int) -> util.Bytes: 228 """Read exactly `n` bytes 229 230 If exact number of bytes could not be read, `ConnectionError` is 231 raised. 232 233 """ 234 return await self._protocol.readexactly(n) 235 236 def clear_input_buffer(self) -> int: 237 """Clear input buffer 238 239 Returns number of bytes cleared from buffer. 240 241 """ 242 return self._protocol.clear_input_buffer()
TCP connection
174 def __init__(self, protocol: 'Protocol'): 175 self._protocol = protocol 176 self._async_group = aio.Group() 177 self._log = _create_connection_logger(protocol.info.name, 178 protocol.info) 179 self._comm_log = _CommunicationLogger(protocol.info.name, 180 protocol.info) 181 182 self.async_group.spawn(aio.call_on_cancel, protocol.async_close) 183 self.async_group.spawn(aio.call_on_done, protocol.wait_closed(), 184 self.close) 185 186 self._comm_log.log(common.CommLogAction.OPEN)
188 @property 189 def async_group(self) -> aio.Group: 190 """Async group""" 191 return self._async_group
Async group
193 @property 194 def info(self) -> ConnectionInfo: 195 """Connection info""" 196 return self._protocol.info
Connection info
198 @property 199 def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None: 200 """SSL Object""" 201 return self._protocol.ssl_object
SSL Object
203 async def write(self, data: util.Bytes): 204 """Write data 205 206 This coroutine will wait until `data` can be added to output buffer. 207 208 """ 209 if not self.is_open: 210 raise ConnectionError() 211 212 await self._protocol.write(data)
Write data
This coroutine will wait until data can be added to output buffer.
218 async def read(self, n: int = -1) -> util.Bytes: 219 """Read up to `n` bytes 220 221 If EOF is detected and no new bytes are available, `ConnectionError` 222 is raised. 223 224 """ 225 return await self._protocol.read(n)
Read up to n bytes
If EOF is detected and no new bytes are available, ConnectionError
is raised.
227 async def readexactly(self, n: int) -> util.Bytes: 228 """Read exactly `n` bytes 229 230 If exact number of bytes could not be read, `ConnectionError` is 231 raised. 232 233 """ 234 return await self._protocol.readexactly(n)
Read exactly n bytes
If exact number of bytes could not be read, ConnectionError is
raised.
245class Protocol(asyncio.Protocol): 246 """Asyncio protocol implementation""" 247 248 def __init__(self, 249 on_connected: typing.Callable[['Protocol'], None] | None, 250 name: str | None, 251 input_buffer_limit: int): 252 self._on_connected = on_connected 253 self._name = name 254 self._input_buffer_limit = input_buffer_limit 255 self._loop = asyncio.get_running_loop() 256 self._input_buffer = util.BytesBuffer() 257 self._transport = None 258 self._read_queue = None 259 self._write_queue = None 260 self._drain_futures = None 261 self._closed_futures = None 262 self._info = None 263 self._ssl_object = None 264 self._log = _create_connection_logger(name, None) 265 self._comm_log = _CommunicationLogger(name, None) 266 267 @property 268 def info(self) -> ConnectionInfo: 269 return self._info 270 271 @property 272 def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None: 273 return self._ssl_object 274 275 def connection_made(self, transport: asyncio.Transport): 276 self._transport = transport 277 self._read_queue = collections.deque() 278 self._closed_futures = collections.deque() 279 280 try: 281 sockname = transport.get_extra_info('sockname') 282 peername = transport.get_extra_info('peername') 283 self._info = ConnectionInfo( 284 name=self._name, 285 local_addr=Address(sockname[0], sockname[1]), 286 remote_addr=Address(peername[0], peername[1])) 287 288 self._log = _create_connection_logger(self._name, self._info) 289 self._comm_log = _CommunicationLogger(self._name, self._info) 290 291 self._ssl_object = transport.get_extra_info('ssl_object') 292 293 if self._on_connected: 294 self._on_connected(self) 295 296 except Exception: 297 transport.abort() 298 return 299 300 def connection_lost(self, exc: Exception | None): 301 self._transport = None 302 write_queue, self._write_queue = self._write_queue, None 303 drain_futures, self._drain_futures = self._drain_futures, None 304 closed_futures, self._closed_futures = self._closed_futures, None 305 306 self.eof_received() 307 308 while write_queue: 309 _, future = write_queue.popleft() 310 if not future.done(): 311 future.set_exception(ConnectionError()) 312 313 while drain_futures: 314 future = drain_futures.popleft() 315 if not future.done(): 316 future.set_result(None) 317 318 while closed_futures: 319 future = closed_futures.popleft() 320 if not future.done(): 321 future.set_result(None) 322 323 self._comm_log.log(common.CommLogAction.CLOSE) 324 325 def pause_writing(self): 326 self._log.debug('pause writing') 327 328 self._write_queue = collections.deque() 329 self._drain_futures = collections.deque() 330 331 def resume_writing(self): 332 self._log.debug('resume writing') 333 334 write_queue, self._write_queue = self._write_queue, None 335 drain_futures, self._drain_futures = self._drain_futures, None 336 337 while self._write_queue is None and write_queue: 338 data, future = write_queue.popleft() 339 if future.done(): 340 continue 341 342 self._comm_log.log(common.CommLogAction.SEND, data) 343 344 self._transport.write(data) 345 future.set_result(None) 346 347 if write_queue: 348 write_queue.extend(self._write_queue) 349 self._write_queue = write_queue 350 351 drain_futures.extend(self._drain_futures) 352 self._drain_futures = drain_futures 353 354 return 355 356 while drain_futures: 357 future = drain_futures.popleft() 358 if not future.done(): 359 future.set_result(None) 360 361 def data_received(self, data: util.Bytes): 362 self._comm_log.log(common.CommLogAction.RECEIVE, data) 363 364 self._input_buffer.add(data) 365 self._process_input_buffer() 366 367 def eof_received(self): 368 self._log.debug('eof received') 369 370 while self._read_queue: 371 exact, n, future = self._read_queue.popleft() 372 if future.done(): 373 continue 374 375 if exact and n <= len(self._input_buffer): 376 future.set_result(self._input_buffer.read(n)) 377 378 elif not exact and self._input_buffer: 379 future.set_result(self._input_buffer.read(n)) 380 381 else: 382 future.set_exception(ConnectionError()) 383 384 self._read_queue = None 385 386 async def write(self, data: util.Bytes): 387 if self._transport is None: 388 raise ConnectionError() 389 390 if self._write_queue is None: 391 self._comm_log.log(common.CommLogAction.SEND, data) 392 393 self._transport.write(data) 394 return 395 396 future = self._loop.create_future() 397 self._write_queue.append((data, future)) 398 await future 399 400 async def drain(self): 401 if self._drain_futures is None: 402 return 403 404 future = self._loop.create_future() 405 self._drain_futures.append(future) 406 await future 407 408 async def read(self, n: int) -> util.Bytes: 409 if n == 0: 410 return b'' 411 412 if self._input_buffer and not self._read_queue: 413 data = self._input_buffer.read(n) 414 self._process_input_buffer() 415 return data 416 417 if self._read_queue is None: 418 raise ConnectionError() 419 420 future = self._loop.create_future() 421 future.add_done_callback(self._on_read_future_done) 422 self._read_queue.append((False, n, future)) 423 return await future 424 425 async def readexactly(self, n: int) -> util.Bytes: 426 if n == 0: 427 return b'' 428 429 if n <= len(self._input_buffer) and not self._read_queue: 430 data = self._input_buffer.read(n) 431 self._process_input_buffer() 432 return data 433 434 if self._read_queue is None: 435 raise ConnectionError() 436 437 future = self._loop.create_future() 438 future.add_done_callback(self._on_read_future_done) 439 self._read_queue.append((True, n, future)) 440 self._process_input_buffer() 441 return await future 442 443 def clear_input_buffer(self) -> int: 444 count = self._input_buffer.clear() 445 self._transport.resume_reading() 446 return count 447 448 async def async_close(self): 449 if self._transport is not None: 450 self._transport.close() 451 452 await self.wait_closed() 453 454 async def wait_closed(self): 455 if self._closed_futures is None: 456 return 457 458 future = self._loop.create_future() 459 self._closed_futures.append(future) 460 await future 461 462 def _on_read_future_done(self, future): 463 if not self._read_queue: 464 return 465 466 if not future.cancelled(): 467 return 468 469 for _ in range(len(self._read_queue)): 470 i = self._read_queue.popleft() 471 if not i[2].done(): 472 self._read_queue.append(i) 473 474 self._process_input_buffer() 475 476 def _process_input_buffer(self): 477 while self._input_buffer and self._read_queue: 478 exact, n, future = self._read_queue.popleft() 479 if future.done(): 480 continue 481 482 if not exact: 483 future.set_result(self._input_buffer.read(n)) 484 485 elif n <= len(self._input_buffer): 486 future.set_result(self._input_buffer.read(n)) 487 488 else: 489 self._read_queue.appendleft((exact, n, future)) 490 break 491 492 if not self._transport: 493 return 494 495 pause = (self._input_buffer_limit > 0 and 496 len(self._input_buffer) > self._input_buffer_limit and 497 not self._read_queue) 498 499 if pause: 500 self._transport.pause_reading() 501 502 else: 503 self._transport.resume_reading()
Asyncio protocol implementation
248 def __init__(self, 249 on_connected: typing.Callable[['Protocol'], None] | None, 250 name: str | None, 251 input_buffer_limit: int): 252 self._on_connected = on_connected 253 self._name = name 254 self._input_buffer_limit = input_buffer_limit 255 self._loop = asyncio.get_running_loop() 256 self._input_buffer = util.BytesBuffer() 257 self._transport = None 258 self._read_queue = None 259 self._write_queue = None 260 self._drain_futures = None 261 self._closed_futures = None 262 self._info = None 263 self._ssl_object = None 264 self._log = _create_connection_logger(name, None) 265 self._comm_log = _CommunicationLogger(name, None)
275 def connection_made(self, transport: asyncio.Transport): 276 self._transport = transport 277 self._read_queue = collections.deque() 278 self._closed_futures = collections.deque() 279 280 try: 281 sockname = transport.get_extra_info('sockname') 282 peername = transport.get_extra_info('peername') 283 self._info = ConnectionInfo( 284 name=self._name, 285 local_addr=Address(sockname[0], sockname[1]), 286 remote_addr=Address(peername[0], peername[1])) 287 288 self._log = _create_connection_logger(self._name, self._info) 289 self._comm_log = _CommunicationLogger(self._name, self._info) 290 291 self._ssl_object = transport.get_extra_info('ssl_object') 292 293 if self._on_connected: 294 self._on_connected(self) 295 296 except Exception: 297 transport.abort() 298 return
Called when a connection is made.
The argument is the transport representing the pipe connection. To receive data, wait for data_received() calls. When the connection is closed, connection_lost() is called.
300 def connection_lost(self, exc: Exception | None): 301 self._transport = None 302 write_queue, self._write_queue = self._write_queue, None 303 drain_futures, self._drain_futures = self._drain_futures, None 304 closed_futures, self._closed_futures = self._closed_futures, None 305 306 self.eof_received() 307 308 while write_queue: 309 _, future = write_queue.popleft() 310 if not future.done(): 311 future.set_exception(ConnectionError()) 312 313 while drain_futures: 314 future = drain_futures.popleft() 315 if not future.done(): 316 future.set_result(None) 317 318 while closed_futures: 319 future = closed_futures.popleft() 320 if not future.done(): 321 future.set_result(None) 322 323 self._comm_log.log(common.CommLogAction.CLOSE)
Called when the connection is lost or closed.
The argument is an exception object or None (the latter meaning a regular EOF is received or the connection was aborted or closed).
325 def pause_writing(self): 326 self._log.debug('pause writing') 327 328 self._write_queue = collections.deque() 329 self._drain_futures = collections.deque()
Called when the transport's buffer goes over the high-water mark.
Pause and resume calls are paired -- pause_writing() is called once when the buffer goes strictly over the high-water mark (even if subsequent writes increases the buffer size even more), and eventually resume_writing() is called once when the buffer size reaches the low-water mark.
Note that if the buffer size equals the high-water mark, pause_writing() is not called -- it must go strictly over. Conversely, resume_writing() is called when the buffer size is equal or lower than the low-water mark. These end conditions are important to ensure that things go as expected when either mark is zero.
NOTE: This is the only Protocol callback that is not called through EventLoop.call_soon() -- if it were, it would have no effect when it's most needed (when the app keeps writing without yielding until pause_writing() is called).
331 def resume_writing(self): 332 self._log.debug('resume writing') 333 334 write_queue, self._write_queue = self._write_queue, None 335 drain_futures, self._drain_futures = self._drain_futures, None 336 337 while self._write_queue is None and write_queue: 338 data, future = write_queue.popleft() 339 if future.done(): 340 continue 341 342 self._comm_log.log(common.CommLogAction.SEND, data) 343 344 self._transport.write(data) 345 future.set_result(None) 346 347 if write_queue: 348 write_queue.extend(self._write_queue) 349 self._write_queue = write_queue 350 351 drain_futures.extend(self._drain_futures) 352 self._drain_futures = drain_futures 353 354 return 355 356 while drain_futures: 357 future = drain_futures.popleft() 358 if not future.done(): 359 future.set_result(None)
Called when the transport's buffer drains below the low-water mark.
See pause_writing() for details.
361 def data_received(self, data: util.Bytes): 362 self._comm_log.log(common.CommLogAction.RECEIVE, data) 363 364 self._input_buffer.add(data) 365 self._process_input_buffer()
Called when some data is received.
The argument is a bytes object.
367 def eof_received(self): 368 self._log.debug('eof received') 369 370 while self._read_queue: 371 exact, n, future = self._read_queue.popleft() 372 if future.done(): 373 continue 374 375 if exact and n <= len(self._input_buffer): 376 future.set_result(self._input_buffer.read(n)) 377 378 elif not exact and self._input_buffer: 379 future.set_result(self._input_buffer.read(n)) 380 381 else: 382 future.set_exception(ConnectionError()) 383 384 self._read_queue = None
Called when the other end calls write_eof() or equivalent.
If this returns a false value (including None), the transport will close itself. If it returns a true value, closing the transport is up to the protocol.
386 async def write(self, data: util.Bytes): 387 if self._transport is None: 388 raise ConnectionError() 389 390 if self._write_queue is None: 391 self._comm_log.log(common.CommLogAction.SEND, data) 392 393 self._transport.write(data) 394 return 395 396 future = self._loop.create_future() 397 self._write_queue.append((data, future)) 398 await future
408 async def read(self, n: int) -> util.Bytes: 409 if n == 0: 410 return b'' 411 412 if self._input_buffer and not self._read_queue: 413 data = self._input_buffer.read(n) 414 self._process_input_buffer() 415 return data 416 417 if self._read_queue is None: 418 raise ConnectionError() 419 420 future = self._loop.create_future() 421 future.add_done_callback(self._on_read_future_done) 422 self._read_queue.append((False, n, future)) 423 return await future
425 async def readexactly(self, n: int) -> util.Bytes: 426 if n == 0: 427 return b'' 428 429 if n <= len(self._input_buffer) and not self._read_queue: 430 data = self._input_buffer.read(n) 431 self._process_input_buffer() 432 return data 433 434 if self._read_queue is None: 435 raise ConnectionError() 436 437 future = self._loop.create_future() 438 future.add_done_callback(self._on_read_future_done) 439 self._read_queue.append((True, n, future)) 440 self._process_input_buffer() 441 return await future