diff --git a/README.md b/README.md index e0fdca3..5982029 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,7 @@ ROUGHLY_SERVER_PRIVATE_KEY="your_private_key_here" roughly -v server run By default, the server will bind to `0.0.0.0:2002`. You can change this using the `--host` and `--port` flags. I recommend running the server with verbose logging enabled (`-v`), so you can see incoming requests and debug any issues. +Additionally you might want to consider turning off response greasing while testing using the `--no-grease` flag. ### As a library diff --git a/roughly/cli.py b/roughly/cli.py index 59f0a4b..6532338 100644 --- a/roughly/cli.py +++ b/roughly/cli.py @@ -263,12 +263,29 @@ def server() -> None: help="Validity period for the delegated key in seconds. " "If not set, defaults to 3600 seconds (1 hour).", ) +@click.option( + "--no-grease", + is_flag=True, + help="Disable response greasing", + envvar="ROUGHLY_NO_GREASE", +) +@click.option( + "--grease-probability", + default=None, + type=float, + help="Probability of greasing responses (between 0.0 and 1.0). " + f"If not set, defaults to {roughly.server.GREASE_PROBABILITY} " + f"({roughly.server.GREASE_PROBABILITY * 100:.2f}%).", + envvar="ROUGHLY_GREASE_PROBABILITY", +) def server_run( host: str, port: int, private_key: str | None, radius: int, validity_seconds: int | None, + no_grease: bool, + grease_probability: float | None, ) -> None: """Run a Roughtime server.""" key_bytes = base64.b64decode(private_key) if private_key else None @@ -277,6 +294,8 @@ def server_run( key_bytes, validity_seconds=validity_seconds, radius=radius, + grease=not no_grease, + grease_probability=grease_probability, ) pub_bytes = roughly.server.public_key_bytes(config.long_term_key) diff --git a/roughly/server.py b/roughly/server.py index 01e696e..2d796af 100644 --- a/roughly/server.py +++ b/roughly/server.py @@ -3,14 +3,16 @@ import asyncio import logging import os +import string import struct import time -from typing import TYPE_CHECKING, NamedTuple, cast +from random import SystemRandom +from typing import TYPE_CHECKING, NamedTuple, TypeVar, cast from cryptography.hazmat.primitives.asymmetric import ed25519 if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Callable, Sequence from roughly import ( DELEGATION_CONTEXT_STRING, @@ -27,6 +29,7 @@ PacketError, Response, SignedResponse, + Tag, build_supported_versions, format_versions, partial_sha512, @@ -37,18 +40,70 @@ tags, ) +random = SystemRandom() + logger = logging.getLogger(__name__) +T = TypeVar("T") NONCE_SIZE = 32 VER_7_NONCE_SIZE = 64 MAX_DRAFT_VERSION = 0xFFFFFFFF -DEFAULT_RADIUS = int(os.environ.get("ROUGHLY_DEFAULT_RADIUS", "3")) +DEFAULT_RADIUS = 3 CLIENT_VERSIONS_SUPPORTED = build_supported_versions(10, 15) CERT_VALIDITY = 60 * 60 # 1 hour +GREASE_PROBABILITY = 0.001 + + +def grease_add_undefined_tag(message: Message) -> Message: + # undefined tags + # 4 byte tag name + tag_name = int.from_bytes(random.choices(string.ascii_uppercase.encode("ascii"), k=4)) + tag_value = os.urandom(random.randint(1, 16) * 4) + + message.tags.append(Tag(tag=tag_name, value=tag_value)) + message.tags.sort(key=lambda t: t.tag) + return message + + +def grease_remove_random_tag(message: Message) -> Message: + if message.tags: + message.tags.remove(random.choice(message.tags)) + return message + + +def grease_change_version(message: Message) -> Message: + # TODO: implement version grease, need to be able to resign packets + return message + + +def grease_change_time(message: Message) -> Message: + srep_raw = pop_by_tag(message.tags, tags.SREP) + srep = SignedResponse.from_bytes(srep_raw.value) + # from 0 to uint32 max + srep.midpoint = random.randint(0, 0x100000000) + new_srep_raw = srep.to_bytes() + message.tags.append(Tag(tag=tags.SREP, value=new_srep_raw)) + message.tags.sort(key=lambda t: t.tag) + return message + + +GREASERS: list[Callable[[Message], Message]] = [ + grease_add_undefined_tag, + grease_remove_random_tag, + grease_change_time, + grease_change_version, # TODO: implement +] + + +def grease_message(message: Message) -> Message: + greaser = random.choice(GREASERS) + logger.debug("Applying greaser: %s", greaser.__name__) + return greaser(message) + class CertificateStore(NamedTuple): old: Certificate @@ -119,23 +174,29 @@ class Server(NamedTuple): validity_seconds: int | None radius: int versions: tuple[int, ...] + grease: bool + grease_probability: float @staticmethod def get_time() -> int: return int(time.time()) @classmethod - def create( + def create( # noqa: PLR0913 cls, private_key: bytes | None = None, *, validity_seconds: int | None = None, radius: int = DEFAULT_RADIUS, versions: Sequence[int] | None = None, + grease: bool = False, + grease_probability: float | None = None, ) -> Server: cert_validity_seconds = validity_seconds if cert_validity_seconds is None: cert_validity_seconds = CERT_VALIDITY + if grease_probability is None: + grease_probability = GREASE_PROBABILITY long_term = load_key(private_key) if private_key else generate_key() delegated = generate_key() @@ -166,6 +227,8 @@ def make_cert(string: bytes, *, google: bool | None = None) -> Certificate: validity_seconds=validity_seconds, radius=radius, versions=tuple(versions or CLIENT_VERSIONS_SUPPORTED), + grease=grease, + grease_probability=grease_probability, ) def refresh(self) -> Server: @@ -174,6 +237,8 @@ def refresh(self) -> Server: validity_seconds=self.validity_seconds, radius=self.radius, versions=self.versions, + grease=self.grease, + grease_probability=self.grease_probability, ) @@ -300,7 +365,7 @@ def build_response( # noqa: PLR0913 root: bytes, path: list[bytes], index: int, -) -> bytes: +) -> Packet: # We very much expect the client to ignore unknown tags # we could also be a good programmer and handle versions properly # but let's expect clients to be well-built :3 @@ -324,7 +389,7 @@ def build_response( # noqa: PLR0913 cert = pick_cert(certificates=server.certificates, version=version) resp = make_response(server, nonce, version, path, index, srep, cert) - return Packet(message=resp).dump(google=(version == GOOGLE_ROUGHTIME_SENTINEL)) + return Packet(message=resp) def pick_cert(*, certificates: CertificateStore, version: int) -> Certificate: @@ -364,7 +429,10 @@ def handle_request(server: Server, data: bytes) -> bytes | None: return handle_batch(server, (data,))[0] -def handle_batch(server: Server, requests: Sequence[bytes]) -> list[bytes | None]: +def handle_batch( # noqa: C901 TODO: refactor this function + server: Server, + requests: Sequence[bytes], +) -> list[bytes | None]: # TODO(batching): we need to ensure that a batch is compatible # i.e. having to pick different hashers would break the merkle tree # for now, we don't batch requests at all @@ -441,7 +509,7 @@ def handle_batch(server: Server, requests: Sequence[bytes]) -> list[bytes | None for merkle_idx, (req_id, req) in enumerate(zip(valid_idx, valid_requests, strict=True)): path = get_merkle_path(levels, merkle_idx) - response = build_response( + packet = build_response( server, nonce=req.nonce, # TODO(batching): select the right version here @@ -451,6 +519,13 @@ def handle_batch(server: Server, requests: Sequence[bytes]) -> list[bytes | None path=path, index=merkle_idx, ) + + if server.grease and random.random() < server.grease_probability: + logger.debug("Greasing response for request %d", req_id) + grease_message(packet.message) + + response = packet.dump(google=(version == GOOGLE_ROUGHTIME_SENTINEL)) + if len(response) > len(req.raw): # we drop responses larger than requests to avoid amplification attacks logger.debug( @@ -482,7 +557,6 @@ def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: logger.debug("Received datagram from %s:%d", host, port) try: resp = handle_request(self.server, data) - if resp and self.transport: self.transport.sendto(resp, addr) logger.debug("Sent response to %s:%d", host, port)