7 Commits

4 changed files with 159 additions and 47 deletions

View File

@@ -6,13 +6,18 @@ import selectors
import socket import socket
import threading import threading
import time import time
from typing import TYPE_CHECKING, Any
import yaml import yaml
from typing import Any
from judas_protocol import Category, ControlAction, Message from judas_protocol import Category, ControlAction, Message
from judas_server.backend.client import Client, ClientStatus from judas_server.backend.client import Client, ClientStatus
from judas_server.backend.handler.hello_handler import HelloHandler
if TYPE_CHECKING:
from typing import Callable
from judas_protocol import ActionType
class BackendServer: class BackendServer:
@@ -41,13 +46,25 @@ class BackendServer:
) )
self.clients: dict[str, Client] = {} self.clients: dict[str, Client] = {}
self.known_clients: dict[str, dict[str, str | float]] = ( self.known_clients: dict[str, dict[str, str | float]] = (
self._load_known_clients() self._load_known_clients()
) )
self.message_handlers: dict[
tuple[Category, ActionType], Callable[[Client, Message], None]
] = {}
self.running: bool = False self.running: bool = False
def _initialize_handlers(self) -> None:
"""Initialize message handlers."""
hello_handler = HelloHandler(self)
self.message_handlers[(Category.CONTROL, ControlAction.HELLO)] = (
hello_handler.handle
)
def _load_known_clients(self) -> dict[str, dict[str, str | float]]: def _load_known_clients(self) -> dict[str, dict[str, str | float]]:
"""Load the list of known clients from a YAML file and validate.""" """Load the list of known clients from a YAML file and validate."""
known_clients: dict[str, dict[str, str | float]] = {} known_clients: dict[str, dict[str, str | float]] = {}
@@ -168,6 +185,12 @@ class BackendServer:
""" """
self.logger.info(f"[-] Disconnecting {client}...") self.logger.info(f"[-] Disconnecting {client}...")
if client.socket is None:
self.logger.warning(
f"Client {client} has no socket, nothing to disconnect."
)
return
try: try:
self.selector.unregister(client.socket) self.selector.unregister(client.socket)
except Exception as e: except Exception as e:
@@ -230,48 +253,6 @@ class BackendServer:
self._disconnect(client) self._disconnect(client)
return return
if client.id is None:
# expect HELLO message
try:
msg = Message.from_bytes(client.inbound)
except Exception as e:
self.logger.error(
f"Failed to parse HELLO message from {client}: {e}"
)
self._disconnect(client)
return
if (
msg.category == Category.CONTROL
and msg.action == ControlAction.HELLO
and msg.payload.get("id") is not None
):
client.id = msg.payload["id"]
if (
client.id in self.clients
and self.clients[client.id].status == "connected"
):
old_client: Client = self.clients[client.id]
self.logger.warning(
f"Client {client.id} is already connected from {old_client.addr}, disconnecting old client..."
)
self.send_close(old_client)
self.clients[client.id] = client
self.known_clients[client.id] = {
"last_seen": client.last_seen
}
self._save_known_clients()
client.status = ClientStatus.ONLINE
self.logger.info(f"[+] Registered new client {client}")
else:
self.logger.error(
f"Expected HELLO message from {client}, got {msg}"
)
self._disconnect(client)
return
while b"\n" in client.inbound: while b"\n" in client.inbound:
line, client.inbound = client.inbound.split(b"\n", 1) line, client.inbound = client.inbound.split(b"\n", 1)
self.logger.debug( self.logger.debug(
@@ -280,13 +261,35 @@ class BackendServer:
try: try:
msg = Message.from_bytes(line) msg = Message.from_bytes(line)
self.logger.info(f"[.] Parsed message {msg.id}") self.logger.info(f"[.] Parsed message {msg.id}")
if client.id is None:
self.logger.debug(
f"Client {client} has no ID, expecting HELLO message..."
)
if (
msg.category != Category.CONTROL
or msg.action != ControlAction.HELLO
):
self.logger.warning(
f"First message from {client} must be HELLO, disconnecting..."
)
self._disconnect(client)
return
handler: Callable[[Client, Message], None] | None = (
self.message_handlers.get(
(msg.category, msg.action), None
)
)
if handler is not None:
handler(client, msg)
if msg.ack_required: if msg.ack_required:
self.send_ack(client, target_id=msg.id) self.send_ack(client, target_id=msg.id)
except Exception as e: except Exception as e:
self.logger.error( self.logger.error(
f"Failed to parse message from {client}: {e}" f"Failed to parse message from {client}: {e}"
) )
self._disconnect(client)
return return
if mask & selectors.EVENT_WRITE and client.outbound: if mask & selectors.EVENT_WRITE and client.outbound:

View File

@@ -0,0 +1,4 @@
from .base_handler import BaseHandler
from .hello_handler import HelloHandler
__all__ = ["BaseHandler", "HelloHandler"]

View File

@@ -0,0 +1,41 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import logging as lg
from typing import TYPE_CHECKING
from judas_server.backend.client import Client
if TYPE_CHECKING:
from judas_protocol import Message
from judas_server.backend import BackendServer
class BaseHandler:
"""BaseHandler is the base class for all message handlers in the backend server.
It defines the interface for handling messages and provides common functionality for all handlers.
"""
def __init__(self, backend_server: BackendServer) -> None:
"""Initialize the BaseHandler with a reference to the backend server.
Args:
backend_server (BackendServer): The backend server instance that this handler belongs to.
"""
self.logger: lg.Logger = lg.getLogger(
f"{__name__}.{self.__class__.__name__}"
)
self.backend_server: BackendServer = backend_server
def handle(self, client: Client, message: Message) -> None:
"""Handle a message from a client.
This method must be implemented by subclasses to define the specific handling logic for different message types.
Args:
client (Client): The client that sent the message.
message (Message): The message to be handled.
"""
raise NotImplementedError("handle() must be implemented by subclasses")

View File

@@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import TYPE_CHECKING, override
from judas_protocol import Category, ControlAction, Message
from judas_server.backend.client import ClientStatus
from judas_server.backend.handler import BaseHandler
if TYPE_CHECKING:
from judas_server.backend.backend_server import BackendServer
from judas_server.backend.client import Client
class HelloHandler(BaseHandler):
def __init__(self, backend_server: BackendServer) -> None:
super().__init__(backend_server)
@override
def handle(self, client: Client, message: Message) -> None:
if client.id is not None:
return
if (
message.category != Category.CONTROL
or message.action != ControlAction.HELLO
):
self.logger.error(
f"Expected HELLO message from {client}, got {message}, disconnecting client..."
)
self.backend_server._disconnect(client)
return
if message.payload.get("id") is None:
self.logger.error(
f"HELLO message from {client} missing 'id' field, disconnecting client..."
)
self.backend_server._disconnect(client)
return
client.id = message.payload["id"]
# check if client already connected, if so disconnect old client and register new one
if (
client.id in self.backend_server.clients
and self.backend_server.clients[client.id].status == "connected"
):
old_client: Client = self.backend_server.clients[client.id]
self.backend_server.logger.warning(
f"Client {client.id} is already connected from {old_client.addr}, disconnecting old client..."
)
self.backend_server.send_close(old_client)
return
self.backend_server.clients[client.id] = client # type: ignore
self.backend_server.known_clients[client.id] = { # type: ignore
"last_seen": client.last_seen
}
self.backend_server._save_known_clients()
client.status = ClientStatus.ONLINE
self.logger.info(f"[+] Registered new client {client}")