Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ ROUGHLY_SERVER_PRIVATE_KEY="your_private_key_here" roughly -v server run

By default, the server will bind to `0.0.0.0:2002`. You can change this using the `--host` and `--port` flags.
I recommend running the server with verbose logging enabled (`-v`), so you can see incoming requests and debug any issues.
Additionally you might want to consider turning off response greasing while testing using the `--no-grease` flag.

### As a library

Expand Down
19 changes: 19 additions & 0 deletions roughly/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,29 @@ def server() -> None:
help="Validity period for the delegated key in seconds. "
"If not set, defaults to 3600 seconds (1 hour).",
)
@click.option(
"--no-grease",
is_flag=True,
help="Disable response greasing",
envvar="ROUGHLY_NO_GREASE",
)
@click.option(
"--grease-probability",
default=None,
type=float,
help="Probability of greasing responses (between 0.0 and 1.0). "
f"If not set, defaults to {roughly.server.GREASE_PROBABILITY} "
f"({roughly.server.GREASE_PROBABILITY * 100:.2f}%).",
envvar="ROUGHLY_GREASE_PROBABILITY",
)
def server_run(
host: str,
port: int,
private_key: str | None,
radius: int,
validity_seconds: int | None,
no_grease: bool,
grease_probability: float | None,
) -> None:
"""Run a Roughtime server."""
key_bytes = base64.b64decode(private_key) if private_key else None
Expand All @@ -277,6 +294,8 @@ def server_run(
key_bytes,
validity_seconds=validity_seconds,
radius=radius,
grease=not no_grease,
grease_probability=grease_probability,
)

pub_bytes = roughly.server.public_key_bytes(config.long_term_key)
Expand Down
92 changes: 83 additions & 9 deletions roughly/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
import asyncio
import logging
import os
import string
import struct
import time
from typing import TYPE_CHECKING, NamedTuple, cast
from random import SystemRandom
from typing import TYPE_CHECKING, NamedTuple, TypeVar, cast

from cryptography.hazmat.primitives.asymmetric import ed25519

if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Callable, Sequence

from roughly import (
DELEGATION_CONTEXT_STRING,
Expand All @@ -27,6 +29,7 @@
PacketError,
Response,
SignedResponse,
Tag,
build_supported_versions,
format_versions,
partial_sha512,
Expand All @@ -37,18 +40,70 @@
tags,
)

random = SystemRandom()

logger = logging.getLogger(__name__)

T = TypeVar("T")

NONCE_SIZE = 32
VER_7_NONCE_SIZE = 64
MAX_DRAFT_VERSION = 0xFFFFFFFF

DEFAULT_RADIUS = int(os.environ.get("ROUGHLY_DEFAULT_RADIUS", "3"))
DEFAULT_RADIUS = 3
CLIENT_VERSIONS_SUPPORTED = build_supported_versions(10, 15)

CERT_VALIDITY = 60 * 60 # 1 hour

GREASE_PROBABILITY = 0.001


def grease_add_undefined_tag(message: Message) -> Message:
# undefined tags
# 4 byte tag name
tag_name = int.from_bytes(random.choices(string.ascii_uppercase.encode("ascii"), k=4))
tag_value = os.urandom(random.randint(1, 16) * 4)

message.tags.append(Tag(tag=tag_name, value=tag_value))
message.tags.sort(key=lambda t: t.tag)
return message


def grease_remove_random_tag(message: Message) -> Message:
if message.tags:
message.tags.remove(random.choice(message.tags))
return message


def grease_change_version(message: Message) -> Message:
# TODO: implement version grease, need to be able to resign packets
return message


def grease_change_time(message: Message) -> Message:
srep_raw = pop_by_tag(message.tags, tags.SREP)
srep = SignedResponse.from_bytes(srep_raw.value)
# from 0 to uint32 max
srep.midpoint = random.randint(0, 0x100000000)
new_srep_raw = srep.to_bytes()
message.tags.append(Tag(tag=tags.SREP, value=new_srep_raw))
message.tags.sort(key=lambda t: t.tag)
return message


GREASERS: list[Callable[[Message], Message]] = [
grease_add_undefined_tag,
grease_remove_random_tag,
grease_change_time,
grease_change_version, # TODO: implement
]


