hat.drivers.tpkt

Transport Service on top of TCP

  1"""Transport Service on top of TCP"""
  2
  3import asyncio
  4import itertools
  5import logging
  6import typing
  7
  8from hat import aio
  9from hat import util
 10
 11from hat.drivers import tcp
 12
 13
 14mlog: logging.Logger = logging.getLogger(__name__)
 15"""Module logger"""
 16
 17ConnectionCb: typing.TypeAlias = aio.AsyncCallable[['Connection'], None]
 18"""Connection callback"""
 19
 20
 21async def connect(addr: tcp.Address,
 22                  *,
 23                  tpkt_receive_queue_size: int = 1024,
 24                  **kwargs
 25                  ) -> 'Connection':
 26    """Create new TPKT connection
 27
 28    Additional arguments are passed directly to `hat.drivers.tcp.connect`.
 29
 30    """
 31    conn = await tcp.connect(addr, **kwargs)
 32    return Connection(conn=conn,
 33                      receive_queue_size=tpkt_receive_queue_size)
 34
 35
 36async def listen(connection_cb: ConnectionCb,
 37                 addr: tcp.Address = tcp.Address('0.0.0.0', 102),
 38                 *,
 39                 tpkt_receive_queue_size: int = 1024,
 40                 **kwargs
 41                 ) -> 'Server':
 42    """Create new TPKT listening server
 43
 44    Additional arguments are passed directly to `hat.drivers.tcp.listen`.
 45
 46    """
 47    server = Server()
 48    server._connection_cb = connection_cb
 49    server._receive_queue_size = tpkt_receive_queue_size
 50    server._log = _create_server_logger(kwargs.get('name'), None)
 51
 52    server._srv = await tcp.listen(server._on_connection, addr, **kwargs)
 53
 54    server._log = _create_server_logger(kwargs.get('name'), server._srv.info)
 55
 56    return server
 57
 58
 59class Server(aio.Resource):
 60    """TPKT listening server
 61
 62    For creation of new instance see `listen` coroutine.
 63
 64    """
 65
 66    @property
 67    def async_group(self) -> aio.Group:
 68        """Async group"""
 69        return self._srv.async_group
 70
 71    @property
 72    def info(self) -> tcp.ServerInfo:
 73        """Server info"""
 74        return self._srv.info
 75
 76    async def _on_connection(self, conn):
 77        try:
 78            conn = Connection(conn=conn,
 79                              receive_queue_size=self._receive_queue_size)
 80            await aio.call(self._connection_cb, conn)
 81
 82        except Exception as e:
 83            self._log.warning('connection callback error: %s', e, exc_info=e)
 84            await aio.uncancellable(conn.async_close())
 85
 86        except asyncio.CancelledError:
 87            await aio.uncancellable(conn.async_close())
 88            raise
 89
 90
 91class Connection(aio.Resource):
 92    """TPKT connection"""
 93
 94    def __init__(self,
 95                 conn: tcp.Connection,
 96                 receive_queue_size: int):
 97        self._conn = conn
 98        self._receive_queue = aio.Queue(receive_queue_size)
 99        self._log = _create_connection_logger(conn.info)
