hat.drivers.tpkt

Transport Service on top of TCP

  1"""Transport Service on top of TCP"""
  2
  3import asyncio
  4import itertools
  5import logging
  6import typing
  7
  8from hat import aio
  9from hat import util
 10
 11from hat.drivers import tcp
 12
 13
 14mlog: logging.Logger = logging.getLogger(__name__)
 15"""Module logger"""
 16
 17ConnectionCb: typing.TypeAlias = aio.AsyncCallable[['Connection'], None]
 18"""Connection callback"""
 19
 20
 21async def connect(addr: tcp.Address,
 22                  **kwargs
 23                  ) -> 'Connection':
 24    """Create new TPKT connection
 25
 26    Additional arguments are passed directly to `hat.drivers.tcp.connect`.
 27
 28    """
 29    conn = await tcp.connect(addr, **kwargs)
 30    return Connection(conn)
 31
 32
 33async def listen(connection_cb: ConnectionCb,
 34                 addr: tcp.Address = tcp.Address('0.0.0.0', 102),
 35                 **kwargs
 36                 ) -> 'Server':
 37    """Create new TPKT listening server
 38
 39    Additional arguments are passed directly to `hat.drivers.tcp.listen`.
 40
 41    """
 42    server = Server()
 43    server._connection_cb = connection_cb
 44
 45    server._srv = await tcp.listen(server._on_connection, addr, **kwargs)
 46
 47    return server
 48
 49
 50class Server(aio.Resource):
 51    """TPKT listening server
 52
 53    For creation of new instance see `listen` coroutine.
 54
 55    """
 56
 57    @property
 58    def async_group(self) -> aio.Group:
 59        """Async group"""
 60        return self._srv.async_group
 61
 62    @property
 63    def addresses(self) -> list[tcp.Address]:
 64        """Listening addresses"""
 65        return self._srv.addresses
 66
 67    async def _on_connection(self, conn):
 68        try:
 69            conn = Connection(conn)
 70            await aio.call(self._connection_cb, conn)
 71
 72        except Exception as e:
 73            mlog.warning('connection callback error: %s', e, exc_info=e)
 74            await aio.uncancellable(conn.async_close())
 75
 76        except asyncio.CancelledError:
 77            await aio.uncancellable(conn.async_close())
 78            raise
 79
 80
 81class Connection(aio.Resource):
 82    """TPKT connection"""
 83
 84    def __init__(self,
 85                 conn: tcp.Connection):
 86        self._conn = conn
 87        self._loop = asyncio.get_running_loop()
 88        self._receive_futures = aio.Queue()
 89
 90        self.async_group.spawn(self._read_loop)
 91
 92    @property
 93    def async_group(self) -> aio.Group:
 94        """Async group"""
 95        return self._conn.async_group
 96
 97    @property
 98    def info(self) -> tcp.ConnectionInfo:
 99        """Connection info"""
100        return self._conn.info
101
102    async def receive(self) -> util.Bytes:
103        """Receive data"""
104        try:
105            future = self._loop.create_future()
106            self._receive_futures.put_nowait(future)
107            return await future
108
109        except aio.QueueClosedError:
110            raise ConnectionError()
111
112    async def send(self, data: util.Bytes):
113        """Send data"""
114        data_len = len(data)
115
116        if data_len > 0xFFFB:
117            raise ValueError("data length greater than 0xFFFB")
118
119        if data_len < 3:
120            raise ValueError("data length less than 3")
121
122        packet_length = data_len + 4
123        packet = bytes(itertools.chain(
124            [3, 0, packet_length >> 8, packet_length & 0xFF],
125            data))
126
127        await self._conn.write(packet)
128
129    async def drain(self):
130        """Drain output buffer"""
131        await self._conn.drain()
132
133    async def _read_loop(self):
134        future = None
135        try:
136            while True:
137                header = await self._conn.readexactly(4)
138                if header[0] != 3:
139                    raise Exception(f"invalid vrsn number "
140                                    f"(received {header[0]})")
141
142                packet_length = (header[2] << 8) | header[3]
143                if packet_length < 7:
144                    raise Exception(f"invalid packet length "
145                                    f"(received {packet_length})")
146
147                data_length = packet_length - 4
148                data = await self._conn.readexactly(data_length)
149
150                while not future or future.done():
151                    future = await self._receive_futures.get()
152
153                future.set_result(data)
154
155        except ConnectionError:
156            pass
157
158        except Exception as e:
159            mlog.warning("read loop error: %s", e, exc_info=e)
160
161        finally:
162            self.close()
163            self._receive_futures.close()
164
165            while True:
166                if future and not future.done():
167                    future.set_exception(ConnectionError())
168                if self._receive_futures.empty():
169                    break
170                future = self._receive_futures.get_nowait()
mlog: logging.Logger = <Logger hat.drivers.tpkt (WARNING)>