def grease_message(message: Message) -> Message:
greaser = random.choice(GREASERS)
logger.debug("Applying greaser: %s", greaser.__name__)
return greaser(message)


class CertificateStore(NamedTuple):
old: Certificate
Expand Down Expand Up @@ -119,23 +174,29 @@ class Server(NamedTuple):
validity_seconds: int | None
radius: int
versions: tuple[int, ...]
grease: bool
grease_probability: float

@staticmethod
def get_time() -> int:
return int(time.time())

@classmethod
def create(
def create( # noqa: PLR0913
cls,
private_key: bytes | None = None,
*,
validity_seconds: int | None = None,
radius: int = DEFAULT_RADIUS,
versions: Sequence[int] | None = None,
grease: bool = False,
grease_probability: float | None = None,
) -> Server:
cert_validity_seconds = validity_seconds
if cert_validity_seconds is None:
cert_validity_seconds = CERT_VALIDITY
if grease_probability is None:
grease_probability = GREASE_PROBABILITY

long_term = load_key(private_key) if private_key else generate_key()
delegated = generate_key()
Expand Down Expand Up @@ -166,6 +227,8 @@ def make_cert(string: bytes, *, google: bool | None = None) -> Certificate:
validity_seconds=validity_seconds,
radius=radius,
versions=tuple(versions or CLIENT_VERSIONS_SUPPORTED),
grease=grease,
grease_probability=grease_probability,
)

def refresh(self) -> Server:
Expand All @@ -174,6 +237,8 @@ def refresh(self) -> Server:
validity_seconds=self.validity_seconds,
radius=self.radius,
versions=self.versions,
grease=self.grease,
grease_probability=self.grease_probability,
)


Expand Down Expand Up @@ -300,7 +365,7 @@ def build_response( # noqa: PLR0913
root: bytes,
path: list[bytes],
index: int,
) -> bytes:
) -> Packet:
# We very much expect the client to ignore unknown tags
# we could also be a good programmer and handle versions properly
# but let's expect clients to be well-built :3
Expand All @@ -324,7 +389,7 @@ def build_response( # noqa: PLR0913
cert = pick_cert(certificates=server.certificates, version=version)

resp = make_response(server, nonce, version, path, index, srep, cert)
return Packet(message=resp).dump(google=(version == GOOGLE_ROUGHTIME_SENTINEL))
return Packet(message=resp)


def pick_cert(*, certificates: CertificateStore, version: int) -> Certificate:
Expand Down Expand Up @@ -364,7 +429,10 @@ def handle_request(server: Server, data: bytes) -> bytes | None:
return handle_batch(server, (data,))[0]


def handle_batch(server: Server, requests: Sequence[bytes]) -> list[bytes | None]:
def handle_batch( # noqa: C901 TODO: refactor this function
server: Server,
requests: Sequence[bytes],
) -> list[bytes | None]:
# TODO(batching): we need to ensure that a batch is compatible
# i.e. having to pick different hashers would break the merkle tree
# for now, we don't batch requests at all
Expand Down Expand Up @@ -441,7 +509,7 @@ def handle_batch(server: Server, requests: Sequence[bytes]) -> list[bytes | None

for merkle_idx, (req_id, req) in enumerate(zip(valid_idx, valid_requests, strict=True)):
path = get_merkle_path(levels, merkle_idx)
response = build_response(
packet = build_response(
server,
nonce=req.nonce,
# TODO(batching): select the right version here
Expand All @@ -451,6 +519,13 @@ def handle_batch(server: Server, requests: Sequence[bytes]) -> list[bytes | None
path=path,
index=merkle_idx,
)

if server.grease and random.random() < server.grease_probability:
logger.debug("Greasing response for request %d", req_id)
grease_message(packet.message)

response = packet.dump(google=(version == GOOGLE_ROUGHTIME_SENTINEL))

if len(response) > len(req.raw):
# we drop responses larger than requests to avoid amplification attacks
logger.debug(
Expand Down Expand Up @@ -482,7 +557,6 @@ def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
logger.debug("Received datagram from %s:%d", host, port)
try:
resp = handle_request(self.server, data)

if resp and self.transport:
self.transport.sendto(resp, addr)
logger.debug("Sent response to %s:%d", host, port)
Expand Down