hat.drivers.tcp
Asyncio TCP wrapper
1"""Asyncio TCP wrapper""" 2 3import asyncio 4import collections 5import functools 6import logging 7import sys 8import typing 9 10from hat import aio 11from hat import util 12 13from hat.drivers import ssl 14 15 16mlog: logging.Logger = logging.getLogger(__name__) 17"""Module logger""" 18 19 20class Address(typing.NamedTuple): 21 host: str 22 port: int 23 24 25class ConnectionInfo(typing.NamedTuple): 26 name: str | None 27 local_addr: Address 28 remote_addr: Address 29 30 31class ServerInfo(typing.NamedTuple): 32 name: str | None 33 addresses: list[Address] 34 35 36ConnectionCb: typing.TypeAlias = aio.AsyncCallable[['Connection'], None] 37"""Connection callback""" 38 39 40async def connect(addr: Address, 41 *, 42 name: str | None = None, 43 input_buffer_limit: int = 64 * 1024, 44 **kwargs 45 ) -> 'Connection': 46 """Create TCP connection 47 48 Argument `addr` specifies remote server listening address. 49 50 Argument `name` defines connection name available in property `info`. 51 52 Argument `input_buffer_limit` defines number of bytes in input buffer 53 that whill temporary pause data receiving. Once number of bytes 54 drops bellow `input_buffer_limit`, data receiving is resumed. If this 55 argument is ``0``, data receive pausing is disabled. 56 57 Additional arguments are passed directly to `asyncio.create_connection`. 58 59 """ 60 loop = asyncio.get_running_loop() 61 create_protocol = functools.partial(Protocol, None, name, 62 input_buffer_limit) 63 _, protocol = await loop.create_connection(create_protocol, 64 addr.host, addr.port, 65 **kwargs) 66 return Connection(protocol) 67 68 69async def listen(connection_cb: ConnectionCb, 70 addr: Address, 71 *, 72 name: str | None = None, 73 bind_connections: bool = False, 74 input_buffer_limit: int = 64 * 1024, 75 **kwargs 76 ) -> 'Server': 77 """Create listening server 78 79 Argument `name` defines server name available in property `info`. This 80 name is used for all incomming connections. 81 82 If `bind_connections` is ``True``, closing server will close all open 83 incoming connections. 84 85 Argument `input_buffer_limit` is associated with newly created connections 86 (see `connect`). 87 88 Additional arguments are passed directly to `asyncio.create_server`. 89 90 """ 91 server = Server() 92 server._connection_cb = connection_cb 93 server._bind_connections = bind_connections 94 server._async_group = aio.Group() 95 server._log = mlog 96 97 on_connection = functools.partial(server.async_group.spawn, 98 server._on_connection) 99 create_protocol = functools.partial(Protocol, on_connection, name, 100 input_buffer_limit) 101 102 loop = asyncio.get_running_loop() 103 server._srv = await loop.create_server(create_protocol, addr.host, 104 addr.port, **kwargs) 105 106 server.async_group.spawn(aio.call_on_cancel, server._on_close) 107 108 try: 109 socknames = (socket.getsockname() for socket in server._srv.sockets) 110 addresses = [Address(*sockname[:2]) for sockname in socknames] 111 server._info = ServerInfo(name=name, 112 addresses=addresses) 113 server._log = _create_server_logger_adapter(server._info) 114 115 except Exception: 116 await aio.uncancellable(server.async_close()) 117 raise 118 119 server._log.debug('listening for incomming connections') 120 121 return server 122 123 124class Server(aio.Resource): 125 """TCP listening server 126 127 Closing server will cancel all running `connection_cb` coroutines. 128 129 """ 130 131 @property 132 def async_group(self) -> aio.Group: 133 """Async group""" 134 return self._async_group 135 136 @property 137 def info(self) -> ServerInfo: 138 """Server info""" 139 return self._info 140 141 async def _on_close(self): 142 self._srv.close() 143 144 if self._bind_connections or sys.version_info[:2] < (3, 12): 145 await self._srv.wait_closed() 146 147 async def _on_connection(self, protocol): 148 self._log.debug('new incomming connection') 149 150 conn = Connection(protocol) 151 152 try: 153 await aio.call(self._connection_cb, conn) 154 155 if self._bind_connections: 156 await conn.wait_closing() 157 158 else: 159 conn = None 160 161 except Exception as e: 162 self._log.warning('connection callback error: %s', e, exc_info=e) 163 164 finally: 165 if conn: 166 await aio.uncancellable(conn.async_close()) 167 168 169class Connection(aio.Resource): 170 """TCP connection""" 171 172 def __init__(self, protocol: 'Protocol'): 173 self._protocol = protocol 174 self._async_group = aio.Group() 175 self._log = _create_connection_logger_adapter(protocol.info) 176 177 self._log.debug('connection established') 178 179 self.async_group.spawn(aio.call_on_cancel, protocol.async_close) 180 self.async_group.spawn(aio.call_on_done, protocol.wait_closed(), 181 self.close) 182 183 @property 184 def async_group(self) -> aio.Group: 185 """Async group""" 186 return self._async_group 187 188 @property 189 def info(self) -> ConnectionInfo: 190 """Connection info""" 191 return self._protocol.info 192 193 @property 194 def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None: 195 """SSL Object""" 196 return self._protocol.ssl_object 197 198 async def write(self, data: util.Bytes): 199 """Write data 200 201 This coroutine will wait until `data` can be added to output buffer. 202 203 """ 204 if not self.is_open: 205 raise ConnectionError() 206 207 await self._protocol.write(data) 208 209 async def drain(self): 210 """Drain output buffer""" 211 await self._protocol.drain() 212 213 async def read(self, n: int = -1) -> util.Bytes: 214 """Read up to `n` bytes 215 216 If EOF is detected and no new bytes are available, `ConnectionError` 217 is raised. 218 219 """ 220 return await self._protocol.read(n) 221 222 async def readexactly(self, n: int) -> util.Bytes: 223 """Read exactly `n` bytes 224 225 If exact number of bytes could not be read, `ConnectionError` is 226 raised. 227 228 """ 229 return await self._protocol.readexactly(n) 230 231 def clear_input_buffer(self) -> int: 232 """Clear input buffer 233 234 Returns number of bytes cleared from buffer. 235 236 """ 237 return self._protocol.clear_input_buffer() 238 239 240class Protocol(asyncio.Protocol): 241 """Asyncio protocol implementation""" 242 243 def __init__(self, 244 on_connected: typing.Callable[['Protocol'], None] | None, 245 name: str | None, 246 input_buffer_limit: int): 247 self._on_connected = on_connected 248 self._name = name 249 self._input_buffer_limit = input_buffer_limit 250 self._loop = asyncio.get_running_loop() 251 self._input_buffer = util.BytesBuffer() 252 self._transport = None 253 self._read_queue = None 254 self._write_queue = None 255 self._drain_futures = None 256 self._closed_futures = None 257 self._info = None 258 self._ssl_object = None 259 self._log = mlog 260 261 @property 262 def info(self) -> ConnectionInfo: 263 return self._info 264 265 @property 266 def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None: 267 return self._ssl_object 268 269 def connection_made(self, transport: asyncio.Transport): 270 self._transport = transport 271 self._read_queue = collections.deque() 272 self._closed_futures = collections.deque() 273 274 try: 275 sockname = transport.get_extra_info('sockname') 276 peername = transport.get_extra_info('peername') 277 self._info = ConnectionInfo( 278 name=self._name, 279 local_addr=Address(sockname[0], sockname[1]), 280 remote_addr=Address(peername[0], peername[1])) 281 282 self._log = _create_connection_logger_adapter(self._info) 283 284 self._ssl_object = transport.get_extra_info('ssl_object') 285 286 if self._on_connected: 287 self._on_connected(self) 288 289 except Exception: 290 transport.abort() 291 return 292 293 def connection_lost(self, exc: Exception | None): 294 self._transport = None 295 write_queue, self._write_queue = self._write_queue, None 296 drain_futures, self._drain_futures = self._drain_futures, None 297 closed_futures, self._closed_futures = self._closed_futures, None 298 299 self.eof_received() 300 301 while write_queue: 302 _, future = write_queue.popleft() 303 if not future.done(): 304 future.set_exception(ConnectionError()) 305 306 while drain_futures: 307 future = drain_futures.popleft() 308 if not future.done(): 309 future.set_result(None) 310 311 while closed_futures: 312 future = closed_futures.popleft() 313 if not future.done(): 314 future.set_result(None) 315 316 self._log.debug('connection closed') 317 318 def pause_writing(self): 319 self._log.debug('pause writing') 320 321 self._write_queue = collections.deque() 322 self._drain_futures = collections.deque() 323 324 def resume_writing(self): 325 self._log.debug('resume writing') 326 327 write_queue, self._write_queue = self._write_queue, None 328 drain_futures, self._drain_futures = self._drain_futures, None 329 330 while self._write_queue is None and write_queue: 331 data, future = write_queue.popleft() 332 if future.done(): 333 continue 334 335 self._log.debug('writing %s bytes', len(data)) 336 self._transport.write(data) 337 future.set_result(None) 338 339 if write_queue: 340 write_queue.extend(self._write_queue) 341 self._write_queue = write_queue 342 343 drain_futures.extend(self._drain_futures) 344 self._drain_futures = drain_futures 345 346 return 347 348 while drain_futures: 349 future = drain_futures.popleft() 350 if not future.done(): 351 future.set_result(None) 352 353 def data_received(self, data: util.Bytes): 354 self._log.debug('received %s bytes', len(data)) 355 356 self._input_buffer.add(data) 357 self._process_input_buffer() 358 359 def eof_received(self): 360 self._log.debug('eof received') 361 362 while self._read_queue: 363 exact, n, future = self._read_queue.popleft() 364 if future.done(): 365 continue 366 367 if exact and n <= len(self._input_buffer): 368 future.set_result(self._input_buffer.read(n)) 369 370 elif not exact and self._input_buffer: 371 future.set_result(self._input_buffer.read(n)) 372 373 else: 374 future.set_exception(ConnectionError()) 375 376 self._read_queue = None 377 378 async def write(self, data: util.Bytes): 379 if self._transport is None: 380 raise ConnectionError() 381 382 if self._write_queue is None: 383 self._log.debug('writing %s bytes', len(data)) 384 self._transport.write(data) 385 return 386 387 future = self._loop.create_future() 388 self._write_queue.append((data, future)) 389 await future 390 391 async def drain(self): 392 if self._drain_futures is None: 393 return 394 395 future = self._loop.create_future() 396 self._drain_futures.append(future) 397 await future 398 399 async def read(self, n: int) -> util.Bytes: 400 if n == 0: 401 return b'' 402 403 if self._input_buffer and not self._read_queue: 404 data = self._input_buffer.read(n) 405 self._process_input_buffer() 406 return data 407 408 if self._read_queue is None: 409 raise ConnectionError() 410 411 future = self._loop.create_future() 412 future.add_done_callback(self._on_read_future_done) 413 self._read_queue.append((False, n, future)) 414 return await future 415 416 async def readexactly(self, n: int) -> util.Bytes: 417 if n == 0: 418 return b'' 419 420 if n <= len(self._input_buffer) and not self._read_queue: 421 data = self._input_buffer.read(n) 422 self._process_input_buffer() 423 return data 424 425 if self._read_queue is None: 426 raise ConnectionError() 427 428 future = self._loop.create_future() 429 future.add_done_callback(self._on_read_future_done) 430 self._read_queue.append((True, n, future)) 431 self._process_input_buffer() 432 return await future 433 434 def clear_input_buffer(self) -> int: 435 count = self._input_buffer.clear() 436 self._transport.resume_reading() 437 return count 438 439 async def async_close(self): 440 if self._transport is not None: 441 self._transport.close() 442 443 await self.wait_closed() 444 445 async def wait_closed(self): 446 if self._closed_futures is None: 447 return 448 449 future = self._loop.create_future() 450 self._closed_futures.append(future) 451 await future 452 453 def _on_read_future_done(self, future): 454 if not self._read_queue: 455 return 456 457 if not future.cancelled(): 458 return 459 460 for _ in range(len(self._read_queue)): 461 i = self._read_queue.popleft() 462 if not i[2].done(): 463 self._read_queue.append(i) 464 465 self._process_input_buffer() 466 467 def _process_input_buffer(self): 468 while self._input_buffer and self._read_queue: 469 exact, n, future = self._read_queue.popleft() 470 if future.done(): 471 continue 472 473 if not exact: 474 future.set_result(self._input_buffer.read(n)) 475 476 elif n <= len(self._input_buffer): 477 future.set_result(self._input_buffer.read(n)) 478 479 else: 480 self._read_queue.appendleft((exact, n, future)) 481 break 482 483 if not self._transport: 484 return 485 486 pause = (self._input_buffer_limit > 0 and 487 len(self._input_buffer) > self._input_buffer_limit and 488 not self._read_queue) 489 490 if pause: 491 self._transport.pause_reading() 492 493 else: 494 self._transport.resume_reading() 495 496 497def _create_server_logger_adapter(info): 498 extra = {'meta': {'type': 'TcpServer', 499 'name': info.name, 500 'addresses': [{'host': addr.host, 501 'port': addr.port} 502 for addr in info.addresses]}} 503 504 return logging.LoggerAdapter(mlog, extra) 505 506 507def _create_connection_logger_adapter(info): 508 extra = {'meta': {'type': 'TcpConnection', 509 'name': info.name, 510 'local_addr': {'host': info.local_addr.host, 511 'port': info.local_addr.port}, 512 'remote_addr': {'host': info.remote_addr.host, 513 'port': info.remote_addr.port}}} 514 515 return logging.LoggerAdapter(mlog, extra)
Module logger
Address(host, port)
26class ConnectionInfo(typing.NamedTuple): 27 name: str | None 28 local_addr: Address 29 remote_addr: Address
ConnectionInfo(name, local_addr, remote_addr)
ServerInfo(name, addresses)
Create new instance of ServerInfo(name, addresses)
Connection callback
41async def connect(addr: Address, 42 *, 43 name: str | None = None, 44 input_buffer_limit: int = 64 * 1024, 45 **kwargs 46 ) -> 'Connection': 47 """Create TCP connection 48 49 Argument `addr` specifies remote server listening address. 50 51 Argument `name` defines connection name available in property `info`. 52 53 Argument `input_buffer_limit` defines number of bytes in input buffer 54 that whill temporary pause data receiving. Once number of bytes 55 drops bellow `input_buffer_limit`, data receiving is resumed. If this 56 argument is ``0``, data receive pausing is disabled. 57 58 Additional arguments are passed directly to `asyncio.create_connection`. 59 60 """ 61 loop = asyncio.get_running_loop() 62 create_protocol = functools.partial(Protocol, None, name, 63 input_buffer_limit) 64 _, protocol = await loop.create_connection(create_protocol, 65 addr.host, addr.port, 66 **kwargs) 67 return Connection(protocol)
Create TCP connection
Argument addr specifies remote server listening address.
Argument name defines connection name available in property info.
Argument input_buffer_limit defines number of bytes in input buffer
that whill temporary pause data receiving. Once number of bytes
drops bellow input_buffer_limit, data receiving is resumed. If this
argument is 0, data receive pausing is disabled.
Additional arguments are passed directly to asyncio.create_connection.
70async def listen(connection_cb: ConnectionCb, 71 addr: Address, 72 *, 73 name: str | None = None, 74 bind_connections: bool = False, 75 input_buffer_limit: int = 64 * 1024, 76 **kwargs 77 ) -> 'Server': 78 """Create listening server 79 80 Argument `name` defines server name available in property `info`. This 81 name is used for all incomming connections. 82 83 If `bind_connections` is ``True``, closing server will close all open 84 incoming connections. 85 86 Argument `input_buffer_limit` is associated with newly created connections 87 (see `connect`). 88 89 Additional arguments are passed directly to `asyncio.create_server`. 90 91 """ 92 server = Server() 93 server._connection_cb = connection_cb 94 server._bind_connections = bind_connections 95 server._async_group = aio.Group() 96 server._log = mlog 97 98 on_connection = functools.partial(server.async_group.spawn, 99 server._on_connection) 100 create_protocol = functools.partial(Protocol, on_connection, name, 101 input_buffer_limit) 102 103 loop = asyncio.get_running_loop() 104 server._srv = await loop.create_server(create_protocol, addr.host, 105 addr.port, **kwargs) 106 107 server.async_group.spawn(aio.call_on_cancel, server._on_close) 108 109 try: 110 socknames = (socket.getsockname() for socket in server._srv.sockets) 111 addresses = [Address(*sockname[:2]) for sockname in socknames] 112 server._info = ServerInfo(name=name, 113 addresses=addresses) 114 server._log = _create_server_logger_adapter(server._info) 115 116 except Exception: 117 await aio.uncancellable(server.async_close()) 118 raise 119 120 server._log.debug('listening for incomming connections') 121 122 return server
Create listening server
Argument name defines server name available in property info. This
name is used for all incomming connections.
If bind_connections is True, closing server will close all open
incoming connections.
Argument input_buffer_limit is associated with newly created connections
(see connect).
Additional arguments are passed directly to asyncio.create_server.
125class Server(aio.Resource): 126 """TCP listening server 127 128 Closing server will cancel all running `connection_cb` coroutines. 129 130 """ 131 132 @property 133 def async_group(self) -> aio.Group: 134 """Async group""" 135 return self._async_group 136 137 @property 138 def info(self) -> ServerInfo: 139 """Server info""" 140 return self._info 141 142 async def _on_close(self): 143 self._srv.close() 144 145 if self._bind_connections or sys.version_info[:2] < (3, 12): 146 await self._srv.wait_closed() 147 148 async def _on_connection(self, protocol): 149 self._log.debug('new incomming connection') 150 151 conn = Connection(protocol) 152 153 try: 154 await aio.call(self._connection_cb, conn) 155 156 if self._bind_connections: 157 await conn.wait_closing() 158 159 else: 160 conn = None 161 162 except Exception as e: 163 self._log.warning('connection callback error: %s', e, exc_info=e) 164 165 finally: 166 if conn: 167 await aio.uncancellable(conn.async_close())
TCP listening server
Closing server will cancel all running connection_cb coroutines.
132 @property 133 def async_group(self) -> aio.Group: 134 """Async group""" 135 return self._async_group
Async group
170class Connection(aio.Resource): 171 """TCP connection""" 172 173 def __init__(self, protocol: 'Protocol'): 174 self._protocol = protocol 175 self._async_group = aio.Group() 176 self._log = _create_connection_logger_adapter(protocol.info) 177 178 self._log.debug('connection established') 179 180 self.async_group.spawn(aio.call_on_cancel, protocol.async_close) 181 self.async_group.spawn(aio.call_on_done, protocol.wait_closed(), 182 self.close) 183 184 @property 185 def async_group(self) -> aio.Group: 186 """Async group""" 187 return self._async_group 188 189 @property 190 def info(self) -> ConnectionInfo: 191 """Connection info""" 192 return self._protocol.info 193 194 @property 195 def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None: 196 """SSL Object""" 197 return self._protocol.ssl_object 198 199 async def write(self, data: util.Bytes): 200 """Write data 201 202 This coroutine will wait until `data` can be added to output buffer. 203 204 """ 205 if not self.is_open: 206 raise ConnectionError() 207 208 await self._protocol.write(data) 209 210 async def drain(self): 211 """Drain output buffer""" 212 await self._protocol.drain() 213 214 async def read(self, n: int = -1) -> util.Bytes: 215 """Read up to `n` bytes 216 217 If EOF is detected and no new bytes are available, `ConnectionError` 218 is raised. 219 220 """ 221 return await self._protocol.read(n) 222 223 async def readexactly(self, n: int) -> util.Bytes: 224 """Read exactly `n` bytes 225 226 If exact number of bytes could not be read, `ConnectionError` is 227 raised. 228 229 """ 230 return await self._protocol.readexactly(n) 231 232 def clear_input_buffer(self) -> int: 233 """Clear input buffer 234 235 Returns number of bytes cleared from buffer. 236 237 """ 238 return self._protocol.clear_input_buffer()
TCP connection
173 def __init__(self, protocol: 'Protocol'): 174 self._protocol = protocol 175 self._async_group = aio.Group() 176 self._log = _create_connection_logger_adapter(protocol.info) 177 178 self._log.debug('connection established') 179 180 self.async_group.spawn(aio.call_on_cancel, protocol.async_close) 181 self.async_group.spawn(aio.call_on_done, protocol.wait_closed(), 182 self.close)
184 @property 185 def async_group(self) -> aio.Group: 186 """Async group""" 187 return self._async_group
Async group
189 @property 190 def info(self) -> ConnectionInfo: 191 """Connection info""" 192 return self._protocol.info
Connection info
194 @property 195 def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None: 196 """SSL Object""" 197 return self._protocol.ssl_object
SSL Object
199 async def write(self, data: util.Bytes): 200 """Write data 201 202 This coroutine will wait until `data` can be added to output buffer. 203 204 """ 205 if not self.is_open: 206 raise ConnectionError() 207 208 await self._protocol.write(data)
Write data
This coroutine will wait until data can be added to output buffer.
214 async def read(self, n: int = -1) -> util.Bytes: 215 """Read up to `n` bytes 216 217 If EOF is detected and no new bytes are available, `ConnectionError` 218 is raised. 219 220 """ 221 return await self._protocol.read(n)
Read up to n bytes
If EOF is detected and no new bytes are available, ConnectionError
is raised.
223 async def readexactly(self, n: int) -> util.Bytes: 224 """Read exactly `n` bytes 225 226 If exact number of bytes could not be read, `ConnectionError` is 227 raised. 228 229 """ 230 return await self._protocol.readexactly(n)
Read exactly n bytes
If exact number of bytes could not be read, ConnectionError is
raised.
241class Protocol(asyncio.Protocol): 242 """Asyncio protocol implementation""" 243 244 def __init__(self, 245 on_connected: typing.Callable[['Protocol'], None] | None, 246 name: str | None, 247 input_buffer_limit: int): 248 self._on_connected = on_connected 249 self._name = name 250 self._input_buffer_limit = input_buffer_limit 251 self._loop = asyncio.get_running_loop() 252 self._input_buffer = util.BytesBuffer() 253 self._transport = None 254 self._read_queue = None 255 self._write_queue = None 256 self._drain_futures = None 257 self._closed_futures = None 258 self._info = None 259 self._ssl_object = None 260 self._log = mlog 261 262 @property 263 def info(self) -> ConnectionInfo: 264 return self._info 265 266 @property 267 def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None: 268 return self._ssl_object 269 270 def connection_made(self, transport: asyncio.Transport): 271 self._transport = transport 272 self._read_queue = collections.deque() 273 self._closed_futures = collections.deque() 274 275 try: 276 sockname = transport.get_extra_info('sockname') 277 peername = transport.get_extra_info('peername') 278 self._info = ConnectionInfo( 279 name=self._name, 280 local_addr=Address(sockname[0], sockname[1]), 281 remote_addr=Address(peername[0], peername[1])) 282 283 self._log = _create_connection_logger_adapter(self._info) 284 285 self._ssl_object = transport.get_extra_info('ssl_object') 286 287 if self._on_connected: 288 self._on_connected(self) 289 290 except Exception: 291 transport.abort() 292 return 293 294 def connection_lost(self, exc: Exception | None): 295 self._transport = None 296 write_queue, self._write_queue = self._write_queue, None 297 drain_futures, self._drain_futures = self._drain_futures, None 298 closed_futures, self._closed_futures = self._closed_futures, None 299 300 self.eof_received() 301 302 while write_queue: 303 _, future = write_queue.popleft() 304 if not future.done(): 305 future.set_exception(ConnectionError()) 306 307 while drain_futures: 308 future = drain_futures.popleft() 309 if not future.done(): 310 future.set_result(None) 311 312 while closed_futures: 313 future = closed_futures.popleft() 314 if not future.done(): 315 future.set_result(None) 316 317 self._log.debug('connection closed') 318 319 def pause_writing(self): 320 self._log.debug('pause writing') 321 322 self._write_queue = collections.deque() 323 self._drain_futures = collections.deque() 324 325 def resume_writing(self): 326 self._log.debug('resume writing') 327 328 write_queue, self._write_queue = self._write_queue, None 329 drain_futures, self._drain_futures = self._drain_futures, None 330 331 while self._write_queue is None and write_queue: 332 data, future = write_queue.popleft() 333 if future.done(): 334 continue 335 336 self._log.debug('writing %s bytes', len(data)) 337 self._transport.write(data) 338 future.set_result(None) 339 340 if write_queue: 341 write_queue.extend(self._write_queue) 342 self._write_queue = write_queue 343 344 drain_futures.extend(self._drain_futures) 345 self._drain_futures = drain_futures 346 347 return 348 349 while drain_futures: 350 future = drain_futures.popleft() 351 if not future.done(): 352 future.set_result(None) 353 354 def data_received(self, data: util.Bytes): 355 self._log.debug('received %s bytes', len(data)) 356 357 self._input_buffer.add(data) 358 self._process_input_buffer() 359 360 def eof_received(self): 361 self._log.debug('eof received') 362 363 while self._read_queue: 364 exact, n, future = self._read_queue.popleft() 365 if future.done(): 366 continue 367 368 if exact and n <= len(self._input_buffer): 369 future.set_result(self._input_buffer.read(n)) 370 371 elif not exact and self._input_buffer: 372 future.set_result(self._input_buffer.read(n)) 373 374 else: 375 future.set_exception(ConnectionError()) 376 377 self._read_queue = None 378 379 async def write(self, data: util.Bytes): 380 if self._transport is None: 381 raise ConnectionError() 382 383 if self._write_queue is None: 384 self._log.debug('writing %s bytes', len(data)) 385 self._transport.write(data) 386 return 387 388 future = self._loop.create_future() 389 self._write_queue.append((data, future)) 390 await future 391 392 async def drain(self): 393 if self._drain_futures is None: 394 return 395 396 future = self._loop.create_future() 397 self._drain_futures.append(future) 398 await future 399 400 async def read(self, n: int) -> util.Bytes: 401 if n == 0: 402 return b'' 403 404 if self._input_buffer and not self._read_queue: 405 data = self._input_buffer.read(n) 406 self._process_input_buffer() 407 return data 408 409 if self._read_queue is None: 410 raise ConnectionError() 411 412 future = self._loop.create_future() 413 future.add_done_callback(self._on_read_future_done) 414 self._read_queue.append((False, n, future)) 415 return await future 416 417 async def readexactly(self, n: int) -> util.Bytes: 418 if n == 0: 419 return b'' 420 421 if n <= len(self._input_buffer) and not self._read_queue: 422 data = self._input_buffer.read(n) 423 self._process_input_buffer() 424 return data 425 426 if self._read_queue is None: 427 raise ConnectionError() 428 429 future = self._loop.create_future() 430 future.add_done_callback(self._on_read_future_done) 431 self._read_queue.append((True, n, future)) 432 self._process_input_buffer() 433 return await future 434 435 def clear_input_buffer(self) -> int: 436 count = self._input_buffer.clear() 437 self._transport.resume_reading() 438 return count 439 440 async def async_close(self): 441 if self._transport is not None: 442 self._transport.close() 443 444 await self.wait_closed() 445 446 async def wait_closed(self): 447 if self._closed_futures is None: 448 return 449 450 future = self._loop.create_future() 451 self._closed_futures.append(future) 452 await future 453 454 def _on_read_future_done(self, future): 455 if not self._read_queue: 456 return 457 458 if not future.cancelled(): 459 return 460 461 for _ in range(len(self._read_queue)): 462 i = self._read_queue.popleft() 463 if not i[2].done(): 464 self._read_queue.append(i) 465 466 self._process_input_buffer() 467 468 def _process_input_buffer(self): 469 while self._input_buffer and self._read_queue: 470 exact, n, future = self._read_queue.popleft() 471 if future.done(): 472 continue 473 474 if not exact: 475 future.set_result(self._input_buffer.read(n)) 476 477 elif n <= len(self._input_buffer): 478 future.set_result(self._input_buffer.read(n)) 479 480 else: 481 self._read_queue.appendleft((exact, n, future)) 482 break 483 484 if not self._transport: 485 return 486 487 pause = (self._input_buffer_limit > 0 and 488 len(self._input_buffer) > self._input_buffer_limit and 489 not self._read_queue) 490 491 if pause: 492 self._transport.pause_reading() 493 494 else: 495 self._transport.resume_reading()
Asyncio protocol implementation
244 def __init__(self, 245 on_connected: typing.Callable[['Protocol'], None] | None, 246 name: str | None, 247 input_buffer_limit: int): 248 self._on_connected = on_connected 249 self._name = name 250 self._input_buffer_limit = input_buffer_limit 251 self._loop = asyncio.get_running_loop() 252 self._input_buffer = util.BytesBuffer() 253 self._transport = None 254 self._read_queue = None 255 self._write_queue = None 256 self._drain_futures = None 257 self._closed_futures = None 258 self._info = None 259 self._ssl_object = None 260 self._log = mlog
270 def connection_made(self, transport: asyncio.Transport): 271 self._transport = transport 272 self._read_queue = collections.deque() 273 self._closed_futures = collections.deque() 274 275 try: 276 sockname = transport.get_extra_info('sockname') 277 peername = transport.get_extra_info('peername') 278 self._info = ConnectionInfo( 279 name=self._name, 280 local_addr=Address(sockname[0], sockname[1]), 281 remote_addr=Address(peername[0], peername[1])) 282 283 self._log = _create_connection_logger_adapter(self._info) 284 285 self._ssl_object = transport.get_extra_info('ssl_object') 286 287 if self._on_connected: 288 self._on_connected(self) 289 290 except Exception: 291 transport.abort() 292 return
Called when a connection is made.
The argument is the transport representing the pipe connection. To receive data, wait for data_received() calls. When the connection is closed, connection_lost() is called.
294 def connection_lost(self, exc: Exception | None): 295 self._transport = None 296 write_queue, self._write_queue = self._write_queue, None 297 drain_futures, self._drain_futures = self._drain_futures, None 298 closed_futures, self._closed_futures = self._closed_futures, None 299 300 self.eof_received() 301 302 while write_queue: 303 _, future = write_queue.popleft() 304 if not future.done(): 305 future.set_exception(ConnectionError()) 306 307 while drain_futures: 308 future = drain_futures.popleft() 309 if not future.done(): 310 future.set_result(None) 311 312 while closed_futures: 313 future = closed_futures.popleft() 314 if not future.done(): 315 future.set_result(None) 316 317 self._log.debug('connection closed')
Called when the connection is lost or closed.
The argument is an exception object or None (the latter meaning a regular EOF is received or the connection was aborted or closed).
319 def pause_writing(self): 320 self._log.debug('pause writing') 321 322 self._write_queue = collections.deque() 323 self._drain_futures = collections.deque()
Called when the transport's buffer goes over the high-water mark.
Pause and resume calls are paired -- pause_writing() is called once when the buffer goes strictly over the high-water mark (even if subsequent writes increases the buffer size even more), and eventually resume_writing() is called once when the buffer size reaches the low-water mark.
Note that if the buffer size equals the high-water mark, pause_writing() is not called -- it must go strictly over. Conversely, resume_writing() is called when the buffer size is equal or lower than the low-water mark. These end conditions are important to ensure that things go as expected when either mark is zero.
NOTE: This is the only Protocol callback that is not called through EventLoop.call_soon() -- if it were, it would have no effect when it's most needed (when the app keeps writing without yielding until pause_writing() is called).
325 def resume_writing(self): 326 self._log.debug('resume writing') 327 328 write_queue, self._write_queue = self._write_queue, None 329 drain_futures, self._drain_futures = self._drain_futures, None 330 331 while self._write_queue is None and write_queue: 332 data, future = write_queue.popleft() 333 if future.done(): 334 continue 335 336 self._log.debug('writing %s bytes', len(data)) 337 self._transport.write(data) 338 future.set_result(None) 339 340 if write_queue: 341 write_queue.extend(self._write_queue) 342 self._write_queue = write_queue 343 344 drain_futures.extend(self._drain_futures) 345 self._drain_futures = drain_futures 346 347 return 348 349 while drain_futures: 350 future = drain_futures.popleft() 351 if not future.done(): 352 future.set_result(None)
Called when the transport's buffer drains below the low-water mark.
See pause_writing() for details.
354 def data_received(self, data: util.Bytes): 355 self._log.debug('received %s bytes', len(data)) 356 357 self._input_buffer.add(data) 358 self._process_input_buffer()
Called when some data is received.
The argument is a bytes object.
360 def eof_received(self): 361 self._log.debug('eof received') 362 363 while self._read_queue: 364 exact, n, future = self._read_queue.popleft() 365 if future.done(): 366 continue 367 368 if exact and n <= len(self._input_buffer): 369 future.set_result(self._input_buffer.read(n)) 370 371 elif not exact and self._input_buffer: 372 future.set_result(self._input_buffer.read(n)) 373 374 else: 375 future.set_exception(ConnectionError()) 376 377 self._read_queue = None
Called when the other end calls write_eof() or equivalent.
If this returns a false value (including None), the transport will close itself. If it returns a true value, closing the transport is up to the protocol.
379 async def write(self, data: util.Bytes): 380 if self._transport is None: 381 raise ConnectionError() 382 383 if self._write_queue is None: 384 self._log.debug('writing %s bytes', len(data)) 385 self._transport.write(data) 386 return 387 388 future = self._loop.create_future() 389 self._write_queue.append((data, future)) 390 await future
400 async def read(self, n: int) -> util.Bytes: 401 if n == 0: 402 return b'' 403 404 if self._input_buffer and not self._read_queue: 405 data = self._input_buffer.read(n) 406 self._process_input_buffer() 407 return data 408 409 if self._read_queue is None: 410 raise ConnectionError() 411 412 future = self._loop.create_future() 413 future.add_done_callback(self._on_read_future_done) 414 self._read_queue.append((False, n, future)) 415 return await future
417 async def readexactly(self, n: int) -> util.Bytes: 418 if n == 0: 419 return b'' 420 421 if n <= len(self._input_buffer) and not self._read_queue: 422 data = self._input_buffer.read(n) 423 self._process_input_buffer() 424 return data 425 426 if self._read_queue is None: 427 raise ConnectionError() 428 429 future = self._loop.create_future() 430 future.add_done_callback(self._on_read_future_done) 431 self._read_queue.append((True, n, future)) 432 self._process_input_buffer() 433 return await future