Module logger

ConnectionCb: TypeAlias = Callable[[ForwardRef('Connection')], None | Awaitable[None]]

Connection callback

async def connect(addr: hat.drivers.tcp.Address, **kwargs) -> Connection:
22async def connect(addr: tcp.Address,
23                  **kwargs
24                  ) -> 'Connection':
25    """Create new TPKT connection
26
27    Additional arguments are passed directly to `hat.drivers.tcp.connect`.
28
29    """
30    conn = await tcp.connect(addr, **kwargs)
31    return Connection(conn)

Create new TPKT connection

Additional arguments are passed directly to hat.drivers.tcp.connect.

async def listen( connection_cb: Callable[[hat.drivers.acse.Connection], None | Awaitable[None]], addr: hat.drivers.tcp.Address = Address(host='0.0.0.0', port=102), **kwargs) -> Server:
34async def listen(connection_cb: ConnectionCb,
35                 addr: tcp.Address = tcp.Address('0.0.0.0', 102),
36                 **kwargs
37                 ) -> 'Server':
38    """Create new TPKT listening server
39
40    Additional arguments are passed directly to `hat.drivers.tcp.listen`.
41
42    """
43    server = Server()
44    server._connection_cb = connection_cb
45
46    server._srv = await tcp.listen(server._on_connection, addr, **kwargs)
47
48    return server

Create new TPKT listening server

Additional arguments are passed directly to hat.drivers.tcp.listen.

class Server(hat.aio.group.Resource):
51class Server(aio.Resource):
52    """TPKT listening server
53
54    For creation of new instance see `listen` coroutine.
55
56    """
57
58    @property
59    def async_group(self) -> aio.Group:
60        """Async group"""
61        return self._srv.async_group
62
63    @property
64    def addresses(self) -> list[tcp.Address]:
65        """Listening addresses"""
66        return self._srv.addresses
67
68    async def _on_connection(self, conn):
69        try:
70            conn = Connection(conn)
71            await aio.call(self._connection_cb, conn)
72
73        except Exception as e:
74            mlog.warning('connection callback error: %s', e, exc_info=e)
75            await aio.uncancellable(conn.async_close())
76
77        except asyncio.CancelledError:
78            await aio.uncancellable(conn.async_close())
79            raise

TPKT listening server

For creation of new instance see listen coroutine.

async_group: hat.aio.group.Group
58    @property
59    def async_group(self) -> aio.Group:
60        """Async group"""
61        return self._srv.async_group

Async group

addresses: list[hat.drivers.tcp.Address]
63    @property
64    def addresses(self) -> list[tcp.Address]:
65        """Listening addresses"""
66        return self._srv.addresses

Listening addresses

