diff --git a/.github/workflows/advanced-backend-unit-tests.yml b/.github/workflows/advanced-backend-unit-tests.yml new file mode 100644 index 00000000..0b6e1e3e --- /dev/null +++ b/.github/workflows/advanced-backend-unit-tests.yml @@ -0,0 +1,59 @@ +name: Advanced Backend Unit Tests + +on: + pull_request: + paths: + - 'backends/advanced/**' + - '.github/workflows/advanced-backend-unit-tests.yml' + - 'Makefile.unittests' + push: + branches: + - dev + - main + paths: + - 'backends/advanced/**' + - '.github/workflows/advanced-backend-unit-tests.yml' + - 'Makefile.unittests' + workflow_dispatch: + +permissions: + contents: read + +jobs: + advanced-backend-unit-tests: + name: Run advanced backend unit tests + runs-on: ubuntu-latest + timeout-minutes: 20 + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install uv + uses: astral-sh/setup-uv@v4 + + - name: Mock config files for unit tests + run: | + cat > config.env << 'EOF' + DEPLOYMENT_MODE=docker-compose + DOMAIN=localhost + CONTAINER_REGISTRY=local + SPEAKER_NODE=localhost + INFRASTRUCTURE_NAMESPACE=infrastructure + APPLICATION_NAMESPACE=application + EOF + + mkdir -p backends/advanced + cat > backends/advanced/.env << 'EOF' + AUTH_SECRET_KEY=test-auth-secret + ADMIN_PASSWORD=test-admin-password + ADMIN_EMAIL=test-admin@example.com + EOF + + - name: Run advanced backend unit tests + run: make -f Makefile.unittests test-unit diff --git a/Makefile.unittests b/Makefile.unittests new file mode 100644 index 00000000..edfded4c --- /dev/null +++ b/Makefile.unittests @@ -0,0 +1,27 @@ +.PHONY: help test-unit robot-test robot-all + +# Minimal auth env required during import-time test collection. +UNIT_TEST_ENV := AUTH_SECRET_KEY=test-auth-secret ADMIN_PASSWORD=test-admin-password ADMIN_EMAIL=test-admin@example.com + +help: + @echo "Repository-level unit/robot test targets" + @echo " make -f Makefile.unittests test-unit Run advanced backend Python unit tests" + @echo "" + @echo "Robot test passthrough targets (runs via tests/)" + @echo " make -f Makefile.unittests robot-test [CONFIG=deepgram-openai.yml]" + @echo " make -f Makefile.unittests robot-all [CONFIG=deepgram-openai.yml]" + +test-unit: + @echo "Running advanced backend Python unit tests..." + @cd backends/advanced && $(UNIT_TEST_ENV) uv run --group test pytest tests/unit tests/test_init_llm_setup.py + @echo "Advanced backend Python unit tests completed" + +robot-test: + @echo "Starting/rebuilding Robot test containers from tests/..." + @$(MAKE) -C tests start-rebuild $(if $(CONFIG),CONFIG=$(CONFIG),) + @echo "Running Robot test workflow from tests/..." + @$(MAKE) -C tests test $(if $(CONFIG),CONFIG=$(CONFIG),) + +robot-all: + @echo "Running all Robot suites from tests/..." + @$(MAKE) -C tests all $(if $(CONFIG),CONFIG=$(CONFIG),) diff --git a/README.md b/README.md index 7e342210..e06938ea 100644 --- a/README.md +++ b/README.md @@ -145,6 +145,22 @@ cd app npm start ``` +### Unit + Robot Tests +```bash +# Run advanced backend Python unit tests +make -f Makefile.unittests test-unit + +# Run Robot test workflow (includes start-rebuild automatically) +make -f Makefile.unittests robot-test + +# Run both unit + Robot tests +make -f Makefile.unittests test-unit && make -f Makefile.unittests robot-test + +# Optional Robot config override +make -f Makefile.unittests robot-test CONFIG=deepgram-openai.yml +``` + + ### Health Checks ```bash # Backend health diff --git a/backends/advanced/init.py b/backends/advanced/init.py index aad7ff0e..f5010d47 100644 --- a/backends/advanced/init.py +++ b/backends/advanced/init.py @@ -164,6 +164,92 @@ def prompt_with_existing_masked(self, prompt_text: str, env_key: str, placeholde default=default ) + def _get_model_def(self, config: Dict[str, Any], model_name: str) -> Dict[str, Any]: + """Get a model definition by name from config.yml.""" + models = config.get("models", []) + if not isinstance(models, list): + return {} + return next((m for m in models if m.get("name") == model_name), {}) + + def _infer_embedding_dimensions(self, model_name: str, fallback: int = 1536) -> int: + """Infer embedding dimensions for common models.""" + known_dimensions = { + "text-embedding-3-small": 1536, + "text-embedding-3-large": 3072, + "text-embedding-ada-002": 1536, + "nomic-embed-text-v1.5": 768, + "nomic-embed-text:latest": 768, + } + return known_dimensions.get(model_name, fallback) + + def _upsert_openai_models( + self, + api_key: str, + base_url: str, + llm_model_name: str, + embedding_model_name: str, + ) -> None: + """Update or create openai-llm/openai-embed in config.yml and set defaults.""" + config = self.config_manager.get_full_config() + models = config.get("models", []) + if not isinstance(models, list): + models = [] + + openai_llm = self._get_model_def(config, "openai-llm") + openai_embed = self._get_model_def(config, "openai-embed") + + llm_params = openai_llm.get("model_params", {}) + if not isinstance(llm_params, dict): + llm_params = {} + llm_params.setdefault("temperature", 0.2) + llm_params.setdefault("max_tokens", 2000) + + embedding_dimensions = openai_embed.get("embedding_dimensions") + if not isinstance(embedding_dimensions, int) or embedding_dimensions <= 0: + embedding_dimensions = self._infer_embedding_dimensions(embedding_model_name) + + llm_payload = { + "name": "openai-llm", + "description": "OpenAI/OpenAI-compatible LLM", + "model_type": "llm", + "model_provider": "openai", + "api_family": "openai", + "model_name": llm_model_name, + "model_url": base_url, + "api_key": api_key, + "model_params": llm_params, + "model_output": "json", + } + embed_payload = { + "name": "openai-embed", + "description": "OpenAI/OpenAI-compatible embeddings", + "model_type": "embedding", + "model_provider": "openai", + "api_family": "openai", + "model_name": embedding_model_name, + "model_url": base_url, + "api_key": api_key, + "embedding_dimensions": embedding_dimensions, + "model_output": "vector", + } + + def upsert_model(payload: Dict[str, Any]): + for idx, model in enumerate(models): + if model.get("name") == payload["name"]: + models[idx] = {**model, **payload} + return + models.append(payload) + + upsert_model(llm_payload) + upsert_model(embed_payload) + + config["models"] = models + if "defaults" not in config or not isinstance(config["defaults"], dict): + config["defaults"] = {} + config["defaults"]["llm"] = "openai-llm" + config["defaults"]["embedding"] = "openai-embed" + + self.config_manager.save_full_config(config) def setup_authentication(self): """Configure authentication settings""" @@ -307,7 +393,7 @@ def setup_llm(self): self.console.print() choices = { - "1": "OpenAI (GPT-4, GPT-3.5 - requires API key)", + "1": "OpenAI / OpenAI-compatible (custom base URL, API key, model names)", "2": "Ollama (local models - runs locally)", "3": "Skip (no memory extraction)" } @@ -315,9 +401,21 @@ def setup_llm(self): choice = self.prompt_choice("Which LLM provider will you use?", choices, "1") if choice == "1": - self.console.print("[blue][INFO][/blue] OpenAI selected") + self.console.print("[blue][INFO][/blue] OpenAI/OpenAI-compatible selected") self.console.print("Get your API key from: https://platform.openai.com/api-keys") + existing_cfg = self.config_manager.get_full_config() + openai_llm = self._get_model_def(existing_cfg, "openai-llm") + openai_embed = self._get_model_def(existing_cfg, "openai-embed") + + default_base_url = openai_llm.get("model_url") or openai_embed.get("model_url") or "https://api.openai.com/v1" + default_llm_model = openai_llm.get("model_name") or "gpt-4o-mini" + default_embedding_model = openai_embed.get("model_name") or "text-embedding-3-small" + + base_url = self.prompt_value("OpenAI-compatible base URL", default_base_url) + llm_model_name = self.prompt_value("LLM model name", default_llm_model) + embedding_model_name = self.prompt_value("Embedding model name", default_embedding_model) + # Use the new masked prompt function api_key = self.prompt_with_existing_masked( prompt_text="OpenAI API key (leave empty to skip)", @@ -329,11 +427,19 @@ def setup_llm(self): if api_key: self.config["OPENAI_API_KEY"] = api_key - # Update config.yml to use OpenAI models - self.config_manager.update_config_defaults({"llm": "openai-llm", "embedding": "openai-embed"}) + # Update config.yml openai model definitions and defaults + self._upsert_openai_models( + api_key=api_key, + base_url=base_url, + llm_model_name=llm_model_name, + embedding_model_name=embedding_model_name, + ) self.console.print("[green][SUCCESS][/green] OpenAI configured in config.yml") self.console.print("[blue][INFO][/blue] Set defaults.llm: openai-llm") self.console.print("[blue][INFO][/blue] Set defaults.embedding: openai-embed") + self.console.print(f"[blue][INFO][/blue] Set openai-llm.model_url: {base_url}") + self.console.print(f"[blue][INFO][/blue] Set openai-llm.model_name: {llm_model_name}") + self.console.print(f"[blue][INFO][/blue] Set openai-embed.model_name: {embedding_model_name}") else: self.console.print("[yellow][WARNING][/yellow] No API key provided - memory extraction will not work") diff --git a/backends/advanced/src/advanced_omi_backend/workers/audio_stream_deepgram_worker.py b/backends/advanced/src/advanced_omi_backend/workers/audio_stream_deepgram_worker.py new file mode 100644 index 00000000..ae6578eb --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/workers/audio_stream_deepgram_worker.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +""" +Deepgram audio stream worker. + +Starts a consumer that reads from audio:stream:deepgram and transcribes audio. +""" + +import os + +from advanced_omi_backend.services.transcription.streaming_consumer import ( + StreamingTranscriptionConsumer, +) +from advanced_omi_backend.workers.base_audio_worker import BaseStreamWorker + + +class DeepgramStreamWorker(BaseStreamWorker): + """Deepgram audio stream worker implementation.""" + + def __init__(self): + super().__init__(service_name="Deepgram audio stream worker") + + def validate_config(self): + """Check that config.yml has Deepgram configured.""" + # The registry provider will load configuration from config.yml + api_key = os.getenv("DEEPGRAM_API_KEY") + if not api_key: + self.logger.warning("DEEPGRAM_API_KEY environment variable not set") + self.logger.warning("Ensure config.yml has a default 'stt' model configured for Deepgram") + self.logger.warning("Audio transcription will use alternative providers if configured in config.yml") + + def get_consumer(self, redis_client): + """Create streaming transcription consumer.""" + return StreamingTranscriptionConsumer(redis_client=redis_client) + + +if __name__ == "__main__": + DeepgramStreamWorker.start() diff --git a/backends/advanced/src/advanced_omi_backend/workers/audio_stream_parakeet_worker.py b/backends/advanced/src/advanced_omi_backend/workers/audio_stream_parakeet_worker.py new file mode 100644 index 00000000..548cbb5d --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/workers/audio_stream_parakeet_worker.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +""" +Parakeet audio stream worker. + +Starts a consumer that reads from audio:stream:* and transcribes audio using Parakeet. +""" + +import os + +from advanced_omi_backend.services.transcription.streaming_consumer import ( + StreamingTranscriptionConsumer, +) +from advanced_omi_backend.workers.base_audio_worker import BaseStreamWorker + + +class ParakeetStreamWorker(BaseStreamWorker): + """Parakeet audio stream worker implementation.""" + + def __init__(self): + super().__init__(service_name="Parakeet audio stream worker") + + def validate_config(self): + """Check that config.yml has Parakeet configured.""" + # The registry provider will load configuration from config.yml + service_url = os.getenv("PARAKEET_ASR_URL") + if not service_url: + self.logger.warning("PARAKEET_ASR_URL environment variable not set") + self.logger.warning("Ensure config.yml has a default 'stt' model configured for Parakeet") + self.logger.warning("Audio transcription will use alternative providers if configured in config.yml") + + def get_consumer(self, redis_client): + """Create streaming transcription consumer.""" + return StreamingTranscriptionConsumer(redis_client=redis_client) + + +if __name__ == "__main__": + ParakeetStreamWorker.start() + diff --git a/backends/advanced/src/advanced_omi_backend/workers/base_audio_worker.py b/backends/advanced/src/advanced_omi_backend/workers/base_audio_worker.py new file mode 100644 index 00000000..da3fa89c --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/workers/base_audio_worker.py @@ -0,0 +1,156 @@ +""" +Base audio stream worker. + +Provides a template for stream workers with consistent Redis connection, +signal handling, and error management. +""" + +import asyncio +import logging +import os +import signal +import sys +from abc import ABC, abstractmethod + +import redis.asyncio as redis + +# Configure basic logging if not already configured +if not logging.getLogger().handlers: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s" + ) + +class BaseStreamWorker(ABC): + """ + Base class for audio stream workers using the Template Method pattern. + + Subclasses must implement: + - validate_config(): Check environment/config requirements + - get_consumer(redis_client): Return the specific consumer instance + """ + + def __init__(self, service_name: str): + self.service_name = service_name + self.logger = logging.getLogger(self.__class__.__name__) + self.redis_client = None + self.consumer = None + + @abstractmethod + def validate_config(self): + """ + Check required environment variables or configuration. + Should log warnings/errors if configuration is missing. + """ + pass + + @abstractmethod + def get_consumer(self, redis_client): + """ + Create and return the consumer instance. + + Args: + redis_client: Initialized Redis client + + Returns: + An instance complying with the BaseAudioStreamConsumer interface + """ + pass + + async def run(self): + """Main execution loop.""" + self.logger.info(f"πŸš€ Starting {self.service_name}") + + self.validate_config() + + redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0") + + try: + self.redis_client = await redis.from_url( + redis_url, + encoding="utf-8", + decode_responses=False + ) + self.logger.info("Connected to Redis") + except Exception as e: + self.logger.error(f"Failed to connect to Redis: {e}") + sys.exit(1) + + try: + self.consumer = self.get_consumer(self.redis_client) + except Exception as e: + self.logger.error(f"Failed to initialize consumer: {e}") + await self.redis_client.aclose() + sys.exit(1) + + # Setup graceful shutdown + loop = asyncio.get_running_loop() + stop_event = asyncio.Event() + + def signal_handler(): + self.logger.info("Received stop signal, shutting down...") + stop_event.set() + + # Register signal handlers + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, signal_handler) + except NotImplementedError: + # Fallback for environments where add_signal_handler is not supported + # (e.g. some Windows environments or custom loops) + self.logger.warning(f"Could not add signal handler for {sig}") + + try: + self.logger.info(f"βœ… {self.service_name} ready") + + # Run consumer as a task + consumer_task = asyncio.create_task(self.consumer.start_consuming()) + stop_wait_task = asyncio.create_task(stop_event.wait()) + + # Wait for either the consumer to finish (error/done) or stop signal + done, pending = await asyncio.wait( + [consumer_task, stop_wait_task], + return_when=asyncio.FIRST_COMPLETED + ) + + # Check if consumer failed + if consumer_task in done: + try: + await consumer_task + except asyncio.CancelledError: + pass + except Exception as e: + self.logger.error(f"Consumer task failed: {e}", exc_info=True) + # We continue to cleanup + + # Trigger stop on consumer + self.logger.info("Stopping consumer...") + await self.consumer.stop() + + # Ensure consumer task finishes if it was running + if consumer_task in pending: + try: + await consumer_task + except asyncio.CancelledError: + pass + except Exception as e: + # Ignore expected errors during shutdown if any + self.logger.debug(f"Consumer shutdown exception: {e}") + + except Exception as e: + self.logger.error(f"Worker runtime error: {e}", exc_info=True) + sys.exit(1) + finally: + if self.redis_client: + await self.redis_client.aclose() + self.logger.info(f"πŸ‘‹ {self.service_name} stopped") + + @classmethod + def start(cls): + """Entry point for script execution.""" + instance = cls() + try: + asyncio.run(instance.run()) + except KeyboardInterrupt: + # Handle keyboard interrupt outside the loop if it propagates + pass diff --git a/backends/advanced/tests/test_init_llm_setup.py b/backends/advanced/tests/test_init_llm_setup.py new file mode 100644 index 00000000..a5063114 --- /dev/null +++ b/backends/advanced/tests/test_init_llm_setup.py @@ -0,0 +1,136 @@ +"""Unit tests for OpenAI custom API setup/initialization flow in init.py. + +These tests verify that wizard setup can initialize OpenAI-compatible providers +with custom API endpoints (base URL), API keys, and model names. +""" + +import importlib.util +from pathlib import Path +from unittest.mock import MagicMock, call + + +_INIT_PATH = Path(__file__).resolve().parents[1] / "init.py" +_SPEC = importlib.util.spec_from_file_location("advanced_init", _INIT_PATH) +_MODULE = importlib.util.module_from_spec(_SPEC) +assert _SPEC and _SPEC.loader +_SPEC.loader.exec_module(_MODULE) +ChronicleSetup = _MODULE.ChronicleSetup + + +def _build_setup_with_mocks() -> ChronicleSetup: + """Create ChronicleSetup instance without running __init__ side effects.""" + setup = ChronicleSetup.__new__(ChronicleSetup) + setup.console = MagicMock() + setup.config = {} + setup.config_manager = MagicMock() + setup.print_section = MagicMock() + return setup + + +def test_upsert_openai_models_updates_existing_defs_and_defaults(): + """Checks OpenAI custom API config upsert in init flow. + + Verifies that init setup updates OpenAI model definitions with custom API + settings and switches defaults to those OpenAI entries. + """ + setup = _build_setup_with_mocks() + setup.config_manager.get_full_config.return_value = { + "defaults": {"llm": "local-llm", "embedding": "local-embed"}, + "models": [ + { + "name": "openai-llm", + "model_type": "llm", + "model_provider": "openai", + "model_name": "gpt-4o-mini", + "model_url": "https://api.openai.com/v1", + "api_key": "old-key", + "model_params": {"temperature": 0.3}, + }, + { + "name": "openai-embed", + "model_type": "embedding", + "model_provider": "openai", + "model_name": "text-embedding-3-small", + "model_url": "https://api.openai.com/v1", + "api_key": "old-key", + "embedding_dimensions": 1536, + }, + ], + } + + setup._upsert_openai_models( + api_key="new-key", + base_url="http://custom.example/v1", + llm_model_name="gpt-oss-20b", + embedding_model_name="text-embedding-3-large", + ) + + saved_config = setup.config_manager.save_full_config.call_args[0][0] + saved_models = {m["name"]: m for m in saved_config["models"]} + + assert saved_config["defaults"]["llm"] == "openai-llm" + assert saved_config["defaults"]["embedding"] == "openai-embed" + + assert saved_models["openai-llm"]["model_url"] == "http://custom.example/v1" + assert saved_models["openai-llm"]["api_key"] == "new-key" + assert saved_models["openai-llm"]["model_name"] == "gpt-oss-20b" + # Existing params are preserved and missing defaults are filled. + assert saved_models["openai-llm"]["model_params"]["temperature"] == 0.3 + assert saved_models["openai-llm"]["model_params"]["max_tokens"] == 2000 + + assert saved_models["openai-embed"]["model_url"] == "http://custom.example/v1" + assert saved_models["openai-embed"]["api_key"] == "new-key" + assert saved_models["openai-embed"]["model_name"] == "text-embedding-3-large" + # Existing embedding dimensions are preserved. + assert saved_models["openai-embed"]["embedding_dimensions"] == 1536 + + +def test_setup_llm_openai_prompts_for_custom_values_and_updates_models(): + """Checks init OpenAI setup prompts for custom API initialization values.""" + setup = _build_setup_with_mocks() + setup.prompt_choice = MagicMock(return_value="1") + setup.prompt_value = MagicMock( + side_effect=["http://my-openai-compatible/v1", "my-chat-model", "my-embed-model"] + ) + setup.prompt_with_existing_masked = MagicMock(return_value="test-api-key") + setup._upsert_openai_models = MagicMock() + setup.config_manager.get_full_config.return_value = { + "models": [ + {"name": "openai-llm", "model_url": "https://api.openai.com/v1", "model_name": "gpt-4o-mini"}, + {"name": "openai-embed", "model_url": "https://api.openai.com/v1", "model_name": "text-embedding-3-small"}, + ] + } + + setup.setup_llm() + + setup.prompt_value.assert_has_calls( + [ + call("OpenAI-compatible base URL", "https://api.openai.com/v1"), + call("LLM model name", "gpt-4o-mini"), + call("Embedding model name", "text-embedding-3-small"), + ] + ) + setup._upsert_openai_models.assert_called_once_with( + api_key="test-api-key", + base_url="http://my-openai-compatible/v1", + llm_model_name="my-chat-model", + embedding_model_name="my-embed-model", + ) + assert setup.config["OPENAI_API_KEY"] == "test-api-key" + + +def test_setup_llm_openai_skips_upsert_when_api_key_missing(): + """Checks init OpenAI custom API setup guards against missing API key.""" + setup = _build_setup_with_mocks() + setup.prompt_choice = MagicMock(return_value="1") + setup.prompt_value = MagicMock( + side_effect=["https://api.openai.com/v1", "gpt-4o-mini", "text-embedding-3-small"] + ) + setup.prompt_with_existing_masked = MagicMock(return_value="") + setup._upsert_openai_models = MagicMock() + setup.config_manager.get_full_config.return_value = {"models": []} + + setup.setup_llm() + + setup._upsert_openai_models.assert_not_called() + assert "OPENAI_API_KEY" not in setup.config diff --git a/backends/advanced/tests/unit/workers/__init__.py b/backends/advanced/tests/unit/workers/__init__.py new file mode 100644 index 00000000..8a802dcb --- /dev/null +++ b/backends/advanced/tests/unit/workers/__init__.py @@ -0,0 +1 @@ +"""Tests for audio stream workers.""" diff --git a/backends/advanced/tests/unit/workers/test_audio_stream_workers.py b/backends/advanced/tests/unit/workers/test_audio_stream_workers.py new file mode 100644 index 00000000..93a72b1e --- /dev/null +++ b/backends/advanced/tests/unit/workers/test_audio_stream_workers.py @@ -0,0 +1,345 @@ +"""Unit tests for audio stream workers using the Template Method pattern.""" + +import asyncio +import os +import signal +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest +import redis.asyncio as redis + +from advanced_omi_backend.services.transcription.streaming_consumer import ( + StreamingTranscriptionConsumer, +) +from advanced_omi_backend.workers.audio_stream_deepgram_worker import DeepgramStreamWorker +from advanced_omi_backend.workers.audio_stream_parakeet_worker import ParakeetStreamWorker +from advanced_omi_backend.workers.base_audio_worker import BaseStreamWorker + + +@pytest.mark.unit +class TestBaseStreamWorker: + """Test the BaseStreamWorker template class.""" + + def test_abstract_methods_must_be_implemented(self): + """Test that BaseStreamWorker cannot be instantiated without implementing abstract methods.""" + with pytest.raises(TypeError, match="abstract methods"): + BaseStreamWorker("test-service") + + def test_service_name_initialization(self): + """Test that service name is properly set during initialization.""" + + class ConcreteWorker(BaseStreamWorker): + def validate_config(self): + pass + + def get_consumer(self, redis_client): + pass + + worker = ConcreteWorker("test-worker") + assert worker.service_name == "test-worker" + assert worker.redis_client is None + assert worker.consumer is None + + @pytest.mark.asyncio + async def test_redis_connection_failure_exits(self): + """Test that worker exits gracefully when Redis connection fails.""" + + class ConcreteWorker(BaseStreamWorker): + def validate_config(self): + pass + + def get_consumer(self, redis_client): + pass + + worker = ConcreteWorker("test-worker") + + async def raise_connection_error(*args, **kwargs): + raise Exception("Connection failed") + + with patch("redis.asyncio.from_url", side_effect=raise_connection_error): + with pytest.raises(SystemExit) as exc_info: + await worker.run() + assert exc_info.value.code == 1 + + @pytest.mark.asyncio + async def test_consumer_initialization_failure_exits(self): + """Test that worker exits gracefully when consumer initialization fails.""" + + class ConcreteWorker(BaseStreamWorker): + def validate_config(self): + pass + + def get_consumer(self, redis_client): + raise ValueError("Consumer init failed") + + worker = ConcreteWorker("test-worker") + + mock_redis = AsyncMock() + + async def mock_from_url(*args, **kwargs): + return mock_redis + + with patch("redis.asyncio.from_url", side_effect=mock_from_url): + with pytest.raises(SystemExit) as exc_info: + await worker.run() + assert exc_info.value.code == 1 + mock_redis.aclose.assert_called_once() + + @pytest.mark.asyncio + async def test_successful_worker_lifecycle(self): + """Test complete worker lifecycle with successful execution.""" + + class ConcreteWorker(BaseStreamWorker): + def validate_config(self): + pass + + def get_consumer(self, redis_client): + consumer = AsyncMock() + consumer.start_consuming = AsyncMock() + consumer.stop = AsyncMock() + return consumer + + worker = ConcreteWorker("test-worker") + mock_redis = AsyncMock() + + # Simulate quick consumer completion + async def quick_consume(): + await asyncio.sleep(0.01) + + async def mock_from_url(*args, **kwargs): + return mock_redis + + with patch("redis.asyncio.from_url", side_effect=mock_from_url): + with patch.object(worker.__class__, "get_consumer") as mock_get_consumer: + mock_consumer = AsyncMock() + mock_consumer.start_consuming = quick_consume + mock_consumer.stop = AsyncMock() + mock_get_consumer.return_value = mock_consumer + + await worker.run() + + mock_redis.aclose.assert_called_once() + mock_consumer.stop.assert_called_once() + + +@pytest.mark.unit +class TestDeepgramStreamWorker: + """Test DeepgramStreamWorker implementation.""" + + def test_initialization(self): + """Test that DeepgramStreamWorker initializes correctly.""" + worker = DeepgramStreamWorker() + assert worker.service_name == "Deepgram audio stream worker" + assert hasattr(worker, "logger") + + def test_validate_config_with_api_key(self): + """Test config validation when DEEPGRAM_API_KEY is set.""" + worker = DeepgramStreamWorker() + + with patch.dict(os.environ, {"DEEPGRAM_API_KEY": "test-key-123"}): + # Should not raise any exceptions or warnings + worker.validate_config() + + def test_validate_config_without_api_key(self): + """Test config validation when DEEPGRAM_API_KEY is missing.""" + worker = DeepgramStreamWorker() + + with patch.dict(os.environ, {}, clear=True): + with patch.object(worker.logger, "warning") as mock_warning: + worker.validate_config() + # Should log 3 warnings about missing API key + assert mock_warning.call_count == 3 + + def test_get_consumer_creates_streaming_consumer(self): + """Test that get_consumer returns a StreamingTranscriptionConsumer instance.""" + worker = DeepgramStreamWorker() + mock_redis = Mock() + + # Mock the config/registry system that StreamingTranscriptionConsumer uses + with patch( + "advanced_omi_backend.services.transcription.streaming_consumer.get_transcription_provider" + ) as mock_get_provider: + mock_provider = Mock() + mock_get_provider.return_value = mock_provider + + consumer = worker.get_consumer(mock_redis) + + assert isinstance(consumer, StreamingTranscriptionConsumer) + # Verify consumer has required async methods + assert hasattr(consumer, "start_consuming") + assert hasattr(consumer, "stop") + assert callable(consumer.start_consuming) + assert callable(consumer.stop) + + def test_start_method_runs_worker(self): + """Test that start() class method creates instance and schedules run() via asyncio.run.""" + captured_coro = None + + with patch.object(DeepgramStreamWorker, "run", new_callable=AsyncMock) as mock_run: + with patch("asyncio.run") as mock_asyncio_run: + def capture_and_close(coro): + nonlocal captured_coro + captured_coro = coro + coro.close() + + mock_asyncio_run.side_effect = capture_and_close + + DeepgramStreamWorker.start() + + mock_run.assert_called_once_with() + mock_asyncio_run.assert_called_once_with(captured_coro) + assert captured_coro is not None + assert asyncio.iscoroutine(captured_coro) + + +@pytest.mark.unit +class TestParakeetStreamWorker: + """Test ParakeetStreamWorker implementation.""" + + def test_initialization(self): + """Test that ParakeetStreamWorker initializes correctly.""" + worker = ParakeetStreamWorker() + assert worker.service_name == "Parakeet audio stream worker" + assert hasattr(worker, "logger") + + def test_validate_config_with_service_url(self): + """Test config validation when PARAKEET_ASR_URL is set.""" + worker = ParakeetStreamWorker() + + with patch.dict(os.environ, {"PARAKEET_ASR_URL": "http://localhost:8767"}): + # Should not raise any exceptions or warnings + worker.validate_config() + + def test_validate_config_without_service_url(self): + """Test config validation when PARAKEET_ASR_URL is missing.""" + worker = ParakeetStreamWorker() + + with patch.dict(os.environ, {}, clear=True): + with patch.object(worker.logger, "warning") as mock_warning: + worker.validate_config() + # Should log 3 warnings about missing service URL + assert mock_warning.call_count == 3 + + def test_get_consumer_creates_streaming_consumer(self): + """Test that get_consumer returns a StreamingTranscriptionConsumer instance.""" + worker = ParakeetStreamWorker() + mock_redis = Mock() + + # Mock the config/registry system that StreamingTranscriptionConsumer uses + with patch( + "advanced_omi_backend.services.transcription.streaming_consumer.get_transcription_provider" + ) as mock_get_provider: + mock_provider = Mock() + mock_get_provider.return_value = mock_provider + + consumer = worker.get_consumer(mock_redis) + + assert isinstance(consumer, StreamingTranscriptionConsumer) + # Verify consumer has required async methods + assert hasattr(consumer, "start_consuming") + assert hasattr(consumer, "stop") + assert callable(consumer.start_consuming) + assert callable(consumer.stop) + + @pytest.mark.asyncio + async def test_start_method_runs_worker(self): + """Test that start() class method creates instance and runs it.""" + with patch.object(ParakeetStreamWorker, "run", new_callable=AsyncMock) as mock_run: + with patch("asyncio.run") as mock_asyncio_run: + # Simulate script execution + mock_asyncio_run.side_effect = lambda coro: asyncio.new_event_loop().run_until_complete( + coro + ) + + worker_instance = ParakeetStreamWorker() + await worker_instance.run() + + mock_run.assert_called_once() + + +@pytest.mark.unit +class TestWorkerIntegration: + """Integration tests for worker components.""" + + @pytest.mark.asyncio + async def test_deepgram_worker_handles_shutdown_signal(self): + """Test that DeepgramStreamWorker handles shutdown signals gracefully.""" + worker = DeepgramStreamWorker() + mock_redis = AsyncMock() + mock_consumer = AsyncMock() + + # Consumer runs for a short time then completes + async def simulate_work(): + await asyncio.sleep(0.05) + + mock_consumer.start_consuming = simulate_work + mock_consumer.stop = AsyncMock() + + async def mock_from_url(*args, **kwargs): + return mock_redis + + with patch("redis.asyncio.from_url", side_effect=mock_from_url): + with patch.object(worker, "get_consumer", return_value=mock_consumer): + # Run worker in background + task = asyncio.create_task(worker.run()) + + # Let it start + await asyncio.sleep(0.01) + + # Should complete naturally + await task + + mock_consumer.stop.assert_called_once() + mock_redis.aclose.assert_called_once() + + @pytest.mark.asyncio + async def test_parakeet_worker_handles_shutdown_signal(self): + """Test that ParakeetStreamWorker handles shutdown signals gracefully.""" + worker = ParakeetStreamWorker() + mock_redis = AsyncMock() + mock_consumer = AsyncMock() + + # Consumer runs for a short time then completes + async def simulate_work(): + await asyncio.sleep(0.05) + + mock_consumer.start_consuming = simulate_work + mock_consumer.stop = AsyncMock() + + async def mock_from_url(*args, **kwargs): + return mock_redis + + with patch("redis.asyncio.from_url", side_effect=mock_from_url): + with patch.object(worker, "get_consumer", return_value=mock_consumer): + # Run worker in background + task = asyncio.create_task(worker.run()) + + # Let it start + await asyncio.sleep(0.01) + + # Should complete naturally + await task + + mock_consumer.stop.assert_called_once() + mock_redis.aclose.assert_called_once() + + def test_workers_share_consistent_behavior(self): + """Test that both workers use consistent shutdown and error handling.""" + deepgram_worker = DeepgramStreamWorker() + parakeet_worker = ParakeetStreamWorker() + + # Both should have same base class + assert isinstance(deepgram_worker, BaseStreamWorker) + assert isinstance(parakeet_worker, BaseStreamWorker) + + # Both should implement required methods + assert callable(deepgram_worker.validate_config) + assert callable(deepgram_worker.get_consumer) + assert callable(parakeet_worker.validate_config) + assert callable(parakeet_worker.get_consumer) + + # Both should inherit run method from BaseStreamWorker + assert hasattr(deepgram_worker, "run") + assert hasattr(parakeet_worker, "run") + # Verify they use the same base implementation + assert type(deepgram_worker).run == type(parakeet_worker).run diff --git a/wizard.py b/wizard.py index e3beb37a..e36bf1d2 100755 --- a/wizard.py +++ b/wizard.py @@ -6,728 +6,581 @@ import shutil import subprocess +import sys from datetime import datetime from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union import yaml +from dotenv import get_key +from rich import print as rprint from rich.console import Console -from rich.prompt import Confirm, Prompt - -# Import shared setup utilities -from setup_utils import ( - detect_tailscale_info, - is_placeholder, - mask_value, - prompt_with_existing_masked, - read_env_value, -) +from rich.prompt import Confirm console = Console() -SERVICES = { - 'backend': { - 'advanced': { - 'path': 'backends/advanced', - 'cmd': ['uv', 'run', '--with-requirements', '../../setup-requirements.txt', 'python', 'init.py'], - 'description': 'Advanced AI backend with full feature set', - 'required': True +# Type definitions +ServiceConfig = Dict[str, Any] +ServiceGroup = Dict[str, ServiceConfig] +ServicesData = Dict[str, ServiceGroup] + +SERVICES: ServicesData = { + "backend": { + "advanced": { + "path": "backends/advanced", + "cmd": [ + "uv", + "run", + "--with-requirements", + "../../setup-requirements.txt", + "python", + "init.py", + ], + "description": "Advanced Backend with full feature set", + "required": True, } }, - 'extras': { - 'speaker-recognition': { - 'path': 'extras/speaker-recognition', - 'cmd': ['uv', 'run', '--with-requirements', '../../setup-requirements.txt', 'python', 'init.py'], - 'description': 'Speaker identification and enrollment' + "extras": { + "speaker-recognition": { + "path": "extras/speaker-recognition", + "cmd": [ + "uv", + "run", + "--with-requirements", + "../../setup-requirements.txt", + "python", + "init.py", + ], + "description": "Speaker identification and enrollment", }, - 'asr-services': { - 'path': 'extras/asr-services', - 'cmd': ['uv', 'run', '--with-requirements', '../../setup-requirements.txt', 'python', 'init.py'], - 'description': 'Offline speech-to-text' + "asr-services": { + "path": "extras/asr-services", + "cmd": [ + "uv", + "run", + "--with-requirements", + "../../setup-requirements.txt", + "python", + "init.py", + ], + "description": "Offline speech-to-text (Parakeet)", }, - 'openmemory-mcp': { - 'path': 'extras/openmemory-mcp', - 'cmd': ['./setup.sh'], - 'description': 'OpenMemory MCP server' - } - } + "openmemory-mcp": { + "path": "extras/openmemory-mcp", + "cmd": ["./setup.sh"], + "description": "OpenMemory MCP server", + }, + }, } -def discover_available_plugins(): - """ - Discover plugins by scanning plugins directory. - - Returns: - Dictionary mapping plugin_id to plugin metadata: - { - 'plugin_id': { - 'has_setup': bool, - 'setup_path': Path or None, - 'dir': Path - } - } - """ - plugins_dir = Path("backends/advanced/src/advanced_omi_backend/plugins") - if not plugins_dir.exists(): - console.print(f"[yellow]Warning: Plugins directory not found: {plugins_dir}[/yellow]") - return {} +def read_env_value(env_file_path: Union[str, Path], key: str) -> Optional[str]: + """Read a value from an .env file using python-dotenv""" + env_path = Path(env_file_path) + if not env_path.exists(): + return None - discovered = {} - skip_dirs = {'__pycache__', '__init__.py', 'base.py', 'router.py'} + value = get_key(str(env_path), key) + return value if value else None - for plugin_dir in plugins_dir.iterdir(): - if not plugin_dir.is_dir() or plugin_dir.name in skip_dirs: - continue - plugin_id = plugin_dir.name - setup_script = plugin_dir / "setup.py" +def is_placeholder(value: Optional[str], *placeholder_variants: str) -> bool: + """ + Check if a value is a placeholder or empty. + """ + if not value: + return True + + normalized_value = value.replace("-", "_").lower() - discovered[plugin_id] = { - 'has_setup': setup_script.exists(), - 'setup_path': setup_script if setup_script.exists() else None, - 'dir': plugin_dir - } + for placeholder in placeholder_variants: + normalized_placeholder = placeholder.replace("-", "_").lower() + if normalized_value == normalized_placeholder: + return True - return discovered + return False -def check_service_exists(service_name, service_config): + +def check_service_exists( + service_name: str, service_config: ServiceConfig +) -> Tuple[bool, str]: """Check if service directory and script exist""" - service_path = Path(service_config['path']) + service_path = Path(service_config["path"]) if not service_path.exists(): return False, f"Directory {service_path} does not exist" - # For services with Python init scripts, check if init.py exists - if service_name in ['advanced', 'speaker-recognition', 'asr-services']: - script_path = service_path / 'init.py' + # For services with Python init scripts + if service_name in ["advanced", "speaker-recognition", "asr-services"]: + script_path = service_path / "init.py" if not script_path.exists(): return False, f"Script {script_path} does not exist" else: - # For other extras, check if setup.sh exists - script_path = service_path / 'setup.sh' + # For other extras (shell scripts) + script_path = service_path / "setup.sh" if not script_path.exists(): - return False, f"Script {script_path} does not exist (will be created in Phase 2)" + return ( + False, + f"Script {script_path} does not exist (will be created in Phase 2)", + ) return True, "OK" -def select_services(transcription_provider=None): + +def _ensure_hf_token() -> Optional[str]: + """Ensure Hugging Face token is available for speaker-recognition""" + speaker_env_path = "extras/speaker-recognition/.env" + hf_token = read_env_value(speaker_env_path, "HF_TOKEN") + + if not hf_token or is_placeholder( + hf_token, + "your_huggingface_token_here", + "your-huggingface-token-here", + "hf_xxxxx", + ): + console.print( + "\n[red][ERROR][/red] HF_TOKEN is required for speaker-recognition service" + ) + console.print( + "[yellow]Speaker recognition requires a Hugging Face token to download models[/yellow]" + ) + console.print("Get your token from: https://huggingface.co/settings/tokens") + console.print() + + try: + hf_token_input = console.input("[cyan]Enter your HF_TOKEN[/cyan]: ").strip() + if not hf_token_input or is_placeholder( + hf_token_input, "your_huggingface_token_here", "hf_xxxxx" + ): + console.print("[red][ERROR][/red] Invalid HF_TOKEN provided.") + return None + return hf_token_input + except EOFError: + return None + + return hf_token + + +def _configure_advanced_backend( + cmd: List[str], + selected_services: List[str], + https_enabled: bool, + server_ip: Optional[str], + obsidian_enabled: bool, + neo4j_password: Optional[str], +) -> List[str]: + """Configure arguments for advanced backend""" + new_cmd = cmd.copy() + if "speaker-recognition" in selected_services: + new_cmd.extend(["--speaker-service-url", "http://speaker-service:8085"]) + if "asr-services" in selected_services: + new_cmd.extend(["--parakeet-asr-url", "http://host.docker.internal:8767"]) + + if https_enabled and server_ip: + new_cmd.extend(["--enable-https", "--server-ip", server_ip]) + + if obsidian_enabled and neo4j_password: + new_cmd.extend(["--enable-obsidian", "--neo4j-password", neo4j_password]) + + return new_cmd + + +def _configure_speaker_recognition( + cmd: List[str], https_enabled: bool, server_ip: Optional[str] +) -> Optional[List[str]]: + """Configure arguments for speaker recognition""" + new_cmd = cmd.copy() + + if https_enabled and server_ip: + new_cmd.extend(["--enable-https", "--server-ip", server_ip]) + + # HF Token + hf_token = _ensure_hf_token() + if not hf_token: + return None + new_cmd.extend(["--hf-token", hf_token]) + console.print("[green][SUCCESS][/green] HF_TOKEN configured") + + # Deepgram Key Reuse + backend_env = "backends/advanced/.env" + deepgram_key = read_env_value(backend_env, "DEEPGRAM_API_KEY") + if deepgram_key and not is_placeholder( + deepgram_key, "your_deepgram_api_key_here" + ): + new_cmd.extend(["--deepgram-api-key", deepgram_key]) + console.print( + "[blue][INFO][/blue] Found existing DEEPGRAM_API_KEY from backend config, reusing" + ) + + # Compute Mode Reuse + speaker_env = "extras/speaker-recognition/.env" + compute_mode = read_env_value(speaker_env, "COMPUTE_MODE") + if compute_mode in ["cpu", "gpu"]: + new_cmd.extend(["--compute-mode", compute_mode]) + console.print( + f"[blue][INFO][/blue] Found existing COMPUTE_MODE ({compute_mode}), reusing" + ) + + return new_cmd + + +def _configure_asr_services(cmd: List[str]) -> List[str]: + """Configure arguments for ASR services""" + new_cmd = cmd.copy() + speaker_env = "extras/speaker-recognition/.env" + cuda_version = read_env_value(speaker_env, "PYTORCH_CUDA_VERSION") + if cuda_version and cuda_version in ["cu121", "cu126", "cu128"]: + new_cmd.extend(["--pytorch-cuda-version", cuda_version]) + console.print( + f"[blue][INFO][/blue] Found existing PYTORCH_CUDA_VERSION ({cuda_version}) from speaker-recognition, reusing" + ) + return new_cmd + + +def _configure_openmemory_mcp(cmd: List[str]) -> List[str]: + """Configure arguments for OpenMemory MCP""" + new_cmd = cmd.copy() + backend_env = "backends/advanced/.env" + openai_key = read_env_value(backend_env, "OPENAI_API_KEY") + if openai_key and not is_placeholder( + openai_key, "your_openai_api_key_here", "your_openai_key_here" + ): + new_cmd.extend(["--openai-api-key", openai_key]) + console.print( + "[blue][INFO][/blue] Found existing OPENAI_API_KEY from backend config, reusing" + ) + return new_cmd + + +def run_service_setup( + service_name: str, + selected_services: List[str], + https_enabled: bool = False, + server_ip: Optional[str] = None, + obsidian_enabled: bool = False, + neo4j_password: Optional[str] = None, +) -> bool: + """Execute individual service setup script""" + console.print(f"\nπŸ”§ [bold]Setting up {service_name}...[/bold]") + + # Identify service config + if service_name == "advanced": + service = SERVICES["backend"][service_name] + cmd = _configure_advanced_backend( + service["cmd"], + selected_services, + https_enabled, + server_ip, + obsidian_enabled, + neo4j_password, + ) + else: + service = SERVICES["extras"][service_name] + cmd = service["cmd"] + + if service_name == "speaker-recognition": + result_cmd = _configure_speaker_recognition(cmd, https_enabled, server_ip) + if result_cmd is None: + return False + cmd = result_cmd + elif service_name == "asr-services": + cmd = _configure_asr_services(cmd) + elif service_name == "openmemory-mcp": + cmd = _configure_openmemory_mcp(cmd) + + exists, msg = check_service_exists(service_name, service) + if not exists: + console.print(f"❌ {service_name} setup failed: {msg}") + return False + + try: + subprocess.run(cmd, cwd=service["path"], check=True, timeout=300) + console.print(f"βœ… {service_name} setup completed") + return True + except ( + subprocess.CalledProcessError, + subprocess.TimeoutExpired, + FileNotFoundError, + ) as e: + console.print(f"❌ {service_name} setup failed: {e}") + return False + except Exception as e: + console.print(f"❌ {service_name} setup failed (unexpected): {e}") + return False + + +def select_services() -> List[str]: """Let user select which services to setup""" console.print("πŸš€ [bold cyan]Chronicle Service Setup[/bold cyan]") console.print("Select which services to configure:\n") selected = [] - # Backend is required + # Backend console.print("πŸ“± [bold]Backend (Required):[/bold]") console.print(" βœ… Advanced Backend - Full AI features") - selected.append('advanced') + selected.append("advanced") - # Services that will be auto-added based on transcription provider choice - auto_added = set() - if transcription_provider in ("parakeet", "vibevoice"): - auto_added.add('asr-services') - - # Optional extras + # Extras console.print("\nπŸ”§ [bold]Optional Services:[/bold]") - for service_name, service_config in SERVICES['extras'].items(): - # Skip services that will be auto-added based on earlier choices - if service_name in auto_added: - provider_label = {"vibevoice": "VibeVoice", "parakeet": "Parakeet"}.get(transcription_provider, transcription_provider) - console.print(f" βœ… {service_config['description']} ({provider_label}) [dim](auto-selected)[/dim]") - continue - - # Check if service exists + for service_name, service_config in SERVICES["extras"].items(): exists, msg = check_service_exists(service_name, service_config) if not exists: console.print(f" ⏸️ {service_config['description']} - [dim]{msg}[/dim]") continue try: - enable_service = Confirm.ask(f" Setup {service_config['description']}?", default=False) + if Confirm.ask(f" Setup {service_config['description']}?", default=False): + selected.append(service_name) except EOFError: - console.print("Using default: No") - enable_service = False - - if enable_service: - selected.append(service_name) + pass return selected -def cleanup_unselected_services(selected_services): + +def cleanup_unselected_services(selected_services: List[str]) -> None: """Backup and remove .env files from services that weren't selected""" - - all_services = list(SERVICES['backend'].keys()) + list(SERVICES['extras'].keys()) - + all_services = list(SERVICES["backend"].keys()) + list(SERVICES["extras"].keys()) + for service_name in all_services: if service_name not in selected_services: - if service_name == 'advanced': - service_path = Path(SERVICES['backend'][service_name]['path']) + if service_name == "advanced": + service_path = Path(SERVICES["backend"][service_name]["path"]) else: - service_path = Path(SERVICES['extras'][service_name]['path']) - - env_file = service_path / '.env' + service_path = Path(SERVICES["extras"][service_name]["path"]) + + env_file = service_path / ".env" if env_file.exists(): - # Create backup with timestamp timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - backup_file = service_path / f'.env.backup.{timestamp}.unselected' + backup_file = service_path / f".env.backup.{timestamp}.unselected" env_file.rename(backup_file) - console.print(f"🧹 [dim]Backed up {service_name} configuration to {backup_file.name} (service not selected)[/dim]") - -def run_service_setup(service_name, selected_services, https_enabled=False, server_ip=None, - obsidian_enabled=False, neo4j_password=None, hf_token=None, - transcription_provider='deepgram'): - """Execute individual service setup script""" - if service_name == 'advanced': - service = SERVICES['backend'][service_name] + console.print( + f"🧹 [dim]Backed up {service_name} config to {backup_file.name}[/dim]" + ) - # For advanced backend, pass URLs of other selected services and HTTPS config - cmd = service['cmd'].copy() - if 'speaker-recognition' in selected_services: - cmd.extend(['--speaker-service-url', 'http://speaker-service:8085']) - if 'asr-services' in selected_services: - cmd.extend(['--parakeet-asr-url', 'http://host.docker.internal:8767']) - # Pass transcription provider choice from wizard - if transcription_provider: - cmd.extend(['--transcription-provider', transcription_provider]) - - # Add HTTPS configuration - if https_enabled and server_ip: - cmd.extend(['--enable-https', '--server-ip', server_ip]) +def setup_https(selected_services: List[str]) -> Tuple[bool, Optional[str]]: + """Prompt and configure HTTPS settings""" + # Check if we have services that benefit from HTTPS + https_services = {"advanced", "speaker-recognition"} + needs_https = bool(https_services.intersection(selected_services)) - # Add Obsidian configuration - if obsidian_enabled and neo4j_password: - cmd.extend(['--enable-obsidian', '--neo4j-password', neo4j_password]) + if not needs_https: + return False, None - else: - service = SERVICES['extras'][service_name] - cmd = service['cmd'].copy() - - # Add HTTPS configuration for services that support it - if service_name == 'speaker-recognition' and https_enabled and server_ip: - cmd.extend(['--enable-https', '--server-ip', server_ip]) - - # For speaker-recognition, pass HF_TOKEN from centralized configuration - if service_name == 'speaker-recognition': - # Define the speaker env path - speaker_env_path = 'extras/speaker-recognition/.env' - - # HF Token should have been provided via setup_hf_token_if_needed() - if hf_token: - cmd.extend(['--hf-token', hf_token]) - else: - console.print("[yellow][WARNING][/yellow] No HF_TOKEN provided - speaker recognition may fail to download models") - - # Pass Deepgram API key from backend if available - backend_env_path = 'backends/advanced/.env' - deepgram_key = read_env_value(backend_env_path, 'DEEPGRAM_API_KEY') - if deepgram_key and not is_placeholder(deepgram_key, 'your_deepgram_api_key_here', 'your-deepgram-api-key-here'): - cmd.extend(['--deepgram-api-key', deepgram_key]) - console.print("[blue][INFO][/blue] Found existing DEEPGRAM_API_KEY from backend config, reusing") - - # Pass compute mode from existing .env if available - compute_mode = read_env_value(speaker_env_path, 'COMPUTE_MODE') - if compute_mode in ['cpu', 'gpu']: - cmd.extend(['--compute-mode', compute_mode]) - console.print(f"[blue][INFO][/blue] Found existing COMPUTE_MODE ({compute_mode}), reusing") - - # For asr-services, pass provider from wizard's transcription choice and reuse CUDA version - if service_name == 'asr-services': - # Map wizard transcription provider to asr-services provider name - wizard_to_asr_provider = { - 'vibevoice': 'vibevoice', - 'parakeet': 'nemo', - } - asr_provider = wizard_to_asr_provider.get(transcription_provider) - if asr_provider: - cmd.extend(['--provider', asr_provider]) - console.print(f"[blue][INFO][/blue] Pre-selecting ASR provider: {asr_provider} (from wizard choice: {transcription_provider})") - - speaker_env_path = 'extras/speaker-recognition/.env' - cuda_version = read_env_value(speaker_env_path, 'PYTORCH_CUDA_VERSION') - if cuda_version and cuda_version in ['cu121', 'cu126', 'cu128']: - cmd.extend(['--pytorch-cuda-version', cuda_version]) - console.print(f"[blue][INFO][/blue] Found existing PYTORCH_CUDA_VERSION ({cuda_version}) from speaker-recognition, reusing") - - # For openmemory-mcp, try to pass OpenAI API key from backend if available - if service_name == 'openmemory-mcp': - backend_env_path = 'backends/advanced/.env' - openai_key = read_env_value(backend_env_path, 'OPENAI_API_KEY') - if openai_key and not is_placeholder(openai_key, 'your_openai_api_key_here', 'your-openai-api-key-here', 'your_openai_key_here', 'your-openai-key-here'): - cmd.extend(['--openai-api-key', openai_key]) - console.print("[blue][INFO][/blue] Found existing OPENAI_API_KEY from backend config, reusing") - - console.print(f"\nπŸ”§ [bold]Setting up {service_name}...[/bold]") - - # Check if service exists before running - exists, msg = check_service_exists(service_name, service) - if not exists: - console.print(f"❌ {service_name} setup failed: {msg}") - return False - + console.print("\nπŸ”’ [bold cyan]HTTPS Configuration[/bold cyan]") try: - result = subprocess.run( - cmd, - cwd=service['path'], - check=True, - timeout=300 # 5 minute timeout for service setup - ) - - console.print(f"βœ… {service_name} setup completed") - return True - - except FileNotFoundError as e: - console.print(f"❌ {service_name} setup failed: {e}") - console.print(f"[yellow] Check that the service directory exists: {service['path']}[/yellow]") - console.print(f"[yellow] And that 'uv' is installed and on your PATH[/yellow]") - return False - except subprocess.TimeoutExpired as e: - console.print(f"❌ {service_name} setup timed out after {e.timeout}s") - console.print(f"[yellow] Configuration may be partially written.[/yellow]") - console.print(f"[yellow] To retry just this service:[/yellow]") - console.print(f"[yellow] cd {service['path']} && {' '.join(service['cmd'])}[/yellow]") - return False - except subprocess.CalledProcessError as e: - console.print(f"❌ {service_name} setup failed with exit code {e.returncode}") - console.print(f"[yellow] Check the error output above for details.[/yellow]") - console.print(f"[yellow] To retry just this service:[/yellow]") - console.print(f"[yellow] cd {service['path']} && {' '.join(service['cmd'])}[/yellow]") - return False - except Exception as e: - console.print(f"❌ {service_name} setup failed: {e}") - return False - -def show_service_status(): - """Show which services are available""" - console.print("\nπŸ“‹ [bold]Service Status:[/bold]") - - # Check backend - exists, msg = check_service_exists('advanced', SERVICES['backend']['advanced']) - status = "βœ…" if exists else "❌" - console.print(f" {status} Advanced Backend - {msg}") - - # Check extras - for service_name, service_config in SERVICES['extras'].items(): - exists, msg = check_service_exists(service_name, service_config) - status = "βœ…" if exists else "⏸️" - console.print(f" {status} {service_config['description']} - {msg}") + if not Confirm.ask("Enable HTTPS for selected services?", default=False): + return False, None + except EOFError: + return False, None -def run_plugin_setup(plugin_id, plugin_info): - """Run a plugin's setup.py script""" - setup_path = plugin_info['setup_path'] + console.print("\n[blue][INFO][/blue] For distributed deployments, use your Tailscale IP") + console.print("Examples: localhost, 100.64.1.2, your-domain.com") - try: - # Run plugin setup script interactively (don't capture output) - # This allows the plugin to prompt for user input - result = subprocess.run( - ['uv', 'run', '--with-requirements', 'setup-requirements.txt', 'python', str(setup_path)], - cwd=str(Path.cwd()) - ) + backend_env_path = "backends/advanced/.env" + existing_ip = read_env_value(backend_env_path, "SERVER_IP") + default_value = ( + existing_ip + if existing_ip and existing_ip not in ["localhost", "your-server-ip-here"] + else "localhost" + ) - if result.returncode == 0: - console.print(f"\n[green]βœ… {plugin_id} configured successfully[/green]") - return True - else: - console.print(f"\n[red]❌ {plugin_id} setup failed with exit code {result.returncode}[/red]") - return False + prompt_text = f"Server IP/Domain [{default_value}]" - except Exception as e: - console.print(f"[red]❌ Error running {plugin_id} setup: {e}[/red]") - return False + while True: + try: + server_ip = console.input(f"{prompt_text}: ").strip() + if not server_ip: + server_ip = default_value + break + except EOFError: + server_ip = default_value + break -def setup_plugins(): - """Discover and setup plugins via delegation""" - console.print("\nπŸ”Œ [bold cyan]Plugin Configuration[/bold cyan]") - console.print("Chronicle supports community plugins for extended functionality.\n") + console.print(f"[green]βœ…[/green] HTTPS configured for: {server_ip}") + return True, server_ip - # Discover available plugins - available_plugins = discover_available_plugins() - if not available_plugins: - console.print("[dim]No plugins found[/dim]") - return +def setup_obsidian(selected_services: List[str]) -> Tuple[bool, Optional[str]]: + """Prompt and configure Obsidian/Neo4j settings""" + if "advanced" not in selected_services: + return False, None - # Ask about enabling community plugins + console.print("\nπŸ—‚οΈ [bold cyan]Obsidian/Neo4j Integration[/bold cyan]") try: - enable_plugins = Confirm.ask( - "Enable community plugins?", - default=True - ) + if not Confirm.ask("Enable Obsidian/Neo4j integration?", default=False): + return False, None except EOFError: - console.print("Using default: Yes") - enable_plugins = True - - if not enable_plugins: - console.print("[dim]Skipping plugin configuration[/dim]") - return + return False, None - # For each plugin with setup script - configured_count = 0 - for plugin_id, plugin_info in available_plugins.items(): - if not plugin_info['has_setup']: - console.print(f"[dim] {plugin_id}: No setup wizard available (configure manually)[/dim]") - continue + console.print("[blue][INFO][/blue] Neo4j will be configured for graph-based memory storage\n") - # Ask if user wants to configure this plugin + while True: try: - configure = Confirm.ask( - f" Configure {plugin_id} plugin?", - default=False + password = ( + console.input("Neo4j password (min 8 chars) [default: neo4jpassword]: ").strip() + or "neo4jpassword" ) + if len(password) >= 8: + return True, password + console.print("[yellow][WARNING][/yellow] Password must be at least 8 characters") except EOFError: - configure = False + return True, "neo4jpassword" + + +def show_service_status() -> None: + """Show which services are available""" + console.print("\nπŸ“‹ [bold]Service Status:[/bold]") + + # Check backend + exists, msg = check_service_exists("advanced", SERVICES["backend"]["advanced"]) + status = "βœ…" if exists else "❌" + console.print(f" {status} Advanced Backend - {msg}") - if configure: - # Delegate to plugin's setup script - console.print(f"\n[cyan]Running {plugin_id} setup wizard...[/cyan]") - success = run_plugin_setup(plugin_id, plugin_info) - if success: - configured_count += 1 + # Check extras + for service_name, service_config in SERVICES["extras"].items(): + exists, msg = check_service_exists(service_name, service_config) + status = "βœ…" if exists else "⏸️" + console.print(f" {status} {service_config['description']} - {msg}") - console.print(f"\n[green]βœ… Configured {configured_count} plugin(s)[/green]") -def setup_git_hooks(): +def setup_git_hooks() -> None: """Setup pre-commit hooks for development""" console.print("\nπŸ”§ [bold]Setting up development environment...[/bold]") try: - # Install pre-commit if not already installed - subprocess.run(['pip', 'install', 'pre-commit'], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - check=False) + subprocess.run( + ["pip", "install", "pre-commit"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + check=False, + ) - # Install git hooks - result = subprocess.run(['pre-commit', 'install', '--hook-type', 'pre-push'], - capture_output=True, - text=True) + result = subprocess.run( + ["pre-commit", "install", "--hook-type", "pre-push"], + capture_output=True, + text=True, + ) if result.returncode == 0: - console.print("βœ… [green]Git hooks installed (tests will run before push)[/green]") + console.print( + "βœ… [green]Git hooks installed (tests will run before push)[/green]" + ) else: console.print("⚠️ [yellow]Could not install git hooks (optional)[/yellow]") - # Also install pre-commit hook - subprocess.run(['pre-commit', 'install', '--hook-type', 'pre-commit'], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - check=False) + subprocess.run( + ["pre-commit", "install", "--hook-type", "pre-commit"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + check=False, + ) except Exception as e: console.print(f"⚠️ [yellow]Could not setup git hooks: {e} (optional)[/yellow]") -def setup_hf_token_if_needed(selected_services): - """Prompt for Hugging Face token if needed by selected services. - - Args: - selected_services: List of service names selected by user - - Returns: - HF_TOKEN string if provided, None otherwise - """ - # Check if any selected services need HF_TOKEN - needs_hf_token = 'speaker-recognition' in selected_services - - if not needs_hf_token: - return None - - console.print("\nπŸ€— [bold cyan]Hugging Face Token Configuration[/bold cyan]") - console.print("Required for speaker recognition (PyAnnote models)") - console.print("\n[blue][INFO][/blue] Get yours from: https://huggingface.co/settings/tokens\n") - - # Check for existing token from speaker-recognition service - speaker_env_path = 'extras/speaker-recognition/.env' - existing_token = read_env_value(speaker_env_path, 'HF_TOKEN') - - # Use the masked prompt function - hf_token = prompt_with_existing_masked( - prompt_text="Hugging Face Token", - existing_value=existing_token, - placeholders=['your_huggingface_token_here', 'your-huggingface-token-here', 'hf_xxxxx'], - is_password=True, - default="" - ) - - if hf_token: - masked = mask_value(hf_token) - console.print(f"[green]βœ… HF_TOKEN configured: {masked}[/green]\n") - return hf_token - else: - console.print("[yellow]⚠️ No HF_TOKEN provided - speaker recognition may fail[/yellow]\n") - return None -def setup_config_file(): +def setup_config_file() -> None: """Setup config/config.yml from template if it doesn't exist""" config_file = Path("config/config.yml") config_template = Path("config/config.yml.template") if not config_file.exists(): if config_template.exists(): - # Ensure config/ directory exists config_file.parent.mkdir(parents=True, exist_ok=True) shutil.copy(config_template, config_file) console.print("βœ… [green]Created config/config.yml from template[/green]") else: - console.print("⚠️ [yellow]config/config.yml.template not found, skipping config setup[/yellow]") + console.print( + "⚠️ [yellow]config/config.yml.template not found, skipping config setup[/yellow]" + ) else: - console.print("ℹ️ [blue]config/config.yml already exists, keeping existing configuration[/blue]") - -def select_transcription_provider(): - """Ask user which transcription provider they want""" - console.print("\n🎀 [bold cyan]Transcription Provider[/bold cyan]") - console.print("Choose your speech-to-text provider:") - console.print() - - choices = { - "1": "Deepgram (cloud-based, high quality, requires API key)", - "2": "Parakeet ASR (offline, runs locally, requires GPU)", - "3": "VibeVoice ASR (offline, built-in speaker diarization, requires GPU)", - "4": "None (skip transcription setup)" - } - - for key, desc in choices.items(): - console.print(f" {key}) {desc}") - console.print() + console.print( + "ℹ️ [blue]config/config.yml already exists, keeping existing configuration[/blue]" + ) - while True: - try: - choice = Prompt.ask("Enter choice", default="1") - if choice in choices: - if choice == "1": - return "deepgram" - elif choice == "2": - return "parakeet" - elif choice == "3": - return "vibevoice" - elif choice == "4": - return "none" - console.print(f"[red]Invalid choice. Please select from {list(choices.keys())}[/red]") - except EOFError: - console.print("Using default: Deepgram") - return "deepgram" -def main(): +def main() -> None: """Main orchestration logic""" console.print("πŸŽ‰ [bold green]Welcome to Chronicle![/bold green]\n") - console.print("[dim]This wizard is safe to run as many times as you like.[/dim]") - console.print("[dim]It backs up your existing config and preserves previously entered values.[/dim]") - console.print("[dim]When unsure, just press Enter β€” the defaults will work.[/dim]\n") - # Setup config file from template setup_config_file() - - # Setup git hooks first setup_git_hooks() - - # Show what's available show_service_status() - # Ask about transcription provider FIRST (determines which services are needed) - transcription_provider = select_transcription_provider() - - # Service Selection (pass transcription_provider so we skip asking about ASR when already chosen) - selected_services = select_services(transcription_provider) - - # Auto-add asr-services if local ASR was chosen (Parakeet or VibeVoice) - if transcription_provider in ("parakeet", "vibevoice") and 'asr-services' not in selected_services: - console.print(f"[blue][INFO][/blue] Auto-adding ASR services for {transcription_provider.capitalize()} transcription") - selected_services.append('asr-services') - + selected_services = select_services() if not selected_services: - console.print("\n[yellow]No services selected. Exiting.[/yellow]") + console.print("[yellow]No services selected. Exiting.[/yellow]") return - # HF Token Configuration (if services require it) - hf_token = setup_hf_token_if_needed(selected_services) - - # HTTPS Configuration (for services that need it) - https_enabled = False - server_ip = None - - # Check if we have services that benefit from HTTPS - https_services = {'advanced', 'speaker-recognition'} # advanced will always need https then - needs_https = bool(https_services.intersection(selected_services)) - - if needs_https: - console.print("\nπŸ”’ [bold cyan]HTTPS Configuration[/bold cyan]") - console.print("HTTPS enables microphone access in browsers and secure connections") - - try: - https_enabled = Confirm.ask("Enable HTTPS for selected services?", default=False) - except EOFError: - console.print("Using default: No") - https_enabled = False - - if https_enabled: - # Try to auto-detect Tailscale address - ts_dns, ts_ip = detect_tailscale_info() - - if ts_dns: - console.print(f"\n[green][AUTO-DETECTED][/green] Tailscale DNS: {ts_dns}") - if ts_ip: - console.print(f"[green][AUTO-DETECTED][/green] Tailscale IP: {ts_ip}") - default_address = ts_dns - elif ts_ip: - console.print(f"\n[green][AUTO-DETECTED][/green] Tailscale IP: {ts_ip}") - default_address = ts_ip - else: - console.print("\n[blue][INFO][/blue] Tailscale not detected") - console.print("[blue][INFO][/blue] To find your Tailscale address: tailscale status --json | jq -r '.Self.DNSName'") - default_address = None - - console.print("[blue][INFO][/blue] For local-only access, use 'localhost'") - console.print("Examples: localhost, myhost.tail1234.ts.net, 100.64.1.2") - - # Check for existing SERVER_IP from backend .env - backend_env_path = 'backends/advanced/.env' - existing_ip = read_env_value(backend_env_path, 'SERVER_IP') - - # Use existing value, or auto-detected address, or localhost as default - effective_default = default_address or "localhost" - - server_ip = prompt_with_existing_masked( - prompt_text="Server IP/Domain for SSL certificates", - existing_value=existing_ip, - placeholders=['localhost', 'your-server-ip-here'], - is_password=False, - default=effective_default - ) - - console.print(f"[green]βœ…[/green] HTTPS configured for: {server_ip}") + https_enabled, server_ip = setup_https(selected_services) + obsidian_enabled, neo4j_password = setup_obsidian(selected_services) - # Obsidian/Neo4j Integration - obsidian_enabled = False - neo4j_password = None - - # Check if advanced backend is selected - if 'advanced' in selected_services: - console.print("\nπŸ—‚οΈ [bold cyan]Obsidian/Neo4j Integration[/bold cyan]") - console.print("Enable graph-based knowledge management for Obsidian vault notes") - console.print() - - try: - obsidian_enabled = Confirm.ask("Enable Obsidian/Neo4j integration?", default=False) - except EOFError: - console.print("Using default: No") - obsidian_enabled = False - - if obsidian_enabled: - console.print("[blue][INFO][/blue] Neo4j will be configured for graph-based memory storage") - console.print() - - # Prompt for Neo4j password - while True: - try: - neo4j_password = console.input("Neo4j password (min 8 chars) [default: neo4jpassword]: ").strip() - if not neo4j_password: - neo4j_password = "neo4jpassword" - if len(neo4j_password) >= 8: - break - console.print("[yellow][WARNING][/yellow] Password must be at least 8 characters") - except EOFError: - neo4j_password = "neo4jpassword" - console.print(f"Using default password") - break - - console.print("[green]βœ…[/green] Obsidian/Neo4j integration will be configured") - - # Pure Delegation - Run Each Service Setup console.print(f"\nπŸ“‹ [bold]Setting up {len(selected_services)} services...[/bold]") - - # Clean up .env files from unselected services (creates backups) cleanup_unselected_services(selected_services) - + success_count = 0 failed_services = [] for service in selected_services: - if run_service_setup(service, selected_services, https_enabled, server_ip, - obsidian_enabled, neo4j_password, hf_token, transcription_provider): + if run_service_setup( + service, + selected_services, + https_enabled, + server_ip, + obsidian_enabled, + neo4j_password, + ): success_count += 1 else: failed_services.append(service) - # Plugin Configuration (AFTER backend .env is created) - # This ensures plugins can add their secrets to the existing .env file - # without the backend init overwriting them - setup_plugins() - - # Check for Obsidian/Neo4j configuration (read from config.yml) - obsidian_enabled = False - if 'advanced' in selected_services and 'advanced' not in failed_services: - config_yml_path = Path('config/config.yml') + # Check for Obsidian configuration via config.yml for final messaging + config_obsidian_enabled = False + if "advanced" in selected_services and "advanced" not in failed_services: + config_yml_path = Path("config/config.yml") if config_yml_path.exists(): try: - with open(config_yml_path, 'r') as f: + with open(config_yml_path, "r") as f: config_data = yaml.safe_load(f) - obsidian_config = config_data.get('memory', {}).get('obsidian', {}) - obsidian_enabled = obsidian_config.get('enabled', False) + obsidian_config = config_data.get("memory", {}).get("obsidian", {}) + config_obsidian_enabled = obsidian_config.get("enabled", False) except Exception as e: console.print(f"[yellow]Warning: Could not read config.yml: {e}[/yellow]") - # Final Summary console.print(f"\n🎊 [bold green]Setup Complete![/bold green]") - console.print(f"βœ… {success_count}/{len(selected_services)} services configured successfully") + console.print( + f"βœ… {success_count}/{len(selected_services)} services configured successfully" + ) if failed_services: console.print(f"❌ Failed services: {', '.join(failed_services)}") - # Inform about Obsidian/Neo4j if configured - if obsidian_enabled: + if config_obsidian_enabled or obsidian_enabled: console.print(f"\nπŸ“š [bold cyan]Obsidian Integration Detected[/bold cyan]") - console.print(" Neo4j will be automatically started with the 'obsidian' profile") + console.print( + " Neo4j will be automatically started with the 'obsidian' profile" + ) console.print(" when you start the backend service.") - - # Next Steps - console.print("\nπŸ“– [bold]Next Steps:[/bold]") - # Configuration info + # Next Steps messaging + console.print("\nπŸ“– [bold]Next Steps:[/bold]") console.print("") console.print("πŸ“ [bold cyan]Configuration Files Updated:[/bold cyan]") console.print(" β€’ [green].env files[/green] - API keys and service URLs") - console.print(" β€’ [green]config.yml[/green] - Model definitions and memory provider settings") + console.print( + " β€’ [green]config.yml[/green] - Model definitions and memory provider settings" + ) console.print("") - - # Development Environment Setup console.print("1. Setup development environment (git hooks, testing):") console.print(" [cyan]make setup-dev[/cyan]") - console.print(" [dim]This installs pre-commit hooks to run tests before pushing[/dim]") console.print("") - - # Service Management Commands console.print("2. Start all configured services:") - console.print(" [cyan]uv run --with-requirements setup-requirements.txt python services.py start --all --build[/cyan]") - console.print("") - console.print("3. Or start individual services:") - - configured_services = [] - if 'advanced' in selected_services and 'advanced' not in failed_services: - configured_services.append("backend") - if 'speaker-recognition' in selected_services and 'speaker-recognition' not in failed_services: - configured_services.append("speaker-recognition") - if 'asr-services' in selected_services and 'asr-services' not in failed_services: - configured_services.append("asr-services") - if 'openmemory-mcp' in selected_services and 'openmemory-mcp' not in failed_services: - configured_services.append("openmemory-mcp") - - if configured_services: - service_list = " ".join(configured_services) - console.print(f" [cyan]uv run --with-requirements setup-requirements.txt python services.py start {service_list}[/cyan]") - - console.print("") - console.print("3. Check service status:") - console.print(" [cyan]uv run --with-requirements setup-requirements.txt python services.py status[/cyan]") - + console.print( + " [cyan]uv run --with-requirements setup-requirements.txt python services.py start --all --build[/cyan]" + ) console.print("") - console.print("4. Stop services when done:") - console.print(" [cyan]uv run --with-requirements setup-requirements.txt python services.py stop --all[/cyan]") - - console.print(f"\nπŸš€ [bold]Enjoy Chronicle![/bold]") - - # Show individual service usage - console.print(f"\nπŸ’‘ [dim]Tip: You can also setup services individually:[/dim]") - console.print(f"[dim] cd backends/advanced && uv run --with-requirements ../../setup-requirements.txt python init.py[/dim]") - console.print(f"[dim] cd extras/speaker-recognition && uv run --with-requirements ../../setup-requirements.txt python init.py[/dim]") - console.print(f"[dim] cd extras/asr-services && uv run --with-requirements ../../setup-requirements.txt python init.py[/dim]") + if __name__ == "__main__": - main() \ No newline at end of file + main()