Compare commits

..

13 Commits

2 changed files with 151 additions and 144 deletions

View File

@@ -2,24 +2,33 @@
from __future__ import annotations from __future__ import annotations
import logging as lg import logging as lg
import selectors
import socket import socket
import time import time
from typing import Callable from typing import Callable
from judas_protocol import Message from judas_protocol import Category, ControlAction, Message
class Connector: class Connector:
"""Connector class for managing TCP connection and message exchange."""
def __init__( def __init__(
self, self,
mac_address: str, mac_address: str,
host: str, host: str,
port: int, port: int,
*, *,
connect_timeout: float = 5.0,
ack_timeout: float | None = None,
on_message: Callable[[Message], None], on_message: Callable[[Message], None],
) -> None: ) -> None:
"""Initialize the Connector.
Args:
mac_address (str): The MAC address of the client.
host (str): The server host address.
port (int): The server port number.
on_message (Callable[[Message], None]): Callback for handling incoming messages.
"""
self.logger: lg.Logger = lg.getLogger( self.logger: lg.Logger = lg.getLogger(
f"{__name__}.{self.__class__.__name__}" f"{__name__}.{self.__class__.__name__}"
) )
@@ -27,172 +36,170 @@ class Connector:
self.host: str = host self.host: str = host
self.port: int = port self.port: int = port
self.socket_timeout: None = None
self.connect_timeout: float = connect_timeout
self.ack_timeout: float | None = ack_timeout
self.selector = selectors.DefaultSelector()
self.socket: socket.socket = socket.socket( self.socket: socket.socket = socket.socket(
socket.AF_INET, socket.SOCK_STREAM socket.AF_INET, socket.SOCK_STREAM
) )
self.socket.setblocking(False)
self.selector.register(
self.socket,
selectors.EVENT_READ | selectors.EVENT_WRITE,
data=None,
)
self.mac_address: str = mac_address self.mac_address: str = mac_address
self.inbound_buffer: bytes = b""
self.outbound_buffer: bytes = b""
self.pending_acks: dict[str, tuple[Message, float]] = {}
self.running: bool = True
self.on_message: Callable[[Message], None] = on_message self.on_message: Callable[[Message], None] = on_message
def _send_ack(self) -> None: def _send_outbound(self) -> None:
self.logger.debug("[>] Sending ACK...") """Send data from the outbound buffer."""
while self.outbound_buffer:
try: try:
self.socket.sendall(Message.ack().to_bytes()) sent = self.socket.send(self.outbound_buffer)
self.logger.debug("[<] ACK sent")
except socket.error as e:
self.logger.error(f"[!] Failed to send ACK: {e}")
def _check_ack(self) -> bool:
self.logger.debug("[.] Waiting for ACK...")
try:
self.socket.settimeout(self.ack_timeout)
ack: bytes = self.socket.recv(1024)
self.socket.settimeout(self.socket_timeout)
if ack == Message.ack().to_bytes():
self.logger.debug("[<] ACK received")
return True
else:
self.logger.error(f"[!] Invalid ACK received: {ack}")
except TimeoutError as e:
self.logger.error(f"[!] ACK timeout: {e}")
except socket.error as e:
self.logger.error(f"[!] Failed to receive ACK: {e}")
return False
def connect(self, retry_interval: int = 1) -> None:
self.logger.debug( self.logger.debug(
f"Connecting to {self.host}:{self.port} with timeout {self.connect_timeout}s..." f"[>] Sent {sent} bytes: {self.outbound_buffer[:sent]!r}"
) )
try:
self.socket.settimeout(self.connect_timeout)
self.socket.connect((self.host, self.port))
self.socket.settimeout(self.socket_timeout)
self.logger.info(f"[+] Connected to {self.host}:{self.port}")
self.send_hello()
except (
socket.timeout,
ConnectionRefusedError,
ConnectionAbortedError,
) as e:
self.logger.error(
f"[!] Connection to {self.host}:{self.port} failed: {e}"
)
self.logger.info(
f"[.] Retrying connection in {retry_interval} s..."
)
time.sleep(retry_interval)
self.connect(retry_interval=min(30, retry_interval * 2))
def send(self, data: bytes, no_check_ack: bool = False) -> None: self.outbound_buffer = self.outbound_buffer[sent:]
self.logger.debug(f"[>] Sending data: {data}") except BlockingIOError:
while True: # OS buffer full, wait for next EVENT_WRITE
try:
self.socket.sendall(data)
if no_check_ack:
self.logger.debug("[>] Data sent without ACK check")
break break
else: except socket.error as e:
self.logger.info("[>] Data sent") self.logger.error(f"[!] Socket error: {e}")
self.reconnect()
acknowledged: bool = self._check_ack()
if acknowledged:
self.logger.debug("[.] Data acknowledged")
break break
else:
self.logger.warning(
"[!] Data not acknowledged, retrying..."
)
except BrokenPipeError as e: def _receive_inbound(self) -> None:
self.logger.error(f"[!] Broken pipe: {e}") """Receive data into the inbound buffer."""
self.logger.info("[.] Reconnecting...")
self.connect()
except (socket.error, ValueError) as e:
self.logger.error(f"[!] Failed to send data: {e}")
time.sleep(1)
def receive(self) -> bytes:
self.logger.debug("[.] Waiting to receive data...")
try: try:
data: bytes = self.socket.recv(4096) data: bytes = self.socket.recv(4096)
if not data: if data:
self.logger.warning("[!] Received empty message") self.logger.debug(f"[<] Received {len(data)} bytes: {data!r}")
return b"" self.inbound_buffer += data
self.logger.debug(f"[<] Received data: {data}") else:
return data self.logger.debug("[!] Connection closed by the server.")
self.reconnect()
except socket.error as e: except socket.error as e:
self.logger.error(f"[!] Failed to receive data: {e}") self.logger.error(f"[!] Socket error: {e}")
return b"" self.reconnect()
def close(self) -> None: def send(self, message: Message) -> None:
self.logger.debug("Closing connection...") """Send a message to the server.
self.socket.close()
self.logger.info("Connection closed.")
def reconnect(self) -> None: Args:
self.logger.debug("Reconnecting...") message (Message): The message to send.
self.close() """
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.logger.debug(f"[>] Queueing message to send: {message}")
self.connect() if message.ack_required:
self.pending_acks[message.id] = (message, time.time())
self.outbound_buffer += message.to_bytes()
def send_hello(self) -> None: def send_hello(self) -> None:
self.logger.debug("[.] Sending hello message...") """Send a HELLO message to the server."""
self.logger.debug("[*] Sending HELLO message...")
hello_message: Message = Message.hello(self.mac_address) hello_message: Message = Message.hello(self.mac_address)
acknowledged: bool = False self.send(hello_message)
while not acknowledged:
self.send(hello_message.to_bytes(), no_check_ack=True)
self.logger.debug("[.] Hello message sent, waiting for ACK...")
acknowledged = self._check_ack()
if not acknowledged:
self.logger.warning(
"[!] Hello message not acknowledged, retrying..."
)
time.sleep(1)
def _loop(self) -> None: def close(self) -> None:
self.logger.debug("Starting connector loop...") """Close the connection and clean up resources."""
while True: self.logger.debug("[*] Closing connection...")
time.sleep(0.1)
data: bytes = self.receive()
if not data:
self.reconnect()
continue
for line in data.split(b"\n"):
line: bytes = line.strip()
if not line:
continue
self.logger.debug(f"[.] Raw message data: {line}")
try: try:
message: Message = Message.from_bytes(line) self.selector.unregister(self.socket)
except ValueError as e: except Exception as e:
self.logger.error(f"[!] Failed to parse message: {e}") self.logger.error(f"[!] Error unregistering socket: {e}")
continue self.socket.close()
self.logger.info(f"[*] Message received: {message}") self.logger.debug("[.] Connection closed.")
self.on_message(message)
# if self._check_ack(): def reconnect(self) -> None:
# self.logger.debug("[.] ACK verified") """Reconnect to the server."""
# else: self.logger.debug("[*] Reconnecting...")
# self.logger.error("[!] ACK verification failed") self.close()
# reinit socket
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setblocking(False)
self.selector.register(
self.socket,
selectors.EVENT_READ | selectors.EVENT_WRITE,
data=None,
)
self.connect()
def connect(self) -> None:
"""Establish a connection to the server."""
self.logger.debug(f"Connecting to {self.host}:{self.port}...")
connected: bool = False
delay: float = 1.0
while not connected:
try:
self.socket.connect((self.host, self.port))
connected = True
except BlockingIOError:
# Connection in progress
time.sleep(0.1)
except socket.error as e:
self.logger.error(f"[!] Connection error: {e}")
self.logger.debug(f"[.] Retrying in {delay} seconds...")
time.sleep(delay)
delay = min(delay * 2, 30) # exponential backoff
self.logger.debug("[*] Connected, sending HELLO...")
self.send_hello()
def run(self) -> None: def run(self) -> None:
self.logger.debug("Running Connector...") """Run the main event loop."""
try:
self.connect() self.connect()
self._loop() try:
while self.running:
events = self.selector.select(timeout=1)
for key, mask in events:
if mask & selectors.EVENT_READ:
self._receive_inbound()
if mask & selectors.EVENT_WRITE:
self._send_outbound()
# Process inbound buffer for complete messages
while b"\n" in self.inbound_buffer:
message_bytes, self.inbound_buffer = (
self.inbound_buffer.split(b"\n", 1)
)
try:
message: Message = Message.from_bytes(message_bytes)
# handle incoming ACKs
if (
message.category == Category.CONTROL
and message.action == ControlAction.ACK
):
if (
message.payload.get("target_id")
in self.pending_acks
):
target_id = message.payload["target_id"]
self.logger.debug(
f"[.] Received ACK for message ID {target_id}"
)
del self.pending_acks[target_id]
else:
self.on_message(message)
if message.ack_required:
ack_message: Message = Message.ack(message.id)
self.send(ack_message)
self._send_outbound()
except Exception as e:
self.logger.error(f"[!] Failed to parse message: {e}")
time.sleep(0.1)
except KeyboardInterrupt: except KeyboardInterrupt:
self.logger.info("Interrupted by user.") self.logger.debug("[*] Interrupted by user.")
finally: finally:
self.close() self.close()

4
uv.lock generated
View File

@@ -292,8 +292,8 @@ test = [
[[package]] [[package]]
name = "judas-protocol" name = "judas-protocol"
version = "0.2.0" version = "0.5.0"
source = { git = "https://gitea.pufereq.pl/judas/judas_protocol.git#bc1bf46388eb904738893a2f86b5050b4ce2489e" } source = { git = "https://gitea.pufereq.pl/judas/judas_protocol.git#c48b69ecee16f5824ffd8bce8921341d5fa326b7" }
[[package]] [[package]]
name = "markdown-it-py" name = "markdown-it-py"