100
101        self.async_group.spawn(self._read_loop)
102
103    @property
104    def async_group(self) -> aio.Group:
105        """Async group"""
106        return self._conn.async_group
107
108    @property
109    def info(self) -> tcp.ConnectionInfo:
110        """Connection info"""
111        return self._conn.info
112
113    async def receive(self) -> util.Bytes:
114        """Receive data"""
115        try:
116            return await self._receive_queue.get()
117
118        except aio.QueueClosedError:
119            raise ConnectionError()
120
121    async def send(self, data: util.Bytes):
122        """Send data"""
123        data_len = len(data)
124
125        if data_len > 0xFFFB:
126            raise ValueError("data length greater than 0xFFFB")
127
128        if data_len < 3:
129            raise ValueError("data length less than 3")
130
131        packet_length = data_len + 4
132        packet = bytes(itertools.chain(
133            [3, 0, packet_length >> 8, packet_length & 0xFF],
134            data))
135
136        await self._conn.write(packet)
137
138    async def drain(self):
139        """Drain output buffer"""
140        await self._conn.drain()
141
142    async def _read_loop(self):
143        self._log.debug('starting read loop')
144
145        try:
146            while True:
147                header = await self._conn.readexactly(4)
148                if header[0] != 3:
149                    raise Exception(f"invalid vrsn number "
150                                    f"(received {header[0]})")
151
152                packet_length = (header[2] << 8) | header[3]
153                if packet_length < 7:
154                    raise Exception(f"invalid packet length "
155                                    f"(received {packet_length})")
156
157                data_length = packet_length - 4
158                data = await self._conn.readexactly(data_length)
159
160                await self._receive_queue.put(data)
161
162        except ConnectionError:
163            pass
164
165        except Exception as e:
166            self._log.warning("read loop error: %s", e, exc_info=e)
167
168        finally:
169            self._log.debug('stopping read loop')
170
171            self.close()
172            self._receive_queue.close()
173
174
175def _create_server_logger(name, info):
176    extra = {'meta': {'type': 'TpktServer',
177                      'name': name}}
178
179    if info is not None:
180        extra['meta']['addresses'] = [{'host': addr.host,
181                                       'port': addr.port}
182                                      for addr in info.addresses]
183
184    return logging.LoggerAdapter(mlog, extra)
185
186
187def _create_connection_logger(info):
188    extra = {'meta': {'type': 'TpktConnection',
189                      'name': info.name,
190                      'local_addr': {'host': info.local_addr.host,
191                                     'port': info.local_addr.port},
192                      'remote_addr': {'host': info.remote_addr.host,
193                                      'port': info.remote_addr.port}}}
194
195    return logging.LoggerAdapter(mlog, extra)
mlog: logging.Logger = <Logger hat.drivers.tpkt (WARNING)>

Module logger

ConnectionCb: TypeAlias = Callable[[ForwardRef('Connection')], None | Awaitable[None]]

Connection callback

async def connect( addr: hat.drivers.tcp.Address, *, tpkt_receive_queue_size: int = 1024, **kwargs) -> Connection:
22async def connect(addr: tcp.Address,
23                  *,
24                  tpkt_receive_queue_size: int = 1024,
25                  **kwargs
26                  ) -> 'Connection':
27    """Create new TPKT connection
28
29    Additional arguments are passed directly to `hat.drivers.tcp.connect`.
30
31    """
32    conn = await tcp.connect(addr, **kwargs)
33    return Connection(conn=conn,
34                      receive_queue_size=tpkt_receive_queue_size)

Create new TPKT connection

Additional arguments are passed directly to hat.drivers.tcp.connect.

async def listen( connection_cb: Callable[[Connection], None | Awaitable[None]], addr: hat.drivers.tcp.Address = Address(host='0.0.0.0', port=102), *, tpkt_receive_queue_size: int = 1024, **kwargs) -> Server:
37async def listen(connection_cb: ConnectionCb,
38                 addr: tcp.Address = tcp.Address('0.0.0.0', 102),
39                 *,
40                 tpkt_receive_queue_size: int = 1024,
41                 **kwargs
42                 ) -> 'Server':
43    """Create new TPKT listening server
44
45    Additional arguments are passed directly to `hat.drivers.tcp.listen`.
46
47    """
48    server = Server()
49    server._connection_cb = connection_cb
50    server._receive_queue_size = tpkt_receive_queue_size
51    server._log = _create_server_logger(kwargs.get('name'), None)
52
53    server._srv = await tcp.listen(server._on_connection, addr, **kwargs)
54
55    server._log = _create_server_logger(kwargs.get('name'), server._srv.info)
56
57    return server

Create new TPKT listening server

Additional arguments are passed directly to hat.drivers.tcp.listen.

class Server(hat.aio.group.Resource):
60class Server(aio.Resource):
61    """TPKT listening server
62
63    For creation of new instance see `listen` coroutine.
64
65    """
66
67    @property
68    def async_group(self) -> aio.Group:
69        """Async group"""
70        return self._srv.async_group
71
72    @property
73    def info(self) -> tcp.ServerInfo:
74        """Server info"""
75        return self._srv.info
76
77    async def _on_connection(self, conn):
78        try:
79            conn = Connection(conn=conn,
80                              receive_queue_size=self._receive_queue_size)
81            await aio.call(self._connection_cb, conn)
82
83        except Exception as e:
84            self._log.warning('connection callback error: %s', e, exc_info=e)
85            await aio.uncancellable(conn.async_close())
86
87        except asyncio.CancelledError:
88            await aio.uncancellable(conn.async_close())
89            raise

TPKT listening server