class Connection(hat.aio.group.Resource):
 82class Connection(aio.Resource):
 83    """TPKT connection"""
 84
 85    def __init__(self,
 86                 conn: tcp.Connection):
 87        self._conn = conn
 88        self._loop = asyncio.get_running_loop()
 89        self._receive_futures = aio.Queue()
 90
 91        self.async_group.spawn(self._read_loop)
 92
 93    @property
 94    def async_group(self) -> aio.Group:
 95        """Async group"""
 96        return self._conn.async_group
 97
 98    @property
 99    def info(self) -> tcp.ConnectionInfo:
100        """Connection info"""
101        return self._conn.info
102
103    async def receive(self) -> util.Bytes:
104        """Receive data"""
105        try:
106            future = self._loop.create_future()
107            self._receive_futures.put_nowait(future)
108            return await future
109
110        except aio.QueueClosedError:
111            raise ConnectionError()
112
113    async def send(self, data: util.Bytes):
114        """Send data"""
115        data_len = len(data)
116
117        if data_len > 0xFFFB:
118            raise ValueError("data length greater than 0xFFFB")
119
120        if data_len < 3:
121            raise ValueError("data length less than 3")
122
123        packet_length = data_len + 4
124        packet = bytes(itertools.chain(
125            [3, 0, packet_length >> 8, packet_length & 0xFF],
126            data))
127
128        await self._conn.write(packet)
129
130    async def drain(self):
131        """Drain output buffer"""
132        await self._conn.drain()
133
134    async def _read_loop(self):
135        future = None
136        try:
137            while True:
138                header = await self._conn.readexactly(4)
139                if header[0] != 3:
140                    raise Exception(f"invalid vrsn number "
141                                    f"(received {header[0]})")
142
143                packet_length = (header[2] << 8) | header[3]
144                if packet_length < 7:
145                    raise Exception(f"invalid packet length "
146                                    f"(received {packet_length})")
147
148                data_length = packet_length - 4
149                data = await self._conn.readexactly(data_length)
150
151                while not future or future.done():
152                    future = await self._receive_futures.get()
153
154                future.set_result(data)
155
156        except ConnectionError:
157            pass
158
159        except Exception as e:
160            mlog.warning("read loop error: %s", e, exc_info=e)
161
162        finally:
163            self.close()
164            self._receive_futures.close()
165
166            while True:
167                if future and not future.done():
168                    future.set_exception(ConnectionError())
169                if self._receive_futures.empty():
170                    break
171                future = self._receive_futures.get_nowait()

TPKT connection

Connection(conn: hat.drivers.tcp.Connection)
85    def __init__(self,
86                 conn: tcp.Connection):
87        self._conn = conn
88        self._loop = asyncio.get_running_loop()
89        self._receive_futures = aio.Queue()
90
91        self.async_group.spawn(self._read_loop)
async_group: hat.aio.group.Group
93    @property
94    def async_group(self) -> aio.Group:
95        """Async group"""
96        return self._conn.async_group

Async group

info: hat.drivers.tcp.ConnectionInfo
 98    @property
 99    def info(self) -> tcp.ConnectionInfo:
100        """Connection info"""
101        return self._conn.info

Connection info

async def receive(self) -> bytes | bytearray | memoryview:
103    async def receive(self) -> util.Bytes:
104        """Receive data"""
105        try:
106            future = self._loop.create_future()
107            self._receive_futures.put_nowait(future)
108            return await future
109
110        except aio.QueueClosedError:
111            raise ConnectionError()

Receive data

async def send(self, data: bytes | bytearray | memoryview):
113    async def send(self, data: util.Bytes):
114        """Send data"""
115        data_len = len(data)
116
117        if data_len > 0xFFFB:
118            raise ValueError("data length greater than 0xFFFB")
119
120        if data_len < 3:
121            raise ValueError("data length less than 3")
122
123        packet_length = data_len + 4
124        packet = bytes(itertools.chain(
125            [3, 0, packet_length >> 8, packet_length & 0xFF],
126            data))
127
128        await self._conn.write(packet)

Send data

async def drain(self):
130    async def drain(self):
131        """Drain output buffer"""
132        await self._conn.drain()

Drain output buffer