diff --git a/.gitignore b/.gitignore index 07febb3..181b2a7 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,6 @@ logs/ # Sphinx docs/_build/ docs/ref/modules/ + +# known clients +config/known_clients.yaml diff --git a/config/.gitkeep b/config/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/src/judas_server/__init__.py b/src/judas_server/__init__.py index c66c5db..5bcf5d2 100644 --- a/src/judas_server/__init__.py +++ b/src/judas_server/__init__.py @@ -1 +1 @@ -__version__: str = "0.1.0" +__version__: str = "0.5.0" diff --git a/src/judas_server/__main__.py b/src/judas_server/__main__.py index d03289a..e850375 100644 --- a/src/judas_server/__main__.py +++ b/src/judas_server/__main__.py @@ -12,6 +12,8 @@ if __name__ == "__main__": format="%(asctime)s : [%(levelname)s] : %(threadName)s : %(name)s :: %(message)s", ) + lg.getLogger("werkzeug").setLevel(lg.WARNING) + ladygaga_logger = lg.getLogger(f"{__name__}.LAGA_DYGA") ladygaga_logger.info(LADY_GAGA) diff --git a/src/judas_server/backend/__init__.py b/src/judas_server/backend/__init__.py index 8e31517..60e15d3 100644 --- a/src/judas_server/backend/__init__.py +++ b/src/judas_server/backend/__init__.py @@ -1,3 +1,5 @@ from .backend_server import BackendServer +from .client import Client +from .client_status import ClientStatus -__all__ = ["BackendServer"] +__all__ = ["BackendServer", "Client", "ClientStatus"] diff --git a/src/judas_server/backend/backend_server.py b/src/judas_server/backend/backend_server.py index b1bea71..cce28a7 100644 --- a/src/judas_server/backend/backend_server.py +++ b/src/judas_server/backend/backend_server.py @@ -6,16 +6,24 @@ import selectors import socket import threading import time +from typing import TYPE_CHECKING, Any, Final + import yaml - -from typing import Any - from judas_protocol import Category, ControlAction, Message from judas_server.backend.client import Client, ClientStatus +from judas_server.backend.handler.hello_handler import HelloHandler + +if TYPE_CHECKING: + from typing import Callable + + from judas_protocol import ActionType class BackendServer: + ACK_TIMEOUT: Final[float] = 5.0 # seconds + HELLO_TIMEOUT: Final[float] = 3.0 # seconds + def __init__(self, host: str = "0.0.0.0", port: int = 3692) -> None: """Initialize the backend server. @@ -28,27 +36,6 @@ class BackendServer: ) self.logger.debug("Initializing Server...") - self.known_clients: dict[str, dict[str, str | float]] = {} - try: - with open("cache/known_clients.yaml", "r") as f: - self.known_clients = ( - yaml.safe_load(f).get("known_clients", {}) or {} - ) - self.logger.debug( - f"Loaded known clients: {self.known_clients}" - ) - self.logger.info( - f"Loaded {len(self.known_clients)} known clients" - ) - except FileNotFoundError: - self.logger.warning( - "known_clients.yaml not found, creating empty known clients list" - ) - with open("cache/known_clients.yaml", "w") as f: - yaml.safe_dump({"known_clients": {}}, f) - except Exception as e: - self.logger.error(f"Error loading known clients: {e}") - self.selector = selectors.DefaultSelector() self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.server_socket.setsockopt( @@ -63,20 +50,85 @@ class BackendServer: self.clients: dict[str, Client] = {} - if self.known_clients: - for client_id in self.known_clients: - client = Client(id=client_id, addr=None, socket=None) - client.status = ClientStatus.OFFLINE - client.last_seen = float( - self.known_clients[client_id].get("last_seen", 0.0) - ) - self.clients[client_id] = client + self.known_clients: dict[str, dict[str, str | float]] = {} + self.known_clients = self._load_known_clients() + + self.message_handlers: dict[ + tuple[Category, ActionType], Callable[[Client, Message], None] + ] = {} + self._initialize_handlers() + + self.pending_acks: list[tuple[Client, Message, float]] = [] + self.pending_hello: dict[Client, float] = {} self.running: bool = False + def _initialize_handlers(self) -> None: + """Initialize message handlers.""" + + hello_handler = HelloHandler(self) + + self.message_handlers[(Category.CONTROL, ControlAction.HELLO)] = ( + hello_handler.handle + ) + + def _load_known_clients(self) -> dict[str, dict[str, str | float]]: + """Load the list of known clients from a YAML file and validate.""" + known_clients: dict[str, dict[str, str | float]] = {} + + try: + with open("config/known_clients.yaml", "r") as f: + data = yaml.safe_load(f) + + if not isinstance(data, dict): + raise ValueError("YAML root must be a dict") + + known_clients = data.get("known_clients", {}) or {} + + if not isinstance(known_clients, dict): + raise ValueError("'known_clients' must be a dict") + + for client_id, client_data in known_clients.items(): + if not isinstance(client_data, dict): + raise ValueError( + f"Client {client_id} data must be a dict" + ) + last_seen = client_data.get("last_seen", 0.0) + if not isinstance(last_seen, (float, int)): + raise ValueError( + f"Client {client_id} 'last_seen' must be a float or int" + ) + + self.logger.debug(f"Loaded known clients: {known_clients}") + self.logger.info(f"Loaded {len(known_clients)} known clients") + + for client_id in known_clients: + client = Client(id=client_id, addr=None, socket=None) + client.status = ClientStatus.OFFLINE + client.last_seen = float( + known_clients[client_id].get("last_seen", 0.0) + ) + self.clients[client_id] = client + + except FileNotFoundError: + self.logger.warning( + "known_clients.yaml not found, creating empty known clients list" + ) + self._save_known_clients() + except Exception as e: + self.logger.error(f"Error loading known clients: {e}") + raise + + return known_clients + def _save_known_clients(self) -> None: """Save the list of known clients to a YAML file.""" - with open("cache/known_clients.yaml", "w") as f: + with open("config/known_clients.yaml", "w") as f: + f.write( + "# This file is automatically generated by BackendServer.\n" + + "# Do not edit manually.\n" + + f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}\n\n" + ) yaml.safe_dump({"known_clients": self.known_clients}, f) self.logger.debug("Saved known clients") @@ -99,6 +151,22 @@ class BackendServer: ) time.sleep(1) + def send(self, client: Client, msg: Message) -> None: + """Send a message to a client. + + Args: + client (Client): The client to send the message to. + msg (Message): The message to send. + """ + msg_bytes: bytes = msg.to_bytes() + self.logger.info( + f"[>] Sending message {msg.id} to {client}, category: {msg.category}, action: {msg.action}, ack_required: {msg.ack_required}" + ) + self.logger.debug(f"[>] Message bytes: {msg_bytes!r}") + if msg.ack_required: + self.pending_acks.append((client, msg, time.time())) + client.outbound += msg_bytes + def send_ack(self, client: Client, target_id: str) -> None: """Send an ACK message to a client. @@ -106,9 +174,9 @@ class BackendServer: client (Client): The client to send the ACK to. target_id (str): The id of the ACK'd message. """ - ack: bytes = Message.ack(target_id=target_id).to_bytes() + ack: Message = Message.Control.ack(target_id=target_id) self.logger.info(f"[>] Sending ACK to {client}") - client.outbound += ack + self.send(client, ack) def send_close(self, client: Client) -> None: """Send a CLOSE message to a client. @@ -116,9 +184,9 @@ class BackendServer: Args: client (Client): The client to send the CLOSE message to. """ - close_msg: bytes = Message.close().to_bytes() + close_msg: Message = Message.Control.close() self.logger.info(f"[>] Sending CLOSE to {client}") - client.outbound += close_msg + self.send(client, close_msg) def _accept_connection(self, sock: socket.socket) -> None: """Accept a new client connection. @@ -134,6 +202,8 @@ class BackendServer: events = selectors.EVENT_READ | selectors.EVENT_WRITE self.selector.register(conn, events, data=client) + self.pending_hello[client] = time.time() + self.logger.info(f"[+] Registered client {client}, HELLO pending...") def _disconnect(self, client: Client) -> None: @@ -144,6 +214,12 @@ class BackendServer: """ self.logger.info(f"[-] Disconnecting {client}...") + if client.socket is None or client.socket._closed: + self.logger.warning( + f"Client {client} has no socket, nothing to disconnect." + ) + return + try: self.selector.unregister(client.socket) except Exception as e: @@ -202,51 +278,6 @@ class BackendServer: try: if mask & selectors.EVENT_READ: self._receive_inbound(sock, client) - if not client.inbound: - self._disconnect(client) - return - - if client.id is None: - # expect HELLO message - try: - msg = Message.from_bytes(client.inbound) - except Exception as e: - self.logger.error( - f"Failed to parse HELLO message from {client}: {e}" - ) - self._disconnect(client) - return - - if ( - msg.category == Category.CONTROL - and msg.action == ControlAction.HELLO - and msg.payload.get("id") is not None - ): - client.id = msg.payload["id"] - if ( - client.id in self.clients - and self.clients[client.id].status == "connected" - ): - old_client: Client = self.clients[client.id] - self.logger.warning( - f"Client {client.id} is already connected from {old_client.addr}, disconnecting old client..." - ) - self.send_close(old_client) - - self.clients[client.id] = client - self.known_clients[client.id] = { - "last_seen": client.last_seen - } - self._save_known_clients() - client.status = ClientStatus.ONLINE - - self.logger.info(f"[+] Registered new client {client}") - else: - self.logger.error( - f"Expected HELLO message from {client}, got {msg}" - ) - self._disconnect(client) - return while b"\n" in client.inbound: line, client.inbound = client.inbound.split(b"\n", 1) @@ -256,13 +287,40 @@ class BackendServer: try: msg = Message.from_bytes(line) self.logger.info(f"[.] Parsed message {msg.id}") + + if client.id is None: + self.logger.debug( + f"Client {client} has no ID, expecting HELLO message..." + ) + if ( + msg.category != Category.CONTROL + or msg.action != ControlAction.HELLO + ): + self.logger.warning( + f"First message from {client} must be HELLO, disconnecting..." + ) + self._disconnect(client) + continue + + handler: Callable[[Client, Message], None] | None = ( + self.message_handlers.get( + (msg.category, msg.action), None + ) + ) + if handler is not None: + handler(client, msg) + else: + self.logger.warning( + f"No handler for message {msg.id} with category {msg.category} and action {msg.action}" + ) + continue + if msg.ack_required: self.send_ack(client, target_id=msg.id) except Exception as e: self.logger.error( f"Failed to parse message from {client}: {e}" ) - self._disconnect(client) return if mask & selectors.EVENT_WRITE and client.outbound: @@ -302,6 +360,25 @@ class BackendServer: and now - client.last_seen > 60 * 60 * 24 # 24 hours ): self.clients[client.id].status = ClientStatus.STALE + + # check pending ACKs + for client, msg, timestamp in self.pending_acks[:]: + if time.time() - timestamp > self.ACK_TIMEOUT: + self.logger.warning( + f"ACK timeout for message {msg.id} to {client}, resending..." + ) + self.send(client, msg) + self.pending_acks.remove((client, msg, timestamp)) + + # check pending HELLOs + for client, timestamp in list(self.pending_hello.items()): + if time.time() - timestamp > self.HELLO_TIMEOUT: + self.logger.warning( + f"HELLO timeout for {client}, disconnecting..." + ) + self._disconnect(client) + del self.pending_hello[client] + time.sleep(0.001) # prevent 100% CPU usage except Exception as e: diff --git a/src/judas_server/backend/client.py b/src/judas_server/backend/client.py index 500f9f4..1c20fa2 100644 --- a/src/judas_server/backend/client.py +++ b/src/judas_server/backend/client.py @@ -5,24 +5,19 @@ from __future__ import annotations import logging as lg import socket -from enum import Enum import time - -class ClientStatus(str, Enum): - """Enumeration of client connection statuses.""" - - ONLINE = "online" - PENDING = "pending" - OFFLINE = "offline" - STALE = "stale" +from judas_server.backend.client_status import ClientStatus class Client: """Represents a client.""" def __init__( - self, id: str | None, addr: tuple[str, int], socket: socket.socket + self, + id: str | None, + addr: tuple[str, int] | None, + socket: socket.socket | None, ) -> None: """Initialize the client. @@ -41,13 +36,15 @@ class Client: self.last_seen: float = 0.0 # unix timestanp of last inbound message self.status: ClientStatus = ClientStatus.PENDING - self.socket: socket.socket = socket - self.addr: tuple[str, int] = addr + self.socket: socket.socket | None = socket + self.addr: tuple[str, int] | None = addr self.inbound: bytes = b"" self.outbound: bytes = b"" def __str__(self) -> str: - return f"Client({self.id} ({self.addr[0]}:{self.addr[1]}))" + if self.addr: + return f"Client({self.id} ({self.addr[0]}:{self.addr[1]}))" + return f"Client({self.id} (not connected))" def __repr__(self) -> str: return f"Client({self.id}, {self.addr})" @@ -55,6 +52,11 @@ class Client: def disconnect(self) -> None: """Disconnect the client and close the socket.""" self.logger.debug(f"Disconnecting Client {self}...") + if self.socket is None: + self.logger.warning( + f"Client {self} not connected, nothing to disconnect." + ) + return try: self.socket.close() except Exception as e: diff --git a/src/judas_server/backend/client_status.py b/src/judas_server/backend/client_status.py new file mode 100644 index 0000000..1dabba6 --- /dev/null +++ b/src/judas_server/backend/client_status.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- +from enum import Enum + + +class ClientStatus(str, Enum): + """Enumeration of client connection statuses.""" + + ONLINE = "online" + PENDING = "pending" + OFFLINE = "offline" + STALE = "stale" diff --git a/src/judas_server/backend/handler/__init__.py b/src/judas_server/backend/handler/__init__.py new file mode 100644 index 0000000..dc07972 --- /dev/null +++ b/src/judas_server/backend/handler/__init__.py @@ -0,0 +1,5 @@ +from .base_handler import BaseHandler +from .hello_handler import HelloHandler +from .ack_handler import AckHandler + +__all__ = ["BaseHandler", "HelloHandler", "AckHandler"] diff --git a/src/judas_server/backend/handler/ack_handler.py b/src/judas_server/backend/handler/ack_handler.py new file mode 100644 index 0000000..8204be0 --- /dev/null +++ b/src/judas_server/backend/handler/ack_handler.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- + +from typing import TYPE_CHECKING + +from .base_handler import BaseHandler + +if TYPE_CHECKING: + from judas_protocol import Message + + from judas_server.backend import BackendServer, Client + + +class AckHandler(BaseHandler): + def __init__(self, backend_server: BackendServer) -> None: + super().__init__(backend_server) + + def handle(self, client: Client, message: Message) -> None: + pending_acks = self.backend_server.pending_acks + if message.id in pending_acks: + del pending_acks[message.id] + self.logger.debug( + f"[*] Received ACK for message {message.id} from {client}." + ) + else: + self.logger.warning( + f"[!] Received ACK for unknown (or ACK'd) message {message.id} from {client}." + ) diff --git a/src/judas_server/backend/handler/base_handler.py b/src/judas_server/backend/handler/base_handler.py new file mode 100644 index 0000000..e7dfa8c --- /dev/null +++ b/src/judas_server/backend/handler/base_handler.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import logging as lg +from typing import TYPE_CHECKING + +from judas_server.backend.client import Client + +if TYPE_CHECKING: + from judas_protocol import Message + + from judas_server.backend import BackendServer + + +class BaseHandler: + """BaseHandler is the base class for all message handlers in the backend server. + + It defines the interface for handling messages and provides common functionality for all handlers. + """ + + def __init__(self, backend_server: BackendServer) -> None: + """Initialize the BaseHandler with a reference to the backend server. + + Args: + backend_server (BackendServer): The backend server instance that this handler belongs to. + """ + self.logger: lg.Logger = lg.getLogger( + f"{__name__}.{self.__class__.__name__}" + ) + self.backend_server: BackendServer = backend_server + + def handle(self, client: Client, message: Message) -> None: + """Handle a message from a client. + + This method must be implemented by subclasses to define the specific handling logic for different message types. + + Args: + client (Client): The client that sent the message. + message (Message): The message to be handled. + """ + raise NotImplementedError("handle() must be implemented by subclasses") diff --git a/src/judas_server/backend/handler/hello_handler.py b/src/judas_server/backend/handler/hello_handler.py new file mode 100644 index 0000000..a9d78a1 --- /dev/null +++ b/src/judas_server/backend/handler/hello_handler.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +from typing import TYPE_CHECKING, override + +from judas_protocol import Category, ControlAction, Message + +from judas_server.backend.client import ClientStatus +from judas_server.backend.handler import BaseHandler + +if TYPE_CHECKING: + from judas_server.backend.backend_server import BackendServer + from judas_server.backend.client import Client + + +class HelloHandler(BaseHandler): + def __init__(self, backend_server: BackendServer) -> None: + super().__init__(backend_server) + + @override + def handle(self, client: Client, message: Message) -> None: + if client.id is not None: + return + + if ( + message.category != Category.CONTROL + or message.action != ControlAction.HELLO + ): + self.logger.error( + f"Expected HELLO message from {client}, got {message}, disconnecting client..." + ) + self.backend_server._disconnect(client) + return + + if message.payload.get("id") is None: + self.logger.error( + f"HELLO message from {client} missing 'id' field, disconnecting client..." + ) + self.backend_server._disconnect(client) + return + + client.id = message.payload["id"] + + # check if client already connected, if so disconnect old client and register new one + if ( + client.id in self.backend_server.clients + and self.backend_server.clients[client.id].status + == ClientStatus.ONLINE + ): + old_client: Client = self.backend_server.clients[client.id] + self.backend_server.logger.warning( + f"Client {client.id} is already connected from {old_client.addr}, disconnecting old client..." + ) + self.backend_server.send_close(old_client) + return + + self.backend_server.clients[client.id] = client # type: ignore + self.backend_server.known_clients[client.id] = { # type: ignore + "last_seen": client.last_seen + } + + del self.backend_server.pending_hello[client] + self.backend_server._save_known_clients() + client.status = ClientStatus.ONLINE + + self.logger.info(f"[+] Registered new client {client}") diff --git a/uv.lock b/uv.lock index a96478f..1c9901a 100644 --- a/uv.lock +++ b/uv.lock @@ -358,8 +358,8 @@ wheels = [ [[package]] name = "judas-protocol" -version = "0.6.0" -source = { git = "https://gitea.pufereq.pl/judas/judas_protocol.git#d16c1914ba343aed300f1c5fae0201370c3274de" } +version = "0.8.0" +source = { git = "https://gitea.pufereq.pl/judas/judas_protocol.git#a805ccf38edffadc1b8c8b276e60758c86516cd3" } [[package]] name = "judas-server"