hat.drivers.tpkt

Transport Service on top of TCP

  1"""Transport Service on top of TCP"""
  2
  3import asyncio
  4import itertools
  5import logging
  6import typing
  7
  8from hat import aio
  9from hat import util
 10
 11from hat.drivers import tcp
 12
 13
 14mlog: logging.Logger = logging.getLogger(__name__)
 15"""Module logger"""
 16
 17ConnectionCb: typing.TypeAlias = aio.AsyncCallable[['Connection'], None]
 18"""Connection callback"""
 19
 20
 21async def connect(addr: tcp.Address,
 22                  *,
 23                  tpkt_receive_queue_size: int = 1024,
 24                  **kwargs
 25                  ) -> 'Connection':
 26    """Create new TPKT connection
 27
 28    Additional arguments are passed directly to `hat.drivers.tcp.connect`.
 29
 30    """
 31    conn = await tcp.connect(addr, **kwargs)
 32    return Connection(conn=conn,
 33                      receive_queue_size=tpkt_receive_queue_size)
 34
 35
 36async def listen(connection_cb: ConnectionCb,
 37                 addr: tcp.Address = tcp.Address('0.0.0.0', 102),
 38                 *,
 39                 tpkt_receive_queue_size: int = 1024,
 40                 **kwargs
 41                 ) -> 'Server':
 42    """Create new TPKT listening server
 43
 44    Additional arguments are passed directly to `hat.drivers.tcp.listen`.
 45
 46    """
 47    server = Server()
 48    server._connection_cb = connection_cb
 49    server._receive_queue_size = tpkt_receive_queue_size
 50    server._log = mlog
 51
 52    server._srv = await tcp.listen(server._on_connection, addr, **kwargs)
 53
 54    server._log = _create_server_logger_adapter(server._srv.info)
 55
 56    return server
 57
 58
 59class Server(aio.Resource):
 60    """TPKT listening server
 61
 62    For creation of new instance see `listen` coroutine.
 63
 64    """
 65
 66    @property
 67    def async_group(self) -> aio.Group:
 68        """Async group"""
 69        return self._srv.async_group
 70
 71    @property
 72    def info(self) -> tcp.ServerInfo:
 73        """Server info"""
 74        return self._srv.info
 75
 76    async def _on_connection(self, conn):
 77        try:
 78            conn = Connection(conn=conn,
 79                              receive_queue_size=self._receive_queue_size)
 80            await aio.call(self._connection_cb, conn)
 81
 82        except Exception as e:
 83            self._log.warning('connection callback error: %s', e, exc_info=e)
 84            await aio.uncancellable(conn.async_close())
 85
 86        except asyncio.CancelledError:
 87            await aio.uncancellable(conn.async_close())
 88            raise
 89
 90
 91class Connection(aio.Resource):
 92    """TPKT connection"""
 93
 94    def __init__(self,
 95                 conn: tcp.Connection,
 96                 receive_queue_size: int):
 97        self._conn = conn
 98        self._receive_queue = aio.Queue(receive_queue_size)
 99        self._log = _create_connection_logger_adapter(False, conn.info)
100        self._comm_log = _create_connection_logger_adapter(True, conn.info)
101
102        self.async_group.spawn(self._read_loop)
103
104        self._comm_log.debug('connection established')
105
106    @property
107    def async_group(self) -> aio.Group:
108        """Async group"""
109        return self._conn.async_group
110
111    @property
112    def info(self) -> tcp.ConnectionInfo:
113        """Connection info"""
114        return self._conn.info
115
116    async def receive(self) -> util.Bytes:
117        """Receive data"""
118        try:
119            return await self._receive_queue.get()
120
121        except aio.QueueClosedError:
122            raise ConnectionError()
123
124    async def send(self, data: util.Bytes):
125        """Send data"""
126        data_len = len(data)
127
128        if data_len > 0xFFFB:
129            raise ValueError("data length greater than 0xFFFB")
130
131        if data_len < 3:
132            raise ValueError("data length less than 3")
133
134        packet_length = data_len + 4
135        packet = bytes(itertools.chain(
136            [3, 0, packet_length >> 8, packet_length & 0xFF],
137            data))
138
139        if self._comm_log.isEnabledFor(logging.DEBUG):
140            self._log.debug('sending %s', data.hex(' '))
141
142        await self._conn.write(packet)
143
144    async def drain(self):
145        """Drain output buffer"""
146        await self._conn.drain()
147
148    async def _read_loop(self):
149        self._log.debug('starting read loop')
150
151        try:
152            while True:
153                header = await self._conn.readexactly(4)
154                if header[0] != 3:
155                    raise Exception(f"invalid vrsn number "
156                                    f"(received {header[0]})")
157
158                packet_length = (header[2] << 8) | header[3]
159                if packet_length < 7:
160                    raise Exception(f"invalid packet length "
161                                    f"(received {packet_length})")
162
163                data_length = packet_length - 4
164                data = await self._conn.readexactly(data_length)
165
166                if self._comm_log.isEnabledFor(logging.DEBUG):
167                    self._comm_log.debug('received %s', data.hex(' '))
168
169                await self._receive_queue.put(data)
170
171        except ConnectionError:
172            pass
173
174        except Exception as e:
175            self._log.warning("read loop error: %s", e, exc_info=e)
176
177        finally:
178            self._log.debug('stopping read loop')
179
180            self.close()
181            self._receive_queue.close()
182
183            self._comm_log.debug('connection closed')
184
185
186def _create_server_logger_adapter(info):
187    extra = {'meta': {'type': 'TpktServer',
188                      'name': info.name,
189                      'addresses': [{'host': addr.host,
190                                     'port': addr.port}
191                                    for addr in info.addresses]}}
192
193    return logging.LoggerAdapter(mlog, extra)
194
195
196def _create_connection_logger_adapter(communication, info):
197    extra = {'meta': {'type': 'TpktConnection',
198                      'communication': communication,
199                      'name': info.name,
200                      'local_addr': {'host': info.local_addr.host,
201                                     'port': info.local_addr.port},
202                      'remote_addr': {'host': info.remote_addr.host,
203                                      'port': info.remote_addr.port}}}
204
205    return logging.LoggerAdapter(mlog, extra)
mlog: logging.Logger = <Logger hat.drivers.tpkt (WARNING)>

