From e32ad0c7aacf2bde2482fbc8b25e74e5b032d35d Mon Sep 17 00:00:00 2001 From: AI Agent Bot Date: Wed, 18 Feb 2026 03:13:56 -0600 Subject: [PATCH 1/5] add .gitignore --- .gitignore | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4a173ba --- /dev/null +++ b/.gitignore @@ -0,0 +1,48 @@ +# Rust build artifacts +target/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# OS +.DS_Store +Thumbs.db + +# PSP build artifacts +*.PBP +*.SFO +*.PRX +psp_output_file.log + +# Secrets +*.key +*.pem +*.cert +.env +*.env.local + +# AI agent session/cache directories +.codex/ +.opencode/ +.claude/ +.gemini/* +!.gemini/settings.json +.crush/ + +# CI review artifacts +codex-review.md +gemini-review.md + +# Docker overrides +docker-compose.override.yml + +# Node (for agent tooling) +node_modules/ + +# Scratch +scratch/ +.local/ From e63884513b346d2f6a9ad5d856a55e1da2c96158 Mon Sep 17 00:00:00 2001 From: AI Agent Bot Date: Wed, 18 Feb 2026 04:24:14 -0600 Subject: [PATCH 2/5] migrate injection toolkit and NMS cockpit video mod into game-mods repo Add all ITK core libraries, framework templates, NMS cockpit video project, CI/Docker configuration, and tooling. Fix all remnant references from the previous template-repo: update Docker image names, project URLs, documentation paths, and source code comments to reflect the game-mods repository structure. Co-Authored-By: Claude Opus 4.6 --- .agents.yaml | 184 +++ .agents.yaml.example | 118 ++ .env.example | 45 + .gitattributes | 13 + .../actions/agent-iteration-check/action.yml | 139 +++ .github/workflows/ci.yml | 71 ++ .github/workflows/main-ci.yml | 117 ++ .github/workflows/pr-validation.yml | 466 ++++++++ .gitignore | 9 +- .gitmodules | 0 .mcp.json | 167 +++ .nvmrc | 1 + .pre-commit-config.yaml | 53 + .rgignore | 1 + .secrets.yaml | 201 ++++ AGENTS.md | 111 ++ CLAUDE.md | 85 ++ CONTRIBUTING.md | 19 + Cargo.toml | 106 ++ LICENSE-MIT | 19 + README.md | 114 +- core/itk-ipc/Cargo.toml | 28 + core/itk-ipc/src/lib.rs | 322 ++++++ core/itk-ipc/src/unix_impl.rs | 364 ++++++ core/itk-ipc/src/windows_impl.rs | 461 ++++++++ core/itk-net/Cargo.toml | 39 + core/itk-net/src/discovery.rs | 269 +++++ core/itk-net/src/error.rs | 40 + core/itk-net/src/lib.rs | 53 + core/itk-net/src/peer.rs | 342 ++++++ core/itk-net/src/session.rs | 392 +++++++ core/itk-net/src/sync_manager.rs | 297 +++++ core/itk-protocol/Cargo.toml | 17 + core/itk-protocol/fuzz/Cargo.toml | 36 + .../fuzz/fuzz_targets/fuzz_decode.rs | 16 + .../fuzz/fuzz_targets/fuzz_header.rs | 11 + .../fuzz/fuzz_targets/fuzz_screen_rect.rs | 54 + core/itk-protocol/src/lib.rs | 612 ++++++++++ core/itk-shmem/Cargo.toml | 25 + core/itk-shmem/src/lib.rs | 795 +++++++++++++ core/itk-shmem/src/unix_impl.rs | 165 +++ core/itk-shmem/src/windows_impl.rs | 172 +++ core/itk-sync/Cargo.toml | 16 + core/itk-sync/src/lib.rs | 446 ++++++++ core/itk-video/Cargo.toml | 39 + core/itk-video/src/decoder.rs | 361 ++++++ core/itk-video/src/error.rs | 77 ++ core/itk-video/src/frame_writer.rs | 234 ++++ core/itk-video/src/hwaccel.rs | 144 +++ core/itk-video/src/lib.rs | 67 ++ core/itk-video/src/scaler.rs | 173 +++ core/itk-video/src/stream.rs | 113 ++ core/itk-video/src/youtube.rs | 141 +++ daemon/Cargo.toml | 25 + daemon/src/main.rs | 595 ++++++++++ daemon/src/video/mod.rs | 13 + daemon/src/video/player.rs | 388 +++++++ daemon/src/video/state.rs | 144 +++ deny.toml | 47 + docker-compose.yml | 153 +++ docker/rust-ci.Dockerfile | 47 + docs/ARCHITECTURE.md | 403 +++++++ docs/MIGRATION.md | 176 +++ injectors/linux/ld-preload/Cargo.toml | 21 + injectors/linux/ld-preload/src/lib.rs | 95 ++ injectors/windows/native-dll/Cargo.toml | 21 + injectors/windows/native-dll/src/lib.rs | 128 +++ overlay/Cargo.toml | 40 + overlay/src/lib.rs | 100 ++ overlay/src/main.rs | 212 ++++ overlay/src/platform/linux.rs | 269 +++++ overlay/src/platform/mod.rs | 30 + overlay/src/platform/windows.rs | 73 ++ overlay/src/render.rs | 364 ++++++ overlay/src/shaders/overlay.wgsl | 29 + overlay/src/video.rs | 123 ++ projects/nms-cockpit-video/README.md | 254 +++++ projects/nms-cockpit-video/daemon/Cargo.toml | 54 + projects/nms-cockpit-video/daemon/src/main.rs | 410 +++++++ .../daemon/src/video/audio.rs | 467 ++++++++ .../nms-cockpit-video/daemon/src/video/mod.rs | 14 + .../daemon/src/video/player.rs | 659 +++++++++++ .../daemon/src/video/state.rs | 148 +++ .../docs/nms-reverse-engineering.md | 373 ++++++ .../nms-cockpit-video/injector/Cargo.toml | 28 + projects/nms-cockpit-video/injector/build.rs | 66 ++ .../injector/shaders/quad.frag.wgsl | 7 + .../injector/shaders/quad.vert.wgsl | 23 + .../injector/src/camera/mod.rs | 187 +++ .../injector/src/camera/pattern_scan.rs | 166 +++ .../injector/src/camera/projection.rs | 183 +++ .../injector/src/hooks/mod.rs | 41 + .../injector/src/hooks/openvr.rs | 371 ++++++ .../injector/src/hooks/vulkan.rs | 832 ++++++++++++++ .../nms-cockpit-video/injector/src/input.rs | 233 ++++ .../nms-cockpit-video/injector/src/lib.rs | 112 ++ .../nms-cockpit-video/injector/src/log.rs | 53 + .../injector/src/renderer/geometry.rs | 77 ++ .../injector/src/renderer/mod.rs | 1000 +++++++++++++++++ .../injector/src/renderer/texture.rs | 519 +++++++++ .../injector/src/shmem_reader.rs | 106 ++ .../nms-cockpit-video/launcher/Cargo.toml | 15 + .../nms-cockpit-video/launcher/src/main.rs | 490 ++++++++ .../mod/NmsCockpitOverlay.sln | 18 + .../CockpitTracker/MatrixReader.cs | 222 ++++ .../CockpitTracker/ScreenProjection.cs | 178 +++ .../mod/NmsCockpitOverlay/Ipc/PipeClient.cs | 190 ++++ .../mod/NmsCockpitOverlay/Mod.cs | 196 ++++ .../mod/NmsCockpitOverlay/ModConfig.json | 33 + .../NmsCockpitOverlay.csproj | 23 + projects/nms-cockpit-video/overlay/Cargo.toml | 62 + .../nms-cockpit-video/overlay/src/main.rs | 435 +++++++ .../nms-cockpit-video/overlay/src/platform.rs | 119 ++ .../nms-cockpit-video/overlay/src/render.rs | 503 +++++++++ projects/nms-cockpit-video/overlay/src/ui.rs | 258 +++++ .../nms-cockpit-video/overlay/src/video.rs | 116 ++ rustfmt.toml | 66 ++ tools/cli/agents/run_claude.bat | 53 + tools/cli/agents/run_claude.sh | 53 + tools/cli/agents/run_codex.sh | 215 ++++ tools/cli/agents/run_crush.sh | 217 ++++ tools/cli/agents/run_gemini.sh | 96 ++ tools/cli/agents/run_opencode.sh | 147 +++ tools/cli/agents/stop_claude.sh | 43 + tools/cli/containers/README.md | 116 ++ tools/cli/containers/run_codex_container.sh | 84 ++ tools/cli/containers/run_crush_container.sh | 180 +++ tools/cli/containers/run_gemini_container.sh | 187 +++ .../cli/containers/run_opencode_container.sh | 201 ++++ tools/cli/containers/run_opencode_simple.sh | 138 +++ tools/mem-scanner/Cargo.toml | 25 + tools/mem-scanner/src/main.rs | 33 + tools/mem-scanner/src/scanner.rs | 372 ++++++ 133 files changed, 23315 insertions(+), 8 deletions(-) create mode 100644 .agents.yaml create mode 100644 .agents.yaml.example create mode 100644 .env.example create mode 100644 .gitattributes create mode 100644 .github/actions/agent-iteration-check/action.yml create mode 100644 .github/workflows/ci.yml create mode 100644 .github/workflows/main-ci.yml create mode 100644 .github/workflows/pr-validation.yml create mode 100644 .gitmodules create mode 100644 .mcp.json create mode 100644 .nvmrc create mode 100644 .pre-commit-config.yaml create mode 100644 .rgignore create mode 100644 .secrets.yaml create mode 100644 AGENTS.md create mode 100644 CLAUDE.md create mode 100644 CONTRIBUTING.md create mode 100644 Cargo.toml create mode 100644 LICENSE-MIT create mode 100644 core/itk-ipc/Cargo.toml create mode 100644 core/itk-ipc/src/lib.rs create mode 100644 core/itk-ipc/src/unix_impl.rs create mode 100644 core/itk-ipc/src/windows_impl.rs create mode 100644 core/itk-net/Cargo.toml create mode 100644 core/itk-net/src/discovery.rs create mode 100644 core/itk-net/src/error.rs create mode 100644 core/itk-net/src/lib.rs create mode 100644 core/itk-net/src/peer.rs create mode 100644 core/itk-net/src/session.rs create mode 100644 core/itk-net/src/sync_manager.rs create mode 100644 core/itk-protocol/Cargo.toml create mode 100644 core/itk-protocol/fuzz/Cargo.toml create mode 100644 core/itk-protocol/fuzz/fuzz_targets/fuzz_decode.rs create mode 100644 core/itk-protocol/fuzz/fuzz_targets/fuzz_header.rs create mode 100644 core/itk-protocol/fuzz/fuzz_targets/fuzz_screen_rect.rs create mode 100644 core/itk-protocol/src/lib.rs create mode 100644 core/itk-shmem/Cargo.toml create mode 100644 core/itk-shmem/src/lib.rs create mode 100644 core/itk-shmem/src/unix_impl.rs create mode 100644 core/itk-shmem/src/windows_impl.rs create mode 100644 core/itk-sync/Cargo.toml create mode 100644 core/itk-sync/src/lib.rs create mode 100644 core/itk-video/Cargo.toml create mode 100644 core/itk-video/src/decoder.rs create mode 100644 core/itk-video/src/error.rs create mode 100644 core/itk-video/src/frame_writer.rs create mode 100644 core/itk-video/src/hwaccel.rs create mode 100644 core/itk-video/src/lib.rs create mode 100644 core/itk-video/src/scaler.rs create mode 100644 core/itk-video/src/stream.rs create mode 100644 core/itk-video/src/youtube.rs create mode 100644 daemon/Cargo.toml create mode 100644 daemon/src/main.rs create mode 100644 daemon/src/video/mod.rs create mode 100644 daemon/src/video/player.rs create mode 100644 daemon/src/video/state.rs create mode 100644 deny.toml create mode 100644 docker-compose.yml create mode 100644 docker/rust-ci.Dockerfile create mode 100644 docs/ARCHITECTURE.md create mode 100644 docs/MIGRATION.md create mode 100644 injectors/linux/ld-preload/Cargo.toml create mode 100644 injectors/linux/ld-preload/src/lib.rs create mode 100644 injectors/windows/native-dll/Cargo.toml create mode 100644 injectors/windows/native-dll/src/lib.rs create mode 100644 overlay/Cargo.toml create mode 100644 overlay/src/lib.rs create mode 100644 overlay/src/main.rs create mode 100644 overlay/src/platform/linux.rs create mode 100644 overlay/src/platform/mod.rs create mode 100644 overlay/src/platform/windows.rs create mode 100644 overlay/src/render.rs create mode 100644 overlay/src/shaders/overlay.wgsl create mode 100644 overlay/src/video.rs create mode 100644 projects/nms-cockpit-video/README.md create mode 100644 projects/nms-cockpit-video/daemon/Cargo.toml create mode 100644 projects/nms-cockpit-video/daemon/src/main.rs create mode 100644 projects/nms-cockpit-video/daemon/src/video/audio.rs create mode 100644 projects/nms-cockpit-video/daemon/src/video/mod.rs create mode 100644 projects/nms-cockpit-video/daemon/src/video/player.rs create mode 100644 projects/nms-cockpit-video/daemon/src/video/state.rs create mode 100644 projects/nms-cockpit-video/docs/nms-reverse-engineering.md create mode 100644 projects/nms-cockpit-video/injector/Cargo.toml create mode 100644 projects/nms-cockpit-video/injector/build.rs create mode 100644 projects/nms-cockpit-video/injector/shaders/quad.frag.wgsl create mode 100644 projects/nms-cockpit-video/injector/shaders/quad.vert.wgsl create mode 100644 projects/nms-cockpit-video/injector/src/camera/mod.rs create mode 100644 projects/nms-cockpit-video/injector/src/camera/pattern_scan.rs create mode 100644 projects/nms-cockpit-video/injector/src/camera/projection.rs create mode 100644 projects/nms-cockpit-video/injector/src/hooks/mod.rs create mode 100644 projects/nms-cockpit-video/injector/src/hooks/openvr.rs create mode 100644 projects/nms-cockpit-video/injector/src/hooks/vulkan.rs create mode 100644 projects/nms-cockpit-video/injector/src/input.rs create mode 100644 projects/nms-cockpit-video/injector/src/lib.rs create mode 100644 projects/nms-cockpit-video/injector/src/log.rs create mode 100644 projects/nms-cockpit-video/injector/src/renderer/geometry.rs create mode 100644 projects/nms-cockpit-video/injector/src/renderer/mod.rs create mode 100644 projects/nms-cockpit-video/injector/src/renderer/texture.rs create mode 100644 projects/nms-cockpit-video/injector/src/shmem_reader.rs create mode 100644 projects/nms-cockpit-video/launcher/Cargo.toml create mode 100644 projects/nms-cockpit-video/launcher/src/main.rs create mode 100644 projects/nms-cockpit-video/mod/NmsCockpitOverlay.sln create mode 100644 projects/nms-cockpit-video/mod/NmsCockpitOverlay/CockpitTracker/MatrixReader.cs create mode 100644 projects/nms-cockpit-video/mod/NmsCockpitOverlay/CockpitTracker/ScreenProjection.cs create mode 100644 projects/nms-cockpit-video/mod/NmsCockpitOverlay/Ipc/PipeClient.cs create mode 100644 projects/nms-cockpit-video/mod/NmsCockpitOverlay/Mod.cs create mode 100644 projects/nms-cockpit-video/mod/NmsCockpitOverlay/ModConfig.json create mode 100644 projects/nms-cockpit-video/mod/NmsCockpitOverlay/NmsCockpitOverlay.csproj create mode 100644 projects/nms-cockpit-video/overlay/Cargo.toml create mode 100644 projects/nms-cockpit-video/overlay/src/main.rs create mode 100644 projects/nms-cockpit-video/overlay/src/platform.rs create mode 100644 projects/nms-cockpit-video/overlay/src/render.rs create mode 100644 projects/nms-cockpit-video/overlay/src/ui.rs create mode 100644 projects/nms-cockpit-video/overlay/src/video.rs create mode 100644 rustfmt.toml create mode 100644 tools/cli/agents/run_claude.bat create mode 100755 tools/cli/agents/run_claude.sh create mode 100755 tools/cli/agents/run_codex.sh create mode 100755 tools/cli/agents/run_crush.sh create mode 100755 tools/cli/agents/run_gemini.sh create mode 100755 tools/cli/agents/run_opencode.sh create mode 100755 tools/cli/agents/stop_claude.sh create mode 100644 tools/cli/containers/README.md create mode 100755 tools/cli/containers/run_codex_container.sh create mode 100755 tools/cli/containers/run_crush_container.sh create mode 100755 tools/cli/containers/run_gemini_container.sh create mode 100755 tools/cli/containers/run_opencode_container.sh create mode 100755 tools/cli/containers/run_opencode_simple.sh create mode 100644 tools/mem-scanner/Cargo.toml create mode 100644 tools/mem-scanner/src/main.rs create mode 100644 tools/mem-scanner/src/scanner.rs diff --git a/.agents.yaml b/.agents.yaml new file mode 100644 index 0000000..33de2f2 --- /dev/null +++ b/.agents.yaml @@ -0,0 +1,184 @@ +# Multi-Agent System Configuration +# This file configures the AI agents available for automated CI/CD tasks +# +# IMPORTANT: All agents run in AUTONOMOUS MODE for CI/CD automation +# - Agents must run without human interaction +# - All agents operate in sandboxed environments (containers/VMs) +# - Interactive prompts are disabled (e.g., --dangerously-skip-permissions) + +# List of enabled agents +# NOTE: Some agents are only available in specific environments: +# - Host-only agents: claude (needs subscription auth), gemini (needs Docker access) +# - Container-only agents: opencode, crush (installed in openrouter-agents container) +enabled_agents: + - claude # Anthropic's Claude Code (host-only, primary agent) + - gemini # Google's Gemini CLI (host-only, for reviews) + # Containerized agents (automatically run in Docker when invoked) + - opencode # Open-source alternative (runs in container) + - crush # Charm Bracelet's multi-provider tool (runs in container) + +# Agent priorities for different task types +agent_priorities: + # For creating PRs from issues + issue_creation: + - claude # Best for implementation + + # For reviewing PRs + pr_reviews: + - gemini # Excellent for code review + - claude # Fallback option + + # For implementing code fixes + code_fixes: + - claude # Most capable for fixes + +# Security settings +security: + # Users authorized to trigger agent actions via [Approved][Agent] keywords + # CRITICAL: Only add trusted human users - these can execute code via agents + # Agent admins can also use [CONTINUE] in PR comments to extend iteration limits + agent_admins: + - AndrewAltimit # Repository owner + + # Trusted sources for comment context (used in PR reviews) + # Comments from these accounts are marked as trusted when providing context to AI + # This does NOT grant them ability to trigger agent actions + trusted_sources: + - AndrewAltimit # Repository owner + - github-actions[bot] # GitHub Actions bot + - dependabot[bot] # Dependabot + - renovate[bot] # Renovate bot + + # Autonomous mode is REQUIRED for CI/CD operation + # All agents run in sandboxed environments for security + autonomous_mode: true # Enables non-interactive flags like --dangerously-skip-permissions + + # Confirm running in sandboxed environment + require_sandbox: true + + # Maximum prompt length to prevent abuse + max_prompt_length: 10000 + + # Cleanup temporary files after use + temp_file_cleanup: true + + # Maximum execution time per agent call + subprocess_timeout: 600 # 10 minutes + + # Memory limit for subprocess execution + memory_limit_mb: 500 + +# Rate limiting (per agent) +rate_limits: + requests_per_minute: 10 + requests_per_hour: 100 + + # Agent-specific overrides + claude: + requests_per_minute: 20 # Claude can handle more + gemini: + requests_per_minute: 5 # More conservative for Gemini + +# Model configuration overrides +model_overrides: + # Gemini models + # Using explicit model specification with API key authentication + # Model names verified from gemini-config.json + gemini: + pro_model: gemini-3-pro-preview # Latest preview model (NOT 3.0!) + flash_model: gemini-3-flash-preview # Fast fallback model (Gemini 3 Flash) + default_model: gemini-3-flash-preview # Primary model for PR reviews (faster, lower rate limits) + + # OpenRouter agents configuration + opencode: + model: qwen/qwen-2.5-coder-32b-instruct + temperature: 0.2 + + crush: + model: qwen/qwen-2.5-coder-32b-instruct + temperature: 0.1 # Lower temperature for Crush + +# OpenRouter configuration (for future agents) +openrouter: + # API key must be provided as OPENROUTER_API_KEY environment variable + + # Default model for OpenRouter-compatible agents + default_model: qwen/qwen-2.5-coder-32b-instruct + + # Fallback models if primary is unavailable + fallback_models: + - deepseek/deepseek-coder-v2-instruct + - meta-llama/llama-3.1-70b-instruct + +# PR Review configuration +pr_review: + # Default agent for reviews (from enabled_agents) + default_agent: gemini + + # Review constraints + max_words: 500 + condensation_threshold: 600 + + # Incremental review settings + incremental_enabled: true + + # Trust bucketing (uses security.agent_admins and security.trusted_sources) + include_comment_context: true + + # Hallucination detection + verify_claims: true + + # Reaction images config + reaction_config_url: "https://raw.githubusercontent.com/AndrewAltimit/Media/refs/heads/main/reaction/config.yaml" + +# Automation settings for inline agent feedback loop +automation: + # Enable inline agent feedback loop in PR validation + inline_feedback_loop: true + + # Maximum iterations before stopping (prevents infinite loops) + max_auto_fix_iterations: 5 + + # Try autoformat (black, isort) before invoking AI + autoformat_first: true + + # Categories of failures to auto-fix + # Options: formatting, linting, type_errors, unused_imports, test_failures + auto_fix_categories: + - formatting + - linting + - type_errors + - unused_imports + + # Skip auto-fix for PRs with these labels + skip_labels: + - no-auto-fix + - needs-human-review + +# Advanced settings +advanced: + # Enable debug logging + debug_mode: false + + # Custom paths + # Set AGENT_TEMP_DIR environment variable to override + temp_directory: /tmp/agents + + # Retry configuration + max_retries: 2 + retry_delay_seconds: 5 + + # Subprocess environment isolation + isolate_environment: true + + # Enable telemetry (metrics collection) + enable_telemetry: false + + # Non-interactive mode settings + non_interactive_flags: + claude: ["--print", "--dangerously-skip-permissions"] + # For gemini: Use -p for prompt and -m for model selection (both work reliably with API keys) + # Note: Only -p is passed by default; model selection uses .env configuration (GEMINI_PRIMARY_MODEL) + gemini: ["-p"] + opencode: ["--non-interactive"] + crush: ["--non-interactive", "--no-update"] diff --git a/.agents.yaml.example b/.agents.yaml.example new file mode 100644 index 0000000..2ee1572 --- /dev/null +++ b/.agents.yaml.example @@ -0,0 +1,118 @@ +# Multi-Agent System Configuration Example +# Copy this to .agents.yaml and customize for your needs + +# List of enabled agents (claude is always enabled by default) +enabled_agents: + - claude # Anthropic's Claude Code (host-only) + - gemini # Google's Gemini CLI (host-only) + - opencode # SST's OpenCode (containerized) + - crush # Charm Bracelet's Crush AI (containerized) + +# Agent priorities for different task types +# Agents are tried in order until one succeeds +agent_priorities: + # For creating PRs from issues + issue_creation: + - claude + - opencode + + # For reviewing PRs + pr_reviews: + - gemini + - claude + + # For implementing code fixes + code_fixes: + - claude + - crush + - opencode + +# Model configuration overrides per agent +model_overrides: + # Example: Use a different model for OpenCode + # opencode: + # model: deepseek/deepseek-coder-v2-instruct + # temperature: 0.3 + + # Example: Configure Crush for faster responses + # crush: + # model: qwen/qwen-2.5-coder-7b-instruct + # temperature: 0.1 + +# OpenRouter configuration for agents that support it +openrouter: + # API key must be provided as OPENROUTER_API_KEY environment variable + + # Default model for OpenRouter-compatible agents + default_model: qwen/qwen-2.5-coder-32b-instruct + + # Fallback models if primary is unavailable + fallback_models: + - deepseek/deepseek-coder-v2-instruct + - meta-llama/llama-3.1-70b-instruct + + # Per-agent model overrides + agent_overrides: + opencode: + model: qwen/qwen-2.5-coder-32b-instruct + temperature: 0.2 + crush: + model: qwen/qwen-2.5-coder-32b-instruct + temperature: 0.1 + +# Security settings +security: + # Users authorized to trigger agent actions via [Approved][Agent] keywords + # CRITICAL: Only add trusted human users - these can execute code via agents + agent_admins: + - your-github-username # Replace with your GitHub username + + # Trusted sources for comment context (used in PR reviews) + # Comments from these accounts are marked as trusted when providing context to AI + # This does NOT grant them ability to trigger agent actions + trusted_sources: + - your-github-username # Replace with your GitHub username + - github-actions[bot] # GitHub Actions bot + - dependabot[bot] # Dependabot + - renovate[bot] # Renovate bot + + # Autonomous mode is REQUIRED for CI/CD operation + # All agents run in sandboxed environments for security + autonomous_mode: true # Enables non-interactive flags like --dangerously-skip-permissions + + # Confirm running in sandboxed environment + require_sandbox: true + + # Maximum prompt length to prevent abuse + max_prompt_length: 10000 + + # Cleanup temporary files after use + temp_file_cleanup: true + + # Maximum execution time per agent call + subprocess_timeout: 600 # 10 minutes + + # Memory limit for subprocess execution + memory_limit_mb: 500 + +# Rate limiting (per agent) +rate_limits: + requests_per_minute: 10 + requests_per_hour: 100 + + # Agent-specific overrides + # claude: + # requests_per_minute: 20 + # gemini: + # requests_per_minute: 5 + +# Advanced settings (optional) +advanced: + # Timeout for agent commands (seconds) + command_timeout: 300 + + # Maximum retries for failed agent calls + max_retries: 2 + + # Enable detailed logging + debug_mode: false diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..c41da0b --- /dev/null +++ b/.env.example @@ -0,0 +1,45 @@ +# Docker user permissions +USER_ID=1000 +GROUP_ID=1000 + +# GitHub Configuration +OPENROUTER_API_KEY=your_api_key_here +GITHUB_REPOSITORY=AndrewAltimit/game-mods +GITHUB_PROJECTS_TOKEN=your_api_key_here +GITHUB_TOKEN=your_api_key_here + +# Codex Configuration +# WARNING: Only set to true if running in a controlled sandboxed VM environment +# Defaults to false for security - only bypass sandbox if you understand the risks +CODEX_BYPASS_SANDBOX=false + +# Optional: ElevenLabs Configuration +ELEVENLABS_API_KEY=your_api_key_here +ELEVENLABS_DEFAULT_MODEL=eleven_v3 + +# Virtual Character Storage Service +STORAGE_SECRET_KEY=your_api_key_here +STORAGE_BASE_URL=http://192.168.0.222:8021 + +# GPU device selection for multi-GPU systems (default: 0) +GPU_DEVICE=0 + +# Gemini API Key (Free Tier from Google AI Studio) +GOOGLE_API_KEY=your_api_key_here +GEMINI_API_KEY=your_api_key_here + +# Gemini Model Configuration +GEMINI_PRIMARY_MODEL=gemini-3-pro-preview +GEMINI_FALLBACK_MODEL=gemini-3-flash-preview + +# ============================================================================= +# AgentCore Memory Configuration +# ============================================================================= +MEMORY_PROVIDER=agentcore +AWS_REGION=us-east-1 +AGENTCORE_MEMORY_ID=mem-xxxxxxxxxxxx +AGENTCORE_DATA_PLANE_ENDPOINT= +CHROMADB_HOST=chromadb +CHROMADB_PORT=8000 +CHROMADB_COLLECTION=agent_memory +HUGGINGFACE_TOKEN= diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..2c31c89 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,13 @@ +# Repository stores everything with LF (Unix line endings) +# Git auto-converts line endings on checkout based on platform + +# Batch files: MUST have CRLF line endings (Windows requirement) +*.bat text eol=crlf + +# Everything else: stored as LF, stays LF on all platforms +*.sh text eol=lf +*.py text eol=lf +*.md text eol=lf +*.json text eol=lf +*.yaml text eol=lf +*.yml text eol=lf diff --git a/.github/actions/agent-iteration-check/action.yml b/.github/actions/agent-iteration-check/action.yml new file mode 100644 index 0000000..22c54bf --- /dev/null +++ b/.github/actions/agent-iteration-check/action.yml @@ -0,0 +1,139 @@ +name: 'Agent Iteration Check' +description: 'Track agent auto-fix iterations via PR comments to prevent infinite loops' + +inputs: + pr_number: + description: 'PR number' + required: true + max_iterations: + description: 'Maximum allowed iterations before stopping' + required: false + default: '5' + agent_type: + description: 'Agent type to track (review-fix or failure-fix)' + required: true + github_token: + description: 'GitHub token for API operations' + required: true + +outputs: + iteration_count: + description: 'Current iteration count' + value: ${{ steps.check.outputs.iteration_count }} + effective_max: + description: 'Effective max iterations after [CONTINUE] multipliers' + value: ${{ steps.check.outputs.effective_max }} + continue_count: + description: 'Number of [CONTINUE] comments from admins' + value: ${{ steps.check.outputs.continue_count }} + is_agent_commit: + description: 'Whether the last commit was made by an agent' + value: ${{ steps.check.outputs.is_agent_commit }} + should_skip: + description: 'Whether to skip this run' + value: ${{ steps.check.outputs.should_skip }} + exceeded_max: + description: 'Whether max iterations have been exceeded' + value: ${{ steps.check.outputs.exceeded_max }} + +runs: + using: 'composite' + steps: + - name: Check iteration count + id: check + shell: bash + env: + GH_TOKEN: ${{ inputs.github_token }} + PR_NUMBER: ${{ inputs.pr_number }} + MAX_ITERATIONS: ${{ inputs.max_iterations }} + AGENT_TYPE: ${{ inputs.agent_type }} + run: | + echo "=== Agent Iteration Check ===" + echo "PR: $PR_NUMBER | Max: $MAX_ITERATIONS | Type: $AGENT_TYPE" + + # Check if last commit is an agent commit + AGENT_AUTHORS=("AI Review Agent" "AI Pipeline Agent" "AI Agent Bot") + LAST_COMMIT_AUTHOR=$(git log -1 --format='%an' 2>/dev/null || echo "") + + IS_AGENT_COMMIT="false" + for author in "${AGENT_AUTHORS[@]}"; do + if [[ "$LAST_COMMIT_AUTHOR" == "$author" ]]; then + IS_AGENT_COMMIT="true" + break + fi + done + echo "is_agent_commit=$IS_AGENT_COMMIT" >> $GITHUB_OUTPUT + + # Use github-agents CLI if available (installed on runner) + if command -v github-agents &>/dev/null; then + RESULT=$(github-agents iteration-check \ + --pr "$PR_NUMBER" \ + --agent-type "$AGENT_TYPE" \ + --max-iterations "$MAX_ITERATIONS" \ + --format json 2>/dev/null || echo "") + + if [ -n "$RESULT" ] && echo "$RESULT" | jq -e . >/dev/null 2>&1; then + echo "iteration_count=$(echo "$RESULT" | jq -r '.iteration_count')" >> $GITHUB_OUTPUT + echo "effective_max=$(echo "$RESULT" | jq -r '.effective_max')" >> $GITHUB_OUTPUT + echo "continue_count=$(echo "$RESULT" | jq -r '.continue_count')" >> $GITHUB_OUTPUT + echo "exceeded_max=$(echo "$RESULT" | jq -r '.exceeded_max')" >> $GITHUB_OUTPUT + echo "should_skip=$(echo "$RESULT" | jq -r '.should_skip')" >> $GITHUB_OUTPUT + echo "=== Iteration Check Complete (via CLI) ===" + exit 0 + fi + fi + + # Fallback: count agent commit metadata comments on the PR + echo "Falling back to comment-based iteration counting..." + ITERATION_COUNT=$(gh api "repos/$GITHUB_REPOSITORY/issues/$PR_NUMBER/comments" \ + --paginate --jq "[.[] | select(.body | contains(\"agent-metadata:type=${AGENT_TYPE}\"))] | length" \ + 2>/dev/null || echo "0") + + # Count [CONTINUE] comments from admins + CONTINUE_COUNT=$(gh api "repos/$GITHUB_REPOSITORY/issues/$PR_NUMBER/comments" \ + --paginate --jq '[.[] | select(.body | test("\\[CONTINUE\\]"))] | length' \ + 2>/dev/null || echo "0") + + EFFECTIVE_MAX=$(( MAX_ITERATIONS + CONTINUE_COUNT * MAX_ITERATIONS )) + EXCEEDED_MAX="false" + SHOULD_SKIP="false" + if [ "$ITERATION_COUNT" -ge "$EFFECTIVE_MAX" ]; then + EXCEEDED_MAX="true" + SHOULD_SKIP="true" + fi + + echo "iteration_count=$ITERATION_COUNT" >> $GITHUB_OUTPUT + echo "effective_max=$EFFECTIVE_MAX" >> $GITHUB_OUTPUT + echo "continue_count=$CONTINUE_COUNT" >> $GITHUB_OUTPUT + echo "exceeded_max=$EXCEEDED_MAX" >> $GITHUB_OUTPUT + echo "should_skip=$SHOULD_SKIP" >> $GITHUB_OUTPUT + + # Post comment if max iterations exceeded + if [ "$EXCEEDED_MAX" = "true" ]; then + if [ "$AGENT_TYPE" = "review-fix" ]; then + AGENT_NAME="Review Response Agent" + else + AGENT_NAME="Failure Handler Agent" + fi + + COMMENT_FILE=$(mktemp) + cat > "$COMMENT_FILE" < + + The **$AGENT_NAME** has reached the iteration limit (**${ITERATION_COUNT}/${EFFECTIVE_MAX}**). + Further automated fixes from this agent have been paused. + + **To allow more iterations:** + - An admin can comment \`[CONTINUE]\` to extend the limit + - Or address the issues manually + - Or add the \`no-auto-fix\` label to disable automated fixes + COMMENT_EOF + + gh pr comment "$PR_NUMBER" --body-file "$COMMENT_FILE" || \ + echo "Warning: Failed to post iteration limit comment" + rm -f "$COMMENT_FILE" + fi + + echo "=== Iteration Check Complete ===" + echo " Count: $ITERATION_COUNT | Max: $EFFECTIVE_MAX | Exceeded: $EXCEEDED_MAX" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..96c891c --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,71 @@ +--- +name: CI + +on: + push: + branches: [main] + workflow_dispatch: + +concurrency: + group: ci-${{ github.ref }} + cancel-in-progress: true + +env: + DOCKER_BUILDKIT: 1 + COMPOSE_DOCKER_CLI_BUILD: 1 + +jobs: + ci: + name: Game Mods CI + runs-on: self-hosted + timeout-minutes: 30 + + steps: + - name: Pre-checkout cleanup + run: | + for item in outputs target .git/index.lock; do + if [ -d "$item" ] || [ -f "$item" ]; then + docker run --rm -v "$(pwd):/workspace" busybox:1.36.1 sh -c \ + "rm -rf /workspace/$item" 2>/dev/null || \ + sudo rm -rf "$item" 2>/dev/null || true + fi + done + + - name: Checkout + uses: actions/checkout@v4 + + - name: Set UID/GID + run: | + echo "USER_ID=$(id -u)" >> $GITHUB_ENV + echo "GROUP_ID=$(id -g)" >> $GITHUB_ENV + + # -- Formatting ------------------------------------------------------- + - name: Format check + run: docker compose --profile ci run --rm rust-ci cargo fmt --all -- --check + + # -- Linting ----------------------------------------------------------- + - name: Clippy + run: docker compose --profile ci run --rm rust-ci cargo clippy --all-targets -- -D warnings + + # -- Tests ------------------------------------------------------------- + - name: Test + run: docker compose --profile ci run --rm rust-ci cargo test + + # -- Build ------------------------------------------------------------- + - name: Build + run: docker compose --profile ci run --rm rust-ci cargo build --release + + # -- License / Advisory ------------------------------------------------ + - name: cargo-deny + run: docker compose --profile ci run --rm rust-ci cargo deny check + + # -- Cleanup ----------------------------------------------------------- + - name: Fix Docker file ownership + if: always() + run: | + for dir in target outputs; do + if [ -d "$dir" ]; then + docker run --rm -v "$(pwd)/$dir:/workspace" busybox:1.36.1 \ + chown -Rh "$(id -u):$(id -g)" /workspace 2>/dev/null || true + fi + done diff --git a/.github/workflows/main-ci.yml b/.github/workflows/main-ci.yml new file mode 100644 index 0000000..588539f --- /dev/null +++ b/.github/workflows/main-ci.yml @@ -0,0 +1,117 @@ +--- +name: Main CI + +on: + push: + branches: [main] + tags: + - 'v*' + workflow_dispatch: + +concurrency: + group: main-ci-${{ github.sha }} + cancel-in-progress: false + +env: + DOCKER_BUILDKIT: 1 + COMPOSE_DOCKER_CLI_BUILD: 1 + +jobs: + # -- CI Stages (containerized) ------------------------------------------- + ci: + name: Game Mods CI + runs-on: self-hosted + timeout-minutes: 30 + steps: + - name: Pre-checkout cleanup + run: | + for item in outputs target .git/index.lock; do + if [ -d "$item" ] || [ -f "$item" ]; then + docker run --rm -v "$(pwd):/workspace" busybox:1.36.1 sh -c \ + "rm -rf /workspace/$item" 2>/dev/null || \ + sudo rm -rf "$item" 2>/dev/null || true + fi + done + + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + clean: true + + - name: Set UID/GID + run: | + echo "USER_ID=$(id -u)" >> $GITHUB_ENV + echo "GROUP_ID=$(id -g)" >> $GITHUB_ENV + + # -- Formatting ------------------------------------------------------- + - name: Format check + run: docker compose --profile ci run --rm rust-ci cargo fmt --all -- --check + + # -- Linting ----------------------------------------------------------- + - name: Clippy + run: docker compose --profile ci run --rm rust-ci cargo clippy --all-targets -- -D warnings + + # -- Tests ------------------------------------------------------------- + - name: Test + run: docker compose --profile ci run --rm rust-ci cargo test + + # -- Build ------------------------------------------------------------- + - name: Build + run: docker compose --profile ci run --rm rust-ci cargo build --release + + # -- License / Advisory ------------------------------------------------ + - name: cargo-deny + run: docker compose --profile ci run --rm rust-ci cargo deny check + + - name: Fix Docker file ownership + if: always() + run: | + for dir in target outputs; do + if [ -d "$dir" ]; then + docker run --rm -v "$(pwd)/$dir:/workspace" busybox:1.36.1 \ + chown -Rh "$(id -u):$(id -g)" /workspace 2>/dev/null || true + fi + done + + # -- CI Summary ---------------------------------------------------------- + notify: + name: CI Summary + needs: [ci] + if: always() + runs-on: self-hosted + steps: + - name: Generate summary + run: | + echo "## Main CI Summary" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "**Commit**: ${{ github.sha }}" >> $GITHUB_STEP_SUMMARY + echo "**Branch**: ${{ github.ref_name }}" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "| Stage | Status |" >> $GITHUB_STEP_SUMMARY + echo "|-------|--------|" >> $GITHUB_STEP_SUMMARY + echo "| CI | ${{ needs.ci.result }} |" >> $GITHUB_STEP_SUMMARY + + if [[ "${{ needs.ci.result }}" == "failure" ]]; then + echo "" >> $GITHUB_STEP_SUMMARY + echo "CI failed - please review the logs" >> $GITHUB_STEP_SUMMARY + exit 1 + fi + + - name: Create issue on failure + if: failure() && github.ref == 'refs/heads/main' + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + gh issue create \ + --title "Main CI Failed: $(date +%Y-%m-%d)" \ + --body "$(cat <<'ISSUE_EOF' + ## CI Failure Report + + **Commit**: ${{ github.sha }} + **Run**: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} + + Please investigate and fix the issues. + ISSUE_EOF + )" \ + --label "ci-failure,automated" || echo "::warning::Could not create failure issue" diff --git a/.github/workflows/pr-validation.yml b/.github/workflows/pr-validation.yml new file mode 100644 index 0000000..d3a0062 --- /dev/null +++ b/.github/workflows/pr-validation.yml @@ -0,0 +1,466 @@ +--- +name: Pull Request Validation + +on: + pull_request: + types: [opened, synchronize, reopened, ready_for_review] + branches: [main] + workflow_dispatch: + +permissions: + contents: write + pull-requests: write + issues: write + +concurrency: + group: pr-${{ github.event.pull_request.number || github.run_id }} + cancel-in-progress: true + +env: + DOCKER_BUILDKIT: 1 + COMPOSE_DOCKER_CLI_BUILD: 1 + +jobs: + # Fork guard: block fork PRs from running on self-hosted runners with write perms + fork-guard: + name: Fork PR Guard + runs-on: ubuntu-latest + if: >- + github.event_name != 'pull_request' || + github.event.pull_request.head.repo.full_name == github.repository + steps: + - run: echo "Not a fork PR - proceeding" + + # -- CI Stages (containerized) ------------------------------------------- + ci: + name: Game Mods CI + needs: fork-guard + runs-on: self-hosted + timeout-minutes: 30 + steps: + - name: Pre-checkout cleanup + run: | + for item in outputs target .git/index.lock; do + if [ -d "$item" ] || [ -f "$item" ]; then + docker run --rm -v "$(pwd):/workspace" busybox:1.36.1 sh -c \ + "rm -rf /workspace/$item" 2>/dev/null || \ + sudo rm -rf "$item" 2>/dev/null || true + fi + done + + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + clean: true + token: ${{ secrets.GITHUB_TOKEN }} + ref: ${{ github.head_ref }} + + - name: Set UID/GID + run: | + echo "USER_ID=$(id -u)" >> $GITHUB_ENV + echo "GROUP_ID=$(id -g)" >> $GITHUB_ENV + + # -- Formatting ------------------------------------------------------- + - name: Format check + run: docker compose --profile ci run --rm rust-ci cargo fmt --all -- --check + + # -- Linting ----------------------------------------------------------- + - name: Clippy + run: docker compose --profile ci run --rm rust-ci cargo clippy --all-targets -- -D warnings + + # -- Tests ------------------------------------------------------------- + - name: Test + run: docker compose --profile ci run --rm rust-ci cargo test + + # -- Build ------------------------------------------------------------- + - name: Build + run: docker compose --profile ci run --rm rust-ci cargo build --release + + # -- License / Advisory ------------------------------------------------ + - name: cargo-deny + run: docker compose --profile ci run --rm rust-ci cargo deny check + + - name: Fix Docker file ownership + if: always() + run: | + for dir in target outputs; do + if [ -d "$dir" ]; then + docker run --rm -v "$(pwd)/$dir:/workspace" busybox:1.36.1 \ + chown -Rh "$(id -u):$(id -g)" /workspace 2>/dev/null || true + fi + done + + # -- Gemini AI Code Review ----------------------------------------------- + gemini-review: + name: Gemini AI Code Review + needs: fork-guard + if: >- + github.event_name == 'pull_request' && + !github.event.pull_request.draft + runs-on: self-hosted + timeout-minutes: 30 + outputs: + status: ${{ steps.review.outputs.status }} + steps: + - name: Pre-checkout cleanup + run: | + for item in outputs target .git/index.lock; do + if [ -d "$item" ] || [ -f "$item" ]; then + docker run --rm -v "$(pwd):/workspace" busybox:1.36.1 sh -c \ + "rm -rf /workspace/$item" 2>/dev/null || \ + sudo rm -rf "$item" 2>/dev/null || true + fi + done + + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + clean: true + token: ${{ secrets.GITHUB_TOKEN }} + ref: ${{ github.head_ref }} + + - name: Log commit info + id: check-agent + run: | + LAST_AUTHOR=$(git log -1 --format='%an') + echo "Last commit author: $LAST_AUTHOR" + IS_AGENT="false" + if [[ "$LAST_AUTHOR" == "AI Review Agent" ]] || \ + [[ "$LAST_AUTHOR" == "AI Pipeline Agent" ]] || \ + [[ "$LAST_AUTHOR" == "AI Agent Bot" ]]; then + IS_AGENT="true" + fi + echo "is_agent_commit=$IS_AGENT" >> $GITHUB_OUTPUT + + - name: Run Gemini review + id: review + env: + GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }} + GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PR_NUMBER: ${{ github.event.pull_request.number }} + run: | + if ! command -v github-agents &>/dev/null; then + echo "::warning::github-agents not found on PATH - skipping Gemini review" + echo "status=skipped" >> $GITHUB_OUTPUT + exit 0 + fi + + set +e + OUTPUT=$(github-agents pr-review "$PR_NUMBER" --editor 2>&1) + EXIT_CODE=$? + set -e + + echo "$OUTPUT" + echo "$OUTPUT" > gemini-review.md + + if [ $EXIT_CODE -ne 0 ]; then + if echo "$OUTPUT" | grep -qiE '429|rate.?limit|quota|resource.?exhausted'; then + echo "::warning::Gemini API rate limit hit - skipping review" + echo "status=rate_limited" >> $GITHUB_OUTPUT + exit 0 + elif echo "$OUTPUT" | grep -qiE '503|502|service.?unavailable|ECONNREFUSED|ETIMEDOUT'; then + echo "::warning::Gemini API unavailable - skipping review" + echo "status=unavailable" >> $GITHUB_OUTPUT + exit 0 + elif echo "$OUTPUT" | grep -qiE 'panicked at|thread.*panic'; then + echo "::warning::Gemini review tool crashed - skipping review" + echo "status=tool_crash" >> $GITHUB_OUTPUT + exit 0 + else + echo "status=failure" >> $GITHUB_OUTPUT + exit $EXIT_CODE + fi + fi + + echo "status=success" >> $GITHUB_OUTPUT + + - name: Upload review artifact + if: always() + uses: actions/upload-artifact@v4 + with: + name: gemini-review-${{ github.run_id }}-${{ github.run_attempt }} + path: gemini-review.md + retention-days: 7 + if-no-files-found: ignore + + # -- Codex AI Code Review (secondary) ------------------------------------ + codex-review: + name: Codex AI Code Review + needs: [fork-guard, gemini-review] + if: >- + github.event_name == 'pull_request' && + !github.event.pull_request.draft && + needs.gemini-review.result != 'skipped' + runs-on: self-hosted + timeout-minutes: 15 + continue-on-error: true + outputs: + status: ${{ steps.review.outputs.status }} + steps: + - name: Pre-checkout cleanup + run: | + for item in outputs target .git/index.lock; do + if [ -d "$item" ] || [ -f "$item" ]; then + docker run --rm -v "$(pwd):/workspace" busybox:1.36.1 sh -c \ + "rm -rf /workspace/$item" 2>/dev/null || \ + sudo rm -rf "$item" 2>/dev/null || true + fi + done + + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + clean: true + token: ${{ secrets.GITHUB_TOKEN }} + ref: ${{ github.head_ref }} + + - name: Run Codex review + id: review + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PR_NUMBER: ${{ github.event.pull_request.number }} + run: | + if ! command -v github-agents &>/dev/null; then + echo "::warning::github-agents not found on PATH - skipping Codex review" + echo "status=skipped" >> $GITHUB_OUTPUT + exit 0 + fi + + set +e + OUTPUT=$(github-agents pr-review "$PR_NUMBER" --agent codex 2>&1) + EXIT_CODE=$? + set -e + + echo "$OUTPUT" + echo "$OUTPUT" > codex-review.md + + if [ $EXIT_CODE -ne 0 ]; then + if echo "$OUTPUT" | grep -qiE '429|rate.?limit|quota|resource.?exhausted'; then + echo "::warning::Codex API rate limit hit" + echo "status=rate_limited" >> $GITHUB_OUTPUT + exit 0 + elif echo "$OUTPUT" | grep -qiE '503|502|service.?unavailable|ECONNREFUSED|ETIMEDOUT'; then + echo "::warning::Codex API unavailable" + echo "status=unavailable" >> $GITHUB_OUTPUT + exit 0 + elif echo "$OUTPUT" | grep -qiE 'panicked at|thread.*panic'; then + echo "::warning::Codex review tool crashed - skipping review" + echo "status=tool_crash" >> $GITHUB_OUTPUT + exit 0 + else + echo "status=failure" >> $GITHUB_OUTPUT + exit $EXIT_CODE + fi + fi + + echo "status=success" >> $GITHUB_OUTPUT + + - name: Upload review artifact + if: always() + uses: actions/upload-artifact@v4 + with: + name: codex-review-${{ github.run_id }}-${{ github.run_attempt }} + path: codex-review.md + retention-days: 7 + if-no-files-found: ignore + + # -- Agent Review Response (responds to Gemini/Codex feedback) ----------- + agent-review-response: + name: Agent Review Response + needs: [ci, gemini-review, codex-review] + if: | + always() && + !cancelled() && + github.event_name == 'pull_request' && + !github.event.pull_request.draft && + !contains(github.event.pull_request.labels.*.name, 'no-auto-fix') + runs-on: self-hosted + timeout-minutes: 30 + outputs: + made_changes: ${{ steps.agent-fix.outputs.made_changes }} + env: + PR_NUMBER: ${{ github.event.pull_request.number }} + GITHUB_TOKEN: ${{ secrets.AGENT_TOKEN }} + GITHUB_REPOSITORY: ${{ github.repository }} + steps: + - name: Pre-checkout cleanup + run: | + for item in outputs target .git/index.lock; do + if [ -d "$item" ] || [ -f "$item" ]; then + docker run --rm -v "$(pwd):/workspace" busybox:1.36.1 sh -c \ + "rm -rf /workspace/$item" 2>/dev/null || \ + sudo rm -rf "$item" 2>/dev/null || true + fi + done + + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + clean: false + token: ${{ secrets.AGENT_TOKEN }} + ref: ${{ github.head_ref }} + + - name: Ensure clean working directory + run: | + git checkout -- . 2>/dev/null || true + git clean -fd 2>/dev/null || true + + - name: Check iteration count + id: iteration + uses: ./.github/actions/agent-iteration-check + with: + pr_number: ${{ github.event.pull_request.number }} + max_iterations: '5' + agent_type: 'review-fix' + github_token: ${{ secrets.GITHUB_TOKEN }} + + - name: Skip if max iterations reached + if: steps.iteration.outputs.exceeded_max == 'true' + run: | + echo "Maximum iterations (5) reached for review-fix agent. Manual intervention required." + echo "made_changes=false" >> $GITHUB_OUTPUT + exit 0 + + - name: Download review artifacts + if: steps.iteration.outputs.should_skip != 'true' + uses: actions/download-artifact@v4 + continue-on-error: true + with: + pattern: '*-review-${{ github.run_id }}-${{ github.run_attempt }}' + merge-multiple: true + path: . + + - name: Run agent review response + id: agent-fix + if: steps.iteration.outputs.should_skip != 'true' + env: + GEMINI_REVIEW_PATH: gemini-review.md + CODEX_REVIEW_PATH: codex-review.md + BRANCH_NAME: ${{ github.head_ref }} + ITERATION_COUNT: ${{ steps.iteration.outputs.iteration_count }} + run: | + if ! command -v automation-cli &>/dev/null; then + echo "::warning::automation-cli not found on PATH - skipping review response" + echo "made_changes=false" >> $GITHUB_OUTPUT + exit 0 + fi + + echo "Running agent review response..." + automation-cli review respond \ + "$PR_NUMBER" \ + "$BRANCH_NAME" \ + "$ITERATION_COUNT" \ + "5" + + # -- Agent Failure Handler ----------------------------------------------- + agent-failure-handler: + name: Agent Failure Handler + needs: [ci, gemini-review, codex-review, agent-review-response] + if: | + failure() && + github.event_name == 'pull_request' && + !github.event.pull_request.draft && + !contains(github.event.pull_request.labels.*.name, 'no-auto-fix') && + needs.ci.result == 'failure' + runs-on: self-hosted + timeout-minutes: 30 + outputs: + made_changes: ${{ steps.agent-fix.outputs.made_changes }} + env: + PR_NUMBER: ${{ github.event.pull_request.number }} + GITHUB_TOKEN: ${{ secrets.AGENT_TOKEN }} + GITHUB_REPOSITORY: ${{ github.repository }} + steps: + - name: Pre-checkout cleanup + run: | + for item in outputs target .git/index.lock; do + if [ -d "$item" ] || [ -f "$item" ]; then + docker run --rm -v "$(pwd):/workspace" busybox:1.36.1 sh -c \ + "rm -rf /workspace/$item" 2>/dev/null || \ + sudo rm -rf "$item" 2>/dev/null || true + fi + done + + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + clean: false + token: ${{ secrets.AGENT_TOKEN }} + ref: ${{ github.head_ref }} + + - name: Ensure clean working directory + run: | + git checkout -- . 2>/dev/null || true + git clean -fd 2>/dev/null || true + + - name: Check iteration count + id: iteration + uses: ./.github/actions/agent-iteration-check + with: + pr_number: ${{ github.event.pull_request.number }} + max_iterations: '5' + agent_type: 'failure-fix' + github_token: ${{ secrets.GITHUB_TOKEN }} + + - name: Skip if max iterations reached + if: steps.iteration.outputs.exceeded_max == 'true' + run: | + echo "Maximum iterations (5) reached for failure-fix agent. Manual intervention required." + echo "made_changes=false" >> $GITHUB_OUTPUT + exit 0 + + - name: Run agent failure handler + id: agent-fix + if: steps.iteration.outputs.exceeded_max != 'true' + env: + BRANCH_NAME: ${{ github.head_ref }} + ITERATION_COUNT: ${{ steps.iteration.outputs.iteration_count }} + run: | + if ! command -v automation-cli &>/dev/null; then + echo "::warning::automation-cli not found on PATH - skipping failure handler" + echo "made_changes=false" >> $GITHUB_OUTPUT + exit 0 + fi + + echo "Running agent failure handler..." + automation-cli review failure \ + "$PR_NUMBER" \ + "$BRANCH_NAME" \ + "$ITERATION_COUNT" \ + "5" \ + "format,lint,test" + + # -- PR Status Summary --------------------------------------------------- + pr-status: + name: PR Status Summary + needs: [ci, gemini-review, codex-review, agent-review-response, agent-failure-handler] + if: always() + runs-on: self-hosted + steps: + - name: Generate status summary + run: | + echo "## PR Validation Summary" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "| Check | Status |" >> $GITHUB_STEP_SUMMARY + echo "|-------|--------|" >> $GITHUB_STEP_SUMMARY + echo "| CI | ${{ needs.ci.result }} |" >> $GITHUB_STEP_SUMMARY + echo "| Gemini Review | ${{ needs.gemini-review.result }} |" >> $GITHUB_STEP_SUMMARY + echo "| Codex Review | ${{ needs.codex-review.result }} |" >> $GITHUB_STEP_SUMMARY + echo "| Review Response | ${{ needs.agent-review-response.result }} |" >> $GITHUB_STEP_SUMMARY + echo "| Failure Handler | ${{ needs.agent-failure-handler.result }} |" >> $GITHUB_STEP_SUMMARY + + # CI failure is blocking; reviews are advisory + if [[ "${{ needs.ci.result }}" == "failure" ]]; then + echo "" >> $GITHUB_STEP_SUMMARY + echo "CI failed - please review the logs" >> $GITHUB_STEP_SUMMARY + exit 1 + fi + + echo "" >> $GITHUB_STEP_SUMMARY + echo "PR validation completed successfully" >> $GITHUB_STEP_SUMMARY diff --git a/.gitignore b/.gitignore index 4a173ba..da540e4 100644 --- a/.gitignore +++ b/.gitignore @@ -12,12 +12,6 @@ target/ .DS_Store Thumbs.db -# PSP build artifacts -*.PBP -*.SFO -*.PRX -psp_output_file.log - # Secrets *.key *.pem @@ -46,3 +40,6 @@ node_modules/ # Scratch scratch/ .local/ + +# Build outputs +outputs/ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..e69de29 diff --git a/.mcp.json b/.mcp.json new file mode 100644 index 0000000..8e1e38e --- /dev/null +++ b/.mcp.json @@ -0,0 +1,167 @@ +{ + "mcpServers": { + "code-quality": { + "command": "docker", + "args": [ + "compose", + "-f", + "./docker-compose.yml", + "--profile", + "services", + "run", + "--rm", + "-T", + "mcp-code-quality", + "mcp-code-quality", + "--mode", + "stdio" + ] + }, + "content-creation": { + "command": "docker", + "args": [ + "compose", + "-f", + "./docker-compose.yml", + "--profile", + "services", + "run", + "--rm", + "-T", + "mcp-content-creation", + "mcp-content-creation", + "--mode", + "stdio" + ] + }, + "gemini": { + "command": "docker", + "args": [ + "compose", + "-f", + "./docker-compose.yml", + "--profile", + "services", + "run", + "--rm", + "-T", + "mcp-gemini", + "mcp-gemini", + "--mode", + "stdio" + ], + "env": { + "GEMINI_TIMEOUT": "300", + "GEMINI_USE_CONTAINER": "false" + } + }, + "opencode": { + "command": "docker", + "args": [ + "compose", + "-f", + "./docker-compose.yml", + "--profile", + "services", + "run", + "--rm", + "-T", + "mcp-opencode", + "mcp-opencode", + "--mode", + "stdio" + ] + }, + "crush": { + "command": "docker", + "args": [ + "compose", + "-f", + "./docker-compose.yml", + "--profile", + "services", + "run", + "--rm", + "-T", + "mcp-crush", + "mcp-crush", + "--mode", + "stdio" + ] + }, + "codex": { + "command": "docker", + "args": [ + "compose", + "-f", + "./docker-compose.yml", + "--profile", + "services", + "run", + "--rm", + "-T", + "mcp-codex", + "mcp-codex", + "--mode", + "stdio" + ] + }, + "github-board": { + "command": "docker", + "args": [ + "compose", + "-f", + "./docker-compose.yml", + "--profile", + "services", + "run", + "--rm", + "-T", + "mcp-github-board", + "mcp-github-board", + "--mode", + "stdio" + ], + "env": { + "GITHUB_TOKEN": "${GITHUB_TOKEN}", + "GITHUB_REPOSITORY": "${GITHUB_REPOSITORY}", + "GITHUB_PROJECT_NUMBER": "1" + } + }, + "agentcore-memory": { + "command": "docker", + "args": [ + "compose", + "-f", + "./docker-compose.yml", + "--env-file", + "./.env", + "--profile", + "memory", + "run", + "--rm", + "-T", + "mcp-agentcore-memory", + "mcp-agentcore-memory", + "--mode", + "stdio" + ] + }, + "reaction-search": { + "command": "docker", + "args": [ + "compose", + "-f", + "./docker-compose.yml", + "--profile", + "services", + "run", + "--rm", + "-T", + "mcp-reaction-search", + "--mode", + "stdio" + ] + } + } +} diff --git a/.nvmrc b/.nvmrc new file mode 100644 index 0000000..5b54067 --- /dev/null +++ b/.nvmrc @@ -0,0 +1 @@ +22.16.0 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..81309a7 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,53 @@ +# Pre-commit hooks for game-mods +# Run manually: pre-commit run --all-files +# Install hooks: pre-commit install + +repos: + # General file hygiene + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + args: [--unsafe] + - id: check-added-large-files + args: [--maxkb=1000] + - id: check-json + - id: pretty-format-json + args: [--autofix, --no-sort-keys] + - id: check-merge-conflict + - id: check-case-conflict + - id: mixed-line-ending + args: [--fix=lf] + + # GitHub Actions workflow linting + - repo: https://github.com/rhysd/actionlint + rev: v1.7.10 + hooks: + - id: actionlint + + # Shell script linting + - repo: https://github.com/shellcheck-py/shellcheck-py + rev: v0.10.0.1 + hooks: + - id: shellcheck + types: [shell] + args: [-x] + + # Rust formatting and linting (containerized) + - repo: local + hooks: + - id: rust-fmt + name: Rust format check + entry: docker compose --profile ci run --rm rust-ci cargo fmt --all -- --check + language: system + files: '\.rs$' + pass_filenames: false + + - id: rust-clippy + name: Rust clippy lint + entry: docker compose --profile ci run --rm rust-ci cargo clippy --all-targets -- -D warnings + language: system + files: '\.rs$' + pass_filenames: false diff --git a/.rgignore b/.rgignore new file mode 100644 index 0000000..2f7896d --- /dev/null +++ b/.rgignore @@ -0,0 +1 @@ +target/ diff --git a/.secrets.yaml b/.secrets.yaml new file mode 100644 index 0000000..44c04f7 --- /dev/null +++ b/.secrets.yaml @@ -0,0 +1,201 @@ +# Secrets Configuration for Automatic Masking +# This configuration ensures secrets are never exposed in public GitHub comments +# Used by all AI agents and automation tools + +version: "1.0.0" +description: "Central configuration for automatic secret masking in public outputs" + +# Environment variables that contain secrets +# These will be masked if their values appear in any GitHub comment +environment_variables: + # GitHub tokens and authentication + - GITHUB_TOKEN + - GH_TOKEN + - AGENT_TOKEN + + # API keys for various services + - OPENROUTER_API_KEY + - ANTHROPIC_API_KEY + - OPENAI_API_KEY + - GEMINI_API_KEY + - CLAUDE_API_KEY + + # Database passwords + - DB_PASSWORD + - DATABASE_PASSWORD + - POSTGRES_PASSWORD + - MYSQL_PASSWORD + - REDIS_PASSWORD + + # Application secrets + - SECRET_KEY + - JWT_SECRET + - SESSION_SECRET + - WEBHOOK_SECRET + - WEBHOOK_URL + + # Cloud provider credentials + - AWS_ACCESS_KEY_ID + - AWS_SECRET_ACCESS_KEY + - AWS_SESSION_TOKEN + - GOOGLE_API_KEY + - GOOGLE_CLIENT_SECRET + - AZURE_CLIENT_SECRET + + # Container registry credentials + - DOCKER_PASSWORD + - DOCKER_HUB_PASSWORD + + # Package registry tokens + - NPM_TOKEN + - PYPI_TOKEN + + # Communication service tokens + - SLACK_TOKEN + - SLACK_WEBHOOK_URL + - DISCORD_TOKEN + - DISCORD_WEBHOOK_URL + - TELEGRAM_BOT_TOKEN + + # Payment service keys + - STRIPE_SECRET_KEY + - STRIPE_WEBHOOK_SECRET + - PAYPAL_CLIENT_SECRET + + # Email service keys + - SENDGRID_API_KEY + - MAILGUN_API_KEY + - TWILIO_AUTH_TOKEN + + # Cryptographic keys + - SSH_PRIVATE_KEY + - GPG_PRIVATE_KEY + - ENCRYPTION_KEY + + # Generic secrets + - API_SECRET + - CLIENT_SECRET + - PRIVATE_KEY + - ACCESS_TOKEN + - REFRESH_TOKEN + - BEARER_TOKEN + - AUTH_TOKEN + + # Service URLs (may contain embedded credentials) + - COMFYUI_SERVER_URL + - AI_TOOLKIT_SERVER_URL + - GAEA2_REMOTE_URL + +# Patterns for detecting secrets by format +# These regex patterns identify common secret formats +patterns: + - name: GITHUB_TOKEN + pattern: "ghp_[A-Za-z0-9_]{36,}" + description: "GitHub personal access token" + + - name: GITHUB_SECRET + pattern: "ghs_[A-Za-z0-9_]{36,}" + description: "GitHub secret" + + - name: GITHUB_PAT + pattern: "github_pat_[A-Za-z0-9_]{22,}" + description: "GitHub personal access token (new format)" + + - name: GITHUB_OAUTH + pattern: "gho_[A-Za-z0-9_]{36,}" + description: "GitHub OAuth token" + + - name: API_KEY + pattern: "sk-[A-Za-z0-9\\-]{32,}" + description: "Generic API key (OpenAI style)" + + - name: PRIVATE_KEY + pattern: "pk-[A-Za-z0-9\\-]{32,}" + description: "Generic private key" + + - name: BEARER_TOKEN + pattern: "Bearer\\s+[A-Za-z0-9\\-_=]{20,}" + description: "Bearer authentication token" + + - name: URL_WITH_AUTH + pattern: "https?://[^:\\s]+:[^@\\s]+@[^\\s]+" + description: "URL with embedded credentials" + + - name: AWS_ACCESS_KEY + pattern: "AKIA[0-9A-Z]{16}" + description: "AWS access key ID" + + - name: JWT_TOKEN + pattern: "eyJ[A-Za-z0-9\\-_=]+\\.[A-Za-z0-9\\-_=]+\\.[A-Za-z0-9\\-_=]+" + description: "JSON Web Token" + + - name: PRIVATE_KEY_BLOCK + pattern: "-----BEGIN[A-Z\\s]+PRIVATE KEY-----[\\s\\S]+?-----END[A-Z\\s]+PRIVATE KEY-----" + description: "Private key block (RSA, EC, etc.)" + + - name: SSH_KEY + pattern: "ssh-rsa\\s+[A-Za-z0-9+/]{100,}" + description: "SSH public key" + + - name: SLACK_TOKEN + pattern: "xox[baprs]-[0-9]{10,}-[A-Za-z0-9]{24,}" + description: "Slack token" + + - name: STRIPE_KEY + pattern: "(sk|pk)_(test|live)_[A-Za-z0-9]{24,}" + description: "Stripe API key" + + - name: NPM_TOKEN + pattern: "npm_[A-Za-z0-9]{36}" + description: "NPM access token" + +# Auto-detection rules for environment variables +auto_detection: + enabled: true + + # Patterns that indicate a variable contains secrets + # Uses glob-style patterns (* = any characters) + include_patterns: + - "*_TOKEN" + - "*_SECRET" + - "*_KEY" + - "*_PASSWORD" + - "*_PASS" + - "*_PWD" + - "*_CREDENTIAL*" + - "*_AUTH" + - "TOKEN_*" + - "SECRET_*" + - "KEY_*" + - "PASSWORD_*" + - "API_*" + - "PRIVATE_*" + + # Patterns to exclude from auto-detection + # These are typically safe to expose + exclude_patterns: + - "PUBLIC_KEY" + - "PUBLIC_*" + - "*_PUBLIC" + - "*_PUBLIC_*" + - "ENABLE_*" + - "*_ENABLED" + - "*_DISABLED" + +# Settings for the masking behavior +settings: + # Minimum length for a value to be considered a secret + minimum_secret_length: 4 + + # Whether pattern matching is case-sensitive + case_sensitive_patterns: false + + # Whether to mask partial matches + mask_partial_matches: false + + # Whether to log when secrets are masked (to stderr) + log_masked_secrets: true + + # Format for masked values + # {name} will be replaced with the variable/pattern name + mask_format: "[MASKED_{name}]" diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..6c3791e --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,111 @@ +# AGENTS.md + +This file provides guidance to AI coding agents working with code in this repository. + +## Project Overview + +Rust monorepo for game modification projects built on the **Injection Toolkit (ITK)** framework. The architecture follows "minimal injection, maximal external processing" — injected code in the target game process is kept minimal, with heavy processing in external daemon and overlay processes connected via IPC and shared memory. + +All code is authored by AI agents under human direction. No external contributions are accepted (see `CONTRIBUTING.md`). + +## Build and Test Commands + +This is a Cargo workspace. All CI runs containerized via Docker but the commands work locally: + +```bash +# Format check +cargo fmt --all -- --check + +# Lint (warnings are errors in CI) +cargo clippy --all-targets -- -D warnings + +# Run all tests +cargo test + +# Run a single crate's tests +cargo test -p itk-protocol + +# Run a specific test by name +cargo test -p itk-shmem -- seqlock_concurrent + +# Build release +cargo build --release + +# License and advisory audit +cargo deny check +``` + +CI pipeline order: **fmt → clippy → test → build → cargo-deny**. + +To match CI exactly (containerized): +```bash +docker compose --profile ci run --rm rust-ci cargo test +``` + +## Code Style + +- Rust Edition 2024. Formatting enforced by `rustfmt.toml`: 100-char max line width, 4-space indentation, Unix newlines, `Tall` fn params layout. +- Run `cargo fmt --all` before committing. CI rejects unformatted code. +- Clippy warnings treated as errors in CI: `cargo clippy --all-targets -- -D warnings`. +- Workspace-level lints in root `Cargo.toml`: `clippy::dbg_macro`, `clippy::todo`, `clippy::unimplemented`, and `unsafe_op_in_unsafe_fn` are all warnings. + +## Workspace Structure + +13 crates organized into four layers: + +**Core libraries** (`core/`): +- `itk-protocol` — Wire protocol (20-byte header + bincode payload, CRC32 validated, 1 MB max) +- `itk-shmem` — Cross-platform shared memory with seqlock (single-writer, multi-reader) +- `itk-ipc` — Named pipes (Windows) / Unix sockets (Linux) +- `itk-sync` — Clock synchronization and drift correction +- `itk-video` — Video decoding via ffmpeg-next +- `itk-net` — P2P networking via laminar + +**Framework templates**: `daemon/`, `overlay/`, `injectors/windows/native-dll/`, `injectors/linux/ld-preload/` + +**Active project**: `projects/nms-cockpit-video/` — No Man's Sky cockpit video player (daemon, injector, overlay, launcher) + +**Tools**: `tools/mem-scanner/` — Memory pattern scanning for reverse engineering + +## Architecture + +``` +Launcher (orchestration) + ├── Daemon (external) — video decode, audio, IPC server, shared memory writer + ├── Injector (DLL/SO) — Vulkan hooks, minimal state extraction, IPC client + └── Overlay (optional) — egui + wgpu transparent window, shared memory reader +``` + +Components run in separate processes. A crash in any one component does not bring down the others. + +## Security Considerations + +- All data from injectors is **untrusted**. Validate: NaN/Inf, numeric bounds, string lengths (256-byte cap), data sizes (64 KB cap). +- Seqlock shared memory is **single-writer only**. Multiple concurrent writers corrupt the sequence counter. +- Unsafe code requires detailed `// SAFETY:` comments explaining the invariant. +- Named pipes use process-token ACLs (Windows); Unix sockets use `0600` permissions (Linux). + +## Conventions + +- Dependencies are managed at workspace level in the root `Cargo.toml`. Add new dependencies there. +- Error handling: `thiserror` for library error types, `anyhow` for application-level errors. +- Logging: `tracing` crate with `tracing-subscriber` env-filter. Not `log` or `println!`. +- Platform-specific code uses `cfg_if!` blocks in `itk-shmem` and `itk-ipc`. +- Protocol changes: add variant to `MessageType` enum in `itk-protocol`, define payload struct with serde derives, update daemon handlers. +- New injector platforms go in `injectors/` and implement IPC client + platform-specific init. + +## CI/CD Pipeline + +Three GitHub Actions workflows on self-hosted runners: + +- **`ci.yml`** — Runs on push to main. Format, lint, test, build, cargo-deny. +- **`main-ci.yml`** — Runs on main push and tags. Same CI stages plus auto-creates GitHub issues on failure. +- **`pr-validation.yml`** — Runs on PRs. CI stages + Gemini AI review (primary) + Codex AI review (secondary) + automated agent fix iterations (max 5, extendable with `[CONTINUE]` comment). Add `no-auto-fix` label to disable automated fixes. + +Agent commit authors: `AI Review Agent`, `AI Pipeline Agent`, `AI Agent Bot`. + +## Known Advisory Exemptions + +Two advisories ignored in `deny.toml`: +- `RUSTSEC-2025-0141` — bincode unmaintained (migration planned) +- `RUSTSEC-2026-0007` — bytes integer overflow (update pending) diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..758a820 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,85 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +Rust monorepo for game modification projects built on the **Injection Toolkit (ITK)** framework. Follows a "minimal injection, maximal external processing" architecture where injected code in the target game process is kept minimal, with heavy processing delegated to external daemon and overlay processes communicating via IPC and shared memory. + +All code is authored by AI agents under human direction. No external contributions are accepted. + +## Build & CI Commands + +All CI runs inside Docker containers via `docker compose --profile ci`. The local equivalents: + +```bash +# Format check +cargo fmt --all -- --check + +# Lint (warnings are errors in CI) +cargo clippy --all-targets -- -D warnings + +# Test +cargo test + +# Build release +cargo build --release + +# License/advisory check +cargo deny check + +# Run a single test +cargo test -p itk-protocol -- test_name + +# Containerized (matches CI exactly) +docker compose --profile ci run --rm rust-ci cargo test +``` + +CI pipeline order: fmt check → clippy → test → build → cargo-deny. + +## Workspace Structure + +**Core libraries** (`core/`): Shared ITK crates used by all projects. +- `itk-protocol` — Wire protocol (20-byte header + bincode payload, CRC32 validated) +- `itk-shmem` — Cross-platform shared memory with seqlock (single-writer, multi-reader) +- `itk-ipc` — Named pipes (Windows) / Unix sockets (Linux) +- `itk-sync` — Clock synchronization and drift correction +- `itk-video` — Video decoding via ffmpeg-next +- `itk-net` — P2P networking via laminar + +**Framework templates**: `daemon/`, `overlay/`, `injectors/windows/native-dll/`, `injectors/linux/ld-preload/` + +**Active project**: `projects/nms-cockpit-video/` (daemon, injector, overlay, launcher) + +**Tools**: `tools/mem-scanner/` — Memory pattern scanning for reverse engineering + +## Architecture + +``` +Launcher (orchestration) + ├── Daemon (external) — video decode, audio, IPC server, shared memory writer + ├── Injector (DLL/SO) — Vulkan hooks, minimal state extraction, IPC client + └── Overlay (optional) — egui + wgpu transparent window, shared memory reader +``` + +Key design constraints: +- Injector must stay under 5 MB memory, no blocking operations, no complex processing +- Seqlock shared memory is **single-writer only** — multiple writers corrupt data +- All data from injectors is treated as **untrusted** — validate NaN/Inf, bounds, string lengths (256 byte cap), data size (64 KB cap) +- Components run in separate processes; crashes are isolated (overlay crash doesn't crash target) + +## Code Conventions + +- Rust Edition 2024, max line width 100 chars, 4-space indentation (see `rustfmt.toml`) +- Workspace-level dependency versions in root `Cargo.toml` +- Workspace lints: `clippy::dbg_macro`, `clippy::todo`, `clippy::unimplemented` are warnings; `unsafe_op_in_unsafe_fn` is a warning +- Error handling: `thiserror` for structured errors, `anyhow` for application-level +- Logging: `tracing` + `tracing-subscriber` with env-filter +- Platform abstraction via `cfg_if!` blocks in shmem and IPC crates +- Unsafe code requires detailed safety comments + +## Known Advisory Exemptions + +See `deny.toml` — two advisories are currently ignored: +- `RUSTSEC-2025-0141` (bincode unmaintained) — migration planned +- `RUSTSEC-2026-0007` (bytes integer overflow) — update pending diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..3d37c88 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,19 @@ +# Contributing + +This repository does not accept external contributions. All code changes are authored by AI agents (Claude, Gemini, Codex, OpenCode, Crush) operating under human direction. + +## No External Contributions + +This is a single-maintainer project. The CI/CD pipeline, agent review system, and security model are designed around autonomous agent authorship with human oversight. Accepting external PRs would break assumptions the tooling is built on. + +## No Feature Requests or Support + +Feature requests, guidance, consulting, and support are not provided. This policy is not negotiable. + +## What You Can Do + +- **Fork it.** Clone the repo and adapt it however you want. This is the recommended path if you need features that don't exist here. +- **Study it.** The codebase demonstrates Rust injection toolkit patterns, Vulkan hooking, cross-platform IPC, and containerized CI/CD with agent-driven development. +- **Use it.** You are free to use any component under the terms of the [MIT License](LICENSE-MIT). + +You do so entirely at your own risk and without any expectation of support, maintenance, or acknowledgment from the maintainer. diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..ff13e4d --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,106 @@ +[workspace] +resolver = "2" +members = [ + # Core libraries (injection toolkit foundation) + "core/itk-shmem", + "core/itk-ipc", + "core/itk-protocol", + "core/itk-sync", + "core/itk-video", + "core/itk-net", + # Framework templates + "daemon", + "overlay", + "injectors/linux/ld-preload", + "injectors/windows/native-dll", + # NMS Cockpit Video Player + "projects/nms-cockpit-video/daemon", + "projects/nms-cockpit-video/overlay", + "projects/nms-cockpit-video/injector", + "projects/nms-cockpit-video/launcher", + # Development tools + "tools/mem-scanner", +] + +[workspace.package] +version = "0.1.0" +edition = "2024" +license = "MIT" +repository = "https://github.com/AndrewAltimit/game-mods" +authors = ["AndrewAltimit"] + +[workspace.dependencies] +# Serialization +serde = { version = "1.0", features = ["derive"] } +bincode = "1.3" + +# Async runtime +tokio = { version = "1.0", features = ["full"] } + +# Error handling +thiserror = "1.0" +anyhow = "1.0" + +# Logging +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } + +# Cross-platform utilities +cfg-if = "1.0" + +# CRC for protocol validation +crc32fast = "1.4" + +# Video decoding +ffmpeg-next = "7.0" + +# Audio output +cpal = "0.15" +ringbuf = "0.4" + +# Networking +laminar = "0.5" +parking_lot = "0.12" + +# Vulkan +ash = { version = "0.38", features = ["loaded"] } + +# Hooking +retour = { version = "0.3", features = ["static-detour"] } + +# Platform-specific (Windows) +[workspace.dependencies.windows] +version = "0.58" +features = [ + "Win32_Foundation", + "Win32_System_Memory", + "Win32_System_Pipes", + "Win32_System_IO", + "Win32_System_SystemServices", + "Win32_System_LibraryLoader", + "Win32_System_Threading", + "Win32_System_Diagnostics_Debug", + "Win32_System_Diagnostics_ToolHelp", + "Win32_Security", + "Win32_Security_Authorization", + "Win32_Storage_FileSystem", + "Win32_System_DataExchange", + "Win32_UI_Input_KeyboardAndMouse", +] + +# Platform-specific (Unix) +[workspace.dependencies.libc] +version = "0.2" + +[workspace.dependencies.nix] +version = "0.29" +features = ["mman", "fs"] + +[workspace.lints.clippy] +clone_on_ref_ptr = "warn" +dbg_macro = "warn" +todo = "warn" +unimplemented = "warn" + +[workspace.lints.rust] +unsafe_op_in_unsafe_fn = "warn" diff --git a/LICENSE-MIT b/LICENSE-MIT new file mode 100644 index 0000000..b8c15c8 --- /dev/null +++ b/LICENSE-MIT @@ -0,0 +1,19 @@ +Copyright 2025 Andrew Showers + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index f887ec5..e545481 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,112 @@ -# game-mods -Game mods +# Game Mods + +**Monorepo for game modification projects built on the Injection Toolkit framework.** + +Game-specific mods that inject custom rendering, overlays, and automation into game processes. Built in Rust with cross-platform IPC, shared memory, and Vulkan/OpenVR hooking. + +All code is authored by AI agents under human direction. + +## Projects + +| Mod | Game | Description | +|-----|------|-------------| +| [NMS Cockpit Video](projects/nms-cockpit-video/) | No Man's Sky | In-cockpit video player via Vulkan injection + desktop overlay | + +## Architecture + +The injection toolkit provides a **minimal injection, maximal external processing** framework: + +``` +Launcher (orchestration) + | + +-- Daemon (external process) <-- video decode, audio, IPC server + | | + | +-- Shared Memory <-- lock-free frame transport (seqlock) + | +-- IPC (named pipes) <-- commands, projection data + | + +-- Injector (DLL in target) <-- Vulkan hooks, texture rendering + | + +-- Overlay (optional desktop) <-- egui + wgpu transparent window +``` + +### Core Libraries + +| Crate | Purpose | +|-------|---------| +| `itk-protocol` | Wire protocol definitions (serde + bincode) | +| `itk-shmem` | Cross-platform shared memory (Windows/Linux) | +| `itk-ipc` | Cross-platform IPC channels (named pipes/Unix sockets) | +| `itk-sync` | Clock synchronization and drift correction | +| `itk-video` | Video decoding via ffmpeg + frame management | +| `itk-net` | P2P networking for multiplayer sync (laminar) | + +### Framework Templates + +| Crate | Purpose | +|-------|---------| +| `itk-daemon` | Central coordinator daemon template | +| `itk-overlay` | wgpu-based transparent overlay window template | +| `itk-native-dll` | Windows DLL injection template | +| `itk-ld-preload` | Linux LD_PRELOAD injection template | + +### Tools + +| Tool | Purpose | +|------|---------| +| `mem-scanner` | Memory pattern scanning utility for reverse engineering | + +## Project Structure + +``` +game-mods/ ++-- core/ # Shared libraries +| +-- itk-protocol/ +| +-- itk-shmem/ +| +-- itk-ipc/ +| +-- itk-sync/ +| +-- itk-video/ +| +-- itk-net/ ++-- daemon/ # Framework: coordinator daemon ++-- overlay/ # Framework: transparent overlay ++-- injectors/ # Framework: injection templates +| +-- windows/native-dll/ +| +-- linux/ld-preload/ ++-- projects/ # Game-specific mods +| +-- nms-cockpit-video/ +| +-- daemon/ # NMS video playback daemon +| +-- injector/ # Vulkan DLL injection (cdylib) +| +-- overlay/ # Desktop overlay (egui + wgpu) +| +-- launcher/ # Process orchestrator +| +-- mod/ # Reloaded-II C# mod (optional) +| +-- docs/ # Reverse engineering notes ++-- tools/ +| +-- mem-scanner/ # Memory scanning utility ++-- docker/ # CI Dockerfiles ++-- .github/ # GitHub Actions workflows +``` + +## Development + +```bash +# Containerized CI (matches GitHub Actions) +docker compose --profile ci run --rm rust-ci cargo fmt --all -- --check +docker compose --profile ci run --rm rust-ci cargo clippy --all-targets -- -D warnings +docker compose --profile ci run --rm rust-ci cargo test +docker compose --profile ci run --rm rust-ci cargo build --release +docker compose --profile ci run --rm rust-ci cargo deny check +``` + +## Technology Stack + +| Layer | Technology | +|-------|-----------| +| Language | Rust (Edition 2024) | +| Video | ffmpeg-next 7.0, cpal 0.15 (audio) | +| Graphics | ash 0.38 (Vulkan), wgpu 0.20, egui 0.28 | +| Hooking | retour 0.3 (function detours) | +| Platform | windows 0.58 (Win32), nix 0.29 (Unix) | +| CI/CD | GitHub Actions (self-hosted runner, Docker containers) | + +## License + +Dual-licensed under [Unlicense](LICENSE) and [MIT](LICENSE-MIT). diff --git a/core/itk-ipc/Cargo.toml b/core/itk-ipc/Cargo.toml new file mode 100644 index 0000000..6dc635d --- /dev/null +++ b/core/itk-ipc/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "itk-ipc" +description = "Cross-platform IPC channels for the Injection Toolkit" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +authors.workspace = true + +[dependencies] +itk-protocol = { path = "../itk-protocol" } +thiserror = { workspace = true } +tokio = { workspace = true } +tracing = { workspace = true } +cfg-if = { workspace = true } + +[target.'cfg(windows)'.dependencies] +windows = { workspace = true } + +[target.'cfg(unix)'.dependencies] +libc = { workspace = true } +nix = { workspace = true } + +[dev-dependencies] +rand = "0.8" + +[lints] +workspace = true diff --git a/core/itk-ipc/src/lib.rs b/core/itk-ipc/src/lib.rs new file mode 100644 index 0000000..478ab75 --- /dev/null +++ b/core/itk-ipc/src/lib.rs @@ -0,0 +1,322 @@ +//! # ITK IPC +//! +//! Cross-platform IPC channels for the Injection Toolkit. +//! +//! This crate provides: +//! - Platform-agnostic IPC channel abstraction +//! - Async support via tokio +//! - Message framing with the ITK protocol +//! +//! ## Platform Support +//! +//! - **Windows**: Named pipes via `\\.\pipe\itk_*` +//! - **Linux**: Unix domain sockets via `/tmp/itk_*.sock` + +use itk_protocol::{HEADER_SIZE, Header}; +use std::io; +use thiserror::Error; + +/// IPC errors +#[derive(Error, Debug)] +pub enum IpcError { + #[error("connection failed: {0}")] + ConnectionFailed(String), + + #[error("channel closed")] + ChannelClosed, + + #[error("timeout waiting for connection")] + Timeout, + + #[error("invalid channel name: {0}")] + InvalidName(String), + + #[error("protocol error: {0}")] + Protocol(#[from] itk_protocol::ProtocolError), + + #[error("IO error: {0}")] + Io(#[from] io::Error), + + #[error("already listening")] + AlreadyListening, + + #[error("not connected")] + NotConnected, + + #[error("platform error: {0}")] + Platform(String), +} + +/// Result type for IPC operations +pub type Result = std::result::Result; + +/// IPC channel trait for platform-agnostic messaging +pub trait IpcChannel: Send + Sync { + /// Send raw bytes over the channel + fn send(&self, data: &[u8]) -> Result<()>; + + /// Receive raw bytes from the channel + /// + /// Blocks until data is available or the channel is closed. + fn recv(&self) -> Result>; + + /// Try to receive without blocking + /// + /// Returns None if no data is available. + fn try_recv(&self) -> Result>>; + + /// Check if the channel is connected + fn is_connected(&self) -> bool; + + /// Close the channel + fn close(&self); +} + +// Note: Async IPC support can be added in the future via a feature flag + +/// IPC server that accepts connections +pub trait IpcServer: Send + Sync { + /// The channel type returned when accepting connections + type Channel: IpcChannel; + + /// Accept a new connection + /// + /// Blocks until a client connects. + fn accept(&self) -> Result; + + /// Stop listening and close the server + fn close(&self); +} + +/// Create a platform-appropriate channel name +pub fn make_channel_name(base_name: &str) -> String { + cfg_if::cfg_if! { + if #[cfg(windows)] { + format!(r"\\.\pipe\itk_{}", base_name) + } else { + format!("/tmp/itk_{}.sock", base_name) + } + } +} + +// Platform-specific implementations +cfg_if::cfg_if! { + if #[cfg(windows)] { + mod windows_impl; + pub use windows_impl::{NamedPipeClient, NamedPipeServer}; + + /// Create a client channel connected to the given name + pub fn connect(name: &str) -> Result { + windows_impl::NamedPipeClient::connect(name) + } + + /// Create a server listening on the given name + pub fn listen(name: &str) -> Result { + windows_impl::NamedPipeServer::new(name) + } + } else if #[cfg(unix)] { + mod unix_impl; + pub use unix_impl::{UnixSocketClient, UnixSocketServer, UnixSocketConnection}; + + /// Create a client channel connected to the given name + pub fn connect(name: &str) -> Result { + unix_impl::UnixSocketClient::connect(name) + } + + /// Create a server listening on the given name + pub fn listen(name: &str) -> Result { + unix_impl::UnixSocketServer::new(name) + } + } +} + +/// Helper for reading length-prefixed messages +pub fn read_message(reader: &mut impl io::Read) -> Result> { + // Read header first + let mut header_buf = [0u8; HEADER_SIZE]; + reader.read_exact(&mut header_buf)?; + + let header = Header::from_bytes(&header_buf)?; + + // Read payload + let mut payload = vec![0u8; header.payload_len as usize]; + reader.read_exact(&mut payload)?; + + // Return full message (header + payload) + let mut message = Vec::with_capacity(HEADER_SIZE + payload.len()); + message.extend_from_slice(&header_buf); + message.extend_from_slice(&payload); + + Ok(message) +} + +/// Helper for writing length-prefixed messages +pub fn write_message(writer: &mut impl io::Write, data: &[u8]) -> Result<()> { + writer.write_all(data)?; + writer.flush()?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_channel_name_format() { + let name = make_channel_name("test"); + + #[cfg(windows)] + assert!(name.starts_with(r"\\.\pipe\itk_")); + + #[cfg(unix)] + assert!(name.starts_with("/tmp/itk_") && name.ends_with(".sock")); + } +} + +/// Integration tests for IPC communication +/// These tests use actual OS sockets/pipes with proper protocol framing +#[cfg(test)] +#[allow(unused_imports, dead_code)] +mod integration_tests { + use super::*; + use itk_protocol::{MessageType, decode, encode}; + use rand::Rng; + use std::thread; + use std::time::{Duration, Instant}; + + /// Generate a unique channel name for testing + fn test_channel_name() -> String { + let id: u32 = rand::thread_rng().r#gen(); + format!("itk_test_{}", id) + } + + /// Connect with retry loop instead of sleep (more robust) + #[cfg(unix)] + fn connect_with_retry(name: &str, timeout: Duration) -> UnixSocketClient { + let start = Instant::now(); + while start.elapsed() < timeout { + if let Ok(client) = connect(name) { + return client; + } + thread::sleep(Duration::from_millis(10)); + } + panic!("Failed to connect to IPC server within {:?}", timeout); + } + + #[test] + #[cfg(unix)] + fn test_unix_socket_ping_pong() { + let channel_name = test_channel_name(); + let channel_name_client = channel_name.clone(); + + // Start server in background thread + let server_handle = thread::spawn(move || { + let server = listen(&channel_name).expect("Failed to create server"); + let conn = server.accept().expect("Failed to accept connection"); + + // Receive message (should be a Ping) + let msg = conn.recv().expect("Failed to receive"); + let (msg_type, _): (MessageType, ()) = decode(&msg).expect("Failed to decode"); + assert_eq!(msg_type, MessageType::Ping); + + // Send Pong response + let pong = encode(MessageType::Pong, &()).expect("Failed to encode"); + conn.send(&pong).expect("Failed to send"); + }); + + // Connect with retry loop (robust against slow server startup) + let client = connect_with_retry(&channel_name_client, Duration::from_secs(2)); + + // Send ping with proper protocol framing + let ping = encode(MessageType::Ping, &()).expect("Failed to encode"); + client.send(&ping).expect("Failed to send"); + + // Receive pong + let response = client.recv().expect("Failed to receive"); + let (msg_type, _): (MessageType, ()) = decode(&response).expect("Failed to decode"); + assert_eq!(msg_type, MessageType::Pong); + + server_handle.join().expect("Server thread panicked"); + } + + #[test] + #[cfg(unix)] + fn test_unix_socket_screen_rect() { + use itk_protocol::ScreenRect; + + let channel_name = test_channel_name(); + let channel_name_client = channel_name.clone(); + + let server_handle = thread::spawn(move || { + let server = listen(&channel_name).expect("Failed to create server"); + let conn = server.accept().expect("Failed to accept"); + + let msg = conn.recv().expect("Failed to receive"); + let (msg_type, rect): (MessageType, ScreenRect) = + decode(&msg).expect("Failed to decode"); + + assert_eq!(msg_type, MessageType::ScreenRect); + assert_eq!(rect.x, 100.0); + assert_eq!(rect.y, 200.0); + assert_eq!(rect.width, 640.0); + assert_eq!(rect.height, 480.0); + + // Acknowledge + let pong = encode(MessageType::Pong, &()).expect("Failed to encode"); + conn.send(&pong).expect("Failed to send"); + }); + + let client = connect_with_retry(&channel_name_client, Duration::from_secs(2)); + + let rect = ScreenRect { + x: 100.0, + y: 200.0, + width: 640.0, + height: 480.0, + rotation: 0.0, + visible: true, + }; + let msg = encode(MessageType::ScreenRect, &rect).expect("Failed to encode"); + client.send(&msg).expect("Failed to send"); + + let response = client.recv().expect("Failed to receive"); + let (msg_type, _): (MessageType, ()) = decode(&response).expect("Failed to decode"); + assert_eq!(msg_type, MessageType::Pong); + + server_handle.join().expect("Server thread panicked"); + } + + #[test] + #[cfg(unix)] + fn test_unix_socket_multiple_pings() { + let channel_name = test_channel_name(); + let channel_name_client = channel_name.clone(); + + let server_handle = thread::spawn(move || { + let server = listen(&channel_name).expect("Failed to create server"); + let conn = server.accept().expect("Failed to accept"); + + for _ in 0..5 { + let msg = conn.recv().expect("Failed to receive"); + let (msg_type, _): (MessageType, ()) = decode(&msg).expect("Failed to decode"); + assert_eq!(msg_type, MessageType::Ping); + } + + let pong = encode(MessageType::Pong, &()).expect("Failed to encode"); + conn.send(&pong).expect("Failed to send"); + }); + + let client = connect_with_retry(&channel_name_client, Duration::from_secs(2)); + + for _ in 0..5 { + let ping = encode(MessageType::Ping, &()).expect("Failed to encode"); + client.send(&ping).expect("Failed to send"); + } + + let response = client.recv().expect("Failed to receive"); + let (msg_type, _): (MessageType, ()) = decode(&response).expect("Failed to decode"); + assert_eq!(msg_type, MessageType::Pong); + + server_handle.join().expect("Server thread panicked"); + } +} diff --git a/core/itk-ipc/src/unix_impl.rs b/core/itk-ipc/src/unix_impl.rs new file mode 100644 index 0000000..4382b7b --- /dev/null +++ b/core/itk-ipc/src/unix_impl.rs @@ -0,0 +1,364 @@ +//! Unix domain socket implementation + +use super::{IpcChannel, IpcError, IpcServer, Result, read_message}; +use std::fs; +use std::io::Write; +use std::os::unix::net::{UnixListener, UnixStream}; +use std::sync::Mutex; +use std::sync::atomic::{AtomicBool, Ordering}; + +/// Validate and create a socket path from a name. +/// +/// For security, this function: +/// - Rejects names with path traversal sequences (..) +/// - Rejects absolute paths (prevents arbitrary file operations) +/// - Creates paths only in /tmp/itk_*.sock +fn make_socket_path(name: &str) -> Result { + // Reject path traversal attempts + if name.contains("..") { + return Err(IpcError::Platform( + "socket name cannot contain path traversal sequences".into(), + )); + } + + // Reject absolute paths - prevents arbitrary file deletion/creation + if name.starts_with('/') { + return Err(IpcError::Platform( + "socket name cannot be an absolute path".into(), + )); + } + + // Reject names with slashes to prevent subdirectory traversal + if name.contains('/') || name.contains('\\') { + return Err(IpcError::Platform( + "socket name cannot contain path separators".into(), + )); + } + + Ok(format!("/tmp/itk_{}.sock", name)) +} + +/// Safely remove a socket file, verifying it's actually a socket first. +/// +/// Uses symlink_metadata to avoid following symlinks, preventing any +/// interaction with symlink targets. +fn remove_socket_file(path: &str) { + // Use symlink_metadata to not follow symlinks - strictly safer + if let Ok(metadata) = fs::symlink_metadata(path) { + use std::os::unix::fs::FileTypeExt; + if metadata.file_type().is_socket() { + let _ = fs::remove_file(path); + } + // If it's not a socket, don't remove it - could be a regular file + } + // If metadata fails (file doesn't exist), that's fine +} + +/// Non-blocking receive helper that keeps the lock held while consuming data. +/// +/// This prevents race conditions where another thread could consume data between +/// peeking and actually reading. +fn try_recv_with_fd(fd: std::os::unix::io::RawFd) -> Result>> { + use itk_protocol::HEADER_SIZE; + + // Use MSG_PEEK to check if enough data is available without consuming bytes. + let mut peek_buf = [0u8; HEADER_SIZE]; + let peeked = unsafe { + libc::recv( + fd, + peek_buf.as_mut_ptr() as *mut libc::c_void, + HEADER_SIZE, + libc::MSG_PEEK | libc::MSG_DONTWAIT, + ) + }; + + if peeked < 0 { + let err = std::io::Error::last_os_error(); + if err.kind() == std::io::ErrorKind::WouldBlock + || err.kind() == std::io::ErrorKind::Interrupted + { + return Ok(None); + } + return Err(IpcError::Io(err)); + } + + // recv returning 0 means EOF (connection closed) + if peeked == 0 { + return Err(IpcError::ChannelClosed); + } + + if (peeked as usize) < HEADER_SIZE { + // Partial header available - not enough data yet + return Ok(None); + } + + // Parse header to determine total message size + let header = itk_protocol::Header::from_bytes(&peek_buf).map_err(IpcError::Protocol)?; + let total_size = HEADER_SIZE + header.payload_len as usize; + + // Peek again to check if full message is available + let mut message = vec![0u8; total_size]; + let peeked_full = unsafe { + libc::recv( + fd, + message.as_mut_ptr() as *mut libc::c_void, + total_size, + libc::MSG_PEEK | libc::MSG_DONTWAIT, + ) + }; + + if peeked_full < 0 { + let err = std::io::Error::last_os_error(); + if err.kind() == std::io::ErrorKind::WouldBlock + || err.kind() == std::io::ErrorKind::Interrupted + { + return Ok(None); + } + return Err(IpcError::Io(err)); + } + + if (peeked_full as usize) < total_size { + return Ok(None); + } + + // Full message is available - consume it (we still hold the lock in the caller) + let received = unsafe { + libc::recv( + fd, + message.as_mut_ptr() as *mut libc::c_void, + total_size, + 0, // Blocking read, but we know data is available + ) + }; + + if received < 0 { + let err = std::io::Error::last_os_error(); + // EINTR on final recv is unusual but handle it by reporting no data available + if err.kind() == std::io::ErrorKind::Interrupted { + return Ok(None); + } + return Err(IpcError::Io(err)); + } + + if (received as usize) != total_size { + return Err(IpcError::Protocol( + itk_protocol::ProtocolError::IncompletePayload { + need: total_size - itk_protocol::HEADER_SIZE, + have: (received as usize).saturating_sub(itk_protocol::HEADER_SIZE), + }, + )); + } + + Ok(Some(message)) +} + +/// Unix domain socket client +pub struct UnixSocketClient { + stream: Mutex, + connected: AtomicBool, + #[allow(dead_code)] + path: String, +} + +impl UnixSocketClient { + /// Connect to a Unix socket server + pub fn connect(name: &str) -> Result { + let path = make_socket_path(name)?; + + let stream = + UnixStream::connect(&path).map_err(|e| IpcError::ConnectionFailed(e.to_string()))?; + + Ok(Self { + stream: Mutex::new(stream), + connected: AtomicBool::new(true), + path, + }) + } +} + +impl IpcChannel for UnixSocketClient { + fn send(&self, data: &[u8]) -> Result<()> { + if !self.is_connected() { + return Err(IpcError::NotConnected); + } + + let mut stream = self.stream.lock().unwrap(); + stream.write_all(data)?; + stream.flush()?; + + Ok(()) + } + + fn recv(&self) -> Result> { + if !self.is_connected() { + return Err(IpcError::NotConnected); + } + + let mut stream = self.stream.lock().unwrap(); + read_message(&mut *stream) + } + + fn try_recv(&self) -> Result>> { + use std::os::unix::io::AsRawFd; + + if !self.is_connected() { + return Err(IpcError::NotConnected); + } + + // Keep lock held during entire operation to prevent race conditions + let stream = self.stream.lock().unwrap(); + let fd = stream.as_raw_fd(); + try_recv_with_fd(fd) + } + + fn is_connected(&self) -> bool { + self.connected.load(Ordering::SeqCst) + } + + fn close(&self) { + if self.connected.swap(false, Ordering::SeqCst) + && let Ok(stream) = self.stream.lock() + { + let _ = stream.shutdown(std::net::Shutdown::Both); + } + } +} + +impl Drop for UnixSocketClient { + fn drop(&mut self) { + self.close(); + } +} + +/// Unix domain socket server +pub struct UnixSocketServer { + listener: UnixListener, + path: String, + listening: AtomicBool, +} + +impl UnixSocketServer { + /// Create a new Unix socket server + /// + /// The socket is created with restricted permissions (0o600) to prevent + /// other users on the system from connecting. + pub fn new(name: &str) -> Result { + let path = make_socket_path(name)?; + + // Safely remove existing socket file if present (verifies it's actually a socket) + remove_socket_file(&path); + + // Set restrictive umask before creating socket to prevent race condition. + // Without this, the socket would briefly exist with default permissions + // (potentially allowing other users to connect) before set_permissions runs. + let old_umask = unsafe { libc::umask(0o077) }; + + let bind_result = UnixListener::bind(&path); + + // Restore original umask immediately after bind + unsafe { + libc::umask(old_umask); + } + + let listener = bind_result.map_err(|e| IpcError::Platform(e.to_string()))?; + + Ok(Self { + listener, + path, + listening: AtomicBool::new(true), + }) + } +} + +impl IpcServer for UnixSocketServer { + type Channel = UnixSocketConnection; + + fn accept(&self) -> Result { + if !self.listening.load(Ordering::SeqCst) { + return Err(IpcError::ChannelClosed); + } + + let (stream, _addr) = self + .listener + .accept() + .map_err(|e| IpcError::ConnectionFailed(e.to_string()))?; + + Ok(UnixSocketConnection { + stream: Mutex::new(stream), + connected: AtomicBool::new(true), + }) + } + + fn close(&self) { + self.listening.store(false, Ordering::SeqCst); + // The listener will be cleaned up on drop + } +} + +impl Drop for UnixSocketServer { + fn drop(&mut self) { + self.close(); + // Clean up socket file (verify it's actually a socket before deletion) + remove_socket_file(&self.path); + } +} + +/// A connected Unix socket (server-side) +pub struct UnixSocketConnection { + stream: Mutex, + connected: AtomicBool, +} + +impl IpcChannel for UnixSocketConnection { + fn send(&self, data: &[u8]) -> Result<()> { + if !self.is_connected() { + return Err(IpcError::NotConnected); + } + + let mut stream = self.stream.lock().unwrap(); + stream.write_all(data)?; + stream.flush()?; + + Ok(()) + } + + fn recv(&self) -> Result> { + if !self.is_connected() { + return Err(IpcError::NotConnected); + } + + let mut stream = self.stream.lock().unwrap(); + read_message(&mut *stream) + } + + fn try_recv(&self) -> Result>> { + use std::os::unix::io::AsRawFd; + + if !self.is_connected() { + return Err(IpcError::NotConnected); + } + + // Keep lock held during entire operation to prevent race conditions + let stream = self.stream.lock().unwrap(); + let fd = stream.as_raw_fd(); + try_recv_with_fd(fd) + } + + fn is_connected(&self) -> bool { + self.connected.load(Ordering::SeqCst) + } + + fn close(&self) { + if self.connected.swap(false, Ordering::SeqCst) + && let Ok(stream) = self.stream.lock() + { + let _ = stream.shutdown(std::net::Shutdown::Both); + } + } +} + +impl Drop for UnixSocketConnection { + fn drop(&mut self) { + self.close(); + } +} diff --git a/core/itk-ipc/src/windows_impl.rs b/core/itk-ipc/src/windows_impl.rs new file mode 100644 index 0000000..1cbe335 --- /dev/null +++ b/core/itk-ipc/src/windows_impl.rs @@ -0,0 +1,461 @@ +//! Windows named pipe implementation +//! +//! Named pipes are created with security descriptors that restrict access +//! to the current user only, preventing other users on the system from +//! connecting to or injecting commands into the daemon. + +use super::{IpcChannel, IpcError, IpcServer, Result, read_message}; +use std::ffi::OsStr; +use std::io::Read; +use std::os::windows::ffi::OsStrExt; +use std::sync::Mutex; +use std::sync::atomic::{AtomicBool, Ordering}; +use windows::Win32::Foundation::{ + CloseHandle, HANDLE, HLOCAL, INVALID_HANDLE_VALUE, LocalFree, WIN32_ERROR, +}; +use windows::Win32::Security::Authorization::{ + ConvertStringSecurityDescriptorToSecurityDescriptorW, SDDL_REVISION_1, +}; +use windows::Win32::Security::{PSECURITY_DESCRIPTOR, SECURITY_ATTRIBUTES}; +use windows::Win32::Storage::FileSystem::{ + CreateFileW, FILE_ATTRIBUTE_NORMAL, FILE_GENERIC_READ, FILE_GENERIC_WRITE, FILE_SHARE_NONE, + FlushFileBuffers, OPEN_EXISTING, PIPE_ACCESS_DUPLEX, ReadFile, WriteFile, +}; +use windows::Win32::System::Pipes::{ + ConnectNamedPipe, CreateNamedPipeW, DisconnectNamedPipe, PIPE_READMODE_BYTE, PIPE_TYPE_BYTE, + PIPE_UNLIMITED_INSTANCES, PIPE_WAIT, PeekNamedPipe, +}; +use windows::core::PCWSTR; + +const BUFFER_SIZE: u32 = 65536; + +/// Check if a pipe has data available using PeekNamedPipe. +/// +/// Returns the number of bytes available, or 0 if none. +fn peek_pipe_bytes(handle: HANDLE) -> std::result::Result { + let mut bytes_available = 0u32; + + unsafe { + // PeekNamedPipe with null buffers just checks byte availability + PeekNamedPipe(handle, None, 0, None, Some(&mut bytes_available), None) + .map_err(|e| IpcError::Io(std::io::Error::other(e.to_string())))?; + } + + Ok(bytes_available) +} + +/// Write all data to a handle, looping until complete. +/// +/// Windows WriteFile may write fewer bytes than requested (partial write). +/// This function ensures all data is written before returning. +fn write_all_to_handle(handle: HANDLE, data: &[u8]) -> std::result::Result<(), IpcError> { + let mut offset = 0; + while offset < data.len() { + let mut bytes_written = 0u32; + let remaining = &data[offset..]; + + unsafe { + WriteFile(handle, Some(remaining), Some(&mut bytes_written), None) + .map_err(|e| IpcError::Io(std::io::Error::other(e.to_string())))?; + } + + if bytes_written == 0 { + return Err(IpcError::Io(std::io::Error::new( + std::io::ErrorKind::WriteZero, + "WriteFile wrote zero bytes", + ))); + } + + offset += bytes_written as usize; + } + Ok(()) +} + +/// ERROR_PIPE_CONNECTED (535) - The pipe is already connected. +/// This occurs when a client connects between CreateNamedPipeW and ConnectNamedPipe. +/// It's not an error condition - it means the pipe is ready for use. +const ERROR_PIPE_CONNECTED: WIN32_ERROR = WIN32_ERROR(535); + +fn to_wide_string(s: &str) -> Vec { + OsStr::new(s).encode_wide().chain(Some(0)).collect() +} + +/// Validate and create a pipe name from a name. +/// +/// For security, this function: +/// - Rejects names with path traversal sequences (..) +/// - Rejects names that look like full pipe paths (prevents escaping itk_ namespace) +/// - Creates paths only in \\.\pipe\itk_* namespace +fn make_pipe_name(name: &str) -> Result { + // Reject path traversal attempts + if name.contains("..") { + return Err(IpcError::Platform( + "pipe name cannot contain path traversal sequences".into(), + )); + } + + // Reject names that look like they're trying to specify a full pipe path + // This prevents callers from escaping the itk_ namespace + if name.starts_with(r"\\") || name.starts_with(r"//") { + return Err(IpcError::Platform( + "pipe name cannot be a full pipe path".into(), + )); + } + + // Reject names with path separators + if name.contains('\\') || name.contains('/') { + return Err(IpcError::Platform( + "pipe name cannot contain path separators".into(), + )); + } + + Ok(format!(r"\\.\pipe\itk_{}", name)) +} + +/// RAII wrapper for security descriptor allocated by Windows APIs +struct SecurityDescriptorGuard { + ptr: PSECURITY_DESCRIPTOR, +} + +impl Drop for SecurityDescriptorGuard { + fn drop(&mut self) { + if !self.ptr.0.is_null() { + unsafe { + let _ = LocalFree(HLOCAL(self.ptr.0)); + } + } + } +} + +/// Create security attributes that restrict pipe access to the current user only. +/// +/// Uses SDDL (Security Descriptor Definition Language) to define: +/// - D: = DACL (Discretionary Access Control List) +/// - (A;;GA;;;CO) = Allow Generic All access to Creator Owner +/// +/// This prevents other users on the system from connecting to the pipe, +/// similar to Unix socket permissions of 0o600. +fn create_restricted_security_attributes() -> Result<(SECURITY_ATTRIBUTES, SecurityDescriptorGuard)> +{ + // SDDL: D:(A;;GA;;;CO) + // D: = DACL + // A = Allow + // GA = Generic All (full access) + // CO = Creator Owner (the user who created the pipe) + // + // This restricts access to only the user who created the pipe. + let sddl = to_wide_string("D:(A;;GA;;;CO)"); + + let mut sd_ptr = PSECURITY_DESCRIPTOR::default(); + + unsafe { + ConvertStringSecurityDescriptorToSecurityDescriptorW( + PCWSTR(sddl.as_ptr()), + SDDL_REVISION_1, + &mut sd_ptr, + None, + ) + .map_err(|e| IpcError::Platform(format!("Failed to create security descriptor: {}", e)))?; + } + + let guard = SecurityDescriptorGuard { ptr: sd_ptr }; + + let sa = SECURITY_ATTRIBUTES { + nLength: std::mem::size_of::() as u32, + lpSecurityDescriptor: sd_ptr.0, + bInheritHandle: false.into(), + }; + + Ok((sa, guard)) +} + +/// Windows named pipe client +pub struct NamedPipeClient { + handle: Mutex, + connected: AtomicBool, + #[allow(dead_code)] + name: String, +} + +// SAFETY: HANDLE is a raw pointer but we protect all access with a Mutex. +// The handle is only accessed through the Mutex, ensuring thread-safe access. +unsafe impl Send for NamedPipeClient {} +unsafe impl Sync for NamedPipeClient {} + +impl NamedPipeClient { + /// Connect to a named pipe server + pub fn connect(name: &str) -> Result { + let pipe_name = make_pipe_name(name)?; + let wide_name = to_wide_string(&pipe_name); + + unsafe { + let handle = CreateFileW( + PCWSTR(wide_name.as_ptr()), + (FILE_GENERIC_READ | FILE_GENERIC_WRITE).0, + FILE_SHARE_NONE, + None, + OPEN_EXISTING, + FILE_ATTRIBUTE_NORMAL, + None, + ) + .map_err(|e| IpcError::ConnectionFailed(e.to_string()))?; + + if handle == INVALID_HANDLE_VALUE { + return Err(IpcError::ConnectionFailed("Invalid handle".into())); + } + + Ok(Self { + handle: Mutex::new(handle), + connected: AtomicBool::new(true), + name: pipe_name, + }) + } + } +} + +impl IpcChannel for NamedPipeClient { + fn send(&self, data: &[u8]) -> Result<()> { + if !self.is_connected() { + return Err(IpcError::NotConnected); + } + + let handle = self.handle.lock().unwrap(); + + // Use helper to ensure all data is written (handles partial writes) + write_all_to_handle(*handle, data)?; + + unsafe { + FlushFileBuffers(*handle) + .map_err(|e| IpcError::Io(std::io::Error::other(e.to_string())))?; + } + + Ok(()) + } + + fn recv(&self) -> Result> { + if !self.is_connected() { + return Err(IpcError::NotConnected); + } + + let handle = self.handle.lock().unwrap(); + let mut reader = PipeReader { handle: *handle }; + read_message(&mut reader) + } + + fn try_recv(&self) -> Result>> { + if !self.is_connected() { + return Err(IpcError::NotConnected); + } + + let handle = self.handle.lock().unwrap(); + + // Use PeekNamedPipe for true non-blocking check + let bytes_available = peek_pipe_bytes(*handle)?; + if bytes_available == 0 { + return Ok(None); + } + + // Data is available, read it + let mut reader = PipeReader { handle: *handle }; + read_message(&mut reader).map(Some) + } + + fn is_connected(&self) -> bool { + self.connected.load(Ordering::SeqCst) + } + + fn close(&self) { + if self.connected.swap(false, Ordering::SeqCst) { + let handle = self.handle.lock().unwrap(); + unsafe { + let _ = CloseHandle(*handle); + } + } + } +} + +impl Drop for NamedPipeClient { + fn drop(&mut self) { + self.close(); + } +} + +/// Windows named pipe server +pub struct NamedPipeServer { + name: String, + listening: AtomicBool, +} + +impl NamedPipeServer { + /// Create a new named pipe server + pub fn new(name: &str) -> Result { + let pipe_name = make_pipe_name(name)?; + + Ok(Self { + name: pipe_name, + listening: AtomicBool::new(true), + }) + } + + /// Create a new pipe instance. + /// + /// Uses restricted security attributes that only allow the creating user + /// to connect, preventing unauthorized access from other users. + fn create_pipe_instance(&self) -> Result { + let wide_name = to_wide_string(&self.name); + + // Create restricted security attributes (owner-only access) + let (mut sa, _guard) = create_restricted_security_attributes()?; + + unsafe { + let handle = CreateNamedPipeW( + PCWSTR(wide_name.as_ptr()), + PIPE_ACCESS_DUPLEX, + PIPE_TYPE_BYTE | PIPE_READMODE_BYTE | PIPE_WAIT, + PIPE_UNLIMITED_INSTANCES, + BUFFER_SIZE, + BUFFER_SIZE, + 0, + Some(&mut sa), + ); + + if handle == INVALID_HANDLE_VALUE { + return Err(IpcError::Platform("Failed to create named pipe".into())); + } + + Ok(handle) + } + } +} + +impl IpcServer for NamedPipeServer { + type Channel = NamedPipeConnection; + + fn accept(&self) -> Result { + if !self.listening.load(Ordering::SeqCst) { + return Err(IpcError::ChannelClosed); + } + + let handle = self.create_pipe_instance()?; + + unsafe { + // Wait for client to connect. + // If the client connects between CreateNamedPipeW and ConnectNamedPipe, + // we get ERROR_PIPE_CONNECTED (535), which is not an error - it means + // the pipe is already connected and ready for use. + if let Err(e) = ConnectNamedPipe(handle, None) { + if e.code() != windows::core::HRESULT::from(ERROR_PIPE_CONNECTED) { + let _ = CloseHandle(handle); + return Err(IpcError::ConnectionFailed(e.to_string())); + } + // ERROR_PIPE_CONNECTED is success - pipe is already connected + } + } + + Ok(NamedPipeConnection { + handle: Mutex::new(handle), + connected: AtomicBool::new(true), + }) + } + + fn close(&self) { + self.listening.store(false, Ordering::SeqCst); + } +} + +/// A connected named pipe (server-side) +pub struct NamedPipeConnection { + handle: Mutex, + connected: AtomicBool, +} + +// SAFETY: HANDLE is a raw pointer but we protect all access with a Mutex. +// The handle is only accessed through the Mutex, ensuring thread-safe access. +unsafe impl Send for NamedPipeConnection {} +unsafe impl Sync for NamedPipeConnection {} + +impl IpcChannel for NamedPipeConnection { + fn send(&self, data: &[u8]) -> Result<()> { + if !self.is_connected() { + return Err(IpcError::NotConnected); + } + + let handle = self.handle.lock().unwrap(); + + // Use helper to ensure all data is written (handles partial writes) + write_all_to_handle(*handle, data)?; + + unsafe { + FlushFileBuffers(*handle) + .map_err(|e| IpcError::Io(std::io::Error::other(e.to_string())))?; + } + + Ok(()) + } + + fn recv(&self) -> Result> { + if !self.is_connected() { + return Err(IpcError::NotConnected); + } + + let handle = self.handle.lock().unwrap(); + let mut reader = PipeReader { handle: *handle }; + read_message(&mut reader) + } + + fn try_recv(&self) -> Result>> { + if !self.is_connected() { + return Err(IpcError::NotConnected); + } + + let handle = self.handle.lock().unwrap(); + + // Use PeekNamedPipe for true non-blocking check + let bytes_available = peek_pipe_bytes(*handle)?; + if bytes_available == 0 { + return Ok(None); + } + + // Data is available, read it + let mut reader = PipeReader { handle: *handle }; + read_message(&mut reader).map(Some) + } + + fn is_connected(&self) -> bool { + self.connected.load(Ordering::SeqCst) + } + + fn close(&self) { + if self.connected.swap(false, Ordering::SeqCst) { + let handle = self.handle.lock().unwrap(); + unsafe { + let _ = DisconnectNamedPipe(*handle); + let _ = CloseHandle(*handle); + } + } + } +} + +impl Drop for NamedPipeConnection { + fn drop(&mut self) { + self.close(); + } +} + +/// Helper for reading from a pipe handle +struct PipeReader { + handle: HANDLE, +} + +impl Read for PipeReader { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let mut bytes_read = 0u32; + + unsafe { + ReadFile(self.handle, Some(buf), Some(&mut bytes_read), None) + .map_err(|e| std::io::Error::other(e.to_string()))?; + } + + Ok(bytes_read as usize) + } +} diff --git a/core/itk-net/Cargo.toml b/core/itk-net/Cargo.toml new file mode 100644 index 0000000..8205cf2 --- /dev/null +++ b/core/itk-net/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "itk-net" +description = "P2P networking for multiplayer video sync" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +authors.workspace = true + +[features] +default = [] + +[dependencies] +# Core ITK libraries +itk-protocol = { path = "../itk-protocol" } +itk-sync = { path = "../itk-sync" } + +# Networking - laminar for game-focused UDP +laminar = "0.5" + +# Async runtime +tokio = { workspace = true } + +# Serialization +serde = { workspace = true } +bincode = { workspace = true } + +# Logging +tracing = { workspace = true } + +# Error handling +thiserror = { workspace = true } + +# Utilities +parking_lot = "0.12" +socket2 = { version = "0.5", features = ["all"] } + +[lints] +workspace = true diff --git a/core/itk-net/src/discovery.rs b/core/itk-net/src/discovery.rs new file mode 100644 index 0000000..942e881 --- /dev/null +++ b/core/itk-net/src/discovery.rs @@ -0,0 +1,269 @@ +//! LAN peer discovery via UDP broadcast +//! +//! Broadcasts presence on the local network to find other peers. + +use crate::{DISCOVERY_PORT, NetError, Result}; +use serde::{Deserialize, Serialize}; +use socket2::{Domain, Protocol, Socket, Type}; +use std::net::{SocketAddr, UdpSocket}; +use std::time::{Duration, Instant, SystemTime}; +use tracing::{debug, info, warn}; + +/// Discovery announcement message +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DiscoveryAnnounce { + /// Magic identifier for our protocol + pub magic: [u8; 4], + /// Protocol version + pub version: u32, + /// Unique peer identifier (distinguishes instances in same session) + pub peer_id: u64, + /// Session ID (content hash or random) + pub session_id: String, + /// Peer's game port (for laminar connection) + pub game_port: u16, + /// Whether this peer is the session leader + pub is_leader: bool, + /// Peer's display name + pub name: String, +} + +impl DiscoveryAnnounce { + pub const MAGIC: [u8; 4] = *b"ITKD"; // ITK Discovery + + pub fn new(session_id: String, game_port: u16, is_leader: bool, name: String) -> Self { + // Generate a unique peer ID from process ID and high-resolution time + let time_nanos = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap_or_default() + .as_nanos() as u64; + let peer_id = time_nanos ^ (std::process::id() as u64) << 32; + + Self { + magic: Self::MAGIC, + version: 1, + peer_id, + session_id, + game_port, + is_leader, + name, + } + } + + pub fn is_valid(&self) -> bool { + self.magic == Self::MAGIC && self.version == 1 + } +} + +/// Discovered peer information +#[derive(Debug, Clone)] +pub struct DiscoveredPeer { + /// Peer's address for game connection + pub addr: SocketAddr, + /// Session ID + pub session_id: String, + /// Whether this peer is the leader + pub is_leader: bool, + /// Peer's display name + pub name: String, + /// When we discovered this peer + pub discovered_at: Instant, +} + +/// LAN discovery service +pub struct Discovery { + /// UDP socket for broadcast + socket: UdpSocket, + /// Our announcement + announce: DiscoveryAnnounce, + /// Discovered peers + discovered: Vec, + /// Last broadcast time + last_broadcast: Instant, + /// Broadcast interval + broadcast_interval: Duration, +} + +impl Discovery { + /// Create a new discovery service + pub fn new(session_id: String, game_port: u16, is_leader: bool, name: String) -> Result { + // Use socket2 to set SO_REUSEADDR/SO_REUSEPORT before binding, allowing + // multiple instances to listen on the same discovery port simultaneously. + let sock2 = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) + .map_err(|e| NetError::BindFailed(format!("Failed to create socket: {}", e)))?; + + sock2 + .set_reuse_address(true) + .map_err(|e| NetError::BindFailed(format!("Failed to set SO_REUSEADDR: {}", e)))?; + + // SO_REUSEPORT allows multiple processes to bind to the same port (Linux/macOS) + #[cfg(not(windows))] + sock2 + .set_reuse_port(true) + .map_err(|e| NetError::BindFailed(format!("Failed to set SO_REUSEPORT: {}", e)))?; + + let addr: std::net::SocketAddrV4 = format!("0.0.0.0:{}", DISCOVERY_PORT).parse().unwrap(); + sock2.bind(&socket2::SockAddr::from(addr)).map_err(|e| { + NetError::BindFailed(format!( + "Failed to bind discovery port {}: {}", + DISCOVERY_PORT, e + )) + })?; + + let socket: UdpSocket = sock2.into(); + + // Enable broadcast + socket + .set_broadcast(true) + .map_err(|e| NetError::BindFailed(format!("Failed to enable broadcast: {}", e)))?; + + // Non-blocking mode + socket + .set_nonblocking(true) + .map_err(|e| NetError::BindFailed(format!("Failed to set non-blocking: {}", e)))?; + + let announce = DiscoveryAnnounce::new(session_id, game_port, is_leader, name); + + info!(port = DISCOVERY_PORT, "Discovery service started"); + + Ok(Self { + socket, + announce, + discovered: Vec::new(), + last_broadcast: Instant::now() - Duration::from_secs(10), // Trigger immediate broadcast + broadcast_interval: Duration::from_secs(2), + }) + } + + /// Update our announcement (e.g., when becoming leader) + pub fn update_announce(&mut self, is_leader: bool, session_id: Option) { + self.announce.is_leader = is_leader; + if let Some(sid) = session_id { + self.announce.session_id = sid; + } + } + + /// Poll for discovery events + pub fn poll(&mut self) -> Vec { + // Broadcast if interval elapsed + if self.last_broadcast.elapsed() >= self.broadcast_interval { + self.broadcast(); + self.last_broadcast = Instant::now(); + } + + // Receive announcements + let mut buf = [0u8; 1024]; + let mut new_peers = Vec::new(); + + loop { + match self.socket.recv_from(&mut buf) { + Ok((len, src_addr)) => { + if let Some(peer) = self.handle_packet(&buf[..len], src_addr) { + new_peers.push(peer); + } + }, + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + break; + }, + Err(e) => { + warn!(error = ?e, "Discovery recv error"); + break; + }, + } + } + + // Prune old discoveries (older than 10 seconds) + let now = Instant::now(); + self.discovered + .retain(|p| now.duration_since(p.discovered_at) < Duration::from_secs(10)); + + new_peers + } + + fn broadcast(&self) { + let data = match bincode::serialize(&self.announce) { + Ok(d) => d, + Err(e) => { + warn!(error = ?e, "Failed to serialize announce"); + return; + }, + }; + + // Broadcast to LAN + let broadcast_addr: SocketAddr = format!("255.255.255.255:{}", DISCOVERY_PORT) + .parse() + .unwrap(); + + if let Err(e) = self.socket.send_to(&data, broadcast_addr) { + debug!(error = ?e, "Broadcast send failed (may be normal on some networks)"); + } + } + + fn handle_packet(&mut self, data: &[u8], src_addr: SocketAddr) -> Option { + let announce: DiscoveryAnnounce = match bincode::deserialize(data) { + Ok(a) => a, + Err(_) => return None, + }; + + if !announce.is_valid() { + return None; + } + + // Ignore our own broadcasts (compare unique peer ID, not session+port) + if announce.peer_id == self.announce.peer_id { + return None; + } + + // Build peer address with their game port + let game_addr = SocketAddr::new(src_addr.ip(), announce.game_port); + + let peer = DiscoveredPeer { + addr: game_addr, + session_id: announce.session_id.clone(), + is_leader: announce.is_leader, + name: announce.name.clone(), + discovered_at: Instant::now(), + }; + + // Check if already discovered + let existing = self + .discovered + .iter_mut() + .find(|p| p.addr == game_addr && p.session_id == announce.session_id); + + if let Some(existing) = existing { + existing.discovered_at = Instant::now(); + existing.is_leader = announce.is_leader; + None // Not a new discovery + } else { + info!( + peer = %game_addr, + session = %announce.session_id, + leader = announce.is_leader, + "Discovered peer" + ); + self.discovered.push(peer.clone()); + Some(peer) + } + } + + /// Get all currently known peers + pub fn peers(&self) -> &[DiscoveredPeer] { + &self.discovered + } + + /// Find peers in a specific session + pub fn peers_in_session(&self, session_id: &str) -> Vec<&DiscoveredPeer> { + self.discovered + .iter() + .filter(|p| p.session_id == session_id) + .collect() + } + + /// Find the leader for a session + pub fn find_leader(&self, session_id: &str) -> Option<&DiscoveredPeer> { + self.discovered + .iter() + .find(|p| p.session_id == session_id && p.is_leader) + } +} diff --git a/core/itk-net/src/error.rs b/core/itk-net/src/error.rs new file mode 100644 index 0000000..6c3bf43 --- /dev/null +++ b/core/itk-net/src/error.rs @@ -0,0 +1,40 @@ +//! Network error types + +use thiserror::Error; + +/// Network errors +#[derive(Error, Debug)] +pub enum NetError { + #[error("failed to bind socket: {0}")] + BindFailed(String), + + #[error("connection failed: {0}")] + ConnectionFailed(String), + + #[error("send failed: {0}")] + SendFailed(String), + + #[error("serialization error: {0}")] + Serialization(#[from] bincode::Error), + + #[error("peer not found: {0}")] + PeerNotFound(String), + + #[error("session not found")] + SessionNotFound, + + #[error("not the leader")] + NotLeader, + + #[error("already in session")] + AlreadyInSession, + + #[error("io error: {0}")] + Io(#[from] std::io::Error), + + #[error("laminar error: {0}")] + Laminar(String), +} + +/// Result type for network operations +pub type Result = std::result::Result; diff --git a/core/itk-net/src/lib.rs b/core/itk-net/src/lib.rs new file mode 100644 index 0000000..84e9e3f --- /dev/null +++ b/core/itk-net/src/lib.rs @@ -0,0 +1,53 @@ +//! # ITK Net +//! +//! P2P networking for multiplayer video synchronization. +//! +//! This crate provides: +//! - LAN peer discovery via UDP broadcast +//! - Reliable and unreliable messaging via laminar +//! - Session management with leader election +//! - Integration with itk-sync for clock/playback sync +//! +//! ## Architecture +//! +//! ```text +//! ┌─────────────────────────────────────────────────────────────┐ +//! │ Session │ +//! │ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐ │ +//! │ │ Discovery │ │ Peers │ │ SyncManager │ │ +//! │ │ (UDP bcast)│ │ (laminar) │ │ (ClockSync + │ │ +//! │ │ │ │ │ │ PlaybackSync) │ │ +//! │ └─────────────┘ └─────────────┘ └─────────────────────┘ │ +//! └─────────────────────────────────────────────────────────────┘ +//! ``` +//! +//! ## Sync Flow +//! +//! 1. Leader broadcasts `SyncState` every 500ms (unreliable) +//! 2. Followers use `DriftCorrector` for smooth catch-up +//! 3. `ClockPing/Pong` exchanges estimate clock offset +//! 4. Commands (play/pause/seek) sent reliably + +pub mod discovery; +pub mod error; +pub mod peer; +pub mod session; +pub mod sync_manager; + +pub use discovery::Discovery; +pub use error::{NetError, Result}; +pub use peer::{Peer, PeerEvent, PeerManager}; +pub use session::{Session, SessionConfig, SessionEvent, SessionRole}; +pub use sync_manager::SyncManager; + +/// Default port for multiplayer sync +pub const DEFAULT_PORT: u16 = 7331; + +/// Discovery broadcast port +pub const DISCOVERY_PORT: u16 = 7332; + +/// Sync broadcast interval in milliseconds +pub const SYNC_INTERVAL_MS: u64 = 500; + +/// Clock ping interval in milliseconds +pub const CLOCK_PING_INTERVAL_MS: u64 = 2000; diff --git a/core/itk-net/src/peer.rs b/core/itk-net/src/peer.rs new file mode 100644 index 0000000..3e3f5a7 --- /dev/null +++ b/core/itk-net/src/peer.rs @@ -0,0 +1,342 @@ +//! Peer connection management using laminar +//! +//! Provides reliable and unreliable messaging between peers. + +use crate::{DEFAULT_PORT, NetError, Result}; +use laminar::{Config, Packet, Socket, SocketEvent}; +use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tracing::{debug, info, warn}; + +/// Peer identifier (socket address as string) +pub type PeerId = String; + +/// Information about a connected peer +#[derive(Debug, Clone)] +pub struct Peer { + /// Peer's network address + pub addr: SocketAddr, + /// Display name (if known) + pub name: Option, + /// Last time we received a message from this peer + pub last_seen: Instant, + /// Whether this peer is the session leader + pub is_leader: bool, + /// Estimated latency in milliseconds + pub latency_ms: Option, +} + +impl Peer { + pub fn new(addr: SocketAddr) -> Self { + Self { + addr, + name: None, + last_seen: Instant::now(), + is_leader: false, + latency_ms: None, + } + } + + pub fn id(&self) -> PeerId { + self.addr.to_string() + } +} + +/// Events from the peer manager +#[derive(Debug, Clone)] +pub enum PeerEvent { + /// A new peer connected + PeerConnected(PeerId), + /// A peer disconnected + PeerDisconnected(PeerId), + /// Received a message from a peer + Message { from: PeerId, data: Vec }, + /// Connection timeout + Timeout(PeerId), +} + +/// Network message types +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum NetMessage { + /// Heartbeat/keepalive + Ping { timestamp_ms: u64 }, + /// Response to ping + Pong { + timestamp_ms: u64, + peer_time_ms: u64, + }, + /// Announce presence with name + Announce { name: String }, + /// Sync state broadcast (unreliable, frequent) + SyncState(itk_protocol::SyncState), + /// Clock ping for time sync + ClockPing(itk_protocol::ClockPing), + /// Clock pong response + ClockPong(itk_protocol::ClockPong), + /// Video command (reliable) + VideoCommand(VideoCommand), +} + +/// Video control commands +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum VideoCommand { + Load { url: String }, + Play, + Pause, + Seek { position_ms: u64 }, +} + +/// Manages peer connections +pub struct PeerManager { + /// Laminar socket + socket: Socket, + /// Connected peers + peers: Arc>>, + /// Our local address + local_addr: SocketAddr, + /// Pending events + events: Arc>>, +} + +impl PeerManager { + /// Create a new peer manager bound to the given port + pub fn new(port: u16) -> Result { + let addr: SocketAddr = format!("0.0.0.0:{}", port) + .parse() + .map_err(|e| NetError::BindFailed(format!("{}", e)))?; + + let config = Config { + heartbeat_interval: Some(Duration::from_millis(500)), + idle_connection_timeout: Duration::from_secs(10), + ..Default::default() + }; + + let socket = Socket::bind_with_config(addr, config) + .map_err(|e| NetError::BindFailed(e.to_string()))?; + + let local_addr = socket + .local_addr() + .map_err(|e| NetError::BindFailed(e.to_string()))?; + + info!(addr = %local_addr, "PeerManager bound"); + + Ok(Self { + socket, + peers: Arc::new(RwLock::new(HashMap::new())), + local_addr, + events: Arc::new(RwLock::new(Vec::new())), + }) + } + + /// Create with default port + pub fn with_default_port() -> Result { + Self::new(DEFAULT_PORT) + } + + /// Get our local address + pub fn local_addr(&self) -> SocketAddr { + self.local_addr + } + + /// Connect to a peer + pub fn connect(&mut self, addr: SocketAddr) -> Result<()> { + // Send announce message to initiate connection + let msg = NetMessage::Announce { + name: format!("peer-{}", self.local_addr.port()), + }; + self.send_reliable(addr, &msg)?; + info!(peer = %addr, "Connecting to peer"); + Ok(()) + } + + /// Send a message reliably (guaranteed delivery, ordered) + pub fn send_reliable(&mut self, addr: SocketAddr, msg: &NetMessage) -> Result<()> { + let data = bincode::serialize(msg)?; + let packet = Packet::reliable_ordered(addr, data, Some(0)); + self.socket + .send(packet) + .map_err(|e| NetError::SendFailed(e.to_string()))?; + Ok(()) + } + + /// Send a message unreliably (may be lost, unordered - good for sync state) + pub fn send_unreliable(&mut self, addr: SocketAddr, msg: &NetMessage) -> Result<()> { + let data = bincode::serialize(msg)?; + let packet = Packet::unreliable(addr, data); + self.socket + .send(packet) + .map_err(|e| NetError::SendFailed(e.to_string()))?; + Ok(()) + } + + /// Broadcast a message to all connected peers (unreliable) + pub fn broadcast(&mut self, msg: &NetMessage) -> Result<()> { + let peers: Vec = self.peers.read().values().map(|p| p.addr).collect(); + for addr in peers { + if let Err(e) = self.send_unreliable(addr, msg) { + warn!(peer = %addr, error = ?e, "Failed to broadcast to peer"); + } + } + Ok(()) + } + + /// Broadcast a message reliably to all peers + pub fn broadcast_reliable(&mut self, msg: &NetMessage) -> Result<()> { + let peers: Vec = self.peers.read().values().map(|p| p.addr).collect(); + for addr in peers { + if let Err(e) = self.send_reliable(addr, msg) { + warn!(peer = %addr, error = ?e, "Failed to broadcast reliable to peer"); + } + } + Ok(()) + } + + /// Poll for network events (non-blocking) + pub fn poll(&mut self) -> Vec { + // Process socket events + self.socket.manual_poll(Instant::now()); + + while let Some(event) = self.socket.recv() { + self.handle_socket_event(event); + } + + // Return accumulated events + std::mem::take(&mut *self.events.write()) + } + + fn handle_socket_event(&mut self, event: SocketEvent) { + match event { + SocketEvent::Packet(packet) => { + let addr = packet.addr(); + let peer_id = addr.to_string(); + + // Update last seen + { + let mut peers = self.peers.write(); + if let Some(peer) = peers.get_mut(&peer_id) { + peer.last_seen = Instant::now(); + } + } + + // Try to deserialize message + match bincode::deserialize::(packet.payload()) { + Ok(msg) => { + self.handle_message(addr, msg); + }, + Err(e) => { + warn!(peer = %addr, error = ?e, "Failed to deserialize message"); + }, + } + }, + + SocketEvent::Connect(addr) => { + let peer_id = addr.to_string(); + info!(peer = %addr, "Peer connected"); + + { + let mut peers = self.peers.write(); + peers.insert(peer_id.clone(), Peer::new(addr)); + } + + self.events.write().push(PeerEvent::PeerConnected(peer_id)); + }, + + SocketEvent::Disconnect(addr) => { + let peer_id = addr.to_string(); + info!(peer = %addr, "Peer disconnected"); + + { + let mut peers = self.peers.write(); + peers.remove(&peer_id); + } + + self.events + .write() + .push(PeerEvent::PeerDisconnected(peer_id)); + }, + + SocketEvent::Timeout(addr) => { + let peer_id = addr.to_string(); + warn!(peer = %addr, "Peer timeout"); + + { + let mut peers = self.peers.write(); + peers.remove(&peer_id); + } + + self.events.write().push(PeerEvent::Timeout(peer_id)); + }, + } + } + + fn handle_message(&mut self, addr: SocketAddr, msg: NetMessage) { + let peer_id = addr.to_string(); + + match msg { + NetMessage::Announce { name } => { + debug!(peer = %addr, name = %name, "Peer announced"); + let mut peers = self.peers.write(); + if let Some(peer) = peers.get_mut(&peer_id) { + peer.name = Some(name); + } else { + let mut peer = Peer::new(addr); + peer.name = Some(name); + peers.insert(peer_id.clone(), peer); + drop(peers); + self.events.write().push(PeerEvent::PeerConnected(peer_id)); + } + }, + + NetMessage::Ping { timestamp_ms } => { + // Respond with pong + let pong = NetMessage::Pong { + timestamp_ms, + peer_time_ms: itk_sync::now_ms(), + }; + let _ = self.send_unreliable(addr, &pong); + }, + + NetMessage::Pong { + timestamp_ms, + peer_time_ms: _, + } => { + // Calculate latency + let now = itk_sync::now_ms(); + let rtt = now.saturating_sub(timestamp_ms) as u32; + + let mut peers = self.peers.write(); + if let Some(peer) = peers.get_mut(&peer_id) { + peer.latency_ms = Some(rtt / 2); + } + }, + + // Forward other messages as events + _ => { + let data = bincode::serialize(&msg).unwrap_or_default(); + self.events.write().push(PeerEvent::Message { + from: peer_id, + data, + }); + }, + } + } + + /// Get list of connected peers + pub fn peers(&self) -> Vec { + self.peers.read().values().cloned().collect() + } + + /// Get a specific peer + pub fn get_peer(&self, id: &str) -> Option { + self.peers.read().get(id).cloned() + } + + /// Number of connected peers + pub fn peer_count(&self) -> usize { + self.peers.read().len() + } +} diff --git a/core/itk-net/src/session.rs b/core/itk-net/src/session.rs new file mode 100644 index 0000000..12ba096 --- /dev/null +++ b/core/itk-net/src/session.rs @@ -0,0 +1,392 @@ +//! Session management +//! +//! A session represents a group of peers watching the same content together. +//! One peer is the leader who controls playback; others follow. + +use crate::discovery::{DiscoveredPeer, Discovery}; +use crate::peer::{NetMessage, PeerEvent, PeerManager, VideoCommand}; +use crate::sync_manager::SyncManager; +use crate::{DEFAULT_PORT, NetError, Result}; +use std::net::SocketAddr; +use std::time::{Duration, Instant}; +use tracing::{info, warn}; + +/// Session configuration +#[derive(Debug, Clone)] +pub struct SessionConfig { + /// Our display name + pub name: String, + /// Port to use for game connections + pub port: u16, + /// Whether to enable LAN discovery + pub enable_discovery: bool, +} + +impl Default for SessionConfig { + fn default() -> Self { + Self { + name: format!("Player-{}", std::process::id() % 10000), + port: DEFAULT_PORT, + enable_discovery: true, + } + } +} + +/// Our role in the session +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SessionRole { + /// We control playback + Leader, + /// We follow the leader + Follower, +} + +/// Events from the session +#[derive(Debug, Clone)] +pub enum SessionEvent { + /// Session started + Started { + session_id: String, + role: SessionRole, + }, + /// Peer joined the session + PeerJoined { name: String, addr: SocketAddr }, + /// Peer left the session + PeerLeft { name: String }, + /// We became the leader (leader left) + BecameLeader, + /// Sync state updated + SyncUpdated { position_ms: u64, is_playing: bool }, + /// Should seek to position (large drift) + SeekRequired { position_ms: u64 }, + /// Video command received + Command(VideoCommand), +} + +/// Multiplayer session +pub struct Session { + /// Session configuration + _config: SessionConfig, + /// Session ID (content URL hash) + session_id: String, + /// Our role + role: SessionRole, + /// Peer manager + peers: PeerManager, + /// Discovery service + discovery: Option, + /// Sync manager + sync: SyncManager, + /// Leader's address (if we're a follower) + leader_addr: Option, + /// Pending events + events: Vec, + /// Last time we checked for leader timeout + last_leader_check: Instant, +} + +impl Session { + /// Create a new session as leader + pub fn create(config: SessionConfig, content_id: String) -> Result { + let peers = PeerManager::new(config.port)?; + + let discovery = if config.enable_discovery { + Some(Discovery::new( + content_id.clone(), + config.port, + true, // We're the leader + config.name.clone(), + )?) + } else { + None + }; + + let sync = SyncManager::new(content_id.clone(), true); + + info!(session_id = %content_id, "Created session as leader"); + + let mut session = Self { + _config: config, + session_id: content_id.clone(), + role: SessionRole::Leader, + peers, + discovery, + sync, + leader_addr: None, + events: Vec::new(), + last_leader_check: Instant::now(), + }; + + session.events.push(SessionEvent::Started { + session_id: content_id, + role: SessionRole::Leader, + }); + + Ok(session) + } + + /// Join an existing session + pub fn join( + config: SessionConfig, + leader_addr: SocketAddr, + content_id: String, + ) -> Result { + let mut peers = PeerManager::new(config.port)?; + + // Connect to the leader + peers.connect(leader_addr)?; + + let discovery = if config.enable_discovery { + Some(Discovery::new( + content_id.clone(), + config.port, + false, // We're a follower + config.name.clone(), + )?) + } else { + None + }; + + let sync = SyncManager::new(content_id.clone(), false); + + info!(session_id = %content_id, leader = %leader_addr, "Joining session"); + + let mut session = Self { + _config: config, + session_id: content_id.clone(), + role: SessionRole::Follower, + peers, + discovery, + sync, + leader_addr: Some(leader_addr), + events: Vec::new(), + last_leader_check: Instant::now(), + }; + + session.events.push(SessionEvent::Started { + session_id: content_id, + role: SessionRole::Follower, + }); + + Ok(session) + } + + /// Join a discovered session + pub fn join_discovered(config: SessionConfig, peer: &DiscoveredPeer) -> Result { + Self::join(config, peer.addr, peer.session_id.clone()) + } + + /// Get our role + pub fn role(&self) -> SessionRole { + self.role + } + + /// Get session ID + pub fn session_id(&self) -> &str { + &self.session_id + } + + /// Get current playback position + pub fn position_ms(&self) -> u64 { + self.sync.current_position_ms() + } + + /// Get recommended playback rate + pub fn playback_rate(&self) -> f64 { + self.sync.playback_rate() + } + + /// Check if playing + pub fn is_playing(&self) -> bool { + self.sync.is_playing() + } + + /// Get current drift (followers only) + pub fn drift_ms(&self) -> Option { + self.sync.drift_ms() + } + + /// Get peer count + pub fn peer_count(&self) -> usize { + self.peers.peer_count() + } + + // === Leader commands === + + /// Load content (leader only) + pub fn load(&mut self, url: &str) -> Result<()> { + if self.role != SessionRole::Leader { + return Err(NetError::NotLeader); + } + self.sync.send_command( + VideoCommand::Load { + url: url.to_string(), + }, + &mut self.peers, + ) + } + + /// Play (leader only) + pub fn play(&mut self) -> Result<()> { + if self.role != SessionRole::Leader { + return Err(NetError::NotLeader); + } + self.sync.send_command(VideoCommand::Play, &mut self.peers) + } + + /// Pause (leader only) + pub fn pause(&mut self) -> Result<()> { + if self.role != SessionRole::Leader { + return Err(NetError::NotLeader); + } + self.sync.send_command(VideoCommand::Pause, &mut self.peers) + } + + /// Seek (leader only) + pub fn seek(&mut self, position_ms: u64) -> Result<()> { + if self.role != SessionRole::Leader { + return Err(NetError::NotLeader); + } + self.sync + .send_command(VideoCommand::Seek { position_ms }, &mut self.peers) + } + + /// Poll for events + pub fn poll(&mut self) -> Vec { + // Poll discovery + if let Some(ref mut discovery) = self.discovery { + for peer in discovery.poll() { + // Auto-connect to peers in same session + if peer.session_id == self.session_id + && let Err(e) = self.peers.connect(peer.addr) + { + warn!(peer = %peer.addr, error = ?e, "Failed to connect to discovered peer"); + } + } + } + + // Poll network + for event in self.peers.poll() { + self.handle_peer_event(event); + } + + // Update sync manager + if let Err(e) = self.sync.update(&mut self.peers) { + warn!(error = ?e, "Sync update failed"); + } + + // Check for seek requirement (large drift) + if let Some(target_pos) = self.sync.should_seek() { + self.sync.seek(target_pos); + self.events.push(SessionEvent::SeekRequired { + position_ms: target_pos, + }); + } + + // Check for leader timeout (followers only) + if self.role == SessionRole::Follower { + self.check_leader_timeout(); + } + + std::mem::take(&mut self.events) + } + + fn handle_peer_event(&mut self, event: PeerEvent) { + match event { + PeerEvent::PeerConnected(peer_id) => { + if let Some(peer) = self.peers.get_peer(&peer_id) { + info!(peer = %peer_id, "Peer connected to session"); + self.events.push(SessionEvent::PeerJoined { + name: peer.name.unwrap_or_else(|| peer_id.clone()), + addr: peer.addr, + }); + } + }, + + PeerEvent::PeerDisconnected(peer_id) | PeerEvent::Timeout(peer_id) => { + info!(peer = %peer_id, "Peer left session"); + + // Check if it was the leader + if self.role == SessionRole::Follower + && let Some(leader) = self.leader_addr + && peer_id == leader.to_string() + { + self.promote_to_leader(); + } + + self.events.push(SessionEvent::PeerLeft { name: peer_id }); + }, + + PeerEvent::Message { from, data } => { + self.handle_message(&from, &data); + }, + } + } + + fn handle_message(&mut self, from: &str, data: &[u8]) { + let msg: NetMessage = match bincode::deserialize(data) { + Ok(m) => m, + Err(_) => return, + }; + + let peer_addr = match self.peers.get_peer(from) { + Some(p) => p.addr, + None => return, + }; + + match msg { + NetMessage::SyncState(state) => { + self.sync.receive_sync_state(state); + self.events.push(SessionEvent::SyncUpdated { + position_ms: self.sync.current_position_ms(), + is_playing: self.sync.is_playing(), + }); + }, + + NetMessage::ClockPing(ping) => { + self.sync + .receive_clock_ping(from, ping, &mut self.peers, peer_addr); + }, + + NetMessage::ClockPong(pong) => { + self.sync.receive_clock_pong(from, pong); + }, + + NetMessage::VideoCommand(cmd) => { + self.sync.receive_command(cmd.clone()); + self.events.push(SessionEvent::Command(cmd)); + }, + + _ => {}, + } + } + + fn check_leader_timeout(&mut self) { + if self.last_leader_check.elapsed() < Duration::from_secs(5) { + return; + } + self.last_leader_check = Instant::now(); + + // If we haven't received sync state in a while, assume leader is gone + // This is a simplified check - real implementation would track last sync time + if let Some(leader) = self.leader_addr + && self.peers.get_peer(&leader.to_string()).is_none() + { + info!("Leader disconnected, promoting self"); + self.promote_to_leader(); + } + } + + fn promote_to_leader(&mut self) { + self.role = SessionRole::Leader; + self.leader_addr = None; + self.sync.set_leader(true); + + if let Some(ref mut discovery) = self.discovery { + discovery.update_announce(true, None); + } + + info!("Promoted to leader"); + self.events.push(SessionEvent::BecameLeader); + } +} diff --git a/core/itk-net/src/sync_manager.rs b/core/itk-net/src/sync_manager.rs new file mode 100644 index 0000000..95c92b7 --- /dev/null +++ b/core/itk-net/src/sync_manager.rs @@ -0,0 +1,297 @@ +//! Sync manager - integrates itk-sync with network messaging +//! +//! Handles clock synchronization and playback sync across peers. + +use crate::peer::{NetMessage, PeerManager, VideoCommand}; +use crate::{CLOCK_PING_INTERVAL_MS, Result, SYNC_INTERVAL_MS}; +use itk_protocol::SyncState; +use itk_sync::{ClockSync, DriftCorrector, PlaybackSync}; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::time::Instant; +use tracing::{debug, info}; + +/// Manages synchronization state for a session +pub struct SyncManager { + /// Our local playback state + local_sync: PlaybackSync, + /// Whether we are the leader + is_leader: bool, + /// Drift corrector (for followers) + drift_corrector: DriftCorrector, + /// Per-peer clock sync (for leader to track followers) + peer_clocks: HashMap, + /// Last sync broadcast time + last_sync_broadcast: Instant, + /// Last clock ping time + last_clock_ping: Instant, + /// Pending clock pings (peer_id -> send_time_ms) + pending_pings: HashMap, +} + +impl SyncManager { + /// Create a new sync manager + pub fn new(content_id: String, is_leader: bool) -> Self { + Self { + local_sync: PlaybackSync::new(content_id), + is_leader, + drift_corrector: DriftCorrector::new(), + peer_clocks: HashMap::new(), + last_sync_broadcast: Instant::now(), + last_clock_ping: Instant::now(), + pending_pings: HashMap::new(), + } + } + + /// Set whether we are the leader + pub fn set_leader(&mut self, is_leader: bool) { + self.is_leader = is_leader; + info!(is_leader, "Leader status changed"); + } + + /// Get current local position (with drift correction if follower) + pub fn current_position_ms(&self) -> u64 { + self.local_sync.current_position_ms() + } + + /// Get the recommended playback rate (for drift correction) + pub fn playback_rate(&self) -> f64 { + if self.is_leader { + 1.0 + } else { + self.drift_corrector + .calculate_rate(self.local_sync.current_position_ms()) + } + } + + /// Check if we should seek to correct large drift + pub fn should_seek(&self) -> Option { + if self.is_leader { + None + } else { + self.drift_corrector + .should_seek(self.local_sync.current_position_ms()) + } + } + + /// Get current drift in milliseconds (followers only) + pub fn drift_ms(&self) -> Option { + if self.is_leader { + None + } else { + self.drift_corrector + .current_drift_ms(self.local_sync.current_position_ms()) + } + } + + /// Load new content + pub fn load(&mut self, content_id: String) { + self.local_sync = PlaybackSync::new(content_id); + } + + /// Start playback + pub fn play(&mut self) { + self.local_sync.set_playing(true); + } + + /// Pause playback + pub fn pause(&mut self) { + self.local_sync.set_playing(false); + } + + /// Seek to position + pub fn seek(&mut self, position_ms: u64) { + self.local_sync.seek(position_ms); + } + + /// Check if currently playing + pub fn is_playing(&self) -> bool { + self.local_sync.is_playing + } + + /// Get the content ID + pub fn content_id(&self) -> &str { + &self.local_sync.content_id + } + + /// Process received sync state (from leader) + pub fn receive_sync_state(&mut self, state: SyncState) { + if self.is_leader { + // Leaders ignore sync states from others + return; + } + + debug!( + position = state.position_at_ref_ms, + playing = state.is_playing, + "Received sync state" + ); + + // Update drift corrector target + let playback_sync = PlaybackSync { + content_id: state.content_id.clone(), + position_at_ref_ms: state.position_at_ref_ms, + ref_wallclock_ms: state.ref_wallclock_ms, + is_playing: state.is_playing, + playback_rate: state.playback_rate, + }; + self.drift_corrector.update_target(playback_sync.clone()); + + // Update local sync if content changed + if self.local_sync.content_id != state.content_id { + self.local_sync.content_id = state.content_id; + } + + // Match play/pause state + if self.local_sync.is_playing != state.is_playing { + self.local_sync.set_playing(state.is_playing); + } + } + + /// Process received clock ping + pub fn receive_clock_ping( + &mut self, + _from: &str, + ping: itk_protocol::ClockPing, + peer_manager: &mut PeerManager, + peer_addr: SocketAddr, + ) { + let pong = itk_protocol::ClockPong { + sender_time_ms: ping.sender_time_ms, + receiver_time_ms: itk_sync::now_ms(), + }; + let msg = NetMessage::ClockPong(pong); + let _ = peer_manager.send_unreliable(peer_addr, &msg); + } + + /// Process received clock pong + pub fn receive_clock_pong(&mut self, from: &str, pong: itk_protocol::ClockPong) { + let recv_time = itk_sync::now_ms(); + + if self.is_leader { + // Leader tracks clock offset to each peer + let clock = self.peer_clocks.entry(from.to_string()).or_default(); + clock.process_pong(pong.sender_time_ms, pong.receiver_time_ms, recv_time); + } else { + // Follower uses drift corrector's clock sync + self.drift_corrector.clock_sync_mut().process_pong( + pong.sender_time_ms, + pong.receiver_time_ms, + recv_time, + ); + } + + self.pending_pings.remove(from); + } + + /// Periodic update - call this regularly + pub fn update(&mut self, peer_manager: &mut PeerManager) -> Result<()> { + let now = Instant::now(); + + // Leader broadcasts sync state + if self.is_leader + && now.duration_since(self.last_sync_broadcast).as_millis() as u64 >= SYNC_INTERVAL_MS + { + self.broadcast_sync_state(peer_manager)?; + self.last_sync_broadcast = now; + } + + // Send clock pings periodically + if now.duration_since(self.last_clock_ping).as_millis() as u64 >= CLOCK_PING_INTERVAL_MS { + self.send_clock_pings(peer_manager)?; + self.last_clock_ping = now; + } + + Ok(()) + } + + fn broadcast_sync_state(&mut self, peer_manager: &mut PeerManager) -> Result<()> { + let state = SyncState { + content_id: self.local_sync.content_id.clone(), + position_at_ref_ms: self.local_sync.position_at_ref_ms, + ref_wallclock_ms: self.local_sync.ref_wallclock_ms, + is_playing: self.local_sync.is_playing, + playback_rate: self.local_sync.playback_rate, + }; + + let msg = NetMessage::SyncState(state); + peer_manager.broadcast(&msg) + } + + fn send_clock_pings(&mut self, peer_manager: &mut PeerManager) -> Result<()> { + let now_ms = itk_sync::now_ms(); + let peers: Vec<_> = peer_manager + .peers() + .into_iter() + .map(|p| (p.id(), p.addr)) + .collect(); + + for (peer_id, addr) in peers { + let ping = itk_protocol::ClockPing { + sender_time_ms: now_ms, + }; + self.pending_pings.insert(peer_id, now_ms); + let msg = NetMessage::ClockPing(ping); + let _ = peer_manager.send_unreliable(addr, &msg); + } + + Ok(()) + } + + /// Send a video command to all peers (leader only) + pub fn send_command( + &mut self, + cmd: VideoCommand, + peer_manager: &mut PeerManager, + ) -> Result<()> { + if !self.is_leader { + return Ok(()); + } + + // Apply locally first + match &cmd { + VideoCommand::Load { url } => { + self.load(url.clone()); + }, + VideoCommand::Play => { + self.play(); + }, + VideoCommand::Pause => { + self.pause(); + }, + VideoCommand::Seek { position_ms } => { + self.seek(*position_ms); + }, + } + + // Broadcast to peers + let msg = NetMessage::VideoCommand(cmd); + peer_manager.broadcast_reliable(&msg) + } + + /// Process a received video command (followers) + pub fn receive_command(&mut self, cmd: VideoCommand) { + if self.is_leader { + return; // Leaders don't take commands + } + + match cmd { + VideoCommand::Load { url } => { + info!(url = %url, "Received load command"); + self.load(url); + }, + VideoCommand::Play => { + info!("Received play command"); + self.play(); + }, + VideoCommand::Pause => { + info!("Received pause command"); + self.pause(); + }, + VideoCommand::Seek { position_ms } => { + info!(position_ms, "Received seek command"); + self.seek(position_ms); + }, + } + } +} diff --git a/core/itk-protocol/Cargo.toml b/core/itk-protocol/Cargo.toml new file mode 100644 index 0000000..3644ae1 --- /dev/null +++ b/core/itk-protocol/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "itk-protocol" +description = "Wire protocol definitions for the Injection Toolkit" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +authors.workspace = true + +[dependencies] +serde = { workspace = true } +bincode = { workspace = true } +thiserror = { workspace = true } +crc32fast = { workspace = true } + +[lints] +workspace = true diff --git a/core/itk-protocol/fuzz/Cargo.toml b/core/itk-protocol/fuzz/Cargo.toml new file mode 100644 index 0000000..2edb1b5 --- /dev/null +++ b/core/itk-protocol/fuzz/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "itk-protocol-fuzz" +version = "0.0.0" +publish = false +edition = "2021" + +[package.metadata] +cargo-fuzz = true + +[dependencies] +libfuzzer-sys = "0.4" +arbitrary = { version = "1", features = ["derive"] } + +[dependencies.itk-protocol] +path = ".." + +[[bin]] +name = "fuzz_header" +path = "fuzz_targets/fuzz_header.rs" +test = false +doc = false +bench = false + +[[bin]] +name = "fuzz_decode" +path = "fuzz_targets/fuzz_decode.rs" +test = false +doc = false +bench = false + +[[bin]] +name = "fuzz_screen_rect" +path = "fuzz_targets/fuzz_screen_rect.rs" +test = false +doc = false +bench = false diff --git a/core/itk-protocol/fuzz/fuzz_targets/fuzz_decode.rs b/core/itk-protocol/fuzz/fuzz_targets/fuzz_decode.rs new file mode 100644 index 0000000..ac7b374 --- /dev/null +++ b/core/itk-protocol/fuzz/fuzz_targets/fuzz_decode.rs @@ -0,0 +1,16 @@ +#![no_main] + +use libfuzzer_sys::fuzz_target; +use itk_protocol::{decode, ScreenRect, StateEvent, StateSnapshot}; + +// Fuzz the full decode path with arbitrary bytes +// Goal: Ensure decode never panics and handles all malformed input gracefully +fuzz_target!(|data: &[u8]| { + // Try decoding as various message types + // All should return Ok or Err, never panic + + let _: Result<(_, ScreenRect), _> = decode(data); + let _: Result<(_, StateEvent), _> = decode(data); + let _: Result<(_, StateSnapshot), _> = decode(data); + let _: Result<(_, ()), _> = decode(data); +}); diff --git a/core/itk-protocol/fuzz/fuzz_targets/fuzz_header.rs b/core/itk-protocol/fuzz/fuzz_targets/fuzz_header.rs new file mode 100644 index 0000000..8fdc00e --- /dev/null +++ b/core/itk-protocol/fuzz/fuzz_targets/fuzz_header.rs @@ -0,0 +1,11 @@ +#![no_main] + +use libfuzzer_sys::fuzz_target; +use itk_protocol::Header; + +// Fuzz the header parser with arbitrary bytes +// Goal: Ensure the parser never panics on malformed input +fuzz_target!(|data: &[u8]| { + // This should never panic, only return Ok or Err + let _ = Header::from_bytes(data); +}); diff --git a/core/itk-protocol/fuzz/fuzz_targets/fuzz_screen_rect.rs b/core/itk-protocol/fuzz/fuzz_targets/fuzz_screen_rect.rs new file mode 100644 index 0000000..d71c008 --- /dev/null +++ b/core/itk-protocol/fuzz/fuzz_targets/fuzz_screen_rect.rs @@ -0,0 +1,54 @@ +#![no_main] + +use libfuzzer_sys::fuzz_target; +use arbitrary::Arbitrary; +use itk_protocol::{encode, decode, MessageType, ScreenRect}; + +// Fuzz ScreenRect encoding/decoding with structured input +// Uses Arbitrary to generate valid-ish ScreenRects that test edge cases +#[derive(Arbitrary, Debug)] +struct FuzzScreenRect { + x: f32, + y: f32, + width: f32, + height: f32, + rotation: f32, + visible: bool, +} + +fuzz_target!(|input: FuzzScreenRect| { + let rect = ScreenRect { + x: input.x, + y: input.y, + width: input.width, + height: input.height, + rotation: input.rotation, + visible: input.visible, + }; + + // Encoding should never panic + if let Ok(encoded) = encode(MessageType::ScreenRect, &rect) { + // If encoding succeeded, decoding the same bytes should work + let result: Result<(_, ScreenRect), _> = decode(&encoded); + + // If we encoded valid data, we should be able to decode it + if let Ok((msg_type, decoded)) = result { + assert_eq!(msg_type, MessageType::ScreenRect); + + // NaN != NaN, so we need special handling + if !input.x.is_nan() { + assert_eq!(decoded.x, rect.x); + } + if !input.y.is_nan() { + assert_eq!(decoded.y, rect.y); + } + if !input.width.is_nan() { + assert_eq!(decoded.width, rect.width); + } + if !input.height.is_nan() { + assert_eq!(decoded.height, rect.height); + } + assert_eq!(decoded.visible, rect.visible); + } + } +}); diff --git a/core/itk-protocol/src/lib.rs b/core/itk-protocol/src/lib.rs new file mode 100644 index 0000000..9436cbf --- /dev/null +++ b/core/itk-protocol/src/lib.rs @@ -0,0 +1,612 @@ +//! # ITK Protocol +//! +//! Wire protocol definitions for the Injection Toolkit. +//! +//! This crate defines the message format used for IPC communication between: +//! - Injected DLL/SO and daemon +//! - Daemon and overlay +//! - Daemon and MCP server +//! +//! ## Wire Format +//! +//! ```text +//! ┌─────────┬─────────┬──────────┬─────────────┬─────────┬───────────┐ +//! │ Magic │ Version │ MsgType │ PayloadLen │ CRC32 │ Payload │ +//! │ 4 bytes │ 4 bytes │ 4 bytes │ 4 bytes │ 4 bytes │ N bytes │ +//! │ "ITKP" │ 1 │ enum │ ≤ 1MB │ crc32 │ bincode │ +//! └─────────┴─────────┴──────────┴─────────────┴─────────┴───────────┘ +//! ``` + +use bincode::Options; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +/// Protocol magic bytes: "ITKP" (Injection Toolkit Protocol) +pub const MAGIC: [u8; 4] = *b"ITKP"; + +/// Current protocol version +pub const VERSION: u32 = 1; + +/// Maximum payload size (1 MB) +pub const MAX_PAYLOAD_SIZE: usize = 1024 * 1024; + +/// Header size in bytes +pub const HEADER_SIZE: usize = 20; // 4 + 4 + 4 + 4 + 4 + +/// Protocol errors +#[derive(Error, Debug)] +pub enum ProtocolError { + #[error("invalid magic bytes: expected {expected:?}, got {got:?}")] + InvalidMagic { expected: [u8; 4], got: [u8; 4] }, + + #[error("unsupported protocol version: {0}")] + UnsupportedVersion(u32), + + #[error("payload too large: {size} bytes (max {max})")] + PayloadTooLarge { size: usize, max: usize }, + + #[error("CRC mismatch: expected {expected:#x}, got {got:#x}")] + CrcMismatch { expected: u32, got: u32 }, + + #[error("unknown message type: {0}")] + UnknownMessageType(u32), + + #[error("serialization error: {0}")] + Serialization(#[from] bincode::Error), + + #[error("incomplete header: need {need} bytes, have {have}")] + IncompleteHeader { need: usize, have: usize }, + + #[error("incomplete payload: need {need} bytes, have {have}")] + IncompletePayload { need: usize, have: usize }, +} + +/// Message type identifiers +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[repr(u32)] +pub enum MessageType { + /// Heartbeat/keepalive + Ping = 0, + /// Response to ping + Pong = 1, + + /// Screen rectangle update (from injector to daemon/overlay) + ScreenRect = 10, + /// Window state update + WindowState = 11, + /// Combined overlay update + OverlayUpdate = 12, + + /// Application state snapshot + StateSnapshot = 20, + /// State change event + StateEvent = 21, + /// State query request + StateQuery = 22, + /// State query response + StateResponse = 23, + + /// Multiplayer sync state + SyncState = 30, + /// Clock synchronization ping + ClockPing = 31, + /// Clock synchronization pong + ClockPong = 32, + + // Video playback messages (40-49) + /// Load a video from URL or file path + VideoLoad = 40, + /// Start/resume video playback + VideoPlay = 41, + /// Pause video playback + VideoPause = 42, + /// Seek to a position in the video + VideoSeek = 43, + /// Video state update (position, duration, playing status) + VideoState = 44, + /// Video metadata (dimensions, duration, codec info) + VideoMetadata = 45, + /// Video playback error + VideoError = 46, + + /// Error response + Error = 255, +} + +impl TryFrom for MessageType { + type Error = ProtocolError; + + fn try_from(value: u32) -> Result { + match value { + 0 => Ok(Self::Ping), + 1 => Ok(Self::Pong), + 10 => Ok(Self::ScreenRect), + 11 => Ok(Self::WindowState), + 12 => Ok(Self::OverlayUpdate), + 20 => Ok(Self::StateSnapshot), + 21 => Ok(Self::StateEvent), + 22 => Ok(Self::StateQuery), + 23 => Ok(Self::StateResponse), + 30 => Ok(Self::SyncState), + 31 => Ok(Self::ClockPing), + 32 => Ok(Self::ClockPong), + 40 => Ok(Self::VideoLoad), + 41 => Ok(Self::VideoPlay), + 42 => Ok(Self::VideoPause), + 43 => Ok(Self::VideoSeek), + 44 => Ok(Self::VideoState), + 45 => Ok(Self::VideoMetadata), + 46 => Ok(Self::VideoError), + 255 => Ok(Self::Error), + _ => Err(ProtocolError::UnknownMessageType(value)), + } + } +} + +/// Message header +#[derive(Debug, Clone, Copy)] +pub struct Header { + pub magic: [u8; 4], + pub version: u32, + pub msg_type: MessageType, + pub payload_len: u32, + pub crc32: u32, +} + +impl Header { + /// Parse header from bytes + pub fn from_bytes(bytes: &[u8]) -> Result { + if bytes.len() < HEADER_SIZE { + return Err(ProtocolError::IncompleteHeader { + need: HEADER_SIZE, + have: bytes.len(), + }); + } + + let magic: [u8; 4] = bytes[0..4].try_into().unwrap(); + if magic != MAGIC { + return Err(ProtocolError::InvalidMagic { + expected: MAGIC, + got: magic, + }); + } + + let version = u32::from_le_bytes(bytes[4..8].try_into().unwrap()); + if version != VERSION { + return Err(ProtocolError::UnsupportedVersion(version)); + } + + let msg_type_raw = u32::from_le_bytes(bytes[8..12].try_into().unwrap()); + let msg_type = MessageType::try_from(msg_type_raw)?; + + let payload_len = u32::from_le_bytes(bytes[12..16].try_into().unwrap()); + if payload_len as usize > MAX_PAYLOAD_SIZE { + return Err(ProtocolError::PayloadTooLarge { + size: payload_len as usize, + max: MAX_PAYLOAD_SIZE, + }); + } + + let crc32 = u32::from_le_bytes(bytes[16..20].try_into().unwrap()); + + Ok(Self { + magic, + version, + msg_type, + payload_len, + crc32, + }) + } + + /// Serialize header to bytes + pub fn to_bytes(&self) -> [u8; HEADER_SIZE] { + let mut bytes = [0u8; HEADER_SIZE]; + bytes[0..4].copy_from_slice(&self.magic); + bytes[4..8].copy_from_slice(&self.version.to_le_bytes()); + bytes[8..12].copy_from_slice(&(self.msg_type as u32).to_le_bytes()); + bytes[12..16].copy_from_slice(&self.payload_len.to_le_bytes()); + bytes[16..20].copy_from_slice(&self.crc32.to_le_bytes()); + bytes + } +} + +/// Screen rectangle message +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ScreenRect { + /// X coordinate in window pixels + pub x: f32, + /// Y coordinate in window pixels + pub y: f32, + /// Width in window pixels + pub width: f32, + /// Height in window pixels + pub height: f32, + /// Rotation in radians (for perspective correction) + pub rotation: f32, + /// Whether the rect is valid/visible + pub visible: bool, +} + +/// Window state message +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WindowState { + /// Window X position on screen + pub x: i32, + /// Window Y position on screen + pub y: i32, + /// Window width + pub width: u32, + /// Window height + pub height: u32, + /// DPI scaling factor + pub dpi_scale: f32, + /// Whether fullscreen (overlay may not work) + pub is_fullscreen: bool, + /// Whether borderless windowed + pub is_borderless: bool, + /// Whether window is focused + pub is_focused: bool, +} + +/// Combined overlay update +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OverlayUpdate { + /// Screen rectangle for rendering + pub rect: ScreenRect, + /// Window state + pub window: WindowState, + /// Timestamp (monotonic ms) + pub timestamp_ms: u64, +} + +/// Application state snapshot (generic, app-specific fields in `data`) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StateSnapshot { + /// Application identifier (e.g., "vrchat", "nms") + pub app_id: String, + /// Snapshot timestamp + pub timestamp_ms: u64, + /// Application-specific state as JSON + pub data: String, +} + +/// State change event +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StateEvent { + /// Application identifier + pub app_id: String, + /// Event type (app-specific) + pub event_type: String, + /// Event timestamp + pub timestamp_ms: u64, + /// Event data as JSON + pub data: String, +} + +/// State query request +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StateQuery { + /// Application identifier + pub app_id: String, + /// Query type (app-specific) + pub query_type: String, + /// Query parameters as JSON + pub params: String, +} + +/// State query response +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StateResponse { + /// Whether query succeeded + pub success: bool, + /// Response data as JSON (if success) + pub data: Option, + /// Error message (if !success) + pub error: Option, +} + +/// Multiplayer sync state +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SyncState { + /// Content identifier (URL, file hash, etc.) + pub content_id: String, + /// Position at reference time (milliseconds) + pub position_at_ref_ms: u64, + /// Reference wallclock time (milliseconds since epoch) + pub ref_wallclock_ms: u64, + /// Whether currently playing + pub is_playing: bool, + /// Playback rate (1.0 = normal) + pub playback_rate: f64, +} + +/// Clock synchronization ping +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClockPing { + /// Sender's local time (milliseconds) + pub sender_time_ms: u64, +} + +/// Clock synchronization pong +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClockPong { + /// Original sender's time from ping + pub sender_time_ms: u64, + /// Receiver's local time when ping was received + pub receiver_time_ms: u64, +} + +/// Error message +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ErrorMessage { + /// Error code + pub code: u32, + /// Human-readable message + pub message: String, +} + +// ============================================================================= +// Video Playback Messages +// ============================================================================= + +/// Load a video from URL or file path +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VideoLoad { + /// Video source (file path or URL) + pub source: String, + /// Start position in milliseconds (0 = beginning) + pub start_position_ms: u64, + /// Whether to start playing immediately + pub autoplay: bool, +} + +/// Start or resume video playback +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VideoPlay { + /// Optional position to start from (None = current position) + pub from_position_ms: Option, +} + +/// Pause video playback +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VideoPause { + // Empty struct - just a command +} + +/// Seek to a position in the video +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VideoSeek { + /// Target position in milliseconds + pub position_ms: u64, +} + +/// Video state update (broadcast periodically and on state changes) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VideoState { + /// Content identifier (URL hash or file path) + pub content_id: String, + /// Current playback position in milliseconds + pub position_ms: u64, + /// Total duration in milliseconds (0 if unknown/live) + pub duration_ms: u64, + /// Whether currently playing + pub is_playing: bool, + /// Whether currently buffering + pub is_buffering: bool, + /// Playback rate (1.0 = normal) + pub playback_rate: f64, + /// Volume (0.0 - 1.0) + pub volume: f32, +} + +/// Video metadata (sent once when video is loaded) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VideoMetadata { + /// Content identifier (URL hash or file path) + pub content_id: String, + /// Video width in pixels + pub width: u32, + /// Video height in pixels + pub height: u32, + /// Duration in milliseconds (0 if unknown/live) + pub duration_ms: u64, + /// Frames per second (0 if unknown) + pub fps: f32, + /// Codec name (e.g., "h264", "vp9") + pub codec: String, + /// Whether this is a live stream + pub is_live: bool, + /// Human-readable title (if available from metadata) + pub title: Option, +} + +/// Video playback error +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VideoError { + /// Error code + pub code: VideoErrorCode, + /// Human-readable error message + pub message: String, + /// Whether playback can be retried + pub is_recoverable: bool, +} + +/// Video error codes +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[repr(u32)] +pub enum VideoErrorCode { + /// Unknown error + Unknown = 0, + /// Failed to open the source (file not found, network error, etc.) + OpenFailed = 1, + /// No video stream found in the source + NoVideoStream = 2, + /// Codec not supported + UnsupportedCodec = 3, + /// Decode error (corrupted data) + DecodeError = 4, + /// Network error during streaming + NetworkError = 5, + /// Source requires authentication + AuthenticationRequired = 6, + /// Geographic restriction + GeoRestricted = 7, + /// YouTube extraction failed (yt-dlp error) + YoutubeExtractionFailed = 8, + /// YouTube support not enabled + YoutubeNotEnabled = 9, +} + +/// Bincode configuration with size limits to prevent allocation bombs +fn bincode_config() -> impl bincode::Options { + bincode::options() + .with_limit(MAX_PAYLOAD_SIZE as u64) + .with_little_endian() + .with_fixint_encoding() +} + +/// Encode a message to wire format +pub fn encode(msg_type: MessageType, payload: &T) -> Result, ProtocolError> { + let payload_bytes = bincode_config().serialize(payload)?; + + if payload_bytes.len() > MAX_PAYLOAD_SIZE { + return Err(ProtocolError::PayloadTooLarge { + size: payload_bytes.len(), + max: MAX_PAYLOAD_SIZE, + }); + } + + let crc = crc32fast::hash(&payload_bytes); + + let header = Header { + magic: MAGIC, + version: VERSION, + msg_type, + payload_len: payload_bytes.len() as u32, + crc32: crc, + }; + + let mut result = Vec::with_capacity(HEADER_SIZE + payload_bytes.len()); + result.extend_from_slice(&header.to_bytes()); + result.extend_from_slice(&payload_bytes); + + Ok(result) +} + +/// Decode a message from wire format +/// +/// Returns the message type and deserialized payload +pub fn decode Deserialize<'de>>( + bytes: &[u8], +) -> Result<(MessageType, T), ProtocolError> { + let header = Header::from_bytes(bytes)?; + + let payload_start = HEADER_SIZE; + let payload_end = payload_start + header.payload_len as usize; + + if bytes.len() < payload_end { + return Err(ProtocolError::IncompletePayload { + need: header.payload_len as usize, + have: bytes.len() - HEADER_SIZE, + }); + } + + let payload_bytes = &bytes[payload_start..payload_end]; + + // Verify CRC + let computed_crc = crc32fast::hash(payload_bytes); + if computed_crc != header.crc32 { + return Err(ProtocolError::CrcMismatch { + expected: header.crc32, + got: computed_crc, + }); + } + + // Use bincode with size limits to prevent allocation bombs + let payload: T = bincode_config().deserialize(payload_bytes)?; + + Ok((header.msg_type, payload)) +} + +/// Decode only the header (useful for routing without deserializing) +pub fn decode_header(bytes: &[u8]) -> Result { + Header::from_bytes(bytes) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_encode_decode_screen_rect() { + let rect = ScreenRect { + x: 100.0, + y: 200.0, + width: 640.0, + height: 480.0, + rotation: 0.0, + visible: true, + }; + + let encoded = encode(MessageType::ScreenRect, &rect).unwrap(); + let (msg_type, decoded): (_, ScreenRect) = decode(&encoded).unwrap(); + + assert_eq!(msg_type, MessageType::ScreenRect); + assert_eq!(decoded.x, rect.x); + assert_eq!(decoded.y, rect.y); + assert_eq!(decoded.width, rect.width); + assert_eq!(decoded.height, rect.height); + assert_eq!(decoded.visible, rect.visible); + } + + #[test] + fn test_header_roundtrip() { + let header = Header { + magic: MAGIC, + version: VERSION, + msg_type: MessageType::StateSnapshot, + payload_len: 1234, + crc32: 0xDEADBEEF, + }; + + let bytes = header.to_bytes(); + let parsed = Header::from_bytes(&bytes).unwrap(); + + assert_eq!(parsed.magic, header.magic); + assert_eq!(parsed.version, header.version); + assert_eq!(parsed.msg_type, header.msg_type); + assert_eq!(parsed.payload_len, header.payload_len); + assert_eq!(parsed.crc32, header.crc32); + } + + #[test] + fn test_invalid_magic() { + let mut bytes = [0u8; HEADER_SIZE]; + bytes[0..4].copy_from_slice(b"NOPE"); + + let result = Header::from_bytes(&bytes); + assert!(matches!(result, Err(ProtocolError::InvalidMagic { .. }))); + } + + #[test] + fn test_crc_validation() { + let rect = ScreenRect { + x: 100.0, + y: 200.0, + width: 640.0, + height: 480.0, + rotation: 0.0, + visible: true, + }; + + let mut encoded = encode(MessageType::ScreenRect, &rect).unwrap(); + + // Corrupt the payload + if let Some(last) = encoded.last_mut() { + *last ^= 0xFF; + } + + let result: Result<(_, ScreenRect), _> = decode(&encoded); + assert!(matches!(result, Err(ProtocolError::CrcMismatch { .. }))); + } +} diff --git a/core/itk-shmem/Cargo.toml b/core/itk-shmem/Cargo.toml new file mode 100644 index 0000000..5c7c795 --- /dev/null +++ b/core/itk-shmem/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "itk-shmem" +description = "Cross-platform shared memory for the Injection Toolkit" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +authors.workspace = true + +[lints.rust] +unexpected_cfgs = { level = "warn", check-cfg = ['cfg(loom)'] } + +[dependencies] +thiserror = { workspace = true } +cfg-if = { workspace = true } + +[target.'cfg(windows)'.dependencies] +windows = { workspace = true } + +[target.'cfg(unix)'.dependencies] +libc = { workspace = true } +nix = { workspace = true } + +[dev-dependencies] +loom = "0.7" diff --git a/core/itk-shmem/src/lib.rs b/core/itk-shmem/src/lib.rs new file mode 100644 index 0000000..ba4e57b --- /dev/null +++ b/core/itk-shmem/src/lib.rs @@ -0,0 +1,795 @@ +//! # ITK Shared Memory +//! +//! Cross-platform shared memory primitives for the Injection Toolkit. +//! +//! This crate provides: +//! - Platform-agnostic shared memory regions +//! - Seqlock-based lock-free synchronization +//! - Triple-buffered frame transfer +//! +//! ## Platform Support +//! +//! - **Windows**: Named shared memory via `CreateFileMappingW` +//! - **Linux**: POSIX shared memory via `shm_open` + +use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; +use thiserror::Error; + +/// Shared memory errors +#[derive(Error, Debug)] +pub enum ShmemError { + #[error("failed to create shared memory: {0}")] + CreateFailed(String), + + #[error("failed to open shared memory: {0}")] + OpenFailed(String), + + #[error("failed to map shared memory: {0}")] + MapFailed(String), + + #[error("shared memory size mismatch: expected {expected}, got {actual}")] + SizeMismatch { expected: usize, actual: usize }, + + #[error("invalid shared memory name: {0}")] + InvalidName(String), + + #[error("shared memory already exists")] + AlreadyExists, + + #[error("shared memory not found")] + NotFound, + + #[error("platform error: {0}")] + Platform(String), + + #[error("size calculation overflow: dimensions {width}x{height} exceed maximum")] + SizeOverflow { width: u32, height: u32 }, + + #[error("seqlock read contention: writer may have crashed or is holding lock too long")] + SeqlockContention, +} + +/// Result type for shared memory operations +pub type Result = std::result::Result; + +/// Shared memory region handle +pub struct SharedMemory { + ptr: *mut u8, + size: usize, + name: String, + #[cfg(windows)] + handle: windows::Win32::Foundation::HANDLE, + #[cfg(unix)] + fd: std::os::unix::io::RawFd, + #[allow(dead_code)] // Used for cleanup logic, may be utilized in future + owner: bool, +} + +// SAFETY: SharedMemory uses raw pointers but the memory region is process-shared +// and we use atomic operations for synchronization. +unsafe impl Send for SharedMemory {} +unsafe impl Sync for SharedMemory {} + +impl SharedMemory { + /// Create a new shared memory region + /// + /// # Arguments + /// * `name` - Unique identifier for the shared memory + /// * `size` - Size in bytes + /// + /// # Platform Notes + /// - Windows: Name becomes `Global\{name}` or `Local\{name}` + /// - Linux: Name becomes `/itk_{name}` in `/dev/shm` + pub fn create(name: &str, size: usize) -> Result { + Self::create_impl(name, size) + } + + /// Open an existing shared memory region + pub fn open(name: &str, size: usize) -> Result { + Self::open_impl(name, size) + } + + /// Get a raw pointer to the shared memory + /// + /// # Safety + /// Caller must ensure proper synchronization when accessing the memory. + pub fn as_ptr(&self) -> *mut u8 { + self.ptr + } + + /// Get the size of the shared memory region + pub fn size(&self) -> usize { + self.size + } + + /// Get the name of the shared memory region + pub fn name(&self) -> &str { + &self.name + } + + /// Read bytes from the shared memory + /// + /// # Safety + /// Caller must ensure no concurrent writes or use appropriate synchronization. + pub unsafe fn read(&self, offset: usize, buf: &mut [u8]) -> Result<()> { + // Use checked_add to prevent overflow bypassing the bounds check + let end = offset + .checked_add(buf.len()) + .ok_or(ShmemError::Platform("offset + length overflow".into()))?; + + if end > self.size { + return Err(ShmemError::SizeMismatch { + expected: end, + actual: self.size, + }); + } + + unsafe { + std::ptr::copy_nonoverlapping(self.ptr.add(offset), buf.as_mut_ptr(), buf.len()); + } + + Ok(()) + } + + /// Write bytes to the shared memory + /// + /// # Safety + /// Caller must ensure no concurrent reads or use appropriate synchronization. + pub unsafe fn write(&self, offset: usize, buf: &[u8]) -> Result<()> { + // Use checked_add to prevent overflow bypassing the bounds check + let end = offset + .checked_add(buf.len()) + .ok_or(ShmemError::Platform("offset + length overflow".into()))?; + + if end > self.size { + return Err(ShmemError::SizeMismatch { + expected: end, + actual: self.size, + }); + } + + unsafe { + std::ptr::copy_nonoverlapping(buf.as_ptr(), self.ptr.add(offset), buf.len()); + } + + Ok(()) + } +} + +// Platform-specific implementations +cfg_if::cfg_if! { + if #[cfg(windows)] { + mod windows_impl; + + impl SharedMemory { + fn create_impl(name: &str, size: usize) -> Result { + windows_impl::create(name, size) + } + + fn open_impl(name: &str, size: usize) -> Result { + windows_impl::open(name, size) + } + } + + impl Drop for SharedMemory { + fn drop(&mut self) { + windows_impl::cleanup(self); + } + } + } else if #[cfg(unix)] { + mod unix_impl; + + impl SharedMemory { + fn create_impl(name: &str, size: usize) -> Result { + unix_impl::create(name, size) + } + + fn open_impl(name: &str, size: usize) -> Result { + unix_impl::open(name, size) + } + } + + impl Drop for SharedMemory { + fn drop(&mut self) { + unix_impl::cleanup(self); + } + } + } +} + +/// Seqlock header for lock-free synchronization +/// +/// This is placed at the start of the shared memory region. +/// Uses carefully chosen memory orderings for ARM compatibility. +/// +/// # Thread Safety +/// +/// **SINGLE-WRITER ONLY**: This seqlock assumes exactly one writer thread/process. +/// Multiple concurrent writers will corrupt the sequence counter and cause +/// undefined behavior. If you need multiple writers, protect the write path +/// with an external mutex. +/// +/// Multiple readers are safe and supported - the seqlock is designed for +/// one-writer-many-readers scenarios (e.g., frame buffer synchronization). +#[repr(C)] +pub struct SeqlockHeader { + /// Sequence number (odd = write in progress, even = consistent) + pub seq: AtomicU32, + /// Index of the current read buffer (0-2 for triple buffering) + pub read_idx: AtomicU32, + /// Presentation timestamp in milliseconds + pub pts_ms: AtomicU64, + /// Frame width (for validation) + pub frame_width: AtomicU32, + /// Frame height (for validation) + pub frame_height: AtomicU32, + /// Playback state (1 = playing, 0 = paused) + pub is_playing: AtomicU32, + /// Quick hash of content ID for change detection + pub content_id_hash: AtomicU64, + /// Total duration in milliseconds (0 if unknown/live) + pub duration_ms: AtomicU64, + /// Padding to cache line (64 bytes) + _padding: [u8; 12], +} + +impl SeqlockHeader { + /// Size of the header in bytes (cache-line aligned) + pub const SIZE: usize = 64; + + /// Initialize a seqlock header at the given memory location + /// + /// # Safety + /// - The pointer must be valid and aligned for SeqlockHeader. + /// - The returned reference is only valid while the underlying memory is valid. + /// Caller must ensure the memory outlives the returned reference. + pub unsafe fn init<'a>(ptr: *mut u8) -> &'a Self { + let header = ptr as *mut SeqlockHeader; + + unsafe { + // Zero-initialize + std::ptr::write_bytes(header, 0, 1); + + // Set initial values + (*header).seq = AtomicU32::new(0); + (*header).read_idx = AtomicU32::new(0); + (*header).pts_ms = AtomicU64::new(0); + (*header).frame_width = AtomicU32::new(0); + (*header).frame_height = AtomicU32::new(0); + (*header).is_playing = AtomicU32::new(0); + (*header).content_id_hash = AtomicU64::new(0); + (*header).duration_ms = AtomicU64::new(0); + + &*header + } + } + + /// Get a reference to an existing seqlock header + /// + /// # Safety + /// - The pointer must be valid and point to an initialized SeqlockHeader. + /// - The returned reference is only valid while the underlying memory is valid. + /// Caller must ensure the memory outlives the returned reference. + pub unsafe fn from_ptr<'a>(ptr: *mut u8) -> &'a Self { + unsafe { &*(ptr as *const SeqlockHeader) } + } + + /// Begin a write operation (marks sequence as odd) + /// + /// Uses Acquire ordering to prevent subsequent data writes from being + /// reordered before the sequence increment. This ensures readers see + /// the odd sequence before any new data is written. + pub fn begin_write(&self) { + self.seq.fetch_add(1, Ordering::Acquire); + } + + /// End a write operation (marks sequence as even) + /// + /// Uses Release ordering to ensure all data writes are visible + /// before the even sequence number. + pub fn end_write(&self) { + self.seq.fetch_add(1, Ordering::Release); + } + + /// Read the current sequence number with Acquire ordering + fn read_seq_acquire(&self) -> u32 { + self.seq.load(Ordering::Acquire) + } + + /// Check if a write is in progress (sequence is odd) + pub fn is_write_in_progress(&self) -> bool { + self.read_seq_acquire() & 1 != 0 + } + + /// Try to read consistent state + /// + /// Returns None if a write was in progress. + /// Caller should retry in a loop. + /// + /// Memory ordering for ARM compatibility: + /// - Acquire on first seq read synchronizes with writer's Release + /// - Relaxed for data reads (bounded by seqlock) + /// - Acquire fence before second seq check prevents data reads from + /// being reordered past the validation + pub fn try_read(&self) -> Option { + let seq1 = self.read_seq_acquire(); + if seq1 & 1 != 0 { + return None; // Write in progress + } + + // Data reads can be Relaxed - bounded by seqlock fence below + let state = SeqlockState { + read_idx: self.read_idx.load(Ordering::Relaxed), + pts_ms: self.pts_ms.load(Ordering::Relaxed), + frame_width: self.frame_width.load(Ordering::Relaxed), + frame_height: self.frame_height.load(Ordering::Relaxed), + is_playing: self.is_playing.load(Ordering::Relaxed) != 0, + content_id_hash: self.content_id_hash.load(Ordering::Relaxed), + duration_ms: self.duration_ms.load(Ordering::Relaxed), + }; + + // Critical: Prevent data loads from sinking past the sequence check. + // Without this fence, the CPU could reorder seq2 read before data reads, + // causing us to validate against an old sequence while reading new data. + std::sync::atomic::fence(Ordering::Acquire); + + // Relaxed is sufficient here - the fence above provides ordering + let seq2 = self.seq.load(Ordering::Relaxed); + if seq1 != seq2 { + return None; // Write happened during read + } + + Some(state) + } + + /// Read state, spinning until consistent + /// + /// **Warning**: This can spin indefinitely if the writer crashes while holding + /// the seqlock (seq stuck on odd). Use `read_with_timeout` for bounded waiting. + pub fn read_blocking(&self) -> SeqlockState { + loop { + if let Some(state) = self.try_read() { + return state; + } + std::hint::spin_loop(); + } + } + + /// Read state with bounded retries to prevent infinite spinning. + /// + /// If the writer crashes while holding the seqlock (seq stuck on odd), + /// this will return `SeqlockContention` after `max_attempts` iterations + /// instead of spinning forever. + /// + /// # Arguments + /// * `max_attempts` - Maximum number of read attempts before giving up + /// + /// # Recommended Values + /// - 1000 for tight polling loops + /// - 10000 for looser real-time requirements + /// - 100000 for batch processing with tolerance for delays + pub fn read_with_timeout(&self, max_attempts: u32) -> Result { + for _ in 0..max_attempts { + if let Some(state) = self.try_read() { + return Ok(state); + } + std::hint::spin_loop(); + } + Err(ShmemError::SeqlockContention) + } +} + +/// Snapshot of seqlock state +#[derive(Debug, Clone)] +pub struct SeqlockState { + pub read_idx: u32, + pub pts_ms: u64, + pub frame_width: u32, + pub frame_height: u32, + pub is_playing: bool, + pub content_id_hash: u64, + pub duration_ms: u64, +} + +/// Triple-buffered frame storage +/// +/// Layout in shared memory: +/// ```text +/// [SeqlockHeader: 64 bytes] +/// [Buffer 0: width * height * 4 bytes] +/// [Buffer 1: width * height * 4 bytes] +/// [Buffer 2: width * height * 4 bytes] +/// ``` +pub struct FrameBuffer { + shmem: SharedMemory, + frame_size: usize, + width: u32, + height: u32, +} + +impl FrameBuffer { + /// Calculate total shared memory size needed for given frame dimensions + /// + /// Returns an error if the dimensions would cause arithmetic overflow. + pub fn calculate_size(width: u32, height: u32) -> Result { + let frame_size = (width as usize) + .checked_mul(height as usize) + .and_then(|s| s.checked_mul(4)) // RGBA + .ok_or(ShmemError::SizeOverflow { width, height })?; + + SeqlockHeader::SIZE + .checked_add( + frame_size + .checked_mul(3) + .ok_or(ShmemError::SizeOverflow { width, height })?, + ) + .ok_or(ShmemError::SizeOverflow { width, height }) + } + + /// Create a new frame buffer + pub fn create(name: &str, width: u32, height: u32) -> Result { + let frame_size = (width as usize) + .checked_mul(height as usize) + .and_then(|s| s.checked_mul(4)) + .ok_or(ShmemError::SizeOverflow { width, height })?; + let total_size = Self::calculate_size(width, height)?; + + let shmem = SharedMemory::create(name, total_size)?; + + // Initialize the header (Relaxed is fine during init - no readers yet) + unsafe { + let header = SeqlockHeader::init(shmem.as_ptr()); + header.frame_width.store(width, Ordering::Relaxed); + header.frame_height.store(height, Ordering::Relaxed); + } + + Ok(Self { + shmem, + frame_size, + width, + height, + }) + } + + /// Open an existing frame buffer + pub fn open(name: &str, width: u32, height: u32) -> Result { + let frame_size = (width as usize) + .checked_mul(height as usize) + .and_then(|s| s.checked_mul(4)) + .ok_or(ShmemError::SizeOverflow { width, height })?; + let total_size = Self::calculate_size(width, height)?; + + let shmem = SharedMemory::open(name, total_size)?; + + Ok(Self { + shmem, + frame_size, + width, + height, + }) + } + + /// Get the frame width + pub fn width(&self) -> u32 { + self.width + } + + /// Get the frame height + pub fn height(&self) -> u32 { + self.height + } + + /// Get the size of a single frame in bytes + pub fn frame_size(&self) -> usize { + self.frame_size + } + + /// Get the seqlock header + pub fn header(&self) -> &SeqlockHeader { + unsafe { SeqlockHeader::from_ptr(self.shmem.as_ptr()) } + } + + /// Set the duration in milliseconds (metadata, not per-frame). + /// + /// This is written outside the seqlock since it only changes on load. + pub fn set_duration_ms(&self, duration_ms: u64) { + self.header() + .duration_ms + .store(duration_ms, Ordering::Release); + } + + /// Get the duration in milliseconds. + pub fn duration_ms(&self) -> u64 { + self.header().duration_ms.load(Ordering::Acquire) + } + + /// Get a pointer to a specific buffer + fn buffer_ptr(&self, idx: u32) -> *mut u8 { + let offset = SeqlockHeader::SIZE + (idx as usize % 3) * self.frame_size; + unsafe { self.shmem.as_ptr().add(offset) } + } + + /// Write a frame (producer side) + /// + /// # Safety + /// - `data` must be exactly `frame_size` bytes + /// - Only one writer should be active + pub unsafe fn write_frame(&self, data: &[u8], pts_ms: u64, content_id_hash: u64) -> Result<()> { + if data.len() != self.frame_size { + return Err(ShmemError::SizeMismatch { + expected: self.frame_size, + actual: data.len(), + }); + } + + let header = self.header(); + + // Get next write buffer (round-robin) + // Relaxed is fine here - we're the only writer + let current_idx = header.read_idx.load(Ordering::Relaxed); + let write_idx = (current_idx + 1) % 3; + + // Begin critical section - marks seq odd + // Acquire ordering prevents subsequent data writes from floating up + header.begin_write(); + + // Write buffer data inside the seqlock critical section + // This ensures proper ordering on weak memory models (ARM) + let buf_ptr = self.buffer_ptr(write_idx); + unsafe { + std::ptr::copy_nonoverlapping(data.as_ptr(), buf_ptr, data.len()); + } + + // Update header fields + header.read_idx.store(write_idx, Ordering::Relaxed); + header.pts_ms.store(pts_ms, Ordering::Relaxed); + header + .content_id_hash + .store(content_id_hash, Ordering::Relaxed); + + // End critical section - marks seq even + // Release ordering ensures all writes are visible before this + header.end_write(); + + Ok(()) + } + + /// Maximum retry attempts before returning contention error + const MAX_READ_ATTEMPTS: u32 = 10000; + + /// Read the current frame (consumer side) + /// + /// Returns (pts_ms, data_changed) where data_changed indicates + /// if this is a new frame since the last read. + /// + /// This method has bounded retries to prevent infinite spinning if the + /// writer crashes or holds the seqlock for too long. + pub fn read_frame(&self, last_pts: u64, buf: &mut [u8]) -> Result<(u64, bool)> { + if buf.len() != self.frame_size { + return Err(ShmemError::SizeMismatch { + expected: self.frame_size, + actual: buf.len(), + }); + } + + for _ in 0..Self::MAX_READ_ATTEMPTS { + // Use bounded read to prevent spinning on crashed writer + let state = self.header().read_with_timeout(1000)?; + + // Skip copy if same frame + if state.pts_ms == last_pts { + return Ok((state.pts_ms, false)); + } + + // Copy frame data + let buf_ptr = self.buffer_ptr(state.read_idx); + unsafe { + std::ptr::copy_nonoverlapping(buf_ptr, buf.as_mut_ptr(), self.frame_size); + } + + // Verify consistency after copy + if let Some(state2) = self.header().try_read() + && state2.read_idx == state.read_idx + && state2.pts_ms == state.pts_ms + { + return Ok((state.pts_ms, true)); + } + + // State changed during read, retry + std::hint::spin_loop(); + } + + Err(ShmemError::SeqlockContention) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_seqlock_header_size() { + assert_eq!(std::mem::size_of::(), SeqlockHeader::SIZE); + } + + #[test] + fn test_calculate_size() { + // 1280x720 RGBA = 3,686,400 bytes per frame + // Triple buffered + 64 byte header + let size = FrameBuffer::calculate_size(1280, 720).unwrap(); + assert_eq!(size, 64 + (1280 * 720 * 4 * 3)); + + // Test overflow detection with extremely large dimensions + let overflow_result = FrameBuffer::calculate_size(u32::MAX, u32::MAX); + assert!(overflow_result.is_err()); + } + + /// Aligned storage for SeqlockHeader in tests + #[repr(C, align(64))] + struct AlignedHeaderMem([u8; SeqlockHeader::SIZE]); + + #[test] + fn test_seqlock_write_makes_odd_then_even() { + let mut header_mem = AlignedHeaderMem([0u8; SeqlockHeader::SIZE]); + let header = unsafe { SeqlockHeader::init(header_mem.0.as_mut_ptr()) }; + + assert_eq!(header.seq.load(Ordering::SeqCst), 0); + assert!(!header.is_write_in_progress()); + + header.begin_write(); + assert_eq!(header.seq.load(Ordering::SeqCst), 1); + assert!(header.is_write_in_progress()); + + header.end_write(); + assert_eq!(header.seq.load(Ordering::SeqCst), 2); + assert!(!header.is_write_in_progress()); + } + + #[test] + fn test_seqlock_try_read_returns_none_during_write() { + let mut header_mem = AlignedHeaderMem([0u8; SeqlockHeader::SIZE]); + let header = unsafe { SeqlockHeader::init(header_mem.0.as_mut_ptr()) }; + + // Set some initial values + header.read_idx.store(1, Ordering::Relaxed); + header.pts_ms.store(12345, Ordering::Relaxed); + + // Before write - should succeed + let state = header.try_read(); + assert!(state.is_some()); + assert_eq!(state.unwrap().read_idx, 1); + + // During write - should fail + header.begin_write(); + assert!(header.try_read().is_none()); + + // After write - should succeed + header.end_write(); + let state = header.try_read(); + assert!(state.is_some()); + } + + #[test] + fn test_seqlock_detects_concurrent_modification() { + let mut header_mem = AlignedHeaderMem([0u8; SeqlockHeader::SIZE]); + let header = unsafe { SeqlockHeader::init(header_mem.0.as_mut_ptr()) }; + + // Simulate a "torn read" scenario: + // Reader gets seq1, then writer completes a full write cycle + let seq1 = header.seq.load(Ordering::Acquire); + assert_eq!(seq1, 0); + + // Writer does a complete write + header.begin_write(); + header.read_idx.store(42, Ordering::Relaxed); + header.end_write(); + + // seq changed from 0 to 2, so seq1 != seq2 + let seq2 = header.seq.load(Ordering::Acquire); + assert_ne!(seq1, seq2); + } +} + +/// Loom-based concurrency tests for verifying seqlock correctness +/// Run with: RUSTFLAGS="--cfg loom" cargo test --lib loom_tests +#[cfg(all(test, loom))] +mod loom_tests { + use loom::sync::Arc; + use loom::sync::atomic::{AtomicU32, AtomicU64, Ordering}; + use loom::thread; + + /// Simplified seqlock for loom testing + struct LoomSeqlock { + seq: AtomicU32, + value: AtomicU64, + } + + impl LoomSeqlock { + fn new() -> Self { + Self { + seq: AtomicU32::new(0), + value: AtomicU64::new(0), + } + } + + fn write(&self, val: u64) { + // Begin write (odd) - Acquire prevents data writes from floating up + self.seq.fetch_add(1, Ordering::Acquire); + + // Write data + self.value.store(val, Ordering::Relaxed); + + // End write (even) - Release ensures data writes are visible + self.seq.fetch_add(1, Ordering::Release); + } + + fn try_read(&self) -> Option { + let seq1 = self.seq.load(Ordering::Acquire); + if seq1 & 1 != 0 { + return None; // Write in progress + } + + let val = self.value.load(Ordering::Relaxed); + + // Fence prevents data loads from sinking past seq2 check + loom::sync::atomic::fence(Ordering::Acquire); + + let seq2 = self.seq.load(Ordering::Relaxed); + if seq1 != seq2 { + return None; // Write happened during read + } + + Some(val) + } + } + + #[test] + fn loom_seqlock_single_writer_single_reader() { + loom::model(|| { + let lock = Arc::new(LoomSeqlock::new()); + let lock2 = Arc::clone(&lock); + + let writer = thread::spawn(move || { + lock2.write(42); + }); + + // Reader may see 0 (initial) or 42 (written), but never garbage + loop { + if let Some(val) = lock.try_read() { + assert!(val == 0 || val == 42, "Got unexpected value: {}", val); + break; + } + // Retry if write was in progress + loom::thread::yield_now(); + } + + writer.join().unwrap(); + }); + } + + #[test] + fn loom_seqlock_multiple_writes() { + loom::model(|| { + let lock = Arc::new(LoomSeqlock::new()); + let lock2 = Arc::clone(&lock); + + let writer = thread::spawn(move || { + lock2.write(1); + lock2.write(2); + }); + + // Reader should only see valid states: 0, 1, or 2 + let mut last_seen = 0; + for _ in 0..3 { + if let Some(val) = lock.try_read() { + assert!(val == 0 || val == 1 || val == 2, "Invalid value: {}", val); + assert!(val >= last_seen, "Values should not decrease"); + last_seen = val; + } + loom::thread::yield_now(); + } + + writer.join().unwrap(); + }); + } +} diff --git a/core/itk-shmem/src/unix_impl.rs b/core/itk-shmem/src/unix_impl.rs new file mode 100644 index 0000000..1184e60 --- /dev/null +++ b/core/itk-shmem/src/unix_impl.rs @@ -0,0 +1,165 @@ +//! Unix shared memory implementation using POSIX shm_open + +use super::{Result, SharedMemory, ShmemError}; +use nix::fcntl::OFlag; +use nix::sys::mman::{MapFlags, ProtFlags, mmap, munmap, shm_open, shm_unlink}; +use nix::sys::stat::{Mode, fstat}; +use nix::unistd::ftruncate; +use std::ffi::CString; +use std::num::NonZeroUsize; +use std::os::fd::{AsRawFd, IntoRawFd}; + +/// Validate and create the shared memory name (/dev/shm/itk_name on Linux). +/// +/// For security, this function: +/// - Rejects names with path traversal sequences (..) +/// - Rejects absolute paths (starting with /) +/// - Rejects names with path separators +/// - Creates names only in /itk_* namespace +fn make_name(name: &str) -> Result { + // Reject path traversal attempts + if name.contains("..") { + return Err(ShmemError::InvalidName( + "name cannot contain path traversal sequences".into(), + )); + } + + // Reject absolute paths + if name.starts_with('/') { + return Err(ShmemError::InvalidName( + "name cannot be an absolute path".into(), + )); + } + + // Reject names with path separators + if name.contains('/') || name.contains('\\') { + return Err(ShmemError::InvalidName( + "name cannot contain path separators".into(), + )); + } + + Ok(format!("/itk_{}", name)) +} + +pub fn create(name: &str, size: usize) -> Result { + let full_name = make_name(name)?; + let c_name = CString::new(full_name.clone()) + .map_err(|_| ShmemError::InvalidName("Name contains null bytes".into()))?; + + // Try to unlink first in case it exists from a previous crash + let _ = shm_unlink(c_name.as_c_str()); + + // Create shared memory object + let fd = shm_open( + c_name.as_c_str(), + OFlag::O_CREAT | OFlag::O_RDWR | OFlag::O_EXCL, + Mode::S_IRUSR | Mode::S_IWUSR, + ) + .map_err(|e| ShmemError::CreateFailed(e.to_string()))?; + + // Set size + if let Err(e) = ftruncate(&fd, size as i64) { + // OwnedFd will close the fd when dropped, no need to manually close + // (manual close would cause double-close since OwnedFd also closes on drop) + let _ = shm_unlink(c_name.as_c_str()); + return Err(ShmemError::CreateFailed(format!("ftruncate failed: {}", e))); + } + + // Map memory + let ptr = unsafe { + mmap( + None, + NonZeroUsize::new(size) + .ok_or_else(|| ShmemError::CreateFailed("Size is zero".into()))?, + ProtFlags::PROT_READ | ProtFlags::PROT_WRITE, + MapFlags::MAP_SHARED, + &fd, + 0, + ) + .map_err(|e| ShmemError::MapFailed(e.to_string()))? + }; + + // Take ownership of the fd - prevents OwnedFd from closing it on drop + let raw_fd = fd.into_raw_fd(); + + Ok(SharedMemory { + ptr: ptr.as_ptr() as *mut u8, + size, + name: full_name, + fd: raw_fd, + owner: true, + }) +} + +pub fn open(name: &str, size: usize) -> Result { + let full_name = make_name(name)?; + let c_name = CString::new(full_name.clone()) + .map_err(|_| ShmemError::InvalidName("Name contains null bytes".into()))?; + + // Open existing shared memory object + let fd = shm_open(c_name.as_c_str(), OFlag::O_RDWR, Mode::empty()) + .map_err(|e| ShmemError::OpenFailed(e.to_string()))?; + + // Validate the actual size of the shared memory object + let stat = fstat(fd.as_raw_fd()) + .map_err(|e| ShmemError::OpenFailed(format!("fstat failed: {}", e)))?; + let actual_size = stat.st_size as usize; + + if actual_size < size { + // OwnedFd will close the fd when dropped, no need to manually close + return Err(ShmemError::SizeMismatch { + expected: size, + actual: actual_size, + }); + } + + // Map memory + let ptr = unsafe { + mmap( + None, + NonZeroUsize::new(size).ok_or_else(|| ShmemError::OpenFailed("Size is zero".into()))?, + ProtFlags::PROT_READ | ProtFlags::PROT_WRITE, + MapFlags::MAP_SHARED, + &fd, + 0, + ) + .map_err(|e| ShmemError::MapFailed(e.to_string()))? + }; + + // Take ownership of the fd - prevents OwnedFd from closing it on drop + let raw_fd = fd.into_raw_fd(); + + Ok(SharedMemory { + ptr: ptr.as_ptr() as *mut u8, + size, + name: full_name, + fd: raw_fd, + owner: false, + }) +} + +pub fn cleanup(shmem: &mut SharedMemory) { + // Unmap memory + if !shmem.ptr.is_null() { + unsafe { + let ptr = std::ptr::NonNull::new(shmem.ptr as *mut _); + if let Some(ptr) = ptr { + let _ = munmap(ptr, shmem.size); + } + } + shmem.ptr = std::ptr::null_mut(); + } + + // Close file descriptor + if shmem.fd >= 0 { + let _ = nix::unistd::close(shmem.fd); + shmem.fd = -1; + } + + // Unlink if owner + if shmem.owner + && let Ok(c_name) = CString::new(shmem.name.clone()) + { + let _ = shm_unlink(c_name.as_c_str()); + } +} diff --git a/core/itk-shmem/src/windows_impl.rs b/core/itk-shmem/src/windows_impl.rs new file mode 100644 index 0000000..257c5f7 --- /dev/null +++ b/core/itk-shmem/src/windows_impl.rs @@ -0,0 +1,172 @@ +//! Windows shared memory implementation using CreateFileMappingW + +use super::{Result, SharedMemory, ShmemError}; +use std::ffi::OsStr; +use std::os::windows::ffi::OsStrExt; +use windows::Win32::Foundation::{ + CloseHandle, ERROR_ALREADY_EXISTS, GetLastError, HANDLE, INVALID_HANDLE_VALUE, +}; +use windows::Win32::System::Memory::{ + CreateFileMappingW, FILE_MAP_ALL_ACCESS, MEMORY_BASIC_INFORMATION, MEMORY_MAPPED_VIEW_ADDRESS, + MapViewOfFile, OpenFileMappingW, PAGE_READWRITE, UnmapViewOfFile, VirtualQuery, +}; + +/// Convert a Rust string to a Windows wide string +fn to_wide_string(s: &str) -> Vec { + OsStr::new(s).encode_wide().chain(Some(0)).collect() +} + +/// Validate and create the shared memory name (Local\ namespace for same-session access). +/// +/// For security, this function: +/// - Rejects names with path traversal sequences (..) +/// - Rejects names that look like full object paths (starting with Local\ or Global\) +/// - Rejects names with path separators +/// - Creates names only in Local\itk_* namespace +fn make_name(name: &str) -> Result { + // Reject path traversal attempts + if name.contains("..") { + return Err(ShmemError::InvalidName( + "name cannot contain path traversal sequences".into(), + )); + } + + // Reject names that look like they're trying to specify a full object path + if name.starts_with("Local\\") || name.starts_with("Global\\") { + return Err(ShmemError::InvalidName( + "name cannot be a full object path".into(), + )); + } + + // Reject names with path separators + if name.contains('\\') || name.contains('/') { + return Err(ShmemError::InvalidName( + "name cannot contain path separators".into(), + )); + } + + Ok(format!("Local\\itk_{}", name)) +} + +pub fn create(name: &str, size: usize) -> Result { + let full_name = make_name(name)?; + let wide_name = to_wide_string(&full_name); + + unsafe { + // Create file mapping backed by system paging file + // INVALID_HANDLE_VALUE (-1) is required to use the paging file + // HANDLE::default() (0) is NOT valid for this purpose + let handle = CreateFileMappingW( + INVALID_HANDLE_VALUE, // Use system paging file + None, // Default security + PAGE_READWRITE, + (size >> 32) as u32, + size as u32, + windows::core::PCWSTR(wide_name.as_ptr()), + ) + .map_err(|e| ShmemError::CreateFailed(e.to_string()))?; + + if handle.is_invalid() { + return Err(ShmemError::CreateFailed("Invalid handle returned".into())); + } + + // Check if we attached to an existing mapping instead of creating a new one. + // This prevents silently attaching to a stale region with wrong size/semantics. + if GetLastError() == ERROR_ALREADY_EXISTS { + CloseHandle(handle).ok(); + return Err(ShmemError::AlreadyExists); + } + + // Map view + let ptr = MapViewOfFile(handle, FILE_MAP_ALL_ACCESS, 0, 0, size); + + if ptr.Value.is_null() { + CloseHandle(handle).ok(); + return Err(ShmemError::MapFailed("MapViewOfFile returned null".into())); + } + + Ok(SharedMemory { + ptr: ptr.Value as *mut u8, + size, + name: full_name, + handle, + owner: true, + }) + } +} + +pub fn open(name: &str, size: usize) -> Result { + let full_name = make_name(name)?; + let wide_name = to_wide_string(&full_name); + + unsafe { + // Open existing file mapping + let handle = OpenFileMappingW( + FILE_MAP_ALL_ACCESS.0, + false, + windows::core::PCWSTR(wide_name.as_ptr()), + ) + .map_err(|e| ShmemError::OpenFailed(e.to_string()))?; + + if handle.is_invalid() { + return Err(ShmemError::NotFound); + } + + // Map the entire section first (size=0) to query its actual size + let ptr = MapViewOfFile(handle, FILE_MAP_ALL_ACCESS, 0, 0, 0); + + if ptr.Value.is_null() { + CloseHandle(handle).ok(); + return Err(ShmemError::MapFailed("MapViewOfFile returned null".into())); + } + + // Query the actual region size + let mut mbi: MEMORY_BASIC_INFORMATION = std::mem::zeroed(); + let query_result = VirtualQuery( + Some(ptr.Value), + &mut mbi, + std::mem::size_of::(), + ); + + if query_result == 0 { + UnmapViewOfFile(ptr).ok(); + CloseHandle(handle).ok(); + return Err(ShmemError::OpenFailed("VirtualQuery failed".into())); + } + + let actual_size = mbi.RegionSize; + + if actual_size < size { + UnmapViewOfFile(ptr).ok(); + CloseHandle(handle).ok(); + return Err(ShmemError::SizeMismatch { + expected: size, + actual: actual_size, + }); + } + + Ok(SharedMemory { + ptr: ptr.Value as *mut u8, + size, + name: full_name, + handle, + owner: false, + }) + } +} + +pub fn cleanup(shmem: &mut SharedMemory) { + unsafe { + if !shmem.ptr.is_null() { + let _ = UnmapViewOfFile(MEMORY_MAPPED_VIEW_ADDRESS { + Value: shmem.ptr as *mut _, + }); + shmem.ptr = std::ptr::null_mut(); + } + + if !shmem.handle.is_invalid() { + let _ = CloseHandle(shmem.handle); + shmem.handle = HANDLE::default(); + } + } +} diff --git a/core/itk-sync/Cargo.toml b/core/itk-sync/Cargo.toml new file mode 100644 index 0000000..54b348e --- /dev/null +++ b/core/itk-sync/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "itk-sync" +description = "Clock synchronization and drift correction for the Injection Toolkit" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +authors.workspace = true + +[dependencies] +itk-protocol = { path = "../itk-protocol" } +thiserror = { workspace = true } +tracing = { workspace = true } + +[lints] +workspace = true diff --git a/core/itk-sync/src/lib.rs b/core/itk-sync/src/lib.rs new file mode 100644 index 0000000..78edb8a --- /dev/null +++ b/core/itk-sync/src/lib.rs @@ -0,0 +1,446 @@ +//! # ITK Sync +//! +//! Clock synchronization and drift correction for multiplayer scenarios. +//! +//! This crate provides: +//! - NTP-lite clock offset estimation +//! - Reference-point based position calculation +//! - Drift correction with smooth rate adjustment +//! +//! ## Sync Model +//! +//! Rather than constantly syncing "current position", we sync a reference point: +//! +//! ```text +//! SyncState { +//! position_at_ref_ms: u64, // Position at reference time +//! ref_wallclock_ms: u64, // When that position was valid +//! is_playing: bool, +//! playback_rate: f64, // Usually 1.0 +//! } +//! ``` +//! +//! Each client computes current position locally: +//! ```text +//! current_pos = position_at_ref + (now - ref_time) * rate +//! ``` + +use std::collections::VecDeque; +use std::time::{SystemTime, UNIX_EPOCH}; + +#[cfg(test)] +use std::time::Duration; +use thiserror::Error; + +/// Sync errors +#[derive(Error, Debug)] +pub enum SyncError { + #[error("clock offset not yet estimated")] + NoClockOffset, + + #[error("insufficient samples for estimation")] + InsufficientSamples, + + #[error("reference time is in the future")] + FutureReference, + + #[error("time conversion overflow: result would be negative or exceed u64::MAX")] + TimeConversionOverflow, +} + +/// Result type for sync operations +pub type Result = std::result::Result; + +/// Get current time in milliseconds since UNIX epoch +pub fn now_ms() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_millis() as u64 +} + +/// Clock synchronization state +/// +/// Uses NTP-lite algorithm to estimate offset between local and remote clocks. +pub struct ClockSync { + /// Estimated offset: remote_time = local_time + offset + offset_ms: Option, + + /// Recent RTT samples for filtering + samples: VecDeque, + + /// Maximum samples to keep + max_samples: usize, +} + +#[derive(Debug, Clone, Copy)] +struct ClockSample { + offset_ms: i64, + rtt_ms: u64, +} + +impl ClockSync { + /// Create a new clock sync instance + pub fn new() -> Self { + Self { + offset_ms: None, + samples: VecDeque::with_capacity(10), + max_samples: 10, + } + } + + /// Process a ping/pong exchange + /// + /// # Arguments + /// * `send_time_ms` - Local time when ping was sent + /// * `remote_time_ms` - Remote time when pong was created + /// * `recv_time_ms` - Local time when pong was received + pub fn process_pong(&mut self, send_time_ms: u64, remote_time_ms: u64, recv_time_ms: u64) { + let rtt_ms = recv_time_ms.saturating_sub(send_time_ms); + let one_way_ms = rtt_ms / 2; + + // Estimate: at what local time was remote_time_ms? + // Answer: send_time_ms + one_way_ms + // So: offset = remote_time_ms - (send_time_ms + one_way_ms) + let estimated_local_at_remote = send_time_ms + one_way_ms; + let offset_ms = remote_time_ms as i64 - estimated_local_at_remote as i64; + + let sample = ClockSample { offset_ms, rtt_ms }; + + // Add sample + if self.samples.len() >= self.max_samples { + self.samples.pop_front(); + } + self.samples.push_back(sample); + + // Update estimated offset using median of recent samples + // (median is more robust to outliers than mean) + self.update_offset(); + } + + fn update_offset(&mut self) { + if self.samples.is_empty() { + return; + } + + // Use samples with lowest RTT (most reliable) + let mut sorted: Vec<_> = self.samples.iter().collect(); + sorted.sort_by_key(|s| s.rtt_ms); + + // Take median of best half + let best_count = sorted.len().div_ceil(2); + let best: Vec<_> = sorted.into_iter().take(best_count).collect(); + + // Median offset + let mut offsets: Vec<_> = best.iter().map(|s| s.offset_ms).collect(); + offsets.sort(); + + let median_offset = if offsets.len() % 2 == 0 { + (offsets[offsets.len() / 2 - 1] + offsets[offsets.len() / 2]) / 2 + } else { + offsets[offsets.len() / 2] + }; + + self.offset_ms = Some(median_offset); + } + + /// Get estimated clock offset + /// + /// `remote_time = local_time + offset` + pub fn offset_ms(&self) -> Option { + self.offset_ms + } + + /// Convert local time to estimated remote time + /// + /// Returns an error if the result would overflow or underflow. + pub fn local_to_remote(&self, local_ms: u64) -> Result { + let offset = self.offset_ms.ok_or(SyncError::NoClockOffset)?; + + // Use checked arithmetic to handle both positive and negative offsets + if offset >= 0 { + local_ms + .checked_add(offset as u64) + .ok_or(SyncError::TimeConversionOverflow) + } else { + let abs_offset = offset.unsigned_abs(); + local_ms + .checked_sub(abs_offset) + .ok_or(SyncError::TimeConversionOverflow) + } + } + + /// Convert remote time to estimated local time + /// + /// Returns an error if the result would overflow or underflow. + pub fn remote_to_local(&self, remote_ms: u64) -> Result { + let offset = self.offset_ms.ok_or(SyncError::NoClockOffset)?; + + // Use checked arithmetic to handle both positive and negative offsets + if offset >= 0 { + remote_ms + .checked_sub(offset as u64) + .ok_or(SyncError::TimeConversionOverflow) + } else { + let abs_offset = offset.unsigned_abs(); + remote_ms + .checked_add(abs_offset) + .ok_or(SyncError::TimeConversionOverflow) + } + } + + /// Check if clock is synchronized + pub fn is_synced(&self) -> bool { + self.offset_ms.is_some() + } + + /// Clear all samples and reset + pub fn reset(&mut self) { + self.offset_ms = None; + self.samples.clear(); + } +} + +impl Default for ClockSync { + fn default() -> Self { + Self::new() + } +} + +/// Playback synchronization state +#[derive(Debug, Clone)] +pub struct PlaybackSync { + /// Content identifier (URL, file hash, etc.) + pub content_id: String, + + /// Position at reference time (milliseconds into content) + pub position_at_ref_ms: u64, + + /// Reference wallclock time (milliseconds since epoch) + pub ref_wallclock_ms: u64, + + /// Whether currently playing + pub is_playing: bool, + + /// Playback rate (1.0 = normal, 0.95-1.05 for drift correction) + pub playback_rate: f64, +} + +impl PlaybackSync { + /// Create a new playback sync state + pub fn new(content_id: String) -> Self { + Self { + content_id, + position_at_ref_ms: 0, + ref_wallclock_ms: now_ms(), + is_playing: false, + playback_rate: 1.0, + } + } + + /// Calculate current position based on reference point + pub fn current_position_ms(&self) -> u64 { + if !self.is_playing { + return self.position_at_ref_ms; + } + + let elapsed = now_ms().saturating_sub(self.ref_wallclock_ms); + let adjusted_elapsed = (elapsed as f64 * self.playback_rate) as u64; + + // Use saturating_add to prevent overflow on large values + self.position_at_ref_ms.saturating_add(adjusted_elapsed) + } + + /// Update from a received sync state (e.g., from network) + pub fn update_from(&mut self, other: &PlaybackSync) { + self.content_id = other.content_id.clone(); + self.position_at_ref_ms = other.position_at_ref_ms; + self.ref_wallclock_ms = other.ref_wallclock_ms; + self.is_playing = other.is_playing; + // Don't copy playback_rate - that's local drift correction + } + + /// Set playing state + pub fn set_playing(&mut self, playing: bool) { + // When changing state, update reference point to current position + self.position_at_ref_ms = self.current_position_ms(); + self.ref_wallclock_ms = now_ms(); + self.is_playing = playing; + } + + /// Seek to a position + pub fn seek(&mut self, position_ms: u64) { + self.position_at_ref_ms = position_ms; + self.ref_wallclock_ms = now_ms(); + } +} + +/// Drift correction calculator +pub struct DriftCorrector { + /// Target position (from sync leader) + target_sync: Option, + + /// Clock sync for time conversion + clock_sync: ClockSync, +} + +impl DriftCorrector { + /// Create a new drift corrector + pub fn new() -> Self { + Self { + target_sync: None, + clock_sync: ClockSync::new(), + } + } + + /// Update target sync state from leader + pub fn update_target(&mut self, sync: PlaybackSync) { + self.target_sync = Some(sync); + } + + /// Get the clock sync instance for processing pongs + pub fn clock_sync_mut(&mut self) -> &mut ClockSync { + &mut self.clock_sync + } + + /// Calculate recommended playback rate to correct drift + /// + /// Returns the rate adjustment (1.0 = no change) + pub fn calculate_rate(&self, current_position_ms: u64) -> f64 { + let Some(target) = &self.target_sync else { + return 1.0; + }; + + if !target.is_playing { + return 1.0; + } + + // Calculate target position, correcting for clock offset between local and remote + let target_position = self.target_position_ms(target); + + // Calculate drift (positive = we're ahead, negative = we're behind) + let drift_ms = current_position_ms as i64 - target_position as i64; + + // Apply correction based on drift magnitude + match drift_ms.abs() { + 0..=150 => 1.0, // Within tolerance, no correction + 151..=500 => { + // Gentle correction + if drift_ms > 0 { 0.98 } else { 1.02 } + }, + 501..=1500 => { + // Moderate correction + if drift_ms > 0 { 0.95 } else { 1.05 } + }, + _ => { + // Large drift - recommend hard seek instead + 0.0 + }, + } + } + + /// Check if a hard seek is recommended (drift too large) + pub fn should_seek(&self, current_position_ms: u64) -> Option { + let Some(target) = &self.target_sync else { + return None; + }; + + let target_position = self.target_position_ms(target); + let drift_ms = (current_position_ms as i64 - target_position as i64).abs(); + + if drift_ms > 1500 { + Some(target_position) + } else { + None + } + } + + /// Get current drift in milliseconds (positive = ahead, negative = behind) + pub fn current_drift_ms(&self, current_position_ms: u64) -> Option { + let target = self.target_sync.as_ref()?; + let target_position = self.target_position_ms(target); + Some(current_position_ms as i64 - target_position as i64) + } + + /// Calculate target position corrected for clock offset. + /// + /// The target's `ref_wallclock_ms` is in the remote clock's time domain. + /// We convert it to local time using the estimated clock offset before + /// computing elapsed time since the reference point. + fn target_position_ms(&self, target: &PlaybackSync) -> u64 { + if !target.is_playing { + return target.position_at_ref_ms; + } + + // Convert remote reference wallclock to local time using clock offset. + // If no offset is available yet, fall back to using the remote time directly + // (which may be slightly inaccurate but avoids blocking on sync). + let local_ref_time = match self.clock_sync.remote_to_local(target.ref_wallclock_ms) { + Ok(local_time) => local_time, + Err(_) => target.ref_wallclock_ms, + }; + + let elapsed = now_ms().saturating_sub(local_ref_time); + let adjusted_elapsed = (elapsed as f64 * target.playback_rate) as u64; + + target.position_at_ref_ms.saturating_add(adjusted_elapsed) + } +} + +impl Default for DriftCorrector { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_clock_sync_basic() { + let mut sync = ClockSync::new(); + + // Simulate ping/pong with 100ms RTT and 50ms offset + // local send: 1000, remote: 1050, local recv: 1100 + sync.process_pong(1000, 1050, 1100); + + // Offset should be approximately 0 (50 - 50 = 0) + // local_at_remote = 1000 + 50 = 1050 + // offset = 1050 - 1050 = 0 + assert!(sync.is_synced()); + assert_eq!(sync.offset_ms(), Some(0)); + } + + #[test] + fn test_playback_sync_position() { + let mut sync = PlaybackSync::new("test".to_string()); + sync.position_at_ref_ms = 10000; + sync.ref_wallclock_ms = now_ms() - 1000; // 1 second ago + sync.is_playing = true; + sync.playback_rate = 1.0; + + // Should be approximately 11000 (10000 + 1000) + let pos = sync.current_position_ms(); + assert!((10900..=11100).contains(&pos)); + } + + #[test] + fn test_drift_correction_rates() { + let corrector = DriftCorrector::new(); + + // No target = no correction + assert_eq!(corrector.calculate_rate(5000), 1.0); + } + + #[test] + fn test_playback_sync_paused() { + let mut sync = PlaybackSync::new("test".to_string()); + sync.position_at_ref_ms = 10000; + sync.is_playing = false; + + // Paused - position doesn't change + std::thread::sleep(Duration::from_millis(100)); + assert_eq!(sync.current_position_ms(), 10000); + } +} diff --git a/core/itk-video/Cargo.toml b/core/itk-video/Cargo.toml new file mode 100644 index 0000000..915fb6b --- /dev/null +++ b/core/itk-video/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "itk-video" +description = "Video decoding and frame management for the Injection Toolkit" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +authors.workspace = true + +[features] +default = [] +youtube = ["tokio/process"] +# Build ffmpeg from source (slow, but works without system ffmpeg) +build-ffmpeg = ["ffmpeg-next/build"] + +[dependencies] +# Internal dependencies +itk-protocol = { path = "../itk-protocol" } +itk-shmem = { path = "../itk-shmem" } +itk-sync = { path = "../itk-sync" } + +# Video decoding +ffmpeg-next = { workspace = true } + +# Async runtime +tokio = { workspace = true } + +# Error handling +thiserror = { workspace = true } +anyhow = { workspace = true } + +# Logging +tracing = { workspace = true } + +[dev-dependencies] +tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } + +[lints] +workspace = true diff --git a/core/itk-video/src/decoder.rs b/core/itk-video/src/decoder.rs new file mode 100644 index 0000000..8262dae --- /dev/null +++ b/core/itk-video/src/decoder.rs @@ -0,0 +1,361 @@ +//! Video decoder using ffmpeg-next. + +use crate::error::{VideoError, VideoResult}; +use crate::hwaccel::HwDeviceContext; +use crate::scaler::FrameScaler; +use crate::stream::StreamSource; +use crate::{DEFAULT_HEIGHT, DEFAULT_WIDTH}; +use ffmpeg_next::format::context::Input as FormatContext; +use ffmpeg_next::media::Type as MediaType; +use ffmpeg_next::util::frame::video::Video as VideoFrame; +use ffmpeg_next::{Codec, Packet, Rational, codec, decoder}; +use std::sync::Once; +use tracing::{debug, info, warn}; + +/// Initialize ffmpeg (called once). +static FFMPEG_INIT: Once = Once::new(); + +fn init_ffmpeg() { + FFMPEG_INIT.call_once(|| { + ffmpeg_next::init().expect("failed to initialize ffmpeg"); + info!("ffmpeg initialized"); + }); +} + +/// A decoded video frame with metadata. +#[derive(Debug)] +pub struct DecodedFrame { + /// The presentation timestamp in milliseconds. + pub pts_ms: u64, + /// The frame data in RGBA format. + pub data: Vec, + /// Frame width. + pub width: u32, + /// Frame height. + pub height: u32, +} + +/// Video decoder that reads frames from a stream source. +pub struct VideoDecoder { + format_ctx: FormatContext, + decoder: decoder::Video, + video_stream_index: usize, + time_base: Rational, + scaler: FrameScaler, + /// Duration in milliseconds, if known. + duration_ms: Option, + /// Frames per second, if known. + fps: Option, + /// Reusable frame buffer. + frame: VideoFrame, + /// Reusable packet. + _packet: Packet, + /// Hardware device context (kept alive for the decoder's lifetime). + _hw_ctx: Option, + /// Whether hardware acceleration is active. + hw_accel_active: bool, +} + +impl VideoDecoder { + /// Create a new decoder for the given source. + pub fn new(source: StreamSource) -> VideoResult { + Self::with_size(source, DEFAULT_WIDTH, DEFAULT_HEIGHT) + } + + /// Create a new decoder with custom output dimensions. + pub fn with_size(source: StreamSource, width: u32, height: u32) -> VideoResult { + init_ffmpeg(); + + // Handle YouTube URLs + #[allow(unused_mut)] + let mut actual_source = source.clone(); + + if source.is_youtube() { + #[cfg(feature = "youtube")] + { + // YouTube extraction happens asynchronously, so we need to use block_on + // In practice, the caller should handle this before calling new() + return Err(VideoError::YoutubeExtraction( + "call youtube::extract_url() first, then pass the direct URL".to_string(), + )); + } + #[cfg(not(feature = "youtube"))] + { + return Err(VideoError::YoutubeNotEnabled); + } + } + + let input_path = actual_source.as_ffmpeg_input(); + debug!(path = %input_path, "opening video source"); + + // Open the input format context + let format_ctx = ffmpeg_next::format::input(&input_path) + .map_err(|e| VideoError::OpenFailed(format!("{}: {}", input_path, e)))?; + + // Find the best video stream + let video_stream = format_ctx + .streams() + .best(MediaType::Video) + .ok_or(VideoError::NoVideoStream)?; + + let video_stream_index = video_stream.index(); + let time_base = video_stream.time_base(); + + // Get duration if available + let duration_ms = if format_ctx.duration() > 0 { + Some((format_ctx.duration() as u64) / 1000) // Convert from microseconds + } else { + None + }; + + // Calculate FPS + let fps = { + let rate = video_stream.avg_frame_rate(); + if rate.denominator() > 0 { + Some(rate.numerator() as f64 / rate.denominator() as f64) + } else { + None + } + }; + + // Get decoder parameters + let codec_params = video_stream.parameters(); + let codec_id = codec_params.id(); + + // Try D3D11VA hardware acceleration first + let hw_ctx = HwDeviceContext::create_d3d11va(); + let hw_accel_active; + + // Find the decoder - when hw accel is available, use default decoder + // (it will use the hw device context). Otherwise prefer software decoders. + let codec = if hw_ctx.is_some() { + decoder::find(codec_id) + .ok_or_else(|| VideoError::NoDecoder(format!("{:?}", codec_id)))? + } else { + find_software_decoder(codec_id) + .or_else(|| decoder::find(codec_id)) + .ok_or_else(|| VideoError::NoDecoder(format!("{:?}", codec_id)))? + }; + + debug!(codec_name = ?codec.name(), hw_accel = hw_ctx.is_some(), "using decoder"); + + // Create decoder context + let mut decoder_ctx = codec::context::Context::new_with_codec(codec); + decoder_ctx.set_parameters(codec_params)?; + + // Set hardware device context if available + if let Some(ref hw) = hw_ctx { + unsafe { + let raw = decoder_ctx.as_mut_ptr(); + (*raw).hw_device_ctx = hw.new_ref(); + } + hw_accel_active = true; + info!("Hardware acceleration enabled (D3D11VA)"); + } else { + // Software mode: set thread_count for better performance + unsafe { + let raw = decoder_ctx.as_mut_ptr(); + (*raw).thread_count = 0; // auto-detect + } + hw_accel_active = false; + } + + let decoder = decoder_ctx.decoder().video()?; + + info!( + width = decoder.width(), + height = decoder.height(), + fps = ?fps, + duration_ms = ?duration_ms, + codec = ?codec_id, + hw_accel = hw_accel_active, + "video decoder initialized" + ); + + Ok(Self { + format_ctx, + decoder, + video_stream_index, + time_base, + scaler: FrameScaler::with_size(width, height), + duration_ms, + fps, + frame: VideoFrame::empty(), + _packet: Packet::empty(), + _hw_ctx: hw_ctx, + hw_accel_active, + }) + } + + /// Get the video duration in milliseconds, if known. + pub fn duration_ms(&self) -> Option { + self.duration_ms + } + + /// Get the video FPS, if known. + pub fn fps(&self) -> Option { + self.fps + } + + /// Get the source video width. + pub fn source_width(&self) -> u32 { + self.decoder.width() + } + + /// Get the source video height. + pub fn source_height(&self) -> u32 { + self.decoder.height() + } + + /// Get the output width (after scaling). + pub fn output_width(&self) -> u32 { + self.scaler.width() + } + + /// Get the output height (after scaling). + pub fn output_height(&self) -> u32 { + self.scaler.height() + } + + /// Decode and return the next frame. + /// + /// Returns `None` when the end of the stream is reached. + pub fn next_frame(&mut self) -> VideoResult> { + loop { + // Try to receive a decoded frame first + match self.decoder.receive_frame(&mut self.frame) { + Ok(()) => { + // Transfer from GPU to CPU memory if using hardware acceleration + if self.hw_accel_active { + unsafe { + crate::hwaccel::transfer_hw_frame_if_needed(self.frame.as_mut_ptr()); + } + } + + // Calculate PTS in milliseconds + let pts = self.frame.pts().unwrap_or(0); + let pts_ms = self.pts_to_ms(pts); + + // Scale the frame to output resolution + let scaled_data = self.scaler.scale(&self.frame)?; + + return Ok(Some(DecodedFrame { + pts_ms, + data: scaled_data.to_vec(), + width: self.scaler.width(), + height: self.scaler.height(), + })); + }, + Err(ffmpeg_next::Error::Other { errno }) if errno == ffmpeg_next::error::EAGAIN => { + // Need more input, continue to read packets + }, + Err(ffmpeg_next::Error::Eof) => { + // End of stream + return Ok(None); + }, + Err(e) => { + return Err(VideoError::DecodeError(e.to_string())); + }, + } + + // Read the next packet + match self.format_ctx.packets().next() { + Some((stream, packet)) => { + if stream.index() == self.video_stream_index { + // Send packet to decoder + self.decoder.send_packet(&packet)?; + } + // Non-video packets are ignored + }, + None => { + // No more packets, flush the decoder + self.decoder.send_eof()?; + }, + } + } + } + + /// Seek to a position in milliseconds. + pub fn seek(&mut self, position_ms: u64) -> VideoResult<()> { + // format_ctx.seek() with stream_index=-1 expects AV_TIME_BASE units (microseconds) + let timestamp_us = (position_ms as i64) * 1000; + + // Seek to the nearest keyframe before the target + self.format_ctx + .seek(timestamp_us, ..timestamp_us) + .map_err(|e| VideoError::SeekError(e.to_string()))?; + + // Flush the decoder + self.decoder.flush(); + + debug!(position_ms, "seeked to position"); + Ok(()) + } + + /// Convert a PTS value to milliseconds. + fn pts_to_ms(&self, pts: i64) -> u64 { + if pts < 0 { + return 0; + } + + let num = self.time_base.numerator() as i64; + let den = self.time_base.denominator() as i64; + + if den == 0 { + return 0; + } + + // pts * (num / den) * 1000 = pts * num * 1000 / den + ((pts * num * 1000) / den) as u64 + } + + /// Convert milliseconds to PTS value. + fn _ms_to_pts(&self, ms: u64) -> i64 { + let num = self.time_base.numerator() as i64; + let den = self.time_base.denominator() as i64; + + if num == 0 { + return 0; + } + + // ms / 1000 * (den / num) = ms * den / (1000 * num) + ((ms as i64) * den) / (1000 * num) + } +} + +/// Try to find a software decoder for the given codec ID. +/// +/// For codecs like AV1 that often have broken hardware acceleration on Windows, +/// we prefer known software decoders (libdav1d, libaom-av1) over the default +/// decoder which may attempt hardware acceleration and fail. +fn find_software_decoder(codec_id: codec::Id) -> Option { + let software_names: &[&str] = match codec_id { + codec::Id::AV1 => &["libdav1d", "libaom-av1"], + codec::Id::H264 => &["h264"], // native software decoder + codec::Id::HEVC => &["hevc"], + _ => return None, + }; + + for name in software_names { + if let Some(dec) = decoder::find_by_name(name) { + debug!(codec = ?codec_id, decoder = name, "found software decoder"); + return Some(dec); + } + } + + warn!(codec = ?codec_id, "no software decoder found, will try default"); + None +} + +#[cfg(test)] +mod tests { + #[allow(unused_imports)] + use super::*; + + #[test] + fn test_pts_conversion() { + // This test verifies the PTS conversion logic + // With time_base = 1/1000, pts_to_ms should be identity + // We can't easily test this without a real decoder, but the math is straightforward + } +} diff --git a/core/itk-video/src/error.rs b/core/itk-video/src/error.rs new file mode 100644 index 0000000..a310c8c --- /dev/null +++ b/core/itk-video/src/error.rs @@ -0,0 +1,77 @@ +//! Error types for video operations. + +use thiserror::Error; + +/// Result type for video operations. +pub type VideoResult = Result; + +/// Errors that can occur during video operations. +#[derive(Debug, Error)] +pub enum VideoError { + /// Failed to open the video source. + #[error("failed to open video source: {0}")] + OpenFailed(String), + + /// Failed to find a video stream in the source. + #[error("no video stream found in source")] + NoVideoStream, + + /// Failed to find a suitable decoder. + #[error("no decoder found for codec: {0}")] + NoDecoder(String), + + /// Failed to decode a frame. + #[error("decode error: {0}")] + DecodeError(String), + + /// Failed to scale/convert a frame. + #[error("scaling error: {0}")] + ScaleError(String), + + /// Failed to seek in the video. + #[error("seek error: {0}")] + SeekError(String), + + /// End of stream reached. + #[error("end of stream")] + EndOfStream, + + /// Invalid frame data. + #[error("invalid frame data: {0}")] + InvalidFrame(String), + + /// Shared memory error. + #[error("shared memory error: {0}")] + SharedMemory(#[from] itk_shmem::ShmemError), + + /// I/O error. + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + + /// FFmpeg error. + #[error("ffmpeg error: {0}")] + Ffmpeg(String), + + /// YouTube extraction error (only with `youtube` feature). + #[cfg(feature = "youtube")] + #[error("YouTube extraction failed: {0}")] + YoutubeExtraction(String), + + /// YouTube feature not enabled. + #[error("YouTube support not enabled (build with --features youtube)")] + YoutubeNotEnabled, + + /// Invalid URL format. + #[error("invalid URL: {0}")] + InvalidUrl(String), + + /// Unsupported pixel format. + #[error("unsupported pixel format: {0:?}")] + UnsupportedPixelFormat(String), +} + +impl From for VideoError { + fn from(err: ffmpeg_next::Error) -> Self { + VideoError::Ffmpeg(err.to_string()) + } +} diff --git a/core/itk-video/src/frame_writer.rs b/core/itk-video/src/frame_writer.rs new file mode 100644 index 0000000..f54b265 --- /dev/null +++ b/core/itk-video/src/frame_writer.rs @@ -0,0 +1,234 @@ +//! Frame writer for shared memory output. + +use crate::decoder::DecodedFrame; +use crate::error::{VideoError, VideoResult}; +use crate::{DEFAULT_HEIGHT, DEFAULT_WIDTH}; +use itk_shmem::FrameBuffer; +use std::hash::{Hash, Hasher}; +use tracing::{debug, info, trace}; + +/// Default shared memory name for video frames. +pub const DEFAULT_SHMEM_NAME: &str = "itk_video_frames"; + +/// Writes decoded video frames to a shared memory buffer. +pub struct FrameWriter { + buffer: FrameBuffer, + last_pts_ms: u64, + content_id_hash: u64, +} + +impl FrameWriter { + /// Create a new frame writer with a new shared memory region. + /// + /// If the shared memory already exists (e.g., a previous daemon left it behind + /// while the injector still holds it open), opens the existing region instead. + /// This allows the daemon to restart and continue writing to the same shmem + /// the injector is reading from. + pub fn create(name: &str, width: u32, height: u32) -> VideoResult { + let buffer = match FrameBuffer::create(name, width, height) { + Ok(buf) => buf, + Err(itk_shmem::ShmemError::AlreadyExists) => { + // Fall back to opening existing region (injector may still hold it) + info!("Shared memory already exists, reusing existing region"); + FrameBuffer::open(name, width, height)? + }, + Err(e) => return Err(e.into()), + }; + Ok(Self { + buffer, + last_pts_ms: 0, + content_id_hash: 0, + }) + } + + /// Create a new frame writer with the default name and 720p resolution. + pub fn create_default() -> VideoResult { + Self::create(DEFAULT_SHMEM_NAME, DEFAULT_WIDTH, DEFAULT_HEIGHT) + } + + /// Open an existing shared memory region for writing. + pub fn open(name: &str, width: u32, height: u32) -> VideoResult { + let buffer = FrameBuffer::open(name, width, height)?; + Ok(Self { + buffer, + last_pts_ms: 0, + content_id_hash: 0, + }) + } + + /// Set the content ID for change detection. + /// + /// This should be called when loading a new video. + /// The hash is used by readers to detect content changes. + pub fn set_content_id(&mut self, content_id: &str) { + self.content_id_hash = hash_content_id(content_id); + debug!(content_id, hash = self.content_id_hash, "content ID set"); + } + + /// Set the total duration in milliseconds. + /// + /// This should be called when loading a new video so readers + /// can display seek bars and time remaining. + pub fn set_duration_ms(&self, duration_ms: u64) { + self.buffer.set_duration_ms(duration_ms); + } + + /// Write a decoded frame to shared memory. + /// + /// Uses frame-skip optimization: if the PTS hasn't changed (e.g., video is paused), + /// the write is skipped to avoid unnecessary memory copies. + /// + /// Returns `true` if the frame was written, `false` if skipped. + pub fn write_frame(&mut self, frame: &DecodedFrame) -> VideoResult { + // Frame-skip optimization: don't write if PTS hasn't changed + if frame.pts_ms == self.last_pts_ms && self.last_pts_ms > 0 { + trace!(pts_ms = frame.pts_ms, "skipping duplicate frame"); + return Ok(false); + } + + // Verify frame dimensions match buffer + let expected_size = (frame.width as usize) * (frame.height as usize) * 4; + if frame.data.len() != expected_size { + return Err(VideoError::InvalidFrame(format!( + "frame size mismatch: expected {} bytes, got {}", + expected_size, + frame.data.len() + ))); + } + + // Write to the shared memory buffer + unsafe { + self.buffer + .write_frame(&frame.data, frame.pts_ms, self.content_id_hash)?; + } + + self.last_pts_ms = frame.pts_ms; + trace!(pts_ms = frame.pts_ms, "wrote frame to shmem"); + Ok(true) + } + + /// Write raw frame data to shared memory. + /// + /// This is a lower-level API for when you have raw RGBA data. + pub fn write_raw(&mut self, data: &[u8], pts_ms: u64) -> VideoResult { + // Frame-skip optimization + if pts_ms == self.last_pts_ms && self.last_pts_ms > 0 { + return Ok(false); + } + + unsafe { + self.buffer + .write_frame(data, pts_ms, self.content_id_hash)?; + } + + self.last_pts_ms = pts_ms; + Ok(true) + } + + /// Get the last written PTS in milliseconds. + pub fn last_pts_ms(&self) -> u64 { + self.last_pts_ms + } + + /// Get the content ID hash. + pub fn content_id_hash(&self) -> u64 { + self.content_id_hash + } + + /// Get the frame buffer's width. + pub fn width(&self) -> u32 { + self.buffer.width() + } + + /// Get the frame buffer's height. + pub fn height(&self) -> u32 { + self.buffer.height() + } +} + +/// Hash a content ID string to a u64. +fn hash_content_id(content_id: &str) -> u64 { + use std::collections::hash_map::DefaultHasher; + let mut hasher = DefaultHasher::new(); + content_id.hash(&mut hasher); + hasher.finish() +} + +/// Frame reader for shared memory input. +/// +/// This is used by the overlay to read frames written by the daemon. +pub struct FrameReader { + buffer: FrameBuffer, + last_pts_ms: u64, + frame_data: Vec, +} + +impl FrameReader { + /// Open an existing shared memory region for reading. + pub fn open(name: &str, width: u32, height: u32) -> VideoResult { + let buffer = FrameBuffer::open(name, width, height)?; + let frame_size = (width as usize) * (height as usize) * 4; + Ok(Self { + buffer, + last_pts_ms: 0, + frame_data: vec![0u8; frame_size], + }) + } + + /// Open with the default name and 720p resolution. + pub fn open_default() -> VideoResult { + Self::open(DEFAULT_SHMEM_NAME, DEFAULT_WIDTH, DEFAULT_HEIGHT) + } + + /// Try to read the latest frame. + /// + /// Returns `Some((pts_ms, data))` if a new frame is available, + /// `None` if the frame hasn't changed since the last read. + pub fn read_frame(&mut self) -> VideoResult> { + let (pts_ms, changed) = self + .buffer + .read_frame(self.last_pts_ms, &mut self.frame_data)?; + + if changed { + self.last_pts_ms = pts_ms; + Ok(Some((pts_ms, &self.frame_data))) + } else { + Ok(None) + } + } + + /// Get the last read PTS in milliseconds. + pub fn last_pts_ms(&self) -> u64 { + self.last_pts_ms + } + + /// Get the frame buffer's width. + pub fn width(&self) -> u32 { + self.buffer.width() + } + + /// Get the frame buffer's height. + pub fn height(&self) -> u32 { + self.buffer.height() + } + + /// Get a reference to the current frame data. + pub fn current_frame(&self) -> &[u8] { + &self.frame_data + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_content_id_hash() { + let hash1 = hash_content_id("test_video.mp4"); + let hash2 = hash_content_id("test_video.mp4"); + let hash3 = hash_content_id("other_video.mp4"); + + assert_eq!(hash1, hash2); + assert_ne!(hash1, hash3); + } +} diff --git a/core/itk-video/src/hwaccel.rs b/core/itk-video/src/hwaccel.rs new file mode 100644 index 0000000..04ae89c --- /dev/null +++ b/core/itk-video/src/hwaccel.rs @@ -0,0 +1,144 @@ +//! Hardware-accelerated video decoding via D3D11VA. +//! +//! This module provides optional GPU-accelerated decoding using Direct3D 11 +//! Video Acceleration. When available (e.g., NVIDIA RTX series, Intel, AMD), +//! this offloads video decode from the CPU to the GPU's dedicated decoder. +//! +//! Falls back gracefully to software decoding if hardware acceleration +//! is not available. + +use ffmpeg_next::ffi; +use std::ptr; +use std::sync::atomic::{AtomicBool, Ordering}; +use tracing::{debug, info, trace, warn}; + +/// Manages a D3D11VA hardware device context for accelerated decoding. +/// +/// When attached to a decoder's `AVCodecContext`, ffmpeg will attempt to +/// use D3D11VA hardware acceleration for supported codecs (H.264, HEVC, AV1, VP9). +pub struct HwDeviceContext { + device_ctx: *mut ffi::AVBufferRef, +} + +// Safety: The AVBufferRef is reference-counted and thread-safe in ffmpeg. +// The D3D11 device it wraps is created with default threading settings. +unsafe impl Send for HwDeviceContext {} + +impl HwDeviceContext { + /// Try to create a D3D11VA hardware device context. + /// + /// Returns `None` if D3D11VA is not available on this system. + /// This is expected on systems without a GPU or with incompatible drivers. + pub fn create_d3d11va() -> Option { + unsafe { + let mut device_ctx: *mut ffi::AVBufferRef = ptr::null_mut(); + let ret = ffi::av_hwdevice_ctx_create( + &mut device_ctx, + ffi::AVHWDeviceType::AV_HWDEVICE_TYPE_D3D11VA, + ptr::null(), // Use default device + ptr::null_mut(), // No options + 0, // No flags + ); + + if ret < 0 || device_ctx.is_null() { + debug!(error_code = ret, "D3D11VA device context creation failed"); + return None; + } + + info!("D3D11VA hardware device context created"); + Some(Self { device_ctx }) + } + } + + /// Create a new buffer reference to the device context. + /// + /// The returned pointer is a new reference that the caller owns. + /// It should be assigned to `AVCodecContext.hw_device_ctx`. + /// + /// # Safety + /// The caller must ensure `self` contains a valid device context pointer. + pub unsafe fn new_ref(&self) -> *mut ffi::AVBufferRef { + unsafe { ffi::av_buffer_ref(self.device_ctx) } + } + + /// Get the raw pixel format for D3D11 hardware frames. + pub fn hw_pix_fmt() -> ffi::AVPixelFormat { + ffi::AVPixelFormat::AV_PIX_FMT_D3D11 + } +} + +impl Drop for HwDeviceContext { + fn drop(&mut self) { + unsafe { + if !self.device_ctx.is_null() { + ffi::av_buffer_unref(&mut self.device_ctx); + } + } + } +} + +/// Transfer a hardware frame to a software frame. +/// +/// When the decoder outputs a frame in D3D11 format (GPU memory), +/// this function copies it to system memory (typically NV12 format). +/// +/// Returns `true` if transfer was performed, `false` if the frame +/// was already in software format. +/// +/// # Safety +/// The `frame` pointer must be a valid, non-null AVFrame from the decoder. +/// After this call, if transfer occurred, `frame` contains the software data. +pub unsafe fn transfer_hw_frame_if_needed(frame: *mut ffi::AVFrame) -> bool { + if frame.is_null() { + return false; + } + + // Safety: all operations below involve FFI calls and raw pointer derefs + // that require unsafe. The caller guarantees `frame` is valid and non-null. + unsafe { + let format = (*frame).format; + let hw_fmt = HwDeviceContext::hw_pix_fmt() as i32; + + if format != hw_fmt { + // Frame is already in software format, nothing to do + return false; + } + + // Allocate a software frame to receive the transfer + let sw_frame = ffi::av_frame_alloc(); + if sw_frame.is_null() { + warn!("Failed to allocate software frame for hw transfer"); + return false; + } + + // Transfer from GPU to CPU memory + let ret = ffi::av_hwframe_transfer_data(sw_frame, frame, 0); + if ret < 0 { + warn!( + error_code = ret, + "Failed to transfer hardware frame to software" + ); + ffi::av_frame_free(&mut (sw_frame as *mut _)); + return false; + } + + // Copy metadata (pts, etc.) from the hw frame + (*sw_frame).pts = (*frame).pts; + (*sw_frame).pkt_dts = (*frame).pkt_dts; + (*sw_frame).duration = (*frame).duration; + (*sw_frame).best_effort_timestamp = (*frame).best_effort_timestamp; + + // Move the software frame data into the original frame + ffi::av_frame_unref(frame); + ffi::av_frame_move_ref(frame, sw_frame); + ffi::av_frame_free(&mut (sw_frame as *mut _)); + } + + static FIRST_TRANSFER: AtomicBool = AtomicBool::new(true); + if FIRST_TRANSFER.swap(false, Ordering::Relaxed) { + info!("First hardware frame transferred to software successfully"); + } else { + trace!("Transferred hardware frame to software"); + } + true +} diff --git a/core/itk-video/src/lib.rs b/core/itk-video/src/lib.rs new file mode 100644 index 0000000..ee7dafb --- /dev/null +++ b/core/itk-video/src/lib.rs @@ -0,0 +1,67 @@ +//! Video decoding and frame management for the Injection Toolkit. +//! +//! This crate provides video decoding capabilities using ffmpeg-next, +//! with support for local files, HLS/DASH streams, and optionally YouTube +//! via yt-dlp. +//! +//! # Architecture +//! +//! ```text +//! StreamSource (file/URL) +//! │ +//! ▼ +//! VideoDecoder (ffmpeg-next) +//! │ +//! ▼ +//! Scaler (to 1280x720 RGBA) +//! │ +//! ▼ +//! FrameWriter (to shared memory) +//! ``` +//! +//! # Example +//! +//! ```ignore +//! use itk_video::{VideoDecoder, StreamSource, FrameWriter}; +//! use itk_shmem::FrameBuffer; +//! +//! let source = StreamSource::File("/path/to/video.mp4".into()); +//! let mut decoder = VideoDecoder::new(source)?; +//! let buffer = FrameBuffer::create("video_frames", 1280, 720)?; +//! let mut writer = FrameWriter::new(buffer); +//! +//! while let Some(frame) = decoder.next_frame()? { +//! writer.write_frame(&frame)?; +//! } +//! ``` + +pub mod decoder; +pub mod error; +pub mod frame_writer; +pub mod hwaccel; +pub mod scaler; +pub mod stream; + +#[cfg(feature = "youtube")] +pub mod youtube; + +pub use decoder::{DecodedFrame, VideoDecoder}; +pub use error::{VideoError, VideoResult}; +pub use frame_writer::FrameWriter; +pub use scaler::FrameScaler; +pub use stream::StreamSource; + +/// Default output width for video frames (720p). +pub const DEFAULT_WIDTH: u32 = 1280; + +/// Default output height for video frames (720p). +pub const DEFAULT_HEIGHT: u32 = 720; + +/// Bytes per pixel for RGBA format. +pub const BYTES_PER_PIXEL: usize = 4; + +/// Calculate the size of a single frame in bytes. +#[inline] +pub const fn frame_size(width: u32, height: u32) -> usize { + (width as usize) * (height as usize) * BYTES_PER_PIXEL +} diff --git a/core/itk-video/src/scaler.rs b/core/itk-video/src/scaler.rs new file mode 100644 index 0000000..1fe1874 --- /dev/null +++ b/core/itk-video/src/scaler.rs @@ -0,0 +1,173 @@ +//! Frame scaling and pixel format conversion. + +use crate::error::{VideoError, VideoResult}; +use crate::{BYTES_PER_PIXEL, DEFAULT_HEIGHT, DEFAULT_WIDTH}; +use ffmpeg_next::format::Pixel; +use ffmpeg_next::software::scaling::{Context as SwsContext, Flags}; +use ffmpeg_next::util::frame::video::Video as VideoFrame; + +/// Handles scaling and converting video frames to RGBA at a target resolution. +pub struct FrameScaler { + context: Option, + output_width: u32, + output_height: u32, + /// Reusable output buffer to avoid allocations. + output_buffer: Vec, + /// Reusable output frame. + output_frame: VideoFrame, + /// Whether the output frame buffer has been allocated. + frame_allocated: bool, + /// Last input dimensions/format for detecting changes. + last_input_width: u32, + last_input_height: u32, + last_input_format: Pixel, +} + +impl FrameScaler { + /// Create a new scaler with the default 720p output resolution. + pub fn new() -> Self { + Self::with_size(DEFAULT_WIDTH, DEFAULT_HEIGHT) + } + + /// Create a new scaler with a custom output resolution. + pub fn with_size(width: u32, height: u32) -> Self { + let buffer_size = (width as usize) * (height as usize) * BYTES_PER_PIXEL; + let mut output_frame = VideoFrame::empty(); + output_frame.set_format(Pixel::RGBA); + output_frame.set_width(width); + output_frame.set_height(height); + + Self { + context: None, + output_width: width, + output_height: height, + output_buffer: vec![0u8; buffer_size], + output_frame, + frame_allocated: false, + last_input_width: 0, + last_input_height: 0, + last_input_format: Pixel::None, + } + } + + /// Get the output width. + pub fn width(&self) -> u32 { + self.output_width + } + + /// Get the output height. + pub fn height(&self) -> u32 { + self.output_height + } + + /// Scale and convert a frame to RGBA. + /// + /// Returns a reference to the internal buffer containing the RGBA data. + /// The buffer is valid until the next call to `scale`. + pub fn scale(&mut self, input: &VideoFrame) -> VideoResult<&[u8]> { + let input_format = input.format(); + let input_width = input.width(); + let input_height = input.height(); + + // Create or recreate the scaling context if input parameters changed + let needs_new_context = self.context.is_some() + && (self.last_input_width != input_width + || self.last_input_height != input_height + || self.last_input_format != input_format); + + if needs_new_context || self.context.is_none() { + self.context = Some( + SwsContext::get( + input_format, + input_width, + input_height, + Pixel::RGBA, + self.output_width, + self.output_height, + Flags::BILINEAR, + ) + .map_err(|e| VideoError::ScaleError(e.to_string()))?, + ); + self.last_input_width = input_width; + self.last_input_height = input_height; + self.last_input_format = input_format; + } + + let ctx = self.context.as_mut().unwrap(); + + // Allocate output frame buffer if needed + if !self.frame_allocated { + unsafe { + let ret = ffmpeg_next::ffi::av_frame_get_buffer( + self.output_frame.as_mut_ptr(), + 32, // alignment + ); + if ret < 0 { + return Err(VideoError::ScaleError( + "failed to allocate output frame buffer".to_string(), + )); + } + } + self.frame_allocated = true; + } + + // Run the scaling operation + ctx.run(input, &mut self.output_frame) + .map_err(|e| VideoError::ScaleError(e.to_string()))?; + + // Copy the frame data to our output buffer + // RGBA frames have a single plane + let data = self.output_frame.data(0); + let linesize = self.output_frame.stride(0); + let row_bytes = (self.output_width as usize) * BYTES_PER_PIXEL; + + // Handle potential padding in frame rows + if linesize == row_bytes { + // No padding, direct copy + let copy_size = row_bytes * (self.output_height as usize); + self.output_buffer[..copy_size].copy_from_slice(&data[..copy_size]); + } else { + // Row-by-row copy to handle padding + for y in 0..self.output_height as usize { + let src_offset = y * linesize; + let dst_offset = y * row_bytes; + self.output_buffer[dst_offset..dst_offset + row_bytes] + .copy_from_slice(&data[src_offset..src_offset + row_bytes]); + } + } + + Ok(&self.output_buffer) + } + + /// Get the expected output buffer size in bytes. + pub fn buffer_size(&self) -> usize { + (self.output_width as usize) * (self.output_height as usize) * BYTES_PER_PIXEL + } +} + +impl Default for FrameScaler { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_scaler_dimensions() { + let scaler = FrameScaler::new(); + assert_eq!(scaler.width(), DEFAULT_WIDTH); + assert_eq!(scaler.height(), DEFAULT_HEIGHT); + assert_eq!(scaler.buffer_size(), 1280 * 720 * 4); + } + + #[test] + fn test_custom_dimensions() { + let scaler = FrameScaler::with_size(640, 480); + assert_eq!(scaler.width(), 640); + assert_eq!(scaler.height(), 480); + assert_eq!(scaler.buffer_size(), 640 * 480 * 4); + } +} diff --git a/core/itk-video/src/stream.rs b/core/itk-video/src/stream.rs new file mode 100644 index 0000000..5e0a3db --- /dev/null +++ b/core/itk-video/src/stream.rs @@ -0,0 +1,113 @@ +//! Video stream source types. + +use std::path::PathBuf; + +/// Represents a video source that can be decoded. +#[derive(Debug, Clone)] +pub enum StreamSource { + /// Local file path. + File(PathBuf), + + /// Remote URL (HTTP/HTTPS, including HLS/DASH manifests). + Url(String), + + /// Separate video and audio URLs (e.g., YouTube DASH streams). + UrlWithAudio { video: String, audio: String }, +} + +impl StreamSource { + /// Create a stream source from a string, auto-detecting the type. + /// + /// - Strings starting with `http://` or `https://` are treated as URLs. + /// - Strings starting with `file://` have the prefix stripped and are treated as files. + /// - Everything else is treated as a local file path. + pub fn from_string(s: &str) -> Self { + let trimmed = s.trim(); + + if trimmed.starts_with("http://") || trimmed.starts_with("https://") { + StreamSource::Url(trimmed.to_string()) + } else if let Some(path) = trimmed.strip_prefix("file://") { + StreamSource::File(PathBuf::from(path)) + } else { + StreamSource::File(PathBuf::from(trimmed)) + } + } + + /// Check if this source is a YouTube URL. + pub fn is_youtube(&self) -> bool { + match self { + StreamSource::Url(url) | StreamSource::UrlWithAudio { video: url, .. } => { + url.contains("youtube.com") || url.contains("youtu.be") + }, + StreamSource::File(_) => false, + } + } + + /// Get the path/URL as a string for ffmpeg (video stream). + pub fn as_ffmpeg_input(&self) -> &str { + match self { + StreamSource::File(path) => path.to_str().unwrap_or(""), + StreamSource::Url(url) | StreamSource::UrlWithAudio { video: url, .. } => url.as_str(), + } + } + + /// Get the audio URL if this is a split video+audio source. + pub fn audio_url(&self) -> Option<&str> { + match self { + StreamSource::UrlWithAudio { audio, .. } => Some(audio.as_str()), + _ => None, + } + } +} + +impl From for StreamSource { + fn from(path: PathBuf) -> Self { + StreamSource::File(path) + } +} + +impl From for StreamSource { + fn from(s: String) -> Self { + StreamSource::from_string(&s) + } +} + +impl From<&str> for StreamSource { + fn from(s: &str) -> Self { + StreamSource::from_string(s) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_auto_detect_file() { + let source = StreamSource::from_string("/path/to/video.mp4"); + assert!(matches!(source, StreamSource::File(_))); + } + + #[test] + fn test_auto_detect_url() { + let source = StreamSource::from_string("https://example.com/stream.m3u8"); + assert!(matches!(source, StreamSource::Url(_))); + } + + #[test] + fn test_auto_detect_file_uri() { + let source = StreamSource::from_string("file:///path/to/video.mp4"); + match source { + StreamSource::File(path) => assert_eq!(path.to_str().unwrap(), "/path/to/video.mp4"), + _ => panic!("expected File variant"), + } + } + + #[test] + fn test_is_youtube() { + assert!(StreamSource::from_string("https://www.youtube.com/watch?v=abc").is_youtube()); + assert!(StreamSource::from_string("https://youtu.be/abc").is_youtube()); + assert!(!StreamSource::from_string("https://example.com/video.mp4").is_youtube()); + assert!(!StreamSource::from_string("/path/to/video.mp4").is_youtube()); + } +} diff --git a/core/itk-video/src/youtube.rs b/core/itk-video/src/youtube.rs new file mode 100644 index 0000000..0c2c49c --- /dev/null +++ b/core/itk-video/src/youtube.rs @@ -0,0 +1,141 @@ +//! YouTube URL extraction using yt-dlp. +//! +//! This module is only available with the `youtube` feature flag. +//! +//! # Requirements +//! +//! - `yt-dlp` must be installed and available in PATH +//! - Works best with a recent version of yt-dlp +//! +//! # Example +//! +//! ```ignore +//! use itk_video::youtube; +//! +//! let direct_url = youtube::extract_url("https://www.youtube.com/watch?v=dQw4w9WgXcQ").await?; +//! // Now use direct_url with VideoDecoder +//! ``` + +use crate::error::{VideoError, VideoResult}; +use tokio::process::Command; +use tracing::{debug, info, warn}; + +/// Maximum resolution to request (720p). +const MAX_HEIGHT: u32 = 720; + +/// Extract a direct video URL from a YouTube link using yt-dlp. +/// +/// This function shells out to yt-dlp to get the actual video URL that +/// can be fed to ffmpeg. It requests the best quality up to 720p. +/// +/// # Arguments +/// +/// * `youtube_url` - A YouTube URL (youtube.com or youtu.be) +/// +/// # Returns +/// +/// The direct video URL that can be used with ffmpeg. +pub async fn extract_url(youtube_url: &str) -> VideoResult { + info!(url = %youtube_url, "extracting direct URL via yt-dlp"); + + // Build the yt-dlp command + let output = Command::new("yt-dlp") + .args([ + // Format selection: best video+audio up to 720p + "-f", + &format!( + "bestvideo[height<={}]+bestaudio/best[height<={}]", + MAX_HEIGHT, MAX_HEIGHT + ), + // Get URL only, don't download + "-g", + // No warnings (cleaner output) + "--no-warnings", + // No playlist (single video only) + "--no-playlist", + // The URL + youtube_url, + ]) + .output() + .await + .map_err(|e| { + if e.kind() == std::io::ErrorKind::NotFound { + VideoError::YoutubeExtraction( + "yt-dlp not found in PATH. Install it: pip install yt-dlp".to_string(), + ) + } else { + VideoError::YoutubeExtraction(format!("failed to run yt-dlp: {}", e)) + } + })?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(VideoError::YoutubeExtraction(format!( + "yt-dlp failed: {}", + stderr.trim() + ))); + } + + let stdout = String::from_utf8_lossy(&output.stdout); + let lines: Vec<&str> = stdout.trim().lines().collect(); + + // yt-dlp returns one URL per stream (video, audio) + // We need to return the video URL for ffmpeg + // If there are two lines, the first is video, second is audio + // ffmpeg can handle both via concat protocol + let url = if lines.is_empty() { + return Err(VideoError::YoutubeExtraction( + "yt-dlp returned no URLs".to_string(), + )); + } else if lines.len() == 1 { + // Single combined stream + lines[0].to_string() + } else { + // Multiple streams - use the first (video) for now + // TODO: Support merging video+audio streams + warn!( + "yt-dlp returned {} URLs, using first (video only)", + lines.len() + ); + lines[0].to_string() + }; + + debug!(extracted_url = %url, "extracted direct URL"); + Ok(url) +} + +/// Check if yt-dlp is available in PATH. +pub async fn is_available() -> bool { + Command::new("yt-dlp") + .arg("--version") + .output() + .await + .is_ok_and(|o| o.status.success()) +} + +/// Get the version of yt-dlp if available. +pub async fn version() -> Option { + let output = Command::new("yt-dlp") + .arg("--version") + .output() + .await + .ok()?; + + if output.status.success() { + Some(String::from_utf8_lossy(&output.stdout).trim().to_string()) + } else { + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_is_available() { + // This test just checks the function runs without panic + // Actual availability depends on the system + let _ = is_available().await; + } +} diff --git a/daemon/Cargo.toml b/daemon/Cargo.toml new file mode 100644 index 0000000..d22db28 --- /dev/null +++ b/daemon/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "itk-daemon" +description = "Daemon template for the Injection Toolkit" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +authors.workspace = true + +[dependencies] +itk-protocol = { path = "../core/itk-protocol" } +itk-shmem = { path = "../core/itk-shmem" } +itk-ipc = { path = "../core/itk-ipc" } +itk-sync = { path = "../core/itk-sync" } + +serde = { workspace = true } +serde_json = "1.0" +tokio = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = { workspace = true } +thiserror = { workspace = true } +anyhow = { workspace = true } + +[lints] +workspace = true diff --git a/daemon/src/main.rs b/daemon/src/main.rs new file mode 100644 index 0000000..0520375 --- /dev/null +++ b/daemon/src/main.rs @@ -0,0 +1,595 @@ +//! # ITK Daemon Template +//! +//! Central coordinator daemon for an Injection Toolkit project. +//! +//! This template provides: +//! - IPC server for injected modules +//! - State aggregation and caching +//! - Optional multiplayer synchronization +//! - Configurable logging +//! +//! ## Customization Points +//! +//! 1. `StateHandler` trait - implement for your application-specific state +//! 2. `process_injector_message` - handle messages from injected code +//! 3. `process_client_message` - handle messages from overlay/MCP clients +//! +//! ## Security +//! +//! IMPORTANT: Data from the injector should be treated as UNTRUSTED. +//! A compromised or malicious injector could send crafted messages. +//! All incoming data is validated before use. + +use anyhow::{Context, Result, bail}; +use itk_ipc::{IpcChannel, IpcServer}; +use itk_protocol::{ + MessageType, ScreenRect, StateEvent, StateQuery, StateResponse, StateSnapshot, VideoLoad, + VideoPause, VideoPlay, VideoSeek, decode, encode, +}; +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; +use std::thread; +use tracing::{debug, error, info, warn}; + +mod video; +use video::VideoPlayer; + +// ============================================================================= +// Security: Input Validation +// ============================================================================= + +/// Maximum allowed string length for event types and keys +const MAX_STRING_LEN: usize = 256; + +/// Maximum allowed JSON data size +const MAX_DATA_SIZE: usize = 64 * 1024; // 64 KB + +/// Maximum screen dimension (sanity check) +const MAX_SCREEN_DIM: f32 = 16384.0; + +/// Validate a ScreenRect from untrusted source +fn validate_screen_rect(rect: &ScreenRect) -> Result<()> { + // Check for NaN/Inf which could cause issues + if !rect.x.is_finite() + || !rect.y.is_finite() + || !rect.width.is_finite() + || !rect.height.is_finite() + || !rect.rotation.is_finite() + { + bail!("ScreenRect contains non-finite values"); + } + + // Sanity check dimensions + if rect.x.abs() > MAX_SCREEN_DIM + || rect.y.abs() > MAX_SCREEN_DIM + || rect.width > MAX_SCREEN_DIM + || rect.height > MAX_SCREEN_DIM + { + bail!("ScreenRect dimensions out of bounds"); + } + + // Width/height should be non-negative + if rect.width < 0.0 || rect.height < 0.0 { + bail!("ScreenRect has negative dimensions"); + } + + // Check for coordinate overflow (could crash GPU/wgpu) + let right = rect.x + rect.width; + let bottom = rect.y + rect.height; + if !right.is_finite() + || !bottom.is_finite() + || right.abs() > MAX_SCREEN_DIM + || bottom.abs() > MAX_SCREEN_DIM + { + bail!( + "ScreenRect coordinate overflow: right={}, bottom={}", + right, + bottom + ); + } + + Ok(()) +} + +/// Validate a StateEvent from untrusted source +fn validate_state_event(event: &StateEvent) -> Result<()> { + if event.event_type.len() > MAX_STRING_LEN { + bail!( + "StateEvent event_type too long: {} bytes", + event.event_type.len() + ); + } + if event.data.len() > MAX_DATA_SIZE { + bail!("StateEvent data too large: {} bytes", event.data.len()); + } + if event.app_id.len() > MAX_STRING_LEN { + bail!("StateEvent app_id too long: {} bytes", event.app_id.len()); + } + Ok(()) +} + +/// Validate a StateSnapshot from untrusted source +fn validate_state_snapshot(snapshot: &StateSnapshot) -> Result<()> { + if snapshot.app_id.len() > MAX_STRING_LEN { + bail!( + "StateSnapshot app_id too long: {} bytes", + snapshot.app_id.len() + ); + } + if snapshot.data.len() > MAX_DATA_SIZE { + bail!( + "StateSnapshot data too large: {} bytes", + snapshot.data.len() + ); + } + Ok(()) +} + +/// Application state container +/// +/// Customize this for your specific application. +#[derive(Default)] +pub struct AppState { + /// Screen rect from injector (for overlay positioning) + pub screen_rect: Option, + + /// Custom state data (JSON format for flexibility) + pub custom_data: HashMap, + + /// Last update timestamp + pub last_update_ms: u64, + + /// Video player (lazy initialized) + pub video_player: Option, +} + +/// Daemon configuration +#[derive(Debug, Clone)] +pub struct DaemonConfig { + /// Application identifier (e.g., "nms", "vrchat") + pub app_id: String, + + /// IPC channel name for injector communication + pub injector_channel: String, + + /// IPC channel name for client (overlay/MCP) communication + pub client_channel: String, + + /// Enable multiplayer sync + pub enable_sync: bool, +} + +impl Default for DaemonConfig { + fn default() -> Self { + Self { + app_id: "itk_app".to_string(), + injector_channel: "itk_injector".to_string(), + client_channel: "itk_client".to_string(), + enable_sync: false, + } + } +} + +/// Main daemon struct +pub struct Daemon { + config: DaemonConfig, + state: Arc>, +} + +impl Daemon { + /// Create a new daemon instance + pub fn new(config: DaemonConfig) -> Self { + Self { + config, + state: Arc::new(RwLock::new(AppState::default())), + } + } + + /// Run the daemon + pub fn run(&self) -> Result<()> { + info!( + app_id = %self.config.app_id, + "Starting ITK daemon" + ); + + // Start injector listener thread + let injector_state = Arc::clone(&self.state); + let injector_channel = self.config.injector_channel.clone(); + let injector_handle = thread::spawn(move || { + if let Err(e) = run_injector_listener(&injector_channel, injector_state) { + error!(?e, "Injector listener failed"); + } + }); + + // Start client listener thread + let client_state = Arc::clone(&self.state); + let client_channel = self.config.client_channel.clone(); + let app_id = self.config.app_id.clone(); + let client_handle = thread::spawn(move || { + if let Err(e) = run_client_listener(&client_channel, client_state, &app_id) { + error!(?e, "Client listener failed"); + } + }); + + info!("Daemon running. Press Ctrl+C to stop."); + + // Wait for threads + let _ = injector_handle.join(); + let _ = client_handle.join(); + + Ok(()) + } +} + +/// Run the injector IPC listener +fn run_injector_listener(channel_name: &str, state: Arc>) -> Result<()> { + info!(channel = %channel_name, "Starting injector listener"); + + let server = itk_ipc::listen(channel_name).context("Failed to create injector IPC server")?; + + loop { + info!("Waiting for injector connection..."); + + match server.accept() { + Ok(channel) => { + info!("Injector connected"); + handle_injector_connection(channel, Arc::clone(&state)); + }, + Err(e) => { + warn!(?e, "Failed to accept injector connection"); + thread::sleep(std::time::Duration::from_secs(1)); + }, + } + } +} + +/// Handle a connected injector +fn handle_injector_connection(channel: impl IpcChannel, state: Arc>) { + loop { + match channel.recv() { + Ok(data) => { + if let Err(e) = process_injector_message(&data, &state) { + warn!(?e, "Failed to process injector message"); + } + }, + Err(itk_ipc::IpcError::ChannelClosed) => { + info!("Injector disconnected"); + break; + }, + Err(e) => { + warn!(?e, "Error receiving from injector"); + break; + }, + } + } +} + +/// Process a message from the injector +/// +/// SECURITY: All data from the injector is treated as UNTRUSTED and validated. +/// Customize this function for your application's specific message types. +fn process_injector_message(data: &[u8], state: &Arc>) -> Result<()> { + let header = itk_protocol::decode_header(data)?; + + match header.msg_type { + MessageType::ScreenRect => { + let (_, rect): (_, ScreenRect) = decode(data)?; + // SECURITY: Validate before use + validate_screen_rect(&rect)?; + let mut state = state.write().unwrap(); + state.screen_rect = Some(rect); + state.last_update_ms = itk_sync::now_ms(); + }, + + MessageType::StateEvent => { + let (_, event): (_, StateEvent) = decode(data)?; + // SECURITY: Validate before use + validate_state_event(&event)?; + let mut state = state.write().unwrap(); + state.custom_data.insert(event.event_type, event.data); + state.last_update_ms = event.timestamp_ms; + }, + + MessageType::StateSnapshot => { + let (_, snapshot): (_, StateSnapshot) = decode(data)?; + // SECURITY: Validate before use + validate_state_snapshot(&snapshot)?; + let mut state = state.write().unwrap(); + state + .custom_data + .insert("snapshot".to_string(), snapshot.data); + state.last_update_ms = snapshot.timestamp_ms; + }, + + other => { + warn!(?other, "Unexpected message type from injector"); + }, + } + + Ok(()) +} + +/// Run the client (overlay/MCP) IPC listener +fn run_client_listener( + channel_name: &str, + state: Arc>, + app_id: &str, +) -> Result<()> { + info!(channel = %channel_name, "Starting client listener"); + + let server = itk_ipc::listen(channel_name).context("Failed to create client IPC server")?; + + loop { + info!("Waiting for client connection..."); + + match server.accept() { + Ok(channel) => { + info!("Client connected"); + handle_client_connection(channel, Arc::clone(&state), app_id); + }, + Err(e) => { + warn!(?e, "Failed to accept client connection"); + thread::sleep(std::time::Duration::from_secs(1)); + }, + } + } +} + +/// Handle a connected client +fn handle_client_connection(channel: impl IpcChannel, state: Arc>, app_id: &str) { + loop { + match channel.recv() { + Ok(data) => { + if let Err(e) = process_client_message(&data, &state, &channel, app_id) { + warn!(?e, "Failed to process client message"); + } + }, + Err(itk_ipc::IpcError::ChannelClosed) => { + info!("Client disconnected"); + break; + }, + Err(e) => { + warn!(?e, "Error receiving from client"); + break; + }, + } + } +} + +/// Process a message from a client (overlay or MCP) +/// +/// Customize this function for your application's specific queries. +fn process_client_message( + data: &[u8], + state: &Arc>, + channel: &impl IpcChannel, + app_id: &str, +) -> Result<()> { + let header = itk_protocol::decode_header(data)?; + + match header.msg_type { + MessageType::Ping => { + // Respond with pong + let pong = encode(MessageType::Pong, &())?; + channel.send(&pong)?; + }, + + MessageType::StateQuery => { + let (_, query): (_, StateQuery) = decode(data)?; + let response = handle_state_query(&query, state, app_id)?; + let encoded = encode(MessageType::StateResponse, &response)?; + channel.send(&encoded)?; + }, + + // Video playback commands + MessageType::VideoLoad => { + let (_, cmd): (_, VideoLoad) = decode(data)?; + debug!(source = %cmd.source, "Video load command"); + handle_video_load(state, &cmd); + }, + + MessageType::VideoPlay => { + let (_, _cmd): (_, VideoPlay) = decode(data)?; + debug!("Video play command"); + handle_video_play(state); + }, + + MessageType::VideoPause => { + let (_, _cmd): (_, VideoPause) = decode(data)?; + debug!("Video pause command"); + handle_video_pause(state); + }, + + MessageType::VideoSeek => { + let (_, cmd): (_, VideoSeek) = decode(data)?; + debug!(position_ms = cmd.position_ms, "Video seek command"); + handle_video_seek(state, cmd.position_ms); + }, + + other => { + warn!(?other, "Unexpected message type from client"); + }, + } + + Ok(()) +} + +/// Handle a state query from a client +fn handle_state_query( + query: &StateQuery, + state: &Arc>, + app_id: &str, +) -> Result { + let state = state.read().unwrap(); + + let response = match query.query_type.as_str() { + "screen_rect" => { + if let Some(ref rect) = state.screen_rect { + StateResponse { + success: true, + data: Some(serde_json::to_string(rect)?), + error: None, + } + } else { + StateResponse { + success: false, + data: None, + error: Some("No screen rect available".to_string()), + } + } + }, + + "snapshot" => { + let snapshot = StateSnapshot { + app_id: app_id.to_string(), + timestamp_ms: state.last_update_ms, + data: serde_json::to_string(&state.custom_data)?, + }; + StateResponse { + success: true, + data: Some(serde_json::to_string(&snapshot)?), + error: None, + } + }, + + "custom" => { + // Query for specific custom data key + if let Some(value) = state.custom_data.get(&query.params) { + StateResponse { + success: true, + data: Some(value.clone()), + error: None, + } + } else { + StateResponse { + success: false, + data: None, + error: Some(format!("Key not found: {}", query.params)), + } + } + }, + + "video_state" => { + // Query for current video playback state + if let Some(ref player) = state.video_player { + if let Some(video_state) = player.get_video_state() { + StateResponse { + success: true, + data: Some(serde_json::to_string(&video_state)?), + error: None, + } + } else { + StateResponse { + success: false, + data: None, + error: Some("No video loaded".to_string()), + } + } + } else { + StateResponse { + success: false, + data: None, + error: Some("Video player not initialized".to_string()), + } + } + }, + + "video_metadata" => { + // Query for video metadata + if let Some(ref player) = state.video_player { + if let Some(metadata) = player.get_metadata() { + StateResponse { + success: true, + data: Some(serde_json::to_string(&metadata)?), + error: None, + } + } else { + StateResponse { + success: false, + data: None, + error: Some("No video loaded".to_string()), + } + } + } else { + StateResponse { + success: false, + data: None, + error: Some("Video player not initialized".to_string()), + } + } + }, + + _ => StateResponse { + success: false, + data: None, + error: Some(format!("Unknown query type: {}", query.query_type)), + }, + }; + + Ok(response) +} + +// ============================================================================= +// Video Playback Handlers +// ============================================================================= + +/// Ensure the video player is initialized +fn ensure_video_player(state: &Arc>) { + let mut state = state.write().unwrap(); + if state.video_player.is_none() { + info!("Initializing video player"); + state.video_player = Some(VideoPlayer::new()); + } +} + +/// Handle a video load command +fn handle_video_load(state: &Arc>, cmd: &VideoLoad) { + ensure_video_player(state); + let state = state.read().unwrap(); + if let Some(ref player) = state.video_player { + player.load(&cmd.source, cmd.start_position_ms, cmd.autoplay); + } +} + +/// Handle a video play command +fn handle_video_play(state: &Arc>) { + let state = state.read().unwrap(); + if let Some(ref player) = state.video_player { + player.play(); + } +} + +/// Handle a video pause command +fn handle_video_pause(state: &Arc>) { + let state = state.read().unwrap(); + if let Some(ref player) = state.video_player { + player.pause(); + } +} + +/// Handle a video seek command +fn handle_video_seek(state: &Arc>, position_ms: u64) { + let state = state.read().unwrap(); + if let Some(ref player) = state.video_player { + player.seek(position_ms); + } +} + +fn main() -> Result<()> { + // Initialize logging + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive("itk_daemon=info".parse().unwrap()), + ) + .init(); + + // Load config (in a real app, this would come from CLI args or a config file) + let config = DaemonConfig::default(); + + // Create and run daemon + let daemon = Daemon::new(config); + daemon.run()?; + + Ok(()) +} diff --git a/daemon/src/video/mod.rs b/daemon/src/video/mod.rs new file mode 100644 index 0000000..83e2c6b --- /dev/null +++ b/daemon/src/video/mod.rs @@ -0,0 +1,13 @@ +//! Video playback subsystem for the daemon. +//! +//! This module handles video decoding, frame output to shared memory, +//! and playback state management. + +mod player; +mod state; + +pub use player::VideoPlayer; + +// Re-export state types for external use +#[allow(unused_imports)] +pub use state::{PlayerCommand, PlayerState, VideoInfo}; diff --git a/daemon/src/video/player.rs b/daemon/src/video/player.rs new file mode 100644 index 0000000..b31699c --- /dev/null +++ b/daemon/src/video/player.rs @@ -0,0 +1,388 @@ +//! Video player implementation. + +use super::state::{PlayerCommand, PlayerState, VideoInfo}; +use itk_protocol::{VideoMetadata, VideoState}; +use itk_shmem::FrameBuffer; +use std::sync::mpsc::{self, Receiver, Sender}; +use std::sync::{Arc, Mutex}; +use std::thread::{self, JoinHandle}; +use std::time::{Duration, Instant}; +use tracing::{debug, error, info, warn}; + +/// Default output width for video frames. +const DEFAULT_WIDTH: u32 = 1280; +/// Default output height for video frames. +const DEFAULT_HEIGHT: u32 = 720; +/// Shared memory name for video frames. +const SHMEM_NAME: &str = "itk_video_frames"; + +/// Video player that decodes video and writes frames to shared memory. +pub struct VideoPlayer { + /// Current player state. + state: Arc>, + /// Command sender for the decode thread. + command_tx: Sender, + /// Handle to the decode thread. + decode_thread: Option>, + /// Shared memory frame buffer (used when ffmpeg decoding is enabled). + #[allow(dead_code)] + frame_buffer: Option, +} + +impl VideoPlayer { + /// Create a new video player. + pub fn new() -> Self { + let (command_tx, command_rx) = mpsc::channel(); + let state = Arc::new(Mutex::new(PlayerState::Idle)); + + // Try to create the shared memory frame buffer + let frame_buffer = match FrameBuffer::create(SHMEM_NAME, DEFAULT_WIDTH, DEFAULT_HEIGHT) { + Ok(fb) => { + info!( + width = DEFAULT_WIDTH, + height = DEFAULT_HEIGHT, + "Created video frame buffer" + ); + Some(fb) + }, + Err(e) => { + warn!(?e, "Failed to create frame buffer, video output disabled"); + None + }, + }; + + // Start the decode thread + let state_clone = Arc::clone(&state); + let decode_thread = thread::spawn(move || { + decode_loop(state_clone, command_rx); + }); + + Self { + state, + command_tx, + decode_thread: Some(decode_thread), + frame_buffer, + } + } + + /// Send a command to the video player. + pub fn send_command(&self, cmd: PlayerCommand) { + if let Err(e) = self.command_tx.send(cmd) { + error!(?e, "Failed to send command to video player"); + } + } + + /// Load a video from a source. + pub fn load(&self, source: &str, start_position_ms: u64, autoplay: bool) { + self.send_command(PlayerCommand::Load { + source: source.to_string(), + start_position_ms, + autoplay, + }); + } + + /// Start or resume playback. + pub fn play(&self) { + self.send_command(PlayerCommand::Play); + } + + /// Pause playback. + pub fn pause(&self) { + self.send_command(PlayerCommand::Pause); + } + + /// Seek to a position. + pub fn seek(&self, position_ms: u64) { + self.send_command(PlayerCommand::Seek { position_ms }); + } + + /// Stop playback and unload. + pub fn stop(&self) { + self.send_command(PlayerCommand::Stop); + } + + /// Get the current player state. + pub fn state(&self) -> PlayerState { + self.state.lock().unwrap().clone() + } + + /// Get the current video state for protocol messages. + pub fn get_video_state(&self) -> Option { + let state = self.state.lock().unwrap(); + match &*state { + PlayerState::Playing { info, .. } + | PlayerState::Paused { info, .. } + | PlayerState::Buffering { info, .. } => Some(VideoState { + content_id: info.content_id.clone(), + position_ms: state.position_ms(), + duration_ms: info.duration_ms, + is_playing: state.is_playing(), + is_buffering: matches!(*state, PlayerState::Buffering { .. }), + playback_rate: info.playback_rate, + volume: info.volume, + }), + _ => None, + } + } + + /// Get video metadata for protocol messages. + pub fn get_metadata(&self) -> Option { + let state = self.state.lock().unwrap(); + state.video_info().map(|info| VideoMetadata { + content_id: info.content_id.clone(), + width: info.width, + height: info.height, + duration_ms: info.duration_ms, + fps: info.fps, + codec: info.codec.clone(), + is_live: info.is_live, + title: info.title.clone(), + }) + } +} + +impl Default for VideoPlayer { + fn default() -> Self { + Self::new() + } +} + +impl Drop for VideoPlayer { + fn drop(&mut self) { + // Signal the decode thread to stop + let _ = self.command_tx.send(PlayerCommand::Stop); + + // Wait for the thread to finish + if let Some(handle) = self.decode_thread.take() { + let _ = handle.join(); + } + } +} + +/// Main decode loop that runs in a separate thread. +fn decode_loop(state: Arc>, command_rx: Receiver) { + info!("Video decode thread started"); + + loop { + // Wait for a command with timeout to allow periodic state checks + match command_rx.recv_timeout(Duration::from_millis(16)) { + Ok(cmd) => { + debug!(?cmd, "Received video command"); + match cmd { + PlayerCommand::Load { + source, + start_position_ms, + autoplay, + } => { + handle_load(&state, &source, start_position_ms, autoplay); + }, + PlayerCommand::Play => { + handle_play(&state); + }, + PlayerCommand::Pause => { + handle_pause(&state); + }, + PlayerCommand::Seek { position_ms } => { + handle_seek(&state, position_ms); + }, + PlayerCommand::SetRate { rate } => { + handle_set_rate(&state, rate); + }, + PlayerCommand::SetVolume { volume } => { + handle_set_volume(&state, volume); + }, + PlayerCommand::Stop => { + info!("Video decode thread stopping"); + *state.lock().unwrap() = PlayerState::Idle; + break; + }, + } + }, + Err(mpsc::RecvTimeoutError::Timeout) => { + // Check if we need to decode more frames + let current_state = state.lock().unwrap().clone(); + if let PlayerState::Playing { .. } = current_state { + // In a real implementation, this would decode and output frames + // For now, just update the position based on elapsed time + } + }, + Err(mpsc::RecvTimeoutError::Disconnected) => { + info!("Command channel disconnected, stopping decode thread"); + break; + }, + } + } +} + +/// Handle a load command. +fn handle_load( + state: &Arc>, + source: &str, + start_position_ms: u64, + autoplay: bool, +) { + info!(source = %source, start_ms = start_position_ms, autoplay, "Loading video"); + + // Set loading state + *state.lock().unwrap() = PlayerState::Loading { + source: source.to_string(), + }; + + // In a real implementation, this would: + // 1. Initialize ffmpeg decoder for the source + // 2. Extract metadata (duration, codec, fps) + // 3. Seek to start_position_ms + // 4. Start decoding if autoplay is true + + // For now, create a mock video info + let content_id = format!("{:016x}", hash_string(source)); + let info = VideoInfo { + content_id, + width: DEFAULT_WIDTH, + height: DEFAULT_HEIGHT, + duration_ms: 0, // Unknown duration for mock + fps: 30.0, + codec: "mock".to_string(), + is_live: source.contains(".m3u8") || source.contains("/live/"), + title: None, + playback_rate: 1.0, + volume: 1.0, + }; + + if autoplay { + *state.lock().unwrap() = PlayerState::Playing { + info, + position_ms: start_position_ms, + started_at: Instant::now(), + }; + } else { + *state.lock().unwrap() = PlayerState::Paused { + info, + position_ms: start_position_ms, + }; + } +} + +/// Handle a play command. +fn handle_play(state: &Arc>) { + let mut state = state.lock().unwrap(); + if let PlayerState::Paused { info, position_ms } = state.clone() { + info!(position_ms, "Resuming playback"); + *state = PlayerState::Playing { + info, + position_ms, + started_at: Instant::now(), + }; + } +} + +/// Handle a pause command. +fn handle_pause(state: &Arc>) { + let mut state = state.lock().unwrap(); + if let PlayerState::Playing { + info, + position_ms, + started_at, + } = state.clone() + { + let current_pos = position_ms.saturating_add(started_at.elapsed().as_millis() as u64); + info!(position_ms = current_pos, "Pausing playback"); + *state = PlayerState::Paused { + info, + position_ms: current_pos, + }; + } +} + +/// Handle a seek command. +fn handle_seek(state: &Arc>, position_ms: u64) { + let mut state = state.lock().unwrap(); + match state.clone() { + PlayerState::Playing { info, .. } => { + info!(position_ms, "Seeking (playing)"); + *state = PlayerState::Playing { + info, + position_ms, + started_at: Instant::now(), + }; + }, + PlayerState::Paused { info, .. } => { + info!(position_ms, "Seeking (paused)"); + *state = PlayerState::Paused { info, position_ms }; + }, + _ => { + warn!("Seek ignored - no video loaded"); + }, + } +} + +/// Handle a set rate command. +fn handle_set_rate(state: &Arc>, rate: f64) { + let mut state = state.lock().unwrap(); + if let Some(info) = state.video_info().cloned() { + let mut new_info = info; + new_info.playback_rate = rate.clamp(0.25, 4.0); + debug!(rate = new_info.playback_rate, "Set playback rate"); + + // Update the info in the current state + match &mut *state { + PlayerState::Playing { info, .. } + | PlayerState::Paused { info, .. } + | PlayerState::Buffering { info, .. } => { + *info = new_info; + }, + _ => {}, + } + } +} + +/// Handle a set volume command. +fn handle_set_volume(state: &Arc>, volume: f32) { + let mut state = state.lock().unwrap(); + if let Some(info) = state.video_info().cloned() { + let mut new_info = info; + new_info.volume = volume.clamp(0.0, 1.0); + debug!(volume = new_info.volume, "Set volume"); + + // Update the info in the current state + match &mut *state { + PlayerState::Playing { info, .. } + | PlayerState::Paused { info, .. } + | PlayerState::Buffering { info, .. } => { + *info = new_info; + }, + _ => {}, + } + } +} + +/// Simple hash function for content IDs. +fn hash_string(s: &str) -> u64 { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + let mut hasher = DefaultHasher::new(); + s.hash(&mut hasher); + hasher.finish() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_player_creation() { + // This test verifies the player can be created + // Frame buffer creation may fail without proper permissions + let player = VideoPlayer::new(); + assert!(matches!(player.state(), PlayerState::Idle)); + } + + #[test] + fn test_hash_string() { + let hash1 = hash_string("test"); + let hash2 = hash_string("test"); + let hash3 = hash_string("other"); + assert_eq!(hash1, hash2); + assert_ne!(hash1, hash3); + } +} diff --git a/daemon/src/video/state.rs b/daemon/src/video/state.rs new file mode 100644 index 0000000..9d3b5d4 --- /dev/null +++ b/daemon/src/video/state.rs @@ -0,0 +1,144 @@ +//! Video player state and commands. + +use std::time::Instant; + +/// Commands that can be sent to the video player. +#[derive(Debug, Clone)] +pub enum PlayerCommand { + /// Load a video from a URL or file path. + Load { + source: String, + start_position_ms: u64, + autoplay: bool, + }, + /// Start or resume playback. + Play, + /// Pause playback. + Pause, + /// Seek to a position in milliseconds. + Seek { position_ms: u64 }, + /// Set the playback rate (1.0 = normal). + SetRate { rate: f64 }, + /// Set volume (0.0 - 1.0). + SetVolume { volume: f32 }, + /// Stop playback and unload the video. + Stop, +} + +/// Video player state. +#[derive(Debug, Clone)] +pub enum PlayerState { + /// No video loaded. + Idle, + /// Loading a video. + Loading { source: String }, + /// Video is playing. + Playing { + info: VideoInfo, + position_ms: u64, + started_at: Instant, + }, + /// Video is paused. + Paused { info: VideoInfo, position_ms: u64 }, + /// Buffering (waiting for data). + Buffering { + info: VideoInfo, + target_position_ms: u64, + }, + /// Playback error. + Error { message: String }, +} + +impl PlayerState { + /// Check if the player is currently playing. + pub fn is_playing(&self) -> bool { + matches!(self, PlayerState::Playing { .. }) + } + + /// Check if the player is paused. + pub fn is_paused(&self) -> bool { + matches!(self, PlayerState::Paused { .. }) + } + + /// Check if a video is loaded (playing, paused, or buffering). + pub fn has_video(&self) -> bool { + matches!( + self, + PlayerState::Playing { .. } + | PlayerState::Paused { .. } + | PlayerState::Buffering { .. } + ) + } + + /// Get the current video info, if available. + pub fn video_info(&self) -> Option<&VideoInfo> { + match self { + PlayerState::Playing { info, .. } + | PlayerState::Paused { info, .. } + | PlayerState::Buffering { info, .. } => Some(info), + _ => None, + } + } + + /// Get the current position in milliseconds. + pub fn position_ms(&self) -> u64 { + match self { + PlayerState::Playing { + position_ms, + started_at, + .. + } => { + // Calculate current position based on elapsed time + let elapsed_ms = started_at.elapsed().as_millis() as u64; + position_ms.saturating_add(elapsed_ms) + }, + PlayerState::Paused { position_ms, .. } => *position_ms, + PlayerState::Buffering { + target_position_ms, .. + } => *target_position_ms, + _ => 0, + } + } +} + +/// Information about the currently loaded video. +#[derive(Debug, Clone)] +pub struct VideoInfo { + /// Content identifier (URL or file path hash). + pub content_id: String, + /// Video width in pixels. + pub width: u32, + /// Video height in pixels. + pub height: u32, + /// Duration in milliseconds (0 if unknown/live). + pub duration_ms: u64, + /// Frames per second. + pub fps: f32, + /// Codec name. + pub codec: String, + /// Whether this is a live stream. + pub is_live: bool, + /// Title from metadata. + pub title: Option, + /// Playback rate (1.0 = normal). + pub playback_rate: f64, + /// Volume (0.0 - 1.0). + pub volume: f32, +} + +impl Default for VideoInfo { + fn default() -> Self { + Self { + content_id: String::new(), + width: 1280, + height: 720, + duration_ms: 0, + fps: 30.0, + codec: String::new(), + is_live: false, + title: None, + playback_rate: 1.0, + volume: 1.0, + } + } +} diff --git a/deny.toml b/deny.toml new file mode 100644 index 0000000..4d781ab --- /dev/null +++ b/deny.toml @@ -0,0 +1,47 @@ +# cargo-deny configuration for game-mods +# See https://embarkstudios.github.io/cargo-deny/ + +[advisories] +unmaintained = "workspace" +ignore = [ + # bincode: maintainers ceased development (RUSTSEC-2025-0141) + # TODO: migrate to bincode2 or bitcode + "RUSTSEC-2025-0141", + # bytes: integer overflow in BytesMut::reserve (RUSTSEC-2026-0007) + # TODO: update bytes to patched version + "RUSTSEC-2026-0007", +] + +[licenses] +# Allow common permissive licenses +allow = [ + "MIT", + "Apache-2.0", + "Apache-2.0 WITH LLVM-exception", + "BSD-2-Clause", + "BSD-3-Clause", + "ISC", + "Zlib", + "CC0-1.0", + "Unlicense", + "MPL-2.0", + "Unicode-DFS-2016", + "Unicode-3.0", + "BSL-1.0", + "WTFPL", + "OFL-1.1", + "LicenseRef-UFL-1.0", +] +confidence-threshold = 0.8 +# Workspace crates inherit license from [workspace.package] +private = { ignore = true } + +[bans] +multiple-versions = "warn" +wildcards = "allow" +deny = [] + +[sources] +unknown-registry = "warn" +unknown-git = "warn" +allow-registry = ["https://github.com/rust-lang/crates.io-index"] diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..e251c33 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,153 @@ +services: + # -- CI Services -------------------------------------------------------- + rust-ci: + build: + context: . + dockerfile: docker/rust-ci.Dockerfile + user: "${USER_ID:-1000}:${GROUP_ID:-1000}" + volumes: + - .:/app + - cargo-registry-cache:/tmp/cargo + environment: + - CARGO_HOME=/tmp/cargo + - RUSTUP_HOME=/usr/local/rustup + - CARGO_INCREMENTAL=1 + - CI=true + working_dir: /app + command: ["/bin/bash"] + stdin_open: true + tty: true + profiles: + - ci + + # -- Agent MCP Services ------------------------------------------------- + # These services use pre-built Docker images. + # They are NOT buildable from this repo -- source code lives in a + # separate tools repository under tools/mcp/mcp_/. + # + # To build them (one-time, from the tools repo checkout): + # docker compose --profile services build + # + # Images follow the naming convention: game-mods-mcp-:latest + # The .mcp.json file in this repo configures Claude Code to launch + # these services via `docker compose --profile services run`. + + mcp-code-quality: + image: game-mods-mcp-code-quality:latest + volumes: + - ./:/app:ro + - /tmp:/tmp + environment: + - RUST_LOG=info + - MCP_CODE_QUALITY_ALLOWED_PATHS=/app,/workspace,/home + profiles: + - services + + mcp-content-creation: + image: game-mods-mcp-content-creation:latest + volumes: + - ./outputs/mcp-content:/output + - .:/app:ro + environment: + - RUST_LOG=info + - MCP_OUTPUT_DIR=/output + - MCP_PROJECT_ROOT=/app + - MCP_HOST_PROJECT_ROOT=${PWD} + profiles: + - services + + mcp-gemini: + image: game-mods-mcp-gemini:latest + volumes: + - ~/.gemini:/home/geminiuser/.gemini + - ./:/workspace:ro + environment: + - PYTHONUNBUFFERED=1 + - GEMINI_API_KEY=${GOOGLE_API_KEY:-} + - GOOGLE_API_KEY= + - GEMINI_TIMEOUT=${GEMINI_TIMEOUT:-300} + - GEMINI_USE_CONTAINER=false + - HOME=/home/geminiuser + profiles: + - services + + mcp-opencode: + image: game-mods-mcp-opencode:latest + volumes: + - ./:/app:ro + environment: + - PORT=8014 + - OPENROUTER_API_KEY=${OPENROUTER_API_KEY} + - OPENCODE_ENABLED=${OPENCODE_ENABLED:-true} + - OPENCODE_MODEL=${OPENCODE_MODEL:-qwen/qwen-2.5-coder-32b-instruct} + profiles: + - services + + mcp-crush: + image: game-mods-mcp-crush:latest + volumes: + - ./:/app:ro + environment: + - PORT=8015 + - OPENROUTER_API_KEY=${OPENROUTER_API_KEY} + - CRUSH_ENABLED=${CRUSH_ENABLED:-true} + profiles: + - services + + mcp-codex: + image: game-mods-mcp-codex:latest + user: "${USER_ID:-1000}:${GROUP_ID:-1000}" + volumes: + - ./:/workspace:ro + - ~/.codex:/home/user/.codex:rw + environment: + - MODE=mcp + - PORT=8021 + - CODEX_ENABLED=${CODEX_ENABLED:-true} + - CODEX_AUTH_PATH=/home/user/.codex/auth.json + profiles: + - services + + mcp-github-board: + image: game-mods-mcp-github-board:latest + volumes: + - ./:/app:ro + environment: + - RUST_LOG=info + - GITHUB_TOKEN=${GITHUB_TOKEN} + - GITHUB_REPOSITORY=${GITHUB_REPOSITORY} + - GITHUB_PROJECT_NUMBER=${GITHUB_PROJECT_NUMBER:-1} + profiles: + - services + + mcp-agentcore-memory: + image: game-mods-mcp-agentcore-memory:latest + user: "${USER_ID:-1000}:${GROUP_ID:-1000}" + volumes: + - .:/app:ro + - ${HOME}/.aws:/home/appuser/.aws:ro + environment: + - HOME=/home/appuser + - PYTHONUNBUFFERED=1 + - PORT=8023 + - MEMORY_PROVIDER=${MEMORY_PROVIDER:-chromadb} + - AWS_REGION=${AWS_REGION:-us-east-1} + - CHROMADB_HOST=chromadb + - CHROMADB_PORT=8000 + - CHROMADB_COLLECTION=${CHROMADB_COLLECTION:-agent_memory} + profiles: + - services + - memory + + mcp-reaction-search: + image: game-mods-mcp-reaction-search:latest + volumes: + - reaction-search-cache:/home/mcp/.cache + environment: + - RUST_LOG=info + profiles: + - services + +volumes: + cargo-registry-cache: + reaction-search-cache: diff --git a/docker/rust-ci.Dockerfile b/docker/rust-ci.Dockerfile new file mode 100644 index 0000000..5b20adf --- /dev/null +++ b/docker/rust-ci.Dockerfile @@ -0,0 +1,47 @@ +# syntax=docker/dockerfile:1.4 +# Rust CI image for game-mods +# Stable toolchain with system dependencies for injection toolkit crates + +FROM rust:1.93-slim + +# System dependencies +# - pkg-config + libclang: required for ffmpeg-next bindings (itk-video) +# - libasound2-dev: ALSA headers for cpal audio (itk-daemon) +# - libffmpeg-dev or ffmpeg: video decoding (itk-video) +RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ + --mount=type=cache,target=/var/lib/apt,sharing=locked \ + apt-get update && apt-get install -y --no-install-recommends \ + pkg-config \ + git \ + clang \ + libclang-dev \ + libavcodec-dev \ + libavformat-dev \ + libavutil-dev \ + libswscale-dev \ + libswresample-dev \ + libavfilter-dev \ + libavdevice-dev \ + libasound2-dev \ + && rm -rf /var/lib/apt/lists/* + +# Rust components +RUN rustup component add rustfmt clippy + +# Install cargo-deny for license/advisory checks +RUN --mount=type=cache,target=/usr/local/cargo/registry \ + cargo install cargo-deny --locked 2>/dev/null || true + +# Non-root user (overridden by docker-compose USER_ID/GROUP_ID) +RUN useradd -m -u 1000 ciuser \ + && mkdir -p /tmp/cargo && chmod 1777 /tmp/cargo + +WORKDIR /workspace + +ENV CARGO_HOME=/tmp/cargo +ENV RUSTUP_HOME=/usr/local/rustup +ENV CARGO_INCREMENTAL=1 \ + CARGO_NET_RETRY=10 \ + RUST_BACKTRACE=short + +CMD ["bash"] diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md new file mode 100644 index 0000000..03fce94 --- /dev/null +++ b/docs/ARCHITECTURE.md @@ -0,0 +1,403 @@ +# Injection Toolkit Architecture + +## Design Philosophy + +The Injection Toolkit follows a **minimal injection, maximal external processing** philosophy: + +1. **Inject only what's necessary** - Code running inside the target application should be minimal +2. **Process externally** - Heavy lifting happens in the daemon and overlay +3. **Communicate via IPC** - Clean separation between components +4. **Fail gracefully** - Components should handle failures without crashing the target + +## Component Responsibilities + +### Injector (DLL/SO) + +**What it does:** +- Connects to daemon via IPC +- Hooks strategic functions in the target +- Extracts minimal state data +- Sends state updates to daemon + +**What it doesn't do:** +- Complex processing +- Rendering +- Network communication (beyond local IPC) +- Blocking operations + +### Daemon + +**What it does:** +- Receives state from injector +- Aggregates and caches state +- Serves queries from clients (overlay, MCP) +- Optionally handles multiplayer sync + +**What it doesn't do:** +- Inject code +- Render anything +- Interact directly with target application + +### Overlay + +**What it does:** +- Renders content on top of target application +- Handles click-through mode +- Provides interactive UI when enabled +- Reads frame data from shared memory (for video) + +**What it doesn't do:** +- Inject code +- Process video/audio +- Heavy computation + +### Shared Memory + +**What it does:** +- Transfer large data (video frames) between daemon and overlay +- Lock-free triple-buffered design +- Seqlock for consistency + +**What it doesn't do:** +- Small message passing (use IPC) +- Cross-machine communication + +## Data Flow + +### State Extraction Flow + +``` +Target App Function + │ + ▼ + [Hook triggers] + │ + ▼ +Injector extracts state + │ + ▼ + [IPC message] + │ + ▼ +Daemon receives & caches + │ + ▼ + [IPC query] + │ + ▼ +Client receives state +``` + +### Frame Data Flow (Video/Images) + +``` +Video Source (daemon) + │ + ▼ + [Decode frame] + │ + ▼ +Write to shared memory + (seqlock protected) + │ + ▼ +Overlay reads frame + │ + ▼ + [GPU texture upload] + │ + ▼ +Render to screen +``` + +## Protocol Design + +### Wire Format + +All messages use a common header: + +``` +┌─────────┬─────────┬──────────┬─────────────┬─────────┬───────────┐ +│ Magic │ Version │ MsgType │ PayloadLen │ CRC32 │ Payload │ +│ 4 bytes │ 4 bytes │ 4 bytes │ 4 bytes │ 4 bytes │ N bytes │ +└─────────┴─────────┴──────────┴─────────────┴─────────┴───────────┘ +``` + +- **Magic**: `"ITKP"` - identifies ITK protocol +- **Version**: Protocol version for compatibility +- **MsgType**: Enum identifying message type +- **PayloadLen**: Size of payload (max 1MB) +- **CRC32**: Checksum for validation +- **Payload**: bincode-serialized data + +### Message Types + +| Type | Direction | Purpose | +|------|-----------|---------| +| Ping/Pong | Any | Keepalive, latency measurement | +| ScreenRect | Injector → Daemon | Overlay positioning | +| WindowState | Injector → Daemon | Target window properties | +| StateSnapshot | Injector → Daemon | Full state dump | +| StateEvent | Injector → Daemon | Incremental state change | +| StateQuery | Client → Daemon | Request state | +| StateResponse | Daemon → Client | State data | + +## Platform Abstraction + +### IPC + +| Platform | Implementation | +|----------|----------------| +| Windows | Named Pipes (`\\.\pipe\itk_*`) | +| Linux | Unix Domain Sockets (`/tmp/itk_*.sock`) | + +### Shared Memory + +| Platform | Implementation | +|----------|----------------| +| Windows | `CreateFileMappingW` + `MapViewOfFile` | +| Linux | `shm_open` + `mmap` | + +### Overlay + +| Platform | Click-Through | Always-on-Top | +|----------|---------------|---------------| +| Windows | `WS_EX_TRANSPARENT` | `HWND_TOPMOST` | +| Linux/X11 | SHAPE extension | `_NET_WM_STATE_ABOVE` | +| Linux/Wayland | Layer-shell* | Layer-shell* | + +*Requires compositor support and additional dependencies. + +## Synchronization + +### Seqlock Algorithm + +Used for lock-free shared memory access: + +```rust +// Writer +seq.fetch_add(1, Acquire); // Odd = writing, prevents data writes from floating up +// ... write data (Relaxed is fine here) ... +seq.fetch_add(1, Release); // Even = done, makes writes visible + +// Reader +loop { + let s1 = seq.load(Acquire); // Synchronizes with writer's Release + if s1 & 1 != 0 { continue; } // Write in progress + // ... read data (Relaxed, bounded by fence below) ... + fence(Acquire); // Prevents data reads from sinking past seq2 check + let s2 = seq.load(Relaxed); // Fence provides ordering + if s1 == s2 { break; } // Consistent read +} +``` + +### Memory Ordering Strategy + +The seqlock uses carefully chosen orderings for ARM compatibility: + +**Writer:** +- **begin_write**: `fetch_add(1, Acquire)` - Prevents subsequent data writes from + being reordered before the odd-increment. Without this, readers could see + "even" sequence but read partially-written data. +- **end_write**: `fetch_add(1, Release)` - Ensures all data writes are visible + before the even sequence number. + +**Reader:** +- **First seq load**: `load(Acquire)` - Synchronizes with writer's Release, + ensuring we see data that was written before the sequence we observe. +- **Data reads**: `Relaxed` - Safe because bounded by the fence below. +- **Fence**: `fence(Acquire)` before second seq check - **Critical**: Prevents + data loads from being reordered past the validation. Without this fence, the + CPU could check seq2, find it valid, then execute data reads that see new/torn + data from a concurrent write. +- **Second seq load**: `Relaxed` - The fence provides the necessary ordering. + +This approach: +- **ARM compatible**: Correctly handles weak memory ordering +- **Performant**: Uses minimal barriers (no SeqCst) +- **Verified**: Tested with Loom concurrency checker + +### Single-Writer Requirement + +**CRITICAL**: The seqlock implementation assumes a **single writer**. Multiple concurrent +writers will corrupt the sequence counter and cause undefined behavior. This is an +intentional design choice for our use case: + +- **Daemon**: Single process, single write path for frame updates +- **Injector**: Single process, single write path for state updates + +If you need multiple writers in the future: +1. Wrap the `Seqlock::write()` call with an external `Mutex` or `RwLock` +2. Or use a different synchronization primitive (e.g., a channel-based approach) + +```rust +// SAFE: External mutex protects multi-threaded writer access +let writer_lock = Mutex::new(()); +{ + let _guard = writer_lock.lock().unwrap(); + seqlock.write(|state| { + // ... update state ... + }); +} + +// UNSAFE: Multiple threads calling write() without synchronization +// This WILL corrupt data - do not do this! +std::thread::spawn(|| seqlock.write(|s| s.pts_ms = 100)); // Thread 1 +std::thread::spawn(|| seqlock.write(|s| s.pts_ms = 200)); // Thread 2 - DATA RACE! +``` + +The seqlock is designed for **one writer, many readers** - this is the common pattern +for frame buffer synchronization where one producer (decoder) writes frames and +multiple consumers (overlay, MCP clients) read them. + +## Error Handling + +### Graceful Degradation + +| Failure | Detection | Behavior | +|---------|-----------|----------| +| Daemon unreachable | IPC error | Injector continues without state export | +| Injector disconnects | IPC timeout | Daemon serves stale state | +| Overlay crash | Process exit | Target unaffected | +| Target crash | Process exit | All components survive | + +### Recovery + +- IPC channels automatically reconnect with exponential backoff +- Shared memory handles are validated before each access +- Missing state returns explicit errors, not crashes + +## Security + +### Threat Model + +The Injection Toolkit operates in a hostile environment where the injected code +runs inside an untrusted process. The daemon and overlay must treat ALL data +from the injector as **potentially malicious**. + +| Component | Trust Level | Threat | +|-----------|-------------|--------| +| Injector | **UNTRUSTED** | Compromised target, malicious mods, memory corruption | +| Daemon | Trusted | Local process with validated inputs | +| Overlay | Trusted | Local process with validated inputs | +| Shared Memory | Untrusted data | Injector can write arbitrary bytes | + +### Input Validation + +The daemon validates all incoming data before use: + +```rust +// String length limits +const MAX_STRING_LEN: usize = 256; +const MAX_DATA_SIZE: usize = 64 * 1024; // 64 KB + +// Numeric bounds checking +const MAX_SCREEN_DIM: f32 = 16384.0; + +// Float validation (reject NaN/Inf) +if !value.is_finite() { + bail!("Non-finite value rejected"); +} + +// Dimension validation +if width < 0.0 || height < 0.0 { + bail!("Negative dimensions rejected"); +} +``` + +### IPC Security + +#### Windows Named Pipes + +Named pipes should use appropriate security descriptors: + +- Default: Local user access only (inherited from process token) +- Custom: Use `SECURITY_ATTRIBUTES` to restrict access further +- Never expose pipes to network without explicit intent + +```rust +// Recommended: Restrict to current user +let mut sa = SECURITY_ATTRIBUTES::default(); +// Set up DACL allowing only current user... +``` + +#### Linux Unix Sockets + +Unix domain sockets use filesystem permissions: + +- Socket created with `0600` permissions (owner only) +- Located in `/tmp` with sticky bit protection +- Consider `SO_PASSCRED` for peer authentication + +```rust +// Socket path: /tmp/itk_{name}.sock +// Permissions: -rw------- (0600) +``` + +### Shared Memory Security + +- Memory regions are created with restrictive permissions +- Size is fixed at creation to prevent overflow +- Triple-buffering prevents reader/writer corruption +- Seqlock provides consistency, not access control + +### Defense in Depth + +1. **Protocol validation**: Magic bytes, version, CRC32 +2. **Size limits**: Payload bounded to 1MB max +3. **Type validation**: All fields checked before use +4. **Fail-safe**: Invalid data logged and rejected, never crashes +5. **Isolation**: Components run in separate processes + +### What We Don't Protect Against + +The toolkit does not protect against: + +- Malicious overlay/daemon (these are trusted) +- Kernel-level attacks +- Physical access attacks +- Side-channel attacks + +These are out of scope for a userspace injection framework. + +## Performance Budgets + +### Memory + +| Component | Budget | Notes | +|-----------|--------|-------| +| Injector | < 5 MB | Minimal footprint | +| Daemon | < 30 MB | State caching | +| Overlay | < 20 MB | GPU resources | +| Shmem | ~10 MB | Triple-buffered 720p | + +### Latency + +| Operation | Target | Notes | +|-----------|--------|-------| +| State update (IPC) | < 1 ms | Local only | +| Frame copy (shmem) | < 1 ms | ~3.5 MB @ 720p | +| Overlay render | < 5 ms | Simple quad | + +## Extending the Toolkit + +### Adding a New Injector Platform + +1. Create new crate in `injectors/` +2. Implement IPC client connection +3. Implement platform-specific initialization +4. Export state using `itk-protocol` messages + +### Adding a New Message Type + +1. Add variant to `MessageType` enum in `itk-protocol` +2. Define payload struct with serde derives +3. Update daemon message handlers +4. Update clients as needed + +### Adding a New Platform + +1. Add platform module in `itk-shmem` and `itk-ipc` +2. Implement platform traits +3. Update `cfg_if!` blocks +4. Add platform module in overlay if needed diff --git a/docs/MIGRATION.md b/docs/MIGRATION.md new file mode 100644 index 0000000..a579f28 --- /dev/null +++ b/docs/MIGRATION.md @@ -0,0 +1,176 @@ +# Migration Plans + +## FlatBuffers Migration (Planned) + +The current `itk-protocol` crate uses bincode for payload serialization. This +works well for pure-Rust projects but limits cross-language interoperability. + +### Why Migrate? + +| Aspect | bincode (Current) | FlatBuffers (Proposed) | +|--------|-------------------|------------------------| +| Languages | Rust only | C, C++, C#, Go, Python, etc. | +| Schema evolution | Breaking changes | Backwards compatible | +| Zero-copy reads | No | Yes | +| Performance | Good | Excellent | +| Tooling | Simple | Requires flatc compiler | + +### Use Cases Enabled by FlatBuffers + +1. **C++ Injectors**: Direct injection without Rust compilation +2. **Python MCP Servers**: Read state without Rust bindings +3. **Unity/C# Overlays**: Native integration +4. **Version Compatibility**: Old injectors work with new daemons + +### Migration Strategy + +#### Phase 1: Schema Definition + +Create FlatBuffers schema files alongside existing Rust code: + +```flatbuffers +// itk-protocol/schemas/messages.fbs +namespace itk.protocol; + +enum MessageType : uint32 { + Ping = 0, + Pong = 1, + ScreenRect = 2, + WindowState = 3, + StateSnapshot = 4, + StateEvent = 5, + StateQuery = 6, + StateResponse = 7, +} + +table Header { + magic: uint32; + version: uint32; + msg_type: MessageType; + payload_len: uint32; + crc32: uint32; +} + +table ScreenRect { + x: float32; + y: float32; + width: float32; + height: float32; + rotation: float32; +} + +table StateEvent { + app_id: string; + event_type: string; + timestamp_ms: uint64; + data: string; +} + +// ... additional tables +``` + +#### Phase 2: Dual Support + +1. Add `flatbuffers` crate dependency +2. Generate Rust code from schemas +3. Support both formats with version detection +4. Header byte indicates format: bincode=0x00, flatbuffers=0x01 + +```rust +pub fn decode(data: &[u8]) -> Result<(Header, T)> { + let header = decode_header(data)?; + match header.version & 0xFF { + 0x00 => decode_bincode(data), + 0x01 => decode_flatbuffers(data), + _ => Err(ProtocolError::UnsupportedVersion), + } +} +``` + +#### Phase 3: Default Switch + +1. Update default encoding to FlatBuffers +2. Mark bincode as deprecated +3. Update all injector templates + +#### Phase 4: Removal (Major Version) + +1. Remove bincode support in next major version +2. Simplify codebase +3. Update documentation + +### Compatibility Matrix + +| Client Version | Server Version | Wire Format | +|----------------|----------------|-------------| +| 1.x | 1.x | bincode | +| 2.x | 1.x | bincode (fallback) | +| 1.x | 2.x | bincode (detected) | +| 2.x | 2.x | flatbuffers | + +### Build Changes + +FlatBuffers requires the `flatc` compiler: + +```bash +# Ubuntu/Debian +apt install flatbuffers-compiler + +# macOS +brew install flatbuffers + +# Windows +# Download from https://github.com/google/flatbuffers/releases +``` + +Build script integration: + +```rust +// build.rs +fn main() { + // Generate Rust from .fbs files + flatc_rust::run(flatc_rust::Args { + inputs: &["schemas/messages.fbs"], + out_dir: "src/generated/", + ..Default::default() + }).expect("flatc failed"); +} +``` + +### Alternatives Considered + +#### Protocol Buffers + +- Pros: Mature ecosystem, widely used +- Cons: Requires copy on read, larger runtime + +#### Cap'n Proto + +- Pros: Zero-copy like FlatBuffers +- Cons: Less portable, fewer language bindings + +#### MessagePack + +- Pros: Simple, schemaless +- Cons: No schema evolution guarantees + +### Decision + +FlatBuffers chosen for: +1. Zero-copy reads (important for high-frequency state updates) +2. Broad language support (C++ injectors, Python tools) +3. Schema evolution (version compatibility) +4. Google backing and active development + +### Timeline + +This migration is planned but not scheduled. Implementation will begin when: +- Cross-language injector support is needed +- Python MCP direct integration is prioritized +- Major version bump is planned + +### References + +- [FlatBuffers Documentation](https://flatbuffers.dev/) +- [FlatBuffers Rust Crate](https://crates.io/crates/flatbuffers) +- [Schema Evolution Best Practices](https://flatbuffers.dev/flatbuffers_guide_writing_schema.html) diff --git a/injectors/linux/ld-preload/Cargo.toml b/injectors/linux/ld-preload/Cargo.toml new file mode 100644 index 0000000..a6b72bf --- /dev/null +++ b/injectors/linux/ld-preload/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "itk-ld-preload" +description = "LD_PRELOAD injection template for the Injection Toolkit" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +authors.workspace = true + +[lib] +crate-type = ["cdylib"] + +[dependencies] +itk-protocol = { path = "../../../core/itk-protocol" } +itk-ipc = { path = "../../../core/itk-ipc" } + +libc = { workspace = true } +tracing = { workspace = true } + +[lints] +workspace = true diff --git a/injectors/linux/ld-preload/src/lib.rs b/injectors/linux/ld-preload/src/lib.rs new file mode 100644 index 0000000..f4c7f20 --- /dev/null +++ b/injectors/linux/ld-preload/src/lib.rs @@ -0,0 +1,95 @@ +//! # ITK LD_PRELOAD Injector Template +//! +//! Template for creating injectable shared libraries on Linux using LD_PRELOAD. +//! +//! ## Usage +//! +//! ```bash +//! LD_PRELOAD=/path/to/libitk_preload.so ./target_application +//! ``` +//! +//! ## Customization +//! +//! 1. Implement `init()` to set up your hooks and IPC connection +//! 2. Use `dlsym(RTLD_NEXT, ...)` to hook functions +//! 3. Send state updates via the IPC channel + +use itk_ipc::{IpcChannel, UnixSocketClient}; +use std::sync::OnceLock; + +static IPC_CHANNEL: OnceLock = OnceLock::new(); + +/// Called when the library is loaded (via constructor attribute) +/// +/// This is where you should: +/// - Connect to the daemon via IPC +/// - Set up function hooks +/// - Initialize any state tracking +#[unsafe(no_mangle)] +pub extern "C" fn itk_init() { + // Connect to daemon + match itk_ipc::connect("itk_injector") { + Ok(channel) => { + let _ = IPC_CHANNEL.set(channel); + // Log success (can't use tracing easily in injected context) + eprintln!("[ITK] Connected to daemon"); + }, + Err(e) => { + eprintln!("[ITK] Failed to connect to daemon: {:?}", e); + }, + } +} + +// Example: Hook a function by defining it with the same signature +// +// ```rust,ignore +// #[unsafe(no_mangle)] +// pub extern "C" fn target_function(arg: i32) -> i32 { +// // Your pre-hook logic here +// +// // Call the original function +// type OrigFn = extern "C" fn(i32) -> i32; +// let orig: OrigFn = unsafe { +// std::mem::transmute(libc::dlsym(libc::RTLD_NEXT, b"target_function\0".as_ptr() as *const _)) +// }; +// let result = orig(arg); +// +// // Your post-hook logic here +// +// result +// } +// ``` + +/// Send a state update to the daemon +pub fn send_state_event(event_type: &str, data: &str) { + if let Some(channel) = IPC_CHANNEL.get() { + let event = itk_protocol::StateEvent { + app_id: "itk_app".to_string(), + event_type: event_type.to_string(), + timestamp_ms: now_ms(), + data: data.to_string(), + }; + + if let Ok(encoded) = itk_protocol::encode(itk_protocol::MessageType::StateEvent, &event) { + let _ = channel.send(&encoded); + } + } +} + +fn now_ms() -> u64 { + use std::time::{SystemTime, UNIX_EPOCH}; + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_millis() as u64 +} + +// Constructor attribute to call itk_init on library load +#[used] +#[unsafe(link_section = ".init_array")] +static INIT: extern "C" fn() = { + extern "C" fn init() { + itk_init(); + } + init +}; diff --git a/injectors/windows/native-dll/Cargo.toml b/injectors/windows/native-dll/Cargo.toml new file mode 100644 index 0000000..ca755cf --- /dev/null +++ b/injectors/windows/native-dll/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "itk-native-dll" +description = "Native DLL injection template for the Injection Toolkit (Windows)" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +authors.workspace = true + +[lib] +crate-type = ["cdylib"] + +[dependencies] +itk-protocol = { path = "../../../core/itk-protocol" } +itk-ipc = { path = "../../../core/itk-ipc" } + +windows = { workspace = true, features = ["Win32_System_SystemServices"] } +tracing = { workspace = true } + +[lints] +workspace = true diff --git a/injectors/windows/native-dll/src/lib.rs b/injectors/windows/native-dll/src/lib.rs new file mode 100644 index 0000000..695628f --- /dev/null +++ b/injectors/windows/native-dll/src/lib.rs @@ -0,0 +1,128 @@ +//! # ITK Native DLL Injector Template (Windows) +//! +//! Template for creating injectable DLLs on Windows. +//! +//! ## Injection Methods +//! +//! This DLL can be injected via: +//! - LoadLibrary injection +//! - Manual mapping +//! - Mod frameworks (Reloaded-II, MelonLoader, etc.) +//! +//! ## Customization +//! +//! 1. Implement `on_attach()` to set up your hooks and IPC connection +//! 2. Use a hooking library (detours, minhook) for function hooking +//! 3. Send state updates via the IPC channel + +use std::sync::OnceLock; +use windows::Win32::Foundation::{BOOL, HINSTANCE, TRUE}; +use windows::Win32::System::SystemServices::{DLL_PROCESS_ATTACH, DLL_PROCESS_DETACH}; + +static IPC_CHANNEL: OnceLock = OnceLock::new(); + +/// DLL entry point +#[unsafe(no_mangle)] +pub extern "system" fn DllMain( + _hinst: HINSTANCE, + reason: u32, + _reserved: *mut std::ffi::c_void, +) -> BOOL { + match reason { + DLL_PROCESS_ATTACH => { + // Spawn a thread for initialization to avoid loader lock issues + std::thread::spawn(|| { + on_attach(); + }); + }, + DLL_PROCESS_DETACH => { + on_detach(); + }, + _ => {}, + } + TRUE +} + +/// Called when DLL is attached to process +fn on_attach() { + // Connect to daemon + match itk_ipc::connect("itk_injector") { + Ok(channel) => { + let _ = IPC_CHANNEL.set(channel); + log("[ITK] Connected to daemon"); + }, + Err(e) => { + log(&format!("[ITK] Failed to connect to daemon: {:?}", e)); + }, + } + + // TODO: Set up your hooks here + // Example with a hypothetical hooking library: + // unsafe { + // hooks::install_hook("kernel32.dll", "CreateFileW", my_create_file_hook); + // } +} + +/// Called when DLL is detached from process +fn on_detach() { + // TODO: Clean up hooks + log("[ITK] Detaching"); +} + +/// Send a state update to the daemon +pub fn send_state_event(event_type: &str, data: &str) { + if let Some(channel) = IPC_CHANNEL.get() { + let event = itk_protocol::StateEvent { + app_id: "itk_app".to_string(), + event_type: event_type.to_string(), + timestamp_ms: now_ms(), + data: data.to_string(), + }; + + if let Ok(encoded) = itk_protocol::encode(itk_protocol::MessageType::StateEvent, &event) { + let _ = channel.send(&encoded); + } + } +} + +/// Send a screen rect update +pub fn send_screen_rect(x: f32, y: f32, width: f32, height: f32) { + if let Some(channel) = IPC_CHANNEL.get() { + let rect = itk_protocol::ScreenRect { + x, + y, + width, + height, + rotation: 0.0, + visible: true, + }; + + if let Ok(encoded) = itk_protocol::encode(itk_protocol::MessageType::ScreenRect, &rect) { + let _ = channel.send(&encoded); + } + } +} + +fn now_ms() -> u64 { + use std::time::{SystemTime, UNIX_EPOCH}; + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_millis() as u64 +} + +/// Simple logging (OutputDebugString on Windows) +fn log(msg: &str) { + #[cfg(debug_assertions)] + { + use std::ffi::CString; + if let Ok(c_msg) = CString::new(msg) { + unsafe { + windows::Win32::System::Diagnostics::Debug::OutputDebugStringA( + windows::core::PCSTR(c_msg.as_ptr() as *const _), + ); + } + } + } + let _ = msg; // Suppress unused warning in release +} diff --git a/overlay/Cargo.toml b/overlay/Cargo.toml new file mode 100644 index 0000000..9b44aaf --- /dev/null +++ b/overlay/Cargo.toml @@ -0,0 +1,40 @@ +[package] +name = "itk-overlay" +description = "Overlay window template for the Injection Toolkit" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +authors.workspace = true + +[dependencies] +itk-protocol = { path = "../core/itk-protocol" } +itk-shmem = { path = "../core/itk-shmem" } +itk-ipc = { path = "../core/itk-ipc" } + +tracing = { workspace = true } +tracing-subscriber = { workspace = true } +thiserror = { workspace = true } +anyhow = { workspace = true } + +# GPU rendering +wgpu = "22" +winit = "0.30" +pollster = "0.3" +bytemuck = { version = "1.16", features = ["derive"] } + +# Platform-specific +cfg-if = { workspace = true } + +[target.'cfg(windows)'.dependencies] +windows = { workspace = true, features = [ + "Win32_UI_WindowsAndMessaging", + "Win32_Graphics_Gdi", +] } + +[target.'cfg(unix)'.dependencies] +# Using x11rb (safe Rust bindings) instead of raw x11/xlib +x11rb = { version = "0.13", features = ["shape", "xfixes"] } + +[lints] +workspace = true diff --git a/overlay/src/lib.rs b/overlay/src/lib.rs new file mode 100644 index 0000000..7af79f0 --- /dev/null +++ b/overlay/src/lib.rs @@ -0,0 +1,100 @@ +//! # ITK Overlay +//! +//! Cross-platform overlay window for the Injection Toolkit. +//! +//! This library provides: +//! - Transparent, always-on-top overlay window +//! - Click-through mode (input passes to underlying window) +//! - Interactive mode (captures input for UI) +//! - wgpu-based rendering +//! +//! ## Platform Support +//! +//! - **Windows**: Uses `WS_EX_TRANSPARENT` and `WS_EX_NOACTIVATE` for click-through +//! - **Linux**: Uses X11 hints (Wayland support planned) + +pub mod platform; +pub mod render; +pub mod video; + +use itk_protocol::ScreenRect; +use thiserror::Error; + +/// Overlay errors +#[derive(Error, Debug)] +pub enum OverlayError { + #[error("failed to create window: {0}")] + WindowCreation(String), + + #[error("failed to initialize renderer: {0}")] + RendererInit(String), + + #[error("platform not supported: {0}")] + UnsupportedPlatform(String), + + #[error("lost GPU device")] + DeviceLost, + + #[error("IPC error: {0}")] + Ipc(#[from] itk_ipc::IpcError), +} + +/// Result type for overlay operations +pub type Result = std::result::Result; + +/// Overlay configuration +#[derive(Debug, Clone)] +pub struct OverlayConfig { + /// Window title (usually not visible) + pub title: String, + + /// Initial width + pub width: u32, + + /// Initial height + pub height: u32, + + /// Start in click-through mode + pub click_through: bool, + + /// Daemon IPC channel name + pub daemon_channel: String, +} + +impl Default for OverlayConfig { + fn default() -> Self { + Self { + title: "ITK Overlay".to_string(), + width: 1920, + height: 1080, + click_through: true, + daemon_channel: "itk_client".to_string(), + } + } +} + +/// Overlay state +#[derive(Debug)] +pub struct OverlayState { + /// Current screen rect for rendering + pub screen_rect: Option, + + /// Whether in click-through mode + pub click_through: bool, + + /// Whether overlay should be visible + pub visible: bool, +} + +impl Default for OverlayState { + fn default() -> Self { + Self { + screen_rect: None, + click_through: true, + visible: true, + } + } +} + +/// Toggle key for switching between click-through and interactive mode +pub const TOGGLE_KEY: &str = "F9"; diff --git a/overlay/src/main.rs b/overlay/src/main.rs new file mode 100644 index 0000000..8ea73f7 --- /dev/null +++ b/overlay/src/main.rs @@ -0,0 +1,212 @@ +//! # ITK Overlay Application +//! +//! Example overlay application using the ITK overlay library. + +use anyhow::Result; +use itk_overlay::{ + OverlayConfig, OverlayState, platform, render::Renderer, video::VideoFrameReader, +}; +use itk_protocol::ScreenRect; +use std::sync::Arc; +use tracing::{debug, error, info}; +use winit::{ + application::ApplicationHandler, + event::WindowEvent, + event_loop::{ActiveEventLoop, ControlFlow, EventLoop}, + window::{Window, WindowAttributes, WindowLevel}, +}; + +/// Default video rectangle: centered, 720p +static DEFAULT_VIDEO_RECT: ScreenRect = ScreenRect { + x: 320.0, + y: 180.0, + width: 1280.0, + height: 720.0, + rotation: 0.0, + visible: true, +}; + +struct App { + config: OverlayConfig, + state: OverlayState, + window: Option>, + renderer: Option, + /// Video frame reader for shared memory + frame_reader: VideoFrameReader, + /// Whether we've logged the connection status + logged_connection: bool, +} + +impl App { + fn new(config: OverlayConfig) -> Self { + Self { + config, + state: OverlayState::default(), + window: None, + renderer: None, + frame_reader: VideoFrameReader::new(), + logged_connection: false, + } + } +} + +impl ApplicationHandler for App { + fn resumed(&mut self, event_loop: &ActiveEventLoop) { + if self.window.is_some() { + return; + } + + // Create window + let window_attrs = WindowAttributes::default() + .with_title(&self.config.title) + .with_inner_size(winit::dpi::LogicalSize::new( + self.config.width, + self.config.height, + )) + .with_transparent(true) + .with_decorations(false) + .with_window_level(WindowLevel::AlwaysOnTop); + + let window = match event_loop.create_window(window_attrs) { + Ok(w) => Arc::new(w), + Err(e) => { + error!(?e, "Failed to create window"); + event_loop.exit(); + return; + }, + }; + + // Set platform-specific attributes + if let Err(e) = platform::set_transparent(&window) { + error!(?e, "Failed to set transparent"); + } + if let Err(e) = platform::set_always_on_top(&window, true) { + error!(?e, "Failed to set always-on-top"); + } + if self.state.click_through + && let Err(e) = platform::set_click_through(&window, true) + { + error!(?e, "Failed to set click-through"); + } + + // Create renderer + let renderer = pollster::block_on(Renderer::new(Arc::clone(&window))); + match renderer { + Ok(r) => { + info!("Renderer initialized"); + self.renderer = Some(r); + }, + Err(e) => { + error!(?e, "Failed to create renderer"); + event_loop.exit(); + return; + }, + } + + self.window = Some(window); + info!("Overlay window created"); + } + + fn window_event( + &mut self, + event_loop: &ActiveEventLoop, + _window_id: winit::window::WindowId, + event: WindowEvent, + ) { + match event { + WindowEvent::CloseRequested => { + event_loop.exit(); + }, + + WindowEvent::Resized(physical_size) => { + if let Some(renderer) = &mut self.renderer { + renderer.resize(physical_size); + } + }, + + WindowEvent::KeyboardInput { event, .. } => { + // F9 toggles click-through mode + if event.state.is_pressed() + && let winit::keyboard::PhysicalKey::Code(winit::keyboard::KeyCode::F9) = + event.physical_key + { + self.state.click_through = !self.state.click_through; + if let Some(window) = &self.window + && let Err(e) = + platform::set_click_through(window, self.state.click_through) + { + error!(?e, "Failed to toggle click-through"); + } + info!(click_through = %self.state.click_through, "Toggled click-through mode"); + } + }, + + WindowEvent::RedrawRequested => { + if let Some(renderer) = &mut self.renderer { + // Use provided screen rect, or default to center of screen if video is playing + let screen_rect = self.state.screen_rect.as_ref().or_else(|| { + if self.frame_reader.is_connected() && self.frame_reader.last_pts_ms() > 0 { + // Default rect: centered, 720p + Some(&DEFAULT_VIDEO_RECT) + } else { + None + } + }); + + if let Err(e) = renderer.render(screen_rect) { + error!(?e, "Render failed"); + } + } + }, + + _ => {}, + } + } + + fn about_to_wait(&mut self, _event_loop: &ActiveEventLoop) { + // Log connection status changes + if !self.logged_connection && self.frame_reader.is_connected() { + info!("Connected to video frame buffer"); + self.logged_connection = true; + } + + // Try to read a new video frame + if let Some(frame_data) = self.frame_reader.try_read_frame() + && let Some(renderer) = &self.renderer + { + renderer.update_texture(frame_data); + debug!( + pts_ms = self.frame_reader.last_pts_ms(), + "Updated texture with new frame" + ); + } + + if let Some(window) = &self.window { + window.request_redraw(); + } + } +} + +fn main() -> Result<()> { + // Initialize logging + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive("itk_overlay=info".parse().unwrap()), + ) + .init(); + + info!("Starting ITK Overlay"); + + // Create config + let config = OverlayConfig::default(); + + // Create event loop and app + let event_loop = EventLoop::new()?; + event_loop.set_control_flow(ControlFlow::Poll); + + let mut app = App::new(config); + event_loop.run_app(&mut app)?; + + Ok(()) +} diff --git a/overlay/src/platform/linux.rs b/overlay/src/platform/linux.rs new file mode 100644 index 0000000..f39d700 --- /dev/null +++ b/overlay/src/platform/linux.rs @@ -0,0 +1,269 @@ +//! Linux-specific overlay functionality (X11) +//! +//! All X11 operations use the safe x11rb crate instead of raw xlib bindings. +//! +//! # Platform Support +//! +//! - **X11**: Full support for overlays, click-through, always-on-top +//! - **Wayland**: Not currently supported (requires layer-shell protocol) +//! - **Headless**: Not supported (requires a display server) + +use crate::{OverlayError, Result}; +use winit::raw_window_handle::{HasWindowHandle, RawWindowHandle}; +use x11rb::connection::Connection; +use x11rb::protocol::shape::{self, ConnectionExt as ShapeConnectionExt, SK}; +use x11rb::protocol::xproto::{ + AtomEnum, CLIENT_MESSAGE_EVENT, ClientMessageData, ClientMessageEvent, ConnectionExt, + EventMask, PropMode, +}; +use x11rb::rust_connection::RustConnection; +use x11rb::wrapper::ConnectionExt as WrapperConnectionExt; + +// Constants for _NET_WM_STATE actions +const NET_WM_STATE_REMOVE: u32 = 0; +const NET_WM_STATE_ADD: u32 = 1; + +/// Check if running in a headless environment (no display available) +fn is_headless() -> bool { + std::env::var("DISPLAY").is_err() && std::env::var("WAYLAND_DISPLAY").is_err() +} + +/// Helper to create a new x11rb connection with better error messages +fn connect_x11() -> Result<(RustConnection, usize)> { + // Check for headless environment first + if is_headless() { + return Err(OverlayError::UnsupportedPlatform( + "No display server available. Overlay requires X11 or Wayland. \ + Set DISPLAY environment variable or run with a display server." + .into(), + )); + } + + RustConnection::connect(None).map_err(|e| { + // Provide more helpful error messages based on common failure modes + let error_str = e.to_string(); + if error_str.contains("Connection refused") || error_str.contains("No such file") { + OverlayError::WindowCreation(format!( + "X11 connection failed: {}. Is the X server running? \ + Check that DISPLAY is set correctly (current: {:?})", + e, + std::env::var("DISPLAY").ok() + )) + } else { + OverlayError::WindowCreation(format!("X11 connection failed: {}", e)) + } + }) +} + +/// Set click-through mode using X11 shape extension +/// +/// On X11, we use the SHAPE extension to set an empty input region, +/// making all input pass through to the window below. +pub fn set_click_through_impl(window: &winit::window::Window, enabled: bool) -> Result<()> { + let handle = window + .window_handle() + .map_err(|e| OverlayError::WindowCreation(e.to_string()))?; + + match handle.as_raw() { + RawWindowHandle::Xlib(xlib_handle) => { + let window_id = xlib_handle.window as u32; + let (conn, _screen_num) = connect_x11()?; + + // Check if shape extension is available + let shape_ext = conn + .query_extension(shape::X11_EXTENSION_NAME.as_bytes()) + .map_err(|e| OverlayError::WindowCreation(format!("Shape query failed: {}", e)))? + .reply() + .map_err(|e| OverlayError::WindowCreation(format!("Shape reply failed: {}", e)))?; + + if !shape_ext.present { + return Err(OverlayError::UnsupportedPlatform( + "X11 SHAPE extension not available".into(), + )); + } + + if enabled { + // Set empty input shape - all input passes through + // Using ShapeInput (SK::INPUT) with an empty rectangle list + conn.shape_rectangles( + shape::SO::SET, + SK::INPUT, + x11rb::protocol::xproto::ClipOrdering::UNSORTED, + window_id, + 0, + 0, + &[], // Empty rectangle list = no input region + ) + .map_err(|e| { + OverlayError::WindowCreation(format!("Shape rectangles failed: {}", e)) + })?; + } else { + // Reset input shape to default (full window receives input) + conn.shape_mask(shape::SO::SET, SK::INPUT, window_id, 0, 0, x11rb::NONE) + .map_err(|e| { + OverlayError::WindowCreation(format!("Shape mask failed: {}", e)) + })?; + } + + conn.flush() + .map_err(|e| OverlayError::WindowCreation(format!("X11 flush failed: {}", e)))?; + + tracing::debug!( + "Click-through {} for window {}", + if enabled { "enabled" } else { "disabled" }, + window_id + ); + Ok(()) + }, + RawWindowHandle::Xcb(_) => { + // XCB support could be added here using the same x11rb connection + Err(OverlayError::UnsupportedPlatform( + "XCB not yet supported, use Xlib".into(), + )) + }, + RawWindowHandle::Wayland(_) => { + // Wayland doesn't support click-through in the same way + // Layer-shell protocol would be needed + Err(OverlayError::UnsupportedPlatform( + "Wayland is not yet supported for overlay windows. \ + Click-through requires the layer-shell protocol which is compositor-specific. \ + Consider running under XWayland (set GDK_BACKEND=x11 or QT_QPA_PLATFORM=xcb) \ + or use an X11 session." + .into(), + )) + }, + _ => Err(OverlayError::UnsupportedPlatform( + "Unknown window handle type".into(), + )), + } +} + +/// Set always-on-top using _NET_WM_STATE +/// +/// Sends a client message to the window manager to add/remove the +/// _NET_WM_STATE_ABOVE property. +pub fn set_always_on_top_impl(window: &winit::window::Window, enabled: bool) -> Result<()> { + let handle = window + .window_handle() + .map_err(|e| OverlayError::WindowCreation(e.to_string()))?; + + match handle.as_raw() { + RawWindowHandle::Xlib(xlib_handle) => { + let window_id = xlib_handle.window as u32; + let (conn, screen_num) = connect_x11()?; + + // Get the root window + let screen = &conn.setup().roots[screen_num]; + let root = screen.root; + + // Intern atoms for _NET_WM_STATE protocol + let wm_state = conn + .intern_atom(false, b"_NET_WM_STATE") + .map_err(|e| OverlayError::WindowCreation(format!("Intern atom failed: {}", e)))? + .reply() + .map_err(|e| OverlayError::WindowCreation(format!("Atom reply failed: {}", e)))? + .atom; + + let wm_state_above = conn + .intern_atom(false, b"_NET_WM_STATE_ABOVE") + .map_err(|e| OverlayError::WindowCreation(format!("Intern atom failed: {}", e)))? + .reply() + .map_err(|e| OverlayError::WindowCreation(format!("Atom reply failed: {}", e)))? + .atom; + + // Build client message data: + // data[0] = action (ADD/REMOVE) + // data[1] = first property atom + // data[2] = second property atom (0 if none) + // data[3] = source indication (1 = normal application) + let action = if enabled { + NET_WM_STATE_ADD + } else { + NET_WM_STATE_REMOVE + }; + let data = ClientMessageData::from([action, wm_state_above, 0u32, 1u32, 0u32]); + + // Create client message event + let event = ClientMessageEvent { + response_type: CLIENT_MESSAGE_EVENT, + format: 32, + sequence: 0, + window: window_id, + type_: wm_state, + data, + }; + + // Send to root window with substructure masks + conn.send_event( + false, + root, + EventMask::SUBSTRUCTURE_REDIRECT | EventMask::SUBSTRUCTURE_NOTIFY, + event, + ) + .map_err(|e| OverlayError::WindowCreation(format!("Send event failed: {}", e)))?; + + conn.flush() + .map_err(|e| OverlayError::WindowCreation(format!("X11 flush failed: {}", e)))?; + + tracing::debug!( + "Always-on-top {} for window {}", + if enabled { "enabled" } else { "disabled" }, + window_id + ); + Ok(()) + }, + _ => Err(OverlayError::UnsupportedPlatform( + "Only Xlib supported for always-on-top".into(), + )), + } +} + +/// Set window type hint for overlay behavior +/// +/// Sets _NET_WM_WINDOW_TYPE to DOCK, which tells the window manager +/// this is an overlay/dock window that should be treated specially. +pub fn set_transparent_impl(window: &winit::window::Window) -> Result<()> { + let handle = window + .window_handle() + .map_err(|e| OverlayError::WindowCreation(e.to_string()))?; + + match handle.as_raw() { + RawWindowHandle::Xlib(xlib_handle) => { + let window_id = xlib_handle.window as u32; + let (conn, _screen_num) = connect_x11()?; + + // Intern atoms for window type + let wm_window_type = conn + .intern_atom(false, b"_NET_WM_WINDOW_TYPE") + .map_err(|e| OverlayError::WindowCreation(format!("Intern atom failed: {}", e)))? + .reply() + .map_err(|e| OverlayError::WindowCreation(format!("Atom reply failed: {}", e)))? + .atom; + + let wm_window_type_dock = conn + .intern_atom(false, b"_NET_WM_WINDOW_TYPE_DOCK") + .map_err(|e| OverlayError::WindowCreation(format!("Intern atom failed: {}", e)))? + .reply() + .map_err(|e| OverlayError::WindowCreation(format!("Atom reply failed: {}", e)))? + .atom; + + // Set window type property + // Property format is 32-bit atoms + conn.change_property32( + PropMode::REPLACE, + window_id, + wm_window_type, + AtomEnum::ATOM, + &[wm_window_type_dock], + ) + .map_err(|e| OverlayError::WindowCreation(format!("Change property failed: {}", e)))?; + + conn.flush() + .map_err(|e| OverlayError::WindowCreation(format!("X11 flush failed: {}", e)))?; + + tracing::debug!("Set window type to DOCK for window {}", window_id); + Ok(()) + }, + _ => Ok(()), // Ignore for other platforms + } +} diff --git a/overlay/src/platform/mod.rs b/overlay/src/platform/mod.rs new file mode 100644 index 0000000..03babbf --- /dev/null +++ b/overlay/src/platform/mod.rs @@ -0,0 +1,30 @@ +//! Platform-specific overlay functionality + +cfg_if::cfg_if! { + if #[cfg(windows)] { + mod windows; + pub use windows::*; + } else if #[cfg(target_os = "linux")] { + mod linux; + pub use linux::*; + } else { + compile_error!("Unsupported platform for overlay"); + } +} + +/// Set click-through mode for a window +/// +/// When enabled, mouse input passes through the window to the one behind it. +pub fn set_click_through(window: &winit::window::Window, enabled: bool) -> crate::Result<()> { + set_click_through_impl(window, enabled) +} + +/// Set always-on-top for a window +pub fn set_always_on_top(window: &winit::window::Window, enabled: bool) -> crate::Result<()> { + set_always_on_top_impl(window, enabled) +} + +/// Make window transparent (for compositor) +pub fn set_transparent(window: &winit::window::Window) -> crate::Result<()> { + set_transparent_impl(window) +} diff --git a/overlay/src/platform/windows.rs b/overlay/src/platform/windows.rs new file mode 100644 index 0000000..33153b5 --- /dev/null +++ b/overlay/src/platform/windows.rs @@ -0,0 +1,73 @@ +//! Windows-specific overlay functionality + +use crate::{OverlayError, Result}; +use winit::raw_window_handle::{HasWindowHandle, RawWindowHandle}; + +use windows::Win32::Foundation::HWND; +use windows::Win32::UI::WindowsAndMessaging::{ + GWL_EXSTYLE, GetWindowLongW, HWND_TOPMOST, SWP_NOMOVE, SWP_NOSIZE, SetWindowLongW, + SetWindowPos, WS_EX_LAYERED, WS_EX_NOACTIVATE, WS_EX_TOOLWINDOW, WS_EX_TRANSPARENT, +}; + +/// Get the HWND from a winit window +fn get_hwnd(window: &winit::window::Window) -> Result { + match window + .window_handle() + .map_err(|e| OverlayError::WindowCreation(e.to_string()))? + .as_raw() + { + RawWindowHandle::Win32(handle) => Ok(HWND(handle.hwnd.get() as *mut _)), + _ => Err(OverlayError::UnsupportedPlatform( + "Expected Win32 window handle".into(), + )), + } +} + +pub fn set_click_through_impl(window: &winit::window::Window, enabled: bool) -> Result<()> { + let hwnd = get_hwnd(window)?; + + unsafe { + let mut ex_style = GetWindowLongW(hwnd, GWL_EXSTYLE) as u32; + + if enabled { + // Enable click-through + ex_style |= WS_EX_TRANSPARENT.0 | WS_EX_LAYERED.0 | WS_EX_NOACTIVATE.0; + } else { + // Disable click-through + ex_style &= !(WS_EX_TRANSPARENT.0 | WS_EX_NOACTIVATE.0); + } + + SetWindowLongW(hwnd, GWL_EXSTYLE, ex_style as i32); + } + + Ok(()) +} + +pub fn set_always_on_top_impl(window: &winit::window::Window, enabled: bool) -> Result<()> { + let hwnd = get_hwnd(window)?; + + unsafe { + let insert_after = if enabled { + HWND_TOPMOST + } else { + windows::Win32::UI::WindowsAndMessaging::HWND_NOTOPMOST + }; + + SetWindowPos(hwnd, insert_after, 0, 0, 0, 0, SWP_NOMOVE | SWP_NOSIZE) + .map_err(|e| OverlayError::WindowCreation(e.to_string()))?; + } + + Ok(()) +} + +pub fn set_transparent_impl(window: &winit::window::Window) -> Result<()> { + let hwnd = get_hwnd(window)?; + + unsafe { + let mut ex_style = GetWindowLongW(hwnd, GWL_EXSTYLE) as u32; + ex_style |= WS_EX_LAYERED.0 | WS_EX_TOOLWINDOW.0; + SetWindowLongW(hwnd, GWL_EXSTYLE, ex_style as i32); + } + + Ok(()) +} diff --git a/overlay/src/render.rs b/overlay/src/render.rs new file mode 100644 index 0000000..4987ecf --- /dev/null +++ b/overlay/src/render.rs @@ -0,0 +1,364 @@ +//! wgpu-based rendering for the overlay + +use crate::{OverlayError, Result}; +use itk_protocol::ScreenRect; +use std::sync::Arc; + +/// Renderer state +pub struct Renderer { + surface: wgpu::Surface<'static>, + device: wgpu::Device, + queue: wgpu::Queue, + config: wgpu::SurfaceConfiguration, + render_pipeline: wgpu::RenderPipeline, + vertex_buffer: wgpu::Buffer, + bind_group: wgpu::BindGroup, + texture: wgpu::Texture, + texture_size: (u32, u32), +} + +impl Renderer { + /// Create a new renderer for the given window + pub async fn new(window: Arc) -> Result { + let size = window.inner_size(); + + // Create wgpu instance + let instance = wgpu::Instance::new(wgpu::InstanceDescriptor { + backends: wgpu::Backends::all(), + ..Default::default() + }); + + // Create surface + let surface = instance + .create_surface(window) + .map_err(|e| OverlayError::RendererInit(e.to_string()))?; + + // Request adapter + let adapter = instance + .request_adapter(&wgpu::RequestAdapterOptions { + power_preference: wgpu::PowerPreference::LowPower, + compatible_surface: Some(&surface), + force_fallback_adapter: false, + }) + .await + .ok_or_else(|| OverlayError::RendererInit("No suitable GPU adapter found".into()))?; + + // Create device and queue + let (device, queue) = adapter + .request_device( + &wgpu::DeviceDescriptor { + label: Some("ITK Overlay Device"), + required_features: wgpu::Features::empty(), + required_limits: wgpu::Limits::default(), + memory_hints: wgpu::MemoryHints::default(), + }, + None, + ) + .await + .map_err(|e| OverlayError::RendererInit(e.to_string()))?; + + // Configure surface + let surface_caps = surface.get_capabilities(&adapter); + let surface_format = surface_caps + .formats + .iter() + .find(|f| f.is_srgb()) + .copied() + .unwrap_or(surface_caps.formats[0]); + + let config = wgpu::SurfaceConfiguration { + usage: wgpu::TextureUsages::RENDER_ATTACHMENT, + format: surface_format, + width: size.width.max(1), + height: size.height.max(1), + present_mode: wgpu::PresentMode::Fifo, + alpha_mode: wgpu::CompositeAlphaMode::PreMultiplied, + view_formats: vec![], + desired_maximum_frame_latency: 2, + }; + surface.configure(&device, &config); + + // Create shader + let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("Overlay Shader"), + source: wgpu::ShaderSource::Wgsl(include_str!("shaders/overlay.wgsl").into()), + }); + + // Create texture for video/content (720p default) + let texture_size = (1280u32, 720u32); + let texture = device.create_texture(&wgpu::TextureDescriptor { + label: Some("Content Texture"), + size: wgpu::Extent3d { + width: texture_size.0, + height: texture_size.1, + depth_or_array_layers: 1, + }, + mip_level_count: 1, + sample_count: 1, + dimension: wgpu::TextureDimension::D2, + format: wgpu::TextureFormat::Rgba8UnormSrgb, + usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::COPY_DST, + view_formats: &[], + }); + + let texture_view = texture.create_view(&wgpu::TextureViewDescriptor::default()); + let sampler = device.create_sampler(&wgpu::SamplerDescriptor { + address_mode_u: wgpu::AddressMode::ClampToEdge, + address_mode_v: wgpu::AddressMode::ClampToEdge, + address_mode_w: wgpu::AddressMode::ClampToEdge, + mag_filter: wgpu::FilterMode::Linear, + min_filter: wgpu::FilterMode::Linear, + ..Default::default() + }); + + // Create bind group layout + let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: Some("Texture Bind Group Layout"), + entries: &[ + wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::FRAGMENT, + ty: wgpu::BindingType::Texture { + multisampled: false, + view_dimension: wgpu::TextureViewDimension::D2, + sample_type: wgpu::TextureSampleType::Float { filterable: true }, + }, + count: None, + }, + wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: wgpu::ShaderStages::FRAGMENT, + ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering), + count: None, + }, + ], + }); + + let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { + label: Some("Texture Bind Group"), + layout: &bind_group_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: wgpu::BindingResource::TextureView(&texture_view), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: wgpu::BindingResource::Sampler(&sampler), + }, + ], + }); + + // Create pipeline layout + let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: Some("Overlay Pipeline Layout"), + bind_group_layouts: &[&bind_group_layout], + push_constant_ranges: &[], + }); + + // Create render pipeline + let render_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor { + label: Some("Overlay Render Pipeline"), + layout: Some(&pipeline_layout), + vertex: wgpu::VertexState { + module: &shader, + entry_point: "vs_main", + buffers: &[wgpu::VertexBufferLayout { + array_stride: std::mem::size_of::() as wgpu::BufferAddress, + step_mode: wgpu::VertexStepMode::Vertex, + attributes: &[ + wgpu::VertexAttribute { + offset: 0, + shader_location: 0, + format: wgpu::VertexFormat::Float32x2, + }, + wgpu::VertexAttribute { + offset: std::mem::size_of::<[f32; 2]>() as wgpu::BufferAddress, + shader_location: 1, + format: wgpu::VertexFormat::Float32x2, + }, + ], + }], + compilation_options: wgpu::PipelineCompilationOptions::default(), + }, + fragment: Some(wgpu::FragmentState { + module: &shader, + entry_point: "fs_main", + targets: &[Some(wgpu::ColorTargetState { + format: config.format, + blend: Some(wgpu::BlendState::ALPHA_BLENDING), + write_mask: wgpu::ColorWrites::ALL, + })], + compilation_options: wgpu::PipelineCompilationOptions::default(), + }), + primitive: wgpu::PrimitiveState { + topology: wgpu::PrimitiveTopology::TriangleList, + strip_index_format: None, + front_face: wgpu::FrontFace::Ccw, + cull_mode: None, + polygon_mode: wgpu::PolygonMode::Fill, + unclipped_depth: false, + conservative: false, + }, + depth_stencil: None, + multisample: wgpu::MultisampleState::default(), + multiview: None, + cache: None, + }); + + // Create vertex buffer (quad) + let vertex_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Vertex Buffer"), + size: std::mem::size_of::<[Vertex; 6]>() as wgpu::BufferAddress, + usage: wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + Ok(Self { + surface, + device, + queue, + config, + render_pipeline, + vertex_buffer, + bind_group, + texture, + texture_size, + }) + } + + /// Resize the surface + pub fn resize(&mut self, new_size: winit::dpi::PhysicalSize) { + if new_size.width > 0 && new_size.height > 0 { + self.config.width = new_size.width; + self.config.height = new_size.height; + self.surface.configure(&self.device, &self.config); + } + } + + /// Update texture with new frame data + pub fn update_texture(&self, data: &[u8]) { + self.queue.write_texture( + wgpu::ImageCopyTexture { + texture: &self.texture, + mip_level: 0, + origin: wgpu::Origin3d::ZERO, + aspect: wgpu::TextureAspect::All, + }, + data, + wgpu::ImageDataLayout { + offset: 0, + bytes_per_row: Some(self.texture_size.0 * 4), + rows_per_image: Some(self.texture_size.1), + }, + wgpu::Extent3d { + width: self.texture_size.0, + height: self.texture_size.1, + depth_or_array_layers: 1, + }, + ); + } + + /// Render a frame + pub fn render(&mut self, screen_rect: Option<&ScreenRect>) -> Result<()> { + let output = self.surface.get_current_texture().map_err(|e| match e { + wgpu::SurfaceError::Lost => OverlayError::DeviceLost, + wgpu::SurfaceError::OutOfMemory => OverlayError::RendererInit("Out of memory".into()), + _ => OverlayError::RendererInit(e.to_string()), + })?; + + let view = output + .texture + .create_view(&wgpu::TextureViewDescriptor::default()); + + let mut encoder = self + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("Render Encoder"), + }); + + // Update vertex buffer if we have a screen rect + if let Some(rect) = screen_rect { + let vertices = self.create_vertices(rect); + self.queue + .write_buffer(&self.vertex_buffer, 0, bytemuck::cast_slice(&vertices)); + } + + { + let mut render_pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor { + label: Some("Render Pass"), + color_attachments: &[Some(wgpu::RenderPassColorAttachment { + view: &view, + resolve_target: None, + ops: wgpu::Operations { + load: wgpu::LoadOp::Clear(wgpu::Color::TRANSPARENT), + store: wgpu::StoreOp::Store, + }, + })], + depth_stencil_attachment: None, + occlusion_query_set: None, + timestamp_writes: None, + }); + + if screen_rect.is_some() { + render_pass.set_pipeline(&self.render_pipeline); + render_pass.set_bind_group(0, &self.bind_group, &[]); + render_pass.set_vertex_buffer(0, self.vertex_buffer.slice(..)); + render_pass.draw(0..6, 0..1); + } + } + + self.queue.submit(std::iter::once(encoder.finish())); + output.present(); + + Ok(()) + } + + /// Create vertices for a screen rect quad + fn create_vertices(&self, rect: &ScreenRect) -> [Vertex; 6] { + // Convert screen coordinates to NDC (-1 to 1) + let screen_width = self.config.width as f32; + let screen_height = self.config.height as f32; + + let left = (rect.x / screen_width) * 2.0 - 1.0; + let right = ((rect.x + rect.width) / screen_width) * 2.0 - 1.0; + let top = 1.0 - (rect.y / screen_height) * 2.0; + let bottom = 1.0 - ((rect.y + rect.height) / screen_height) * 2.0; + + // Two triangles forming a quad + [ + Vertex { + position: [left, top], + tex_coord: [0.0, 0.0], + }, + Vertex { + position: [right, top], + tex_coord: [1.0, 0.0], + }, + Vertex { + position: [left, bottom], + tex_coord: [0.0, 1.0], + }, + Vertex { + position: [left, bottom], + tex_coord: [0.0, 1.0], + }, + Vertex { + position: [right, top], + tex_coord: [1.0, 0.0], + }, + Vertex { + position: [right, bottom], + tex_coord: [1.0, 1.0], + }, + ] + } +} + +/// Vertex data +#[repr(C)] +#[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)] +struct Vertex { + position: [f32; 2], + tex_coord: [f32; 2], +} diff --git a/overlay/src/shaders/overlay.wgsl b/overlay/src/shaders/overlay.wgsl new file mode 100644 index 0000000..8c5ad06 --- /dev/null +++ b/overlay/src/shaders/overlay.wgsl @@ -0,0 +1,29 @@ +// Overlay shader for rendering textured quads + +struct VertexInput { + @location(0) position: vec2, + @location(1) tex_coord: vec2, +} + +struct VertexOutput { + @builtin(position) clip_position: vec4, + @location(0) tex_coord: vec2, +} + +@vertex +fn vs_main(in: VertexInput) -> VertexOutput { + var out: VertexOutput; + out.clip_position = vec4(in.position, 0.0, 1.0); + out.tex_coord = in.tex_coord; + return out; +} + +@group(0) @binding(0) +var t_diffuse: texture_2d; +@group(0) @binding(1) +var s_diffuse: sampler; + +@fragment +fn fs_main(in: VertexOutput) -> @location(0) vec4 { + return textureSample(t_diffuse, s_diffuse, in.tex_coord); +} diff --git a/overlay/src/video.rs b/overlay/src/video.rs new file mode 100644 index 0000000..76aed7b --- /dev/null +++ b/overlay/src/video.rs @@ -0,0 +1,123 @@ +//! Video frame reader for the overlay. +//! +//! Reads decoded video frames from shared memory written by the daemon. + +use itk_shmem::FrameBuffer; +use tracing::{debug, trace, warn}; + +/// Default shared memory name for video frames. +const SHMEM_NAME: &str = "itk_video_frames"; +/// Default video width. +const DEFAULT_WIDTH: u32 = 1280; +/// Default video height. +const DEFAULT_HEIGHT: u32 = 720; + +/// Reader for video frames from shared memory. +pub struct VideoFrameReader { + buffer: Option, + frame_data: Vec, + last_pts_ms: u64, + width: u32, + height: u32, +} + +impl VideoFrameReader { + /// Create a new video frame reader. + /// + /// Attempts to open the shared memory region created by the daemon. + /// If the region doesn't exist yet, the reader will retry on each read. + pub fn new() -> Self { + let width = DEFAULT_WIDTH; + let height = DEFAULT_HEIGHT; + let frame_size = (width as usize) * (height as usize) * 4; + + // Try to open the shared memory region + let buffer = match FrameBuffer::open(SHMEM_NAME, width, height) { + Ok(fb) => { + debug!("Opened video frame buffer"); + Some(fb) + }, + Err(e) => { + debug!( + ?e, + "Frame buffer not available yet (daemon may not be running)" + ); + None + }, + }; + + Self { + buffer, + frame_data: vec![0u8; frame_size], + last_pts_ms: 0, + width, + height, + } + } + + /// Try to read the latest frame. + /// + /// Returns `Some(data)` if a new frame is available, `None` otherwise. + /// The returned slice is valid until the next call to `try_read_frame`. + pub fn try_read_frame(&mut self) -> Option<&[u8]> { + // Try to open the buffer if we don't have it yet + if self.buffer.is_none() { + self.buffer = FrameBuffer::open(SHMEM_NAME, self.width, self.height).ok(); + if self.buffer.is_some() { + debug!("Connected to video frame buffer"); + } + } + + let buffer = self.buffer.as_ref()?; + + // Try to read a frame + match buffer.read_frame(self.last_pts_ms, &mut self.frame_data) { + Ok((pts_ms, changed)) => { + if changed { + self.last_pts_ms = pts_ms; + trace!(pts_ms, "Read new frame"); + Some(&self.frame_data) + } else { + None + } + }, + Err(itk_shmem::ShmemError::SeqlockContention) => { + // Writer may be slow or crashed, don't spam logs + trace!("Seqlock contention, will retry"); + None + }, + Err(e) => { + warn!(?e, "Failed to read frame"); + // Connection may have been lost, try to reconnect next time + self.buffer = None; + None + }, + } + } + + /// Get the last presentation timestamp in milliseconds. + pub fn last_pts_ms(&self) -> u64 { + self.last_pts_ms + } + + /// Check if connected to the frame buffer. + pub fn is_connected(&self) -> bool { + self.buffer.is_some() + } + + /// Get the frame dimensions. + pub fn dimensions(&self) -> (u32, u32) { + (self.width, self.height) + } + + /// Get a reference to the current frame data (may be stale). + pub fn current_frame(&self) -> &[u8] { + &self.frame_data + } +} + +impl Default for VideoFrameReader { + fn default() -> Self { + Self::new() + } +} diff --git a/projects/nms-cockpit-video/README.md b/projects/nms-cockpit-video/README.md new file mode 100644 index 0000000..42a8ab6 --- /dev/null +++ b/projects/nms-cockpit-video/README.md @@ -0,0 +1,254 @@ +# NMS Cockpit Video Player + +> A video player that renders inside your No Man's Sky spaceship cockpit, supporting both desktop and VR. + +## Architecture + +Two rendering paths are available: + +### Vulkan Injector (Desktop + VR) + +The injector DLL hooks NMS's Vulkan pipeline directly to render a textured quad in 3D space, visible to both desktop and VR users. + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ NMS Process │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ nms-cockpit-injector.dll │ │ +│ │ • Hook vkCreateDevice/SwapchainKHR/QueuePresentKHR │ │ +│ │ • Read camera matrices from cGcCameraManager │ │ +│ │ • Read video frames from shared memory │ │ +│ │ • Render textured quad via Vulkan pipeline │ │ +│ │ • (VR) Hook IVRCompositor::Submit for per-eye render │ │ +│ └─────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ + ▲ + │ Shared Memory (video frames) + │ +┌─────────────────────────────────────────────────────────────────┐ +│ nms-video-daemon │ +│ • Decode video (ffmpeg + yt-dlp) → shared memory │ +│ • Audio playback (cpal + ffmpeg resampler) │ +│ • Handle commands (play/pause/seek/load) │ +│ • P2P multiplayer sync (laminar) │ +└─────────────────────────────────────────────────────────────────┘ +``` + +### Overlay (Desktop Only, Legacy) + +A separate overlay window renders on top of NMS. Simpler but limited to desktop borderless/windowed mode. + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ nms-video-daemon │ +│ • Decode video → shared memory │ +│ • Audio playback │ +└──────────────────────────────┬──────────────────────────────────┘ + │ Shared Memory (video frames) + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ nms-video-overlay │ +│ • Read frames from shared memory │ +│ • Render video at screen position │ +│ • egui controls (F9 toggles interactive mode) │ +│ • Click-through by default │ +└─────────────────────────────────────────────────────────────────┘ +``` + +## Components + +### 1. Launcher (`launcher/`) +Orchestrates the full workflow: starts daemon, launches NMS with `--disable-eac`, waits for the game process, injects the DLL, and shuts down the daemon when the game exits. + +### 2. Daemon (`daemon/`) +Video decode and audio playback. Decodes via ffmpeg with D3D11VA hardware acceleration (falls back to software), writes RGBA frames to shared memory via seqlock, plays audio through system output via cpal. Supports YouTube URLs (via yt-dlp) and local files. + +### 3. Injector (`injector/`) +Vulkan rendering hook DLL. Intercepts NMS's Vulkan calls to render a textured quad inside the game's 3D pipeline. Includes keyboard input handler (F5-F9) that sends IPC commands to the daemon. Works in desktop and VR, any display mode. Requires nightly Rust toolchain. + +### 4. Overlay (`overlay/`) +Desktop overlay window with egui controls. Simpler alternative to the injector for desktop-only use in borderless/windowed mode. + +## Building + +### Prerequisites +- Rust nightly MSVC toolchain (injector requires nightly; daemon/overlay work on stable) +- FFmpeg libraries (via vcpkg: `vcpkg install ffmpeg:x64-windows`) +- yt-dlp in PATH (for YouTube URL extraction, optional) +- GPU with D3D11VA support (optional, falls back to software decode) + +### Rust Components +```bash +# Set up environment (MSVC + vcpkg) +export FFMPEG_DIR="C:/vcpkg/installed/x64-windows" + +# Build all (daemon, launcher, injector) - requires nightly for injector +cargo +nightly-x86_64-pc-windows-msvc build --release \ + -p nms-video-daemon -p nms-video-launcher -p nms-cockpit-injector + +# Build overlay (desktop-only alternative, stable toolchain) +cargo +stable-x86_64-pc-windows-msvc build --release -p nms-video-overlay +``` + +YouTube support (yt-dlp) is enabled by default in the daemon. D3D11VA hardware-accelerated decoding is used automatically when available (NVIDIA/AMD/Intel GPUs). + +## Usage + +### With Launcher (Recommended) + +Place all binaries in the same directory with an optional video file: +``` +nms-video-launcher.exe +nms-video-daemon.exe +nms_cockpit_injector.dll +nms_video.mp4 (optional: auto-loads on startup) +nms_video.txt (optional: path/URL to load instead of .mp4) +``` + +Run the launcher: +```bash +nms-video-launcher.exe [nms-exe-path] [dll-path] +``` + +The launcher handles everything: +1. Starts the daemon (with `--load` if a video file is found) +2. Launches NMS with `--disable-eac` +3. Waits for the game process to spawn +4. Injects the DLL via CreateRemoteThread +5. When NMS exits, shuts down the daemon + +### Keyboard Shortcuts (Injector) +| Key | Action | +|-----|--------| +| F5 | Toggle video overlay on/off | +| F6 | Play/Pause | +| F7 | Seek backward 10s | +| F8 | Seek forward 10s | +| F9 | Load video URL from clipboard (YouTube supported) | + +### With Injector (Manual) + +1. Start the daemon: `nms-video-daemon.exe --load ` +2. Launch NMS with `--disable-eac` +3. Inject `nms_cockpit_injector.dll` via the launcher or LoadLibrary +4. Press F5 to show the video overlay + +### With Overlay (Desktop Only) + +1. Start the daemon: `nms-video-daemon.exe` +2. Launch NMS in **borderless** or **windowed** mode +3. Start the overlay: `nms-video-overlay.exe` +4. Press **F9** to open controls, enter a video URL +5. Press **F9** again to return to click-through mode + +### Keyboard Shortcuts (Overlay) +| Key | Action | +|-----|--------| +| F9 | Toggle interactive/click-through mode | +| Space | Play/Pause (when interactive) | +| Left/Right | Seek -10s/+10s (when interactive) | + +## Multiplayer (P2P) + +1. All players must have the daemon + injector/overlay +2. One player hosts (enters video URL first) +3. Other players join by entering the same URL +4. Sync happens automatically via P2P (laminar) + +## Limitations + +- **PC only** - Console modding not possible +- **Singleplayer/P2P coop** - EAC must be disabled for injection +- **Overlay: borderless/windowed only** - Exclusive fullscreen blocks the overlay window +- **Injector: NMS updates may shift memory offsets** - Camera RVA needs updating per patch + +## Troubleshooting + +### Injector Not Rendering +1. Check DebugView for `[NMS-VIDEO]` log messages +2. Ensure DLL is injected before NMS creates its Vulkan device (early load) +3. Verify `vulkan-1.dll` is loaded in the process +4. Check that shared memory frames are being written by the daemon + +### Overlay Not Visible +1. Ensure NMS is in **borderless** or **windowed** mode +2. Check that daemon and overlay processes are running +3. Look for "Connected to video frame buffer" in overlay logs + +### Video Not Playing +1. Check daemon logs for decode errors +2. Verify ffmpeg libraries are in PATH or vcpkg +3. For YouTube: ensure yt-dlp is installed and up to date +4. Try a local file to isolate network issues + +## Injector Architecture + +The injector DLL is organized into four modules that implement the full rendering pipeline: + +``` +injector/src/ +├── lib.rs DllMain, init thread, module declarations +├── log.rs OutputDebugString logging (view with DebugView) +├── input.rs Keyboard hotkeys (F5-F9), IPC to daemon, clipboard +├── shmem_reader.rs Lock-free shared memory frame polling (itk-shmem seqlock) +├── hooks/ +│ ├── mod.rs Hook install/remove orchestration +│ ├── vulkan.rs Vulkan function detours (retour static_detour) +│ └── openvr.rs IVRCompositor::Submit vtable hook +├── renderer/ +│ ├── mod.rs VulkanRenderer: pipeline, draw commands, VR rendering +│ ├── pipeline.rs Render pass, graphics pipeline, descriptor sets +│ ├── texture.rs VideoTexture: staging buffer upload, device-local image +│ └── geometry.rs Quad vertex buffer (6 vertices, pos3 + uv2) +├── camera/ +│ ├── mod.rs CameraReader: NMS process memory reads +│ └── projection.rs Perspective projection, cockpit MVP computation +└── shaders/ + ├── quad.vert.wgsl Vertex shader (push constant mat4 MVP) + └── quad.frag.wgsl Fragment shader (texture2d + sampler) +``` + +### Hook Points + +| Hook | Method | Purpose | +|------|--------|---------| +| `vkCreateInstance` | retour static_detour | Capture VkInstance for ash loader | +| `vkCreateDevice` | retour static_detour | Capture VkDevice, VkPhysicalDevice, queue family | +| `vkCreateSwapchainKHR` | retour + ICD RawDetour | Track format, extent, swapchain images | +| `vkQueuePresentKHR` | retour + ICD RawDetour | Render quad before desktop present | +| `IVRCompositor::Submit` | vtable swap | Render quad per VR eye before compositor | + +Extension functions (`*KHR`) bypass the Vulkan loader when obtained via `vkGetDeviceProcAddr`, so they are hooked at both the loader trampoline level (retour static_detour) and the ICD level (RawDetour on the address returned by `vkGetDeviceProcAddr`). + +### Rendering Flow + +**Desktop**: `vkQueuePresentKHR` hook -> read camera -> poll shmem -> compute MVP -> transition swapchain image -> render pass (LOAD) -> draw quad -> transition back -> submit + +**VR**: `IVRCompositor::Submit` hook -> read VRVulkanTextureData_t -> get VkImage per eye -> create temp framebuffer -> render pass (LOAD) -> draw quad -> cleanup + +### Shader Compilation + +Shaders are written in WGSL and compiled to SPIR-V at build time via naga (pure Rust, no Vulkan SDK required). The fragment shader uses separate `texture_2d` and `sampler` bindings (WGSL requirement). + +## Development + +### Memory Offsets (cGcCameraManager) + +See `docs/nms-reverse-engineering.md` for camera singleton details: +``` +Singleton pointer: NMS.exe + 0x56666B0 ++0x118 Camera mode (u32, cockpit = 0x10) ++0x130 View matrix (4x4 f32, row-major) ++0x1D0 FoV (f32, degrees) ++0x1D4 Aspect ratio (f32) +``` + +### Updating for NMS Patches +When NMS updates shift memory layouts: +1. Use `mem-scanner` tool or x64dbg to find new cGcCameraManager RVA +2. Update the offset in the injector's camera module +3. Verify camera mode detection still works + +## License + +Part of the game-mods project. See repository LICENSE file. diff --git a/projects/nms-cockpit-video/daemon/Cargo.toml b/projects/nms-cockpit-video/daemon/Cargo.toml new file mode 100644 index 0000000..40ed02e --- /dev/null +++ b/projects/nms-cockpit-video/daemon/Cargo.toml @@ -0,0 +1,54 @@ +[package] +name = "nms-video-daemon" +description = "Video playback daemon for No Man's Sky Cockpit Video Player" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +authors.workspace = true + +[[bin]] +name = "nms-video-daemon" +path = "src/main.rs" + +[features] +default = ["youtube"] +youtube = ["itk-video/youtube"] +# Build ffmpeg from source (slow, but works without system ffmpeg) +build-ffmpeg = ["itk-video/build-ffmpeg"] + +[dependencies] +# Core ITK libraries +itk-protocol = { path = "../../../core/itk-protocol" } +itk-shmem = { path = "../../../core/itk-shmem" } +itk-ipc = { path = "../../../core/itk-ipc" } +itk-sync = { path = "../../../core/itk-sync" } +itk-video = { path = "../../../core/itk-video" } + +# Async runtime +tokio = { workspace = true } + +# Serialization +serde = { workspace = true } +serde_json = "1.0" + +# Logging +tracing = { workspace = true } +tracing-subscriber = { workspace = true } + +# Error handling +thiserror = { workspace = true } +anyhow = { workspace = true } + +# CLI +clap = { version = "4", features = ["derive"] } + +# Audio output +cpal = { workspace = true } +ringbuf = { workspace = true } + +# FFmpeg (for audio decoder) +ffmpeg-next = { workspace = true } + +[lints] +workspace = true diff --git a/projects/nms-cockpit-video/daemon/src/main.rs b/projects/nms-cockpit-video/daemon/src/main.rs new file mode 100644 index 0000000..9477c10 --- /dev/null +++ b/projects/nms-cockpit-video/daemon/src/main.rs @@ -0,0 +1,410 @@ +//! NMS Cockpit Video Player - Daemon +//! +//! Central coordinator for the No Man's Sky cockpit video player. +//! Handles video decoding, screen rect from the mod, and multiplayer sync. +//! +//! NOTE: Input validation and IPC listener patterns are shared with the +//! generic daemon template in `daemon/`. +//! Future work: extract shared IPC/validation into `itk-daemon-core` library crate. + +use anyhow::{Context, Result, bail}; +use clap::Parser; +use itk_ipc::{IpcChannel, IpcServer}; +use itk_protocol::{ + MessageType, ScreenRect, StateQuery, StateResponse, VideoLoad, VideoPause, VideoPlay, + VideoSeek, decode, encode, +}; +use std::sync::{Arc, RwLock}; +use std::thread; +use std::time::Duration; +use tracing::{debug, error, info, warn}; + +mod video; +use video::VideoPlayer; + +// ============================================================================= +// Input Validation (shared patterns from daemon template) +// ============================================================================= + +/// Maximum screen dimension (sanity check) +const MAX_SCREEN_DIM: f32 = 16384.0; + +/// Validate a ScreenRect from untrusted source (injected mod). +/// +/// Mirrors the validation from the ITK daemon template to prevent +/// crafted messages from causing crashes in GPU/rendering code. +fn validate_screen_rect(rect: &ScreenRect) -> Result<()> { + if !rect.x.is_finite() + || !rect.y.is_finite() + || !rect.width.is_finite() + || !rect.height.is_finite() + || !rect.rotation.is_finite() + { + bail!("ScreenRect contains non-finite values"); + } + + if rect.x.abs() > MAX_SCREEN_DIM + || rect.y.abs() > MAX_SCREEN_DIM + || rect.width > MAX_SCREEN_DIM + || rect.height > MAX_SCREEN_DIM + { + bail!("ScreenRect dimensions out of bounds"); + } + + if rect.width < 0.0 || rect.height < 0.0 { + bail!("ScreenRect has negative dimensions"); + } + + let right = rect.x + rect.width; + let bottom = rect.y + rect.height; + if !right.is_finite() + || !bottom.is_finite() + || right.abs() > MAX_SCREEN_DIM + || bottom.abs() > MAX_SCREEN_DIM + { + bail!("ScreenRect coordinate overflow"); + } + + Ok(()) +} + +/// NMS Cockpit Video Player Daemon +#[derive(Parser, Debug)] +#[command(name = "nms-video-daemon")] +#[command(about = "Video playback daemon for No Man's Sky Cockpit Video Player")] +struct Args { + /// IPC channel name for the NMS mod + #[arg(long, default_value = "nms_cockpit_injector")] + mod_channel: String, + + /// IPC channel name for clients (overlay, MCP) + #[arg(long, default_value = "nms_cockpit_client")] + client_channel: String, + + /// Enable multiplayer sync + #[arg(long)] + multiplayer: bool, + + /// Multiplayer sync port + #[arg(long, default_value = "7331")] + sync_port: u16, + + /// Log level (trace, debug, info, warn, error) + #[arg(long, default_value = "info")] + log_level: String, + + /// Auto-load a video URL or file path on startup (for testing) + #[arg(long)] + load: Option, +} + +/// Application state +#[derive(Default)] +struct AppState { + /// Screen rect from NMS mod + screen_rect: Option, + /// Last update timestamp + last_update_ms: u64, + /// Video player + video_player: Option, +} + +fn main() -> Result<()> { + let args = Args::parse(); + + // Initialize logging + let filter = format!("nms_video_daemon={},itk={}", args.log_level, args.log_level); + let env_filter = tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(&filter)); + tracing_subscriber::fmt().with_env_filter(env_filter).init(); + + info!("NMS Cockpit Video Player Daemon starting"); + info!(mod_channel = %args.mod_channel, client_channel = %args.client_channel); + + if args.multiplayer { + info!(port = args.sync_port, "Multiplayer sync enabled"); + } + + let state = Arc::new(RwLock::new(AppState::default())); + + // Initialize video player + { + let mut state = state.write().unwrap(); + state.video_player = Some(VideoPlayer::new()); + info!("Video player initialized"); + } + + // Auto-load if --load was provided + if let Some(ref url) = args.load { + info!(url = %url, "Auto-loading video"); + let state_read = state.read().unwrap(); + if let Some(ref player) = state_read.video_player { + player.load(url, 0, true); + } + } + + // Start mod listener thread (receives ScreenRect from NMS) + let mod_state = Arc::clone(&state); + let mod_channel = args.mod_channel.clone(); + let mod_handle = thread::spawn(move || { + if let Err(e) = run_mod_listener(&mod_channel, mod_state) { + error!(?e, "Mod listener failed"); + } + }); + + // Start client listener thread (receives commands from overlay) + let client_state = Arc::clone(&state); + let client_channel = args.client_channel.clone(); + let client_handle = thread::spawn(move || { + if let Err(e) = run_client_listener(&client_channel, client_state) { + error!(?e, "Client listener failed"); + } + }); + + info!("Daemon running. Press Ctrl+C to stop."); + + // Wait for threads + let _ = mod_handle.join(); + let _ = client_handle.join(); + + Ok(()) +} + +/// Run the mod IPC listener (receives ScreenRect from NMS mod) +fn run_mod_listener(channel_name: &str, state: Arc>) -> Result<()> { + info!(channel = %channel_name, "Starting mod listener"); + + let server = itk_ipc::listen(channel_name).context("Failed to create mod IPC server")?; + + loop { + info!("Waiting for NMS mod connection..."); + + match server.accept() { + Ok(channel) => { + info!("NMS mod connected"); + handle_mod_connection(channel, Arc::clone(&state)); + }, + Err(e) => { + warn!(?e, "Failed to accept mod connection"); + thread::sleep(Duration::from_secs(1)); + }, + } + } +} + +/// Handle a connected NMS mod +fn handle_mod_connection(channel: impl IpcChannel, state: Arc>) { + loop { + match channel.recv() { + Ok(data) => { + if let Err(e) = process_mod_message(&data, &state) { + warn!(?e, "Failed to process mod message"); + } + }, + Err(itk_ipc::IpcError::ChannelClosed) => { + info!("NMS mod disconnected"); + break; + }, + Err(e) => { + warn!(?e, "Error receiving from mod"); + break; + }, + } + } +} + +/// Process a message from the NMS mod. +/// +/// SECURITY: Data from the injected mod is treated as UNTRUSTED and validated. +fn process_mod_message(data: &[u8], state: &Arc>) -> Result<()> { + let header = itk_protocol::decode_header(data)?; + + match header.msg_type { + MessageType::ScreenRect => { + let (_, rect): (_, ScreenRect) = decode(data)?; + + // SECURITY: Full validation of untrusted data from mod + validate_screen_rect(&rect)?; + + debug!( + x = rect.x, + y = rect.y, + w = rect.width, + h = rect.height, + "Updated screen rect" + ); + + let mut state = state.write().unwrap(); + state.screen_rect = Some(rect); + state.last_update_ms = itk_sync::now_ms(); + }, + other => { + warn!(?other, "Unexpected message type from mod"); + }, + } + + Ok(()) +} + +/// Run the client IPC listener (receives commands from overlay) +fn run_client_listener(channel_name: &str, state: Arc>) -> Result<()> { + info!(channel = %channel_name, "Starting client listener"); + + let server = itk_ipc::listen(channel_name).context("Failed to create client IPC server")?; + + loop { + info!("Waiting for client connection..."); + + match server.accept() { + Ok(channel) => { + info!("Client connected"); + handle_client_connection(channel, Arc::clone(&state)); + }, + Err(e) => { + warn!(?e, "Failed to accept client connection"); + thread::sleep(Duration::from_secs(1)); + }, + } + } +} + +/// Handle a connected client (overlay) +fn handle_client_connection(channel: impl IpcChannel, state: Arc>) { + loop { + match channel.recv() { + Ok(data) => { + if let Err(e) = process_client_message(&data, &state, &channel) { + warn!(?e, "Failed to process client message"); + } + }, + Err(itk_ipc::IpcError::ChannelClosed) => { + info!("Client disconnected"); + break; + }, + Err(e) => { + warn!(?e, "Error receiving from client"); + break; + }, + } + } +} + +/// Process a message from a client +fn process_client_message( + data: &[u8], + state: &Arc>, + channel: &impl IpcChannel, +) -> Result<()> { + let header = itk_protocol::decode_header(data)?; + + match header.msg_type { + MessageType::Ping => { + let pong = encode(MessageType::Pong, &())?; + channel.send(&pong)?; + }, + + MessageType::StateQuery => { + let (_, query): (_, StateQuery) = decode(data)?; + let response = handle_state_query(&query, state)?; + let encoded = encode(MessageType::StateResponse, &response)?; + channel.send(&encoded)?; + }, + + MessageType::VideoLoad => { + let (_, cmd): (_, VideoLoad) = decode(data)?; + info!(source = %cmd.source, "Loading video"); + let state = state.read().unwrap(); + if let Some(ref player) = state.video_player { + player.load(&cmd.source, cmd.start_position_ms, cmd.autoplay); + } + }, + + MessageType::VideoPlay => { + let (_, _cmd): (_, VideoPlay) = decode(data)?; + debug!("Play"); + let state = state.read().unwrap(); + if let Some(ref player) = state.video_player { + player.play(); + } + }, + + MessageType::VideoPause => { + let (_, _cmd): (_, VideoPause) = decode(data)?; + debug!("Pause"); + let state = state.read().unwrap(); + if let Some(ref player) = state.video_player { + player.pause(); + } + }, + + MessageType::VideoSeek => { + let (_, cmd): (_, VideoSeek) = decode(data)?; + debug!(position_ms = cmd.position_ms, "Seek"); + let state = state.read().unwrap(); + if let Some(ref player) = state.video_player { + player.seek(cmd.position_ms); + } + }, + + other => { + warn!(?other, "Unexpected message type from client"); + }, + } + + Ok(()) +} + +/// Handle a state query +fn handle_state_query(query: &StateQuery, state: &Arc>) -> Result { + let state = state.read().unwrap(); + + let response = match query.query_type.as_str() { + "screen_rect" => { + if let Some(ref rect) = state.screen_rect { + StateResponse { + success: true, + data: Some(serde_json::to_string(rect)?), + error: None, + } + } else { + StateResponse { + success: false, + data: None, + error: Some("No screen rect available (NMS mod not connected)".to_string()), + } + } + }, + + "video_state" => { + if let Some(ref player) = state.video_player { + if let Some(video_state) = player.get_video_state() { + StateResponse { + success: true, + data: Some(serde_json::to_string(&video_state)?), + error: None, + } + } else { + StateResponse { + success: false, + data: None, + error: Some("No video loaded".to_string()), + } + } + } else { + StateResponse { + success: false, + data: None, + error: Some("Video player not initialized".to_string()), + } + } + }, + + _ => StateResponse { + success: false, + data: None, + error: Some(format!("Unknown query type: {}", query.query_type)), + }, + }; + + Ok(response) +} diff --git a/projects/nms-cockpit-video/daemon/src/video/audio.rs b/projects/nms-cockpit-video/daemon/src/video/audio.rs new file mode 100644 index 0000000..4a0ef3b --- /dev/null +++ b/projects/nms-cockpit-video/daemon/src/video/audio.rs @@ -0,0 +1,467 @@ +//! Audio player using ffmpeg for decoding and cpal for output. +//! +//! Opens the same source as the video decoder, finds the audio stream, +//! decodes + resamples to the output device format, and plays through cpal. + +use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; +use cpal::{Device, SampleRate, Stream, StreamConfig}; +use ffmpeg_next::media::Type as MediaType; +use ffmpeg_next::software::resampling::Context as Resampler; +use ffmpeg_next::util::frame::audio::Audio as AudioFrame; +use ffmpeg_next::{ChannelLayout, codec, decoder, format}; +use ringbuf::HeapRb; +use ringbuf::traits::{Consumer, Observer, Producer, Split}; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; +use std::sync::mpsc::{self, Receiver, Sender}; +use std::thread::{self, JoinHandle}; +use std::time::Duration; +use tracing::{debug, error, info, warn}; + +/// Commands for the audio decode thread. +#[derive(Debug)] +pub enum AudioCommand { + Play, + Pause, + Seek { position_ms: u64 }, + SetVolume { volume: f32 }, + Stop, +} + +/// Ring buffer size: ~500ms of stereo f32 audio at 48kHz. +const RING_BUFFER_SIZE: usize = 48000 * 2 * 2; // samples * channels * 0.5s worth + +/// Audio player that decodes audio from a source and plays through system audio. +pub struct AudioPlayer { + command_tx: Sender, + decode_thread: Option>, + _stream: Stream, +} + +impl AudioPlayer { + /// Create a new audio player for the given source. + /// + /// Opens the source with ffmpeg, finds the audio stream, and starts playback. + /// Returns None if no audio stream is found or audio output is unavailable. + pub fn new(source: &str, volume: f32, autoplay: bool) -> Option { + // Initialize cpal output device + let host = cpal::default_host(); + let device = match host.default_output_device() { + Some(d) => d, + None => { + warn!("No audio output device available"); + return None; + }, + }; + + let device_name = device.name().unwrap_or_else(|_| "unknown".to_string()); + info!(device = %device_name, "Using audio output device"); + + // Get output config - prefer 48kHz stereo f32 + let config = match get_output_config(&device) { + Some(c) => c, + None => { + warn!("No suitable audio output configuration"); + return None; + }, + }; + + let sample_rate = config.sample_rate.0; + let channels = config.channels as u32; + info!(sample_rate, channels, "Audio output config"); + + // Create ring buffer + let rb = HeapRb::::new(RING_BUFFER_SIZE); + let (producer, mut consumer) = rb.split(); + + // Volume control shared with output callback + let volume_atomic = Arc::new(AtomicU32::new(volume.to_bits())); + let paused = Arc::new(AtomicBool::new(!autoplay)); + // Flush flag: when set, consumer discards samples (used during seek) + let flush_flag = Arc::new(AtomicBool::new(false)); + + // Create cpal output stream + let vol_clone = Arc::clone(&volume_atomic); + let paused_clone = Arc::clone(&paused); + let flush_clone = Arc::clone(&flush_flag); + let stream = device + .build_output_stream( + &config, + move |data: &mut [f32], _: &cpal::OutputCallbackInfo| { + let vol = f32::from_bits(vol_clone.load(Ordering::Relaxed)); + let is_paused = paused_clone.load(Ordering::Relaxed); + + // If flush requested, discard buffered samples + if flush_clone.load(Ordering::Relaxed) { + consumer.skip(consumer.occupied_len()); + flush_clone.store(false, Ordering::Relaxed); + data.fill(0.0); + return; + } + + if is_paused { + data.fill(0.0); + return; + } + + let available = consumer.occupied_len(); + let to_read = data.len().min(available); + + // Read from ring buffer + if to_read > 0 { + consumer.pop_slice(&mut data[..to_read]); + // Apply volume + for sample in &mut data[..to_read] { + *sample *= vol; + } + } + + // Fill remainder with silence + if to_read < data.len() { + data[to_read..].fill(0.0); + } + }, + |err| { + error!(?err, "Audio output stream error"); + }, + None, + ) + .ok()?; + + stream.play().ok()?; + + // Create command channel + let (command_tx, command_rx) = mpsc::channel(); + + // Spawn audio decode thread + let source_owned = source.to_string(); + let vol_for_thread = Arc::clone(&volume_atomic); + let paused_for_thread = Arc::clone(&paused); + let flush_for_thread = Arc::clone(&flush_flag); + let decode_thread = thread::spawn(move || { + audio_decode_loop( + &source_owned, + sample_rate, + channels, + producer, + command_rx, + vol_for_thread, + paused_for_thread, + flush_for_thread, + autoplay, + ); + }); + + Some(Self { + command_tx, + decode_thread: Some(decode_thread), + _stream: stream, + }) + } + + /// Send a command to the audio player. + pub fn send_command(&self, cmd: AudioCommand) { + let _ = self.command_tx.send(cmd); + } + + /// Resume audio playback. + pub fn play(&self) { + self.send_command(AudioCommand::Play); + } + + /// Pause audio playback. + pub fn pause(&self) { + self.send_command(AudioCommand::Pause); + } + + /// Seek audio to a position. + pub fn seek(&self, position_ms: u64) { + self.send_command(AudioCommand::Seek { position_ms }); + } + + /// Set audio volume (0.0 - 1.0). + pub fn set_volume(&self, volume: f32) { + self.send_command(AudioCommand::SetVolume { volume }); + } +} + +impl Drop for AudioPlayer { + fn drop(&mut self) { + let _ = self.command_tx.send(AudioCommand::Stop); + if let Some(handle) = self.decode_thread.take() { + let _ = handle.join(); + } + } +} + +/// Find a suitable output config for the device. +fn get_output_config(device: &Device) -> Option { + if let Ok(configs) = device.supported_output_configs() { + let configs: Vec<_> = configs.collect(); + + // Prefer stereo f32 at 48kHz + for cfg in &configs { + if cfg.channels() == 2 && cfg.sample_format() == cpal::SampleFormat::F32 { + let rate = SampleRate(48000); + if cfg.min_sample_rate() <= rate && cfg.max_sample_rate() >= rate { + return Some(cfg.with_sample_rate(rate).into()); + } + } + } + + // Fall back to any stereo config + for cfg in &configs { + if cfg.channels() == 2 { + let rate = SampleRate(48000).clamp(cfg.min_sample_rate(), cfg.max_sample_rate()); + return Some(cfg.with_sample_rate(rate).into()); + } + } + + // Fall back to default config + device.default_output_config().ok().map(|c| c.into()) + } else { + device.default_output_config().ok().map(|c| c.into()) + } +} + +/// Type alias for the ring buffer producer. +type AudioProducer = ringbuf::HeapProd; + +/// Audio decode loop - runs in a separate thread. +#[allow(clippy::too_many_arguments)] +fn audio_decode_loop( + source: &str, + target_sample_rate: u32, + target_channels: u32, + mut producer: AudioProducer, + command_rx: Receiver, + volume_atomic: Arc, + paused: Arc, + flush_flag: Arc, + autoplay: bool, +) { + // Open the source + let mut format_ctx = match format::input(&source) { + Ok(ctx) => ctx, + Err(e) => { + error!(?e, "Failed to open audio source"); + return; + }, + }; + + // Find audio stream + let audio_stream = match format_ctx.streams().best(MediaType::Audio) { + Some(s) => s, + None => { + info!("No audio stream found in source"); + return; + }, + }; + + let audio_stream_index = audio_stream.index(); + + // Create audio decoder + let codec_params = audio_stream.parameters(); + let codec_id = codec_params.id(); + let codec = match decoder::find(codec_id) { + Some(c) => c, + None => { + error!(?codec_id, "No audio decoder found"); + return; + }, + }; + + let mut decoder_ctx = codec::context::Context::new_with_codec(codec); + if let Err(e) = decoder_ctx.set_parameters(codec_params) { + error!(?e, "Failed to set audio decoder parameters"); + return; + } + + let mut audio_decoder = match decoder_ctx.decoder().audio() { + Ok(d) => d, + Err(e) => { + error!(?e, "Failed to open audio decoder"); + return; + }, + }; + + info!( + sample_rate = audio_decoder.rate(), + channels = audio_decoder.channels(), + format = ?audio_decoder.format(), + "Audio decoder initialized" + ); + + // We'll create the resampler lazily after the first frame is decoded, + // since some codecs don't report format info until then. + let mut resampler: Option = None; + let mut audio_frame = AudioFrame::empty(); + let mut resampled_frame = AudioFrame::empty(); + + // Set initial state + if !autoplay { + paused.store(true, Ordering::Relaxed); + } + + let mut running = true; + + while running { + // Check for commands (non-blocking) + while let Ok(cmd) = command_rx.try_recv() { + match cmd { + AudioCommand::Play => { + debug!("Audio: play"); + paused.store(false, Ordering::Relaxed); + }, + AudioCommand::Pause => { + debug!("Audio: pause"); + paused.store(true, Ordering::Relaxed); + }, + AudioCommand::Seek { position_ms } => { + debug!(position_ms, "Audio: seek"); + let timestamp_us = (position_ms as i64) * 1000; + if let Err(e) = format_ctx.seek(timestamp_us, ..timestamp_us) { + warn!(?e, "Audio seek failed"); + } + audio_decoder.flush(); + // Signal consumer to discard buffered samples + flush_flag.store(true, Ordering::Release); + }, + AudioCommand::SetVolume { volume } => { + volume_atomic.store(volume.to_bits(), Ordering::Relaxed); + }, + AudioCommand::Stop => { + debug!("Audio: stop"); + running = false; + }, + } + } + + if !running { + break; + } + + // If paused, just sleep and poll commands + if paused.load(Ordering::Relaxed) { + thread::sleep(Duration::from_millis(50)); + continue; + } + + // If ring buffer is nearly full, wait a bit + if producer.vacant_len() < 4096 { + thread::sleep(Duration::from_millis(5)); + continue; + } + + // Try to receive a decoded frame + match audio_decoder.receive_frame(&mut audio_frame) { + Ok(()) => { + // Create resampler if needed + if resampler.is_none() { + let target_layout = if target_channels == 1 { + ChannelLayout::MONO + } else { + ChannelLayout::STEREO + }; + + match Resampler::get( + audio_frame.format(), + audio_frame.channel_layout(), + audio_frame.rate(), + ffmpeg_next::format::Sample::F32(ffmpeg_next::format::sample::Type::Packed), + target_layout, + target_sample_rate, + ) { + Ok(r) => { + info!( + src_rate = audio_frame.rate(), + dst_rate = target_sample_rate, + src_channels = audio_frame.channels(), + dst_channels = target_channels, + "Audio resampler created" + ); + resampler = Some(r); + }, + Err(e) => { + error!(?e, "Failed to create audio resampler"); + running = false; + continue; + }, + } + } + + // Resample the frame + if let Some(ref mut ctx) = resampler { + match ctx.run(&audio_frame, &mut resampled_frame) { + Ok(_delay) => { + // Extract f32 samples from resampled frame + let data = resampled_frame.data(0); + let sample_count = resampled_frame.samples() * target_channels as usize; + let samples: &[f32] = unsafe { + std::slice::from_raw_parts( + data.as_ptr() as *const f32, + sample_count, + ) + }; + + // Write to ring buffer, blocking briefly if full to + // maintain A/V sync (instead of dropping samples) + let mut offset = 0; + let mut retries = 0; + while offset < samples.len() && running { + let written = producer.push_slice(&samples[offset..]); + offset += written; + if offset < samples.len() { + retries += 1; + if retries > 50 { + // Timeout: consumer stalled, drop remaining + warn!( + "Audio ring buffer blocked too long, dropped {} samples", + samples.len() - offset + ); + break; + } + // Brief sleep to let consumer drain + std::thread::sleep(Duration::from_millis(1)); + } + } + }, + Err(e) => { + warn!(?e, "Audio resample failed"); + }, + } + } + }, + Err(ffmpeg_next::Error::Other { errno }) if errno == ffmpeg_next::error::EAGAIN => { + // Need more input packets - fall through to read next packet + }, + Err(ffmpeg_next::Error::Eof) => { + info!("Audio stream ended"); + running = false; + continue; + }, + Err(e) => { + warn!(?e, "Audio decode error"); + running = false; + continue; + }, + } + + // Read the next packet + match format_ctx.packets().next() { + Some((stream, packet)) => { + if stream.index() == audio_stream_index + && let Err(e) = audio_decoder.send_packet(&packet) + { + warn!(?e, "Failed to send audio packet"); + } + }, + None => { + // End of stream - flush decoder + let _ = audio_decoder.send_eof(); + }, + } + } + + info!("Audio decode thread exiting"); +} diff --git a/projects/nms-cockpit-video/daemon/src/video/mod.rs b/projects/nms-cockpit-video/daemon/src/video/mod.rs new file mode 100644 index 0000000..2c81680 --- /dev/null +++ b/projects/nms-cockpit-video/daemon/src/video/mod.rs @@ -0,0 +1,14 @@ +//! Video playback subsystem for the daemon. +//! +//! This module handles video decoding, frame output to shared memory, +//! and playback state management. + +mod audio; +mod player; +mod state; + +pub use player::VideoPlayer; + +// Re-export state types for external use +#[allow(unused_imports)] +pub use state::{PlayerCommand, PlayerState, VideoInfo}; diff --git a/projects/nms-cockpit-video/daemon/src/video/player.rs b/projects/nms-cockpit-video/daemon/src/video/player.rs new file mode 100644 index 0000000..c9f4e6e --- /dev/null +++ b/projects/nms-cockpit-video/daemon/src/video/player.rs @@ -0,0 +1,659 @@ +//! Video player implementation with real ffmpeg decoding. + +use super::audio::AudioPlayer; +use super::state::{PlayerCommand, PlayerState, VideoInfo}; +use itk_protocol::{VideoMetadata, VideoState}; +use itk_video::{DecodedFrame, FrameWriter, StreamSource, VideoDecoder}; +use std::sync::mpsc::{self, Receiver, Sender}; +use std::sync::{Arc, Mutex}; +use std::thread::{self, JoinHandle}; +use std::time::{Duration, Instant}; +use tracing::{debug, error, info, warn}; + +/// Lock the player state mutex, recovering from poisoning. +/// +/// A poisoned mutex means a thread panicked while holding the lock. +/// We recover by taking the inner data since partial state is better +/// than crashing the daemon. +fn lock_state(mutex: &Mutex) -> std::sync::MutexGuard<'_, PlayerState> { + mutex.lock().unwrap_or_else(|poisoned| { + warn!("Player state mutex was poisoned, recovering"); + poisoned.into_inner() + }) +} + +/// Default output width for video frames. +const DEFAULT_WIDTH: u32 = 1280; +/// Default output height for video frames. +const DEFAULT_HEIGHT: u32 = 720; +/// Shared memory name for video frames. +const SHMEM_NAME: &str = "itk_video_frames"; + +/// Video player that decodes video and writes frames to shared memory. +pub struct VideoPlayer { + /// Current player state. + state: Arc>, + /// Command sender for the decode thread. + command_tx: Sender, + /// Handle to the decode thread. + decode_thread: Option>, +} + +impl VideoPlayer { + /// Create a new video player. + pub fn new() -> Self { + let (command_tx, command_rx) = mpsc::channel(); + let state = Arc::new(Mutex::new(PlayerState::Idle)); + + // Start the decode thread + let state_clone = Arc::clone(&state); + let decode_thread = thread::spawn(move || { + decode_loop(state_clone, command_rx); + }); + + Self { + state, + command_tx, + decode_thread: Some(decode_thread), + } + } + + /// Send a command to the video player. + pub fn send_command(&self, cmd: PlayerCommand) { + if let Err(e) = self.command_tx.send(cmd) { + error!(?e, "Failed to send command to video player"); + } + } + + /// Load a video from a source. + pub fn load(&self, source: &str, start_position_ms: u64, autoplay: bool) { + self.send_command(PlayerCommand::Load { + source: source.to_string(), + start_position_ms, + autoplay, + }); + } + + /// Start or resume playback. + pub fn play(&self) { + self.send_command(PlayerCommand::Play); + } + + /// Pause playback. + pub fn pause(&self) { + self.send_command(PlayerCommand::Pause); + } + + /// Seek to a position. + pub fn seek(&self, position_ms: u64) { + self.send_command(PlayerCommand::Seek { position_ms }); + } + + /// Stop playback and unload. + #[allow(dead_code)] + pub fn stop(&self) { + self.send_command(PlayerCommand::Stop); + } + + /// Get the current player state. + #[allow(dead_code)] + pub fn state(&self) -> PlayerState { + lock_state(&self.state).clone() + } + + /// Get the current video state for protocol messages. + pub fn get_video_state(&self) -> Option { + let state = lock_state(&self.state); + match &*state { + PlayerState::Playing { info, .. } + | PlayerState::Paused { info, .. } + | PlayerState::Buffering { info, .. } => Some(VideoState { + content_id: info.content_id.clone(), + position_ms: state.position_ms(), + duration_ms: info.duration_ms, + is_playing: state.is_playing(), + is_buffering: matches!(*state, PlayerState::Buffering { .. }), + playback_rate: info.playback_rate, + volume: info.volume, + }), + _ => None, + } + } + + /// Get video metadata for protocol messages. + #[allow(dead_code)] + pub fn get_metadata(&self) -> Option { + let state = lock_state(&self.state); + state.video_info().map(|info| VideoMetadata { + content_id: info.content_id.clone(), + width: info.width, + height: info.height, + duration_ms: info.duration_ms, + fps: info.fps, + codec: info.codec.clone(), + is_live: info.is_live, + title: info.title.clone(), + }) + } +} + +impl Default for VideoPlayer { + fn default() -> Self { + Self::new() + } +} + +impl Drop for VideoPlayer { + fn drop(&mut self) { + // Signal the decode thread to stop + let _ = self.command_tx.send(PlayerCommand::Stop); + + // Wait for the thread to finish + if let Some(handle) = self.decode_thread.take() { + let _ = handle.join(); + } + } +} + +/// Resolve a source string to a StreamSource, handling YouTube URLs. +fn resolve_source(source: &str) -> Result { + let stream_source = StreamSource::from_string(source); + + if stream_source.is_youtube() { + #[cfg(feature = "youtube")] + { + info!(url = %source, "Extracting YouTube direct URL via yt-dlp"); + let output = std::process::Command::new("yt-dlp") + .args([ + "-f", + // Progressive formats (muxed audio+video) first, then DASH with audio. + // Format 22 = 720p progressive MP4 (h264+aac) + // Format 18 = 360p progressive MP4 (h264+aac) + // DASH fallback requests bestaudio paired, so audio player gets a URL. + "22/18/bestvideo[height<=720][vcodec^=avc1]+bestaudio/bestvideo[height<=720]+bestaudio", + "-g", + "--no-warnings", + "--no-playlist", + source, + ]) + .output() + .map_err(|e| { + if e.kind() == std::io::ErrorKind::NotFound { + "yt-dlp not found in PATH. Install: pip install yt-dlp".to_string() + } else { + format!("Failed to run yt-dlp: {e}") + } + })?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(format!("yt-dlp failed: {}", stderr.trim())); + } + + let stdout = String::from_utf8_lossy(&output.stdout); + let lines: Vec<&str> = stdout.trim().lines().collect(); + + if lines.is_empty() { + return Err("yt-dlp returned no URLs".to_string()); + } + + let video_url = lines[0].to_string(); + + if lines.len() >= 2 { + // DASH format selected: separate video + audio URLs. + // Return as AudioUrl variant so the audio player uses the correct stream. + info!("YouTube DASH streams: video + audio URLs extracted"); + return Ok(StreamSource::UrlWithAudio { + video: video_url, + audio: lines[1].to_string(), + }); + } + + info!("YouTube progressive stream extracted (muxed audio+video)"); + return Ok(StreamSource::Url(video_url)); + } + #[cfg(not(feature = "youtube"))] + { + return Err( + "YouTube support not enabled (compile with --features youtube)".to_string(), + ); + } + } + + Ok(stream_source) +} + +/// Decode thread state holding the active decoder and frame writer. +struct DecodeContext { + decoder: VideoDecoder, + writer: FrameWriter, + info: VideoInfo, + /// The PTS (ms) at which current playback segment started. + base_pts_ms: u64, + /// Wall-clock time when current playback segment started. + playback_start: Instant, + /// Audio player (None if source has no audio or output unavailable). + audio: Option, + /// Frame decoded too early (>100ms ahead), held until its scheduled time. + pending_frame: Option, +} + +/// Main decode loop that runs in a separate thread. +fn decode_loop(state: Arc>, command_rx: Receiver) { + info!("Video decode thread started"); + + let mut ctx: Option = None; + + loop { + let is_playing = lock_state(&state).is_playing(); + let timeout = if is_playing { + Duration::from_millis(1) // Fast polling during playback + } else { + Duration::from_millis(100) // Slower when idle/paused + }; + + match command_rx.recv_timeout(timeout) { + Ok(cmd) => { + debug!(?cmd, "Received video command"); + match cmd { + PlayerCommand::Load { + source, + start_position_ms, + autoplay, + } => { + ctx = handle_load(&state, &source, start_position_ms, autoplay); + }, + PlayerCommand::Play => { + handle_play(&state, &mut ctx); + }, + PlayerCommand::Pause => { + handle_pause(&state, &ctx); + }, + PlayerCommand::Seek { position_ms } => { + handle_seek(&state, &mut ctx, position_ms); + }, + PlayerCommand::SetRate { rate } => { + handle_set_rate(&state, rate); + }, + PlayerCommand::SetVolume { volume } => { + handle_set_volume(&state, &ctx, volume); + }, + PlayerCommand::Stop => { + info!("Video decode thread stopping"); + drop(ctx.take()); + *lock_state(&state) = PlayerState::Idle; + break; + }, + } + }, + Err(mpsc::RecvTimeoutError::Timeout) => { + // Decode next frame if playing + if is_playing && let Some(decode_ctx) = &mut ctx { + decode_next_frame(decode_ctx, &state); + } + }, + Err(mpsc::RecvTimeoutError::Disconnected) => { + info!("Command channel disconnected, stopping decode thread"); + break; + }, + } + } +} + +/// Decode and output the next frame with PTS-based pacing. +/// +/// If a previously-decoded frame was too early (>100ms ahead), it is held in +/// `pending_frame` and retried on the next call instead of being dropped. +fn decode_next_frame(ctx: &mut DecodeContext, state: &Arc>) { + // Use pending frame if available, otherwise decode a new one + let frame_result = if ctx.pending_frame.is_some() { + Ok(ctx.pending_frame.take()) + } else { + ctx.decoder.next_frame() + }; + + match frame_result { + Ok(Some(frame)) => { + // PTS-based frame pacing: sleep until the correct wall-clock time + let target_elapsed_ms = frame.pts_ms.saturating_sub(ctx.base_pts_ms); + let wall_elapsed = ctx.playback_start.elapsed(); + let target_wall = Duration::from_millis(target_elapsed_ms); + + if target_wall > wall_elapsed { + let sleep_dur = target_wall - wall_elapsed; + // Cap sleep to 100ms to stay responsive to commands + if sleep_dur < Duration::from_millis(100) { + thread::sleep(sleep_dur); + } else { + // Frame is too early - hold it for the next iteration + thread::sleep(Duration::from_millis(100)); + ctx.pending_frame = Some(frame); + return; + } + } + + // Write frame to shared memory + match ctx.writer.write_frame(&frame) { + Ok(_written) => { + // Update position in state + let mut s = lock_state(state); + if let PlayerState::Playing { position_ms, .. } = &mut *s { + *position_ms = frame.pts_ms; + } + }, + Err(e) => { + warn!(?e, "Failed to write frame to shared memory"); + }, + } + }, + Ok(None) => { + // End of stream + info!("Video playback complete (end of stream)"); + let mut s = lock_state(state); + if let PlayerState::Playing { info, .. } = s.clone() { + *s = PlayerState::Paused { + info, + position_ms: ctx.info.duration_ms, + }; + } + }, + Err(e) => { + error!(?e, "Decode error"); + *lock_state(state) = PlayerState::Error { + message: format!("Decode error: {e}"), + }; + }, + } +} + +/// Handle a load command - creates decoder and frame writer. +fn handle_load( + state: &Arc>, + source: &str, + start_position_ms: u64, + autoplay: bool, +) -> Option { + info!(source = %source, start_ms = start_position_ms, autoplay, "Loading video"); + + // Set loading state + *lock_state(state) = PlayerState::Loading { + source: source.to_string(), + }; + + // Resolve source (handles YouTube extraction) + let stream_source = match resolve_source(source) { + Ok(s) => s, + Err(e) => { + error!(error = %e, "Failed to resolve video source"); + *lock_state(state) = PlayerState::Error { + message: format!("Source resolution failed: {e}"), + }; + return None; + }, + }; + + // Get the audio source path: use separate audio URL for DASH streams, + // or the same video URL for progressive (muxed) streams. + let audio_source_path = stream_source + .audio_url() + .unwrap_or_else(|| stream_source.as_ffmpeg_input()) + .to_string(); + + // Create the decoder + let mut decoder = match VideoDecoder::with_size(stream_source, DEFAULT_WIDTH, DEFAULT_HEIGHT) { + Ok(d) => d, + Err(e) => { + error!(?e, "Failed to create video decoder"); + *lock_state(state) = PlayerState::Error { + message: format!("Decoder init failed: {e}"), + }; + return None; + }, + }; + + // Create the frame writer (shared memory) + let mut writer = match FrameWriter::create(SHMEM_NAME, DEFAULT_WIDTH, DEFAULT_HEIGHT) { + Ok(w) => w, + Err(e) => { + error!(?e, "Failed to create frame writer"); + *lock_state(state) = PlayerState::Error { + message: format!("Frame writer failed: {e}"), + }; + return None; + }, + }; + + // Build video info from decoder metadata + let content_id = format!("{:016x}", hash_string(source)); + let info = VideoInfo { + content_id: content_id.clone(), + width: decoder.output_width(), + height: decoder.output_height(), + duration_ms: decoder.duration_ms().unwrap_or(0), + fps: decoder.fps().unwrap_or(30.0) as f32, + codec: "h264".to_string(), // ffmpeg-next doesn't easily expose codec name string + is_live: source.contains(".m3u8") || source.contains("/live/"), + title: None, + playback_rate: 1.0, + volume: 1.0, + }; + + writer.set_content_id(&content_id); + writer.set_duration_ms(info.duration_ms); + + // Create audio player (opens the same source independently) + let audio = AudioPlayer::new(&audio_source_path, info.volume, autoplay); + if audio.is_some() { + info!("Audio player initialized"); + } else { + info!("No audio available for this source"); + } + + // Seek to start position if needed + if start_position_ms > 0 { + if let Err(e) = decoder.seek(start_position_ms) { + warn!( + ?e, + "Failed to seek to start position, starting from beginning" + ); + } + if let Some(audio_player) = &audio { + audio_player.seek(start_position_ms); + } + } + + info!( + width = info.width, + height = info.height, + duration_ms = info.duration_ms, + fps = info.fps, + "Video loaded successfully" + ); + + let ctx = DecodeContext { + decoder, + writer, + info: info.clone(), + base_pts_ms: start_position_ms, + playback_start: Instant::now(), + audio, + pending_frame: None, + }; + + if autoplay { + *lock_state(state) = PlayerState::Playing { + info, + position_ms: start_position_ms, + started_at: Instant::now(), + }; + } else { + *lock_state(state) = PlayerState::Paused { + info, + position_ms: start_position_ms, + }; + } + + Some(ctx) +} + +/// Handle a play command. +fn handle_play(state: &Arc>, ctx: &mut Option) { + let mut s = lock_state(state); + if let PlayerState::Paused { info, position_ms } = s.clone() { + info!(position_ms, "Resuming playback"); + + // Update decode context timing + if let Some(decode_ctx) = ctx { + decode_ctx.base_pts_ms = position_ms; + decode_ctx.playback_start = Instant::now(); + + // Resume audio + if let Some(audio) = &decode_ctx.audio { + audio.play(); + } + } + + *s = PlayerState::Playing { + info, + position_ms, + started_at: Instant::now(), + }; + } +} + +/// Handle a pause command. +fn handle_pause(state: &Arc>, ctx: &Option) { + let mut s = lock_state(state); + if let PlayerState::Playing { info, .. } = s.clone() { + // Use the frame writer's last PTS as the accurate position + let current_pos = ctx.as_ref().map(|c| c.writer.last_pts_ms()).unwrap_or(0); + info!(position_ms = current_pos, "Pausing playback"); + + // Pause audio + if let Some(ctx) = ctx + && let Some(audio) = &ctx.audio + { + audio.pause(); + } + + *s = PlayerState::Paused { + info, + position_ms: current_pos, + }; + } +} + +/// Handle a seek command. +fn handle_seek(state: &Arc>, ctx: &mut Option, position_ms: u64) { + // Seek the decoder + if let Some(decode_ctx) = ctx { + if let Err(e) = decode_ctx.decoder.seek(position_ms) { + warn!(?e, position_ms, "Seek failed"); + return; + } + // Reset timing for the new position + decode_ctx.base_pts_ms = position_ms; + decode_ctx.playback_start = Instant::now(); + decode_ctx.pending_frame = None; // Discard stale buffered frame + + // Seek audio + if let Some(audio) = &decode_ctx.audio { + audio.seek(position_ms); + } + } + + let mut s = lock_state(state); + match s.clone() { + PlayerState::Playing { info, .. } => { + info!(position_ms, "Seeking (playing)"); + *s = PlayerState::Playing { + info, + position_ms, + started_at: Instant::now(), + }; + }, + PlayerState::Paused { info, .. } => { + info!(position_ms, "Seeking (paused)"); + *s = PlayerState::Paused { info, position_ms }; + }, + _ => { + warn!("Seek ignored - no video loaded"); + }, + } +} + +/// Handle a set rate command. +fn handle_set_rate(state: &Arc>, rate: f64) { + let mut state = lock_state(state); + if let Some(info) = state.video_info().cloned() { + let mut new_info = info; + new_info.playback_rate = rate.clamp(0.25, 4.0); + debug!(rate = new_info.playback_rate, "Set playback rate"); + + match &mut *state { + PlayerState::Playing { info, .. } + | PlayerState::Paused { info, .. } + | PlayerState::Buffering { info, .. } => { + *info = new_info; + }, + _ => {}, + } + } +} + +/// Handle a set volume command. +fn handle_set_volume(state: &Arc>, ctx: &Option, volume: f32) { + let clamped = volume.clamp(0.0, 1.0); + + // Update audio player volume + if let Some(decode_ctx) = ctx + && let Some(audio) = &decode_ctx.audio + { + audio.set_volume(clamped); + } + + let mut state = lock_state(state); + if let Some(info) = state.video_info().cloned() { + let mut new_info = info; + new_info.volume = clamped; + debug!(volume = new_info.volume, "Set volume"); + + match &mut *state { + PlayerState::Playing { info, .. } + | PlayerState::Paused { info, .. } + | PlayerState::Buffering { info, .. } => { + *info = new_info; + }, + _ => {}, + } + } +} + +/// Simple hash function for content IDs. +fn hash_string(s: &str) -> u64 { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + let mut hasher = DefaultHasher::new(); + s.hash(&mut hasher); + hasher.finish() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_player_creation() { + let player = VideoPlayer::new(); + assert!(matches!(player.state(), PlayerState::Idle)); + } + + #[test] + fn test_hash_string() { + let hash1 = hash_string("test"); + let hash2 = hash_string("test"); + let hash3 = hash_string("other"); + assert_eq!(hash1, hash2); + assert_ne!(hash1, hash3); + } +} diff --git a/projects/nms-cockpit-video/daemon/src/video/state.rs b/projects/nms-cockpit-video/daemon/src/video/state.rs new file mode 100644 index 0000000..da4bf10 --- /dev/null +++ b/projects/nms-cockpit-video/daemon/src/video/state.rs @@ -0,0 +1,148 @@ +//! Video player state and commands. + +use std::time::Instant; + +/// Commands that can be sent to the video player. +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub enum PlayerCommand { + /// Load a video from a URL or file path. + Load { + source: String, + start_position_ms: u64, + autoplay: bool, + }, + /// Start or resume playback. + Play, + /// Pause playback. + Pause, + /// Seek to a position in milliseconds. + Seek { position_ms: u64 }, + /// Set the playback rate (1.0 = normal). + SetRate { rate: f64 }, + /// Set volume (0.0 - 1.0). + SetVolume { volume: f32 }, + /// Stop playback and unload the video. + Stop, +} + +/// Video player state. +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub enum PlayerState { + /// No video loaded. + Idle, + /// Loading a video. + Loading { source: String }, + /// Video is playing. + Playing { + info: VideoInfo, + position_ms: u64, + started_at: Instant, + }, + /// Video is paused. + Paused { info: VideoInfo, position_ms: u64 }, + /// Buffering (waiting for data). + Buffering { + info: VideoInfo, + target_position_ms: u64, + }, + /// Playback error. + Error { message: String }, +} + +#[allow(dead_code)] +impl PlayerState { + /// Check if the player is currently playing. + pub fn is_playing(&self) -> bool { + matches!(self, PlayerState::Playing { .. }) + } + + /// Check if the player is paused. + pub fn is_paused(&self) -> bool { + matches!(self, PlayerState::Paused { .. }) + } + + /// Check if a video is loaded (playing, paused, or buffering). + pub fn has_video(&self) -> bool { + matches!( + self, + PlayerState::Playing { .. } + | PlayerState::Paused { .. } + | PlayerState::Buffering { .. } + ) + } + + /// Get the current video info, if available. + pub fn video_info(&self) -> Option<&VideoInfo> { + match self { + PlayerState::Playing { info, .. } + | PlayerState::Paused { info, .. } + | PlayerState::Buffering { info, .. } => Some(info), + _ => None, + } + } + + /// Get the current position in milliseconds. + pub fn position_ms(&self) -> u64 { + match self { + PlayerState::Playing { + position_ms, + started_at, + .. + } => { + // Calculate current position based on elapsed time + let elapsed_ms = started_at.elapsed().as_millis() as u64; + position_ms.saturating_add(elapsed_ms) + }, + PlayerState::Paused { position_ms, .. } => *position_ms, + PlayerState::Buffering { + target_position_ms, .. + } => *target_position_ms, + _ => 0, + } + } +} + +/// Information about the currently loaded video. +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct VideoInfo { + /// Content identifier (URL or file path hash). + pub content_id: String, + /// Video width in pixels. + pub width: u32, + /// Video height in pixels. + pub height: u32, + /// Duration in milliseconds (0 if unknown/live). + pub duration_ms: u64, + /// Frames per second. + pub fps: f32, + /// Codec name. + pub codec: String, + /// Whether this is a live stream. + pub is_live: bool, + /// Title from metadata. + pub title: Option, + /// Playback rate (1.0 = normal). + pub playback_rate: f64, + /// Volume (0.0 - 1.0). + pub volume: f32, +} + +impl Default for VideoInfo { + fn default() -> Self { + Self { + content_id: String::new(), + width: 1280, + height: 720, + duration_ms: 0, + fps: 30.0, + codec: String::new(), + is_live: false, + title: None, + playback_rate: 1.0, + volume: 1.0, + } + } +} diff --git a/projects/nms-cockpit-video/docs/nms-reverse-engineering.md b/projects/nms-cockpit-video/docs/nms-reverse-engineering.md new file mode 100644 index 0000000..a5cd9e3 --- /dev/null +++ b/projects/nms-cockpit-video/docs/nms-reverse-engineering.md @@ -0,0 +1,373 @@ +# NMS Reverse Engineering Findings + +## Session Info +- **NMS.exe PID**: 70572 +- **NMS.exe Base**: 0x7FF6DF260000 +- **NMS.exe Size**: ~112 MB +- **Game Version**: Current as of 2026-01-23 +- **Player State**: Stationary in ship cockpit + +## Key Structures + +### cGcCameraManager + +**RTTI Chain:** +- Type Descriptor: `NMS.exe+0x4B87380` (string: `.?AVcGcCameraManager@@`) +- Complete Object Locator: `NMS.exe+0x4628A70` +- Vtable: `NMS.exe+0x4528588` (RVA) +- Inherits from: `cTkCameraManager` + +**Singleton Access:** +- Global pointer at: `NMS.exe+0x56666B0` +- Points to heap-allocated instance (current session: 0x2106B20CE80) +- Object size: 0x3340 bytes (from vtable function analysis) + +**Pointer Chain (validated):** +``` +NMS.exe+0x56666B0 → [deref] → cGcCameraManager* + +0x130 → View Matrix (4x4 float, row-major) + +0x1D0 → FoV (float, degrees) + +0x40 → World Position (vec4) +``` + +### cTkCamera (Embedded Camera Object) + +The camera manager embeds `cTkCamera` instances (engine-level camera class). + +**RTTI:** `.?AVcTkCamera@@` +**Vtable:** `NMS.exe+0x45286C0` (RVA) + +**cTkCamera Layout (size: 0xC0 bytes):** +| Offset | Type | Description | Example | +|--------|------|-------------|---------| +| +0x00 | ptr | Vtable (cTkCamera) | | +| +0x08 | - | Padding (zeros) | | +| +0x10 | mat4x4 | View/World Matrix | see below | +| +0x50 | vec4 | World Position | | +| +0x60 | mat4x4 | View Matrix (copy) | | +| +0xA0 | vec4 | World Position (copy) | | +| +0xB0 | float | FoV (degrees) | 75.0 | +| +0xB4 | float | Aspect/multiplier | 1.5 or 1.0 | +| +0xB8 | u32 | Active flag | 1 | + +**Embedded cTkCamera instances in Camera Manager:** +- Camera Manager + 0x120 = First cTkCamera (ACTIVE, FoV=75, aspect=1.5) +- Camera Manager + 0x1E0 = Second cTkCamera (FoV=75, aspect=1.0) +- Camera Manager + 0x2A0 = Third cTkCamera (interpolated state?) + +### Camera Manager Object Layout + +| Offset | Type | Description | Example Value | +|--------|------|-------------|---------------| +| +0x00 | ptr | Vtable pointer | cGcCameraManager vtable | +| +0x08 | - | Base class padding (zeros) | 0 | +| +0x40 | vec4 | World position (x, y, z, sector=1024) | (-899072, -48128, -2243584, 1024) | +| +0x50 | ptr[25] | Camera behaviour pointers | 25 camera modes | +| +0x118 | u32 | **Camera mode enum** (NOT index!) | 0x10=cockpit, 0x40=on-foot | +| +0x11C | u32 | Camera behaviour count | 25 | +| +0x120 | cTkCamera | First embedded camera (ACTIVE) | | +| +0x130 | mat4x4 | **View/World Matrix** (= cTkCamera+0x10) | Current frame | +| +0x170 | vec4 | World position (= cTkCamera+0x50) | | +| +0x180 | mat4x4 | View Matrix copy (= cTkCamera+0x60) | | +| +0x1C0 | vec4 | World position copy (= cTkCamera+0xA0) | | +| +0x1D0 | float | **FoV degrees** (= cTkCamera+0xB0) | 75.0 | +| +0x1D4 | float | Aspect/multiplier (= cTkCamera+0xB4) | 1.5 | +| +0x1E0 | cTkCamera | Second embedded camera | | +| +0x2A0 | cTkCamera | Third embedded camera | | + +### View Matrix Format (at +0x130) + +Row-major 4x4 float matrix (world-to-local transform): +``` +Row 0: (right.x, right.y, right.z, 0.0) - Right vector +Row 1: (up.x, up.y, up.z, 0.0) - Up vector +Row 2: (forward.x, forward.y, forward.z, 0.0) - Forward vector +Row 3: (pos.x, pos.y, pos.z, 1.0) - Local camera position +``` + +**Current values (cockpit, stationary):** +``` +Right: (-0.2428, 0.4305, 0.8693, 0.0) +Up: ( 0.2465, 0.8941, -0.3740, 0.0) +Forward: (-0.9383, 0.1235, -0.3232, 0.0) +Position: (1425.42, -252.66, 803.52, 1.0) ← LOCAL coords +``` + +The position in row 3 is LOCAL coordinates (relative to sector/chunk origin). +The WORLD position at +0x40 gives absolute coordinates. + +### Camera Behaviour Sub-Objects (ptr[0..24] at +0x50) + +25 camera behaviour objects, indexed from 0. Active mode stored at +0x118. + +**When in cockpit: Mode = 0x10 (16)** + +| Index | Class | Notes | +|-------|-------|-------| +| 0 | `cGcCameraBehaviourFly` | Flight camera | +| 1 | `cTkCameraBehaviourInterpolate` | Camera transitions (engine-level) | +| 2 | `cGcCameraBehaviourOffset` | Offset camera | +| 3 | `cGcCameraBehaviourCharacter` | Character view | +| 4 | `cGcCameraBehaviourFirstPerson` | First person (on foot) | +| 5 | `cGcCameraBehaviourThirdPerson` | Generic 3rd person | +| 6 | `cGcCameraBehaviourPlayerThirdPerson` | Player 3rd person | +| 7 | `cGcCameraBehaviourGalacticTransition` | Galaxy map transition | +| 8 | `cGcCameraBehaviourGalacticNavigation` | Galaxy map navigation | +| 9 | `cGcCameraBehaviourGalacticLookAt` | Galaxy map look-at | +| 10 | `cGcCameraBehaviourInteraction` | NPC interaction | +| 11 | `cGcCameraBehaviourLookAt` | Generic look-at | +| 12 | `cGcCameraBehaviourAerialView` | Aerial/overhead view | +| 13 | `cGcCameraBehaviourScreenshot` | Screenshot mode | +| 14 | `cGcCameraBehaviourPhotoMode` | Photo mode | +| 15 | `cGcCameraBehaviourAmbient` | Ambient/idle camera | +| 16 | `cGcCameraBehaviourModelView` | **COCKPIT CAMERA** | +| 17 | `cGcCameraBehaviourAnimation` | Cutscene/animation | +| 18 | `cGcCameraBehaviourFollowTarget` | Follow target | +| 19 | `cGcCameraBehaviourShipWarp` | Ship warp effect | +| 20 | `cGcCameraBehaviourCockpitTransition` | Enter/exit cockpit | +| 21 | `cGcCameraBehaviourBuildingMode` | Building placement | +| 22 | `cGcCameraBehaviourFocusBuildingMode` | Building focus | +| 23 | `cGcCameraBehaviourOrbitBuildingMode` | Building orbit | +| 24 | `cGcCameraBehaviourFreighterWarp` | Freighter warp effect | + +**Key modes for mod:** +- Index 16 (ModelView) = cockpit camera, used when mode == 0x10 +- Index 19 (ShipWarp) = embedded at ModelView+0x90, used during warp +- Index 20 (CockpitTransition) = transition in/out of cockpit + +### cGcCameraBehaviourModelView (Cockpit Camera, ptr[16]) + +**Vtable:** 0x7FF6E37EAD58 (`.?AVcGcCameraBehaviourModelView@@`) + +**Layout:** +| Offset | Type | Description | Value | +|--------|------|-------------|-------| +| +0x00 | ptr | Vtable | | +| +0x28 | float | Unknown | 2.0 | +| +0x2C | float | Unknown | 3.0 | +| +0x30 | float | Unknown | 1.0 | +| +0x34 | float | Unknown | 2.0 | +| +0x36 | u8[2] | Flags | (1, 1) | +| +0x38 | u32[2] | Masks? | (0x7FFFF, 0x7FFFF) | +| +0x60 | float | Scale? | 1.0 | +| +0x68 | float | Scale? | 1.0 | +| +0x90 | ptr | Embedded vtable (ShipWarp) | | + +**Camera Configuration Parameters (at +0x120):** +| Offset | Value | Possible Meaning | +|--------|-------|------------------| +| +0x120 | -4.0 | Min distance offset? | +| +0x124 | 0.5 | Interpolation speed? | +| +0x128 | 3.0 | Default distance? | +| +0x12C | 20.0 | Max distance? | +| +0x130 | 5.0 | Look distance? | +| +0x138 | 1.5 | Aspect/scale? | +| +0x140 | 0.5 | Sensitivity? | +| +0x148 | 0.1 | Min speed? | +| +0x14C | 1.5 | Speed multiplier? | +| +0x168 | 30.0 | Max angle? | +| +0x170 | 5.0 | Transition speed? | +| +0x1C0 | 50.0 | Far distance? | +| +0x1C8 | 15.0 | Medium distance? | +| +0x1F0 | 30.0 | Look range? | +| +0x1F4 | -15.0 | Vertical offset? | + +### cGcApplication + +**RTTI:** `.?AVcGcApplication@@` +- Type Descriptor: `NMS.exe+0x4C08EC8` +- COL: `NMS.exe+0x0468BA30` +- Vtable: `NMS.exe+0x0518C858` +- **Static global object** at: `NMS.exe+0x068F7460` (NOT heap-allocated!) +- Does NOT directly contain camera manager pointer + +### Global Pointer Context + +The camera manager global pointer is part of a global manager registry: +``` +NMS.exe+0x56666A4: count = 4 +NMS.exe+0x56666A8: ptr to unknown manager A +NMS.exe+0x56666B0: ptr to cGcCameraManager ← TARGET +NMS.exe+0x56666B8: ptr to unknown manager C +NMS.exe+0x56666C0: [360 (0x168)] +``` + +Float data near the registry (at NMS.exe+0x5666690): +- 0.1169, 0.5649, 1.0, 0.9 (possibly rendering parameters) + +## Cross-Mode Comparison (VERIFIED LIVE) + +The cTkCamera at +0x120 ALWAYS reflects the active rendering camera: + +| Field | Cockpit (mode=0x10) | On-Foot (mode=0x40) | Notes | +|-------|---------------------|---------------------|-------| +| FoV (+0x1D0) | 75.0 | 70.0 | Mode-specific | +| Aspect (+0x1D4) | 1.5 | 1.0 | Cockpit uses wider FoV multiplier | +| View Matrix (+0x130) | Changes with look | Changes with look | ALWAYS live | +| Position w | 1.0 | ~0.989 | Ignore w, use xyz only | +| World Pos (+0x40) | Sector coords | Same sector | Only changes on sector boundary | +| Mode (+0x118) | 0x10 (16) | 0x40 (64) | Powers of 2 = bitmask/enum | + +**Key Insight**: For the mod, only check mode == 0x10 (cockpit) before rendering video overlay. +The view matrix and FoV at +0x130/+0x1D0 always give current frame data regardless of mode. + +## Cockpit-Related Strings + +**SpaceMap cockpit parameters (for hologram positioning):** +- `SpaceMapCockpitOffset` +- `SpaceMapCockpitScale` +- `SpaceMapCockpitScaleAdjustDropShip` +- `SpaceMapCockpitScaleAdjustFighter` +- `SpaceMapCockpitScaleAdjustScientific` +- `SpaceMapCockpitScaleAdjustShuttle` +- `SpaceMapCockpitScaleAdjustRoyal` + +**Camera strings:** +- `CameraLook`, `CameraLookX`, `CameraLookY` +- `CameraRollLeft`, `CameraRollRight` +- `CameraHeight`, `CameraDistanceFade`, `CameraRelative` + +## Module Info + +Notable loaded modules: +- Vulkan renderer (vulkan-1.dll) +- D3D12 (d3d12.dll, dxgi.dll) +- DLSS (nvngx_dlss.dll) +- OpenVR (openvr_api.dll) +- Steam (steam_api64.dll) +- PlayFab networking + +## Pattern Scanning Strategy for Mod + +### Approach 1: RTTI-Based (Most Robust, Cross-Version) +``` +1. Pattern scan NMS.exe for bytes: ".?AVcGcCameraManager@@" +2. Type Descriptor = match_addr - 0x10 (vtable+internal ptrs precede name) +3. type_desc_rva = type_desc_addr - nms_base +4. Scan .rdata for 4-byte value matching type_desc_rva (finds COL) +5. Verify COL: signature==1, pSelf matches +6. Vtable = COL_addr + sizeof(COL) = COL + 24 (COL ptr lives at vtable[-8]) +7. Scan ALL process memory for 8-byte vtable address → singleton instance +``` + +### Approach 2: Global Pointer (Fastest, Version-Specific) +``` +1. Read pointer at NMS.exe + 0x56666B0 +2. WARNING: This RVA changes with each game update! +``` + +### Approach 3: Code Signature (TODO - Medium Robustness) +- Find unique instruction sequence that accesses the global pointer +- Use wildcard bytes for the RIP-relative displacement +- More robust than Approach 2, less overhead than Approach 1 + +## Key Offsets for Mod Implementation + +```csharp +// C# offsets for Reloaded-II mod (cGcCameraManager) +const int OFFSET_WORLD_POS = 0x40; // vec4 (x, y, z, sector_size=1024) +const int OFFSET_CAM_BEHAVIOURS = 0x50; // ptr[25] camera behaviour objects +const int OFFSET_ACTIVE_CAM_IDX = 0x118; // u32, current=16 for cockpit +const int OFFSET_CAM_COUNT = 0x11C; // u32, always 25 +const int OFFSET_ACTIVE_CAMERA = 0x120; // cTkCamera embedded object +const int OFFSET_VIEW_MATRIX = 0x130; // mat4x4 row-major (= cTkCamera+0x10) +const int OFFSET_VIEW_MATRIX_PREV = 0x180;// mat4x4 (previous/copy) +const int OFFSET_FOV = 0x1D0; // float, degrees (currently 75.0) +const int OFFSET_ASPECT = 0x1D4; // float (currently 1.5) +``` + +## Mod Implementation Notes + +### Video Screen Placement +To place a video screen in the cockpit: +1. Read view matrix at +0x130 +2. Extract forward vector (row 2) and up vector (row 1) +3. Position screen at: camera_pos + forward * SCREEN_DISTANCE + up * SCREEN_HEIGHT_OFFSET +4. Orient screen to face camera (billboard or fixed orientation) +5. Scale based on desired apparent size / FoV + +### Projection Matrix +Not stored in camera manager. Near/far clip planes are NOT stored as named struct fields +(searched for "NearPlane"/"FarPlane" - only DOF and interaction-related matches found). +Near/far are likely hardcoded in the renderer or computed per-frame. + +For the overlay mod, projection is computed independently: +``` +fov_rad = FoV * PI / 180 // Read from +0x1D0 +aspect = screen_width / screen_height // Query from swapchain +near = 0.1 // Reasonable default for overlay +far = 1000.0 // Only need range for video screen quad + +proj[0][0] = 1.0 / (aspect * tan(fov_rad / 2)) +proj[1][1] = 1.0 / tan(fov_rad / 2) +proj[2][2] = far / (near - far) +proj[2][3] = -1.0 +proj[3][2] = (near * far) / (near - far) +``` + +### Ship Type Enum (eShipClass) + +String literals found at NMS.exe .rdata (contiguous at ~0x7FF6E22101A0): +``` +"Ship" // 0 - Base/generic +"Dropship" // 1 - Sentinel Interceptor +"Fighter" // 2 +"Shuttle" // 3 +"PlayerFreighter" // 4 - Capital ship +"Royal" // 5 - Exotic +"Sail" // 6 - Solar +``` + +Missing from this table: "Scientific" (Explorer) and "Hauler" - may use different names internally. + +**SpaceMapCockpitScaleAdjust parameters exist per type:** +- `SpaceMapCockpitScaleAdjustFighter` +- `SpaceMapCockpitScaleAdjustScientific` +- `SpaceMapCockpitScaleAdjustShuttle` +- `SpaceMapCockpitScaleAdjustRoyal` +- `SpaceMapCockpitScaleAdjustSail` +- `SpaceMapCockpitScaleAdjustDropShip` + +These confirm per-ship-type cockpit geometry differences. Finding the active ship type +at runtime would require tracing the player state → ship ownership reference chain, +which is complex. For basic mod functionality, dynamic FoV reading is sufficient. + +### Rendering Hook +NMS uses **Vulkan** as primary renderer (vulkan-1.dll loaded). +Hook target: `vkQueuePresentKHR` or equivalent. +Alternative: DXGI hook if D3D12 mode is selected. + +## Tools Built + +### mem-scanner (Rust binary) +Location: `tools/mem-scanner/` +```bash +mem-scanner.exe [--min-addr ] [--max-addr ] [--max-results ] +``` +- Scans ALL committed readable memory regions (heap + stack + mapped) +- Fast pattern matching with wildcard support (??) +- JSON output for MCP integration +- Scanned 2.87 GB in seconds to find the singleton + +### Memory Explorer MCP Server (Python) +Location: `tools/mcp/mcp_memory_explorer/` +- STDIO MCP server for interactive memory exploration +- Uses pymem for process access +- Calls mem-scanner for heap scanning +- Tools: attach, read, dump, scan, watch, resolve pointers + +## TODO + +### Completed +- [x] Verify view matrix changes when camera moves (CONFIRMED: live every frame) +- [x] Identify all 25 camera behaviour types (DONE: full RTTI class map) +- [x] Find near/far clip plane values (NOT stored in camera; overlay computes its own) + +### Remaining +- [ ] Find cockpit 3D model transform (for precise screen placement) +- [ ] Create stable code signature for cross-version pattern scanning +- [ ] Test offsets across game restarts (verify consistency) +- [ ] Hook Vulkan vkQueuePresentKHR for rendering overlay +- [ ] Find ship type at runtime (trace player state → ship → eShipClass enum) +- [ ] Map mode enum values to camera behaviour indices (0x10→16?, 0x40→?) +- [ ] Find render resolution / swapchain extent for overlay sizing diff --git a/projects/nms-cockpit-video/injector/Cargo.toml b/projects/nms-cockpit-video/injector/Cargo.toml new file mode 100644 index 0000000..738d1d0 --- /dev/null +++ b/projects/nms-cockpit-video/injector/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "nms-cockpit-injector" +description = "Vulkan texture injection DLL for NMS Cockpit Video Player" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +authors.workspace = true + +[lib] +crate-type = ["cdylib"] + +[dependencies] +ash = { workspace = true } +retour = { workspace = true } +windows = { workspace = true } +itk-shmem = { path = "../../../core/itk-shmem" } +itk-ipc = { path = "../../../core/itk-ipc" } +itk-protocol = { path = "../../../core/itk-protocol" } +serde = { workspace = true } +tracing = { workspace = true } +once_cell = "1.19" + +[build-dependencies] +naga = { version = "24", features = ["wgsl-in", "spv-out"] } + +[lints] +workspace = true diff --git a/projects/nms-cockpit-video/injector/build.rs b/projects/nms-cockpit-video/injector/build.rs new file mode 100644 index 0000000..68ad402 --- /dev/null +++ b/projects/nms-cockpit-video/injector/build.rs @@ -0,0 +1,66 @@ +//! Build script: compile WGSL shaders to SPIR-V using naga. + +use naga::back::spv; +use naga::valid::{Capabilities, ValidationFlags, Validator}; +use std::fs; +use std::path::Path; + +fn main() { + println!("cargo:rerun-if-changed=shaders/"); + + let out_dir = std::env::var("OUT_DIR").unwrap(); + + compile_wgsl( + "shaders/quad.vert.wgsl", + &format!("{}/quad.vert.spv", out_dir), + naga::ShaderStage::Vertex, + ); + + compile_wgsl( + "shaders/quad.frag.wgsl", + &format!("{}/quad.frag.spv", out_dir), + naga::ShaderStage::Fragment, + ); +} + +fn compile_wgsl(input: &str, output: &str, stage: naga::ShaderStage) { + let source = + fs::read_to_string(input).unwrap_or_else(|e| panic!("Failed to read {}: {}", input, e)); + + // Parse WGSL + let module = naga::front::wgsl::parse_str(&source) + .unwrap_or_else(|e| panic!("Failed to parse {}: {}", input, e)); + + // Validate + let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all()); + let info = validator + .validate(&module) + .unwrap_or_else(|e| panic!("Validation failed for {}: {:?}", input, e)); + + // Generate SPIR-V + let options = spv::Options { + lang_version: (1, 0), + flags: spv::WriterFlags::empty(), + ..Default::default() + }; + + let pipeline_options = spv::PipelineOptions { + shader_stage: stage, + entry_point: "main".to_string(), + }; + + let spv = spv::write_vec(&module, &info, &options, Some(&pipeline_options)) + .unwrap_or_else(|e| panic!("SPIR-V generation failed for {}: {}", input, e)); + + // Write as raw bytes (u32 words -> u8 bytes) + let bytes: Vec = spv.iter().flat_map(|word| word.to_le_bytes()).collect(); + let out_path = Path::new(output); + fs::write(out_path, &bytes).unwrap_or_else(|e| panic!("Failed to write {}: {}", output, e)); + + println!( + "cargo:warning=Compiled {} -> {} ({} bytes)", + input, + output, + bytes.len() + ); +} diff --git a/projects/nms-cockpit-video/injector/shaders/quad.frag.wgsl b/projects/nms-cockpit-video/injector/shaders/quad.frag.wgsl new file mode 100644 index 0000000..549784f --- /dev/null +++ b/projects/nms-cockpit-video/injector/shaders/quad.frag.wgsl @@ -0,0 +1,7 @@ +@group(0) @binding(0) var t_video: texture_2d; +@group(0) @binding(1) var s_video: sampler; + +@fragment +fn main(@location(0) uv: vec2) -> @location(0) vec4 { + return textureSample(t_video, s_video, uv); +} diff --git a/projects/nms-cockpit-video/injector/shaders/quad.vert.wgsl b/projects/nms-cockpit-video/injector/shaders/quad.vert.wgsl new file mode 100644 index 0000000..42694c0 --- /dev/null +++ b/projects/nms-cockpit-video/injector/shaders/quad.vert.wgsl @@ -0,0 +1,23 @@ +struct PushConstants { + mvp: mat4x4, +} + +var pc: PushConstants; + +struct VertexInput { + @location(0) pos: vec3, + @location(1) uv: vec2, +} + +struct VertexOutput { + @builtin(position) clip_pos: vec4, + @location(0) frag_uv: vec2, +} + +@vertex +fn main(input: VertexInput) -> VertexOutput { + var out: VertexOutput; + out.clip_pos = pc.mvp * vec4(input.pos, 1.0); + out.frag_uv = input.uv; + return out; +} diff --git a/projects/nms-cockpit-video/injector/src/camera/mod.rs b/projects/nms-cockpit-video/injector/src/camera/mod.rs new file mode 100644 index 0000000..9b744bc --- /dev/null +++ b/projects/nms-cockpit-video/injector/src/camera/mod.rs @@ -0,0 +1,187 @@ +//! Camera state reader for NMS's cGcCameraManager. +//! +//! Reads the camera singleton from NMS process memory to get: +//! - Camera mode (cockpit, on-foot, photo, etc.) +//! - Field of view (degrees) +//! - Aspect ratio +//! +//! Memory layout (NMS 5.x): +//! ```text +//! NMS.exe + 0x56666B0 → pointer to cGcCameraManager singleton +//! +0x118 Camera mode (u32, cockpit = 0x10) +//! +0x130 View matrix (4x4 f32, row-major) +//! +0x1D0 FoV (f32, degrees) +//! +0x1D4 Aspect ratio (f32) +//! ``` +//! +//! The singleton address is located via pattern scanning at runtime to +//! support multiple game versions. Falls back to a known RVA if the +//! scan fails. + +pub mod pattern_scan; +pub mod projection; + +use crate::log::vlog; +use std::sync::atomic::{AtomicU64, Ordering}; +use windows::Win32::System::LibraryLoader::GetModuleHandleA; + +/// Fallback RVA offset to the cGcCameraManager singleton pointer (NMS 5.x). +/// Used only when pattern scanning fails to locate the pointer dynamically. +const CAMERA_MANAGER_RVA_FALLBACK: usize = 0x56666B0; + +/// Offset to camera mode field. +const OFFSET_CAMERA_MODE: usize = 0x118; + +/// Offset to view matrix (4x4 f32, row-major). +const OFFSET_VIEW_MATRIX: usize = 0x130; + +/// Offset to FoV in degrees. +const OFFSET_FOV: usize = 0x1D0; + +/// Offset to aspect ratio. +const OFFSET_ASPECT: usize = 0x1D4; + +/// Camera mode value for cockpit view. +pub const CAMERA_MODE_COCKPIT: u32 = 0x10; + +/// Camera state read from NMS memory. +#[derive(Clone, Debug)] +pub struct CameraState { + /// Camera mode (0x10 = cockpit). + pub mode: u32, + /// Vertical field of view in degrees. + pub fov_deg: f32, + /// Aspect ratio (width / height). + pub aspect: f32, + /// View matrix (4x4, row-major as stored by NMS). + pub view_matrix: [f32; 16], +} + +/// Reader for NMS camera state from process memory. +pub struct CameraReader { + /// Address of the singleton pointer (NMS.exe base + RVA). + singleton_ptr_addr: usize, + /// Frame counter for periodic logging. + log_counter: AtomicU64, +} + +impl CameraReader { + /// Create a new camera reader. + /// + /// Uses pattern scanning to locate cGcCameraManager across game versions, + /// falling back to the known RVA if pattern scanning fails. + /// + /// # Safety + /// NMS.exe must be loaded in the current process. + pub unsafe fn new() -> Result { + let base = get_nms_base()?; + + // Try pattern scanning first for cross-version robustness + let singleton_ptr_addr = match pattern_scan::find_camera_manager_ptr() { + Some(addr) => { + vlog!("CameraReader: pattern scan found singleton at 0x{:X}", addr); + addr + }, + None => { + let fallback = base + CAMERA_MANAGER_RVA_FALLBACK; + vlog!( + "CameraReader: using fallback RVA 0x{:X} (pattern scan failed)", + fallback + ); + fallback + }, + }; + + vlog!( + "CameraReader: NMS.exe base=0x{:X} singleton_ptr=0x{:X}", + base, + singleton_ptr_addr + ); + + // Check if singleton is already valid (may be null during loading) + let ptr = *(singleton_ptr_addr as *const usize); + if ptr != 0 { + vlog!("CameraReader: singleton at 0x{:X}", ptr); + } else { + vlog!("CameraReader: singleton not yet initialized (will poll in read())"); + } + + Ok(Self { + singleton_ptr_addr, + log_counter: AtomicU64::new(0), + }) + } + + /// Read the current camera state from NMS memory. + /// + /// Returns `None` if the singleton pointer is null (e.g., during loading screens). + /// + /// # Safety + /// The singleton pointer must point to valid cGcCameraManager memory. + pub unsafe fn read(&self) -> Option { + let singleton = *(self.singleton_ptr_addr as *const usize); + if singleton == 0 { + return None; + } + + let mode = *((singleton + OFFSET_CAMERA_MODE) as *const u32); + let fov_deg = *((singleton + OFFSET_FOV) as *const f32); + let aspect = *((singleton + OFFSET_ASPECT) as *const f32); + + // Read view matrix (16 contiguous f32s) + let matrix_ptr = (singleton + OFFSET_VIEW_MATRIX) as *const [f32; 16]; + let view_matrix = *matrix_ptr; + + // Sanity check: FoV should be reasonable (1-179 degrees) + if !(1.0..=179.0).contains(&fov_deg) { + let count = self.log_counter.fetch_add(1, Ordering::Relaxed); + if count.is_multiple_of(300) { + vlog!("CameraReader: suspicious FoV={}, skipping", fov_deg); + } + return None; + } + + // Sanity check: aspect should be reasonable (0.1 - 10.0) + if !(0.1..=10.0).contains(&aspect) { + let count = self.log_counter.fetch_add(1, Ordering::Relaxed); + if count.is_multiple_of(300) { + vlog!("CameraReader: suspicious aspect={}, skipping", aspect); + } + return None; + } + + let count = self.log_counter.fetch_add(1, Ordering::Relaxed); + if count.is_multiple_of(600) { + vlog!( + "CameraReader: mode=0x{:X} fov={:.1} aspect={:.3}", + mode, + fov_deg, + aspect + ); + } + + Some(CameraState { + mode, + fov_deg, + aspect, + view_matrix, + }) + } + + /// Check if the singleton is currently valid. + /// + /// # Safety + /// Must be called from a thread that can safely read NMS memory. + pub unsafe fn is_valid(&self) -> bool { + let singleton = *(self.singleton_ptr_addr as *const usize); + singleton != 0 + } +} + +/// Get the base address of NMS.exe in the current process. +unsafe fn get_nms_base() -> Result { + // GetModuleHandleA(NULL) returns the base of the main executable + let handle = GetModuleHandleA(windows::core::PCSTR::null()) + .map_err(|e| format!("GetModuleHandle(NULL) failed: {}", e))?; + Ok(handle.0 as usize) +} diff --git a/projects/nms-cockpit-video/injector/src/camera/pattern_scan.rs b/projects/nms-cockpit-video/injector/src/camera/pattern_scan.rs new file mode 100644 index 0000000..0758e24 --- /dev/null +++ b/projects/nms-cockpit-video/injector/src/camera/pattern_scan.rs @@ -0,0 +1,166 @@ +//! Pattern scanner for locating game structures by byte signatures. +//! +//! Searches the NMS.exe `.text` section for instruction patterns that +//! reference the cGcCameraManager singleton, making the injector robust +//! across game updates where the RVA may change. + +use crate::log::vlog; +use windows::Win32::System::LibraryLoader::GetModuleHandleA; + +/// A byte pattern entry: either a specific byte or a wildcard. +#[derive(Clone, Copy)] +enum PatternByte { + Exact(u8), + Wildcard, +} + +/// Parse an IDA-style pattern string (e.g., "48 8B 05 ?? ?? ?? ?? 48 85 C0"). +fn parse_pattern(pattern: &str) -> Vec { + pattern + .split_whitespace() + .map(|b| { + if b == "??" || b == "?" { + PatternByte::Wildcard + } else { + PatternByte::Exact(u8::from_str_radix(b, 16).unwrap_or(0)) + } + }) + .collect() +} + +/// Scan a memory region for a byte pattern. Returns the offset from `base` where the +/// pattern starts, or `None` if not found. +unsafe fn scan_region(base: *const u8, size: usize, pattern: &[PatternByte]) -> Option { + if pattern.is_empty() || size < pattern.len() { + return None; + } + + let end = size - pattern.len(); + for i in 0..=end { + let mut matched = true; + for (j, pb) in pattern.iter().enumerate() { + match pb { + PatternByte::Exact(expected) => { + if *base.add(i + j) != *expected { + matched = false; + break; + } + }, + PatternByte::Wildcard => {}, + } + } + if matched { + return Some(i); + } + } + None +} + +/// Get the NMS.exe module base and size from PE headers. +unsafe fn get_module_text_section() -> Option<(*const u8, usize)> { + let handle = GetModuleHandleA(windows::core::PCSTR::null()).ok()?; + let base = handle.0 as *const u8; + + // Parse PE headers to find .text section + let dos_header = base as *const u16; + if *dos_header != 0x5A4D { + // Not a valid MZ header + return None; + } + + let e_lfanew = *(base.add(0x3C) as *const u32) as usize; + let pe_header = base.add(e_lfanew); + + // PE signature check + if *(pe_header as *const u32) != 0x4550 { + return None; + } + + // Optional header starts at PE + 24 + let optional_header = pe_header.add(24); + let size_of_code = *(optional_header.add(4) as *const u32) as usize; + let base_of_code = *(optional_header.add(20) as *const u32) as usize; + + let text_start = base.add(base_of_code); + Some((text_start, size_of_code)) +} + +/// Known patterns for cGcCameraManager singleton access. +/// +/// These patterns match the x64 instruction sequence that loads the global pointer: +/// mov rax/rcx, [rip + disp32] ; Load cGcCameraManager* +/// test rax/rcx, rax/rcx ; Null check +/// jz/je ... ; Skip if null (loading screen) +/// +/// The RIP-relative displacement at offset 3 resolves to the global pointer address. +const CAMERA_PATTERNS: &[(&str, usize)] = &[ + // mov rax, [rip+disp32]; test rax, rax; jz + ("48 8B 05 ?? ?? ?? ?? 48 85 C0 0F 84", 3), + // mov rcx, [rip+disp32]; test rcx, rcx; jz + ("48 8B 0D ?? ?? ?? ?? 48 85 C9 0F 84", 3), + // mov rax, [rip+disp32]; test rax, rax; je (short) + ("48 8B 05 ?? ?? ?? ?? 48 85 C0 74", 3), + // mov rcx, [rip+disp32]; test rcx, rcx; je (short) + ("48 8B 0D ?? ?? ?? ?? 48 85 C9 74", 3), +]; + +/// Scan NMS.exe for the cGcCameraManager singleton pointer address. +/// +/// Returns the absolute address of the global pointer, or None if not found. +/// The returned address, when dereferenced, gives the cGcCameraManager instance. +pub unsafe fn find_camera_manager_ptr() -> Option { + let (text_base, text_size) = get_module_text_section()?; + let module_base = GetModuleHandleA(windows::core::PCSTR::null()).ok()?.0 as usize; + + vlog!( + "Pattern scan: text section at 0x{:X}, size=0x{:X}", + text_base as usize, + text_size + ); + + for (pattern_str, disp_offset) in CAMERA_PATTERNS { + let pattern = parse_pattern(pattern_str); + + if let Some(match_offset) = scan_region(text_base, text_size, &pattern) { + // Calculate the RIP-relative target address + let instr_addr = text_base.add(match_offset) as usize; + let disp_addr = instr_addr + disp_offset; + let disp = *(disp_addr as *const i32); + + // RIP-relative: target = next_instruction + displacement + // The instruction is 7 bytes (48 8B 05/0D + 4-byte disp) + let next_instr = instr_addr + 7; + let target = (next_instr as isize + disp as isize) as usize; + + // Validate: target should be within the module's data sections + let rva = target.wrapping_sub(module_base); + if rva > 0x1000 && rva < 0x10000000 { + // Try to dereference and check for a valid vtable pointer + let singleton_ptr = *(target as *const usize); + if singleton_ptr != 0 { + let vtable = *(singleton_ptr as *const usize); + // Vtable should point within the module + let vtable_rva = vtable.wrapping_sub(module_base); + if vtable_rva > 0x1000 && vtable_rva < 0x10000000 { + vlog!( + "Pattern scan: found cGcCameraManager at RVA 0x{:X} (pattern match at 0x{:X})", + rva, + instr_addr - module_base + ); + return Some(target); + } + } else { + // Singleton is null (loading screen) - still valid, just not initialized yet + vlog!( + "Pattern scan: found likely cGcCameraManager at RVA 0x{:X} (singleton null, loading screen?)", + rva + ); + return Some(target); + } + } + } + } + + vlog!("Pattern scan: no cGcCameraManager pattern matched, using fallback RVA"); + None +} diff --git a/projects/nms-cockpit-video/injector/src/camera/projection.rs b/projects/nms-cockpit-video/injector/src/camera/projection.rs new file mode 100644 index 0000000..d55763f --- /dev/null +++ b/projects/nms-cockpit-video/injector/src/camera/projection.rs @@ -0,0 +1,183 @@ +//! Projection and MVP computation for cockpit quad placement. +//! +//! The quad is placed at a fixed position in camera-local (view) space, +//! so it stays fixed relative to the cockpit regardless of where the +//! player is looking. The MVP is: Projection * Model. +//! +//! Cockpit quad corners in view space (right-handed, Y-up, -Z forward): +//! ```text +//! TL(-0.25, 0.15, -0.5) TR(0.25, 0.15, -0.5) +//! BL(-0.25, -0.05, -0.5) BR(0.25, -0.05, -0.5) +//! ``` + +/// Near clip plane for perspective projection. +const NEAR: f32 = 0.01; +/// Far clip plane for perspective projection. +const FAR: f32 = 100.0; + +/// Cockpit quad placement parameters (view space). +/// These define where the video quad appears in the cockpit. +const QUAD_HALF_WIDTH: f32 = 0.25; +const QUAD_CENTER_Y: f32 = 0.05; +const QUAD_HALF_HEIGHT: f32 = 0.10; +const QUAD_DEPTH: f32 = -0.5; // 0.5m in front of camera (-Z is forward) + +/// Compute the MVP matrix for the cockpit video quad. +/// +/// The model matrix maps the unit quad ([-1,1] x [-1,1] at Z=0) to the +/// cockpit screen position in view space. The projection matrix then +/// projects to Vulkan clip space. +/// +/// Returns a column-major 4x4 matrix suitable for push constants. +pub fn compute_cockpit_mvp(fov_deg: f32, aspect: f32) -> [f32; 16] { + let model = cockpit_model_matrix(); + let projection = perspective_vulkan(fov_deg, aspect, NEAR, FAR); + mat4_multiply(&projection, &model) +} + +/// Build the model matrix that positions the unit quad at cockpit coordinates. +/// +/// Maps: +/// - X: [-1,1] → [-QUAD_HALF_WIDTH, QUAD_HALF_WIDTH] +/// - Y: [-1,1] → [QUAD_CENTER_Y + QUAD_HALF_HEIGHT, QUAD_CENTER_Y - QUAD_HALF_HEIGHT] +/// (note: Y=-1 in model space is top in NDC after Vulkan Y-flip, +/// so Y=-1 maps to the higher Y in view space) +/// - Z: 0 → QUAD_DEPTH +fn cockpit_model_matrix() -> [f32; 16] { + let sx = QUAD_HALF_WIDTH; // 0.25 + let sy = -QUAD_HALF_HEIGHT; // -0.10 (negative because model Y=-1 = top = higher view Y) + let ty = QUAD_CENTER_Y; // 0.05 + let tz = QUAD_DEPTH; // -0.5 + + // Column-major 4x4: scale + translate + [ + sx, 0.0, 0.0, 0.0, // column 0 + 0.0, sy, 0.0, 0.0, // column 1 + 0.0, 0.0, 1.0, 0.0, // column 2 + 0.0, ty, tz, 1.0, // column 3 + ] +} + +/// Build a Vulkan perspective projection matrix. +/// +/// Right-handed, Y-down in clip space, depth range [0, 1]. +/// +/// Parameters: +/// - `fov_deg`: Vertical field of view in degrees +/// - `aspect`: Width / height +/// - `near`: Near clip plane distance (positive) +/// - `far`: Far clip plane distance (positive) +/// +/// Returns column-major 4x4 matrix. +fn perspective_vulkan(fov_deg: f32, aspect: f32, near: f32, far: f32) -> [f32; 16] { + let fov_rad = fov_deg * std::f32::consts::PI / 180.0; + let f = 1.0 / (fov_rad / 2.0).tan(); + + let range_inv = 1.0 / (near - far); + + // Column-major layout + // P[0][0] = f / aspect + // P[1][1] = -f (Vulkan Y-flip: positive Y in view → negative Y in clip) + // P[2][2] = far / (near - far) = far * range_inv + // P[2][3] = -1.0 (perspective divide: w = -z) + // P[3][2] = (near * far) / (near - far) = near * far * range_inv + [ + f / aspect, + 0.0, + 0.0, + 0.0, // column 0 + 0.0, + -f, + 0.0, + 0.0, // column 1 + 0.0, + 0.0, + far * range_inv, + -1.0, // column 2 + 0.0, + 0.0, + near * far * range_inv, + 0.0, // column 3 + ] +} + +/// Multiply two 4x4 matrices (column-major): result = a * b. +fn mat4_multiply(a: &[f32; 16], b: &[f32; 16]) -> [f32; 16] { + let mut result = [0.0f32; 16]; + + for col in 0..4 { + for row in 0..4 { + let mut sum = 0.0f32; + for k in 0..4 { + // a[row, k] * b[k, col] + // Column-major: element (row, col) is at index col*4 + row + sum += a[k * 4 + row] * b[col * 4 + k]; + } + result[col * 4 + row] = sum; + } + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_identity_multiply() { + let identity: [f32; 16] = [ + 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + ]; + let result = mat4_multiply(&identity, &identity); + for i in 0..16 { + assert!( + (result[i] - identity[i]).abs() < 1e-6, + "Mismatch at index {}: {} vs {}", + i, + result[i], + identity[i] + ); + } + } + + #[test] + fn test_perspective_basic() { + let proj = perspective_vulkan(90.0, 16.0 / 9.0, 0.01, 100.0); + // At 90 degree FoV, f = 1/tan(45) = 1.0 + let f = 1.0f32; + let aspect = 16.0 / 9.0; + assert!((proj[0] - f / aspect).abs() < 1e-5); // P[0][0] + assert!((proj[5] - (-f)).abs() < 1e-5); // P[1][1] (Y-flip) + assert!((proj[11] - (-1.0)).abs() < 1e-5); // P[2][3] (perspective divide) + } + + #[test] + fn test_cockpit_mvp_produces_valid_result() { + let mvp = compute_cockpit_mvp(75.0, 16.0 / 9.0); + // The result should be finite (no NaN/Inf) + for (i, &val) in mvp.iter().enumerate() { + assert!(val.is_finite(), "MVP[{}] is not finite: {}", i, val); + } + } + + #[test] + fn test_quad_center_projects_to_screen_center_area() { + let mvp = compute_cockpit_mvp(75.0, 16.0 / 9.0); + + // Transform the quad center (0, 0, 0) through MVP + // clip = MVP * (0, 0, 0, 1) + // clip.x = mvp[12], clip.y = mvp[13], clip.z = mvp[14], clip.w = mvp[15] + let clip_x = mvp[12]; + let clip_y = mvp[13]; + let clip_w = mvp[15]; + + // After perspective divide: ndc = clip / clip.w + let ndc_x = clip_x / clip_w; + let ndc_y = clip_y / clip_w; + + // The quad center should be near screen center (within [-0.5, 0.5]) + assert!(ndc_x.abs() < 0.5, "NDC X {} too far from center", ndc_x); + assert!(ndc_y.abs() < 0.5, "NDC Y {} too far from center", ndc_y); + } +} diff --git a/projects/nms-cockpit-video/injector/src/hooks/mod.rs b/projects/nms-cockpit-video/injector/src/hooks/mod.rs new file mode 100644 index 0000000..22a5d80 --- /dev/null +++ b/projects/nms-cockpit-video/injector/src/hooks/mod.rs @@ -0,0 +1,41 @@ +//! Hook installation and management. +//! +//! Installs detours on Vulkan functions and optionally hooks +//! OpenVR's IVRCompositor::Submit for VR rendering. + +pub mod openvr; +pub mod vulkan; + +use crate::log::vlog; + +/// Install all hooks. +/// +/// # Safety +/// Must be called from a thread where the target modules are loaded. +pub unsafe fn install() -> Result<(), String> { + vlog!("Installing hooks..."); + + // Vulkan hooks (required) + vulkan::install()?; + + // OpenVR hooks (optional - only if VR is active) + match openvr::try_install() { + Ok(true) => vlog!("VR hooks installed"), + Ok(false) => vlog!("VR not active, skipping VR hooks"), + Err(e) => vlog!("VR hook failed (non-fatal): {}", e), + } + + vlog!("All hooks installed"); + Ok(()) +} + +/// Remove all hooks. +/// +/// # Safety +/// Must be called during DLL detach. +pub unsafe fn remove() { + vlog!("Removing hooks..."); + openvr::remove(); + vulkan::remove(); + vlog!("All hooks removed"); +} diff --git a/projects/nms-cockpit-video/injector/src/hooks/openvr.rs b/projects/nms-cockpit-video/injector/src/hooks/openvr.rs new file mode 100644 index 0000000..c7eced0 --- /dev/null +++ b/projects/nms-cockpit-video/injector/src/hooks/openvr.rs @@ -0,0 +1,371 @@ +//! OpenVR compositor hook for VR rendering. +//! +//! Hooks IVRCompositor::Submit to render the video quad on each eye's +//! texture before it's submitted to the headset. +//! +//! Vtable layout (IVRCompositor_028, Windows x64 MSVC): +//! ```text +//! [0] SetTrackingSpace +//! [1] GetTrackingSpace +//! [2] WaitGetPoses +//! [3] GetLastPoses +//! [4] GetLastPoseForTrackedDeviceIndex +//! [5] Submit ← hooked +//! ``` + +use crate::log::vlog; +use ash::vk::Handle; +use std::ffi::{CString, c_void}; +use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering}; +use windows::Win32::System::LibraryLoader::{GetModuleHandleA, GetProcAddress}; +use windows::Win32::System::Memory::{PAGE_PROTECTION_FLAGS, PAGE_READWRITE, VirtualProtect}; + +/// Whether VR hooks are active. +static VR_ACTIVE: AtomicBool = AtomicBool::new(false); + +/// Original Submit function pointer (saved before hook). +static ORIGINAL_SUBMIT: AtomicPtr = AtomicPtr::new(std::ptr::null_mut()); + +/// The vtable slot address (for unhooking). +static VTABLE_SLOT: AtomicPtr<*const c_void> = AtomicPtr::new(std::ptr::null_mut()); + +// --- OpenVR type definitions --- + +/// OpenVR eye enum. +#[repr(i32)] +#[derive(Clone, Copy, Debug, PartialEq)] +#[allow(dead_code)] +pub enum EVREye { + Left = 0, + Right = 1, +} + +/// OpenVR texture type enum. +#[repr(i32)] +#[derive(Clone, Copy, Debug, PartialEq)] +#[allow(dead_code)] +enum ETextureType { + DirectX = 0, + OpenGL = 1, + Vulkan = 3, +} + +/// OpenVR compositor error enum. +#[repr(i32)] +#[derive(Clone, Copy, Debug)] +#[allow(dead_code)] +pub enum EVRCompositorError { + None = 0, + RequestFailed = 1, + IncompatibleVersion = 100, + DoNotHaveFocus = 101, + InvalidTexture = 102, + IsNotSceneApplication = 103, + TextureIsOnWrongDevice = 104, + TextureUsesUnsupportedFormat = 105, + SharedTexturesNotSupported = 106, + IndexOutOfRange = 107, + AlreadySubmitted = 108, + InvalidBounds = 109, + AlreadySet = 110, +} + +/// OpenVR Texture_t struct. +#[repr(C)] +#[derive(Clone, Copy)] +struct Texture_t { + handle: *const c_void, + e_type: i32, + e_color_space: i32, +} + +/// OpenVR VRTextureBounds_t struct. +#[repr(C)] +#[derive(Clone, Copy)] +#[allow(dead_code)] +struct VRTextureBounds_t { + u_min: f32, + v_min: f32, + u_max: f32, + v_max: f32, +} + +/// OpenVR submit flags. +#[repr(i32)] +#[derive(Clone, Copy)] +#[allow(dead_code)] +enum EVRSubmitFlags { + Default = 0, +} + +/// Vulkan texture data passed through Texture_t.handle for Vulkan textures. +#[repr(C)] +#[derive(Clone, Copy, Debug)] +pub struct VRVulkanTextureData_t { + pub image: u64, // VkImage (uint64_t handle) + pub device: *const c_void, // VkDevice + pub physical_device: *const c_void, // VkPhysicalDevice + pub instance: *const c_void, // VkInstance + pub queue: *const c_void, // VkQueue + pub queue_family_index: u32, + pub width: u32, + pub height: u32, + pub format: u32, // VkFormat + pub sample_count: u32, +} + +/// IVRCompositor::Submit function signature. +type FnSubmit = unsafe extern "system" fn( + this: *const c_void, + eye: EVREye, + texture: *const Texture_t, + bounds: *const VRTextureBounds_t, + flags: i32, +) -> EVRCompositorError; + +/// VR_GetGenericInterface function signature. +type FnGetGenericInterface = + unsafe extern "system" fn(interface_version: *const u8, error: *mut i32) -> *const c_void; + +/// Submit vtable index in IVRCompositor_028. +const SUBMIT_VTABLE_INDEX: usize = 5; + +/// Frame counter for VR. +static VR_FRAME_COUNT: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0); + +// --- Public API --- + +/// Try to install OpenVR hooks. +/// +/// Returns Ok(true) if VR is active and hooks were installed, +/// Ok(false) if VR is not active (openvr_api.dll not loaded), +/// Err if VR is active but hooking failed. +/// +/// # Safety +/// Must be called after Vulkan hooks are installed. +pub unsafe fn try_install() -> Result { + // Check if openvr_api.dll is loaded + let openvr_module = { + let name = CString::new("openvr_api.dll").unwrap(); + GetModuleHandleA(windows::core::PCSTR(name.as_ptr() as *const _)) + }; + + let openvr_module = match openvr_module { + Ok(h) => h, + Err(_) => { + vlog!("openvr_api.dll not loaded - VR not active"); + return Ok(false); + }, + }; + + vlog!("openvr_api.dll found - attempting VR hook"); + + // Get VR_GetGenericInterface + let proc_name = CString::new("VR_GetGenericInterface").unwrap(); + let get_interface_addr = GetProcAddress( + openvr_module, + windows::core::PCSTR(proc_name.as_ptr() as *const _), + ); + + let get_interface: FnGetGenericInterface = match get_interface_addr { + Some(f) => std::mem::transmute::< + unsafe extern "system" fn() -> isize, + unsafe extern "system" fn(*const u8, *mut i32) -> *const c_void, + >(f), + None => return Err("VR_GetGenericInterface not found in openvr_api.dll".into()), + }; + + // Get IVRCompositor + let version = b"IVRCompositor_028\0"; + let mut error: i32 = 0; + let compositor = get_interface(version.as_ptr(), &mut error); + + if compositor.is_null() || error != 0 { + return Err(format!("Failed to get IVRCompositor_028: error={}", error)); + } + + vlog!("IVRCompositor_028 at {:p}", compositor); + + // Read vtable (first pointer in the object) + let vtable = *(compositor as *const *const *const c_void); + let submit_slot = vtable.add(SUBMIT_VTABLE_INDEX) as *mut *const c_void; + + vlog!( + "Submit vtable slot at {:p}, current value {:p}", + submit_slot, + *submit_slot + ); + + // Save original function pointer + let original = *submit_slot; + ORIGINAL_SUBMIT.store(original as *mut c_void, Ordering::Release); + VTABLE_SLOT.store(submit_slot, Ordering::Release); + + // Make vtable writable, swap pointer, restore protection + let mut old_protect = PAGE_PROTECTION_FLAGS(0); + let slot_size = std::mem::size_of::<*const c_void>(); + + let protect_result = VirtualProtect( + submit_slot as *const c_void, + slot_size, + PAGE_READWRITE, + &mut old_protect, + ); + + if protect_result.is_err() { + return Err("VirtualProtect failed on vtable".into()); + } + + // Write our hook function + *submit_slot = hooked_submit as *const c_void; + + // Restore protection + let mut dummy = PAGE_PROTECTION_FLAGS(0); + let _ = VirtualProtect( + submit_slot as *const c_void, + slot_size, + old_protect, + &mut dummy, + ); + + VR_ACTIVE.store(true, Ordering::Release); + vlog!("IVRCompositor::Submit hooked successfully"); + + Ok(true) +} + +/// Remove VR hooks. +/// +/// # Safety +/// Must be called during DLL detach. +pub unsafe fn remove() { + if !VR_ACTIVE.load(Ordering::Acquire) { + return; + } + + let slot = VTABLE_SLOT.load(Ordering::Acquire); + let original = ORIGINAL_SUBMIT.load(Ordering::Acquire); + + if !slot.is_null() && !original.is_null() { + let mut old_protect = PAGE_PROTECTION_FLAGS(0); + let slot_size = std::mem::size_of::<*const c_void>(); + + let protect_result = VirtualProtect( + slot as *const c_void, + slot_size, + PAGE_READWRITE, + &mut old_protect, + ); + + if protect_result.is_ok() { + *slot = original as *const c_void; + let mut dummy = PAGE_PROTECTION_FLAGS(0); + let _ = VirtualProtect(slot as *const c_void, slot_size, old_protect, &mut dummy); + vlog!("IVRCompositor::Submit hook removed"); + } + } + + VR_ACTIVE.store(false, Ordering::Release); +} + +/// Check if VR is currently active. +#[allow(dead_code)] +pub fn is_active() -> bool { + VR_ACTIVE.load(Ordering::Relaxed) +} + +// --- Hook implementation --- + +/// Hooked IVRCompositor::Submit. +/// +/// Renders the video quad to the eye texture before calling the original Submit. +unsafe extern "system" fn hooked_submit( + this: *const c_void, + eye: EVREye, + texture: *const Texture_t, + bounds: *const VRTextureBounds_t, + flags: i32, +) -> EVRCompositorError { + let frame = VR_FRAME_COUNT.fetch_add(1, Ordering::Relaxed); + + // Try to render our overlay on the VR texture + if !texture.is_null() { + let tex = &*texture; + + // Only process Vulkan textures + if tex.e_type == ETextureType::Vulkan as i32 && !tex.handle.is_null() { + let vk_data = &*(tex.handle as *const VRVulkanTextureData_t); + + if frame.is_multiple_of(600) { + vlog!( + "VR Submit: eye={:?} {}x{} format={} image=0x{:X}", + eye, + vk_data.width, + vk_data.height, + vk_data.format, + vk_data.image + ); + } + + render_to_vr_eye(eye, vk_data, frame); + } + } + + // Call original Submit + let original: FnSubmit = std::mem::transmute(ORIGINAL_SUBMIT.load(Ordering::Acquire)); + original(this, eye, texture, bounds, flags) +} + +/// Render the video quad to a VR eye texture. +/// +/// Uses zero-copy frame passing: the SHMEM_READER lock is held across the render +/// call so `poll_frame()` can return a direct `&[u8]` into shared memory. +unsafe fn render_to_vr_eye(eye: EVREye, vk_data: &VRVulkanTextureData_t, frame: u64) { + use super::vulkan::{SHMEM_FAILED, SHMEM_READER, get_mvp, get_renderer}; + use crate::shmem_reader::ShmemFrameReader; + use ash::vk; + use std::sync::Mutex; + use std::sync::atomic::Ordering; + + // Get the renderer (shared with desktop rendering) + let renderer_mutex = get_renderer(); + + if let Ok(mut guard) = renderer_mutex.lock() { + let Some(renderer) = guard.as_mut() else { + return; + }; + + let mvp = get_mvp(frame); + + // Zero-copy shmem access: hold lock across render (same pattern as desktop path) + let reader_result = SHMEM_READER.get_or_try_init(|| match ShmemFrameReader::open() { + Ok(reader) => { + SHMEM_FAILED.store(false, Ordering::Relaxed); + Ok(Mutex::new(reader)) + }, + Err(e) => { + if !SHMEM_FAILED.swap(true, Ordering::Relaxed) { + vlog!("VR: ShmemFrameReader open failed: {}", e); + } + Err(e) + }, + }); + let mut shmem_guard = reader_result.ok().and_then(|m| m.lock().ok()); + let new_frame = match shmem_guard.as_mut() { + Some(reader) => reader.poll_frame(), + None => None, + }; + + let vr_image = vk::Image::from_raw(vk_data.image); + let extent = vk::Extent2D { + width: vk_data.width, + height: vk_data.height, + }; + + if let Err(e) = renderer.render_to_vr_image(vr_image, extent, &mvp, new_frame) { + if frame.is_multiple_of(600) { + vlog!("VR render error (eye={:?}): {}", eye, e); + } + } + } +} diff --git a/projects/nms-cockpit-video/injector/src/hooks/vulkan.rs b/projects/nms-cockpit-video/injector/src/hooks/vulkan.rs new file mode 100644 index 0000000..e62df04 --- /dev/null +++ b/projects/nms-cockpit-video/injector/src/hooks/vulkan.rs @@ -0,0 +1,832 @@ +//! Vulkan function hooks for rendering injection. +//! +//! Hooks Vulkan at two levels: +//! - Loader trampolines (vkCreateInstance, vkCreateDevice) via retour static_detour +//! - ICD-level functions (vkQueuePresentKHR, vkCreateSwapchainKHR) via RawDetour +//! +//! Extension functions (KHR) bypass the loader when obtained via vkGetDeviceProcAddr, +//! so we hook the actual ICD implementation addresses after device creation. + +use crate::camera::projection::compute_cockpit_mvp; +use crate::camera::{CAMERA_MODE_COCKPIT, CameraReader}; +use crate::log::vlog; +use crate::renderer::VulkanRenderer; +use crate::shmem_reader::ShmemFrameReader; +use ash::vk; +use once_cell::sync::OnceCell; +use retour::{RawDetour, static_detour}; +use std::ffi::{CString, c_char, c_void}; +use std::sync::Mutex; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use windows::Win32::System::LibraryLoader::{GetModuleHandleA, GetProcAddress}; + +// --- Type aliases for ICD function pointers --- + +type PfnQueuePresent = unsafe extern "system" fn(vk::Queue, *const c_void) -> vk::Result; +type PfnCreateSwapchain = unsafe extern "system" fn( + vk::Device, + *const c_void, + *const c_void, + *mut vk::SwapchainKHR, +) -> vk::Result; + +// --- Captured state --- + +/// The VkInstance. +static INSTANCE: OnceCell = OnceCell::new(); + +/// The ash Instance loader (created from vulkan-1.dll exports, no VkInstance needed). +static ASH_INSTANCE: OnceCell = OnceCell::new(); + +/// The VkPhysicalDevice used to create the device. +static PHYSICAL_DEVICE: OnceCell = OnceCell::new(); + +/// The VkDevice created by the game. +static DEVICE: OnceCell = OnceCell::new(); + +/// The VkQueue used for presentation. +static PRESENT_QUEUE: OnceCell = OnceCell::new(); + +/// Queue family index (captured from device creation). +static QUEUE_FAMILY_INDEX: AtomicU64 = AtomicU64::new(0); + +/// Current swapchain handle. +static SWAPCHAIN: OnceCell = OnceCell::new(); + +/// Cached ash::Device dispatch table (avoids per-frame function pointer resolution). +static ASH_DEVICE: OnceCell = OnceCell::new(); + +/// Cached swapchain extension function table. +static ASH_SWAPCHAIN_FN: OnceCell = OnceCell::new(); + +/// Cached swapchain images (cleared on swapchain recreation). +static SWAPCHAIN_IMAGES: Mutex>> = Mutex::new(None); + +/// Swapchain extent (packed: width << 32 | height). +static SWAPCHAIN_EXTENT: AtomicU64 = AtomicU64::new(0); + +/// Swapchain format (raw i32). +static SWAPCHAIN_FORMAT: AtomicU64 = AtomicU64::new(0); + +/// Frame counter for diagnostics. +static FRAME_COUNT: AtomicU64 = AtomicU64::new(0); + +/// The renderer (lazy-initialized on first present, resettable on swapchain recreation). +static RENDERER: Mutex> = Mutex::new(None); + +/// Camera reader (lazy-initialized on first present). +static CAMERA: OnceCell = OnceCell::new(); + +/// Shared memory frame reader (retries periodically if daemon isn't running). +pub(super) static SHMEM_READER: OnceCell> = OnceCell::new(); + +/// Whether we've already failed to open shmem (avoids log spam). +pub(super) static SHMEM_FAILED: AtomicBool = AtomicBool::new(false); + +// --- ICD-level hook state --- + +/// Trampoline to call the original ICD vkQueuePresentKHR. +static ICD_PRESENT_TRAMPOLINE: OnceCell = OnceCell::new(); + +/// Trampoline to call the original ICD vkCreateSwapchainKHR. +static ICD_SWAPCHAIN_TRAMPOLINE: OnceCell = OnceCell::new(); + +/// Wrapper for RawDetour to make it Send+Sync (detours are set-and-forget). +struct DetourHolder(RawDetour); +unsafe impl Send for DetourHolder {} +unsafe impl Sync for DetourHolder {} + +/// Keep RawDetours alive (dropping disables them). +static ICD_PRESENT_DETOUR: OnceCell = OnceCell::new(); +static ICD_SWAPCHAIN_DETOUR: OnceCell = OnceCell::new(); + +/// Set when ICD-level present hook is active (loader hook skips rendering). +static ICD_PRESENT_ACTIVE: AtomicBool = AtomicBool::new(false); + +// --- Hook definitions --- + +static_detour! { + static Hook_vkCreateInstance: unsafe extern "system" fn( + *const c_void, // pCreateInfo (VkInstanceCreateInfo*) + *const c_void, // pAllocator + *mut vk::Instance // pInstance + ) -> vk::Result; + + static Hook_vkCreateDevice: unsafe extern "system" fn( + vk::PhysicalDevice, + *const c_void, // pCreateInfo (VkDeviceCreateInfo*) + *const c_void, // pAllocator + *mut vk::Device + ) -> vk::Result; + + static Hook_vkCreateSwapchainKHR: unsafe extern "system" fn( + vk::Device, + *const c_void, // pCreateInfo (VkSwapchainCreateInfoKHR*) + *const c_void, // pAllocator + *mut vk::SwapchainKHR + ) -> vk::Result; + + static Hook_vkQueuePresentKHR: unsafe extern "system" fn( + vk::Queue, + *const c_void // pPresentInfo (VkPresentInfoKHR*) + ) -> vk::Result; +} + +// --- Hook implementations --- + +/// Hooked vkCreateInstance: captures instance for ash loader. +fn hooked_create_instance( + create_info: *const c_void, + allocator: *const c_void, + instance: *mut vk::Instance, +) -> vk::Result { + vlog!("vkCreateInstance called"); + + let result = unsafe { Hook_vkCreateInstance.call(create_info, allocator, instance) }; + + if result == vk::Result::SUCCESS { + let inst = unsafe { *instance }; + let _ = INSTANCE.set(inst); + vlog!("VkInstance captured: {:?}", inst); + + // Create ash::Instance loader + unsafe { + let get_instance_proc_addr = get_instance_proc_addr_fn(); + if let Some(gipa) = get_instance_proc_addr { + let static_fn = ash::StaticFn { + get_instance_proc_addr: gipa, + }; + let entry = ash::Entry::from_static_fn(static_fn); + let ash_instance = ash::Instance::load(entry.static_fn(), inst); + let _ = ASH_INSTANCE.set(ash_instance); + vlog!("ash::Instance created"); + } + } + } + + result +} + +/// Hooked vkCreateDevice: captures device and physical device, sets up ICD hooks. +fn hooked_create_device( + physical_device: vk::PhysicalDevice, + create_info: *const c_void, + allocator: *const c_void, + device: *mut vk::Device, +) -> vk::Result { + vlog!("vkCreateDevice called"); + + // Extract queue family index from create info + if !create_info.is_null() { + let info = unsafe { &*(create_info as *const vk::DeviceCreateInfo<'_>) }; + if info.queue_create_info_count > 0 && !info.p_queue_create_infos.is_null() { + let queue_info = unsafe { &*info.p_queue_create_infos }; + QUEUE_FAMILY_INDEX.store(queue_info.queue_family_index as u64, Ordering::Relaxed); + vlog!("Queue family index: {}", queue_info.queue_family_index); + } + } + + let result = + unsafe { Hook_vkCreateDevice.call(physical_device, create_info, allocator, device) }; + + if result == vk::Result::SUCCESS { + let dev = unsafe { *device }; + let _ = PHYSICAL_DEVICE.set(physical_device); + let _ = DEVICE.set(dev); + vlog!("VkDevice captured: {:?}", dev); + + // If vkCreateInstance was missed (late injection), create ash::Instance + // using vulkan-1.dll exports directly + if ASH_INSTANCE.get().is_none() { + vlog!("vkCreateInstance was missed - creating ash::Instance from loader exports"); + unsafe { + create_ash_instance_from_exports(); + } + } + + // Hook vkQueuePresentKHR and vkCreateSwapchainKHR at the ICD level. + // Extension functions bypass the loader when obtained via vkGetDeviceProcAddr, + // so we must hook the actual ICD addresses. + unsafe { + setup_icd_hooks(dev); + } + + // Cache the device dispatch table and swapchain extension functions + // to avoid per-frame function pointer resolution in try_render_overlay. + if let Some(instance) = ASH_INSTANCE.get() { + let device_fns = ash::Device::load(instance.fp_v1_0(), dev); + let swapchain_fn = ash::khr::swapchain::Device::new(instance, &device_fns); + let _ = ASH_SWAPCHAIN_FN.set(swapchain_fn); + let _ = ASH_DEVICE.set(device_fns); + vlog!("ash::Device and swapchain extension cached"); + } + } + + result +} + +/// Hooked vkCreateSwapchainKHR: tracks swapchain properties. +fn hooked_create_swapchain( + device: vk::Device, + create_info: *const c_void, + allocator: *const c_void, + swapchain: *mut vk::SwapchainKHR, +) -> vk::Result { + // Destroy existing renderer and invalidate cached images on swapchain recreation + // so they reinitialize with new swapchain properties (resolution, format, images) + if let Ok(mut guard) = RENDERER.lock() { + if guard.is_some() { + vlog!("Swapchain recreated - destroying old renderer for reinit"); + *guard = None; + } + } + if let Ok(mut images) = SWAPCHAIN_IMAGES.lock() { + *images = None; + } + + let result = + unsafe { Hook_vkCreateSwapchainKHR.call(device, create_info, allocator, swapchain) }; + + if result == vk::Result::SUCCESS && !create_info.is_null() { + let info = unsafe { &*(create_info as *const vk::SwapchainCreateInfoKHR<'_>) }; + let extent = info.image_extent; + let format = info.image_format; + + SWAPCHAIN_EXTENT.store( + ((extent.width as u64) << 32) | (extent.height as u64), + Ordering::Relaxed, + ); + SWAPCHAIN_FORMAT.store(format.as_raw() as u64, Ordering::Relaxed); + + let sc = unsafe { *swapchain }; + // Store swapchain (first one only for now) + let _ = SWAPCHAIN.set(sc); + + vlog!( + "Swapchain created: {}x{} format={:?} handle={:?}", + extent.width, + extent.height, + format, + sc + ); + } + + result +} + +/// Hooked vkQueuePresentKHR: injection point for rendering. +/// +/// When the ICD-level hook is active, this loader-level hook becomes a passthrough +/// to avoid double-rendering (both hooks fire on the same present call). +fn hooked_queue_present(queue: vk::Queue, present_info_ptr: *const c_void) -> vk::Result { + // Skip rendering if ICD hook handles it (avoids double submissions) + if !ICD_PRESENT_ACTIVE.load(Ordering::Relaxed) { + let count = FRAME_COUNT.fetch_add(1, Ordering::Relaxed); + + // Store the present queue on first call + if PRESENT_QUEUE.get().is_none() { + let _ = PRESENT_QUEUE.set(queue); + vlog!("Present queue captured: {:?}", queue); + } + + // Log every 300 frames (~5 seconds at 60fps) + if count.is_multiple_of(300) { + vlog!("vkQueuePresentKHR frame={}", count); + } + + // Try to render our quad overlay + if !present_info_ptr.is_null() { + unsafe { + try_render_overlay(queue, present_info_ptr, count); + } + } + } + + unsafe { Hook_vkQueuePresentKHR.call(queue, present_info_ptr) } +} + +// --- ICD-level hook implementations --- + +/// Create ash::Instance from vulkan-1.dll exports (no VkInstance handle needed). +/// +/// Provides a custom vkGetInstanceProcAddr that resolves functions directly from +/// vulkan-1.dll's exports via GetProcAddress. The loader's exported trampolines +/// dispatch correctly using the handle's internal dispatch table. +unsafe fn create_ash_instance_from_exports() { + let static_fn = ash::StaticFn { + get_instance_proc_addr: loader_get_instance_proc_addr, + }; + let entry = ash::Entry::from_static_fn(static_fn); + let ash_instance = ash::Instance::load(entry.static_fn(), vk::Instance::null()); + let _ = ASH_INSTANCE.set(ash_instance); + vlog!("ash::Instance created from loader exports (late-init)"); +} + +/// Custom vkGetInstanceProcAddr that resolves functions from vulkan-1.dll exports. +/// +/// Ignores the VkInstance parameter and uses GetProcAddress on vulkan-1.dll directly. +/// The loader's exported trampolines use the dispatch table embedded in dispatchable +/// handles (VkDevice, VkPhysicalDevice, VkQueue) to route to the correct ICD. +unsafe extern "system" fn loader_get_instance_proc_addr( + _instance: vk::Instance, + p_name: *const c_char, +) -> vk::PFN_vkVoidFunction { + if p_name.is_null() { + return None; + } + let module = match get_vulkan_module() { + Ok(m) => m, + Err(_) => return None, + }; + GetProcAddress(module, windows::core::PCSTR(p_name as *const u8)) + .map(|f| std::mem::transmute::<_, unsafe extern "system" fn()>(f)) +} + +/// Set up ICD-level hooks for extension functions that bypass the loader. +/// +/// Gets the actual ICD addresses via vkGetDeviceProcAddr and detours them. +unsafe fn setup_icd_hooks(device: vk::Device) { + let vulkan_module = match get_vulkan_module() { + Ok(m) => m, + Err(e) => { + vlog!("Failed to get vulkan module for ICD hooks: {}", e); + return; + }, + }; + + // Get vkGetDeviceProcAddr from vulkan-1.dll + let gdpa_addr = match get_proc(vulkan_module, "vkGetDeviceProcAddr") { + Ok(addr) => addr, + Err(e) => { + vlog!("Failed to get vkGetDeviceProcAddr: {}", e); + return; + }, + }; + let gdpa: vk::PFN_vkGetDeviceProcAddr = std::mem::transmute(gdpa_addr); + + // Hook vkQueuePresentKHR at ICD level + let present_name = CString::new("vkQueuePresentKHR").unwrap(); + let icd_present = (gdpa)(device, present_name.as_ptr()); + if let Some(present_fn) = icd_present { + let target = present_fn as *const (); + vlog!("ICD vkQueuePresentKHR at {:p}", target); + + match RawDetour::new(target, icd_hooked_queue_present as *const ()) { + Ok(detour) => { + let trampoline: PfnQueuePresent = std::mem::transmute(detour.trampoline()); + match detour.enable() { + Ok(()) => { + let _ = ICD_PRESENT_TRAMPOLINE.set(trampoline); + let _ = ICD_PRESENT_DETOUR.set(DetourHolder(detour)); + ICD_PRESENT_ACTIVE.store(true, Ordering::Relaxed); + vlog!("ICD vkQueuePresentKHR hooked successfully (loader hook disabled)"); + }, + Err(e) => vlog!("Failed to enable ICD present hook: {}", e), + } + }, + Err(e) => vlog!("Failed to create ICD present detour: {}", e), + } + } else { + vlog!("vkGetDeviceProcAddr returned null for vkQueuePresentKHR"); + } + + // Hook vkCreateSwapchainKHR at ICD level + let swapchain_name = CString::new("vkCreateSwapchainKHR").unwrap(); + let icd_swapchain = (gdpa)(device, swapchain_name.as_ptr()); + if let Some(swapchain_fn) = icd_swapchain { + let target = swapchain_fn as *const (); + vlog!("ICD vkCreateSwapchainKHR at {:p}", target); + + match RawDetour::new(target, icd_hooked_create_swapchain as *const ()) { + Ok(detour) => { + let trampoline: PfnCreateSwapchain = std::mem::transmute(detour.trampoline()); + match detour.enable() { + Ok(()) => { + let _ = ICD_SWAPCHAIN_TRAMPOLINE.set(trampoline); + let _ = ICD_SWAPCHAIN_DETOUR.set(DetourHolder(detour)); + vlog!("ICD vkCreateSwapchainKHR hooked successfully"); + }, + Err(e) => vlog!("Failed to enable ICD swapchain hook: {}", e), + } + }, + Err(e) => vlog!("Failed to create ICD swapchain detour: {}", e), + } + } else { + vlog!("vkGetDeviceProcAddr returned null for vkCreateSwapchainKHR"); + } +} + +/// ICD-level hook for vkQueuePresentKHR (called when NMS presents a frame). +unsafe extern "system" fn icd_hooked_queue_present( + queue: vk::Queue, + present_info_ptr: *const c_void, +) -> vk::Result { + let count = FRAME_COUNT.fetch_add(1, Ordering::Relaxed); + + // Store the present queue on first call + if PRESENT_QUEUE.get().is_none() { + let _ = PRESENT_QUEUE.set(queue); + vlog!("Present queue captured (ICD hook): {:?}", queue); + } + + // Log every 300 frames (~5 seconds at 60fps) + if count.is_multiple_of(300) { + vlog!("vkQueuePresentKHR (ICD) frame={}", count); + } + + // Try to render our quad overlay + if !present_info_ptr.is_null() { + try_render_overlay(queue, present_info_ptr, count); + } + + // Call the original ICD function + if let Some(trampoline) = ICD_PRESENT_TRAMPOLINE.get() { + (trampoline)(queue, present_info_ptr) + } else { + vk::Result::SUCCESS + } +} + +/// ICD-level hook for vkCreateSwapchainKHR (captures format/extent). +unsafe extern "system" fn icd_hooked_create_swapchain( + device: vk::Device, + create_info: *const c_void, + allocator: *const c_void, + swapchain: *mut vk::SwapchainKHR, +) -> vk::Result { + // Destroy renderer and invalidate cached images before recreation + // (mirrors the loader-level hook behavior) + if let Ok(mut guard) = RENDERER.lock() { + if guard.is_some() { + vlog!("Swapchain recreated (ICD) - destroying renderer for reinit"); + *guard = None; + } + } + if let Ok(mut images) = SWAPCHAIN_IMAGES.lock() { + *images = None; + } + + // Call the original ICD function first + let result = if let Some(trampoline) = ICD_SWAPCHAIN_TRAMPOLINE.get() { + (trampoline)(device, create_info, allocator, swapchain) + } else { + vk::Result::ERROR_UNKNOWN + }; + + if result == vk::Result::SUCCESS && !create_info.is_null() { + let info = &*(create_info as *const vk::SwapchainCreateInfoKHR<'_>); + let extent = info.image_extent; + let format = info.image_format; + + SWAPCHAIN_EXTENT.store( + ((extent.width as u64) << 32) | (extent.height as u64), + Ordering::Relaxed, + ); + SWAPCHAIN_FORMAT.store(format.as_raw() as u64, Ordering::Relaxed); + + let sc = *swapchain; + let _ = SWAPCHAIN.set(sc); + + vlog!( + "Swapchain created (ICD hook): {}x{} format={:?} handle={:?}", + extent.width, + extent.height, + format, + sc + ); + } + + result +} + +/// Attempt to render the overlay quad. +unsafe fn try_render_overlay(queue: vk::Queue, present_info_ptr: *const c_void, frame: u64) { + // Skip rendering if overlay is toggled off (F5) + if !crate::OVERLAY_VISIBLE.load(Ordering::Relaxed) { + return; + } + + let present_info = &*(present_info_ptr as *const vk::PresentInfoKHR<'_>); + + // We need at least one swapchain and one image index + if present_info.swapchain_count == 0 || present_info.p_swapchains.is_null() { + return; + } + + let swapchain = *present_info.p_swapchains; + let image_index = *present_info.p_image_indices; + + // Compute MVP from camera state (or fallback) + let mvp = compute_frame_mvp(frame); + + // Lock order: RENDERER → SHMEM_READER (must match VR path in openvr.rs to avoid deadlock) + if let Ok(mut guard) = RENDERER.lock() { + // Lazy-initialize the renderer (or reinitialize after swapchain recreation) + if guard.is_none() { + match init_renderer(queue, swapchain) { + Ok(renderer) => { + *guard = Some(renderer); + }, + Err(e) => { + if frame.is_multiple_of(300) { + vlog!("Renderer not ready: {}", e); + } + return; + }, + } + } + + if let Some(renderer) = guard.as_ref() { + // Poll shared memory for a new video frame (zero-copy: hold lock while rendering + // so we can pass the reference directly without allocating a Vec) + let reader_result = SHMEM_READER.get_or_try_init(|| match ShmemFrameReader::open() { + Ok(reader) => { + SHMEM_FAILED.store(false, Ordering::Relaxed); + Ok(Mutex::new(reader)) + }, + Err(e) => { + if !SHMEM_FAILED.swap(true, Ordering::Relaxed) { + vlog!("ShmemFrameReader open failed: {} (daemon not running?)", e); + } + Err(e) + }, + }); + let mut shmem_guard = reader_result.ok().and_then(|m| m.lock().ok()); + let new_frame = match shmem_guard.as_mut() { + Some(reader) => { + let result = reader.poll_frame(); + if frame.is_multiple_of(300) { + if let Some(data) = &result { + vlog!("Got video frame: {} bytes", data.len()); + } else { + vlog!( + "poll_frame: None (last_pts={}, shmem_pts={})", + reader.last_pts(), + reader.shmem_pts() + ); + } + } + result + }, + None => { + if frame.is_multiple_of(600) { + vlog!("Waiting for video daemon (shmem not available)"); + } + None + }, + }; + + // Use cached swapchain images (fetched once per swapchain lifetime) + let image = { + let mut img_guard = match SWAPCHAIN_IMAGES.lock() { + Ok(g) => g, + Err(_) => return, + }; + if img_guard.is_none() { + if let Some(swapchain_fn) = ASH_SWAPCHAIN_FN.get() { + if let Ok(images) = swapchain_fn.get_swapchain_images(swapchain) { + *img_guard = Some(images); + } + } + } + img_guard + .as_ref() + .and_then(|imgs| imgs.get(image_index as usize).copied()) + }; + + if let Some(image) = image { + if let Err(e) = renderer.render_frame(image_index, image, &mvp, new_frame) { + if frame.is_multiple_of(300) { + vlog!("Render error: {}", e); + } + } + } + } + } +} + +/// Compute the MVP for this frame using camera state. +/// +/// Returns a fallback MVP if the camera can't be read, or skips rendering +/// (via a fixed off-screen MVP) if not in cockpit mode. +unsafe fn compute_frame_mvp(frame: u64) -> [f32; 16] { + // Try to initialize/use the camera reader + let camera_result = CAMERA.get_or_try_init(|| unsafe { CameraReader::new() }); + + if let Ok(camera) = camera_result { + if let Some(state) = camera.read() { + // Only render in cockpit mode - return off-screen MVP otherwise + if state.mode != CAMERA_MODE_COCKPIT { + return offscreen_mvp(); + } + return compute_cockpit_mvp(state.fov_deg, state.aspect); + } + } + + // Fallback: camera not available, use default values + if frame.is_multiple_of(300) { + vlog!("Camera unavailable, using fallback MVP"); + } + let extent_packed = SWAPCHAIN_EXTENT.load(Ordering::Relaxed); + let w = (extent_packed >> 32) as f32; + let h = (extent_packed & 0xFFFFFFFF) as f32; + let aspect = if h > 0.0 { w / h } else { 16.0 / 9.0 }; + compute_cockpit_mvp(75.0, aspect) +} + +/// MVP that places the quad entirely off-screen (used to hide when not in cockpit). +fn offscreen_mvp() -> [f32; 16] { + [ + 0.0, 0.0, 0.0, 0.0, // column 0: zero scale + 0.0, 0.0, 0.0, 0.0, // column 1: zero scale + 0.0, 0.0, 0.0, 0.0, // column 2 + 0.0, 0.0, 0.0, 1.0, // column 3: w=1 but everything else zero → degenerate + ] +} + +/// Initialize the renderer with current state. +fn init_renderer(queue: vk::Queue, swapchain: vk::SwapchainKHR) -> Result { + let instance = ASH_INSTANCE + .get() + .ok_or_else(|| "No ash::Instance".to_string())?; + let &physical_device = PHYSICAL_DEVICE + .get() + .ok_or_else(|| "No physical device".to_string())?; + let &device = DEVICE.get().ok_or_else(|| "No device".to_string())?; + + let extent_packed = SWAPCHAIN_EXTENT.load(Ordering::Relaxed); + let extent = vk::Extent2D { + width: (extent_packed >> 32) as u32, + height: (extent_packed & 0xFFFFFFFF) as u32, + }; + + let format_raw = SWAPCHAIN_FORMAT.load(Ordering::Relaxed) as i32; + let format = vk::Format::from_raw(format_raw); + + let queue_family = QUEUE_FAMILY_INDEX.load(Ordering::Relaxed) as u32; + + vlog!( + "Initializing renderer: extent={}x{} format={:?} queue_family={}", + extent.width, + extent.height, + format, + queue_family + ); + + let renderer = unsafe { + VulkanRenderer::new( + device, + physical_device, + queue, + queue_family, + swapchain, + format, + extent, + instance, + )? + }; + + Ok(renderer) +} + +// --- Installation --- + +/// Install Vulkan hooks. +/// +/// # Safety +/// vulkan-1.dll must be loaded in the process. +pub unsafe fn install() -> Result<(), String> { + let vulkan_module = get_vulkan_module()?; + + // Hook vkCreateInstance + let addr = get_proc(vulkan_module, "vkCreateInstance")?; + vlog!("vkCreateInstance at {:p}", addr); + let pfn: unsafe extern "system" fn( + *const c_void, + *const c_void, + *mut vk::Instance, + ) -> vk::Result = std::mem::transmute(addr); + Hook_vkCreateInstance + .initialize(pfn, hooked_create_instance) + .map_err(|e| format!("Failed to init vkCreateInstance hook: {}", e))?; + Hook_vkCreateInstance + .enable() + .map_err(|e| format!("Failed to enable vkCreateInstance hook: {}", e))?; + vlog!("vkCreateInstance hooked"); + + // Hook vkCreateDevice + let addr = get_proc(vulkan_module, "vkCreateDevice")?; + vlog!("vkCreateDevice at {:p}", addr); + let pfn: unsafe extern "system" fn( + vk::PhysicalDevice, + *const c_void, + *const c_void, + *mut vk::Device, + ) -> vk::Result = std::mem::transmute(addr); + Hook_vkCreateDevice + .initialize(pfn, hooked_create_device) + .map_err(|e| format!("Failed to init vkCreateDevice hook: {}", e))?; + Hook_vkCreateDevice + .enable() + .map_err(|e| format!("Failed to enable vkCreateDevice hook: {}", e))?; + vlog!("vkCreateDevice hooked"); + + // Hook vkCreateSwapchainKHR + let addr = get_proc(vulkan_module, "vkCreateSwapchainKHR")?; + vlog!("vkCreateSwapchainKHR at {:p}", addr); + let pfn: unsafe extern "system" fn( + vk::Device, + *const c_void, + *const c_void, + *mut vk::SwapchainKHR, + ) -> vk::Result = std::mem::transmute(addr); + Hook_vkCreateSwapchainKHR + .initialize(pfn, hooked_create_swapchain) + .map_err(|e| format!("Failed to init vkCreateSwapchainKHR hook: {}", e))?; + Hook_vkCreateSwapchainKHR + .enable() + .map_err(|e| format!("Failed to enable vkCreateSwapchainKHR hook: {}", e))?; + vlog!("vkCreateSwapchainKHR hooked"); + + // Hook vkQueuePresentKHR + let addr = get_proc(vulkan_module, "vkQueuePresentKHR")?; + vlog!("vkQueuePresentKHR at {:p}", addr); + let pfn: unsafe extern "system" fn(vk::Queue, *const c_void) -> vk::Result = + std::mem::transmute(addr); + Hook_vkQueuePresentKHR + .initialize(pfn, hooked_queue_present) + .map_err(|e| format!("Failed to init vkQueuePresentKHR hook: {}", e))?; + Hook_vkQueuePresentKHR + .enable() + .map_err(|e| format!("Failed to enable vkQueuePresentKHR hook: {}", e))?; + vlog!("vkQueuePresentKHR hooked"); + + Ok(()) +} + +/// Remove all Vulkan hooks. +/// +/// # Safety +/// Must be called during DLL detach. +pub unsafe fn remove() { + // Disable ICD-level hooks first (these are the active ones) + if let Some(holder) = ICD_PRESENT_DETOUR.get() { + let _ = holder.0.disable(); + } + if let Some(holder) = ICD_SWAPCHAIN_DETOUR.get() { + let _ = holder.0.disable(); + } + // Disable loader trampoline hooks + let _ = Hook_vkQueuePresentKHR.disable(); + let _ = Hook_vkCreateSwapchainKHR.disable(); + let _ = Hook_vkCreateDevice.disable(); + let _ = Hook_vkCreateInstance.disable(); + vlog!("Vulkan hooks removed"); +} + +// --- Public accessors for VR hook --- + +/// Get the renderer mutex (used by OpenVR hook). +pub fn get_renderer() -> &'static Mutex> { + &RENDERER +} + +/// Compute the current MVP matrix (used by OpenVR hook). +/// +/// # Safety +/// Reads camera memory. +pub unsafe fn get_mvp(frame: u64) -> [f32; 16] { + compute_frame_mvp(frame) +} + +// --- Helpers --- + +/// Get the vulkan-1.dll module handle. +unsafe fn get_vulkan_module() -> Result { + let name = CString::new("vulkan-1.dll").unwrap(); + GetModuleHandleA(windows::core::PCSTR(name.as_ptr() as *const _)) + .map_err(|e| format!("GetModuleHandle(vulkan-1.dll) failed: {}", e)) +} + +/// Get a function address from a module. +unsafe fn get_proc( + module: windows::Win32::Foundation::HMODULE, + name: &str, +) -> Result<*const c_void, String> { + let cname = CString::new(name).unwrap(); + let addr = GetProcAddress(module, windows::core::PCSTR(cname.as_ptr() as *const _)); + match addr { + Some(f) => Ok(f as *const c_void), + None => Err(format!("GetProcAddress({}) failed", name)), + } +} + +/// Get vkGetInstanceProcAddr function pointer from vulkan-1.dll. +unsafe fn get_instance_proc_addr_fn() -> Option { + let module = get_vulkan_module().ok()?; + let addr = get_proc(module, "vkGetInstanceProcAddr").ok()?; + Some(std::mem::transmute::< + *const std::ffi::c_void, + unsafe extern "system" fn(vk::Instance, *const i8) -> Option, + >(addr)) +} diff --git a/projects/nms-cockpit-video/injector/src/input.rs b/projects/nms-cockpit-video/injector/src/input.rs new file mode 100644 index 0000000..3a93b7d --- /dev/null +++ b/projects/nms-cockpit-video/injector/src/input.rs @@ -0,0 +1,233 @@ +//! In-game keyboard input handler. +//! +//! Polls for hotkey presses and sends IPC commands to the video daemon. +//! Runs in a background thread spawned during DLL initialization. + +use crate::log::vlog; +use crate::{OVERLAY_VISIBLE, SHUTDOWN}; +use itk_ipc::IpcChannel; +use itk_protocol::{MessageType, VideoLoad, VideoPause, VideoPlay, VideoSeek, encode}; +use std::sync::atomic::Ordering; +use std::thread; +use std::time::Duration; +use windows::Win32::UI::Input::KeyboardAndMouse::GetAsyncKeyState; + +/// IPC channel name to connect to the daemon. +const CLIENT_CHANNEL: &str = "nms_cockpit_client"; + +/// Polling interval for keyboard state. +const POLL_INTERVAL: Duration = Duration::from_millis(50); + +/// Seek step in milliseconds. +const SEEK_STEP_MS: u64 = 10_000; + +/// Virtual key codes for hotkeys. +const VK_F5: i32 = 0x74; +const VK_F6: i32 = 0x75; +const VK_F7: i32 = 0x76; +const VK_F8: i32 = 0x77; +const VK_F9: i32 = 0x78; + +/// Start the input handler thread. +pub fn start() { + thread::spawn(|| { + vlog!("Input handler starting"); + input_loop(); + vlog!("Input handler stopped"); + }); +} + +/// Main input polling loop. +fn input_loop() { + let mut prev_f5 = false; + let mut prev_f6 = false; + let mut prev_f7 = false; + let mut prev_f8 = false; + let mut prev_f9 = false; + let mut playing = false; + let mut position_ms: u64 = 0; + + // Lazily connected IPC channel + let mut ipc: Option> = None; + + loop { + if SHUTDOWN.load(Ordering::Acquire) { + break; + } + + thread::sleep(POLL_INTERVAL); + + // Read current key states (bit 15 = currently pressed) + let f5 = is_key_down(VK_F5); + let f6 = is_key_down(VK_F6); + let f7 = is_key_down(VK_F7); + let f8 = is_key_down(VK_F8); + let f9 = is_key_down(VK_F9); + + // Detect key-down edges + if f5 && !prev_f5 { + let was_visible = OVERLAY_VISIBLE.fetch_xor(true, Ordering::Relaxed); + let now_visible = !was_visible; + vlog!("F5: Overlay {}", if now_visible { "ON" } else { "OFF" }); + } + + if f6 && !prev_f6 { + if let Some(channel) = ensure_connected(&mut ipc) { + if playing { + vlog!("F6: Pause"); + let cmd = VideoPause {}; + send_command(channel, MessageType::VideoPause, &cmd); + playing = false; + } else { + vlog!("F6: Play"); + let cmd = VideoPlay { + from_position_ms: None, + }; + send_command(channel, MessageType::VideoPlay, &cmd); + playing = true; + } + } + } + + if f7 && !prev_f7 { + if let Some(channel) = ensure_connected(&mut ipc) { + position_ms = position_ms.saturating_sub(SEEK_STEP_MS); + vlog!("F7: Seek back to {}ms", position_ms); + let cmd = VideoSeek { position_ms }; + send_command(channel, MessageType::VideoSeek, &cmd); + } + } + + if f8 && !prev_f8 { + if let Some(channel) = ensure_connected(&mut ipc) { + position_ms += SEEK_STEP_MS; + vlog!("F8: Seek forward to {}ms", position_ms); + let cmd = VideoSeek { position_ms }; + send_command(channel, MessageType::VideoSeek, &cmd); + } + } + + if f9 && !prev_f9 { + vlog!("F9: Load from clipboard"); + if let Some(url) = read_clipboard_text() { + let url = url.trim().to_string(); + if !url.is_empty() { + vlog!("Loading URL: {}", url); + if let Some(channel) = ensure_connected(&mut ipc) { + let cmd = VideoLoad { + source: url, + start_position_ms: 0, + autoplay: true, + }; + send_command(channel, MessageType::VideoLoad, &cmd); + playing = true; + position_ms = 0; + } + } else { + vlog!("Clipboard is empty"); + } + } else { + vlog!("Failed to read clipboard"); + } + } + + prev_f5 = f5; + prev_f6 = f6; + prev_f7 = f7; + prev_f8 = f8; + prev_f9 = f9; + } +} + +/// Check if a key is currently pressed. +fn is_key_down(vk: i32) -> bool { + unsafe { GetAsyncKeyState(vk) & (0x8000u16 as i16) != 0 } +} + +/// Read unicode text from the Windows clipboard. +fn read_clipboard_text() -> Option { + use windows::Win32::System::DataExchange::{CloseClipboard, GetClipboardData, OpenClipboard}; + use windows::Win32::System::Memory::{GlobalLock, GlobalUnlock}; + + const CF_UNICODETEXT: u32 = 13; + + unsafe { + // Open the clipboard (no window owner) + if OpenClipboard(None).is_err() { + return None; + } + + let result = (|| -> Option { + // Get the clipboard data as unicode text + let handle = GetClipboardData(CF_UNICODETEXT).ok()?; + + // The HANDLE from GetClipboardData is actually an HGLOBAL + let hglobal = windows::Win32::Foundation::HGLOBAL(handle.0 as _); + + // Lock the global memory to get a pointer + let ptr = GlobalLock(hglobal); + if ptr.is_null() { + return None; + } + + // Read the null-terminated wide string + let wstr = ptr as *const u16; + let mut len = 0usize; + while *wstr.add(len) != 0 { + len += 1; + // Safety bound + if len > 65536 { + break; + } + } + + let slice = std::slice::from_raw_parts(wstr, len); + let text = String::from_utf16_lossy(slice); + + let _ = GlobalUnlock(hglobal); + Some(text) + })(); + + let _ = CloseClipboard(); + result + } +} + +/// Ensure IPC connection is established. Reconnects on failure. +fn ensure_connected(ipc: &mut Option>) -> Option<&dyn IpcChannel> { + // Check if existing connection is still good + if let Some(ref channel) = ipc { + if channel.is_connected() { + return ipc.as_deref(); + } + vlog!("IPC disconnected, reconnecting..."); + } + + // Try to connect + match itk_ipc::connect(CLIENT_CHANNEL) { + Ok(channel) => { + vlog!("Connected to daemon IPC ({})", CLIENT_CHANNEL); + *ipc = Some(Box::new(channel)); + ipc.as_deref() + }, + Err(e) => { + vlog!("IPC connect failed: {} (is daemon running?)", e); + *ipc = None; + None + }, + } +} + +/// Send an encoded ITK protocol command. +fn send_command(channel: &dyn IpcChannel, msg_type: MessageType, payload: &T) { + match encode(msg_type, payload) { + Ok(data) => { + if let Err(e) = channel.send(&data) { + vlog!("IPC send failed: {}", e); + } + }, + Err(e) => { + vlog!("Protocol encode failed: {}", e); + }, + } +} diff --git a/projects/nms-cockpit-video/injector/src/lib.rs b/projects/nms-cockpit-video/injector/src/lib.rs new file mode 100644 index 0000000..48abae3 --- /dev/null +++ b/projects/nms-cockpit-video/injector/src/lib.rs @@ -0,0 +1,112 @@ +//! NMS Cockpit Video Player - Vulkan Texture Injector +//! +//! Injectable DLL that hooks NMS's Vulkan pipeline to render video frames +//! as a textured quad in the cockpit, visible to both desktop and VR users. + +pub mod camera; +mod hooks; +pub mod input; +mod log; +pub mod renderer; +pub mod shmem_reader; + +use log::vlog; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::thread; +use std::time::Duration; +use windows::Win32::Foundation::{BOOL, HINSTANCE, TRUE}; +use windows::Win32::System::LibraryLoader::GetModuleHandleA; +use windows::Win32::System::SystemServices::{DLL_PROCESS_ATTACH, DLL_PROCESS_DETACH}; + +/// Whether the DLL has been initialized. +static INITIALIZED: AtomicBool = AtomicBool::new(false); + +/// Whether shutdown has been requested. +static SHUTDOWN: AtomicBool = AtomicBool::new(false); + +/// Whether the video overlay is visible (toggled by F5). +pub static OVERLAY_VISIBLE: AtomicBool = AtomicBool::new(false); + +/// DLL entry point. +#[unsafe(no_mangle)] +pub extern "system" fn DllMain( + _hinst: HINSTANCE, + reason: u32, + _reserved: *mut std::ffi::c_void, +) -> BOOL { + match reason { + DLL_PROCESS_ATTACH => { + // Spawn init thread to avoid loader lock deadlock + thread::spawn(|| { + if let Err(e) = init() { + vlog!("Init failed: {}", e); + } + }); + }, + DLL_PROCESS_DETACH => { + shutdown(); + }, + _ => {}, + } + TRUE +} + +/// Initialize the injector: wait for Vulkan, install hooks. +fn init() -> Result<(), String> { + log::init_file_log(); + vlog!("Initializing..."); + + // Wait for vulkan-1.dll to be loaded by the game + wait_for_module("vulkan-1.dll", Duration::from_secs(30))?; + vlog!("vulkan-1.dll found"); + + // Install Vulkan hooks + unsafe { + hooks::install()?; + } + + // Start keyboard input handler (sends IPC commands to daemon) + input::start(); + + INITIALIZED.store(true, Ordering::Release); + vlog!("Initialization complete"); + Ok(()) +} + +/// Wait for a DLL module to be loaded in the current process. +fn wait_for_module(name: &str, timeout: Duration) -> Result<(), String> { + let start = std::time::Instant::now(); + let name_cstr = std::ffi::CString::new(name).map_err(|e| e.to_string())?; + + loop { + let handle = + unsafe { GetModuleHandleA(windows::core::PCSTR(name_cstr.as_ptr() as *const _)) }; + + if handle.is_ok() { + return Ok(()); + } + + if start.elapsed() > timeout { + return Err(format!("Timeout waiting for {}", name)); + } + + if SHUTDOWN.load(Ordering::Acquire) { + return Err("Shutdown requested during init".to_string()); + } + + thread::sleep(Duration::from_millis(10)); + } +} + +/// Clean shutdown: remove hooks. +fn shutdown() { + SHUTDOWN.store(true, Ordering::Release); + + if INITIALIZED.load(Ordering::Acquire) { + vlog!("Shutting down..."); + unsafe { + hooks::remove(); + } + vlog!("Shutdown complete"); + } +} diff --git a/projects/nms-cockpit-video/injector/src/log.rs b/projects/nms-cockpit-video/injector/src/log.rs new file mode 100644 index 0000000..2fe70d5 --- /dev/null +++ b/projects/nms-cockpit-video/injector/src/log.rs @@ -0,0 +1,53 @@ +//! Logging via OutputDebugString + file for DLL injection context. +//! +//! Uses Windows OutputDebugStringA which can be captured by DebugView, +//! and also writes to a log file for reliable post-mortem analysis. + +use std::fs::OpenOptions; +use std::io::Write; +use std::sync::Mutex; + +static LOG_FILE: Mutex> = Mutex::new(None); + +/// Initialize the file log path (call once during init). +pub fn init_file_log() { + if let Ok(mut path) = LOG_FILE.lock() { + if path.is_none() { + // Write log next to the DLL or in temp + let log_path = std::env::temp_dir().join("nms_video_injector.log"); + *path = Some(log_path.to_string_lossy().into_owned()); + // Truncate on init + if let Ok(mut f) = std::fs::File::create(&log_path) { + let _ = writeln!(f, "[NMS-VIDEO] Log initialized"); + } + } + } +} + +/// Log a message via OutputDebugStringA and to file. +pub fn debug_log(msg: &str) { + let prefixed = format!("[NMS-VIDEO] {}\0", msg); + unsafe { + windows::Win32::System::Diagnostics::Debug::OutputDebugStringA(windows::core::PCSTR( + prefixed.as_ptr(), + )); + } + + // Also write to file + if let Ok(guard) = LOG_FILE.lock() { + if let Some(path) = guard.as_ref() { + if let Ok(mut f) = OpenOptions::new().create(true).append(true).open(path) { + let _ = writeln!(f, "[NMS-VIDEO] {}", msg); + } + } + } +} + +/// Formatted logging macro. +macro_rules! vlog { + ($($arg:tt)*) => { + $crate::log::debug_log(&format!($($arg)*)) + }; +} + +pub(crate) use vlog; diff --git a/projects/nms-cockpit-video/injector/src/renderer/geometry.rs b/projects/nms-cockpit-video/injector/src/renderer/geometry.rs new file mode 100644 index 0000000..8076125 --- /dev/null +++ b/projects/nms-cockpit-video/injector/src/renderer/geometry.rs @@ -0,0 +1,77 @@ +//! Quad geometry for the video overlay. + +use ash::vk; + +/// Vertex with 3D position and 2D UV coordinates. +#[repr(C)] +#[derive(Clone, Copy, Debug)] +pub struct Vertex { + pub pos: [f32; 3], + pub uv: [f32; 2], +} + +impl Vertex { + /// Vertex input binding description (per-vertex, stride = 20 bytes). + pub fn binding_description() -> vk::VertexInputBindingDescription { + vk::VertexInputBindingDescription { + binding: 0, + stride: std::mem::size_of::() as u32, + input_rate: vk::VertexInputRate::VERTEX, + } + } + + /// Vertex input attribute descriptions (position at location 0, uv at location 1). + pub fn attribute_descriptions() -> [vk::VertexInputAttributeDescription; 2] { + [ + // location 0: vec3 position + vk::VertexInputAttributeDescription { + location: 0, + binding: 0, + format: vk::Format::R32G32B32_SFLOAT, + offset: 0, + }, + // location 1: vec2 uv + vk::VertexInputAttributeDescription { + location: 1, + binding: 0, + format: vk::Format::R32G32_SFLOAT, + offset: 12, // after vec3 (3 * 4 bytes) + }, + ] + } +} + +/// Unit quad vertices (2 triangles, CCW winding). +/// Spans [-1, 1] in X and Y at Z=0, with UVs [0,1]. +/// In Phase 3, the MVP push constant will transform this to cockpit screen position. +pub const QUAD_VERTICES: [Vertex; 6] = [ + // Triangle 1 (top-left, bottom-left, bottom-right) + Vertex { + pos: [-1.0, -1.0, 0.0], + uv: [0.0, 0.0], + }, // TL + Vertex { + pos: [-1.0, 1.0, 0.0], + uv: [0.0, 1.0], + }, // BL + Vertex { + pos: [1.0, 1.0, 0.0], + uv: [1.0, 1.0], + }, // BR + // Triangle 2 (top-left, bottom-right, top-right) + Vertex { + pos: [-1.0, -1.0, 0.0], + uv: [0.0, 0.0], + }, // TL + Vertex { + pos: [1.0, 1.0, 0.0], + uv: [1.0, 1.0], + }, // BR + Vertex { + pos: [1.0, -1.0, 0.0], + uv: [1.0, 0.0], + }, // TR +]; + +/// Size in bytes of the quad vertex data. +pub const QUAD_VERTICES_SIZE: u64 = std::mem::size_of::<[Vertex; 6]>() as u64; diff --git a/projects/nms-cockpit-video/injector/src/renderer/mod.rs b/projects/nms-cockpit-video/injector/src/renderer/mod.rs new file mode 100644 index 0000000..599e271 --- /dev/null +++ b/projects/nms-cockpit-video/injector/src/renderer/mod.rs @@ -0,0 +1,1000 @@ +//! Vulkan renderer for the video quad overlay. +//! +//! Creates and manages all GPU resources needed to render a textured quad +//! on top of the game's swapchain images. + +pub mod geometry; +pub mod texture; + +use crate::log::vlog; +use ash::vk; +use geometry::{QUAD_VERTICES, QUAD_VERTICES_SIZE, Vertex}; +use std::collections::HashMap; +use texture::VideoTexture; + +/// Embedded SPIR-V shaders (compiled by build.rs). +const VERT_SPV: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/quad.vert.spv")); +const FRAG_SPV: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/quad.frag.spv")); + +/// Per-frame rendering resources. +struct FrameResources { + framebuffer: vk::Framebuffer, + command_buffer: vk::CommandBuffer, + fence: vk::Fence, + image_view: vk::ImageView, +} + +/// Cached VR rendering resources for a specific eye texture. +struct VrFrameCache { + image_view: vk::ImageView, + framebuffer: vk::Framebuffer, + fence: vk::Fence, + command_buffer: vk::CommandBuffer, +} + +/// Default video frame dimensions (must match daemon). +const VIDEO_WIDTH: u32 = 1280; +const VIDEO_HEIGHT: u32 = 720; + +/// The Vulkan renderer that draws a quad over the game's output. +pub struct VulkanRenderer { + device: ash::Device, + queue: vk::Queue, + render_pass: vk::RenderPass, + pipeline_layout: vk::PipelineLayout, + pipeline: vk::Pipeline, + descriptor_set_layout: vk::DescriptorSetLayout, + video_texture: VideoTexture, + command_pool: vk::CommandPool, + vertex_buffer: vk::Buffer, + vertex_memory: vk::DeviceMemory, + frames: Vec, + extent: vk::Extent2D, + format: vk::Format, + /// Cached VR resources per eye image handle (avoids re-creation every frame). + vr_cache: HashMap, +} + +impl VulkanRenderer { + /// Initialize the renderer with the game's Vulkan device and swapchain info. + /// + /// # Safety + /// - `raw_device` must be a valid VkDevice + /// - `queue` must be a valid VkQueue from that device + /// - `swapchain` must be a valid VkSwapchainKHR + /// - `physical_device` must be the physical device used to create `raw_device` + #[allow(clippy::too_many_arguments)] + pub unsafe fn new( + raw_device: vk::Device, + physical_device: vk::PhysicalDevice, + queue: vk::Queue, + queue_family_index: u32, + swapchain: vk::SwapchainKHR, + format: vk::Format, + extent: vk::Extent2D, + instance: &ash::Instance, + ) -> Result { + // Load device functions + let device = ash::Device::load(instance.fp_v1_0(), raw_device); + + // Get swapchain images + let swapchain_fn = ash::khr::swapchain::Device::new(instance, &device); + let images = swapchain_fn + .get_swapchain_images(swapchain) + .map_err(|e| format!("Failed to get swapchain images: {:?}", e))?; + + vlog!("Swapchain has {} images", images.len()); + + // Create render pass + let render_pass = create_render_pass(&device, format)?; + + // Create descriptor set layout for video texture + let descriptor_set_layout = + texture::create_descriptor_set_layout(&device).map_err(|e| { + device.destroy_render_pass(render_pass, None); + e + })?; + + // Create pipeline layout (descriptor set + push constant: mat4 = 64 bytes) + let push_constant_range = vk::PushConstantRange { + stage_flags: vk::ShaderStageFlags::VERTEX, + offset: 0, + size: 64, // mat4 + }; + + let layout_info = vk::PipelineLayoutCreateInfo::default() + .set_layouts(std::slice::from_ref(&descriptor_set_layout)) + .push_constant_ranges(std::slice::from_ref(&push_constant_range)); + + let pipeline_layout = device + .create_pipeline_layout(&layout_info, None) + .map_err(|e| { + device.destroy_descriptor_set_layout(descriptor_set_layout, None); + device.destroy_render_pass(render_pass, None); + format!("Failed to create pipeline layout: {:?}", e) + })?; + + // Create graphics pipeline + let pipeline = + create_pipeline(&device, render_pass, pipeline_layout, extent).map_err(|e| { + device.destroy_pipeline_layout(pipeline_layout, None); + device.destroy_descriptor_set_layout(descriptor_set_layout, None); + device.destroy_render_pass(render_pass, None); + e + })?; + + // Create command pool + let pool_info = vk::CommandPoolCreateInfo::default() + .queue_family_index(queue_family_index) + .flags(vk::CommandPoolCreateFlags::RESET_COMMAND_BUFFER); + + let command_pool = device.create_command_pool(&pool_info, None).map_err(|e| { + device.destroy_pipeline(pipeline, None); + device.destroy_pipeline_layout(pipeline_layout, None); + device.destroy_descriptor_set_layout(descriptor_set_layout, None); + device.destroy_render_pass(render_pass, None); + format!("Failed to create command pool: {:?}", e) + })?; + + // Create vertex buffer + let (vertex_buffer, vertex_memory) = + create_vertex_buffer(&device, instance, physical_device).map_err(|e| { + device.destroy_command_pool(command_pool, None); + device.destroy_pipeline(pipeline, None); + device.destroy_pipeline_layout(pipeline_layout, None); + device.destroy_descriptor_set_layout(descriptor_set_layout, None); + device.destroy_render_pass(render_pass, None); + e + })?; + + // Create video texture + let video_texture = VideoTexture::new( + &device, + instance, + physical_device, + descriptor_set_layout, + VIDEO_WIDTH, + VIDEO_HEIGHT, + queue, + command_pool, + ) + .map_err(|e| { + device.free_memory(vertex_memory, None); + device.destroy_buffer(vertex_buffer, None); + device.destroy_command_pool(command_pool, None); + device.destroy_pipeline(pipeline, None); + device.destroy_pipeline_layout(pipeline_layout, None); + device.destroy_descriptor_set_layout(descriptor_set_layout, None); + device.destroy_render_pass(render_pass, None); + e + })?; + + // Create per-frame resources + let frames = + create_frame_resources(&device, &images, format, extent, render_pass, command_pool) + .map_err(|e| { + video_texture.destroy(&device); + device.free_memory(vertex_memory, None); + device.destroy_buffer(vertex_buffer, None); + device.destroy_command_pool(command_pool, None); + device.destroy_pipeline(pipeline, None); + device.destroy_pipeline_layout(pipeline_layout, None); + device.destroy_descriptor_set_layout(descriptor_set_layout, None); + device.destroy_render_pass(render_pass, None); + e + })?; + + vlog!( + "Renderer initialized: {}x{} format={:?} frames={}", + extent.width, + extent.height, + format, + frames.len() + ); + + Ok(Self { + device, + queue, + render_pass, + pipeline_layout, + pipeline, + descriptor_set_layout, + video_texture, + command_pool, + vertex_buffer, + vertex_memory, + frames, + extent, + format, + vr_cache: HashMap::new(), + }) + } + + /// Render the quad overlay for the given swapchain image index. + /// + /// `mvp` is a column-major 4x4 matrix that transforms the unit quad to clip space. + /// `new_frame` is optional RGBA pixel data to upload to the video texture. + /// + /// # Safety + /// Must be called from the render thread with a valid image_index. + pub unsafe fn render_frame( + &self, + image_index: u32, + image: vk::Image, + mvp: &[f32; 16], + new_frame: Option<&[u8]>, + ) -> Result<(), String> { + let idx = image_index as usize; + if idx >= self.frames.len() { + return Err(format!("image_index {} out of range", image_index)); + } + + let frame = &self.frames[idx]; + + // Wait for previous use of this frame's resources + self.device + .wait_for_fences(&[frame.fence], true, u64::MAX) + .map_err(|e| format!("Wait fence failed: {:?}", e))?; + self.device + .reset_fences(&[frame.fence]) + .map_err(|e| format!("Reset fence failed: {:?}", e))?; + + // Reset and record command buffer + self.device + .reset_command_buffer(frame.command_buffer, vk::CommandBufferResetFlags::empty()) + .map_err(|e| format!("Reset cmd buf failed: {:?}", e))?; + + let begin_info = vk::CommandBufferBeginInfo::default() + .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT); + + self.device + .begin_command_buffer(frame.command_buffer, &begin_info) + .map_err(|e| format!("Begin cmd buf failed: {:?}", e))?; + + // Upload new video frame if available + if let Some(frame_data) = new_frame { + self.video_texture + .upload_frame(&self.device, frame.command_buffer, frame_data); + } + + // Transition: PRESENT_SRC -> COLOR_ATTACHMENT + let barrier = vk::ImageMemoryBarrier::default() + .src_access_mask(vk::AccessFlags::MEMORY_READ) + .dst_access_mask(vk::AccessFlags::COLOR_ATTACHMENT_WRITE) + .old_layout(vk::ImageLayout::PRESENT_SRC_KHR) + .new_layout(vk::ImageLayout::COLOR_ATTACHMENT_OPTIMAL) + .image(image) + .subresource_range(vk::ImageSubresourceRange { + aspect_mask: vk::ImageAspectFlags::COLOR, + base_mip_level: 0, + level_count: 1, + base_array_layer: 0, + layer_count: 1, + }); + + self.device.cmd_pipeline_barrier( + frame.command_buffer, + vk::PipelineStageFlags::COLOR_ATTACHMENT_OUTPUT, + vk::PipelineStageFlags::COLOR_ATTACHMENT_OUTPUT, + vk::DependencyFlags::empty(), + &[], + &[], + &[barrier], + ); + + // Begin render pass (LOAD existing content) + let clear_values = []; // No clear - we use LOAD_OP_LOAD + let render_pass_info = vk::RenderPassBeginInfo::default() + .render_pass(self.render_pass) + .framebuffer(frame.framebuffer) + .render_area(vk::Rect2D { + offset: vk::Offset2D { x: 0, y: 0 }, + extent: self.extent, + }) + .clear_values(&clear_values); + + self.device.cmd_begin_render_pass( + frame.command_buffer, + &render_pass_info, + vk::SubpassContents::INLINE, + ); + + // Bind pipeline + self.device.cmd_bind_pipeline( + frame.command_buffer, + vk::PipelineBindPoint::GRAPHICS, + self.pipeline, + ); + + // Set dynamic viewport and scissor + let viewport = vk::Viewport { + x: 0.0, + y: 0.0, + width: self.extent.width as f32, + height: self.extent.height as f32, + min_depth: 0.0, + max_depth: 1.0, + }; + self.device + .cmd_set_viewport(frame.command_buffer, 0, &[viewport]); + + let scissor = vk::Rect2D { + offset: vk::Offset2D { x: 0, y: 0 }, + extent: self.extent, + }; + self.device + .cmd_set_scissor(frame.command_buffer, 0, &[scissor]); + + // Bind video texture descriptor set + self.device.cmd_bind_descriptor_sets( + frame.command_buffer, + vk::PipelineBindPoint::GRAPHICS, + self.pipeline_layout, + 0, + &[self.video_texture.descriptor_set], + &[], + ); + + // Push MVP matrix + let mvp_bytes: &[u8] = std::slice::from_raw_parts(mvp.as_ptr() as *const u8, 64); + self.device.cmd_push_constants( + frame.command_buffer, + self.pipeline_layout, + vk::ShaderStageFlags::VERTEX, + 0, + mvp_bytes, + ); + + // Bind vertex buffer and draw + self.device + .cmd_bind_vertex_buffers(frame.command_buffer, 0, &[self.vertex_buffer], &[0]); + self.device.cmd_draw(frame.command_buffer, 6, 1, 0, 0); + + // End render pass + self.device.cmd_end_render_pass(frame.command_buffer); + + // Transition: COLOR_ATTACHMENT -> PRESENT_SRC + let barrier = vk::ImageMemoryBarrier::default() + .src_access_mask(vk::AccessFlags::COLOR_ATTACHMENT_WRITE) + .dst_access_mask(vk::AccessFlags::MEMORY_READ) + .old_layout(vk::ImageLayout::COLOR_ATTACHMENT_OPTIMAL) + .new_layout(vk::ImageLayout::PRESENT_SRC_KHR) + .image(image) + .subresource_range(vk::ImageSubresourceRange { + aspect_mask: vk::ImageAspectFlags::COLOR, + base_mip_level: 0, + level_count: 1, + base_array_layer: 0, + layer_count: 1, + }); + + self.device.cmd_pipeline_barrier( + frame.command_buffer, + vk::PipelineStageFlags::COLOR_ATTACHMENT_OUTPUT, + vk::PipelineStageFlags::TOP_OF_PIPE, + vk::DependencyFlags::empty(), + &[], + &[], + &[barrier], + ); + + self.device + .end_command_buffer(frame.command_buffer) + .map_err(|e| format!("End cmd buf failed: {:?}", e))?; + + // Submit + let cmd_bufs = [frame.command_buffer]; + let submit_info = vk::SubmitInfo::default().command_buffers(&cmd_bufs); + + self.device + .queue_submit(self.queue, &[submit_info], frame.fence) + .map_err(|e| format!("Queue submit failed: {:?}", e))?; + + // Wait for our rendering to complete before the game presents + self.device + .wait_for_fences(&[frame.fence], true, u64::MAX) + .map_err(|e| format!("Wait after submit failed: {:?}", e))?; + + Ok(()) + } + + /// Render the quad overlay to a VR eye image. + /// + /// Uses cached resources per VR image handle to avoid re-creating + /// VkImageView, VkFramebuffer, VkFence, and command buffers every frame. + /// + /// # Safety + /// - `vr_image` must be a valid VkImage (from OpenVR's VRVulkanTextureData_t) + /// - The image is expected to be in TRANSFER_SRC_OPTIMAL layout (ready for compositor) + /// - Must be called from the render thread + pub unsafe fn render_to_vr_image( + &mut self, + vr_image: vk::Image, + extent: vk::Extent2D, + mvp: &[f32; 16], + new_frame: Option<&[u8]>, + ) -> Result<(), String> { + // Get or create cached resources for this VR image + if !self.vr_cache.contains_key(&vr_image) { + let cache = self.create_vr_frame_cache(vr_image, extent)?; + self.vr_cache.insert(vr_image, cache); + } + let cached = self.vr_cache.get(&vr_image).unwrap(); + let cmd = cached.command_buffer; + let fence = cached.fence; + let framebuffer = cached.framebuffer; + + // Wait for previous use of this cached fence + self.device + .wait_for_fences(&[fence], true, u64::MAX) + .map_err(|e| format!("VR fence wait failed: {:?}", e))?; + self.device + .reset_fences(&[fence]) + .map_err(|e| format!("VR fence reset failed: {:?}", e))?; + + // Reset and re-record command buffer + self.device + .reset_command_buffer(cmd, vk::CommandBufferResetFlags::empty()) + .map_err(|e| format!("VR cmd reset failed: {:?}", e))?; + + let begin_info = vk::CommandBufferBeginInfo::default() + .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT); + + self.device + .begin_command_buffer(cmd, &begin_info) + .map_err(|e| format!("VR begin cmd failed: {:?}", e))?; + + // Upload new video frame if available + if let Some(frame_data) = new_frame { + self.video_texture + .upload_frame(&self.device, cmd, frame_data); + } + + // Transition VR image: TRANSFER_SRC_OPTIMAL -> COLOR_ATTACHMENT_OPTIMAL + let barrier = vk::ImageMemoryBarrier::default() + .src_access_mask(vk::AccessFlags::TRANSFER_READ) + .dst_access_mask(vk::AccessFlags::COLOR_ATTACHMENT_WRITE) + .old_layout(vk::ImageLayout::TRANSFER_SRC_OPTIMAL) + .new_layout(vk::ImageLayout::COLOR_ATTACHMENT_OPTIMAL) + .image(vr_image) + .subresource_range(vk::ImageSubresourceRange { + aspect_mask: vk::ImageAspectFlags::COLOR, + base_mip_level: 0, + level_count: 1, + base_array_layer: 0, + layer_count: 1, + }); + + self.device.cmd_pipeline_barrier( + cmd, + vk::PipelineStageFlags::TRANSFER, + vk::PipelineStageFlags::COLOR_ATTACHMENT_OUTPUT, + vk::DependencyFlags::empty(), + &[], + &[], + &[barrier], + ); + + // Begin render pass (LOAD existing game content) + let clear_values = []; + let render_pass_info = vk::RenderPassBeginInfo::default() + .render_pass(self.render_pass) + .framebuffer(framebuffer) + .render_area(vk::Rect2D { + offset: vk::Offset2D { x: 0, y: 0 }, + extent, + }) + .clear_values(&clear_values); + + self.device + .cmd_begin_render_pass(cmd, &render_pass_info, vk::SubpassContents::INLINE); + + // Bind pipeline and set dynamic state + self.device + .cmd_bind_pipeline(cmd, vk::PipelineBindPoint::GRAPHICS, self.pipeline); + + let viewport = vk::Viewport { + x: 0.0, + y: 0.0, + width: extent.width as f32, + height: extent.height as f32, + min_depth: 0.0, + max_depth: 1.0, + }; + self.device.cmd_set_viewport(cmd, 0, &[viewport]); + + let scissor = vk::Rect2D { + offset: vk::Offset2D { x: 0, y: 0 }, + extent, + }; + self.device.cmd_set_scissor(cmd, 0, &[scissor]); + + // Bind descriptor set and push MVP + self.device.cmd_bind_descriptor_sets( + cmd, + vk::PipelineBindPoint::GRAPHICS, + self.pipeline_layout, + 0, + &[self.video_texture.descriptor_set], + &[], + ); + + let mvp_bytes: &[u8] = std::slice::from_raw_parts(mvp.as_ptr() as *const u8, 64); + self.device.cmd_push_constants( + cmd, + self.pipeline_layout, + vk::ShaderStageFlags::VERTEX, + 0, + mvp_bytes, + ); + + // Draw quad + self.device + .cmd_bind_vertex_buffers(cmd, 0, &[self.vertex_buffer], &[0]); + self.device.cmd_draw(cmd, 6, 1, 0, 0); + + self.device.cmd_end_render_pass(cmd); + + // Transition VR image back: COLOR_ATTACHMENT_OPTIMAL -> TRANSFER_SRC_OPTIMAL + let barrier = vk::ImageMemoryBarrier::default() + .src_access_mask(vk::AccessFlags::COLOR_ATTACHMENT_WRITE) + .dst_access_mask(vk::AccessFlags::TRANSFER_READ) + .old_layout(vk::ImageLayout::COLOR_ATTACHMENT_OPTIMAL) + .new_layout(vk::ImageLayout::TRANSFER_SRC_OPTIMAL) + .image(vr_image) + .subresource_range(vk::ImageSubresourceRange { + aspect_mask: vk::ImageAspectFlags::COLOR, + base_mip_level: 0, + level_count: 1, + base_array_layer: 0, + layer_count: 1, + }); + + self.device.cmd_pipeline_barrier( + cmd, + vk::PipelineStageFlags::COLOR_ATTACHMENT_OUTPUT, + vk::PipelineStageFlags::TRANSFER, + vk::DependencyFlags::empty(), + &[], + &[], + &[barrier], + ); + + self.device + .end_command_buffer(cmd) + .map_err(|e| format!("VR end cmd failed: {:?}", e))?; + + // Submit and wait + let cmd_bufs_submit = [cmd]; + let submit_info = vk::SubmitInfo::default().command_buffers(&cmd_bufs_submit); + + self.device + .queue_submit(self.queue, &[submit_info], fence) + .map_err(|e| format!("VR queue submit failed: {:?}", e))?; + + self.device + .wait_for_fences(&[fence], true, u64::MAX) + .map_err(|e| format!("VR fence wait failed: {:?}", e))?; + + Ok(()) + } + + /// Create cached VR frame resources for a specific eye image. + unsafe fn create_vr_frame_cache( + &self, + vr_image: vk::Image, + extent: vk::Extent2D, + ) -> Result { + let view_info = vk::ImageViewCreateInfo::default() + .image(vr_image) + .view_type(vk::ImageViewType::TYPE_2D) + .format(self.format) + .subresource_range(vk::ImageSubresourceRange { + aspect_mask: vk::ImageAspectFlags::COLOR, + base_mip_level: 0, + level_count: 1, + base_array_layer: 0, + layer_count: 1, + }); + + let image_view = self + .device + .create_image_view(&view_info, None) + .map_err(|e| format!("VR cache image view failed: {:?}", e))?; + + let fb_info = vk::FramebufferCreateInfo::default() + .render_pass(self.render_pass) + .attachments(std::slice::from_ref(&image_view)) + .width(extent.width) + .height(extent.height) + .layers(1); + + let framebuffer = self + .device + .create_framebuffer(&fb_info, None) + .map_err(|e| { + self.device.destroy_image_view(image_view, None); + format!("VR cache framebuffer failed: {:?}", e) + })?; + + let alloc_info = vk::CommandBufferAllocateInfo::default() + .command_pool(self.command_pool) + .level(vk::CommandBufferLevel::PRIMARY) + .command_buffer_count(1); + + let cmd_bufs = self + .device + .allocate_command_buffers(&alloc_info) + .map_err(|e| { + self.device.destroy_framebuffer(framebuffer, None); + self.device.destroy_image_view(image_view, None); + format!("VR cache cmd buf failed: {:?}", e) + })?; + + let fence_info = vk::FenceCreateInfo::default().flags(vk::FenceCreateFlags::SIGNALED); + let fence = self.device.create_fence(&fence_info, None).map_err(|e| { + self.device + .free_command_buffers(self.command_pool, &cmd_bufs); + self.device.destroy_framebuffer(framebuffer, None); + self.device.destroy_image_view(image_view, None); + format!("VR cache fence failed: {:?}", e) + })?; + + Ok(VrFrameCache { + image_view, + framebuffer, + fence, + command_buffer: cmd_bufs[0], + }) + } + + /// Get the swapchain extent. + pub fn extent(&self) -> vk::Extent2D { + self.extent + } +} + +impl Drop for VulkanRenderer { + fn drop(&mut self) { + unsafe { + let _ = self.device.device_wait_idle(); + + // Clean up cached VR resources + for cached in self.vr_cache.values() { + self.device.destroy_fence(cached.fence, None); + self.device.destroy_framebuffer(cached.framebuffer, None); + self.device.destroy_image_view(cached.image_view, None); + // Command buffers freed with command pool below + } + + for frame in &self.frames { + self.device.destroy_framebuffer(frame.framebuffer, None); + self.device.destroy_image_view(frame.image_view, None); + self.device.destroy_fence(frame.fence, None); + } + + self.video_texture.destroy(&self.device); + self.device.destroy_command_pool(self.command_pool, None); + self.device.free_memory(self.vertex_memory, None); + self.device.destroy_buffer(self.vertex_buffer, None); + self.device.destroy_pipeline(self.pipeline, None); + self.device + .destroy_pipeline_layout(self.pipeline_layout, None); + self.device + .destroy_descriptor_set_layout(self.descriptor_set_layout, None); + self.device.destroy_render_pass(self.render_pass, None); + + vlog!("Renderer destroyed"); + } + } +} + +// --- Helper functions --- + +/// Create render pass with LOAD_OP_LOAD (preserves game frame). +unsafe fn create_render_pass( + device: &ash::Device, + format: vk::Format, +) -> Result { + let attachment = vk::AttachmentDescription { + format, + samples: vk::SampleCountFlags::TYPE_1, + load_op: vk::AttachmentLoadOp::LOAD, + store_op: vk::AttachmentStoreOp::STORE, + stencil_load_op: vk::AttachmentLoadOp::DONT_CARE, + stencil_store_op: vk::AttachmentStoreOp::DONT_CARE, + initial_layout: vk::ImageLayout::COLOR_ATTACHMENT_OPTIMAL, + final_layout: vk::ImageLayout::COLOR_ATTACHMENT_OPTIMAL, + ..Default::default() + }; + + let color_ref = vk::AttachmentReference { + attachment: 0, + layout: vk::ImageLayout::COLOR_ATTACHMENT_OPTIMAL, + }; + + let subpass = vk::SubpassDescription::default() + .pipeline_bind_point(vk::PipelineBindPoint::GRAPHICS) + .color_attachments(std::slice::from_ref(&color_ref)); + + let dependency = vk::SubpassDependency { + src_subpass: vk::SUBPASS_EXTERNAL, + dst_subpass: 0, + src_stage_mask: vk::PipelineStageFlags::COLOR_ATTACHMENT_OUTPUT, + dst_stage_mask: vk::PipelineStageFlags::COLOR_ATTACHMENT_OUTPUT, + src_access_mask: vk::AccessFlags::empty(), + dst_access_mask: vk::AccessFlags::COLOR_ATTACHMENT_WRITE, + ..Default::default() + }; + + let info = vk::RenderPassCreateInfo::default() + .attachments(std::slice::from_ref(&attachment)) + .subpasses(std::slice::from_ref(&subpass)) + .dependencies(std::slice::from_ref(&dependency)); + + device + .create_render_pass(&info, None) + .map_err(|e| format!("Failed to create render pass: {:?}", e)) +} + +/// Create the graphics pipeline. +unsafe fn create_pipeline( + device: &ash::Device, + render_pass: vk::RenderPass, + layout: vk::PipelineLayout, + extent: vk::Extent2D, +) -> Result { + // Create shader modules + let vert_module = create_shader_module(device, VERT_SPV)?; + let frag_module = create_shader_module(device, FRAG_SPV)?; + + let entry_name = c"main"; + + let stages = [ + vk::PipelineShaderStageCreateInfo::default() + .stage(vk::ShaderStageFlags::VERTEX) + .module(vert_module) + .name(entry_name), + vk::PipelineShaderStageCreateInfo::default() + .stage(vk::ShaderStageFlags::FRAGMENT) + .module(frag_module) + .name(entry_name), + ]; + + // Vertex input + let binding = Vertex::binding_description(); + let attributes = Vertex::attribute_descriptions(); + + let vertex_input = vk::PipelineVertexInputStateCreateInfo::default() + .vertex_binding_descriptions(std::slice::from_ref(&binding)) + .vertex_attribute_descriptions(&attributes); + + let input_assembly = vk::PipelineInputAssemblyStateCreateInfo::default() + .topology(vk::PrimitiveTopology::TRIANGLE_LIST); + + // Viewport/scissor (dynamic state) + let viewport = vk::Viewport { + x: 0.0, + y: 0.0, + width: extent.width as f32, + height: extent.height as f32, + min_depth: 0.0, + max_depth: 1.0, + }; + let scissor = vk::Rect2D { + offset: vk::Offset2D { x: 0, y: 0 }, + extent, + }; + + let viewport_state = vk::PipelineViewportStateCreateInfo::default() + .viewports(std::slice::from_ref(&viewport)) + .scissors(std::slice::from_ref(&scissor)); + + let rasterizer = vk::PipelineRasterizationStateCreateInfo::default() + .polygon_mode(vk::PolygonMode::FILL) + .line_width(1.0) + .cull_mode(vk::CullModeFlags::NONE) // No culling for overlay + .front_face(vk::FrontFace::COUNTER_CLOCKWISE); + + let multisampling = vk::PipelineMultisampleStateCreateInfo::default() + .rasterization_samples(vk::SampleCountFlags::TYPE_1); + + // Alpha blending + let blend_attachment = vk::PipelineColorBlendAttachmentState { + blend_enable: vk::TRUE, + src_color_blend_factor: vk::BlendFactor::SRC_ALPHA, + dst_color_blend_factor: vk::BlendFactor::ONE_MINUS_SRC_ALPHA, + color_blend_op: vk::BlendOp::ADD, + src_alpha_blend_factor: vk::BlendFactor::ONE, + dst_alpha_blend_factor: vk::BlendFactor::ZERO, + alpha_blend_op: vk::BlendOp::ADD, + color_write_mask: vk::ColorComponentFlags::RGBA, + }; + + let color_blending = vk::PipelineColorBlendStateCreateInfo::default() + .attachments(std::slice::from_ref(&blend_attachment)); + + // Dynamic state + let dynamic_states = [vk::DynamicState::VIEWPORT, vk::DynamicState::SCISSOR]; + let dynamic_state = + vk::PipelineDynamicStateCreateInfo::default().dynamic_states(&dynamic_states); + + let pipeline_info = vk::GraphicsPipelineCreateInfo::default() + .stages(&stages) + .vertex_input_state(&vertex_input) + .input_assembly_state(&input_assembly) + .viewport_state(&viewport_state) + .rasterization_state(&rasterizer) + .multisample_state(&multisampling) + .color_blend_state(&color_blending) + .dynamic_state(&dynamic_state) + .layout(layout) + .render_pass(render_pass) + .subpass(0); + + let pipelines = device + .create_graphics_pipelines(vk::PipelineCache::null(), &[pipeline_info], None) + .map_err(|(_pipelines, e)| format!("Failed to create pipeline: {:?}", e))?; + + // Clean up shader modules (no longer needed after pipeline creation) + device.destroy_shader_module(vert_module, None); + device.destroy_shader_module(frag_module, None); + + Ok(pipelines[0]) +} + +/// Create a shader module from SPIR-V bytes. +unsafe fn create_shader_module( + device: &ash::Device, + spv_bytes: &[u8], +) -> Result { + // SPIR-V must be aligned to 4 bytes and length must be multiple of 4 + if !spv_bytes.len().is_multiple_of(4) { + return Err("SPIR-V not aligned to 4 bytes".to_string()); + } + + let code: &[u32] = + std::slice::from_raw_parts(spv_bytes.as_ptr() as *const u32, spv_bytes.len() / 4); + + let info = vk::ShaderModuleCreateInfo::default().code(code); + + device + .create_shader_module(&info, None) + .map_err(|e| format!("Failed to create shader module: {:?}", e)) +} + +/// Create vertex buffer with quad data. +unsafe fn create_vertex_buffer( + device: &ash::Device, + instance: &ash::Instance, + physical_device: vk::PhysicalDevice, +) -> Result<(vk::Buffer, vk::DeviceMemory), String> { + let buffer_info = vk::BufferCreateInfo::default() + .size(QUAD_VERTICES_SIZE) + .usage(vk::BufferUsageFlags::VERTEX_BUFFER) + .sharing_mode(vk::SharingMode::EXCLUSIVE); + + let buffer = device + .create_buffer(&buffer_info, None) + .map_err(|e| format!("Failed to create vertex buffer: {:?}", e))?; + + let mem_reqs = device.get_buffer_memory_requirements(buffer); + + let mem_props = instance.get_physical_device_memory_properties(physical_device); + let memory_type = find_memory_type( + &mem_props, + mem_reqs.memory_type_bits, + vk::MemoryPropertyFlags::HOST_VISIBLE | vk::MemoryPropertyFlags::HOST_COHERENT, + ) + .ok_or_else(|| "No suitable memory type for vertex buffer".to_string())?; + + let alloc_info = vk::MemoryAllocateInfo::default() + .allocation_size(mem_reqs.size) + .memory_type_index(memory_type); + + let memory = device + .allocate_memory(&alloc_info, None) + .map_err(|e| format!("Failed to allocate vertex memory: {:?}", e))?; + + device + .bind_buffer_memory(buffer, memory, 0) + .map_err(|e| format!("Failed to bind vertex memory: {:?}", e))?; + + // Map and copy vertex data + let data_ptr = device + .map_memory(memory, 0, QUAD_VERTICES_SIZE, vk::MemoryMapFlags::empty()) + .map_err(|e| format!("Failed to map vertex memory: {:?}", e))?; + + std::ptr::copy_nonoverlapping( + QUAD_VERTICES.as_ptr() as *const u8, + data_ptr as *mut u8, + QUAD_VERTICES_SIZE as usize, + ); + + device.unmap_memory(memory); + + Ok((buffer, memory)) +} + +/// Find a suitable memory type index. +fn find_memory_type( + props: &vk::PhysicalDeviceMemoryProperties, + type_bits: u32, + required: vk::MemoryPropertyFlags, +) -> Option { + (0..props.memory_type_count).find(|&i| { + (type_bits & (1 << i)) != 0 + && props.memory_types[i as usize] + .property_flags + .contains(required) + }) +} + +/// Create per-frame resources (image views, framebuffers, command buffers, fences). +unsafe fn create_frame_resources( + device: &ash::Device, + images: &[vk::Image], + format: vk::Format, + extent: vk::Extent2D, + render_pass: vk::RenderPass, + command_pool: vk::CommandPool, +) -> Result, String> { + // Allocate command buffers for all frames at once + let alloc_info = vk::CommandBufferAllocateInfo::default() + .command_pool(command_pool) + .level(vk::CommandBufferLevel::PRIMARY) + .command_buffer_count(images.len() as u32); + + let command_buffers = device + .allocate_command_buffers(&alloc_info) + .map_err(|e| format!("Failed to allocate command buffers: {:?}", e))?; + + let mut frames = Vec::with_capacity(images.len()); + + for (i, &image) in images.iter().enumerate() { + // Create image view + let view_info = vk::ImageViewCreateInfo::default() + .image(image) + .view_type(vk::ImageViewType::TYPE_2D) + .format(format) + .subresource_range(vk::ImageSubresourceRange { + aspect_mask: vk::ImageAspectFlags::COLOR, + base_mip_level: 0, + level_count: 1, + base_array_layer: 0, + layer_count: 1, + }); + + let image_view = device + .create_image_view(&view_info, None) + .map_err(|e| format!("Failed to create image view {}: {:?}", i, e))?; + + // Create framebuffer + let fb_info = vk::FramebufferCreateInfo::default() + .render_pass(render_pass) + .attachments(std::slice::from_ref(&image_view)) + .width(extent.width) + .height(extent.height) + .layers(1); + + let framebuffer = device + .create_framebuffer(&fb_info, None) + .map_err(|e| format!("Failed to create framebuffer {}: {:?}", i, e))?; + + // Create fence (start signaled so first wait doesn't block) + let fence_info = vk::FenceCreateInfo::default().flags(vk::FenceCreateFlags::SIGNALED); + + let fence = device + .create_fence(&fence_info, None) + .map_err(|e| format!("Failed to create fence {}: {:?}", i, e))?; + + frames.push(FrameResources { + framebuffer, + command_buffer: command_buffers[i], + fence, + image_view, + }); + } + + Ok(frames) +} diff --git a/projects/nms-cockpit-video/injector/src/renderer/texture.rs b/projects/nms-cockpit-video/injector/src/renderer/texture.rs new file mode 100644 index 0000000..4ecf49f --- /dev/null +++ b/projects/nms-cockpit-video/injector/src/renderer/texture.rs @@ -0,0 +1,519 @@ +//! Video texture management. +//! +//! Manages a device-local VkImage for the video texture, a staging buffer +//! for CPU→GPU uploads, and the descriptor set that binds the texture +//! to the fragment shader. + +use crate::log::vlog; +use ash::vk; + +// Safety: The staging_ptr is only accessed from the render thread, +// protected by the Mutex in the hooks module. +unsafe impl Send for VideoTexture {} +unsafe impl Sync for VideoTexture {} + +/// Video texture with staging buffer and descriptor resources. +pub struct VideoTexture { + /// Device-local image (TRANSFER_DST | SAMPLED). + pub image: vk::Image, + pub image_memory: vk::DeviceMemory, + pub image_view: vk::ImageView, + pub sampler: vk::Sampler, + + /// Staging buffer (HOST_VISIBLE | HOST_COHERENT, persistently mapped). + pub staging_buffer: vk::Buffer, + pub staging_memory: vk::DeviceMemory, + pub staging_ptr: *mut u8, + + /// Descriptor resources. + pub descriptor_pool: vk::DescriptorPool, + pub descriptor_set: vk::DescriptorSet, + + /// Texture dimensions. + pub width: u32, + pub height: u32, + + /// Whether the image has been transitioned to SHADER_READ_ONLY at least once. + pub initialized: bool, +} + +impl VideoTexture { + /// Create the video texture and all associated resources. + /// + /// # Safety + /// All Vulkan handles must be valid. + #[allow(clippy::too_many_arguments)] + pub unsafe fn new( + device: &ash::Device, + instance: &ash::Instance, + physical_device: vk::PhysicalDevice, + descriptor_set_layout: vk::DescriptorSetLayout, + width: u32, + height: u32, + queue: vk::Queue, + command_pool: vk::CommandPool, + ) -> Result { + let frame_size = (width * height * 4) as u64; + + // --- Create device-local image --- + let image_info = vk::ImageCreateInfo::default() + .image_type(vk::ImageType::TYPE_2D) + .format(vk::Format::R8G8B8A8_UNORM) + .extent(vk::Extent3D { + width, + height, + depth: 1, + }) + .mip_levels(1) + .array_layers(1) + .samples(vk::SampleCountFlags::TYPE_1) + .tiling(vk::ImageTiling::OPTIMAL) + .usage(vk::ImageUsageFlags::TRANSFER_DST | vk::ImageUsageFlags::SAMPLED) + .sharing_mode(vk::SharingMode::EXCLUSIVE) + .initial_layout(vk::ImageLayout::UNDEFINED); + + let image = device + .create_image(&image_info, None) + .map_err(|e| format!("Create image failed: {:?}", e))?; + + let mem_reqs = device.get_image_memory_requirements(image); + let mem_props = instance.get_physical_device_memory_properties(physical_device); + + let image_mem_type = find_memory_type( + &mem_props, + mem_reqs.memory_type_bits, + vk::MemoryPropertyFlags::DEVICE_LOCAL, + ) + .ok_or_else(|| "No device-local memory for texture".to_string())?; + + let alloc_info = vk::MemoryAllocateInfo::default() + .allocation_size(mem_reqs.size) + .memory_type_index(image_mem_type); + + let image_memory = device + .allocate_memory(&alloc_info, None) + .map_err(|e| format!("Allocate image memory failed: {:?}", e))?; + + device + .bind_image_memory(image, image_memory, 0) + .map_err(|e| format!("Bind image memory failed: {:?}", e))?; + + // --- Create image view --- + let view_info = vk::ImageViewCreateInfo::default() + .image(image) + .view_type(vk::ImageViewType::TYPE_2D) + .format(vk::Format::R8G8B8A8_UNORM) + .subresource_range(vk::ImageSubresourceRange { + aspect_mask: vk::ImageAspectFlags::COLOR, + base_mip_level: 0, + level_count: 1, + base_array_layer: 0, + layer_count: 1, + }); + + let image_view = device + .create_image_view(&view_info, None) + .map_err(|e| format!("Create image view failed: {:?}", e))?; + + // --- Create sampler --- + let sampler_info = vk::SamplerCreateInfo::default() + .mag_filter(vk::Filter::LINEAR) + .min_filter(vk::Filter::LINEAR) + .address_mode_u(vk::SamplerAddressMode::CLAMP_TO_EDGE) + .address_mode_v(vk::SamplerAddressMode::CLAMP_TO_EDGE) + .address_mode_w(vk::SamplerAddressMode::CLAMP_TO_EDGE) + .mipmap_mode(vk::SamplerMipmapMode::LINEAR) + .min_lod(0.0) + .max_lod(0.0); + + let sampler = device + .create_sampler(&sampler_info, None) + .map_err(|e| format!("Create sampler failed: {:?}", e))?; + + // --- Create staging buffer (persistently mapped) --- + let buffer_info = vk::BufferCreateInfo::default() + .size(frame_size) + .usage(vk::BufferUsageFlags::TRANSFER_SRC) + .sharing_mode(vk::SharingMode::EXCLUSIVE); + + let staging_buffer = device + .create_buffer(&buffer_info, None) + .map_err(|e| format!("Create staging buffer failed: {:?}", e))?; + + let buf_mem_reqs = device.get_buffer_memory_requirements(staging_buffer); + + let staging_mem_type = find_memory_type( + &mem_props, + buf_mem_reqs.memory_type_bits, + vk::MemoryPropertyFlags::HOST_VISIBLE | vk::MemoryPropertyFlags::HOST_COHERENT, + ) + .ok_or_else(|| "No host-visible memory for staging buffer".to_string())?; + + let staging_alloc = vk::MemoryAllocateInfo::default() + .allocation_size(buf_mem_reqs.size) + .memory_type_index(staging_mem_type); + + let staging_memory = device + .allocate_memory(&staging_alloc, None) + .map_err(|e| format!("Allocate staging memory failed: {:?}", e))?; + + device + .bind_buffer_memory(staging_buffer, staging_memory, 0) + .map_err(|e| format!("Bind staging memory failed: {:?}", e))?; + + // Persistently map the staging buffer + let staging_ptr = device + .map_memory(staging_memory, 0, frame_size, vk::MemoryMapFlags::empty()) + .map_err(|e| format!("Map staging memory failed: {:?}", e))? + as *mut u8; + + // --- Create descriptor pool and set --- + let pool_sizes = [ + vk::DescriptorPoolSize { + ty: vk::DescriptorType::SAMPLED_IMAGE, + descriptor_count: 1, + }, + vk::DescriptorPoolSize { + ty: vk::DescriptorType::SAMPLER, + descriptor_count: 1, + }, + ]; + + let pool_info = vk::DescriptorPoolCreateInfo::default() + .max_sets(1) + .pool_sizes(&pool_sizes); + + let descriptor_pool = device + .create_descriptor_pool(&pool_info, None) + .map_err(|e| format!("Create descriptor pool failed: {:?}", e))?; + + let alloc_info = vk::DescriptorSetAllocateInfo::default() + .descriptor_pool(descriptor_pool) + .set_layouts(std::slice::from_ref(&descriptor_set_layout)); + + let descriptor_sets = device + .allocate_descriptor_sets(&alloc_info) + .map_err(|e| format!("Allocate descriptor set failed: {:?}", e))?; + + let descriptor_set = descriptor_sets[0]; + + // --- Transition image to SHADER_READ_ONLY and clear to black --- + transition_and_clear(device, queue, command_pool, image, width, height)?; + + // --- Update descriptor set: binding 0 = image, binding 1 = sampler --- + let image_info_desc = vk::DescriptorImageInfo { + sampler: vk::Sampler::null(), + image_view, + image_layout: vk::ImageLayout::SHADER_READ_ONLY_OPTIMAL, + }; + + let sampler_info_desc = vk::DescriptorImageInfo { + sampler, + image_view: vk::ImageView::null(), + image_layout: vk::ImageLayout::UNDEFINED, + }; + + let writes = [ + vk::WriteDescriptorSet::default() + .dst_set(descriptor_set) + .dst_binding(0) + .descriptor_type(vk::DescriptorType::SAMPLED_IMAGE) + .image_info(std::slice::from_ref(&image_info_desc)), + vk::WriteDescriptorSet::default() + .dst_set(descriptor_set) + .dst_binding(1) + .descriptor_type(vk::DescriptorType::SAMPLER) + .image_info(std::slice::from_ref(&sampler_info_desc)), + ]; + + device.update_descriptor_sets(&writes, &[]); + + vlog!( + "VideoTexture created: {}x{} staging={}KB", + width, + height, + frame_size / 1024 + ); + + Ok(Self { + image, + image_memory, + image_view, + sampler, + staging_buffer, + staging_memory, + staging_ptr, + descriptor_pool, + descriptor_set, + width, + height, + initialized: true, + }) + } + + /// Upload frame data to the texture. + /// + /// Copies RGBA data to the staging buffer, then records commands to + /// transfer it to the device-local image. + /// + /// # Safety + /// - `cmd` must be a command buffer in the recording state + /// - `frame_data` must be exactly width * height * 4 bytes + pub unsafe fn upload_frame( + &self, + device: &ash::Device, + cmd: vk::CommandBuffer, + frame_data: &[u8], + ) { + let frame_size = (self.width * self.height * 4) as usize; + debug_assert_eq!(frame_data.len(), frame_size); + + // Copy to staging buffer (persistently mapped, HOST_COHERENT = no flush needed) + std::ptr::copy_nonoverlapping(frame_data.as_ptr(), self.staging_ptr, frame_size); + + // Transition image: SHADER_READ_ONLY → TRANSFER_DST + let barrier = vk::ImageMemoryBarrier::default() + .src_access_mask(vk::AccessFlags::SHADER_READ) + .dst_access_mask(vk::AccessFlags::TRANSFER_WRITE) + .old_layout(vk::ImageLayout::SHADER_READ_ONLY_OPTIMAL) + .new_layout(vk::ImageLayout::TRANSFER_DST_OPTIMAL) + .image(self.image) + .subresource_range(color_subresource_range()); + + device.cmd_pipeline_barrier( + cmd, + vk::PipelineStageFlags::FRAGMENT_SHADER, + vk::PipelineStageFlags::TRANSFER, + vk::DependencyFlags::empty(), + &[], + &[], + &[barrier], + ); + + // Copy staging buffer → image + let region = vk::BufferImageCopy { + buffer_offset: 0, + buffer_row_length: 0, // tightly packed + buffer_image_height: 0, // tightly packed + image_subresource: vk::ImageSubresourceLayers { + aspect_mask: vk::ImageAspectFlags::COLOR, + mip_level: 0, + base_array_layer: 0, + layer_count: 1, + }, + image_offset: vk::Offset3D { x: 0, y: 0, z: 0 }, + image_extent: vk::Extent3D { + width: self.width, + height: self.height, + depth: 1, + }, + }; + + device.cmd_copy_buffer_to_image( + cmd, + self.staging_buffer, + self.image, + vk::ImageLayout::TRANSFER_DST_OPTIMAL, + &[region], + ); + + // Transition image: TRANSFER_DST → SHADER_READ_ONLY + let barrier = vk::ImageMemoryBarrier::default() + .src_access_mask(vk::AccessFlags::TRANSFER_WRITE) + .dst_access_mask(vk::AccessFlags::SHADER_READ) + .old_layout(vk::ImageLayout::TRANSFER_DST_OPTIMAL) + .new_layout(vk::ImageLayout::SHADER_READ_ONLY_OPTIMAL) + .image(self.image) + .subresource_range(color_subresource_range()); + + device.cmd_pipeline_barrier( + cmd, + vk::PipelineStageFlags::TRANSFER, + vk::PipelineStageFlags::FRAGMENT_SHADER, + vk::DependencyFlags::empty(), + &[], + &[], + &[barrier], + ); + } + + /// Destroy all resources. + /// + /// # Safety + /// Device must be idle. All handles must be valid. + pub unsafe fn destroy(&self, device: &ash::Device) { + device.unmap_memory(self.staging_memory); + device.destroy_buffer(self.staging_buffer, None); + device.free_memory(self.staging_memory, None); + device.destroy_sampler(self.sampler, None); + device.destroy_image_view(self.image_view, None); + device.destroy_image(self.image, None); + device.free_memory(self.image_memory, None); + device.destroy_descriptor_pool(self.descriptor_pool, None); + vlog!("VideoTexture destroyed"); + } +} + +/// Create the descriptor set layout for the video texture binding. +/// +/// Layout: +/// - set=0, binding=0: sampled image (texture_2d), fragment stage +/// - set=0, binding=1: sampler, fragment stage +/// +/// # Safety +/// Device must be valid. +pub unsafe fn create_descriptor_set_layout( + device: &ash::Device, +) -> Result { + let bindings = [ + vk::DescriptorSetLayoutBinding { + binding: 0, + descriptor_type: vk::DescriptorType::SAMPLED_IMAGE, + descriptor_count: 1, + stage_flags: vk::ShaderStageFlags::FRAGMENT, + p_immutable_samplers: std::ptr::null(), + ..Default::default() + }, + vk::DescriptorSetLayoutBinding { + binding: 1, + descriptor_type: vk::DescriptorType::SAMPLER, + descriptor_count: 1, + stage_flags: vk::ShaderStageFlags::FRAGMENT, + p_immutable_samplers: std::ptr::null(), + ..Default::default() + }, + ]; + + let layout_info = vk::DescriptorSetLayoutCreateInfo::default().bindings(&bindings); + + device + .create_descriptor_set_layout(&layout_info, None) + .map_err(|e| format!("Create descriptor set layout failed: {:?}", e)) +} + +// --- Helpers --- + +/// Standard color subresource range. +fn color_subresource_range() -> vk::ImageSubresourceRange { + vk::ImageSubresourceRange { + aspect_mask: vk::ImageAspectFlags::COLOR, + base_mip_level: 0, + level_count: 1, + base_array_layer: 0, + layer_count: 1, + } +} + +/// Find a suitable memory type index. +fn find_memory_type( + props: &vk::PhysicalDeviceMemoryProperties, + type_bits: u32, + required: vk::MemoryPropertyFlags, +) -> Option { + (0..props.memory_type_count).find(|&i| { + (type_bits & (1 << i)) != 0 + && props.memory_types[i as usize] + .property_flags + .contains(required) + }) +} + +/// Transition image to SHADER_READ_ONLY and clear to black. +/// +/// Uses a one-shot command buffer. +unsafe fn transition_and_clear( + device: &ash::Device, + queue: vk::Queue, + command_pool: vk::CommandPool, + image: vk::Image, + _width: u32, + _height: u32, +) -> Result<(), String> { + // Allocate a one-shot command buffer + let alloc_info = vk::CommandBufferAllocateInfo::default() + .command_pool(command_pool) + .level(vk::CommandBufferLevel::PRIMARY) + .command_buffer_count(1); + + let cmd_bufs = device + .allocate_command_buffers(&alloc_info) + .map_err(|e| format!("Allocate init cmd buf failed: {:?}", e))?; + let cmd = cmd_bufs[0]; + + let begin_info = + vk::CommandBufferBeginInfo::default().flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT); + device + .begin_command_buffer(cmd, &begin_info) + .map_err(|e| format!("Begin init cmd buf failed: {:?}", e))?; + + // Transition: UNDEFINED → TRANSFER_DST + let barrier = vk::ImageMemoryBarrier::default() + .src_access_mask(vk::AccessFlags::empty()) + .dst_access_mask(vk::AccessFlags::TRANSFER_WRITE) + .old_layout(vk::ImageLayout::UNDEFINED) + .new_layout(vk::ImageLayout::TRANSFER_DST_OPTIMAL) + .image(image) + .subresource_range(color_subresource_range()); + + device.cmd_pipeline_barrier( + cmd, + vk::PipelineStageFlags::TOP_OF_PIPE, + vk::PipelineStageFlags::TRANSFER, + vk::DependencyFlags::empty(), + &[], + &[], + &[barrier], + ); + + // Clear to black + let clear_color = vk::ClearColorValue { + float32: [0.0, 0.0, 0.0, 1.0], + }; + device.cmd_clear_color_image( + cmd, + image, + vk::ImageLayout::TRANSFER_DST_OPTIMAL, + &clear_color, + &[color_subresource_range()], + ); + + // Transition: TRANSFER_DST → SHADER_READ_ONLY + let barrier = vk::ImageMemoryBarrier::default() + .src_access_mask(vk::AccessFlags::TRANSFER_WRITE) + .dst_access_mask(vk::AccessFlags::SHADER_READ) + .old_layout(vk::ImageLayout::TRANSFER_DST_OPTIMAL) + .new_layout(vk::ImageLayout::SHADER_READ_ONLY_OPTIMAL) + .image(image) + .subresource_range(color_subresource_range()); + + device.cmd_pipeline_barrier( + cmd, + vk::PipelineStageFlags::TRANSFER, + vk::PipelineStageFlags::FRAGMENT_SHADER, + vk::DependencyFlags::empty(), + &[], + &[], + &[barrier], + ); + + device + .end_command_buffer(cmd) + .map_err(|e| format!("End init cmd buf failed: {:?}", e))?; + + // Submit and wait + let cmd_bufs_ref = [cmd]; + let submit_info = vk::SubmitInfo::default().command_buffers(&cmd_bufs_ref); + + device + .queue_submit(queue, &[submit_info], vk::Fence::null()) + .map_err(|e| format!("Submit init cmd buf failed: {:?}", e))?; + + device + .queue_wait_idle(queue) + .map_err(|e| format!("Queue wait idle failed: {:?}", e))?; + + // Free the one-shot command buffer + device.free_command_buffers(command_pool, &[cmd]); + + Ok(()) +} diff --git a/projects/nms-cockpit-video/injector/src/shmem_reader.rs b/projects/nms-cockpit-video/injector/src/shmem_reader.rs new file mode 100644 index 0000000..153e43c --- /dev/null +++ b/projects/nms-cockpit-video/injector/src/shmem_reader.rs @@ -0,0 +1,106 @@ +//! Shared memory frame reader. +//! +//! Opens the daemon's shared memory frame buffer and polls for new video frames. +//! Uses the itk-shmem seqlock for lock-free reads from the triple-buffered region. + +use crate::log::vlog; +use itk_shmem::FrameBuffer; + +/// Default shared memory name (daemon must create with the same name). +const SHMEM_NAME: &str = "itk_video_frames"; + +/// Default video dimensions (must match daemon's frame buffer). +const DEFAULT_WIDTH: u32 = 1280; +const DEFAULT_HEIGHT: u32 = 720; + +/// Reader that polls shared memory for new video frames. +pub struct ShmemFrameReader { + fb: FrameBuffer, + /// Buffer for frame data (reused across reads). + frame_buf: Vec, + /// Last seen PTS (to detect new frames). + last_pts: u64, + /// Frame dimensions. + width: u32, + height: u32, +} + +impl ShmemFrameReader { + /// Try to open the shared memory frame buffer. + /// + /// Returns `Err` if the daemon hasn't created the shared memory yet. + pub fn open() -> Result { + Self::open_with_dimensions(DEFAULT_WIDTH, DEFAULT_HEIGHT) + } + + /// Open with specific dimensions. + pub fn open_with_dimensions(width: u32, height: u32) -> Result { + let fb = FrameBuffer::open(SHMEM_NAME, width, height) + .map_err(|e| format!("Failed to open shmem '{}': {}", SHMEM_NAME, e))?; + + let frame_size = fb.frame_size(); + vlog!( + "ShmemFrameReader: opened '{}' {}x{} frame_size={}", + SHMEM_NAME, + width, + height, + frame_size + ); + + Ok(Self { + fb, + frame_buf: vec![0u8; frame_size], + last_pts: u64::MAX, // Sentinel: ensures first frame (PTS=0) is detected as new + width, + height, + }) + } + + /// Poll for a new frame. Returns the RGBA data if a new frame is available. + /// + /// This is lock-free and non-blocking. Returns `None` if: + /// - No new frame since last poll + /// - Writer is in the middle of writing + /// - Shared memory is in an inconsistent state + pub fn poll_frame(&mut self) -> Option<&[u8]> { + match self.fb.read_frame(self.last_pts, &mut self.frame_buf) { + Ok((pts, changed)) => { + if changed { + self.last_pts = pts; + Some(&self.frame_buf) + } else { + None + } + }, + Err(_) => None, + } + } + + /// Get the frame width. + pub fn width(&self) -> u32 { + self.width + } + + /// Get the frame height. + pub fn height(&self) -> u32 { + self.height + } + + /// Get frame size in bytes (width * height * 4). + pub fn frame_size(&self) -> usize { + self.fb.frame_size() + } + + /// Get the last PTS seen by this reader. + pub fn last_pts(&self) -> u64 { + self.last_pts + } + + /// Read the current PTS from shared memory header (for diagnostics). + pub fn shmem_pts(&self) -> u64 { + self.fb + .header() + .pts_ms + .load(std::sync::atomic::Ordering::Relaxed) + } +} diff --git a/projects/nms-cockpit-video/launcher/Cargo.toml b/projects/nms-cockpit-video/launcher/Cargo.toml new file mode 100644 index 0000000..1816329 --- /dev/null +++ b/projects/nms-cockpit-video/launcher/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "nms-video-launcher" +version.workspace = true +edition.workspace = true +license.workspace = true + +[[bin]] +name = "nms-video-launcher" +path = "src/main.rs" + +[dependencies] +windows = { workspace = true } + +[lints] +workspace = true diff --git a/projects/nms-cockpit-video/launcher/src/main.rs b/projects/nms-cockpit-video/launcher/src/main.rs new file mode 100644 index 0000000..5551f32 --- /dev/null +++ b/projects/nms-cockpit-video/launcher/src/main.rs @@ -0,0 +1,490 @@ +//! NMS Video Launcher +//! +//! Launches NMS with --disable-eac, waits for the game process to start, +//! then injects the cockpit video DLL. Handles the case where NMS re-spawns +//! itself during startup by polling for the real game process. +//! +//! Defaults: +//! NMS: D:\SteamLibrary\steamapps\common\No Man's Sky\Binaries\NMS.exe +//! DLL: nms_cockpit_injector.dll (next to this launcher) +//! +//! Usage: nms-video-launcher.exe [nms-exe-path] [dll-path] + +use std::env; +use std::ffi::c_void; +use std::path::{Path, PathBuf}; +use std::process::{self, Command, Stdio}; +use std::thread; +use std::time::{Duration, Instant}; + +use windows::Win32::Foundation::{CloseHandle, WAIT_OBJECT_0}; +use windows::Win32::System::Diagnostics::Debug::WriteProcessMemory; +use windows::Win32::System::LibraryLoader::{GetModuleHandleA, GetProcAddress}; +use windows::Win32::System::Memory::{ + MEM_COMMIT, MEM_RELEASE, MEM_RESERVE, PAGE_READWRITE, VirtualAllocEx, VirtualFreeEx, +}; +use windows::Win32::System::Threading::{ + CreateRemoteThread, INFINITE, OpenProcess, PROCESS_CREATE_THREAD, PROCESS_QUERY_INFORMATION, + PROCESS_SYNCHRONIZE, PROCESS_TERMINATE, PROCESS_VM_OPERATION, PROCESS_VM_WRITE, + TerminateProcess, WaitForSingleObject, +}; +use windows::core::PCSTR; + +const DEFAULT_NMS: &str = r"D:\SteamLibrary\steamapps\common\No Man's Sky\Binaries\NMS.exe"; +const DEFAULT_DLL: &str = "nms_cockpit_injector.dll"; +const DEFAULT_DAEMON: &str = "nms-video-daemon.exe"; + +/// How long to wait for NMS to start before giving up. +const WAIT_TIMEOUT: Duration = Duration::from_secs(30); + +/// How often to poll for the NMS process. +const POLL_INTERVAL: Duration = Duration::from_millis(500); + +/// Delay after finding the process before injecting. +/// Must be short - the DLL waits internally for vulkan-1.dll, and hooks must be +/// installed BEFORE NMS calls vkCreateDevice (which sets up ICD hooks). +const INJECT_DELAY: Duration = Duration::from_millis(500); + +fn main() { + let args: Vec = env::args().collect(); + + // Resolve NMS path + let nms_path = if args.len() > 1 { + PathBuf::from(&args[1]) + } else { + PathBuf::from(DEFAULT_NMS) + }; + + // Resolve DLL path (default: next to this exe) + let dll_path = if args.len() > 2 { + PathBuf::from(&args[2]) + } else { + let exe_dir = env::current_exe() + .ok() + .and_then(|p| p.parent().map(|d| d.to_path_buf())) + .unwrap_or_else(|| PathBuf::from(".")); + exe_dir.join(DEFAULT_DLL) + }; + + // Validate paths + if !nms_path.exists() { + eprintln!("Error: NMS executable not found: {}", nms_path.display()); + eprintln!("Pass the correct path as the first argument."); + process::exit(1); + } + if !dll_path.exists() { + eprintln!("Error: DLL not found: {}", dll_path.display()); + eprintln!( + "Place {} next to this launcher, or pass the path as the second argument.", + DEFAULT_DLL + ); + process::exit(1); + } + + // Get absolute path for the DLL (needed for remote process context) + let dll_abs = match dll_path.canonicalize() { + Ok(p) => p, + Err(e) => { + eprintln!("Error: cannot resolve DLL path: {}", e); + process::exit(1); + }, + }; + + // Strip \\?\ prefix from canonicalized path (LoadLibraryW handles regular paths fine) + let dll_str = dll_abs + .to_str() + .unwrap_or("") + .strip_prefix(r"\\?\") + .unwrap_or(dll_abs.to_str().unwrap_or("")); + let dll_clean = PathBuf::from(dll_str); + + println!("NMS: {}", nms_path.display()); + println!("DLL: {}", dll_clean.display()); + println!("Args: --disable-eac"); + println!(); + + // Start the daemon as a separate process + let daemon_pid = start_daemon(&dll_clean); + + // Launch NMS + println!("Launching NMS with --disable-eac..."); + let nms_dir = nms_path.parent().map(|p| p.to_path_buf()); + + let mut cmd = Command::new(&nms_path); + cmd.arg("--disable-eac") + .stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(Stdio::null()); + + if let Some(dir) = &nms_dir { + cmd.current_dir(dir); + } + + let initial_pid = match cmd.spawn() { + Ok(child) => { + let pid = child.id(); + println!("NMS launched: initial PID {}", pid); + pid + }, + Err(e) => { + eprintln!("Error: failed to launch NMS: {}", e); + process::exit(1); + }, + }; + + // Wait for the real NMS process to appear (skip the initial stub PID) + println!( + "Waiting for NMS game process (skipping initial PID {})...", + initial_pid + ); + let pid = match wait_for_nms(initial_pid) { + Some(pid) => pid, + None => { + eprintln!("Error: NMS process not found within {:?}", WAIT_TIMEOUT); + process::exit(1); + }, + }; + + println!("Found NMS game process: PID {}", pid); + println!( + "Waiting {}ms for process to initialize...", + INJECT_DELAY.as_millis() + ); + thread::sleep(INJECT_DELAY); + + // Inject into the running process + unsafe { + match inject_into_pid(pid, &dll_clean) { + Ok(()) => println!("Success: DLL injected into NMS (PID {})", pid), + Err(e) => { + eprintln!("Error: {}", e); + process::exit(1); + }, + } + } + + // Wait for NMS to exit, then shut down the daemon + println!(); + println!("Waiting for NMS to exit..."); + wait_for_process_exit(pid); + println!("NMS exited."); + + if let Some(dpid) = daemon_pid { + println!("Shutting down daemon (PID {})...", dpid); + kill_process(dpid); + } + println!("Done."); +} + +/// Start the video daemon as a detached background process. +/// Returns the daemon's PID if successfully started. +fn start_daemon(dll_path: &Path) -> Option { + let daemon_path = dll_path + .parent() + .map(|d| d.join(DEFAULT_DAEMON)) + .unwrap_or_else(|| PathBuf::from(DEFAULT_DAEMON)); + + if !daemon_path.exists() { + let exe_dir = env::current_exe() + .ok() + .and_then(|p| p.parent().map(|d| d.to_path_buf())); + let alt_path = exe_dir.map(|d| d.join(DEFAULT_DAEMON)); + + if let Some(alt) = &alt_path { + if alt.exists() { + return spawn_daemon(alt); + } + } + + println!("Note: {} not found, skipping daemon launch", DEFAULT_DAEMON); + println!(); + return None; + } + + spawn_daemon(&daemon_path) +} + +/// Find a video file to auto-load (checks next to the launcher/DLL). +fn find_video_file() -> Option { + let exe_dir = env::current_exe() + .ok() + .and_then(|p| p.parent().map(|d| d.to_path_buf()))?; + + // Check for nms_video.txt config first + let config_path = exe_dir.join("nms_video.txt"); + if config_path.exists() { + if let Ok(content) = std::fs::read_to_string(&config_path) { + let path = content.trim().to_string(); + if !path.is_empty() { + let p = PathBuf::from(&path); + if p.exists() { + return Some(p); + } + } + } + } + + // Check for nms_video.mp4 + let video_path = exe_dir.join("nms_video.mp4"); + if video_path.exists() { + return Some(video_path); + } + + None +} + +/// Spawn the daemon process detached, optionally with --load. +/// Returns the daemon's PID if successfully started. +fn spawn_daemon(path: &Path) -> Option { + // Write daemon logs to a file so we can diagnose issues + let log_path = env::temp_dir().join("nms_video_daemon.log"); + println!("Daemon log: {}", log_path.display()); + + let log_file = std::fs::File::create(&log_path).ok(); + let stderr_cfg = match &log_file { + Some(f) => Stdio::from( + f.try_clone() + .unwrap_or_else(|_| std::fs::File::create(&log_path).expect("create log")), + ), + None => Stdio::null(), + }; + + let mut cmd = Command::new(path); + cmd.stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(stderr_cfg); + + // Enable debug logging for diagnostics + cmd.arg("--log-level").arg("debug"); + + if let Some(video) = find_video_file() { + println!("Video: {}", video.display()); + cmd.arg("--load").arg(&video); + } + + let pid = match cmd.spawn() { + Ok(child) => { + let pid = child.id(); + println!("Daemon started: PID {}", pid); + Some(pid) + }, + Err(e) => { + println!("Warning: failed to start daemon: {}", e); + None + }, + }; + println!(); + pid +} + +/// Poll for an NMS.exe process. Returns the PID when found. +/// +/// NMS may re-spawn itself during startup. Strategy: +/// 1. First look for a re-spawned process (different PID) for up to 10 seconds +/// 2. If not found, accept the initial PID (NMS didn't re-spawn) +fn wait_for_nms(initial_pid: u32) -> Option { + let start = Instant::now(); + let respawn_timeout = Duration::from_secs(10); + + // Phase 1: Look for a re-spawned process (skip initial PID) + while start.elapsed() < respawn_timeout { + if let Some(pid) = find_nms_process(initial_pid) { + return Some(pid); + } + thread::sleep(POLL_INTERVAL); + } + + // Phase 2: NMS didn't re-spawn, accept initial PID if still running + println!("No re-spawn detected, checking initial PID..."); + find_any_nms_process() +} + +/// Find any NMS.exe process (no PID skipping). +fn find_any_nms_process() -> Option { + find_nms_process(0) +} + +/// Find an NMS.exe process by scanning the process list, skipping `skip_pid`. +fn find_nms_process(skip_pid: u32) -> Option { + use windows::Win32::System::Diagnostics::ToolHelp::{ + CreateToolhelp32Snapshot, PROCESSENTRY32W, Process32FirstW, Process32NextW, + TH32CS_SNAPPROCESS, + }; + + unsafe { + let snapshot = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0).ok()?; + + let mut entry = PROCESSENTRY32W { + dwSize: std::mem::size_of::() as u32, + ..Default::default() + }; + + if Process32FirstW(snapshot, &mut entry).is_err() { + let _ = CloseHandle(snapshot); + return None; + } + + let mut found_pid = None; + + loop { + let name: String = entry + .szExeFile + .iter() + .take_while(|&&c| c != 0) + .map(|&c| char::from_u32(c as u32).unwrap_or('?')) + .collect(); + + if name.eq_ignore_ascii_case("NMS.exe") && entry.th32ProcessID != skip_pid { + found_pid = Some(entry.th32ProcessID); + } + + if Process32NextW(snapshot, &mut entry).is_err() { + break; + } + } + + let _ = CloseHandle(snapshot); + found_pid + } +} + +/// Inject the DLL into a running process by PID. +unsafe fn inject_into_pid(pid: u32, dll_path: &Path) -> Result<(), String> { + let dll_wide = to_wide(dll_path.to_str().unwrap_or("")); + let dll_path_bytes = dll_wide.len() * 2; + + // Open the target process + let access = + PROCESS_CREATE_THREAD | PROCESS_VM_OPERATION | PROCESS_VM_WRITE | PROCESS_QUERY_INFORMATION; + + let process = OpenProcess(access, false, pid) + .map_err(|e| format!("OpenProcess({}) failed: {} (try running as admin)", pid, e))?; + + let result = do_inject(process, &dll_wide, dll_path_bytes); + + let _ = CloseHandle(process); + result +} + +/// Core injection logic: allocate, write, CreateRemoteThread(LoadLibraryW). +unsafe fn do_inject( + process: windows::Win32::Foundation::HANDLE, + dll_wide: &[u16], + dll_path_bytes: usize, +) -> Result<(), String> { + // Allocate memory in target process for the DLL path + let remote_buf = VirtualAllocEx( + process, + None, + dll_path_bytes, + MEM_COMMIT | MEM_RESERVE, + PAGE_READWRITE, + ); + + if remote_buf.is_null() { + return Err("VirtualAllocEx failed: could not allocate in target process".into()); + } + + println!("Allocated {} bytes at {:p}", dll_path_bytes, remote_buf); + + // Write DLL path into target memory + let write_result = WriteProcessMemory( + process, + remote_buf, + dll_wide.as_ptr() as *const _, + dll_path_bytes, + None, + ); + + if write_result.is_err() { + let _ = VirtualFreeEx(process, remote_buf, 0, MEM_RELEASE); + return Err("WriteProcessMemory failed".into()); + } + + // Get LoadLibraryW address + let kernel32_name = b"kernel32.dll\0"; + let kernel32 = GetModuleHandleA(PCSTR(kernel32_name.as_ptr())) + .map_err(|e| format!("GetModuleHandleA(kernel32.dll) failed: {}", e))?; + + let load_library_name = b"LoadLibraryW\0"; + let load_library_addr = GetProcAddress(kernel32, PCSTR(load_library_name.as_ptr())); + + let load_library_addr = match load_library_addr { + Some(addr) => addr, + None => { + let _ = VirtualFreeEx(process, remote_buf, 0, MEM_RELEASE); + return Err("GetProcAddress(LoadLibraryW) failed".into()); + }, + }; + + println!("LoadLibraryW at {:p}", load_library_addr as *const ()); + + // Create remote thread to call LoadLibraryW(dll_path) + let thread = CreateRemoteThread( + process, + None, + 0, + Some(std::mem::transmute::< + unsafe extern "system" fn() -> isize, + unsafe extern "system" fn(*mut c_void) -> u32, + >(load_library_addr)), + Some(remote_buf), + 0, + None, + ) + .map_err(|e| format!("CreateRemoteThread failed: {}", e))?; + + println!("Remote thread created, waiting for DLL load..."); + + // Wait for DLL load to complete + let wait_result = WaitForSingleObject(thread, INFINITE); + if wait_result != WAIT_OBJECT_0 { + let _ = CloseHandle(thread); + let _ = VirtualFreeEx(process, remote_buf, 0, MEM_RELEASE); + return Err(format!("WaitForSingleObject returned {:?}", wait_result)); + } + + let _ = CloseHandle(thread); + let _ = VirtualFreeEx(process, remote_buf, 0, MEM_RELEASE); + + println!("DLL loaded successfully"); + Ok(()) +} + +/// Wait for a process to exit. +fn wait_for_process_exit(pid: u32) { + unsafe { + let handle = OpenProcess(PROCESS_SYNCHRONIZE, false, pid); + match handle { + Ok(h) => { + // Wait indefinitely for the process to exit + WaitForSingleObject(h, INFINITE); + let _ = CloseHandle(h); + }, + Err(_) => { + // Process already exited or can't be opened + }, + } + } +} + +/// Terminate a process by PID. +fn kill_process(pid: u32) { + unsafe { + let handle = OpenProcess(PROCESS_TERMINATE, false, pid); + match handle { + Ok(h) => { + let _ = TerminateProcess(h, 0); + let _ = CloseHandle(h); + }, + Err(_) => { + // Process already exited + }, + } + } +} + +/// Convert a string to a null-terminated wide (UTF-16) string. +fn to_wide(s: &str) -> Vec { + s.encode_utf16().chain(std::iter::once(0)).collect() +} diff --git a/projects/nms-cockpit-video/mod/NmsCockpitOverlay.sln b/projects/nms-cockpit-video/mod/NmsCockpitOverlay.sln new file mode 100644 index 0000000..6c26e40 --- /dev/null +++ b/projects/nms-cockpit-video/mod/NmsCockpitOverlay.sln @@ -0,0 +1,18 @@ +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 17 +VisualStudioVersion = 17.0.31903.59 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "NmsCockpitOverlay", "NmsCockpitOverlay\NmsCockpitOverlay.csproj", "{A1B2C3D4-E5F6-7890-ABCD-EF1234567890}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|x64 = Debug|x64 + Release|x64 = Release|x64 + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {A1B2C3D4-E5F6-7890-ABCD-EF1234567890}.Debug|x64.ActiveCfg = Debug|x64 + {A1B2C3D4-E5F6-7890-ABCD-EF1234567890}.Debug|x64.Build.0 = Debug|x64 + {A1B2C3D4-E5F6-7890-ABCD-EF1234567890}.Release|x64.ActiveCfg = Release|x64 + {A1B2C3D4-E5F6-7890-ABCD-EF1234567890}.Release|x64.Build.0 = Release|x64 + EndGlobalSection +EndGlobal diff --git a/projects/nms-cockpit-video/mod/NmsCockpitOverlay/CockpitTracker/MatrixReader.cs b/projects/nms-cockpit-video/mod/NmsCockpitOverlay/CockpitTracker/MatrixReader.cs new file mode 100644 index 0000000..5fc9ccf --- /dev/null +++ b/projects/nms-cockpit-video/mod/NmsCockpitOverlay/CockpitTracker/MatrixReader.cs @@ -0,0 +1,222 @@ +using System.Diagnostics; +using System.Numerics; +using System.Runtime.InteropServices; +using Reloaded.Memory.SigScan; +using Reloaded.Memory.SigScan.Definitions.Structs; +using Reloaded.Mod.Interfaces; + +namespace NmsCockpitOverlay.CockpitTracker; + +/// +/// Reads view and projection matrices from No Man's Sky memory using signature scanning. +/// +/// IMPORTANT: These signatures are specific to particular NMS versions. +/// After game updates, the signatures may need to be updated. +/// +/// Finding new signatures: +/// 1. Use x64dbg or IDA to find the camera update function +/// 2. Look for matrix multiplication sequences near player tick +/// 3. The view matrix is typically near the camera/player controller +/// 4. The projection matrix is updated when FOV changes +/// +/// Pattern format: "XX XX XX ?? ?? XX" where ?? is wildcard +/// +public class MatrixReader +{ + private readonly ILogger? _logger; + private IntPtr _viewMatrixPtr; + private IntPtr _projMatrixPtr; + private bool _isInitialized; + + // Known signatures for different NMS versions + // Format: (exeHash, viewPattern, viewOffset, projPattern, projOffset) + private static readonly Dictionary KnownSignatures = new() + { + // NMS 4.x - Update these patterns after finding them in a specific version + // The patterns here are PLACEHOLDERS - real patterns must be discovered + // by reverse engineering each NMS version + ["placeholder"] = new SignatureSet + { + ViewPattern = "48 8B ?? ?? ?? ?? ?? 48 85 C0 74 ?? F3 0F 10", + ViewOffset = 3, + ProjPattern = "F3 0F 10 ?? ?? ?? ?? ?? F3 0F 10 ?? ?? ?? ?? ?? 0F 28", + ProjOffset = 4, + } + }; + + /// + /// Signature patterns and offsets for a specific NMS version. + /// + private record SignatureSet + { + public required string ViewPattern { get; init; } + public required int ViewOffset { get; init; } + public required string ProjPattern { get; init; } + public required int ProjOffset { get; init; } + } + + public MatrixReader(ILogger? logger) + { + _logger = logger; + } + + /// + /// Whether initialization was successful. + /// + public bool IsInitialized => _isInitialized; + + /// + /// Initialize the matrix reader by scanning for signatures. + /// + public bool Initialize() + { + try + { + var process = Process.GetCurrentProcess(); + var mainModule = process.MainModule; + if (mainModule == null) + { + _logger?.WriteLine("[MatrixReader] Could not get main module"); + return false; + } + + var baseAddress = mainModule.BaseAddress; + var moduleSize = mainModule.ModuleMemorySize; + + _logger?.WriteLine($"[MatrixReader] Scanning module: {mainModule.ModuleName}"); + _logger?.WriteLine($"[MatrixReader] Base: 0x{baseAddress.ToInt64():X}, Size: {moduleSize}"); + + // Calculate EXE hash for version detection + var exeHash = CalculateExeHash(mainModule.FileName); + _logger?.WriteLine($"[MatrixReader] EXE hash: {exeHash}"); + + // Try to find matching signatures + if (!TryFindSignatures(baseAddress, moduleSize)) + { + _logger?.WriteLine("[MatrixReader] No matching signatures found"); + _logger?.WriteLine("[MatrixReader] This game version may require new signatures"); + return false; + } + + _isInitialized = true; + return true; + } + catch (Exception ex) + { + _logger?.WriteLine($"[MatrixReader] Initialization failed: {ex.Message}"); + return false; + } + } + + /// + /// Try to find camera matrix signatures in memory. + /// + private bool TryFindSignatures(IntPtr baseAddress, int moduleSize) + { + // For now, return false since we don't have real signatures + // Real implementation would: + // 1. Create a Scanner with the module memory region + // 2. Search for the view matrix pattern + // 3. Search for the projection matrix pattern + // 4. Resolve relative addresses to absolute pointers + + _logger?.WriteLine("[MatrixReader] WARNING: Using placeholder signatures"); + _logger?.WriteLine("[MatrixReader] Real signatures must be discovered for each NMS version"); + + // Placeholder: In a real implementation, you would: + // + // using var scanner = new Scanner(process, baseAddress, moduleSize); + // + // var viewResult = scanner.FindPattern(signatureSet.ViewPattern); + // if (!viewResult.Found) return false; + // + // var viewInstructionAddr = baseAddress + viewResult.Offset; + // var viewRelativeOffset = Marshal.ReadInt32(viewInstructionAddr + signatureSet.ViewOffset); + // _viewMatrixPtr = viewInstructionAddr + signatureSet.ViewOffset + 4 + viewRelativeOffset; + // + // Similar for projection matrix... + + return false; + } + + /// + /// Try to read the current view and projection matrices. + /// + public bool TryReadMatrices(out Matrix4x4 view, out Matrix4x4 proj) + { + view = Matrix4x4.Identity; + proj = Matrix4x4.Identity; + + if (!_isInitialized || _viewMatrixPtr == IntPtr.Zero || _projMatrixPtr == IntPtr.Zero) + { + return false; + } + + try + { + // Read raw matrix data from memory + unsafe + { + var viewPtr = (float*)_viewMatrixPtr.ToPointer(); + var projPtr = (float*)_projMatrixPtr.ToPointer(); + + view = new Matrix4x4( + viewPtr[0], viewPtr[1], viewPtr[2], viewPtr[3], + viewPtr[4], viewPtr[5], viewPtr[6], viewPtr[7], + viewPtr[8], viewPtr[9], viewPtr[10], viewPtr[11], + viewPtr[12], viewPtr[13], viewPtr[14], viewPtr[15] + ); + + proj = new Matrix4x4( + projPtr[0], projPtr[1], projPtr[2], projPtr[3], + projPtr[4], projPtr[5], projPtr[6], projPtr[7], + projPtr[8], projPtr[9], projPtr[10], projPtr[11], + projPtr[12], projPtr[13], projPtr[14], projPtr[15] + ); + } + + // Validate matrices (basic sanity check) + if (!IsValidMatrix(view) || !IsValidMatrix(proj)) + { + return false; + } + + return true; + } + catch + { + return false; + } + } + + /// + /// Basic validation that a matrix contains finite values. + /// + private static bool IsValidMatrix(Matrix4x4 m) + { + return float.IsFinite(m.M11) && float.IsFinite(m.M22) && float.IsFinite(m.M33) && + float.IsFinite(m.M44) && Math.Abs(m.M44) > 0.0001f; + } + + /// + /// Calculate a hash of the EXE for version detection. + /// + private static string CalculateExeHash(string? path) + { + if (string.IsNullOrEmpty(path) || !File.Exists(path)) + { + return "unknown"; + } + + try + { + // Use file size + timestamp as a quick hash + var info = new FileInfo(path); + return $"{info.Length:X}_{info.LastWriteTimeUtc.Ticks:X}"; + } + catch + { + return "error"; + } + } +} diff --git a/projects/nms-cockpit-video/mod/NmsCockpitOverlay/CockpitTracker/ScreenProjection.cs b/projects/nms-cockpit-video/mod/NmsCockpitOverlay/CockpitTracker/ScreenProjection.cs new file mode 100644 index 0000000..90c0d03 --- /dev/null +++ b/projects/nms-cockpit-video/mod/NmsCockpitOverlay/CockpitTracker/ScreenProjection.cs @@ -0,0 +1,178 @@ +using System.Numerics; + +namespace NmsCockpitOverlay.CockpitTracker; + +/// +/// Computes the screen-space bounding rectangle for the cockpit display screen. +/// +/// The cockpit screen is defined as a quad in 3D space relative to the player's +/// cockpit. This class projects those world-space corners through the view and +/// projection matrices to get 2D screen coordinates. +/// +public class ScreenProjection +{ + // Cockpit screen corners in local cockpit space (meters) + // These values define where the "screen" is located relative to the pilot's view + // Adjust these based on the specific cockpit geometry in NMS + private static readonly Vector3[] CockpitScreenCorners = + [ + new Vector3(-0.25f, 0.15f, 0.5f), // Top-left + new Vector3(0.25f, 0.15f, 0.5f), // Top-right + new Vector3(-0.25f, -0.05f, 0.5f), // Bottom-left + new Vector3(0.25f, -0.05f, 0.5f), // Bottom-right + ]; + + // Minimum screen size threshold (pixels) + private const float MinScreenSize = 50f; + + // Maximum screen size as fraction of viewport + private const float MaxScreenFraction = 0.8f; + + /// + /// Result of projecting the cockpit screen to 2D. + /// + public readonly struct ScreenRect + { + public float X { get; init; } + public float Y { get; init; } + public float Width { get; init; } + public float Height { get; init; } + public float Rotation { get; init; } + public bool Visible { get; init; } + } + + /// + /// Compute the screen-space rectangle for the cockpit display. + /// + /// View matrix from the game camera. + /// Projection matrix from the game camera. + /// Viewport width in pixels (default 1920). + /// Viewport height in pixels (default 1080). + /// Screen rect if visible, null if not visible. + public ScreenRect? ComputeCockpitRect( + Matrix4x4 view, + Matrix4x4 proj, + float viewportWidth = 1920f, + float viewportHeight = 1080f) + { + // Compute view-projection matrix + var viewProj = view * proj; + + // Project each corner to screen space + var screenPoints = new Vector2[4]; + var allVisible = true; + + for (int i = 0; i < 4; i++) + { + var worldPos = CockpitScreenCorners[i]; + var projected = ProjectToScreen(worldPos, viewProj, viewportWidth, viewportHeight, out var visible); + + if (!visible) + { + allVisible = false; + break; + } + + screenPoints[i] = projected; + } + + if (!allVisible) + { + return null; + } + + // Compute axis-aligned bounding box + var minX = float.MaxValue; + var minY = float.MaxValue; + var maxX = float.MinValue; + var maxY = float.MinValue; + + foreach (var point in screenPoints) + { + minX = Math.Min(minX, point.X); + minY = Math.Min(minY, point.Y); + maxX = Math.Max(maxX, point.X); + maxY = Math.Max(maxY, point.Y); + } + + var width = maxX - minX; + var height = maxY - minY; + + // Validate size + if (width < MinScreenSize || height < MinScreenSize) + { + return null; + } + + // Clamp to viewport bounds + minX = Math.Max(0, minX); + minY = Math.Max(0, minY); + maxX = Math.Min(viewportWidth, maxX); + maxY = Math.Min(viewportHeight, maxY); + + width = maxX - minX; + height = maxY - minY; + + // Check if too large (likely incorrect projection) + if (width > viewportWidth * MaxScreenFraction || height > viewportHeight * MaxScreenFraction) + { + return null; + } + + // Compute rotation from top edge (for perspective correction) + var topLeft = screenPoints[0]; + var topRight = screenPoints[1]; + var rotation = MathF.Atan2(topRight.Y - topLeft.Y, topRight.X - topLeft.X); + + return new ScreenRect + { + X = minX, + Y = minY, + Width = width, + Height = height, + Rotation = rotation, + Visible = true + }; + } + + /// + /// Project a 3D point to 2D screen coordinates. + /// + private static Vector2 ProjectToScreen( + Vector3 worldPos, + Matrix4x4 viewProj, + float viewportWidth, + float viewportHeight, + out bool visible) + { + // Transform to clip space + var clipPos = Vector4.Transform(new Vector4(worldPos, 1.0f), viewProj); + + // Check if behind camera + if (clipPos.W <= 0.0001f) + { + visible = false; + return Vector2.Zero; + } + + // Perspective divide to NDC (-1 to 1) + var ndcX = clipPos.X / clipPos.W; + var ndcY = clipPos.Y / clipPos.W; + var ndcZ = clipPos.Z / clipPos.W; + + // Check if outside view frustum + if (ndcX < -1 || ndcX > 1 || ndcY < -1 || ndcY > 1 || ndcZ < 0 || ndcZ > 1) + { + visible = false; + return Vector2.Zero; + } + + // Convert to screen coordinates + // Note: Y is inverted (NDC y=1 is top, screen y=0 is top) + var screenX = (ndcX + 1) * 0.5f * viewportWidth; + var screenY = (1 - ndcY) * 0.5f * viewportHeight; + + visible = true; + return new Vector2(screenX, screenY); + } +} diff --git a/projects/nms-cockpit-video/mod/NmsCockpitOverlay/Ipc/PipeClient.cs b/projects/nms-cockpit-video/mod/NmsCockpitOverlay/Ipc/PipeClient.cs new file mode 100644 index 0000000..c0946e5 --- /dev/null +++ b/projects/nms-cockpit-video/mod/NmsCockpitOverlay/Ipc/PipeClient.cs @@ -0,0 +1,190 @@ +using System.IO.Pipes; +using Reloaded.Mod.Interfaces; + +namespace NmsCockpitOverlay.Ipc; + +/// +/// Named pipe client for sending screen rect data to the video overlay daemon. +/// +/// Wire format (ITK Protocol): +/// - 4 bytes: Magic "ITKP" +/// - 4 bytes: Version (1) +/// - 4 bytes: Message type (10 = ScreenRect) +/// - 4 bytes: Payload length +/// - 4 bytes: CRC32 of payload +/// - N bytes: Payload (bincode-encoded ScreenRect) +/// +public class PipeClient : IDisposable +{ + private const string MagicBytes = "ITKP"; + private const uint ProtocolVersion = 1; + private const uint MessageTypeScreenRect = 10; + + private readonly string _pipeName; + private readonly ILogger? _logger; + private NamedPipeClientStream? _pipe; + private bool _isConnected; + private readonly object _lock = new(); + + public PipeClient(string pipeName, ILogger? logger) + { + _pipeName = pipeName; + _logger = logger; + } + + /// + /// Try to connect to the daemon's named pipe. + /// + public bool TryConnect() + { + lock (_lock) + { + if (_isConnected && _pipe?.IsConnected == true) + { + return true; + } + + try + { + _pipe?.Dispose(); + _pipe = new NamedPipeClientStream( + ".", + _pipeName, + PipeDirection.Out, + PipeOptions.Asynchronous); + + // Non-blocking connect attempt + _pipe.Connect(100); + _isConnected = true; + _logger?.WriteLine($"[PipeClient] Connected to {_pipeName}"); + return true; + } + catch (TimeoutException) + { + // Daemon not available yet + return false; + } + catch (Exception ex) + { + _logger?.WriteLine($"[PipeClient] Connection error: {ex.Message}"); + return false; + } + } + } + + /// + /// Disconnect from the pipe. + /// + public void Disconnect() + { + lock (_lock) + { + _pipe?.Dispose(); + _pipe = null; + _isConnected = false; + } + } + + /// + /// Send a screen rect update to the daemon. + /// + public async Task SendScreenRectAsync(CockpitTracker.ScreenProjection.ScreenRect rect) + { + if (!_isConnected) + { + TryConnect(); + if (!_isConnected) return; + } + + try + { + var message = EncodeScreenRect(rect); + await _pipe!.WriteAsync(message, 0, message.Length); + await _pipe.FlushAsync(); + } + catch (IOException) + { + // Pipe broken, try to reconnect next time + _isConnected = false; + } + catch (Exception ex) + { + _logger?.WriteLine($"[PipeClient] Send error: {ex.Message}"); + _isConnected = false; + } + } + + /// + /// Encode a ScreenRect to ITK protocol wire format. + /// + private static byte[] EncodeScreenRect(CockpitTracker.ScreenProjection.ScreenRect rect) + { + // Payload: x, y, width, height, rotation, visible (f32, f32, f32, f32, f32, bool) + // Using manual binary encoding to match Rust bincode format + using var payloadStream = new MemoryStream(); + using var payloadWriter = new BinaryWriter(payloadStream); + + // Write floats as little-endian + payloadWriter.Write(rect.X); + payloadWriter.Write(rect.Y); + payloadWriter.Write(rect.Width); + payloadWriter.Write(rect.Height); + payloadWriter.Write(rect.Rotation); + payloadWriter.Write(rect.Visible ? (byte)1 : (byte)0); + + var payload = payloadStream.ToArray(); + + // Calculate CRC32 + var crc = CalculateCrc32(payload); + + // Build full message + using var messageStream = new MemoryStream(); + using var messageWriter = new BinaryWriter(messageStream); + + // Header + messageWriter.Write(System.Text.Encoding.ASCII.GetBytes(MagicBytes)); // 4 bytes + messageWriter.Write(ProtocolVersion); // 4 bytes (little-endian) + messageWriter.Write(MessageTypeScreenRect); // 4 bytes + messageWriter.Write((uint)payload.Length); // 4 bytes + messageWriter.Write(crc); // 4 bytes + + // Payload + messageWriter.Write(payload); + + return messageStream.ToArray(); + } + + /// + /// Calculate CRC32 of data (matches crc32fast crate). + /// + private static uint CalculateCrc32(byte[] data) + { + // CRC32 with IEEE polynomial (same as crc32fast) + const uint polynomial = 0xEDB88320; + var table = new uint[256]; + + for (uint i = 0; i < 256; i++) + { + var crc = i; + for (int j = 0; j < 8; j++) + { + crc = (crc & 1) == 1 ? (crc >> 1) ^ polynomial : crc >> 1; + } + table[i] = crc; + } + + uint result = 0xFFFFFFFF; + foreach (var b in data) + { + result = table[(result ^ b) & 0xFF] ^ (result >> 8); + } + + return result ^ 0xFFFFFFFF; + } + + public void Dispose() + { + Disconnect(); + GC.SuppressFinalize(this); + } +} diff --git a/projects/nms-cockpit-video/mod/NmsCockpitOverlay/Mod.cs b/projects/nms-cockpit-video/mod/NmsCockpitOverlay/Mod.cs new file mode 100644 index 0000000..555720d --- /dev/null +++ b/projects/nms-cockpit-video/mod/NmsCockpitOverlay/Mod.cs @@ -0,0 +1,196 @@ +using Reloaded.Mod.Interfaces; +using NmsCockpitOverlay.CockpitTracker; +using NmsCockpitOverlay.Ipc; + +namespace NmsCockpitOverlay; + +/// +/// Main entry point for the NMS Cockpit Overlay mod. +/// Extracts camera matrices and cockpit screen coordinates from No Man's Sky +/// and sends them to the video overlay daemon via named pipe. +/// +public class Mod : IMod +{ + /// + /// Used for logging to the Reloaded-II console. + /// + private ILogger? _logger; + + /// + /// Mod configuration (user-adjustable settings). + /// + private IModConfig? _config; + + /// + /// Signature scanner for finding camera matrices. + /// + private MatrixReader? _matrixReader; + + /// + /// Computes screen-space projection of cockpit screen. + /// + private ScreenProjection? _screenProjection; + + /// + /// Named pipe client for sending data to daemon. + /// + private PipeClient? _pipeClient; + + /// + /// Whether the mod is currently active. + /// + private bool _isActive; + + /// + /// Token for cancelling the update loop. + /// + private CancellationTokenSource? _cts; + + /// + /// Background task for the update loop. + /// + private Task? _updateTask; + + /// + /// Called when the mod is first loaded. + /// + public void Start(IModLoaderV1 loader) + { + _config = loader.GetModConfig(); + _logger = loader.GetLogger(); + + _logger.WriteLine("[NMS Cockpit Overlay] Starting..."); + + try + { + // Initialize components + _pipeClient = new PipeClient("nms_cockpit_injector", _logger); + _matrixReader = new MatrixReader(_logger); + _screenProjection = new ScreenProjection(); + + // Try to find signatures + if (_matrixReader.Initialize()) + { + _logger.WriteLine("[NMS Cockpit Overlay] Signature scan successful"); + StartUpdateLoop(); + } + else + { + _logger.WriteLine("[NMS Cockpit Overlay] WARNING: Signature scan failed - mod disabled"); + _logger.WriteLine("[NMS Cockpit Overlay] This game version may not be supported"); + } + } + catch (Exception ex) + { + _logger.WriteLine($"[NMS Cockpit Overlay] ERROR: {ex.Message}"); + } + } + + /// + /// Start the background update loop. + /// + private void StartUpdateLoop() + { + _isActive = true; + _cts = new CancellationTokenSource(); + _updateTask = Task.Run(UpdateLoopAsync, _cts.Token); + _logger?.WriteLine("[NMS Cockpit Overlay] Update loop started"); + } + + /// + /// Background loop that reads camera matrices and sends updates. + /// + private async Task UpdateLoopAsync() + { + const int targetFps = 60; + const int frameTimeMs = 1000 / targetFps; + + while (!_cts!.Token.IsCancellationRequested && _isActive) + { + try + { + var startTime = Environment.TickCount64; + + // Read camera matrices + if (_matrixReader!.TryReadMatrices(out var view, out var proj)) + { + // Compute screen rect + var rect = _screenProjection!.ComputeCockpitRect(view, proj); + + // Send to daemon + if (rect.HasValue) + { + await _pipeClient!.SendScreenRectAsync(rect.Value); + } + } + + // Sleep to maintain target FPS + var elapsed = Environment.TickCount64 - startTime; + var sleepTime = (int)(frameTimeMs - elapsed); + if (sleepTime > 0) + { + await Task.Delay(sleepTime, _cts.Token); + } + } + catch (OperationCanceledException) + { + break; + } + catch (Exception ex) + { + _logger?.WriteLine($"[NMS Cockpit Overlay] Update error: {ex.Message}"); + await Task.Delay(1000, _cts.Token); // Back off on error + } + } + } + + /// + /// Called when the mod can be unloaded. + /// + public void Suspend() + { + _isActive = false; + _cts?.Cancel(); + _updateTask?.Wait(1000); + _pipeClient?.Disconnect(); + _logger?.WriteLine("[NMS Cockpit Overlay] Suspended"); + } + + /// + /// Called when the mod is being reloaded. + /// + public void Resume() + { + if (_matrixReader?.IsInitialized == true) + { + _pipeClient?.TryConnect(); + StartUpdateLoop(); + _logger?.WriteLine("[NMS Cockpit Overlay] Resumed"); + } + } + + /// + /// Called when the mod is being unloaded permanently. + /// + public void Unload() + { + Suspend(); + _pipeClient?.Dispose(); + _logger?.WriteLine("[NMS Cockpit Overlay] Unloaded"); + } + + /// + /// Whether this mod can be unloaded. + /// + public bool CanUnload() => true; + + /// + /// Whether this mod can be suspended. + /// + public bool CanSuspend() => true; + + /// + /// Mod action (unused). + /// + public Action? Disposing { get; } +} diff --git a/projects/nms-cockpit-video/mod/NmsCockpitOverlay/ModConfig.json b/projects/nms-cockpit-video/mod/NmsCockpitOverlay/ModConfig.json new file mode 100644 index 0000000..c03a79c --- /dev/null +++ b/projects/nms-cockpit-video/mod/NmsCockpitOverlay/ModConfig.json @@ -0,0 +1,33 @@ +{ + "$schema": "https://raw.githubusercontent.com/Reloaded-Project/Reloaded-II/master/docs/Schemas/Mod/1.0.0.json", + "ModId": "nms.cockpit.overlay", + "ModName": "NMS Cockpit Video Overlay", + "ModAuthor": "ITK", + "ModVersion": "1.0.0", + "ModDescription": "Extracts cockpit screen coordinates from No Man's Sky for video overlay rendering", + "ModDll": "NmsCockpitOverlay.dll", + "ModIcon": "", + "ModR2RManagedDll32": "", + "ModR2RManagedDll64": "", + "ModNativeDll32": "", + "ModNativeDll64": "", + "Tags": [ + "Video", + "Overlay", + "Cockpit" + ], + "CanUnload": true, + "HasExports": false, + "IsLibrary": false, + "ReleaseMetadataFileName": "", + "ProjectUrl": "https://github.com/AndrewAltimit/game-mods", + "PluginData": "", + "IsUniversalMod": false, + "ModDependencies": [], + "OptionalDependencies": [], + "SupportedAppId": [ + "nms.exe", + "NMS.exe" + ], + "Author": "ITK" +} diff --git a/projects/nms-cockpit-video/mod/NmsCockpitOverlay/NmsCockpitOverlay.csproj b/projects/nms-cockpit-video/mod/NmsCockpitOverlay/NmsCockpitOverlay.csproj new file mode 100644 index 0000000..c0181cb --- /dev/null +++ b/projects/nms-cockpit-video/mod/NmsCockpitOverlay/NmsCockpitOverlay.csproj @@ -0,0 +1,23 @@ + + + + net8.0-windows + enable + enable + true + x64 + Library + win-x64 + + + + + + + + + + + + + diff --git a/projects/nms-cockpit-video/overlay/Cargo.toml b/projects/nms-cockpit-video/overlay/Cargo.toml new file mode 100644 index 0000000..2baeca7 --- /dev/null +++ b/projects/nms-cockpit-video/overlay/Cargo.toml @@ -0,0 +1,62 @@ +[package] +name = "nms-video-overlay" +description = "Video overlay for No Man's Sky Cockpit Video Player" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +authors.workspace = true + +[[bin]] +name = "nms-video-overlay" +path = "src/main.rs" + +[features] +default = [] + +[dependencies] +# Core ITK libraries +itk-protocol = { path = "../../../core/itk-protocol" } +itk-shmem = { path = "../../../core/itk-shmem" } +itk-ipc = { path = "../../../core/itk-ipc" } + +# Windows platform APIs +[target.'cfg(windows)'.dependencies] +windows = { version = "0.58", features = [ + "Win32_Foundation", + "Win32_UI_WindowsAndMessaging", + "Win32_UI_Input_KeyboardAndMouse", +] } + +# GUI +egui = "0.28" +egui-wgpu = "0.28" +egui-winit = "0.28" + +# Async runtime +tokio = { workspace = true } + +# Serialization +serde = { workspace = true } +serde_json = "1.0" + +# Logging +tracing = { workspace = true } +tracing-subscriber = { workspace = true } + +# Error handling +thiserror = { workspace = true } +anyhow = { workspace = true } + +# Graphics - must match egui-wgpu 0.28's wgpu version +wgpu = "0.20" +winit = { version = "0.29", features = ["rwh_06"] } +pollster = "0.3" +bytemuck = { version = "1.14", features = ["derive"] } +raw-window-handle = "0.6" + +# CLI +clap = { version = "4", features = ["derive"] } + +[lints] +workspace = true diff --git a/projects/nms-cockpit-video/overlay/src/main.rs b/projects/nms-cockpit-video/overlay/src/main.rs new file mode 100644 index 0000000..cedeadd --- /dev/null +++ b/projects/nms-cockpit-video/overlay/src/main.rs @@ -0,0 +1,435 @@ +//! NMS Cockpit Video Overlay +//! +//! Overlay application for the No Man's Sky Cockpit Video Player. +//! Displays video on cockpit screen with egui controls. + +use anyhow::{Context, Result}; +use clap::Parser; +use itk_ipc::IpcChannel; +use itk_protocol::{MessageType, ScreenRect, VideoLoad, VideoPause, VideoPlay, VideoSeek, encode}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tracing::{debug, error, info, warn}; +use winit::{ + event::{Event, WindowEvent}, + event_loop::{ControlFlow, EventLoop}, + window::{WindowBuilder, WindowLevel}, +}; + +mod platform; +mod render; +mod ui; +mod video; + +use render::NmsRenderer; +use ui::VideoControls; +use video::VideoFrameReader; + +/// NMS Cockpit Video Overlay +#[derive(Parser, Debug)] +#[command(name = "nms-video-overlay")] +#[command(about = "Video overlay for No Man's Sky Cockpit Video Player")] +struct Args { + /// Daemon IPC channel name + #[arg(long, default_value = "nms_cockpit_client")] + daemon_channel: String, + + /// Window width + #[arg(long, default_value = "1920")] + width: u32, + + /// Window height + #[arg(long, default_value = "1080")] + height: u32, + + /// Video rectangle position and size: x,y,w,h (screen pixels) + /// Default: centered 1280x720 + #[arg(long, value_parser = parse_video_rect)] + video_rect: Option, + + /// Log level + #[arg(long, default_value = "info")] + log_level: String, +} + +/// Parse a video rect string "x,y,w,h" into a ScreenRect. +fn parse_video_rect(s: &str) -> Result { + let parts: Vec<&str> = s.split(',').collect(); + if parts.len() != 4 { + return Err("Expected format: x,y,w,h (e.g., 320,180,1280,720)".to_string()); + } + let x: f32 = parts[0] + .trim() + .parse() + .map_err(|e| format!("Invalid x: {e}"))?; + let y: f32 = parts[1] + .trim() + .parse() + .map_err(|e| format!("Invalid y: {e}"))?; + let w: f32 = parts[2] + .trim() + .parse() + .map_err(|e| format!("Invalid width: {e}"))?; + let h: f32 = parts[3] + .trim() + .parse() + .map_err(|e| format!("Invalid height: {e}"))?; + Ok(ScreenRect { + x, + y, + width: w, + height: h, + rotation: 0.0, + visible: true, + }) +} + +/// Default video rectangle: centered, 720p on a 1920x1080 screen. +const DEFAULT_VIDEO_RECT: ScreenRect = ScreenRect { + x: 320.0, + y: 180.0, + width: 1280.0, + height: 720.0, + rotation: 0.0, + visible: true, +}; + +/// Overlay state +struct OverlayState { + /// Current screen rect for rendering + screen_rect: Option, + /// Whether in click-through mode + click_through: bool, +} + +impl Default for OverlayState { + fn default() -> Self { + Self { + screen_rect: None, + click_through: true, + } + } +} + +/// Manages the IPC connection to the daemon with auto-reconnect. +struct DaemonConnection { + channel_name: String, + channel: Option>, + last_connect_attempt: Instant, + reconnect_interval: Duration, +} + +impl DaemonConnection { + fn new(channel_name: &str) -> Self { + Self { + channel_name: channel_name.to_string(), + channel: None, + last_connect_attempt: Instant::now() - Duration::from_secs(10), // Allow immediate first attempt + reconnect_interval: Duration::from_secs(2), + } + } + + /// Try to connect if not already connected (rate-limited). + fn ensure_connected(&mut self) { + if self.channel.as_ref().is_some_and(|c| c.is_connected()) { + return; + } + + // Rate-limit reconnection attempts + if self.last_connect_attempt.elapsed() < self.reconnect_interval { + return; + } + self.last_connect_attempt = Instant::now(); + + match itk_ipc::connect(&self.channel_name) { + Ok(ch) => { + info!(channel = %self.channel_name, "Connected to daemon"); + self.channel = Some(Box::new(ch)); + }, + Err(e) => { + debug!(?e, "Daemon not available, will retry"); + self.channel = None; + }, + } + } + + /// Send a protocol message to the daemon. + fn send_message(&mut self, msg_type: MessageType, payload: &T) -> bool { + self.ensure_connected(); + let Some(ref channel) = self.channel else { + return false; + }; + + match encode(msg_type, payload) { + Ok(data) => match channel.send(&data) { + Ok(()) => true, + Err(e) => { + warn!(?e, "Failed to send to daemon, disconnecting"); + self.channel = None; + false + }, + }, + Err(e) => { + error!(?e, "Failed to encode message"); + false + }, + } + } + + fn is_connected(&self) -> bool { + self.channel.as_ref().is_some_and(|c| c.is_connected()) + } +} + +fn main() -> Result<()> { + let args = Args::parse(); + + // Initialize logging + let filter = format!( + "nms_video_overlay={},itk={}", + args.log_level, args.log_level + ); + let env_filter = tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(&filter)); + tracing_subscriber::fmt().with_env_filter(env_filter).init(); + + info!("NMS Cockpit Video Overlay starting"); + + let video_rect = args.video_rect.unwrap_or(DEFAULT_VIDEO_RECT); + + // Create event loop + let event_loop = EventLoop::new().context("Failed to create event loop")?; + event_loop.set_control_flow(ControlFlow::Poll); + + // Create window sized and positioned to match the video rect + let window = WindowBuilder::new() + .with_title("NMS Cockpit Video") + .with_inner_size(winit::dpi::PhysicalSize::new( + video_rect.width as u32, + video_rect.height as u32, + )) + .with_position(winit::dpi::PhysicalPosition::new( + video_rect.x as i32, + video_rect.y as i32, + )) + .with_decorations(false) + .with_window_level(WindowLevel::AlwaysOnTop) + .build(&event_loop) + .context("Failed to create window")?; + + let window = Arc::new(window); + + // Set platform-specific attributes + if let Err(e) = platform::set_transparent(&window) { + error!(?e, "Failed to set transparent"); + } + if let Err(e) = platform::set_always_on_top(&window, true) { + error!(?e, "Failed to set always-on-top"); + } + if let Err(e) = platform::set_click_through(&window, true) { + error!(?e, "Failed to set click-through"); + } + + // Create renderer with egui + let mut renderer = pollster::block_on(NmsRenderer::new(Arc::clone(&window))) + .context("Failed to create renderer")?; + info!("Renderer initialized"); + + // Application state + let mut state = OverlayState::default(); + let mut frame_reader = VideoFrameReader::new(); + let mut controls = VideoControls::new(); + let mut daemon = DaemonConnection::new(&args.daemon_channel); + let mut logged_frame_connection = false; + let mut f9_was_down = false; + + info!("NMS Cockpit Video Overlay started. Press F9 to toggle controls."); + + // Run event loop + event_loop + .run(move |event, elwt| { + match event { + Event::WindowEvent { event, .. } => { + // Pass events to egui first when interactive + if !state.click_through { + let _ = renderer.handle_event(&event); + } + + match event { + WindowEvent::CloseRequested => { + elwt.exit(); + }, + + WindowEvent::Resized(physical_size) => { + renderer.resize(physical_size); + }, + + WindowEvent::KeyboardInput { event, .. } => { + if event.state.is_pressed() { + match event.physical_key { + // F9 toggles click-through/interactive mode + winit::keyboard::PhysicalKey::Code( + winit::keyboard::KeyCode::F9, + ) => { + state.click_through = !state.click_through; + if let Err(e) = platform::set_click_through( + &window, + state.click_through, + ) { + error!(?e, "Failed to toggle click-through"); + } + info!( + mode = if state.click_through { + "click-through" + } else { + "interactive" + }, + "Mode toggled" + ); + }, + // Space toggles play/pause in interactive mode + winit::keyboard::PhysicalKey::Code( + winit::keyboard::KeyCode::Space, + ) => { + if !state.click_through { + controls.toggle_play_pause(); + } + }, + // Left arrow seeks back 10s + winit::keyboard::PhysicalKey::Code( + winit::keyboard::KeyCode::ArrowLeft, + ) => { + if !state.click_through { + controls.seek_relative(-10_000); + } + }, + // Right arrow seeks forward 10s + winit::keyboard::PhysicalKey::Code( + winit::keyboard::KeyCode::ArrowRight, + ) => { + if !state.click_through { + controls.seek_relative(10_000); + } + }, + _ => {}, + } + } + }, + + WindowEvent::RedrawRequested => { + // Get screen rect: prefer NMS mod rect, fall back to CLI/default + let screen_rect = state.screen_rect.as_ref().or_else(|| { + if frame_reader.is_connected() && frame_reader.last_pts_ms() > 0 { + Some(&video_rect) + } else { + None + } + }); + + // Render video and UI + let show_ui = !state.click_through; + if let Err(e) = renderer.render(screen_rect, show_ui, &mut controls) { + error!(?e, "Render failed"); + } + }, + + _ => {}, + } + }, + + Event::AboutToWait => { + // Poll global F9 hotkey (works even in click-through mode) + if platform::is_key_just_pressed(platform::VK_F9, &mut f9_was_down) { + state.click_through = !state.click_through; + if let Err(e) = platform::set_click_through(&window, state.click_through) { + error!(?e, "Failed to toggle click-through"); + } + info!( + mode = if state.click_through { + "click-through" + } else { + "interactive" + }, + "Mode toggled" + ); + } + + // Update connection status + controls.set_daemon_connected(daemon.is_connected()); + + // Log frame buffer connection status + if !logged_frame_connection && frame_reader.is_connected() { + info!("Connected to video frame buffer"); + logged_frame_connection = true; + } + + // Read new video frames + if let Some(frame_data) = frame_reader.try_read_frame() { + renderer.update_texture(frame_data); + // Update controls with current position + controls.set_position(frame_reader.last_pts_ms()); + debug!(pts_ms = frame_reader.last_pts_ms(), "Updated frame"); + } + + // Update duration from shared memory (set once on load) + let dur = frame_reader.duration_ms(); + if dur > 0 { + controls.set_duration(dur); + } + + // Process any UI actions + process_ui_actions(&mut controls, &mut daemon); + + // Request redraw + window.request_redraw(); + }, + + _ => {}, + } + }) + .context("Event loop failed")?; + + Ok(()) +} + +/// Process UI actions and send commands to the daemon via IPC. +fn process_ui_actions(controls: &mut VideoControls, daemon: &mut DaemonConnection) { + // Periodically try to connect + daemon.ensure_connected(); + + // Load video + if let Some(url) = controls.take_load_request() { + info!(url = %url, "Loading video"); + let cmd = VideoLoad { + source: url, + start_position_ms: 0, + autoplay: true, + }; + if !daemon.send_message(MessageType::VideoLoad, &cmd) { + warn!("Failed to send load command (daemon not connected)"); + } + } + + // Play + if controls.take_play_request() { + debug!("Play"); + let cmd = VideoPlay { + from_position_ms: None, + }; + daemon.send_message(MessageType::VideoPlay, &cmd); + } + + // Pause + if controls.take_pause_request() { + debug!("Pause"); + let cmd = VideoPause {}; + daemon.send_message(MessageType::VideoPause, &cmd); + } + + // Seek + if let Some(position_ms) = controls.take_seek_request() { + debug!(position_ms, "Seek"); + let cmd = VideoSeek { position_ms }; + daemon.send_message(MessageType::VideoSeek, &cmd); + } +} diff --git a/projects/nms-cockpit-video/overlay/src/platform.rs b/projects/nms-cockpit-video/overlay/src/platform.rs new file mode 100644 index 0000000..602e26f --- /dev/null +++ b/projects/nms-cockpit-video/overlay/src/platform.rs @@ -0,0 +1,119 @@ +//! Platform-specific overlay functionality for NMS overlay +//! +//! This module handles Windows-specific window attributes for overlay behavior. + +use anyhow::{Result, anyhow}; +use winit::raw_window_handle::{HasWindowHandle, RawWindowHandle}; + +#[cfg(windows)] +use windows::Win32::Foundation::HWND; +#[cfg(windows)] +use windows::Win32::UI::WindowsAndMessaging::{ + GWL_EXSTYLE, GetWindowLongW, HWND_TOPMOST, SWP_NOMOVE, SWP_NOSIZE, SetWindowLongW, + SetWindowPos, WS_EX_LAYERED, WS_EX_NOACTIVATE, WS_EX_TOOLWINDOW, WS_EX_TRANSPARENT, +}; + +/// Get the HWND from a winit window +#[cfg(windows)] +fn get_hwnd(window: &winit::window::Window) -> Result { + match window + .window_handle() + .map_err(|e| anyhow!("Failed to get window handle: {}", e))? + .as_raw() + { + RawWindowHandle::Win32(handle) => Ok(HWND(handle.hwnd.get() as *mut _)), + _ => Err(anyhow!("Expected Win32 window handle")), + } +} + +/// Set click-through mode for a window +/// +/// When enabled, mouse input passes through the window to the one behind it. +#[cfg(windows)] +pub fn set_click_through(window: &winit::window::Window, enabled: bool) -> Result<()> { + let hwnd = get_hwnd(window)?; + + unsafe { + let mut ex_style = GetWindowLongW(hwnd, GWL_EXSTYLE) as u32; + + if enabled { + // Enable click-through + ex_style |= WS_EX_TRANSPARENT.0 | WS_EX_LAYERED.0 | WS_EX_NOACTIVATE.0; + } else { + // Disable click-through + ex_style &= !(WS_EX_TRANSPARENT.0 | WS_EX_NOACTIVATE.0); + } + + SetWindowLongW(hwnd, GWL_EXSTYLE, ex_style as i32); + } + + Ok(()) +} + +/// Set always-on-top for a window +#[cfg(windows)] +pub fn set_always_on_top(window: &winit::window::Window, enabled: bool) -> Result<()> { + let hwnd = get_hwnd(window)?; + + unsafe { + let insert_after = if enabled { + HWND_TOPMOST + } else { + windows::Win32::UI::WindowsAndMessaging::HWND_NOTOPMOST + }; + + SetWindowPos(hwnd, insert_after, 0, 0, 0, 0, SWP_NOMOVE | SWP_NOSIZE) + .map_err(|e| anyhow!("SetWindowPos failed: {}", e))?; + } + + Ok(()) +} + +/// Make window transparent (for compositor) +#[cfg(windows)] +pub fn set_transparent(window: &winit::window::Window) -> Result<()> { + let hwnd = get_hwnd(window)?; + + unsafe { + let mut ex_style = GetWindowLongW(hwnd, GWL_EXSTYLE) as u32; + ex_style |= WS_EX_LAYERED.0 | WS_EX_TOOLWINDOW.0; + SetWindowLongW(hwnd, GWL_EXSTYLE, ex_style as i32); + } + + Ok(()) +} + +/// Check if a virtual key was just pressed (edge-triggered). +/// Call this every frame; it returns true on the frame the key transitions from up to down. +#[cfg(windows)] +pub fn is_key_just_pressed(vk: i32, was_down: &mut bool) -> bool { + let state = unsafe { windows::Win32::UI::Input::KeyboardAndMouse::GetAsyncKeyState(vk) }; + let is_down = (state as u16 & 0x8000) != 0; + let just_pressed = is_down && !*was_down; + *was_down = is_down; + just_pressed +} + +#[cfg(not(windows))] +pub fn is_key_just_pressed(_vk: i32, _was_down: &mut bool) -> bool { + false +} + +/// Virtual key code for F9 +pub const VK_F9: i32 = 0x78; + +// Stub implementations for non-Windows platforms +#[cfg(not(windows))] +pub fn set_click_through(_window: &winit::window::Window, _enabled: bool) -> Result<()> { + Ok(()) +} + +#[cfg(not(windows))] +pub fn set_always_on_top(_window: &winit::window::Window, _enabled: bool) -> Result<()> { + Ok(()) +} + +#[cfg(not(windows))] +pub fn set_transparent(_window: &winit::window::Window) -> Result<()> { + Ok(()) +} diff --git a/projects/nms-cockpit-video/overlay/src/render.rs b/projects/nms-cockpit-video/overlay/src/render.rs new file mode 100644 index 0000000..f267318 --- /dev/null +++ b/projects/nms-cockpit-video/overlay/src/render.rs @@ -0,0 +1,503 @@ +//! NMS Overlay Renderer +//! +//! Combines video rendering with egui UI. + +use crate::ui::VideoControls; +use anyhow::{Context, Result}; +use egui_wgpu::ScreenDescriptor; +use itk_protocol::ScreenRect; +use std::sync::Arc; +use winit::event::WindowEvent; +use winit::window::Window; + +/// Vertex data for video quad +#[repr(C)] +#[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)] +struct Vertex { + position: [f32; 2], + tex_coord: [f32; 2], +} + +/// NMS Overlay Renderer with egui integration +pub struct NmsRenderer { + surface: wgpu::Surface<'static>, + device: wgpu::Device, + queue: wgpu::Queue, + config: wgpu::SurfaceConfiguration, + render_pipeline: wgpu::RenderPipeline, + vertex_buffer: wgpu::Buffer, + bind_group: wgpu::BindGroup, + texture: wgpu::Texture, + texture_size: (u32, u32), + // egui integration + egui_ctx: egui::Context, + egui_state: egui_winit::State, + egui_renderer: egui_wgpu::Renderer, + window: Arc, +} + +impl NmsRenderer { + /// Create a new renderer + pub async fn new(window: Arc) -> Result { + let size = window.inner_size(); + + // Create wgpu instance + let instance = wgpu::Instance::new(wgpu::InstanceDescriptor { + backends: wgpu::Backends::all(), + ..Default::default() + }); + + // Create surface + let surface = instance + .create_surface(Arc::clone(&window)) + .context("Failed to create surface")?; + + // Request adapter + let adapter = instance + .request_adapter(&wgpu::RequestAdapterOptions { + power_preference: wgpu::PowerPreference::LowPower, + compatible_surface: Some(&surface), + force_fallback_adapter: false, + }) + .await + .context("No suitable GPU adapter found")?; + + // Create device and queue + let (device, queue) = adapter + .request_device( + &wgpu::DeviceDescriptor { + label: Some("NMS Overlay Device"), + required_features: wgpu::Features::empty(), + required_limits: wgpu::Limits::default(), + }, + None, + ) + .await + .context("Failed to create device")?; + + // Configure surface + let surface_caps = surface.get_capabilities(&adapter); + let surface_format = surface_caps + .formats + .iter() + .find(|f| f.is_srgb()) + .copied() + .unwrap_or(surface_caps.formats[0]); + + let alpha_mode = if surface_caps + .alpha_modes + .contains(&wgpu::CompositeAlphaMode::PreMultiplied) + { + wgpu::CompositeAlphaMode::PreMultiplied + } else { + surface_caps.alpha_modes[0] + }; + + let config = wgpu::SurfaceConfiguration { + usage: wgpu::TextureUsages::RENDER_ATTACHMENT, + format: surface_format, + width: size.width.max(1), + height: size.height.max(1), + present_mode: wgpu::PresentMode::Fifo, + alpha_mode, + view_formats: vec![], + desired_maximum_frame_latency: 2, + }; + surface.configure(&device, &config); + + // Create video shader + let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("Video Shader"), + source: wgpu::ShaderSource::Wgsl(VIDEO_SHADER.into()), + }); + + // Create texture for video (720p default) + let texture_size = (1280u32, 720u32); + let texture = device.create_texture(&wgpu::TextureDescriptor { + label: Some("Video Texture"), + size: wgpu::Extent3d { + width: texture_size.0, + height: texture_size.1, + depth_or_array_layers: 1, + }, + mip_level_count: 1, + sample_count: 1, + dimension: wgpu::TextureDimension::D2, + format: wgpu::TextureFormat::Rgba8UnormSrgb, + usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::COPY_DST, + view_formats: &[], + }); + + let texture_view = texture.create_view(&wgpu::TextureViewDescriptor::default()); + let sampler = device.create_sampler(&wgpu::SamplerDescriptor { + address_mode_u: wgpu::AddressMode::ClampToEdge, + address_mode_v: wgpu::AddressMode::ClampToEdge, + address_mode_w: wgpu::AddressMode::ClampToEdge, + mag_filter: wgpu::FilterMode::Linear, + min_filter: wgpu::FilterMode::Linear, + ..Default::default() + }); + + // Create bind group layout + let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: Some("Video Bind Group Layout"), + entries: &[ + wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::FRAGMENT, + ty: wgpu::BindingType::Texture { + multisampled: false, + view_dimension: wgpu::TextureViewDimension::D2, + sample_type: wgpu::TextureSampleType::Float { filterable: true }, + }, + count: None, + }, + wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: wgpu::ShaderStages::FRAGMENT, + ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering), + count: None, + }, + ], + }); + + let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { + label: Some("Video Bind Group"), + layout: &bind_group_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: wgpu::BindingResource::TextureView(&texture_view), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: wgpu::BindingResource::Sampler(&sampler), + }, + ], + }); + + // Create pipeline layout + let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: Some("Video Pipeline Layout"), + bind_group_layouts: &[&bind_group_layout], + push_constant_ranges: &[], + }); + + // Create render pipeline + let render_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor { + label: Some("Video Render Pipeline"), + layout: Some(&pipeline_layout), + vertex: wgpu::VertexState { + module: &shader, + entry_point: "vs_main", + buffers: &[wgpu::VertexBufferLayout { + array_stride: std::mem::size_of::() as wgpu::BufferAddress, + step_mode: wgpu::VertexStepMode::Vertex, + attributes: &[ + wgpu::VertexAttribute { + offset: 0, + shader_location: 0, + format: wgpu::VertexFormat::Float32x2, + }, + wgpu::VertexAttribute { + offset: std::mem::size_of::<[f32; 2]>() as wgpu::BufferAddress, + shader_location: 1, + format: wgpu::VertexFormat::Float32x2, + }, + ], + }], + compilation_options: wgpu::PipelineCompilationOptions::default(), + }, + fragment: Some(wgpu::FragmentState { + module: &shader, + entry_point: "fs_main", + targets: &[Some(wgpu::ColorTargetState { + format: config.format, + blend: Some(wgpu::BlendState::ALPHA_BLENDING), + write_mask: wgpu::ColorWrites::ALL, + })], + compilation_options: wgpu::PipelineCompilationOptions::default(), + }), + primitive: wgpu::PrimitiveState { + topology: wgpu::PrimitiveTopology::TriangleList, + strip_index_format: None, + front_face: wgpu::FrontFace::Ccw, + cull_mode: None, + polygon_mode: wgpu::PolygonMode::Fill, + unclipped_depth: false, + conservative: false, + }, + depth_stencil: None, + multisample: wgpu::MultisampleState::default(), + multiview: None, + }); + + // Create vertex buffer + let vertex_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Video Vertex Buffer"), + size: std::mem::size_of::<[Vertex; 6]>() as wgpu::BufferAddress, + usage: wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + // Initialize egui + let egui_ctx = egui::Context::default(); + egui_ctx.set_visuals(egui::Visuals::dark()); + + let egui_state = egui_winit::State::new( + egui_ctx.clone(), + egui::ViewportId::ROOT, + &window, + Some(window.scale_factor() as f32), + None, + ); + + let egui_renderer = egui_wgpu::Renderer::new(&device, config.format, None, 1); + + Ok(Self { + surface, + device, + queue, + config, + render_pipeline, + vertex_buffer, + bind_group, + texture, + texture_size, + egui_ctx, + egui_state, + egui_renderer, + window, + }) + } + + /// Handle window events for egui + pub fn handle_event(&mut self, event: &WindowEvent) -> egui_winit::EventResponse { + self.egui_state.on_window_event(&self.window, event) + } + + /// Resize the surface + pub fn resize(&mut self, new_size: winit::dpi::PhysicalSize) { + if new_size.width > 0 && new_size.height > 0 { + self.config.width = new_size.width; + self.config.height = new_size.height; + self.surface.configure(&self.device, &self.config); + } + } + + /// Update texture with new frame data + pub fn update_texture(&self, data: &[u8]) { + self.queue.write_texture( + wgpu::ImageCopyTexture { + texture: &self.texture, + mip_level: 0, + origin: wgpu::Origin3d::ZERO, + aspect: wgpu::TextureAspect::All, + }, + data, + wgpu::ImageDataLayout { + offset: 0, + bytes_per_row: Some(self.texture_size.0 * 4), + rows_per_image: Some(self.texture_size.1), + }, + wgpu::Extent3d { + width: self.texture_size.0, + height: self.texture_size.1, + depth_or_array_layers: 1, + }, + ); + } + + /// Render a frame with optional UI + pub fn render( + &mut self, + screen_rect: Option<&ScreenRect>, + show_ui: bool, + controls: &mut VideoControls, + ) -> Result<()> { + let output = self + .surface + .get_current_texture() + .context("Failed to get surface texture")?; + + let view = output + .texture + .create_view(&wgpu::TextureViewDescriptor::default()); + + let mut encoder = self + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("Render Encoder"), + }); + + // Update vertex buffer if we have a screen rect + if let Some(rect) = screen_rect { + let vertices = self.create_vertices(rect); + self.queue + .write_buffer(&self.vertex_buffer, 0, bytemuck::cast_slice(&vertices)); + } + + // Render video + { + let mut render_pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor { + label: Some("Video Render Pass"), + color_attachments: &[Some(wgpu::RenderPassColorAttachment { + view: &view, + resolve_target: None, + ops: wgpu::Operations { + load: wgpu::LoadOp::Clear(wgpu::Color::TRANSPARENT), + store: wgpu::StoreOp::Store, + }, + })], + depth_stencil_attachment: None, + occlusion_query_set: None, + timestamp_writes: None, + }); + + if screen_rect.is_some() { + render_pass.set_pipeline(&self.render_pipeline); + render_pass.set_bind_group(0, &self.bind_group, &[]); + render_pass.set_vertex_buffer(0, self.vertex_buffer.slice(..)); + render_pass.draw(0..6, 0..1); + } + } + + // Render egui UI if interactive mode + if show_ui { + let raw_input = self.egui_state.take_egui_input(&self.window); + let full_output = self.egui_ctx.run(raw_input, |ctx| { + controls.ui(ctx); + }); + + self.egui_state + .handle_platform_output(&self.window, full_output.platform_output); + + let clipped_primitives = self + .egui_ctx + .tessellate(full_output.shapes, full_output.pixels_per_point); + + let screen_descriptor = ScreenDescriptor { + size_in_pixels: [self.config.width, self.config.height], + pixels_per_point: full_output.pixels_per_point, + }; + + // Update egui textures + for (id, delta) in &full_output.textures_delta.set { + self.egui_renderer + .update_texture(&self.device, &self.queue, *id, delta); + } + + // Upload egui primitives + self.egui_renderer.update_buffers( + &self.device, + &self.queue, + &mut encoder, + &clipped_primitives, + &screen_descriptor, + ); + + // Render egui + { + let mut render_pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor { + label: Some("egui Render Pass"), + color_attachments: &[Some(wgpu::RenderPassColorAttachment { + view: &view, + resolve_target: None, + ops: wgpu::Operations { + load: wgpu::LoadOp::Load, // Preserve video + store: wgpu::StoreOp::Store, + }, + })], + depth_stencil_attachment: None, + occlusion_query_set: None, + timestamp_writes: None, + }); + + self.egui_renderer.render( + &mut render_pass, + &clipped_primitives, + &screen_descriptor, + ); + } + + // Free egui textures + for id in &full_output.textures_delta.free { + self.egui_renderer.free_texture(id); + } + } + + self.queue.submit(std::iter::once(encoder.finish())); + output.present(); + + Ok(()) + } + + /// Create vertices for a screen rect quad + fn create_vertices(&self, rect: &ScreenRect) -> [Vertex; 6] { + // Convert screen coordinates to NDC (-1 to 1) + let screen_width = self.config.width as f32; + let screen_height = self.config.height as f32; + + let left = (rect.x / screen_width) * 2.0 - 1.0; + let right = ((rect.x + rect.width) / screen_width) * 2.0 - 1.0; + let top = 1.0 - (rect.y / screen_height) * 2.0; + let bottom = 1.0 - ((rect.y + rect.height) / screen_height) * 2.0; + + // Two triangles forming a quad + [ + Vertex { + position: [left, top], + tex_coord: [0.0, 0.0], + }, + Vertex { + position: [right, top], + tex_coord: [1.0, 0.0], + }, + Vertex { + position: [left, bottom], + tex_coord: [0.0, 1.0], + }, + Vertex { + position: [left, bottom], + tex_coord: [0.0, 1.0], + }, + Vertex { + position: [right, top], + tex_coord: [1.0, 0.0], + }, + Vertex { + position: [right, bottom], + tex_coord: [1.0, 1.0], + }, + ] + } +} + +/// WGSL shader for video rendering +const VIDEO_SHADER: &str = r#" +struct VertexInput { + @location(0) position: vec2, + @location(1) tex_coord: vec2, +} + +struct VertexOutput { + @builtin(position) clip_position: vec4, + @location(0) tex_coord: vec2, +} + +@vertex +fn vs_main(in: VertexInput) -> VertexOutput { + var out: VertexOutput; + out.clip_position = vec4(in.position, 0.0, 1.0); + out.tex_coord = in.tex_coord; + return out; +} + +@group(0) @binding(0) var t_texture: texture_2d; +@group(0) @binding(1) var s_sampler: sampler; + +@fragment +fn fs_main(in: VertexOutput) -> @location(0) vec4 { + return textureSample(t_texture, s_sampler, in.tex_coord); +} +"#; diff --git a/projects/nms-cockpit-video/overlay/src/ui.rs b/projects/nms-cockpit-video/overlay/src/ui.rs new file mode 100644 index 0000000..73539ae --- /dev/null +++ b/projects/nms-cockpit-video/overlay/src/ui.rs @@ -0,0 +1,258 @@ +//! Video Controls UI +//! +//! egui-based video player controls for the NMS overlay. + +use egui::{Align2, Color32, FontId, RichText, Rounding, Stroke, Vec2}; + +/// Video player controls state +pub struct VideoControls { + /// URL input text + url_input: String, + /// Whether video is playing + is_playing: bool, + /// Current position in milliseconds + position_ms: u64, + /// Total duration in milliseconds + duration_ms: u64, + /// Pending load request + pending_load: Option, + /// Pending play request + pending_play: bool, + /// Pending pause request + pending_pause: bool, + /// Pending seek request (position in ms) + pending_seek: Option, + /// Whether the seek bar is being dragged + seeking: bool, + /// Seek position while dragging + seek_position: f32, + /// Whether connected to daemon + daemon_connected: bool, +} + +impl VideoControls { + pub fn new() -> Self { + Self { + url_input: String::new(), + is_playing: false, + position_ms: 0, + duration_ms: 0, + pending_load: None, + pending_play: false, + pending_pause: false, + pending_seek: None, + seeking: false, + seek_position: 0.0, + daemon_connected: false, + } + } + + /// Set the current playback position + pub fn set_position(&mut self, position_ms: u64) { + if !self.seeking { + self.position_ms = position_ms; + } + } + + /// Set the total duration + pub fn set_duration(&mut self, duration_ms: u64) { + self.duration_ms = duration_ms; + } + + /// Set playing state + #[allow(dead_code)] + pub fn set_playing(&mut self, playing: bool) { + self.is_playing = playing; + } + + /// Set daemon connection status + pub fn set_daemon_connected(&mut self, connected: bool) { + self.daemon_connected = connected; + } + + /// Toggle play/pause + pub fn toggle_play_pause(&mut self) { + if self.is_playing { + self.pending_pause = true; + } else { + self.pending_play = true; + } + } + + /// Seek relative to current position + pub fn seek_relative(&mut self, delta_ms: i64) { + let new_pos = (self.position_ms as i64 + delta_ms).max(0) as u64; + let clamped = new_pos.min(self.duration_ms); + self.pending_seek = Some(clamped); + } + + /// Take pending load request + pub fn take_load_request(&mut self) -> Option { + self.pending_load.take() + } + + /// Take pending play request + pub fn take_play_request(&mut self) -> bool { + std::mem::take(&mut self.pending_play) + } + + /// Take pending pause request + pub fn take_pause_request(&mut self) -> bool { + std::mem::take(&mut self.pending_pause) + } + + /// Take pending seek request + pub fn take_seek_request(&mut self) -> Option { + self.pending_seek.take() + } + + /// Render the controls UI + pub fn ui(&mut self, ctx: &egui::Context) { + // Semi-transparent panel at the bottom + egui::Area::new(egui::Id::new("video_controls")) + .anchor(Align2::CENTER_BOTTOM, Vec2::new(0.0, -20.0)) + .show(ctx, |ui| { + egui::Frame::none() + .fill(Color32::from_rgba_unmultiplied(20, 20, 30, 220)) + .rounding(Rounding::same(8.0)) + .stroke(Stroke::new(1.0, Color32::from_rgb(60, 60, 80))) + .inner_margin(16.0) + .show(ui, |ui| { + ui.set_min_width(600.0); + self.render_controls(ui); + }); + }); + } + + fn render_controls(&mut self, ui: &mut egui::Ui) { + ui.vertical(|ui| { + // URL input row + ui.horizontal(|ui| { + // Connection status dot + let color = if self.daemon_connected { + Color32::from_rgb(80, 200, 80) + } else { + Color32::from_rgb(200, 80, 80) + }; + let (rect, response) = + ui.allocate_exact_size(Vec2::new(10.0, 10.0), egui::Sense::hover()); + ui.painter().circle_filled(rect.center(), 4.0, color); + response.on_hover_text(if self.daemon_connected { + "Connected to daemon" + } else { + "Daemon not connected" + }); + ui.add_space(4.0); + + ui.label(RichText::new("URL:").color(Color32::WHITE)); + let response = ui.add( + egui::TextEdit::singleline(&mut self.url_input) + .desired_width(400.0) + .hint_text("Enter video URL or file path...") + .text_color(Color32::WHITE), + ); + + if (ui + .button(RichText::new("Load").color(Color32::WHITE)) + .clicked() + || (response.lost_focus() && ui.input(|i| i.key_pressed(egui::Key::Enter)))) + && !self.url_input.is_empty() + { + self.pending_load = Some(self.url_input.clone()); + } + }); + + ui.add_space(8.0); + + // Playback controls row + ui.horizontal(|ui| { + // Play/Pause button + let play_pause_text = if self.is_playing { "||" } else { ">" }; + if ui + .button( + RichText::new(play_pause_text) + .font(FontId::proportional(20.0)) + .color(Color32::WHITE), + ) + .clicked() + { + self.toggle_play_pause(); + } + + ui.add_space(8.0); + + // Time display + let current_time = format_time(self.position_ms); + let total_time = format_time(self.duration_ms); + ui.label( + RichText::new(format!("{} / {}", current_time, total_time)) + .color(Color32::LIGHT_GRAY) + .font(FontId::monospace(14.0)), + ); + + ui.add_space(8.0); + + // Seek bar + let _progress = if self.duration_ms > 0 { + if self.seeking { + self.seek_position + } else { + self.position_ms as f32 / self.duration_ms as f32 + } + } else { + 0.0 + }; + + let slider_response = ui.add( + egui::Slider::new(&mut self.seek_position, 0.0..=1.0) + .show_value(false) + .trailing_fill(true), + ); + + // Update seek position from actual position when not dragging + if !self.seeking && self.duration_ms > 0 { + self.seek_position = self.position_ms as f32 / self.duration_ms as f32; + } + + // Handle dragging + if slider_response.drag_started() { + self.seeking = true; + } + if slider_response.drag_stopped() { + self.seeking = false; + let seek_ms = (self.seek_position * self.duration_ms as f32) as u64; + self.pending_seek = Some(seek_ms); + } + + ui.add_space(8.0); + + // Keyboard shortcuts hint + ui.label( + RichText::new("[Space] Play/Pause [] Seek [F9] Hide") + .color(Color32::GRAY) + .small(), + ); + }); + }); + } +} + +impl Default for VideoControls { + fn default() -> Self { + Self::new() + } +} + +/// Format milliseconds as MM:SS or HH:MM:SS +fn format_time(ms: u64) -> String { + let total_secs = ms / 1000; + let hours = total_secs / 3600; + let mins = (total_secs % 3600) / 60; + let secs = total_secs % 60; + + if hours > 0 { + format!("{:02}:{:02}:{:02}", hours, mins, secs) + } else { + format!("{:02}:{:02}", mins, secs) + } +} diff --git a/projects/nms-cockpit-video/overlay/src/video.rs b/projects/nms-cockpit-video/overlay/src/video.rs new file mode 100644 index 0000000..e9d62b0 --- /dev/null +++ b/projects/nms-cockpit-video/overlay/src/video.rs @@ -0,0 +1,116 @@ +//! Video Frame Reader +//! +//! Reads video frames from shared memory written by the daemon. + +use itk_shmem::FrameBuffer; +use tracing::debug; + +/// Name of the shared memory buffer for video frames +const VIDEO_BUFFER_NAME: &str = "itk_video_frames"; + +/// Default video dimensions (720p) +const DEFAULT_WIDTH: u32 = 1280; +const DEFAULT_HEIGHT: u32 = 720; + +/// Video frame reader - reads frames from shared memory +pub struct VideoFrameReader { + /// Frame buffer connection + buffer: Option, + /// Last frame data (kept for returning reference) + last_frame: Vec, + /// Last presentation timestamp + last_pts: u64, + /// Cached duration from header + duration_ms: u64, + /// Whether we've attempted connection + connection_attempted: bool, +} + +impl VideoFrameReader { + /// Create a new video frame reader + pub fn new() -> Self { + Self { + buffer: None, + last_frame: Vec::new(), + last_pts: 0, + duration_ms: 0, + connection_attempted: false, + } + } + + /// Check if connected to the frame buffer + pub fn is_connected(&self) -> bool { + self.buffer.is_some() + } + + /// Get the last presentation timestamp in milliseconds + pub fn last_pts_ms(&self) -> u64 { + self.last_pts + } + + /// Get the total duration in milliseconds (0 if unknown) + pub fn duration_ms(&self) -> u64 { + self.duration_ms + } + + /// Try to connect to the shared memory buffer + fn try_connect(&mut self) { + if self.buffer.is_some() { + return; + } + + match FrameBuffer::open(VIDEO_BUFFER_NAME, DEFAULT_WIDTH, DEFAULT_HEIGHT) { + Ok(buffer) => { + debug!("Connected to video frame buffer"); + // Pre-allocate frame buffer + self.last_frame = vec![0u8; buffer.frame_size()]; + self.buffer = Some(buffer); + }, + Err(e) => { + if !self.connection_attempted { + debug!(?e, "Frame buffer not available yet"); + self.connection_attempted = true; + } + }, + } + } + + /// Try to read a new frame from shared memory + /// + /// Returns Some(&[u8]) if a new frame is available, None otherwise. + pub fn try_read_frame(&mut self) -> Option<&[u8]> { + // Try to connect if not connected + if self.buffer.is_none() { + self.try_connect(); + } + + // Read frame if connected + if let Some(ref buffer) = self.buffer { + // Update cached duration (cheap atomic load) + let dur = buffer.duration_ms(); + if dur > 0 { + self.duration_ms = dur; + } + + match buffer.read_frame(self.last_pts, &mut self.last_frame) { + Ok((pts_ms, data_changed)) => { + if data_changed { + self.last_pts = pts_ms; + return Some(&self.last_frame); + } + }, + Err(e) => { + debug!(?e, "Failed to read frame"); + }, + } + } + + None + } +} + +impl Default for VideoFrameReader { + fn default() -> Self { + Self::new() + } +} diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..2f9d456 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,66 @@ +# Rust Formatting Configuration +# This file defines explicit formatting rules for the project. +# Run with: cargo fmt +# Check with: cargo fmt --check +# +# Note: We only use stable rustfmt options here. Unstable options would require +# nightly toolchain and are commented out below for reference. + +# Edition to use for parsing (matches Cargo.toml) +edition = "2024" + +# Maximum line width +max_width = 100 + +# Use spaces for indentation +hard_tabs = false +tab_spaces = 4 + +# Import ordering +reorder_imports = true +reorder_modules = true + +# Keep newlines for readability +newline_style = "Unix" + +# Function signatures +fn_params_layout = "Tall" + +# Match arm formatting +match_block_trailing_comma = true + +# Use field init shorthand where possible +use_field_init_shorthand = true + +# Merge derives into a single attribute +merge_derives = true + +# Use try! shorthand (?) +use_try_shorthand = true + +# Chain formatting +chain_width = 60 + +# Array formatting +array_width = 60 + +# ---------------------------------------------------------------------------- +# Unstable options (require nightly toolchain) +# Uncomment these if using nightly rustfmt: +# ---------------------------------------------------------------------------- +# imports_granularity = "Module" +# group_imports = "StdExternalCrate" +# wrap_comments = false +# normalize_comments = false +# normalize_doc_attributes = true +# format_code_in_doc_comments = true +# struct_lit_single_line = true +# fn_single_line = false +# where_single_line = false +# overflow_delimited_expr = true +# enum_discrim_align_threshold = 20 +# match_arm_blocks = true +# brace_style = "SameLineWhere" +# control_brace_style = "AlwaysSameLine" +# blank_lines_upper_bound = 2 +# blank_lines_lower_bound = 0 diff --git a/tools/cli/agents/run_claude.bat b/tools/cli/agents/run_claude.bat new file mode 100644 index 0000000..b19ee40 --- /dev/null +++ b/tools/cli/agents/run_claude.bat @@ -0,0 +1,53 @@ +@echo off +REM run_claude.bat - Start Claude Code with Node.js 22.16.0 + +setlocal enabledelayedexpansion + +echo Starting Claude Code with Node.js 22.16.0 + +REM Check if nvm-windows is installed +where nvm >nul 2>&1 +if %ERRORLEVEL% neq 0 ( + echo NVM for Windows not found. Please install it first. + echo Visit: https://github.com/coreybutler/nvm-windows + exit /b 1 +) + +REM Switch to Node.js 22.16.0 +echo Switching to Node.js 22.16.0... +call nvm use 22.16.0 +if %ERRORLEVEL% neq 0 ( + echo Node.js 22.16.0 not installed. Installing... + call nvm install 22.16.0 + call nvm use 22.16.0 +) + +REM Verify Node version +for /f "tokens=*" %%i in ('node --version') do set NODE_VERSION=%%i +echo Using Node.js: %NODE_VERSION% + +REM Note: Security validation is handled by gh-validator binary +REM via PATH shadowing. No explicit hook initialization needed. + +REM Ask about unattended mode +echo. +echo Claude Code Configuration +echo. +echo Would you like to run Claude Code in unattended mode? +echo This will allow Claude to execute commands without asking for approval. +echo. +choice /c YN /n /m "Use unattended mode? (Y/N): " +if %ERRORLEVEL% equ 1 ( + echo. + echo Starting Claude Code in UNATTENDED mode --dangerously-skip-permissions... + echo WARNING: Claude will execute commands without asking for approval! + echo. + claude --dangerously-skip-permissions +) else ( + echo. + echo Starting Claude Code in NORMAL mode with approval prompts... + echo. + claude +) + +endlocal diff --git a/tools/cli/agents/run_claude.sh b/tools/cli/agents/run_claude.sh new file mode 100755 index 0000000..5a391e1 --- /dev/null +++ b/tools/cli/agents/run_claude.sh @@ -0,0 +1,53 @@ +#!/bin/bash +# run_claude.sh - Start Claude Code with Node.js 22.16.0 + +set -e + +echo "🚀 Starting Claude Code with Node.js 22.16.0" + +# Load NVM if it exists +if [ -s "$HOME/.nvm/nvm.sh" ]; then + echo "📦 Loading NVM..." + # shellcheck source=/dev/null + source "$HOME/.nvm/nvm.sh" +elif [ -s "/usr/local/share/nvm/nvm.sh" ]; then + echo "📦 Loading NVM (system-wide)..." + # shellcheck source=/dev/null + source "/usr/local/share/nvm/nvm.sh" +else + echo "❌ NVM not found. Please install NVM first." + echo "Visit: https://github.com/nvm-sh/nvm#installation-and-update" + exit 1 +fi + +# Use Node.js 22.16.0 +echo "🔧 Switching to Node.js 22.16.0..." +nvm use 22.16.0 + +# Verify Node version +NODE_VERSION=$(node --version) +echo "✅ Using Node.js: $NODE_VERSION" + +# Note: Security validation is handled by gh-validator binary at ~/.local/bin/gh +# via PATH shadowing. No explicit hook initialization needed. + +# Ask about unattended mode +echo "🤖 Claude Code Configuration" +echo "" +echo "Would you like to run Claude Code in unattended mode?" +echo "This will allow Claude to execute commands without asking for approval." +echo "" +read -p "Use unattended mode? (y/N): " -n 1 -r +echo "" + +# Start Claude Code with appropriate flags +if [[ $REPLY =~ ^[Yy]$ ]]; then + echo "⚡ Starting Claude Code in UNATTENDED mode (--dangerously-skip-permissions)..." + echo "⚠️ Claude will execute commands without asking for approval!" + echo "" + claude --dangerously-skip-permissions +else + echo "🔒 Starting Claude Code in NORMAL mode (with approval prompts)..." + echo "" + claude +fi diff --git a/tools/cli/agents/run_codex.sh b/tools/cli/agents/run_codex.sh new file mode 100755 index 0000000..a2f1417 --- /dev/null +++ b/tools/cli/agents/run_codex.sh @@ -0,0 +1,215 @@ +#!/bin/bash +# run_codex.sh - Start Codex CLI for AI-powered code generation + +set -e + +echo "🚀 Starting Codex CLI" + +# Check if codex CLI is available +if ! command -v codex &> /dev/null; then + echo "❌ codex CLI not found. Installing..." + echo "" + echo "Please install Codex with:" + echo " npm install -g @openai/codex" + echo "" + echo "Or in the container version which has it pre-installed:" + echo " ./tools/cli/containers/run_codex_container.sh" + exit 1 +fi + +# Check for auth file +AUTH_FILE="$HOME/.codex/auth.json" +if [ ! -f "$AUTH_FILE" ]; then + echo "❌ Codex authentication not found at $AUTH_FILE" + echo "" + echo "Please authenticate with Codex first:" + echo " codex login" + echo "" + echo "Or run the container version with mounted auth:" + echo " ./tools/cli/containers/run_codex_container.sh" + exit 1 +fi + +echo "✅ Codex CLI found and authenticated" + +# Note: Security validation is handled by gh-validator binary at ~/.local/bin/gh +# via PATH shadowing. No explicit hook initialization needed. + +# Parse command line arguments +MODE="interactive" +QUERY="" +CONTEXT="" +USE_EXEC=false +BYPASS_SANDBOX=false +AUTO_MODE=false + +while [[ $# -gt 0 ]]; do + case $1 in + -q|--query) + QUERY="$2" + MODE="exec" + USE_EXEC=true + shift 2 + ;; + -c|--context) + CONTEXT="$2" + shift 2 + ;; + --auto) + AUTO_MODE=true + shift + ;; + --bypass-sandbox) + BYPASS_SANDBOX=true + shift + ;; + -h|--help) + echo "Usage: $0 [options]" + echo "" + echo "Options:" + echo " -q, --query Execute non-interactively with specified prompt" + echo " -c, --context Add context from file" + echo " --auto Auto-approve mode (uses --full-auto for safer execution)" + echo " --bypass-sandbox Use --dangerously-bypass-approvals-and-sandbox (DANGEROUS!)" + echo " -h, --help Show this help message" + echo "" + echo "Interactive Mode (default):" + echo " Start an interactive session with Codex" + echo "" + echo "Non-Interactive Execution Mode:" + echo " $0 -q 'Write a Python function to calculate fibonacci'" + echo " $0 -q 'Refactor this code' -c existing_code.py" + echo "" + echo "Safe Auto Mode (workspace-write sandbox):" + echo " $0 -q 'Build a web server' --auto" + echo "" + echo "Dangerous Mode (no sandbox - USE WITH CAUTION!):" + echo " $0 -q 'System task' --bypass-sandbox" + echo "" + echo "Note: Codex requires authentication via 'codex login' first." + exit 0 + ;; + *) + echo "Unknown option: $1" + echo "Use -h or --help for usage information" + exit 1 + ;; + esac +done + +# Execute based on mode +if [ "$USE_EXEC" = true ] && [ -n "$QUERY" ]; then + echo "📝 Running non-interactive execution..." + + # Build the prompt with context if provided + FULL_PROMPT="$QUERY" + if [ -n "$CONTEXT" ] && [ -f "$CONTEXT" ]; then + echo "📄 Including context from: $CONTEXT" + CONTEXT_CONTENT=$(cat "$CONTEXT") + FULL_PROMPT="Context from $CONTEXT: +\`\`\` +$CONTEXT_CONTENT +\`\`\` + +Task: $QUERY" + fi + + # Determine execution mode + if [ "$BYPASS_SANDBOX" = true ]; then + # Ask for confirmation unless explicitly bypassed + if [ "$AUTO_MODE" != true ]; then + echo "" + echo "⚠️ WARNING: --dangerously-bypass-approvals-and-sandbox mode" + echo "This will execute commands WITHOUT ANY SANDBOXING or approval!" + echo "Only use this in already-sandboxed environments." + echo "" + read -r -p "Are you ABSOLUTELY SURE you want to continue? (yes/no): " confirm + if [ "$confirm" != "yes" ]; then + echo "❌ Aborted for safety." + exit 1 + fi + fi + + echo "⚡ Executing with --dangerously-bypass-approvals-and-sandbox..." + echo "" + echo "$FULL_PROMPT" | codex exec --dangerously-bypass-approvals-and-sandbox - + + elif [ "$AUTO_MODE" = true ]; then + echo "🔐 Executing with --full-auto (sandboxed workspace-write)..." + echo "" + echo "$FULL_PROMPT" | codex exec --full-auto - + + else + # Default: interactive approval mode with workspace-write sandbox + echo "🔒 Executing with workspace-write sandbox (approval required)..." + echo "" + echo "$FULL_PROMPT" | codex exec --sandbox workspace-write - + fi + +elif [ "$MODE" = "interactive" ]; then + # Only show note if no arguments provided + if [ $# -eq 0 ]; then + echo "🤖 Codex Configuration" + echo "" + echo "ℹ️ Note: Codex is an AI-powered code generation tool by OpenAI." + echo "It can help with code completion, generation, and refactoring." + echo "" + + # Ask about sandbox preference for interactive mode + echo "Choose sandbox mode for this session:" + echo "1) Standard (with approvals and sandbox)" + echo "2) Auto mode (--full-auto: workspace-write sandbox, no approvals)" + echo "3) Dangerous (--dangerously-bypass-approvals-and-sandbox)" + echo "" + read -r -p "Enter choice (1-3) [default: 1]: " choice + + case "$choice" in + 2) + echo "🔐 Starting with --full-auto mode..." + codex --full-auto + ;; + 3) + echo "" + echo "⚠️ WARNING: This disables ALL safety features!" + read -r -p "Are you sure? (yes/no): " confirm + if [ "$confirm" = "yes" ]; then + echo "⚡ Starting with --dangerously-bypass-approvals-and-sandbox..." + codex --dangerously-bypass-approvals-and-sandbox + else + echo "✅ Starting standard interactive mode..." + codex + fi + ;; + *) + echo "✅ Starting standard interactive mode..." + codex + ;; + esac + else + # Arguments were provided but no query - apply flags to interactive mode + echo "🔄 Starting interactive session with provided flags..." + echo "💡 Tips:" + echo " - Use 'help' to see available commands" + echo " - Use 'exit' or Ctrl+C to quit" + echo "" + + # Build command with any flags that were provided + CODEX_CMD="codex" + if [ "$AUTO_MODE" = true ]; then + echo " - Running with --full-auto mode" + CODEX_CMD="$CODEX_CMD --full-auto" + fi + if [ "$BYPASS_SANDBOX" = true ]; then + echo " - ⚠️ Running with --dangerously-bypass-approvals-and-sandbox" + CODEX_CMD="$CODEX_CMD --dangerously-bypass-approvals-and-sandbox" + fi + echo "" + + # Execute with the built command + $CODEX_CMD + fi +else + echo "❌ Error: Query is required for exec mode" + echo "Use -h or --help for usage information" + exit 1 +fi diff --git a/tools/cli/agents/run_crush.sh b/tools/cli/agents/run_crush.sh new file mode 100755 index 0000000..b2fd96f --- /dev/null +++ b/tools/cli/agents/run_crush.sh @@ -0,0 +1,217 @@ +#!/bin/bash +# run_crush.sh - Start Crush CLI for fast code generation + +set -e + +echo "⚡ Starting Crush CLI (Fast Code Generation)" + +# Auto-load .env file if it exists and OPENROUTER_API_KEY is not set +if [ -z "$OPENROUTER_API_KEY" ] && [ -f ".env" ]; then + echo "📄 Loading environment from .env file..." + set -a # Enable auto-export + source .env + set +a # Disable auto-export +fi + +# Check for API key +if [ -z "$OPENROUTER_API_KEY" ]; then + echo "❌ OPENROUTER_API_KEY not set. Please export your API key:" + echo " export OPENROUTER_API_KEY='your-key-here'" + exit 1 +fi + +echo "✅ Using OpenRouter API key: ****${OPENROUTER_API_KEY: -4}" + +# Note: Security validation is handled by gh-validator binary at ~/.local/bin/gh +# via PATH shadowing. No explicit hook initialization needed. + +# Check if crush CLI is available +if ! command -v crush &> /dev/null; then + echo "⚠️ crush CLI not found. Installing..." + echo "" + + # Install crush from Charm Bracelet + if command -v go &> /dev/null; then + echo "📦 Installing crush via go..." + go install github.com/charmbracelet/crush@latest + elif command -v brew &> /dev/null; then + echo "📦 Installing crush via brew..." + brew install charmbracelet/tap/crush + else + echo "❌ Neither go nor brew found. Please install crush manually:" + echo " https://github.com/charmbracelet/crush" + exit 1 + fi +fi + +# Ask about unattended mode for interactive sessions +UNATTENDED_FLAG="" +if [ $# -eq 0 ]; then + # Only ask if no arguments provided (interactive mode) + echo "🤖 Crush Configuration" + echo "" + echo "Would you like to run Crush in unattended mode?" + echo "This will allow Crush to execute commands without asking for approval." + echo "" + read -p "Use unattended mode? (y/N): " -n 1 -r + echo "" + + if [[ $REPLY =~ ^[Yy]$ ]]; then + UNATTENDED_FLAG="-y" + echo "⚡ Will run Crush in UNATTENDED mode (--yolo)..." + echo "⚠️ Crush will execute commands without asking for approval!" + echo "" + else + echo "🔒 Will run Crush in NORMAL mode (with approval prompts)..." + echo "" + fi +fi + +# Parse command line arguments +MODE="interactive" +QUERY="" +STYLE="concise" +CONVERT_TO="" +CODE_FILE="" + +while [[ $# -gt 0 ]]; do + case $1 in + -q|--query) + QUERY="$2" + MODE="single" + shift 2 + ;; + -s|--style) + STYLE="$2" + shift 2 + ;; + -e|--explain) + MODE="explain" + CODE_FILE="$2" + shift 2 + ;; + -c|--convert) + MODE="convert" + CODE_FILE="$2" + shift 2 + ;; + -t|--to) + CONVERT_TO="$2" + shift 2 + ;; + -h|--help) + echo "Usage: $0 [options]" + echo "" + echo "Options:" + echo " -q, --query Single query mode with specified prompt" + echo " -s, --style