Merge pull request 'refactor: make connector use queues' (#4) from refactor/make-connector-use-queues into develop

Reviewed-on: #4
This commit is contained in:
2025-11-30 17:35:33 +00:00
2 changed files with 151 additions and 144 deletions

View File

@@ -2,24 +2,33 @@
from __future__ import annotations from __future__ import annotations
import logging as lg import logging as lg
import selectors
import socket import socket
import time import time
from typing import Callable from typing import Callable
from judas_protocol import Message from judas_protocol import Category, ControlAction, Message
class Connector: class Connector:
"""Connector class for managing TCP connection and message exchange."""
def __init__( def __init__(
self, self,
mac_address: str, mac_address: str,
host: str, host: str,
port: int, port: int,
*, *,
connect_timeout: float = 5.0,
ack_timeout: float | None = None,
on_message: Callable[[Message], None], on_message: Callable[[Message], None],
) -> None: ) -> None:
"""Initialize the Connector.
Args:
mac_address (str): The MAC address of the client.
host (str): The server host address.
port (int): The server port number.
on_message (Callable[[Message], None]): Callback for handling incoming messages.
"""
self.logger: lg.Logger = lg.getLogger( self.logger: lg.Logger = lg.getLogger(
f"{__name__}.{self.__class__.__name__}" f"{__name__}.{self.__class__.__name__}"
) )
@@ -27,172 +36,170 @@ class Connector:
self.host: str = host self.host: str = host
self.port: int = port self.port: int = port
self.socket_timeout: None = None
self.connect_timeout: float = connect_timeout
self.ack_timeout: float | None = ack_timeout
self.selector = selectors.DefaultSelector()
self.socket: socket.socket = socket.socket( self.socket: socket.socket = socket.socket(
socket.AF_INET, socket.SOCK_STREAM socket.AF_INET, socket.SOCK_STREAM
) )
self.socket.setblocking(False)
self.selector.register(
self.socket,
selectors.EVENT_READ | selectors.EVENT_WRITE,
data=None,
)
self.mac_address: str = mac_address self.mac_address: str = mac_address
self.inbound_buffer: bytes = b""
self.outbound_buffer: bytes = b""
self.pending_acks: dict[str, tuple[Message, float]] = {}
self.running: bool = True
self.on_message: Callable[[Message], None] = on_message self.on_message: Callable[[Message], None] = on_message
def _send_ack(self) -> None: def _send_outbound(self) -> None:
self.logger.debug("[>] Sending ACK...") """Send data from the outbound buffer."""
try: while self.outbound_buffer:
self.socket.sendall(Message.ack().to_bytes())
self.logger.debug("[<] ACK sent")
except socket.error as e:
self.logger.error(f"[!] Failed to send ACK: {e}")
def _check_ack(self) -> bool:
self.logger.debug("[.] Waiting for ACK...")
try:
self.socket.settimeout(self.ack_timeout)
ack: bytes = self.socket.recv(1024)
self.socket.settimeout(self.socket_timeout)
if ack == Message.ack().to_bytes():
self.logger.debug("[<] ACK received")
return True
else:
self.logger.error(f"[!] Invalid ACK received: {ack}")
except TimeoutError as e:
self.logger.error(f"[!] ACK timeout: {e}")
except socket.error as e:
self.logger.error(f"[!] Failed to receive ACK: {e}")
return False
def connect(self, retry_interval: int = 1) -> None:
self.logger.debug(
f"Connecting to {self.host}:{self.port} with timeout {self.connect_timeout}s..."
)
try:
self.socket.settimeout(self.connect_timeout)
self.socket.connect((self.host, self.port))
self.socket.settimeout(self.socket_timeout)
self.logger.info(f"[+] Connected to {self.host}:{self.port}")
self.send_hello()
except (
socket.timeout,
ConnectionRefusedError,
ConnectionAbortedError,
) as e:
self.logger.error(
f"[!] Connection to {self.host}:{self.port} failed: {e}"
)
self.logger.info(
f"[.] Retrying connection in {retry_interval} s..."
)
time.sleep(retry_interval)
self.connect(retry_interval=min(30, retry_interval * 2))
def send(self, data: bytes, no_check_ack: bool = False) -> None:
self.logger.debug(f"[>] Sending data: {data}")
while True:
try: try:
self.socket.sendall(data) sent = self.socket.send(self.outbound_buffer)
self.logger.debug(
f"[>] Sent {sent} bytes: {self.outbound_buffer[:sent]!r}"
)
if no_check_ack: self.outbound_buffer = self.outbound_buffer[sent:]
self.logger.debug("[>] Data sent without ACK check") except BlockingIOError:
break # OS buffer full, wait for next EVENT_WRITE
else: break
self.logger.info("[>] Data sent") except socket.error as e:
self.logger.error(f"[!] Socket error: {e}")
self.reconnect()
break
acknowledged: bool = self._check_ack() def _receive_inbound(self) -> None:
if acknowledged: """Receive data into the inbound buffer."""
self.logger.debug("[.] Data acknowledged")
break
else:
self.logger.warning(
"[!] Data not acknowledged, retrying..."
)
except BrokenPipeError as e:
self.logger.error(f"[!] Broken pipe: {e}")
self.logger.info("[.] Reconnecting...")
self.connect()
except (socket.error, ValueError) as e:
self.logger.error(f"[!] Failed to send data: {e}")
time.sleep(1)
def receive(self) -> bytes:
self.logger.debug("[.] Waiting to receive data...")
try: try:
data: bytes = self.socket.recv(4096) data: bytes = self.socket.recv(4096)
if not data: if data:
self.logger.warning("[!] Received empty message") self.logger.debug(f"[<] Received {len(data)} bytes: {data!r}")
return b"" self.inbound_buffer += data
self.logger.debug(f"[<] Received data: {data}") else:
return data self.logger.debug("[!] Connection closed by the server.")
self.reconnect()
except socket.error as e: except socket.error as e:
self.logger.error(f"[!] Failed to receive data: {e}") self.logger.error(f"[!] Socket error: {e}")
return b"" self.reconnect()
def close(self) -> None: def send(self, message: Message) -> None:
self.logger.debug("Closing connection...") """Send a message to the server.
self.socket.close()
self.logger.info("Connection closed.")
def reconnect(self) -> None: Args:
self.logger.debug("Reconnecting...") message (Message): The message to send.
self.close() """
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.logger.debug(f"[>] Queueing message to send: {message}")
self.connect() if message.ack_required:
self.pending_acks[message.id] = (message, time.time())
self.outbound_buffer += message.to_bytes()
def send_hello(self) -> None: def send_hello(self) -> None:
self.logger.debug("[.] Sending hello message...") """Send a HELLO message to the server."""
self.logger.debug("[*] Sending HELLO message...")
hello_message: Message = Message.hello(self.mac_address) hello_message: Message = Message.hello(self.mac_address)
acknowledged: bool = False self.send(hello_message)
while not acknowledged:
self.send(hello_message.to_bytes(), no_check_ack=True)
self.logger.debug("[.] Hello message sent, waiting for ACK...")
acknowledged = self._check_ack()
if not acknowledged:
self.logger.warning(
"[!] Hello message not acknowledged, retrying..."
)
time.sleep(1)
def _loop(self) -> None: def close(self) -> None:
self.logger.debug("Starting connector loop...") """Close the connection and clean up resources."""
while True: self.logger.debug("[*] Closing connection...")
time.sleep(0.1) try:
data: bytes = self.receive() self.selector.unregister(self.socket)
if not data: except Exception as e:
self.reconnect() self.logger.error(f"[!] Error unregistering socket: {e}")
continue self.socket.close()
for line in data.split(b"\n"): self.logger.debug("[.] Connection closed.")
line: bytes = line.strip()
if not line: def reconnect(self) -> None:
continue """Reconnect to the server."""
self.logger.debug("[*] Reconnecting...")
self.close()
self.logger.debug(f"[.] Raw message data: {line}") # reinit socket
try: self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
message: Message = Message.from_bytes(line) self.socket.setblocking(False)
except ValueError as e: self.selector.register(
self.logger.error(f"[!] Failed to parse message: {e}") self.socket,
continue selectors.EVENT_READ | selectors.EVENT_WRITE,
self.logger.info(f"[*] Message received: {message}") data=None,
self.on_message(message) )
self.connect()
# if self._check_ack(): def connect(self) -> None:
# self.logger.debug("[.] ACK verified") """Establish a connection to the server."""
# else: self.logger.debug(f"Connecting to {self.host}:{self.port}...")
# self.logger.error("[!] ACK verification failed") connected: bool = False
delay: float = 1.0
while not connected:
try:
self.socket.connect((self.host, self.port))
connected = True
except BlockingIOError:
# Connection in progress
time.sleep(0.1)
except socket.error as e:
self.logger.error(f"[!] Connection error: {e}")
self.logger.debug(f"[.] Retrying in {delay} seconds...")
time.sleep(delay)
delay = min(delay * 2, 30) # exponential backoff
self.logger.debug("[*] Connected, sending HELLO...")
self.send_hello()
def run(self) -> None: def run(self) -> None:
self.logger.debug("Running Connector...") """Run the main event loop."""
self.connect()
try: try:
self.connect() while self.running:
self._loop() events = self.selector.select(timeout=1)
for key, mask in events:
if mask & selectors.EVENT_READ:
self._receive_inbound()
if mask & selectors.EVENT_WRITE:
self._send_outbound()
# Process inbound buffer for complete messages
while b"\n" in self.inbound_buffer:
message_bytes, self.inbound_buffer = (
self.inbound_buffer.split(b"\n", 1)
)
try:
message: Message = Message.from_bytes(message_bytes)
# handle incoming ACKs
if (
message.category == Category.CONTROL
and message.action == ControlAction.ACK
):
if (
message.payload.get("target_id")
in self.pending_acks
):
target_id = message.payload["target_id"]
self.logger.debug(
f"[.] Received ACK for message ID {target_id}"
)
del self.pending_acks[target_id]
else:
self.on_message(message)
if message.ack_required:
ack_message: Message = Message.ack(message.id)
self.send(ack_message)
self._send_outbound()
except Exception as e:
self.logger.error(f"[!] Failed to parse message: {e}")
time.sleep(0.1)
except KeyboardInterrupt: except KeyboardInterrupt:
self.logger.info("Interrupted by user.") self.logger.debug("[*] Interrupted by user.")
finally: finally:
self.close() self.close()

4
uv.lock generated
View File

@@ -292,8 +292,8 @@ test = [
[[package]] [[package]]
name = "judas-protocol" name = "judas-protocol"
version = "0.2.0" version = "0.5.0"
source = { git = "https://gitea.pufereq.pl/judas/judas_protocol.git#bc1bf46388eb904738893a2f86b5050b4ce2489e" } source = { git = "https://gitea.pufereq.pl/judas/judas_protocol.git#c48b69ecee16f5824ffd8bce8921341d5fa326b7" }
[[package]] [[package]]
name = "markdown-it-py" name = "markdown-it-py"