For creation of new instance see listen coroutine.

async_group: hat.aio.group.Group
67    @property
68    def async_group(self) -> aio.Group:
69        """Async group"""
70        return self._srv.async_group

Async group

info: hat.drivers.tcp.ServerInfo
72    @property
73    def info(self) -> tcp.ServerInfo:
74        """Server info"""
75        return self._srv.info

Server info

class Connection(hat.aio.group.Resource):
 92class Connection(aio.Resource):
 93    """TPKT connection"""
 94
 95    def __init__(self,
 96                 conn: tcp.Connection,
 97                 receive_queue_size: int):
 98        self._conn = conn
 99        self._receive_queue = aio.Queue(receive_queue_size)
100        self._log = _create_connection_logger(conn.info)
101
102        self.async_group.spawn(self._read_loop)
103
104    @property
105    def async_group(self) -> aio.Group:
106        """Async group"""
107        return self._conn.async_group
108
109    @property
110    def info(self) -> tcp.ConnectionInfo:
111        """Connection info"""
112        return self._conn.info
113
114    async def receive(self) -> util.Bytes:
115        """Receive data"""
116        try:
117            return await self._receive_queue.get()
118
119        except aio.QueueClosedError:
120            raise ConnectionError()
121
122    async def send(self, data: util.Bytes):
123        """Send data"""
124        data_len = len(data)
125
126        if data_len > 0xFFFB:
127            raise ValueError("data length greater than 0xFFFB")
128
129        if data_len < 3:
130            raise ValueError("data length less than 3")
131
132        packet_length = data_len + 4
133        packet = bytes(itertools.chain(
134            [3, 0, packet_length >> 8, packet_length & 0xFF],
135            data))
136
137        await self._conn.write(packet)
138
139    async def drain(self):
140        """Drain output buffer"""
141        await self._conn.drain()
142
143    async def _read_loop(self):
144        self._log.debug('starting read loop')
145
146        try:
147            while True:
148                header = await self._conn.readexactly(4)
149                if header[0] != 3:
150                    raise Exception(f"invalid vrsn number "
151                                    f"(received {header[0]})")
152
153                packet_length = (header[2] << 8) | header[3]
154                if packet_length < 7:
155                    raise Exception(f"invalid packet length "
156                                    f"(received {packet_length})")
157
158                data_length = packet_length - 4
159                data = await self._conn.readexactly(data_length)
160
161                await self._receive_queue.put(data)
162
163        except ConnectionError:
164            pass
165
166        except Exception as e:
167            self._log.warning("read loop error: %s", e, exc_info=e)
168
169        finally:
170            self._log.debug('stopping read loop')
171
172            self.close()
173            self._receive_queue.close()

TPKT connection

Connection(conn: hat.drivers.tcp.Connection, receive_queue_size: int)
 95    def __init__(self,
 96                 conn: tcp.Connection,
 97                 receive_queue_size: int):
 98        self._conn = conn
 99        self._receive_queue = aio.Queue(receive_queue_size)
100        self._log = _create_connection_logger(conn.info)
101
102        self.async_group.spawn(self._read_loop)
async_group: hat.aio.group.Group
104    @property
105    def async_group(self) -> aio.Group:
106        """Async group"""
107        return self._conn.async_group

Async group

info: hat.drivers.tcp.ConnectionInfo
109    @property
110    def info(self) -> tcp.ConnectionInfo:
111        """Connection info"""
112        return self._conn.info

Connection info

async def receive(self) -> bytes | bytearray | memoryview:
114    async def receive(self) -> util.Bytes:
115        """Receive data"""
116        try:
117            return await self._receive_queue.get()
118
119        except aio.QueueClosedError:
120            raise ConnectionError()

Receive data

async def send(self, data: bytes | bytearray | memoryview):
122    async def send(self, data: util.Bytes):
123        """Send data"""
124        data_len = len(data)
125
126        if data_len > 0xFFFB:
127            raise ValueError("data length greater than 0xFFFB")
128
129        if data_len < 3:
130            raise ValueError("data length less than 3")
131
132        packet_length = data_len + 4
133        packet = bytes(itertools.chain(
134            [3, 0, packet_length >> 8, packet_length & 0xFF],
135            data))
136
137        await self._conn.write(packet)

Send data

async def drain(self):
139    async def drain(self):
140        """Drain output buffer"""
141        await self._conn.drain()

Drain output buffer