hat.drivers.tcp
Asyncio TCP wrapper
1"""Asyncio TCP wrapper""" 2 3import asyncio 4import collections 5import functools 6import logging 7import sys 8import typing 9 10from hat import aio 11from hat import util 12 13from hat.drivers import ssl 14 15 16mlog: logging.Logger = logging.getLogger(__name__) 17"""Module logger""" 18 19 20class Address(typing.NamedTuple): 21 host: str 22 port: int 23 24 25class ConnectionInfo(typing.NamedTuple): 26 local_addr: Address 27 remote_addr: Address 28 29 30ConnectionCb: typing.TypeAlias = aio.AsyncCallable[['Connection'], None] 31"""Connection callback""" 32 33 34async def connect(addr: Address, 35 *, 36 input_buffer_limit: int = 64 * 1024, 37 **kwargs 38 ) -> 'Connection': 39 """Create TCP connection 40 41 Argument `addr` specifies remote server listening address. 42 43 Argument `input_buffer_limit` defines number of bytes in input buffer 44 that whill temporary pause data receiving. Once number of bytes 45 drops bellow `input_buffer_limit`, data receiving is resumed. If this 46 argument is ``0``, data receive pausing is disabled. 47 48 Additional arguments are passed directly to `asyncio.create_connection`. 49 50 """ 51 loop = asyncio.get_running_loop() 52 create_transport = functools.partial(Protocol, None, input_buffer_limit) 53 _, protocol = await loop.create_connection(create_transport, 54 addr.host, addr.port, 55 **kwargs) 56 return Connection(protocol) 57 58 59async def listen(connection_cb: ConnectionCb, 60 addr: Address, 61 *, 62 bind_connections: bool = False, 63 input_buffer_limit: int = 64 * 1024, 64 **kwargs 65 ) -> 'Server': 66 """Create listening server 67 68 If `bind_connections` is ``True``, closing server will close all open 69 incoming connections. 70 71 Argument `input_buffer_limit` is associated with newly created connections 72 (see `connect`). 73 74 Additional arguments are passed directly to `asyncio.create_server`. 75 76 """ 77 server = Server() 78 server._connection_cb = connection_cb 79 server._bind_connections = bind_connections 80 server._async_group = aio.Group() 81 82 on_connection = functools.partial(server.async_group.spawn, 83 server._on_connection) 84 create_transport = functools.partial(Protocol, on_connection, 85 input_buffer_limit) 86 87 loop = asyncio.get_running_loop() 88 server._srv = await loop.create_server(create_transport, addr.host, 89 addr.port, **kwargs) 90 91 server.async_group.spawn(aio.call_on_cancel, server._on_close) 92 93 try: 94 socknames = (socket.getsockname() for socket in server._srv.sockets) 95 server._addresses = [Address(*sockname[:2]) for sockname in socknames] 96 97 except Exception: 98 await aio.uncancellable(server.async_close()) 99 raise 100 101 return server 102 103 104class Server(aio.Resource): 105 """TCP listening server 106 107 Closing server will cancel all running `connection_cb` coroutines. 108 109 """ 110 111 @property 112 def async_group(self) -> aio.Group: 113 """Async group""" 114 return self._async_group 115 116 @property 117 def addresses(self) -> list[Address]: 118 """Listening addresses""" 119 return self._addresses 120 121 async def _on_close(self): 122 self._srv.close() 123 124 if self._bind_connections or sys.version_info[:2] < (3, 12): 125 await self._srv.wait_closed() 126 127 async def _on_connection(self, protocol): 128 conn = Connection(protocol) 129 130 try: 131 await aio.call(self._connection_cb, conn) 132 133 if self._bind_connections: 134 await conn.wait_closing() 135 136 else: 137 conn = None 138 139 except Exception as e: 140 mlog.warning('connection callback error: %s', e, exc_info=e) 141 142 finally: 143 if conn: 144 await aio.uncancellable(conn.async_close()) 145 146 147class Connection(aio.Resource): 148 """TCP connection""" 149 150 def __init__(self, protocol: 'Protocol'): 151 self._protocol = protocol 152 self._async_group = aio.Group() 153 154 self.async_group.spawn(aio.call_on_cancel, protocol.async_close) 155 self.async_group.spawn(aio.call_on_done, protocol.wait_closed(), 156 self.close) 157 158 @property 159 def async_group(self) -> aio.Group: 160 """Async group""" 161 return self._async_group 162 163 @property 164 def info(self) -> ConnectionInfo: 165 """Connection info""" 166 return self._protocol.info 167 168 @property 169 def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None: 170 """SSL Object""" 171 return self._protocol.ssl_object 172 173 async def write(self, data: util.Bytes): 174 """Write data 175 176 This coroutine will wait until `data` can be added to output buffer. 177 178 """ 179 if not self.is_open: 180 raise ConnectionError() 181 182 await self._protocol.write(data) 183 184 async def drain(self): 185 """Drain output buffer""" 186 await self._protocol.drain() 187 188 async def read(self, n: int = -1) -> util.Bytes: 189 """Read up to `n` bytes 190 191 If EOF is detected and no new bytes are available, `ConnectionError` 192 is raised. 193 194 """ 195 return await self._protocol.read(n) 196 197 async def readexactly(self, n: int) -> util.Bytes: 198 """Read exactly `n` bytes 199 200 If exact number of bytes could not be read, `ConnectionError` is 201 raised. 202 203 """ 204 return await self._protocol.readexactly(n) 205 206 def reset_input_buffer(self) -> int: 207 """Reset input buffer 208 209 Returns number of bytes cleared from buffer. 210 211 """ 212 return self._protocol.reset_input_buffer() 213 214 215class Protocol(asyncio.Protocol): 216 """Asyncio protocol implementation""" 217 218 def __init__(self, 219 on_connected: typing.Callable[['Protocol'], None] | None, 220 input_buffer_limit: int): 221 self._on_connected = on_connected 222 self._input_buffer_limit = input_buffer_limit 223 self._loop = asyncio.get_running_loop() 224 self._input_buffer = util.BytesBuffer() 225 self._transport = None 226 self._read_queue = None 227 self._write_queue = None 228 self._drain_futures = None 229 self._closed_futures = None 230 self._info = None 231 self._ssl_object = None 232 233 @property 234 def info(self) -> ConnectionInfo: 235 return self._info 236 237 @property 238 def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None: 239 return self._ssl_object 240 241 def connection_made(self, transport: asyncio.Transport): 242 self._transport = transport 243 self._read_queue = collections.deque() 244 self._closed_futures = collections.deque() 245 246 try: 247 sockname = transport.get_extra_info('sockname') 248 peername = transport.get_extra_info('peername') 249 self._info = ConnectionInfo( 250 local_addr=Address(sockname[0], sockname[1]), 251 remote_addr=Address(peername[0], peername[1])) 252 self._ssl_object = transport.get_extra_info('ssl_object') 253 254 if self._on_connected: 255 self._on_connected(self) 256 257 except Exception: 258 transport.abort() 259 return 260 261 def connection_lost(self, exc: Exception | None): 262 self._transport = None 263 write_queue, self._write_queue = self._write_queue, None 264 drain_futures, self._drain_futures = self._drain_futures, None 265 closed_futures, self._closed_futures = self._closed_futures, None 266 267 self.eof_received() 268 269 while write_queue: 270 _, future = write_queue.popleft() 271 if not future.done(): 272 future.set_exception(ConnectionError()) 273 274 while drain_futures: 275 future = drain_futures.popleft() 276 if not future.done(): 277 future.set_result(None) 278 279 while closed_futures: 280 future = closed_futures.popleft() 281 if not future.done(): 282 future.set_result(None) 283 284 def pause_writing(self): 285 self._write_queue = collections.deque() 286 self._drain_futures = collections.deque() 287 288 def resume_writing(self): 289 write_queue, self._write_queue = self._write_queue, None 290 drain_futures, self._drain_futures = self._drain_futures, None 291 292 while self._write_queue is None and write_queue: 293 data, future = write_queue.popleft() 294 if future.done(): 295 continue 296 297 self._transport.write(data) 298 future.set_result(None) 299 300 if write_queue: 301 write_queue.extend(self._write_queue) 302 self._write_queue = write_queue 303 304 drain_futures.extend(self._drain_futures) 305 self._drain_futures = drain_futures 306 307 return 308 309 while drain_futures: 310 future = drain_futures.popleft() 311 if not future.done(): 312 future.set_result(None) 313 314 def data_received(self, data: util.Bytes): 315 self._input_buffer.add(data) 316 self._process_input_buffer() 317 318 def eof_received(self): 319 while self._read_queue: 320 exact, n, future = self._read_queue.popleft() 321 if future.done(): 322 continue 323 324 if exact and n <= len(self._input_buffer): 325 future.set_result(self._input_buffer.read(n)) 326 327 elif not exact and self._input_buffer: 328 future.set_result(self._input_buffer.read(n)) 329 330 else: 331 future.set_exception(ConnectionError()) 332 333 self._read_queue = None 334 335 async def write(self, data: util.Bytes): 336 if self._transport is None: 337 raise ConnectionError() 338 339 if self._write_queue is None: 340 self._transport.write(data) 341 return 342 343 future = self._loop.create_future() 344 self._write_queue.append((data, future)) 345 await future 346 347 async def drain(self): 348 if self._drain_futures is None: 349 return 350 351 future = self._loop.create_future() 352 self._drain_futures.append(future) 353 await future 354 355 async def read(self, n: int) -> util.Bytes: 356 if n == 0: 357 return b'' 358 359 if self._input_buffer and not self._read_queue: 360 data = self._input_buffer.read(n) 361 self._process_input_buffer() 362 return data 363 364 if self._read_queue is None: 365 raise ConnectionError() 366 367 future = self._loop.create_future() 368 future.add_done_callback(self._on_read_future_done) 369 self._read_queue.append((False, n, future)) 370 return await future 371 372 async def readexactly(self, n: int) -> util.Bytes: 373 if n == 0: 374 return b'' 375 376 if n <= len(self._input_buffer) and not self._read_queue: 377 data = self._input_buffer.read(n) 378 self._process_input_buffer() 379 return data 380 381 if self._read_queue is None: 382 raise ConnectionError() 383 384 future = self._loop.create_future() 385 future.add_done_callback(self._on_read_future_done) 386 self._read_queue.append((True, n, future)) 387 self._process_input_buffer() 388 return await future 389 390 def reset_input_buffer(self) -> int: 391 count = self._input_buffer.clear() 392 self._transport.resume_reading() 393 return count 394 395 async def async_close(self): 396 if self._transport is not None: 397 self._transport.close() 398 399 await self.wait_closed() 400 401 async def wait_closed(self): 402 if self._closed_futures is None: 403 return 404 405 future = self._loop.create_future() 406 self._closed_futures.append(future) 407 await future 408 409 def _on_read_future_done(self, future): 410 if not self._read_queue: 411 return 412 413 if not future.cancelled(): 414 return 415 416 for _ in range(len(self._read_queue)): 417 i = self._read_queue.popleft() 418 if not i[2].done(): 419 self._read_queue.append(i) 420 421 self._process_input_buffer() 422 423 def _process_input_buffer(self): 424 while self._input_buffer and self._read_queue: 425 exact, n, future = self._read_queue.popleft() 426 if future.done(): 427 continue 428 429 if not exact: 430 future.set_result(self._input_buffer.read(n)) 431 432 elif n <= len(self._input_buffer): 433 future.set_result(self._input_buffer.read(n)) 434 435 else: 436 self._read_queue.appendleft((exact, n, future)) 437 break 438 439 if not self._transport: 440 return 441 442 pause = (self._input_buffer_limit > 0 and 443 len(self._input_buffer) > self._input_buffer_limit and 444 not self._read_queue) 445 446 if pause: 447 self._transport.pause_reading() 448 449 else: 450 self._transport.resume_reading()
Module logger
Address(host, port)
ConnectionInfo(local_addr, remote_addr)
Connection callback
35async def connect(addr: Address, 36 *, 37 input_buffer_limit: int = 64 * 1024, 38 **kwargs 39 ) -> 'Connection': 40 """Create TCP connection 41 42 Argument `addr` specifies remote server listening address. 43 44 Argument `input_buffer_limit` defines number of bytes in input buffer 45 that whill temporary pause data receiving. Once number of bytes 46 drops bellow `input_buffer_limit`, data receiving is resumed. If this 47 argument is ``0``, data receive pausing is disabled. 48 49 Additional arguments are passed directly to `asyncio.create_connection`. 50 51 """ 52 loop = asyncio.get_running_loop() 53 create_transport = functools.partial(Protocol, None, input_buffer_limit) 54 _, protocol = await loop.create_connection(create_transport, 55 addr.host, addr.port, 56 **kwargs) 57 return Connection(protocol)
Create TCP connection
Argument addr
specifies remote server listening address.
Argument input_buffer_limit
defines number of bytes in input buffer
that whill temporary pause data receiving. Once number of bytes
drops bellow input_buffer_limit
, data receiving is resumed. If this
argument is 0
, data receive pausing is disabled.
Additional arguments are passed directly to asyncio.create_connection
.
60async def listen(connection_cb: ConnectionCb, 61 addr: Address, 62 *, 63 bind_connections: bool = False, 64 input_buffer_limit: int = 64 * 1024, 65 **kwargs 66 ) -> 'Server': 67 """Create listening server 68 69 If `bind_connections` is ``True``, closing server will close all open 70 incoming connections. 71 72 Argument `input_buffer_limit` is associated with newly created connections 73 (see `connect`). 74 75 Additional arguments are passed directly to `asyncio.create_server`. 76 77 """ 78 server = Server() 79 server._connection_cb = connection_cb 80 server._bind_connections = bind_connections 81 server._async_group = aio.Group() 82 83 on_connection = functools.partial(server.async_group.spawn, 84 server._on_connection) 85 create_transport = functools.partial(Protocol, on_connection, 86 input_buffer_limit) 87 88 loop = asyncio.get_running_loop() 89 server._srv = await loop.create_server(create_transport, addr.host, 90 addr.port, **kwargs) 91 92 server.async_group.spawn(aio.call_on_cancel, server._on_close) 93 94 try: 95 socknames = (socket.getsockname() for socket in server._srv.sockets) 96 server._addresses = [Address(*sockname[:2]) for sockname in socknames] 97 98 except Exception: 99 await aio.uncancellable(server.async_close()) 100 raise 101 102 return server
Create listening server
If bind_connections
is True
, closing server will close all open
incoming connections.
Argument input_buffer_limit
is associated with newly created connections
(see connect
).
Additional arguments are passed directly to asyncio.create_server
.
105class Server(aio.Resource): 106 """TCP listening server 107 108 Closing server will cancel all running `connection_cb` coroutines. 109 110 """ 111 112 @property 113 def async_group(self) -> aio.Group: 114 """Async group""" 115 return self._async_group 116 117 @property 118 def addresses(self) -> list[Address]: 119 """Listening addresses""" 120 return self._addresses 121 122 async def _on_close(self): 123 self._srv.close() 124 125 if self._bind_connections or sys.version_info[:2] < (3, 12): 126 await self._srv.wait_closed() 127 128 async def _on_connection(self, protocol): 129 conn = Connection(protocol) 130 131 try: 132 await aio.call(self._connection_cb, conn) 133 134 if self._bind_connections: 135 await conn.wait_closing() 136 137 else: 138 conn = None 139 140 except Exception as e: 141 mlog.warning('connection callback error: %s', e, exc_info=e) 142 143 finally: 144 if conn: 145 await aio.uncancellable(conn.async_close())
TCP listening server
Closing server will cancel all running connection_cb
coroutines.
148class Connection(aio.Resource): 149 """TCP connection""" 150 151 def __init__(self, protocol: 'Protocol'): 152 self._protocol = protocol 153 self._async_group = aio.Group() 154 155 self.async_group.spawn(aio.call_on_cancel, protocol.async_close) 156 self.async_group.spawn(aio.call_on_done, protocol.wait_closed(), 157 self.close) 158 159 @property 160 def async_group(self) -> aio.Group: 161 """Async group""" 162 return self._async_group 163 164 @property 165 def info(self) -> ConnectionInfo: 166 """Connection info""" 167 return self._protocol.info 168 169 @property 170 def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None: 171 """SSL Object""" 172 return self._protocol.ssl_object 173 174 async def write(self, data: util.Bytes): 175 """Write data 176 177 This coroutine will wait until `data` can be added to output buffer. 178 179 """ 180 if not self.is_open: 181 raise ConnectionError() 182 183 await self._protocol.write(data) 184 185 async def drain(self): 186 """Drain output buffer""" 187 await self._protocol.drain() 188 189 async def read(self, n: int = -1) -> util.Bytes: 190 """Read up to `n` bytes 191 192 If EOF is detected and no new bytes are available, `ConnectionError` 193 is raised. 194 195 """ 196 return await self._protocol.read(n) 197 198 async def readexactly(self, n: int) -> util.Bytes: 199 """Read exactly `n` bytes 200 201 If exact number of bytes could not be read, `ConnectionError` is 202 raised. 203 204 """ 205 return await self._protocol.readexactly(n) 206 207 def reset_input_buffer(self) -> int: 208 """Reset input buffer 209 210 Returns number of bytes cleared from buffer. 211 212 """ 213 return self._protocol.reset_input_buffer()
TCP connection
159 @property 160 def async_group(self) -> aio.Group: 161 """Async group""" 162 return self._async_group
Async group
164 @property 165 def info(self) -> ConnectionInfo: 166 """Connection info""" 167 return self._protocol.info
Connection info
169 @property 170 def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None: 171 """SSL Object""" 172 return self._protocol.ssl_object
SSL Object
174 async def write(self, data: util.Bytes): 175 """Write data 176 177 This coroutine will wait until `data` can be added to output buffer. 178 179 """ 180 if not self.is_open: 181 raise ConnectionError() 182 183 await self._protocol.write(data)
Write data
This coroutine will wait until data
can be added to output buffer.
189 async def read(self, n: int = -1) -> util.Bytes: 190 """Read up to `n` bytes 191 192 If EOF is detected and no new bytes are available, `ConnectionError` 193 is raised. 194 195 """ 196 return await self._protocol.read(n)
Read up to n
bytes
If EOF is detected and no new bytes are available, ConnectionError
is raised.
198 async def readexactly(self, n: int) -> util.Bytes: 199 """Read exactly `n` bytes 200 201 If exact number of bytes could not be read, `ConnectionError` is 202 raised. 203 204 """ 205 return await self._protocol.readexactly(n)
Read exactly n
bytes
If exact number of bytes could not be read, ConnectionError
is
raised.
216class Protocol(asyncio.Protocol): 217 """Asyncio protocol implementation""" 218 219 def __init__(self, 220 on_connected: typing.Callable[['Protocol'], None] | None, 221 input_buffer_limit: int): 222 self._on_connected = on_connected 223 self._input_buffer_limit = input_buffer_limit 224 self._loop = asyncio.get_running_loop() 225 self._input_buffer = util.BytesBuffer() 226 self._transport = None 227 self._read_queue = None 228 self._write_queue = None 229 self._drain_futures = None 230 self._closed_futures = None 231 self._info = None 232 self._ssl_object = None 233 234 @property 235 def info(self) -> ConnectionInfo: 236 return self._info 237 238 @property 239 def ssl_object(self) -> ssl.SSLObject | ssl.SSLSocket | None: 240 return self._ssl_object 241 242 def connection_made(self, transport: asyncio.Transport): 243 self._transport = transport 244 self._read_queue = collections.deque() 245 self._closed_futures = collections.deque() 246 247 try: 248 sockname = transport.get_extra_info('sockname') 249 peername = transport.get_extra_info('peername') 250 self._info = ConnectionInfo( 251 local_addr=Address(sockname[0], sockname[1]), 252 remote_addr=Address(peername[0], peername[1])) 253 self._ssl_object = transport.get_extra_info('ssl_object') 254 255 if self._on_connected: 256 self._on_connected(self) 257 258 except Exception: 259 transport.abort() 260 return 261 262 def connection_lost(self, exc: Exception | None): 263 self._transport = None 264 write_queue, self._write_queue = self._write_queue, None 265 drain_futures, self._drain_futures = self._drain_futures, None 266 closed_futures, self._closed_futures = self._closed_futures, None 267 268 self.eof_received() 269 270 while write_queue: 271 _, future = write_queue.popleft() 272 if not future.done(): 273 future.set_exception(ConnectionError()) 274 275 while drain_futures: 276 future = drain_futures.popleft() 277 if not future.done(): 278 future.set_result(None) 279 280 while closed_futures: 281 future = closed_futures.popleft() 282 if not future.done(): 283 future.set_result(None) 284 285 def pause_writing(self): 286 self._write_queue = collections.deque() 287 self._drain_futures = collections.deque() 288 289 def resume_writing(self): 290 write_queue, self._write_queue = self._write_queue, None 291 drain_futures, self._drain_futures = self._drain_futures, None 292 293 while self._write_queue is None and write_queue: 294 data, future = write_queue.popleft() 295 if future.done(): 296 continue 297 298 self._transport.write(data) 299 future.set_result(None) 300 301 if write_queue: 302 write_queue.extend(self._write_queue) 303 self._write_queue = write_queue 304 305 drain_futures.extend(self._drain_futures) 306 self._drain_futures = drain_futures 307 308 return 309 310 while drain_futures: 311 future = drain_futures.popleft() 312 if not future.done(): 313 future.set_result(None) 314 315 def data_received(self, data: util.Bytes): 316 self._input_buffer.add(data) 317 self._process_input_buffer() 318 319 def eof_received(self): 320 while self._read_queue: 321 exact, n, future = self._read_queue.popleft() 322 if future.done(): 323 continue 324 325 if exact and n <= len(self._input_buffer): 326 future.set_result(self._input_buffer.read(n)) 327 328 elif not exact and self._input_buffer: 329 future.set_result(self._input_buffer.read(n)) 330 331 else: 332 future.set_exception(ConnectionError()) 333 334 self._read_queue = None 335 336 async def write(self, data: util.Bytes): 337 if self._transport is None: 338 raise ConnectionError() 339 340 if self._write_queue is None: 341 self._transport.write(data) 342 return 343 344 future = self._loop.create_future() 345 self._write_queue.append((data, future)) 346 await future 347 348 async def drain(self): 349 if self._drain_futures is None: 350 return 351 352 future = self._loop.create_future() 353 self._drain_futures.append(future) 354 await future 355 356 async def read(self, n: int) -> util.Bytes: 357 if n == 0: 358 return b'' 359 360 if self._input_buffer and not self._read_queue: 361 data = self._input_buffer.read(n) 362 self._process_input_buffer() 363 return data 364 365 if self._read_queue is None: 366 raise ConnectionError() 367 368 future = self._loop.create_future() 369 future.add_done_callback(self._on_read_future_done) 370 self._read_queue.append((False, n, future)) 371 return await future 372 373 async def readexactly(self, n: int) -> util.Bytes: 374 if n == 0: 375 return b'' 376 377 if n <= len(self._input_buffer) and not self._read_queue: 378 data = self._input_buffer.read(n) 379 self._process_input_buffer() 380 return data 381 382 if self._read_queue is None: 383 raise ConnectionError() 384 385 future = self._loop.create_future() 386 future.add_done_callback(self._on_read_future_done) 387 self._read_queue.append((True, n, future)) 388 self._process_input_buffer() 389 return await future 390 391 def reset_input_buffer(self) -> int: 392 count = self._input_buffer.clear() 393 self._transport.resume_reading() 394 return count 395 396 async def async_close(self): 397 if self._transport is not None: 398 self._transport.close() 399 400 await self.wait_closed() 401 402 async def wait_closed(self): 403 if self._closed_futures is None: 404 return 405 406 future = self._loop.create_future() 407 self._closed_futures.append(future) 408 await future 409 410 def _on_read_future_done(self, future): 411 if not self._read_queue: 412 return 413 414 if not future.cancelled(): 415 return 416 417 for _ in range(len(self._read_queue)): 418 i = self._read_queue.popleft() 419 if not i[2].done(): 420 self._read_queue.append(i) 421 422 self._process_input_buffer() 423 424 def _process_input_buffer(self): 425 while self._input_buffer and self._read_queue: 426 exact, n, future = self._read_queue.popleft() 427 if future.done(): 428 continue 429 430 if not exact: 431 future.set_result(self._input_buffer.read(n)) 432 433 elif n <= len(self._input_buffer): 434 future.set_result(self._input_buffer.read(n)) 435 436 else: 437 self._read_queue.appendleft((exact, n, future)) 438 break 439 440 if not self._transport: 441 return 442 443 pause = (self._input_buffer_limit > 0 and 444 len(self._input_buffer) > self._input_buffer_limit and 445 not self._read_queue) 446 447 if pause: 448 self._transport.pause_reading() 449 450 else: 451 self._transport.resume_reading()
Asyncio protocol implementation
219 def __init__(self, 220 on_connected: typing.Callable[['Protocol'], None] | None, 221 input_buffer_limit: int): 222 self._on_connected = on_connected 223 self._input_buffer_limit = input_buffer_limit 224 self._loop = asyncio.get_running_loop() 225 self._input_buffer = util.BytesBuffer() 226 self._transport = None 227 self._read_queue = None 228 self._write_queue = None 229 self._drain_futures = None 230 self._closed_futures = None 231 self._info = None 232 self._ssl_object = None
242 def connection_made(self, transport: asyncio.Transport): 243 self._transport = transport 244 self._read_queue = collections.deque() 245 self._closed_futures = collections.deque() 246 247 try: 248 sockname = transport.get_extra_info('sockname') 249 peername = transport.get_extra_info('peername') 250 self._info = ConnectionInfo( 251 local_addr=Address(sockname[0], sockname[1]), 252 remote_addr=Address(peername[0], peername[1])) 253 self._ssl_object = transport.get_extra_info('ssl_object') 254 255 if self._on_connected: 256 self._on_connected(self) 257 258 except Exception: 259 transport.abort() 260 return
Called when a connection is made.
The argument is the transport representing the pipe connection. To receive data, wait for data_received() calls. When the connection is closed, connection_lost() is called.
262 def connection_lost(self, exc: Exception | None): 263 self._transport = None 264 write_queue, self._write_queue = self._write_queue, None 265 drain_futures, self._drain_futures = self._drain_futures, None 266 closed_futures, self._closed_futures = self._closed_futures, None 267 268 self.eof_received() 269 270 while write_queue: 271 _, future = write_queue.popleft() 272 if not future.done(): 273 future.set_exception(ConnectionError()) 274 275 while drain_futures: 276 future = drain_futures.popleft() 277 if not future.done(): 278 future.set_result(None) 279 280 while closed_futures: 281 future = closed_futures.popleft() 282 if not future.done(): 283 future.set_result(None)
Called when the connection is lost or closed.
The argument is an exception object or None (the latter meaning a regular EOF is received or the connection was aborted or closed).
285 def pause_writing(self): 286 self._write_queue = collections.deque() 287 self._drain_futures = collections.deque()
Called when the transport's buffer goes over the high-water mark.
Pause and resume calls are paired -- pause_writing() is called once when the buffer goes strictly over the high-water mark (even if subsequent writes increases the buffer size even more), and eventually resume_writing() is called once when the buffer size reaches the low-water mark.
Note that if the buffer size equals the high-water mark, pause_writing() is not called -- it must go strictly over. Conversely, resume_writing() is called when the buffer size is equal or lower than the low-water mark. These end conditions are important to ensure that things go as expected when either mark is zero.
NOTE: This is the only Protocol callback that is not called through EventLoop.call_soon() -- if it were, it would have no effect when it's most needed (when the app keeps writing without yielding until pause_writing() is called).
289 def resume_writing(self): 290 write_queue, self._write_queue = self._write_queue, None 291 drain_futures, self._drain_futures = self._drain_futures, None 292 293 while self._write_queue is None and write_queue: 294 data, future = write_queue.popleft() 295 if future.done(): 296 continue 297 298 self._transport.write(data) 299 future.set_result(None) 300 301 if write_queue: 302 write_queue.extend(self._write_queue) 303 self._write_queue = write_queue 304 305 drain_futures.extend(self._drain_futures) 306 self._drain_futures = drain_futures 307 308 return 309 310 while drain_futures: 311 future = drain_futures.popleft() 312 if not future.done(): 313 future.set_result(None)
Called when the transport's buffer drains below the low-water mark.
See pause_writing() for details.
315 def data_received(self, data: util.Bytes): 316 self._input_buffer.add(data) 317 self._process_input_buffer()
Called when some data is received.
The argument is a bytes object.
319 def eof_received(self): 320 while self._read_queue: 321 exact, n, future = self._read_queue.popleft() 322 if future.done(): 323 continue 324 325 if exact and n <= len(self._input_buffer): 326 future.set_result(self._input_buffer.read(n)) 327 328 elif not exact and self._input_buffer: 329 future.set_result(self._input_buffer.read(n)) 330 331 else: 332 future.set_exception(ConnectionError()) 333 334 self._read_queue = None
Called when the other end calls write_eof() or equivalent.
If this returns a false value (including None), the transport will close itself. If it returns a true value, closing the transport is up to the protocol.
356 async def read(self, n: int) -> util.Bytes: 357 if n == 0: 358 return b'' 359 360 if self._input_buffer and not self._read_queue: 361 data = self._input_buffer.read(n) 362 self._process_input_buffer() 363 return data 364 365 if self._read_queue is None: 366 raise ConnectionError() 367 368 future = self._loop.create_future() 369 future.add_done_callback(self._on_read_future_done) 370 self._read_queue.append((False, n, future)) 371 return await future
373 async def readexactly(self, n: int) -> util.Bytes: 374 if n == 0: 375 return b'' 376 377 if n <= len(self._input_buffer) and not self._read_queue: 378 data = self._input_buffer.read(n) 379 self._process_input_buffer() 380 return data 381 382 if self._read_queue is None: 383 raise ConnectionError() 384 385 future = self._loop.create_future() 386 future.add_done_callback(self._on_read_future_done) 387 self._read_queue.append((True, n, future)) 388 self._process_input_buffer() 389 return await future