Module logger

ConnectionCb: TypeAlias = Callable[[ForwardRef('Connection')], None | Awaitable[None]]

Connection callback

async def connect( addr: hat.drivers.tcp.Address, *, tpkt_receive_queue_size: int = 1024, **kwargs) -> Connection:
22async def connect(addr: tcp.Address,
23                  *,
24                  tpkt_receive_queue_size: int = 1024,
25                  **kwargs
26                  ) -> 'Connection':
27    """Create new TPKT connection
28
29    Additional arguments are passed directly to `hat.drivers.tcp.connect`.
30
31    """
32    conn = await tcp.connect(addr, **kwargs)
33    return Connection(conn=conn,
34                      receive_queue_size=tpkt_receive_queue_size)

Create new TPKT connection

Additional arguments are passed directly to hat.drivers.tcp.connect.

async def listen( connection_cb: Callable[[Connection], None | Awaitable[None]], addr: hat.drivers.tcp.Address = Address(host='0.0.0.0', port=102), *, tpkt_receive_queue_size: int = 1024, **kwargs) -> Server:
37async def listen(connection_cb: ConnectionCb,
38                 addr: tcp.Address = tcp.Address('0.0.0.0', 102),
39                 *,
40                 tpkt_receive_queue_size: int = 1024,
41                 **kwargs
42                 ) -> 'Server':
43    """Create new TPKT listening server
44
45    Additional arguments are passed directly to `hat.drivers.tcp.listen`.
46
47    """
48    server = Server()
49    server._connection_cb = connection_cb
50    server._receive_queue_size = tpkt_receive_queue_size
51    server._log = mlog
52
53    server._srv = await tcp.listen(server._on_connection, addr, **kwargs)
54
55    server._log = _create_server_logger_adapter(server._srv.info)
56
57    return server

Create new TPKT listening server

Additional arguments are passed directly to hat.drivers.tcp.listen.

class Server(hat.aio.group.Resource):
60class Server(aio.Resource):
61    """TPKT listening server
62
63    For creation of new instance see `listen` coroutine.
64
65    """
66
67    @property
68    def async_group(self) -> aio.Group:
69        """Async group"""
70        return self._srv.async_group
71
72    @property
73    def info(self) -> tcp.ServerInfo:
74        """Server info"""
75        return self._srv.info
76
77    async def _on_connection(self, conn):
78        try:
79            conn = Connection(conn=conn,
80                              receive_queue_size=self._receive_queue_size)
81            await aio.call(self._connection_cb, conn)
82
83        except Exception as e:
84            self._log.warning('connection callback error: %s', e, exc_info=e)
85            await aio.uncancellable(conn.async_close())
86
87        except asyncio.CancelledError:
88            await aio.uncancellable(conn.async_close())
89            raise

TPKT listening server

For creation of new instance see listen coroutine.

async_group: hat.aio.group.Group
67    @property
68    def async_group(self) -> aio.Group:
69        """Async group"""
70        return self._srv.async_group

Async group

info: hat.drivers.tcp.ServerInfo
72    @property
73    def info(self) -> tcp.ServerInfo:
74        """Server info"""
75        return self._srv.info

Server info

