From 53912ed3394b0b1909f15976bd149b3a023c210c Mon Sep 17 00:00:00 2001 From: Artur Borecki Date: Tue, 23 Sep 2025 23:35:09 +0200 Subject: [PATCH] refactor(connector.py): rewrite `Connector` to use selectors --- src/judas_client/connector.py | 246 +++++++++++++--------------------- 1 file changed, 96 insertions(+), 150 deletions(-) diff --git a/src/judas_client/connector.py b/src/judas_client/connector.py index c86e8cd..67957cd 100644 --- a/src/judas_client/connector.py +++ b/src/judas_client/connector.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging as lg +import selectors import socket import time from typing import Callable @@ -16,8 +17,6 @@ class Connector: host: str, port: int, *, - connect_timeout: float = 5.0, - ack_timeout: float | None = None, on_message: Callable[[Message], None], ) -> None: self.logger: lg.Logger = lg.getLogger( @@ -27,178 +26,125 @@ class Connector: self.host: str = host self.port: int = port - self.socket_timeout: None = None - self.connect_timeout: float = connect_timeout - self.ack_timeout: float | None = ack_timeout + self.selector = selectors.DefaultSelector() self.socket: socket.socket = socket.socket( socket.AF_INET, socket.SOCK_STREAM ) + self.socket.setblocking(False) + + self.selector.register( + self.socket, + selectors.EVENT_READ | selectors.EVENT_WRITE, + data=None, + ) self.mac_address: str = mac_address + self.inbound_buffer: bytes = b"" + self.outbound_buffer: bytes = b"" + self.on_message: Callable[[Message], None] = on_message - def _send_ack(self) -> None: - self.logger.debug("[>] Sending ACK...") - try: - self.socket.sendall(Message.ack().to_bytes()) - self.logger.debug("[<] ACK sent") - except socket.error as e: - self.logger.error(f"[!] Failed to send ACK: {e}") - - def _check_ack(self) -> bool: - self.logger.debug("[.] Waiting for ACK...") - try: - self.socket.settimeout(self.ack_timeout) - ack: bytes = self.socket.recv(1024) - self.socket.settimeout(self.socket_timeout) - - if ack == Message.ack().to_bytes(): - self.logger.debug("[<] ACK received") - return True - else: - self.logger.error(f"[!] Invalid ACK received: {ack}") - - except TimeoutError as e: - self.logger.error(f"[!] ACK timeout: {e}") - - except socket.error as e: - self.logger.error(f"[!] Failed to receive ACK: {e}") - - return False - - def connect(self) -> None: - retry_interval: int = 1 - connected: bool = False - while not connected: - self.logger.debug( - f"[.] Connecting to {self.host}:{self.port} with timeout {self.connect_timeout}s..." - ) + def _send_outbound(self) -> None: + while self.outbound_buffer: try: - self.socket.settimeout(self.connect_timeout) - self.socket.connect((self.host, self.port)) - self.socket.settimeout(self.socket_timeout) - self.logger.info(f"[+] Connected to {self.host}:{self.port}") - self.send_hello() - connected = True - except ( - socket.timeout, - ConnectionRefusedError, - ConnectionAbortedError, - ) as e: - self.logger.error( - f"[!] Connection to {self.host}:{self.port} failed: {e}" + sent = self.socket.send(self.outbound_buffer) + self.logger.debug( + f"[>] Sent {sent} bytes: {self.outbound_buffer[:sent]!r}" ) - self.logger.info( - f"[.] Retrying connection in {retry_interval} s..." - ) - time.sleep(retry_interval) - retry_interval = min( - retry_interval * 2, 30 - ) # exponential backoff + self.outbound_buffer = self.outbound_buffer[sent:] + except BlockingIOError: + # OS buffer full, wait for next EVENT_WRITE + break + except socket.error as e: + self.logger.error(f"[!] Socket error: {e}") + self.reconnect() + break - def send(self, data: bytes, no_check_ack: bool = False) -> None: - self.logger.debug(f"[>] Sending data: {data}") - while True: - try: - self.socket.sendall(data) - - if no_check_ack: - self.logger.debug("[>] Data sent without ACK check") - break - else: - self.logger.info("[>] Data sent") - - acknowledged: bool = self._check_ack() - if acknowledged: - self.logger.debug("[.] Data acknowledged") - break - else: - self.logger.warning( - "[!] Data not acknowledged, retrying..." - ) - - except BrokenPipeError as e: - self.logger.error(f"[!] Broken pipe: {e}") - self.logger.info("[.] Reconnecting...") - self.connect() - except (socket.error, ValueError) as e: - self.logger.error(f"[!] Failed to send data: {e}") - time.sleep(1) - - def receive(self) -> bytes: - self.logger.debug("[.] Waiting to receive data...") + def _receive_inbound(self) -> None: try: data: bytes = self.socket.recv(4096) - if not data: - self.logger.warning("[!] Received empty message") - return b"" - self.logger.debug(f"[<] Received data: {data}") - return data + if data: + self.logger.debug(f"[<] Received {len(data)} bytes: {data!r}") + self.inbound_buffer += data + else: + self.logger.debug("[!] Connection closed by the server.") + self.reconnect() except socket.error as e: - self.logger.error(f"[!] Failed to receive data: {e}") - return b"" - - def close(self) -> None: - self.logger.debug("Closing connection...") - self.socket.close() - self.logger.info("Connection closed.") - - def reconnect(self) -> None: - self.logger.debug("Reconnecting...") - self.close() - self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.connect() + self.logger.error(f"[!] Socket error: {e}") + self.reconnect() def send_hello(self) -> None: - self.logger.debug("[.] Sending hello message...") - hello_message: Message = Message.hello(self.mac_address) - acknowledged: bool = False - while not acknowledged: - self.send(hello_message.to_bytes(), no_check_ack=True) - self.logger.debug("[.] Hello message sent, waiting for ACK...") - acknowledged = self._check_ack() - if not acknowledged: - self.logger.warning( - "[!] Hello message not acknowledged, retrying..." - ) - time.sleep(1) + self.logger.debug("[*] Sending HELLO message...") + hello_message: bytes = Message.hello(self.mac_address).to_bytes() + self.outbound_buffer += hello_message + self._send_outbound() - def _loop(self) -> None: - self.logger.debug("Starting connector loop...") - while True: - time.sleep(0.1) - data: bytes = self.receive() - if not data: - self.reconnect() - continue - for line in data.split(b"\n"): - line: bytes = line.strip() + def close(self) -> None: + self.logger.debug("[*] Closing connection...") + self.selector.unregister(self.socket) + self.socket.close() + self.logger.debug("[.] Connection closed.") - if not line: - continue + def reconnect(self) -> None: + self.logger.debug("[*] Reconnecting...") + self.close() - self.logger.debug(f"[.] Raw message data: {line}") - try: - message: Message = Message.from_bytes(line) - except ValueError as e: - self.logger.error(f"[!] Failed to parse message: {e}") - continue - self.logger.info(f"[*] Message received: {message}") - self.on_message(message) + # reinit socket + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.setblocking(False) + self.selector.register( + self.socket, + selectors.EVENT_READ | selectors.EVENT_WRITE, + data=None, + ) + self.connect() - # if self._check_ack(): - # self.logger.debug("[.] ACK verified") - # else: - # self.logger.error("[!] ACK verification failed") + def connect(self) -> None: + self.logger.debug(f"Connecting to {self.host}:{self.port}...") + connected: bool = False + delay: float = 1.0 + while not connected: + try: + self.socket.connect((self.host, self.port)) + connected = True + except BlockingIOError: + # Connection in progress + time.sleep(0.1) + except socket.error as e: + self.logger.error(f"[!] Connection error: {e}") + self.logger.debug(f"[.] Retrying in {delay} seconds...") + time.sleep(delay) + delay = min(delay * 2, 30) # exponential backoff + + self.logger.debug("[*] Connected, sending HELLO...") + self.send_hello() def run(self) -> None: - self.logger.debug("Running Connector...") + self.connect() try: - self.connect() - self._loop() + while True: + events = self.selector.select(timeout=1) + for key, mask in events: + if mask & selectors.EVENT_READ: + self._receive_inbound() + if mask & selectors.EVENT_WRITE: + self._send_outbound() + + # Process inbound buffer for complete messages + while b"\n" in self.inbound_buffer: + message_bytes, self.inbound_buffer = ( + self.inbound_buffer.split(b"\n", 1) + ) + try: + message = Message.from_bytes(message_bytes) + self.on_message(message) + except Exception as e: + self.logger.error(f"[!] Failed to parse message: {e}") + + time.sleep(0.1) except KeyboardInterrupt: - self.logger.info("Interrupted by user.") + self.logger.debug("[*] Interrupted by user.") finally: self.close()