hat.drivers.icmp
class
EndpointInfo(typing.NamedTuple):
EndpointInfo(name, local_host)
20async def create_endpoint(local_host: str = '0.0.0.0', 21 *, 22 name: str | None = None 23 ) -> 'Endpoint': 24 loop = asyncio.get_running_loop() 25 local_addr = await _get_host_addr(loop, local_host) 26 27 endpoint = Endpoint() 28 endpoint._async_group = aio.Group() 29 endpoint._loop = loop 30 endpoint._echo_data = _echo_data_iter() 31 endpoint._echo_futures = {} 32 endpoint._info = common.EndpointInfo(name=name, 33 local_host=local_addr[0]) 34 35 endpoint._log = logger.create_logger(mlog, endpoint._info) 36 endpoint._comm_log = logger.CommunicationLogger(mlog, endpoint._info) 37 38 endpoint._socket = _create_socket(local_addr) 39 40 endpoint.async_group.spawn(endpoint._receive_loop) 41 42 endpoint.async_group.spawn(aio.call_on_cancel, endpoint._comm_log.log, 43 common.CommLogAction.CLOSE) 44 endpoint._comm_log.log(common.CommLogAction.OPEN) 45 46 return endpoint
class
Endpoint(hat.aio.group.Resource):
49class Endpoint(aio.Resource): 50 51 @property 52 def async_group(self) -> aio.Group: 53 return self._async_group 54 55 @property 56 def info(self) -> common.EndpointInfo: 57 return self._info 58 59 async def ping(self, remote_host: str): 60 if not self.is_open: 61 raise ConnectionError() 62 63 remote_addr = await _get_host_addr(self._loop, remote_host) 64 65 if not self.is_open: 66 raise ConnectionError() 67 68 data = next(self._echo_data) 69 70 # on linux, echo message identifier is chaged to 71 # `self._socket.getsockname()[1]` 72 req = common.EchoMsg(is_reply=False, 73 identifier=1, 74 sequence_number=1, 75 data=data) 76 req_bytes = encoder.encode_msg(req) 77 78 future = self._loop.create_future() 79 80 try: 81 self._echo_futures[data] = future 82 83 self._comm_log.log(common.CommLogAction.SEND, req) 84 85 if sys.version_info[:2] >= (3, 11): 86 await self._loop.sock_sendto(self._socket, req_bytes, 87 remote_addr) 88 89 else: 90 self._socket.sendto(req_bytes, remote_addr) 91 92 await future 93 94 finally: 95 self._echo_futures.pop(data) 96 97 async def _receive_loop(self): 98 try: 99 while True: 100 msg_bytes = await self._loop.sock_recv(self._socket, 1024) 101 102 try: 103 msg = encoder.decode_msg(memoryview(msg_bytes)) 104 105 except Exception as e: 106 self._log.warning("error decoding message: %s", 107 e, exc_info=e) 108 continue 109 110 self._comm_log.log(common.CommLogAction.RECEIVE, msg) 111 112 if isinstance(msg, common.EchoMsg): 113 self._process_echo_msg(msg) 114 115 except Exception as e: 116 self._log.error("receive loop error: %s", e, exc_info=e) 117 118 finally: 119 self.close() 120 121 for future in self._echo_futures.values(): 122 if not future.done(): 123 future.set_exception(ConnectionError()) 124 125 self._socket.close() 126 127 def _process_echo_msg(self, msg): 128 if not msg.is_reply: 129 return 130 131 # TODO check identifier and sequence number 132 133 data = bytes(msg.data) 134 135 future = self._echo_futures.get(data) 136 if not future or future.done(): 137 return 138 139 future.set_result(None)
Resource with lifetime control based on Group.
async def
ping(self, remote_host: str):
59 async def ping(self, remote_host: str): 60 if not self.is_open: 61 raise ConnectionError() 62 63 remote_addr = await _get_host_addr(self._loop, remote_host) 64 65 if not self.is_open: 66 raise ConnectionError() 67 68 data = next(self._echo_data) 69 70 # on linux, echo message identifier is chaged to 71 # `self._socket.getsockname()[1]` 72 req = common.EchoMsg(is_reply=False, 73 identifier=1, 74 sequence_number=1, 75 data=data) 76 req_bytes = encoder.encode_msg(req) 77 78 future = self._loop.create_future() 79 80 try: 81 self._echo_futures[data] = future 82 83 self._comm_log.log(common.CommLogAction.SEND, req) 84 85 if sys.version_info[:2] >= (3, 11): 86 await self._loop.sock_sendto(self._socket, req_bytes, 87 remote_addr) 88 89 else: 90 self._socket.sendto(req_bytes, remote_addr) 91 92 await future 93 94 finally: 95 self._echo_futures.pop(data)