diff --git a/src/judas_client/connector.py b/src/judas_client/connector.py index b9b7738..cb557e1 100644 --- a/src/judas_client/connector.py +++ b/src/judas_client/connector.py @@ -7,7 +7,7 @@ import socket import time from typing import Callable -from judas_protocol import Message +from judas_protocol import Category, ControlAction, Message class Connector: @@ -54,6 +54,10 @@ class Connector: self.inbound_buffer: bytes = b"" self.outbound_buffer: bytes = b"" + self.pending_acks: dict[str, tuple[Message, float]] = {} + + self.running: bool = True + self.on_message: Callable[[Message], None] = on_message def _send_outbound(self) -> None: @@ -64,6 +68,7 @@ class Connector: self.logger.debug( f"[>] Sent {sent} bytes: {self.outbound_buffer[:sent]!r}" ) + self.outbound_buffer = self.outbound_buffer[sent:] except BlockingIOError: # OS buffer full, wait for next EVENT_WRITE @@ -82,22 +87,37 @@ class Connector: self.inbound_buffer += data else: self.logger.debug("[!] Connection closed by the server.") - self.reconnect() + # TODO: close only when instructed by server + self.close() + self.running = False except socket.error as e: self.logger.error(f"[!] Socket error: {e}") self.reconnect() + def send(self, message: Message) -> None: + """Send a message to the server. + + Args: + message (Message): The message to send. + """ + self.logger.debug(f"[>] Queueing message to send: {message}") + if message.ack_required: + self.pending_acks[message.id] = (message, time.time()) + self.outbound_buffer += message.to_bytes() + def send_hello(self) -> None: """Send a HELLO message to the server.""" self.logger.debug("[*] Sending HELLO message...") - hello_message: bytes = Message.hello(self.mac_address).to_bytes() - self.outbound_buffer += hello_message - self._send_outbound() + hello_message: Message = Message.hello(self.mac_address) + self.send(hello_message) def close(self) -> None: """Close the connection and clean up resources.""" self.logger.debug("[*] Closing connection...") - self.selector.unregister(self.socket) + try: + self.selector.unregister(self.socket) + except Exception as e: + self.logger.error(f"[!] Error unregistering socket: {e}") self.socket.close() self.logger.debug("[.] Connection closed.") @@ -141,7 +161,7 @@ class Connector: """Run the main event loop.""" self.connect() try: - while True: + while self.running: events = self.selector.select(timeout=1) for key, mask in events: if mask & selectors.EVENT_READ: @@ -156,7 +176,27 @@ class Connector: ) try: message = Message.from_bytes(message_bytes) - self.on_message(message) + # handle incoming ACKs + if ( + message.category == Category.CONTROL + and message.action == ControlAction.ACK + ): + if ( + message.payload.get("target_id") + in self.pending_acks + ): + target_id = message.payload["target_id"] + self.logger.debug( + f"[.] Received ACK for message ID {target_id}" + ) + del self.pending_acks[target_id] + else: + self.on_message(message) + + if message.ack_required: + ack_message = Message.ack(message.id) + self.send(ack_message) + self._send_outbound() except Exception as e: self.logger.error(f"[!] Failed to parse message: {e}") diff --git a/uv.lock b/uv.lock index e8ad24e..7dc30b7 100644 --- a/uv.lock +++ b/uv.lock @@ -293,7 +293,7 @@ test = [ [[package]] name = "judas-protocol" version = "0.4.3" -source = { git = "https://gitea.pufereq.pl/judas/judas_protocol.git#5ef300ff93bb43d4db28ae019fec30f48f88152b" } +source = { git = "https://gitea.pufereq.pl/judas/judas_protocol.git#332ce3ffa16ba43d6af1ba71bce1bc633e1661a9" } [[package]] name = "markdown-it-py"