class Connection(hat.aio.group.Resource):
 92class Connection(aio.Resource):
 93    """TPKT connection"""
 94
 95    def __init__(self,
 96                 conn: tcp.Connection,
 97                 receive_queue_size: int):
 98        self._conn = conn
 99        self._receive_queue = aio.Queue(receive_queue_size)
100        self._log = _create_connection_logger_adapter(False, conn.info)
101        self._comm_log = _create_connection_logger_adapter(True, conn.info)
102
103        self.async_group.spawn(self._read_loop)
104
105        self._comm_log.debug('connection established')
106
107    @property
108    def async_group(self) -> aio.Group:
109        """Async group"""
110        return self._conn.async_group
111
112    @property
113    def info(self) -> tcp.ConnectionInfo:
114        """Connection info"""
115        return self._conn.info
116
117    async def receive(self) -> util.Bytes:
118        """Receive data"""
119        try:
120            return await self._receive_queue.get()
121
122        except aio.QueueClosedError:
123            raise ConnectionError()
124
125    async def send(self, data: util.Bytes):
126        """Send data"""
127        data_len = len(data)
128
129        if data_len > 0xFFFB:
130            raise ValueError("data length greater than 0xFFFB")
131
132        if data_len < 3:
133            raise ValueError("data length less than 3")
134
135        packet_length = data_len + 4
136        packet = bytes(itertools.chain(
137            [3, 0, packet_length >> 8, packet_length & 0xFF],
138            data))
139
140        if self._comm_log.isEnabledFor(logging.DEBUG):
141            self._log.debug('sending %s', data.hex(' '))
142
143        await self._conn.write(packet)
144
145    async def drain(self):
146        """Drain output buffer"""
147        await self._conn.drain()
148
149    async def _read_loop(self):
150        self._log.debug('starting read loop')
151
152        try:
153            while True:
154                header = await self._conn.readexactly(4)
155                if header[0] != 3:
156                    raise Exception(f"invalid vrsn number "
157                                    f"(received {header[0]})")
158
159                packet_length = (header[2] << 8) | header[3]
160                if packet_length < 7:
161                    raise Exception(f"invalid packet length "
162                                    f"(received {packet_length})")
163
164                data_length = packet_length - 4
165                data = await self._conn.readexactly(data_length)
166
167                if self._comm_log.isEnabledFor(logging.DEBUG):
168                    self._comm_log.debug('received %s', data.hex(' '))
169
170                await self._receive_queue.put(data)
171
172        except ConnectionError:
173            pass
174
175        except Exception as e:
176            self._log.warning("read loop error: %s", e, exc_info=e)
177
178        finally:
179            self._log.debug('stopping read loop')
180
181            self.close()
182            self._receive_queue.close()
183
184            self._comm_log.debug('connection closed')

TPKT connection

Connection(conn: hat.drivers.tcp.Connection, receive_queue_size: int)
 95    def __init__(self,
 96                 conn: tcp.Connection,
 97                 receive_queue_size: int):
 98        self._conn = conn
 99        self._receive_queue = aio.Queue(receive_queue_size)
100        self._log = _create_connection_logger_adapter(False, conn.info)
101        self._comm_log = _create_connection_logger_adapter(True, conn.info)
102
103        self.async_group.spawn(self._read_loop)
104
105        self._comm_log.debug('connection established')
async_group: hat.aio.group.Group
107    @property
108    def async_group(self) -> aio.Group:
109        """Async group"""
110        return self._conn.async_group

Async group

info: hat.drivers.tcp.ConnectionInfo
112    @property
113    def info(self) -> tcp.ConnectionInfo:
114        """Connection info"""
115        return self._conn.info

Connection info

async def receive(self) -> bytes | bytearray | memoryview:
117    async def receive(self) -> util.Bytes:
118        """Receive data"""
119        try:
120            return await self._receive_queue.get()
121
122        except aio.QueueClosedError:
123            raise ConnectionError()

Receive data

async def send(self, data: bytes | bytearray | memoryview):
125    async def send(self, data: util.Bytes):
126        """Send data"""
127        data_len = len(data)
128
129        if data_len > 0xFFFB:
130            raise ValueError("data length greater than 0xFFFB")
131
132        if data_len < 3:
133            raise ValueError("data length less than 3")
134
135        packet_length = data_len + 4
136        packet = bytes(itertools.chain(
137            [3, 0, packet_length >> 8, packet_length & 0xFF],
138            data))
139
140        if self._comm_log.isEnabledFor(logging.DEBUG):
141            self._log.debug('sending %s', data.hex(' '))
142
143        await self._conn.write(packet)

Send data

async def drain(self):
145    async def drain(self):
146        """Drain output buffer"""
147        await self._conn.drain()

Drain output buffer