diff --git a/Docs/whisper-adapter-finetuning/README.md b/Docs/whisper-adapter-finetuning/README.md new file mode 100644 index 00000000..eb266358 --- /dev/null +++ b/Docs/whisper-adapter-finetuning/README.md @@ -0,0 +1,202 @@ +# Whisper Sneeze Adapter Training + +This project fine-tunes OpenAI's Whisper model to transcribe sneezes in audio/video content using LoRA adapters. The model learns to recognize and transcribe sneezes as the token "SNEEZE" in transcriptions. + +## Prerequisites + +- Python 3.10+ +- CUDA-capable GPU (recommended for training) +- Access to Google Gemini API (for generating transcripts) + +## Installation + +1. Create a virtual environment: +```bash +python -m venv .venv +source .venv/bin/activate # On Windows: .venv\Scripts\activate +``` + +2. Install dependencies: +```bash +pip install torch torchaudio +pip install transformers datasets evaluate +pip install unsloth[colab-new] +pip install librosa soundfile jiwer +pip install tqdm +``` + +## Workflow + +### Step 1: Prepare Your Video + +1. Record or obtain a video file containing sneezes (e.g., `girls_sneezing.mp4` download with + ``` + yt-dlp -f "bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best" --merge-output-format mp4 -o "girls_sneezing.mp4" https://youtu.be/36b4248j5UE + ``` + +### Step 2: Generate Transcript with Gemini + +1. Upload your video to Google Gemini (or use Gemini API) +2. Request a transcript with sneezes marked using the format: `` +3. Generate a JSONL file named `sneeze_data.jsonl` with the following format: + +```jsonl +{"start": 0.0, "end": 5.0, "text": "Ugh, I really need to sneeze. Stuck? Yeah, it's right there."} +{"start": 5.0, "end": 11.0, "text": "Close one. Bless you. Thanks."} +{"start": 12.0, "end": 17.0, "text": "Ugh, I can feel it. I really need to sneeze so bad. Go on, let it out."} +``` + +**Format requirements:** +- Each line is a JSON object +- `start`: Start time in seconds (float) +- `end`: End time in seconds (float) +- `text`: Transcription text with sneezes marked as `` + +**Example Gemini prompt:** +``` +Please transcribe this video and create a JSONL file where each line contains: +- start: start time in seconds +- end: end time in seconds +- text: the transcription with sneezes marked as + +Format as JSONL (one JSON object per line). +``` + +### Step 3: Prepare Training Data + +Run the data preparation script to extract audio chunks and create train/test splits: + +```bash +python prepare_sneeze_data.py +``` + +This script will: +- Extract audio from your video file (`girls_sneezing.mp4`) +- Create audio chunks from the segments in `sneeze_data.jsonl` +- Save chunks to `sneeze_chunks/` directory +- Split data into `train.jsonl` (60%) and `test.jsonl` (40%) + +**Requirements:** +- `sneeze_data.jsonl` must exist in the project root +- Video file must be named `girls_sneezing.mp4` + +### Step 4: Train the Model + +Train the Whisper model with LoRA adapters: + +```bash +python train_sneeze.py +``` + +This will: +- Load the base Whisper Large v3 model +- Apply LoRA adapters (only trains 1-10% of parameters) +- Fine-tune on your sneeze data +- Save the adapter to `sneeze_lora_adapter_unsloth/` + +**Training configuration:** +- Model: `unsloth/whisper-large-v3` +- LoRA rank: 64 +- Batch size: 1 (with gradient accumulation: 4) +- Max steps: 200 +- Learning rate: 1e-4 + +**Note:** Training requires a GPU with sufficient VRAM. Adjust `load_in_4bit=True` in the script if you have limited memory. + +### Step 5: Evaluate the Model + +Evaluate the trained model on the test set: + +```bash +python evaluate_sneeze_model.py +``` + +This will: +- Load the base model and merge the LoRA adapter +- Run inference on test samples +- Calculate Word Error Rate (WER) +- Report sneeze detection recall and false positives + +## Results + +### Training Results + +Training was performed on a Tesla T4 GPU with the following configuration: +- **Model**: `unsloth/whisper-large-v3` +- **Trainable Parameters**: 31,457,280 of 1,574,947,840 (2.00%) +- **Training Time**: 12.04 minutes +- **Peak Memory Usage**: 8.896 GB (60.35% of max memory) +- **Training Samples**: 49 samples +- **Test Samples**: 4 samples + +**Training Loss Progression:** +| Step | Training Loss | Validation Loss | WER | +|------|---------------|-----------------|-----| +| 20 | 1.646100 | 1.869532 | 50.0% | +| 40 | 0.832500 | 1.004385 | 30.0% | +| 60 | 0.304600 | 0.354044 | 30.0% | +| 80 | 0.067700 | 0.051606 | 0.0% | +| 100 | 0.017600 | 0.162433 | 10.0% | +| 120 | 0.003400 | 0.006127 | 0.0% | +| 140 | 0.002000 | 0.004151 | 0.0% | +| 160 | 0.001400 | 0.003399 | 0.0% | +| 180 | 0.001300 | 0.003005 | 0.0% | +| 200 | 0.001000 | 0.002856 | 0.0% | + +**Final Metrics:** +- Final Training Loss: 0.001000 +- Final Validation Loss: 0.002856 +- Final Validation WER: 0.0% + +### Evaluation Results + +Evaluation was performed on 10 test samples (4 containing sneezes): + +**Overall Performance:** +- **Word Error Rate (WER)**: 0.3217 (32.17%) +- **Sneeze Recall**: 2/4 (50.0%) +- **False Positives**: 0 + +**Missed Sneezes:** +1. Reference: "Take your time, it'll come. SNEEZE Oh wow. Excuse me." + Prediction: "Take your time. It'll come. Oh, wow." + +2. Reference: "It's right there but... False alarm? No, it's stuck. SNEEZE Bless you." + Prediction: "It's right there, but... False alarm? No! It stopped..." + +**Analysis:** +- The model achieved perfect WER (0.0%) on the validation set during training, indicating good generalization on the training distribution. +- On the test set, the model achieved 50% sneeze recall, successfully detecting 2 out of 4 sneezes. +- No false positives were detected, showing the model is conservative in its sneeze predictions. +- The 32.17% WER on the test set suggests room for improvement, particularly in detecting sneezes in more varied contexts. + +## Project Structure + +``` +whisper-adapter-test/ +├── prepare_sneeze_data.py # Data preparation script +├── improved_sneeze_trainer.py # Training script +├── evaluate_sneeze_model.py # Evaluation script +├── sneeze_data.jsonl # Input transcript with sneezes +├── train.jsonl # Training manifest +├── test.jsonl # Test manifest +├── sneeze_chunks/ # Extracted audio chunks +└── sneeze_lora_adapter_unsloth/ # Trained adapter (created after training) +``` + +## Output Files + +- `train.jsonl`: Training dataset manifest +- `test.jsonl`: Test dataset manifest +- `sneeze_chunks/`: Directory with extracted audio chunks +- `sneeze_lora_adapter_unsloth/`: Trained LoRA adapter weights + +## Notes + +- The model replaces `` tags with `SNEEZE` during training +- LoRA adapters are memory-efficient and only update a small portion of model weights +- The evaluation script merges the adapter into the base model for inference + +## Conclusion + +Despite training on only 13 examples and evaluating on 10 test samples, the model achieved significant progress in sneeze detection. With just this small dataset, we were able to fine-tune the Whisper model to recognize and transcribe sneezes with 50% recall and zero false positives. This demonstrates the effectiveness of LoRA adapters for efficient fine-tuning on specialized tasks with limited data. diff --git a/Docs/whisper-adapter-finetuning/evaluate_sneeze_model.py b/Docs/whisper-adapter-finetuning/evaluate_sneeze_model.py new file mode 100644 index 00000000..08592bc5 --- /dev/null +++ b/Docs/whisper-adapter-finetuning/evaluate_sneeze_model.py @@ -0,0 +1,115 @@ +import os +import json +import torch +import librosa +import jiwer +from transformers import WhisperProcessor, WhisperForConditionalGeneration +from peft import PeftModel +from tqdm import tqdm + +# --- CONFIGURATION (MUST MATCH YOUR TRAINING) --- +BASE_MODEL_ID = "openai/whisper-large-v3" +ADAPTER_PATH = "sneeze_lora_adapter_unsloth" # The folder Unsloth created +TEST_MANIFEST = "test.jsonl" + +def main(): + # 1. Setup Device + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + + # 2. Load Base Model (Large v3) + print(f"Loading base model: {BASE_MODEL_ID}") + processor = WhisperProcessor.from_pretrained(BASE_MODEL_ID) + model = WhisperForConditionalGeneration.from_pretrained( + BASE_MODEL_ID, + torch_dtype=torch.float16 if device == "cuda" else torch.float32 + ) + + # 3. Load and MERGE Adapter + if os.path.exists(ADAPTER_PATH): + print(f"Loading LoRA adapter from: {ADAPTER_PATH}") + model = PeftModel.from_pretrained(model, ADAPTER_PATH) + print("Merging LoRA weights...") + model = model.merge_and_unload() + else: + print(f"❌ ERROR: Adapter {ADAPTER_PATH} not found!") + return + + model.to(device) + model.eval() + + # 4. Run Evaluation + evaluate_dataset(model, processor, device, TEST_MANIFEST) + +def evaluate_dataset(model, processor, device, manifest_path): + if not os.path.exists(manifest_path): + print(f"Manifest {manifest_path} not found.") + return + + samples = [] + with open(manifest_path, 'r') as f: + for line in f: + samples.append(json.loads(line)) + + print(f"Testing on {len(samples)} samples...") + + predictions = [] + references = [] + sneeze_stats = {"total": 0, "detected": 0, "fp": 0} + + for sample in tqdm(samples): + path = sample['audio'] + ref_text = sample['text'].replace("", "SNEEZE") + + try: + audio, _ = librosa.load(path, sr=16000) + except: continue + + # Process audio + inputs = processor(audio, sampling_rate=16000, return_tensors="pt") + input_features = inputs.input_features.to(device) + + # Handle the dtype for half precision (if on GPU) + if device == "cuda": + input_features = input_features.half() + + # Generate + with torch.no_grad(): + generated_ids = model.generate( + input_features=input_features, # Use input_features, not inputs + language="en", + task="transcribe", + max_new_tokens=256 + ) + + pred = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() + + predictions.append(pred) + references.append(ref_text) + + # Stats + has_sneeze_ref = "SNEEZE" in ref_text + has_sneeze_pred = "SNEEZE" in pred + + if has_sneeze_ref: + sneeze_stats["total"] += 1 + if has_sneeze_pred: + sneeze_stats["detected"] += 1 + else: + print(f"\n❌ MISSED SNEEZE\nRef: {ref_text}\nPrd: {pred}") + elif has_sneeze_pred: + sneeze_stats["fp"] += 1 + print(f"\n⚠️ FALSE POSITIVE\nRef: {ref_text}\nPrd: {pred}") + + # Results + wer = jiwer.wer(references, predictions) + print("\n" + "="*40) + print(f"Word Error Rate: {wer:.4f}") + if sneeze_stats["total"] > 0: + recall = (sneeze_stats["detected"] / sneeze_stats["total"]) * 100 + print(f"Sneeze Recall: {sneeze_stats['detected']}/{sneeze_stats['total']} ({recall:.1f}%)") + print(f"False Positives: {sneeze_stats['fp']}") + print("="*40) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/Docs/whisper-adapter-finetuning/prepare_sneeze_data.py b/Docs/whisper-adapter-finetuning/prepare_sneeze_data.py new file mode 100644 index 00000000..cdc764aa --- /dev/null +++ b/Docs/whisper-adapter-finetuning/prepare_sneeze_data.py @@ -0,0 +1,80 @@ +import json +import os +import random +import librosa +import soundfile as sf +import numpy as np + +def prepare_data(): + jsonl_path = "sneeze_data.jsonl" + video_path = "girls_sneezing.mp4" + output_dir = "sneeze_chunks" + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Load full audio + print(f"Loading {video_path}...") + try: + y, sr = librosa.load(video_path, sr=16000) + except Exception as e: + print(f"Error loading video: {e}") + return + + segments = [] + with open(jsonl_path, 'r') as f: + for line in f: + if line.strip(): + segments.append(json.loads(line)) + + print(f"Found {len(segments)} segments.") + + dataset_entries = [] + + for i, seg in enumerate(segments): + start_time = seg['start'] + end_time = seg['end'] + text = seg['text'] + + # Calculate sample indices + start_sample = int(start_time * sr) + end_sample = int(end_time * sr) + + # Extract audio + chunk = y[start_sample:end_sample] + + # Save to file + chunk_filename = f"chunk_{i:03d}.wav" + chunk_path = os.path.join(output_dir, chunk_filename) + sf.write(chunk_path, chunk, sr) + + dataset_entries.append({ + "audio": chunk_path, + "text": text + }) + print(f"Saved {chunk_filename}: {text[:30]}...") + + # Shuffle and Split + random.seed(42) + random.shuffle(dataset_entries) + + split_idx = int(len(dataset_entries) * 0.6) + train_data = dataset_entries[:split_idx] + test_data = dataset_entries[split_idx:] + + print(f"Training samples: {len(train_data)}") + print(f"Testing samples: {len(test_data)}") + + # Save split manifests + with open("train.jsonl", "w") as f: + for entry in train_data: + f.write(json.dumps(entry) + "\n") + + with open("test.jsonl", "w") as f: + for entry in test_data: + f.write(json.dumps(entry) + "\n") + + print("Data preparation complete.") + +if __name__ == "__main__": + prepare_data() diff --git a/Docs/whisper-adapter-finetuning/sneeze_data.jsonl b/Docs/whisper-adapter-finetuning/sneeze_data.jsonl new file mode 100644 index 00000000..418ee305 --- /dev/null +++ b/Docs/whisper-adapter-finetuning/sneeze_data.jsonl @@ -0,0 +1,23 @@ +{"start": 0.0, "end": 5.0, "text": "Ugh, I really need to sneeze. Stuck? Yeah, it's right there."} +{"start": 5.0, "end": 11.0, "text": "Close one. Bless you. Thanks."} +{"start": 12.0, "end": 17.0, "text": "Ugh, I can feel it. I really need to sneeze so bad. Go on, let it out."} +{"start": 17.0, "end": 23.0, "text": "It's right there but... False alarm? No, it's stuck. Bless you."} +{"start": 24.0, "end": 29.0, "text": "Ugh, my nose. I... I really need to sneeze so bad. Do it then."} +{"start": 29.0, "end": 36.0, "text": "No, it's not coming. That's the worst. Stuck. "} +{"start": 36.0, "end": 42.0, "text": "Ugh, my nose, I want to sneeze so bad. You okay? Is it stuck?"} +{"start": 42.0, "end": 48.0, "text": "Nope, nothing. Teasing you. "} +{"start": 48.0, "end": 54.0, "text": "Ugh, my nose. I need to sneeze so bad. Go on, let it out."} +{"start": 54.0, "end": 60.0, "text": "Oh, it's stuck. It's teasing you. Bless you."} +{"start": 60.0, "end": 66.0, "text": "Ugh, I really... I really need to sneeze so bad. Go on, just let it out."} +{"start": 66.0, "end": 72.0, "text": "Ugh, it's stuck. Oh come on. "} +{"start": 72.0, "end": 78.0, "text": "Ugh, I really need to sneeze so bad. Go on, let it out. It's just stuck."} +{"start": 78.0, "end": 84.0, "text": "I can feel it right there. Oh finally."} +{"start": 84.0, "end": 90.0, "text": "Ugh, my nose is so itchy. I need to sneeze so badly. Do it. Let it out."} +{"start": 90.0, "end": 96.0, "text": "No. It's... it's stuck. Almost? "} +{"start": 96.0, "end": 102.0, "text": "Ugh, I want to sneeze so bad. It's right there, just not coming out."} +{"start": 102.0, "end": 108.0, "text": "No. Still stuck! "} +{"start": 108.0, "end": 114.0, "text": "Ugh, I really... I really need to sneeze so bad. Here it comes?"} +{"start": 114.0, "end": 120.0, "text": "Nope, it's uh, stuck. Why won't it come out? "} +{"start": 120.0, "end": 126.0, "text": "Ugh, I swear I need to sneeze so badly. Ugh, nope, still there."} +{"start": 126.0, "end": 131.0, "text": "Take your time, it'll come. Oh wow. Excuse me."} +{"start": 132.0, "end": 140.0, "text": "Don't forget to check out Patreon dot com slash AI sneeze for exclusive sneezing content and early access. Try it for free! "} diff --git a/Docs/whisper-adapter-finetuning/train_sneeze.py b/Docs/whisper-adapter-finetuning/train_sneeze.py new file mode 100644 index 00000000..ae72b596 --- /dev/null +++ b/Docs/whisper-adapter-finetuning/train_sneeze.py @@ -0,0 +1,274 @@ +import os +import json +import torch +import librosa +import numpy as np +import tqdm +import evaluate +from dataclasses import dataclass +from typing import Any, Dict, List, Union + +from unsloth import FastModel, is_bf16_supported +from transformers import WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer +from datasets import Dataset + +# Configuration +MODEL_ID = "unsloth/whisper-large-v3" +OUTPUT_DIR = "sneeze_lora_adapter_unsloth" +TRAIN_MANIFEST = "train.jsonl" + +def prepare_dataset(): + """Load training data from jsonl""" + audio_paths = [] + texts = [] + + # Check if file exists + if not os.path.exists(TRAIN_MANIFEST): + raise FileNotFoundError(f"{TRAIN_MANIFEST} not found.") + + with open(TRAIN_MANIFEST, 'r') as f: + for line in f: + try: + entry = json.loads(line) + # Normalize text + text = entry['text'].replace("", "SNEEZE") + + audio_paths.append(entry['audio']) + texts.append(text) + + # Simple oversampling for target keyword + if "SNEEZE" in text: + for _ in range(5): + audio_paths.append(entry['audio']) + texts.append(text) + except Exception as e: + print(f"Skipping bad line: {e}") + + print(f"Loaded {len(audio_paths)} training samples.") + + data = { + "audio_path": audio_paths, + "text": texts + } + + return Dataset.from_dict(data) + + +def main(): + print(f"Loading model with Unsloth: {MODEL_ID}") + + # Load model using Unsloth's FastModel + model, tokenizer = FastModel.from_pretrained( + model_name=MODEL_ID, + dtype=None, # Auto detection + load_in_4bit=False, # Set to True for 4bit quantization (lower memory) + auto_model=WhisperForConditionalGeneration, + whisper_language="English", + whisper_task="transcribe", + # token = "hf_...", # Use if needed for gated models + ) + + # Apply LoRA adapters using Unsloth (only updates 1-10% of parameters) + model = FastModel.get_peft_model( + model, + r=64, # Suggested: 8, 16, 32, 64, 128 + target_modules=["q_proj", "v_proj"], + lora_alpha=64, + lora_dropout=0, # 0 is optimized + bias="none", # "none" is optimized + use_gradient_checkpointing="unsloth", # 30% less VRAM, fits 2x larger batch sizes + random_state=3407, + use_rslora=False, + loftq_config=None, + task_type=None, # MUST be None for Whisper + ) + + # Configure generation settings + model.generation_config.language = "<|en|>" + model.generation_config.task = "transcribe" + model.config.suppress_tokens = [] + model.generation_config.forced_decoder_ids = None + + # Load dataset + dataset = prepare_dataset() + + def formatting_prompts_func(example): + """Process audio and text for training""" + try: + # Load audio file + audio_array, sr = librosa.load(example['audio_path'], sr=16000) + except Exception as e: + print(f"Error loading {example['audio_path']}: {e}") + return None + + # Extract features using tokenizer's feature extractor + features = tokenizer.feature_extractor( + audio_array, sampling_rate=16000 + ) + + # Tokenize text + tokenized_text = tokenizer.tokenizer(example["text"]) + + return { + "input_features": features.input_features[0], + "labels": tokenized_text.input_ids, + } + + print("Processing dataset...") + train_data = [] + for example in tqdm.tqdm(dataset, desc='Processing audio'): + result = formatting_prompts_func(example) + if result is not None: + train_data.append(result) + + print(f"Successfully processed {len(train_data)} samples") + + # Split into train/test + split_idx = max(1, int(len(train_data) * 0.94)) + train_dataset = train_data[:split_idx] + test_dataset = train_data[split_idx:] + + print(f"Train samples: {len(train_dataset)}, Test samples: {len(test_dataset)}") + + # Setup WER metric for evaluation + metric = evaluate.load("wer") + + def compute_metrics(pred): + pred_logits = pred.predictions[0] + label_ids = pred.label_ids + + # Replace -100 with pad_token_id + label_ids[label_ids == -100] = tokenizer.pad_token_id + + pred_ids = np.argmax(pred_logits, axis=-1) + + pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) + label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True) + + wer = 100 * metric.compute(predictions=pred_str, references=label_str) + + return {"wer": wer} + + @dataclass + class DataCollatorSpeechSeq2SeqWithPadding: + processor: Any + + def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: + input_features = [{"input_features": feature["input_features"]} for feature in features] + batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") + + label_features = [{"input_ids": feature["labels"]} for feature in features] + labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") + + labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) + + if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item(): + labels = labels[:, 1:] + + batch["labels"] = labels + + return batch + + data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=tokenizer) + + # Show memory stats before training + if torch.cuda.is_available(): + gpu_stats = torch.cuda.get_device_properties(0) + start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) + max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) + print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.") + print(f"{start_gpu_memory} GB of memory reserved.") + + # Setup trainer with Seq2SeqTrainer + trainer = Seq2SeqTrainer( + model=model, + train_dataset=train_dataset, + data_collator=data_collator, + eval_dataset=test_dataset if len(test_dataset) > 0 else None, + tokenizer=tokenizer.feature_extractor, + compute_metrics=compute_metrics, + args=Seq2SeqTrainingArguments( + # predict_with_generate=True, + per_device_train_batch_size=1, + gradient_accumulation_steps=4, + warmup_steps=5, + # num_train_epochs=1, # Set for full training run + max_steps=200, + learning_rate=1e-4, + logging_steps=10, + optim="adamw_8bit", + fp16=not is_bf16_supported(), + bf16=is_bf16_supported(), + weight_decay=0.001, + remove_unused_columns=False, # Required for PEFT + lr_scheduler_type="linear", + label_names=['labels'], + eval_steps=20, + eval_strategy="steps" if len(test_dataset) > 0 else "no", + seed=3407, + output_dir=OUTPUT_DIR, + report_to="none", + ), + ) + + print("Starting training...") + trainer_stats = trainer.train() + + # Show final memory stats + if torch.cuda.is_available(): + used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) + used_memory_for_lora = round(used_memory - start_gpu_memory, 3) + used_percentage = round(used_memory / max_memory * 100, 3) + lora_percentage = round(used_memory_for_lora / max_memory * 100, 3) + print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.") + print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.") + print(f"Peak reserved memory = {used_memory} GB.") + print(f"Peak reserved memory for training = {used_memory_for_lora} GB.") + print(f"Peak reserved memory % of max memory = {used_percentage} %.") + print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.") + + # Save the model + print(f"Saving adapter to {OUTPUT_DIR}") + model.save_pretrained(OUTPUT_DIR) + tokenizer.save_pretrained(OUTPUT_DIR) + + print("Training complete!") + + +def run_inference(audio_file: str, model_path: str = OUTPUT_DIR): + """Run inference with the trained model""" + from transformers import pipeline + + print(f"Loading model from {model_path}") + + # Load the fine-tuned model + model, tokenizer = FastModel.from_pretrained( + model_name=model_path, + dtype=None, + load_in_4bit=False, + auto_model=WhisperForConditionalGeneration, + ) + + # Set model to inference mode + FastModel.for_inference(model) + model.eval() + + # Create pipeline + whisper = pipeline( + "automatic-speech-recognition", + model=model, + tokenizer=tokenizer.tokenizer, + feature_extractor=tokenizer.feature_extractor, + processor=tokenizer, + return_language=True, + torch_dtype=torch.float16, + ) + + # Transcribe + result = whisper(audio_file) + print(f"Transcription: {result['text']}") + return result + + +if __name__ == "__main__": + main() diff --git a/backends/advanced/docs/plugin-development-guide.md b/backends/advanced/docs/plugin-development-guide.md new file mode 100644 index 00000000..17c53b4a --- /dev/null +++ b/backends/advanced/docs/plugin-development-guide.md @@ -0,0 +1,776 @@ +# Chronicle Plugin Development Guide + +A comprehensive guide to creating custom plugins for Chronicle. + +## Table of Contents + +1. [Introduction](#introduction) +2. [Quick Start](#quick-start) +3. [Plugin Architecture](#plugin-architecture) +4. [Event Types](#event-types) +5. [Creating Your First Plugin](#creating-your-first-plugin) +6. [Configuration](#configuration) +7. [Testing Plugins](#testing-plugins) +8. [Best Practices](#best-practices) +9. [Examples](#examples) +10. [Troubleshooting](#troubleshooting) + +## Introduction + +Chronicle's plugin system allows you to extend functionality by subscribing to events and executing custom logic. Plugins are: + +- **Event-driven**: React to transcripts, conversations, or memory processing +- **Auto-discovered**: Drop plugins into the `plugins/` directory +- **Configurable**: YAML-based configuration with environment variable support +- **Isolated**: Each plugin runs independently with proper error handling + +### Plugin Types + +- **Core Plugins**: Built-in plugins (`homeassistant`, `test_event`) +- **Community Plugins**: Auto-discovered plugins in `plugins/` directory + +## Quick Start + +### 1. Generate Plugin Boilerplate + +```bash +cd backends/advanced +uv run python scripts/create_plugin.py my_awesome_plugin +``` + +This creates: +``` +plugins/my_awesome_plugin/ +├── __init__.py # Plugin exports +├── plugin.py # Main plugin logic +└── README.md # Plugin documentation +``` + +### 2. Implement Plugin Logic + +Edit `plugins/my_awesome_plugin/plugin.py`: + +```python +async def on_conversation_complete(self, context: PluginContext) -> Optional[PluginResult]: + """Handle conversation completion.""" + transcript = context.data.get('transcript', '') + + # Your custom logic here + print(f"Processing: {transcript}") + + return PluginResult(success=True, message="Processing complete") +``` + +### 3. Configure Plugin + +Add to `config/plugins.yml`: + +```yaml +plugins: + my_awesome_plugin: + enabled: true + events: + - conversation.complete + condition: + type: always +``` + +### 4. Restart Backend + +```bash +cd backends/advanced +docker compose restart +``` + +Your plugin will be auto-discovered and loaded! + +## Plugin Architecture + +### Base Plugin Class + +All plugins inherit from `BasePlugin`: + +```python +from advanced_omi_backend.plugins.base import BasePlugin, PluginContext, PluginResult + +class MyPlugin(BasePlugin): + SUPPORTED_ACCESS_LEVELS = ['conversation'] # Which events you support + + async def initialize(self): + """Initialize resources (called on app startup)""" + pass + + async def cleanup(self): + """Clean up resources (called on app shutdown)""" + pass + + async def on_conversation_complete(self, context: PluginContext): + """Handle conversation.complete events""" + pass +``` + +### Plugin Context + +Context passed to plugin methods: + +```python +@dataclass +class PluginContext: + user_id: str # User identifier + event: str # Event name (e.g., "conversation.complete") + data: Dict[str, Any] # Event-specific data + metadata: Dict[str, Any] # Additional metadata +``` + +### Plugin Result + +Return value from plugin methods: + +```python +@dataclass +class PluginResult: + success: bool # Whether operation succeeded + data: Optional[Dict[str, Any]] # Optional result data + message: Optional[str] # Optional status message + should_continue: bool # Whether to continue normal processing (default: True) +``` + +## Event Types + +### 1. Transcript Events (`transcript.streaming`) + +**When**: Real-time transcript segments arrive from WebSocket +**Context Data**: +- `transcript` (str): The transcript text +- `segment_id` (str): Unique segment identifier +- `conversation_id` (str): Current conversation ID + +**Use Cases**: +- Wake word detection +- Real-time command processing +- Live transcript analysis + +**Example**: +```python +async def on_transcript(self, context: PluginContext): + transcript = context.data.get('transcript', '') + if 'urgent' in transcript.lower(): + await self.send_notification(transcript) +``` + +### 2. Conversation Events (`conversation.complete`) + +**When**: Conversation processing finishes +**Context Data**: +- `conversation` (dict): Full conversation data +- `transcript` (str): Complete transcript +- `duration` (float): Conversation duration in seconds +- `conversation_id` (str): Conversation identifier + +**Use Cases**: +- Email summaries +- Analytics tracking +- External integrations +- Conversation archiving + +**Example**: +```python +async def on_conversation_complete(self, context: PluginContext): + conversation = context.data.get('conversation', {}) + duration = context.data.get('duration', 0) + + if duration > 300: # 5 minutes + await self.archive_long_conversation(conversation) +``` + +### 3. Memory Events (`memory.processed`) + +**When**: Memory extraction finishes +**Context Data**: +- `memories` (list): Extracted memories +- `conversation` (dict): Source conversation +- `memory_count` (int): Number of memories created +- `conversation_id` (str): Conversation identifier + +**Use Cases**: +- Memory indexing +- Knowledge graph updates +- Memory notifications +- Analytics + +**Example**: +```python +async def on_memory_processed(self, context: PluginContext): + memories = context.data.get('memories', []) + + for memory in memories: + await self.index_memory(memory) +``` + +## Creating Your First Plugin + +### Step 1: Generate Boilerplate + +```bash +uv run python scripts/create_plugin.py todo_extractor +``` + +### Step 2: Define Plugin Logic + +```python +""" +Todo Extractor Plugin - Extracts action items from conversations. +""" +import logging +import re +from typing import Any, Dict, List, Optional + +from ..base import BasePlugin, PluginContext, PluginResult + +logger = logging.getLogger(__name__) + + +class TodoExtractorPlugin(BasePlugin): + """Extract and save action items from conversations.""" + + SUPPORTED_ACCESS_LEVELS = ['conversation'] + + def __init__(self, config: Dict[str, Any]): + super().__init__(config) + self.todo_patterns = [ + r'I need to (.+)', + r'I should (.+)', + r'TODO: (.+)', + r'reminder to (.+)', + ] + + async def initialize(self): + if not self.enabled: + return + + logger.info("TodoExtractor plugin initialized") + + async def on_conversation_complete(self, context: PluginContext): + try: + transcript = context.data.get('transcript', '') + todos = self._extract_todos(transcript) + + if todos: + await self._save_todos(context.user_id, todos) + + return PluginResult( + success=True, + message=f"Extracted {len(todos)} action items", + data={'todos': todos} + ) + + return PluginResult(success=True, message="No action items found") + + except Exception as e: + logger.error(f"Error extracting todos: {e}") + return PluginResult(success=False, message=str(e)) + + def _extract_todos(self, transcript: str) -> List[str]: + """Extract todo items from transcript.""" + todos = [] + + for pattern in self.todo_patterns: + matches = re.findall(pattern, transcript, re.IGNORECASE) + todos.extend(matches) + + return list(set(todos)) # Remove duplicates + + async def _save_todos(self, user_id: str, todos: List[str]): + """Save todos to database or external service.""" + from advanced_omi_backend.database import get_database + + db = get_database() + for todo in todos: + await db['todos'].insert_one({ + 'user_id': user_id, + 'task': todo, + 'completed': False, + 'created_at': datetime.utcnow() + }) +``` + +### Step 3: Configure Plugin + +`config/plugins.yml`: + +```yaml +plugins: + todo_extractor: + enabled: true + events: + - conversation.complete + condition: + type: always +``` + +### Step 4: Test Plugin + +1. Restart backend: `docker compose restart` +2. Create a conversation with phrases like "I need to buy milk" +3. Check logs: `docker compose logs -f chronicle-backend | grep TodoExtractor` +4. Verify todos in database + +## Configuration + +### YAML Configuration + +`config/plugins.yml`: + +```yaml +plugins: + my_plugin: + # Basic Configuration + enabled: true # Enable/disable plugin + + # Event Subscriptions + events: + - conversation.complete + - memory.processed + + # Execution Conditions + condition: + type: always # always, wake_word, regex + # wake_words: ["hey assistant"] # For wake_word type + # pattern: "urgent" # For regex type + + # Custom Configuration + api_url: ${MY_API_URL} # Environment variable + timeout: 30 + max_retries: 3 +``` + +### Environment Variables + +Use `${VAR_NAME}` syntax: + +```yaml +api_key: ${MY_API_KEY} +base_url: ${BASE_URL:-http://localhost:8000} # With default +``` + +Add to `.env`: + +```bash +MY_API_KEY=your-key-here +BASE_URL=https://api.example.com +``` + +### Condition Types + +**Always Execute**: +```yaml +condition: + type: always +``` + +**Wake Word** (transcript events only): +```yaml +condition: + type: wake_word + wake_words: + - hey assistant + - computer +``` + +**Regex Pattern**: +```yaml +condition: + type: regex + pattern: "urgent|important" +``` + +## Testing Plugins + +### Unit Tests + +`tests/test_my_plugin.py`: + +```python +import pytest +from plugins.my_plugin import MyPlugin +from plugins.base import PluginContext + +class TestMyPlugin: + def test_plugin_initialization(self): + config = {'enabled': True, 'events': ['conversation.complete']} + plugin = MyPlugin(config) + assert plugin.enabled is True + + @pytest.mark.asyncio + async def test_conversation_processing(self): + plugin = MyPlugin({'enabled': True}) + await plugin.initialize() + + context = PluginContext( + user_id='test-user', + event='conversation.complete', + data={'transcript': 'Test transcript'} + ) + + result = await plugin.on_conversation_complete(context) + assert result.success is True +``` + +### Integration Testing + +1. **Enable Test Plugin**: +```yaml +test_event: + enabled: true + events: + - conversation.complete +``` + +2. **Check Logs**: +```bash +docker compose logs -f | grep "test_event" +``` + +3. **Upload Test Audio**: +```bash +curl -X POST http://localhost:8000/api/process-audio-files \ + -H "Authorization: Bearer $TOKEN" \ + -F "files=@test.wav" +``` + +### Manual Testing Checklist + +- [ ] Plugin loads without errors +- [ ] Configuration validates correctly +- [ ] Events trigger plugin execution +- [ ] Plugin logic executes successfully +- [ ] Errors are handled gracefully +- [ ] Logs provide useful information + +## Best Practices + +### 1. Error Handling + +Always wrap logic in try-except: + +```python +async def on_conversation_complete(self, context): + try: + # Your logic + result = await self.process(context) + return PluginResult(success=True, data=result) + except Exception as e: + logger.error(f"Error: {e}", exc_info=True) + return PluginResult(success=False, message=str(e)) +``` + +### 2. Logging + +Use appropriate log levels: + +```python +logger.debug("Detailed debug information") +logger.info("Important milestones") +logger.warning("Non-critical issues") +logger.error("Errors that need attention") +``` + +### 3. Resource Management + +Clean up in `cleanup()`: + +```python +async def initialize(self): + self.client = ExternalClient() + await self.client.connect() + +async def cleanup(self): + if self.client: + await self.client.disconnect() +``` + +### 4. Configuration Validation + +Validate in `initialize()`: + +```python +async def initialize(self): + if not self.config.get('api_key'): + raise ValueError("API key is required") + + if self.config.get('timeout', 0) <= 0: + raise ValueError("Timeout must be positive") +``` + +### 5. Async Best Practices + +Use `asyncio.to_thread()` for blocking operations: + +```python +import asyncio + +async def my_method(self): + # Run blocking operation in thread pool + result = await asyncio.to_thread(blocking_function, arg1, arg2) + return result +``` + +### 6. Database Access + +Use the global database handle: + +```python +from advanced_omi_backend.database import get_database + +async def save_data(self, data): + db = get_database() + await db['my_collection'].insert_one(data) +``` + +### 7. LLM Access + +Use the global LLM client: + +```python +from advanced_omi_backend.llm_client import async_generate + +async def generate_summary(self, text): + prompt = f"Summarize: {text}" + summary = await async_generate(prompt) + return summary +``` + +## Examples + +### Example 1: Slack Notifier + +```python +class SlackNotifierPlugin(BasePlugin): + SUPPORTED_ACCESS_LEVELS = ['conversation'] + + async def initialize(self): + self.webhook_url = self.config.get('slack_webhook_url') + if not self.webhook_url: + raise ValueError("Slack webhook URL required") + + async def on_conversation_complete(self, context): + transcript = context.data.get('transcript', '') + duration = context.data.get('duration', 0) + + message = { + "text": f"New conversation ({duration:.1f}s)", + "blocks": [{ + "type": "section", + "text": {"type": "mrkdwn", "text": f"```{transcript[:500]}```"} + }] + } + + async with aiohttp.ClientSession() as session: + await session.post(self.webhook_url, json=message) + + return PluginResult(success=True, message="Notification sent") +``` + +### Example 2: Keyword Alerter + +```python +class KeywordAlerterPlugin(BasePlugin): + SUPPORTED_ACCESS_LEVELS = ['transcript'] + + async def on_transcript(self, context): + transcript = context.data.get('transcript', '') + keywords = self.config.get('keywords', []) + + for keyword in keywords: + if keyword.lower() in transcript.lower(): + await self.send_alert(keyword, transcript) + return PluginResult( + success=True, + message=f"Alert sent for keyword: {keyword}" + ) + + return PluginResult(success=True) +``` + +### Example 3: Analytics Tracker + +```python +class AnalyticsTrackerPlugin(BasePlugin): + SUPPORTED_ACCESS_LEVELS = ['conversation', 'memory'] + + async def on_conversation_complete(self, context): + duration = context.data.get('duration', 0) + word_count = len(context.data.get('transcript', '').split()) + + await self.track_event('conversation_complete', { + 'user_id': context.user_id, + 'duration': duration, + 'word_count': word_count, + }) + + return PluginResult(success=True) + + async def on_memory_processed(self, context): + memory_count = context.data.get('memory_count', 0) + + await self.track_event('memory_processed', { + 'user_id': context.user_id, + 'memory_count': memory_count, + }) + + return PluginResult(success=True) +``` + +## Troubleshooting + +### Plugin Not Loading + +**Check logs**: +```bash +docker compose logs chronicle-backend | grep "plugin" +``` + +**Common issues**: +- Plugin directory name doesn't match class name convention +- Missing `__init__.py` or incorrect exports +- Syntax errors in plugin.py +- Not inheriting from `BasePlugin` + +**Solution**: +1. Verify directory structure matches: `plugins/my_plugin/` +2. Class name should be: `MyPluginPlugin` +3. Export in `__init__.py`: `from .plugin import MyPluginPlugin` + +### Plugin Enabled But Not Executing + +**Check**: +- Plugin enabled in `plugins.yml` +- Correct events subscribed +- Condition matches (wake_word, regex, etc.) + +**Debug**: +```python +async def on_conversation_complete(self, context): + logger.info(f"Plugin executed! Context: {context}") + # Your logic +``` + +### Configuration Errors + +**Error**: `Environment variable not found` + +**Solution**: +- Add variable to `.env` file +- Use default values: `${VAR:-default}` +- Check variable name spelling + +### Import Errors + +**Error**: `ModuleNotFoundError` + +**Solution**: +- Restart backend after adding dependencies +- Verify imports are from correct modules +- Check relative imports use `..base` for base classes + +### Database Connection Issues + +**Error**: `Database connection failed` + +**Solution**: +```python +from advanced_omi_backend.database import get_database + +async def my_method(self): + db = get_database() # Global database handle + # Use db... +``` + +## Advanced Topics + +### Custom Conditions + +Implement custom condition checking: + +```python +async def on_conversation_complete(self, context): + # Custom condition check + if not self._should_execute(context): + return PluginResult(success=True, message="Skipped") + + # Your logic + ... + +def _should_execute(self, context): + # Custom logic + duration = context.data.get('duration', 0) + return duration > 60 # Only process long conversations +``` + +### Plugin Dependencies + +Share data between plugins using context metadata: + +```python +# Plugin A +async def on_conversation_complete(self, context): + context.metadata['extracted_keywords'] = ['important', 'urgent'] + return PluginResult(success=True) + +# Plugin B (executes after Plugin A) +async def on_conversation_complete(self, context): + keywords = context.metadata.get('extracted_keywords', []) + # Use keywords... +``` + +### External Service Integration + +```python +import aiohttp + +class ExternalServicePlugin(BasePlugin): + async def initialize(self): + self.session = aiohttp.ClientSession() + self.api_url = self.config.get('api_url') + self.api_key = self.config.get('api_key') + + async def cleanup(self): + await self.session.close() + + async def on_conversation_complete(self, context): + async with self.session.post( + self.api_url, + headers={'Authorization': f'Bearer {self.api_key}'}, + json={'transcript': context.data.get('transcript')} + ) as response: + result = await response.json() + return PluginResult(success=True, data=result) +``` + +## Resources + +- **Base Plugin Class**: `backends/advanced/src/advanced_omi_backend/plugins/base.py` +- **Example Plugins**: + - Email Summarizer: `plugins/email_summarizer/` + - Home Assistant: `plugins/homeassistant/` + - Test Event: `plugins/test_event/` +- **Plugin Generator**: `scripts/create_plugin.py` +- **Configuration**: `config/plugins.yml.template` + +## Contributing Plugins + +Want to share your plugin with the community? + +1. Create a well-documented plugin +2. Add comprehensive README +3. Include configuration examples +4. Test thoroughly +5. Submit PR to Chronicle repository + +## Support + +- **GitHub Issues**: [chronicle-ai/chronicle/issues](https://github.com/chronicle-ai/chronicle/issues) +- **Discussions**: [chronicle-ai/chronicle/discussions](https://github.com/chronicle-ai/chronicle/discussions) +- **Documentation**: [Chronicle Docs](https://github.com/chronicle-ai/chronicle) + +Happy plugin development! 🚀 diff --git a/backends/advanced/event-detection/README.md b/backends/advanced/event-detection/README.md new file mode 100644 index 00000000..3d111987 --- /dev/null +++ b/backends/advanced/event-detection/README.md @@ -0,0 +1,476 @@ +# Event Detection - Whisper LoRA Adapter + +**🟢 STATUS: TRAINING/EXPORT WORKFLOW (User-Loop → Export → Training)** + +This folder contains **training/export utilities** for Whisper + LoRA event detection. It integrates with the **Chronicle user-loop** for continuous data collection and training. + +--- + +## 📋 Overview + +This system uses a **LoRA (Low-Rank Adaptation)** adapter on top of Whisper's Large V3 model to detect specific custom events (sounds, keywords, phrases) in audio. + +### Workflow: + +``` +Backend Anomaly Scan Job (sets maybe_anomaly: true) + │ + ▼ +User-Loop Popup (Review Anomalies) + │ + ├──► Swipe Right → Accept/Verify + │ │ + │ ▼ + │ MongoDB: maybe_anomaly = "verified" + │ + └──► Swipe Left → Reject/Stash + │ + ▼ + MongoDB: training_stash collection + │ + ▼ + Export: user_loop_feedback.jsonl + │ + ▼ + Train: LoRA adapter +``` + +--- + +## 📁 Files + +| File | Purpose | Status | +|-------|----------|--------| +| `export_from_mongo.py` | Export MongoDB `training_stash` to JSONL for training | ✅ Active (Bridge) | +| `train.py` | Fine-tune Whisper with LoRA adapter | ✅ Active | +| `requirements.txt` | Python dependencies | ✅ Active | + +Anomaly flagging (setting `maybe_anomaly: true` in MongoDB) is handled by the backend script `backends/advanced/src/advanced_omi_backend/scripts/run_anomaly_detection.py`. + +--- + +## 🚀 Production Workflow + +### Step 1: Data Collection (User-Loop) + +**Users interact with user-loop popup:** + +1. **Frontend shows popup** when conversations have `maybe_anomaly: true` +2. **User reviews transcript** and audio +3. **Swipe Left** → Reject (stashes for training) +4. **Swipe Right** → Accept (marks as verified, `maybe_anomaly: "verified"`) + +**MongoDB Collections:** + +```javascript +// conversations - User-Loop reviews these +{ + "conversation_id": "1a43e276-...", + "transcript_versions": [{ + "version_id": "c9c392d9-...", + "maybe_anomaly": true, // Triggers popup + "transcript": "The stale smell of old beer..." + }] +} + +// training_stash - User-Loop saves rejected items here +{ + "_id": ObjectId("..."), + "version_id": "c9c392d9-...", + "conversation_id": "1a43e276-...", + "transcript": "The stale smell of old beer...", + "reason": "False positive", + "timestamp": 1738254720.123, + "audio_chunks": [...], + "metadata": {"word_count": 43} +} +``` + +--- + +### Step 2: Export Training Data (Bridge) + +**Export MongoDB `training_stash` collection to JSONL format:** + +```bash +uv run python export_from_mongo.py \ + --output user_loop_feedback.jsonl \ + --min_samples 10 +``` + +**Output (`user_loop_feedback.jsonl`):** +```json +{"audio": "/data/audio/1a43e276-....wav", "text": "The stale smell of old beer...", "type": "positive", "timestamp": "2024-01-30T10:00:00Z"} +{"audio": "/data/audio/another-id.wav", "text": "Transcription with ", "type": "positive", "timestamp": "2024-01-30T10:05:00Z"} +``` + +**Arguments:** +- `--output`: Output JSONL file path (default: `user_loop_feedback.jsonl`) +- `--mongo_uri`: MongoDB connection (default: `mongodb://localhost:27017`) +- `--db_name`: Database name (default: `chronicle`) +- `--min_samples`: Minimum samples to export (default: 0) + +**Schema Mapping:** +```python +# MongoDB → Training JSONL +{ + "audio": f"/data/audio/{entry['conversation_id']}.wav", + "text": entry["transcript"], + "timestamp": entry.get("timestamp"), + "type": "positive" # All user-loop rejections = positive for training +} +``` + +--- + +### Step 3: Train LoRA Adapter + +**Fine-tune Whisper with exported user-loop data:** + +```bash +uv run python train.py \ + --train_manifest user_loop_feedback.jsonl \ + --output_dir ./sneeze_adapter \ + --base_model unsloth/whisper-large-v3 \ + --source_tag "" \ + --target_token "EVENT_DETECTED" +``` + +**Training Parameters:** +- `--train_manifest`: JSONL file from export (default: `train.jsonl`) +- `--output_dir`: Directory to save adapter (default: `event_lora_adapter_unsloth`) +- `--base_model`: Whisper model ID (default: `unsloth/whisper-large-v3`) +- `--source_tag`: Tag in text to replace (default: ``) +- `--target_token`: Token to emit for event (default: `EVENT_DETECTED`) + +**Output:** +```bash +./sneeze_adapter/ + ├── adapter_config.json + ├── adapter_model.safetensors + └── README.md +``` + +--- + +### Step 4: Flag New Anomalies (Backend Job) + +The backend provides a MongoDB scan job that sets `transcript_versions.$.maybe_anomaly = True` for transcripts that haven't been reviewed yet. + +From `backends/advanced/`: + +```bash +uv run python src/advanced_omi_backend/scripts/run_anomaly_detection.py +``` + +Notes: +- Configure MongoDB via `MONGODB_URI` (defaults to `mongodb://localhost:27017`). +- This script is currently a placeholder implementation (it marks unflagged transcripts as anomalies). + +--- + +## 🔄 Full Cycle Diagram + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ USER INTERACTION PHASE │ +│ Frontend shows popup when maybe_anomaly: true │ +└──────────────────┬──────────────────────────────────────────┘ + │ + ┌──────────┴──────────┐ + │ Swipe Actions │ + │ Left=Reject Right=Accept + ┌────┴────┐ ┌────┴────┐ + │ Swipe │ │ Swipe │ + │ Left │ │ Right │ + ▼ ▼ ▼ ▼ + Reject Reject Accept Accept + (stash) (stash) (verify) (verify) + │ │ │ │ + ▼ ▼ │ │ + MongoDB: MongoDB: │ │ + training_ training_ ▼ ▼ + stash stash maybe_ maybe_ + anomaly anomaly + ="verified" + │ │ + ▼ ▼ + More Data in + training_stash + │ + ▼ + ┌────────────────────┐ + │ EXPORT PHASE │ + │ export_from_mongo │ + │ .py │ + └────────┬─────────┘ + │ + ▼ + ┌────────────────────┐ + │ TRAINING PHASE │ + │ train.py │ + └────────┬─────────┘ + │ + ▼ + ┌────────────────────┐ + │ ./sneeze_ │ + │ adapter/ │ + └────────┬─────────┘ + │ + ▼ + ┌──────────────────────────────┐ + │ BACKEND ANOMALY SCAN (JOB) │ + │ run_anomaly_detection.py │ + │ sets maybe_anomaly: true │ + └─────────────┬────────────────┘ + │ + ▼ + ┌────────────────────┐ + │ User Popup │ + │ (Round 2) │ + └────────────────────┘ + │ + └──► Back to start! +``` + +--- + +## 📊 Training Data Format + +The training JSONL file (`user_loop_feedback.jsonl`) uses this schema: + +```json +{ + "audio": "/path/to/audio.wav", + "text": "Transcription with tag or just normal speech", + "timestamp": "2024-01-30T10:00:00Z", + "type": "positive" +} +``` + +**Fields:** +- `audio`: Path to audio file for training +- `text`: Ground truth transcription (from user-loop transcript) +- `timestamp`: When sample was added (from MongoDB) +- `type`: Always `"positive"` for user-loop data (rejections = positive training samples) + +--- + +## 🧠 Training Details + +### LoRA Configuration + +- **Base Model**: Whisper Large V3 (unsloth) +- **Adapter Type**: LoRA (Low-Rank Adaptation) +- **Parameters**: + - `r`: Rank (8-32 recommended) + - `lora_alpha`: Scaling factor (16-64) + - `target_modules`: `["q_proj", "v_proj"]` + - `dtype`: `float16` for CUDA, `float32` for CPU + +### Training Process + +1. Load base model (Whisper Large V3) +2. Load training data (audio + transcriptions) +3. Fine-tune adapter layers only +4. Validate on test set +5. Save adapter weights to output directory + +--- + +## 🎯 Production Deployment + +### Automated Workflow + +**Setup Cron Jobs for continuous improvement:** + +```bash +# crontab -e + +# Export training data daily at 2 AM +0 2 * * * cd /path/to/backends/advanced/event-detection && uv run python export_from_mongo.py --min_samples 50 + +# Retrain adapter weekly on Sunday at 3 AM +0 3 * * 0 cd /path/to/backends/advanced/event-detection && uv run python train.py --train_manifest user_loop_feedback.jsonl +``` + +### Adapter Versioning + +Store versioned adapters for A/B testing and rollback: + +```bash +./adapters/ + ├── sneeze_v1/ # Initial training + ├── sneeze_v2/ # After 100 samples + ├── sneeze_v3/ # After 500 samples + └── sneeze_latest/ # Symlink to current +``` + +### Monitoring + +Track metrics to improve detection: + +- **False Positive Rate**: Swipe left (reject) / Total popups +- **True Positive Rate**: Swipe right (accept) / Total swipes right +- **Detection Accuracy**: Correct detections / Total samples + +--- + +## 🐛 Troubleshooting + +### Issue: "No entries found in training_stash" + +**Symptoms:** +``` +❌ No entries found in training_stash collection +💡 Tip: Swipe left on user-loop popup to add samples (reject = stash) +``` + +**Solutions:** +1. Verify user-loop popup is working +2. Swipe left on some anomalies to add to training_stash +3. Check MongoDB connection +4. Lower `--min_samples` threshold + +--- + +### Issue: "Adapter output directory missing" + +**Symptoms:** +``` +Expected adapter directory not found: ./sneeze_adapter +``` + +**Solutions:** +1. Verify `--output_dir` matches where you expect the adapter to be saved +2. Check if train.py completed successfully +3. Ensure the output directory exists and is writable + +--- + +### Issue: "No audio files found" (in export) + +**Symptoms:** +``` +Exported 0 entries to user_loop_feedback.jsonl + Has audio chunks: 0/10 +``` + +**Solutions:** +1. Verify conversations have audio in MongoDB +2. Check audio_chunks collection +3. Ensure audio was uploaded correctly + +--- + +### Issue: CUDA out of memory + +**Symptoms:** +``` +RuntimeError: CUDA out of memory +``` + +**Solutions:** +1. Reduce `--batch_size` (try 2 or 1) +2. Use CPU with `torch_dtype=torch.float32` +3. Use smaller base model (Whisper Base instead of Large) + +--- + +### Issue: Poor detection accuracy + +**Symptoms:** +- High false positive rate +- Misses obvious events +- Random detections + +**Solutions:** +1. **More data**: Need at least 100-500 samples +2. **Better labels**: Review user-loop feedback for accuracy +3. **Retrain**: Train with more epochs +4. **Adjust trigger token**: Check if token appears in training data + +--- + +## 📚 Dependencies + +Install required packages: + +```bash +pip install -r requirements.txt +``` + +**Key Dependencies:** +- `unsloth`: Optimized Whisper model +- `transformers`: Hugging Face model library +- `peft`: LoRA adapters +- `torch`: Deep learning framework +- `librosa`: Audio processing +- `datasets`: Training data utilities + +**System Requirements:** +- Python 3.8+ +- CUDA-capable GPU (recommended) or 16GB+ RAM for CPU + +--- + +## 🔗 Integration with Backend + +### Current State + +**Frontend:** +```typescript +// UserLoopModal.tsx +const checkAnomaly = async () => { + // TODO: Replace with actual algorithm + const shouldShow = true // Always shows popup +} +``` + +**Backend:** +```python +# user_loop_routes.py +- ✅ GET /api/user-loop/events (returns anomalies) +- ✅ POST /api/user-loop/accept (verifies) +- ✅ POST /api/user-loop/reject (stashes to training) +- ✅ Anomaly scan job: src/advanced_omi_backend/scripts/run_anomaly_detection.py (sets maybe_anomaly: true) +``` + +### Future Integration + +To replace the placeholder scan with **model-based anomaly detection**: + +1. Train an adapter in this folder (`train.py`) and version it. +2. Load the adapter in the backend scan job and use inference to decide whether to set `maybe_anomaly: true`. +3. Ensure the UI only opens the user-loop popup when `/api/user-loop/events` returns events. + +--- + +## 📊 Workflow Summary + +| Phase | Component | Command | +|--------|-----------|----------| +| **1. Collection** | User-Loop Popup | User swipes left to reject (stash) / right to accept | +| **2. Storage** | MongoDB | Saves to `training_stash` collection | +| **3. Export** | export_from_mongo.py | `uv run python export_from_mongo.py --min_samples 10` | +| **4. Training** | train.py | `uv run python train.py --train_manifest user_loop_feedback.jsonl --output_dir ./sneeze_adapter` | +| **5. Flagging** | Backend job | `cd .. && uv run python src/advanced_omi_backend/scripts/run_anomaly_detection.py` | +| **6. Deployment** | Backend | (Future) Use trained adapter inside the scan job | + +--- + +## 🤝 Contributing + +To improve event detection: + +1. **Collect More Data**: Swipe left on user-loop popup to reject (stash) samples +2. **Review Labels**: Check training data quality +3. **Retrain Often**: Update adapter weekly with new data +4. **A/B Test**: Compare new vs old adapters +5. **Monitor Metrics**: Track false positive/negative rates + +--- + +**Last Updated**: January 30, 2026 +**Status**: 🟢 Training/Export Workflow (User-Loop → Export → Training) ✅ +**Backend Integration**: 🟡 Partial (flagging job exists; model-based detection pending) diff --git a/backends/advanced/event-detection/export_from_mongo.py b/backends/advanced/event-detection/export_from_mongo.py new file mode 100755 index 00000000..4f057c84 --- /dev/null +++ b/backends/advanced/event-detection/export_from_mongo.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +""" +Export MongoDB Training Stash to JSONL for LoRA Training + +Bridge script: User-Loop (MongoDB) → Training (JSONL) + +Usage: + python export_from_mongo.py --output user_loop_feedback.jsonl + +Workflow: + 1. Users swipe on user-loop popup + 2. Items saved to MongoDB training_stash collection + 3. Export with this script to JSONL format + 4. Train LoRA adapter: python train.py --train_manifest user_loop_feedback.jsonl +""" + +import argparse +import json +from datetime import datetime +from pymongo import MongoClient + +def parse_args(): + parser = argparse.ArgumentParser(description="Export training stash from MongoDB to JSONL") + parser.add_argument("--output", type=str, default="user_loop_feedback.jsonl", help="Output JSONL file") + parser.add_argument("--mongo_uri", type=str, default="mongodb://localhost:27017", help="MongoDB connection string") + parser.add_argument("--db_name", type=str, default="chronicle", help="Database name") + parser.add_argument("--min_samples", type=int, default=0, help="Minimum samples to export") + return parser.parse_args() + +def export_training_stash(mongo_uri, db_name, output_file, min_samples): + """ + Export MongoDB training_stash collection to JSONL format + + Schema: + { + "audio": "/path/to/audio.wav", + "text": "Transcription with tag", + "timestamp": "2024-01-30T10:00:00Z", + "type": "positive" + } + """ + print(f"🔗 Connecting to MongoDB: {mongo_uri}") + print(f"📁 Database: {db_name}") + + client = MongoClient(mongo_uri) + db = client[db_name] + + # Fetch all training stash entries + entries = list(db.training_stash.find({})) + + if not entries: + print("❌ No entries found in training_stash collection") + print("💡 Tip: Swipe right on user-loop popup to add samples") + return False + + if len(entries) < min_samples: + print(f"⚠️ Found {len(entries)} entries (minimum: {min_samples})") + return False + + print(f"✅ Found {len(entries)} entries in training_stash") + + # Convert to JSONL format + exported = 0 + with open(output_file, "w") as f: + for entry in entries: + # Map MongoDB schema to training schema + training_sample = { + "audio": f"/data/audio/{entry['conversation_id']}.wav", + "text": entry["transcript"], + "timestamp": entry.get("timestamp", datetime.now().isoformat()), + "type": "positive" # User-loop rejections = positive for training + } + + # Write as JSONL (one JSON per line) + f.write(json.dumps(training_sample) + "\n") + exported += 1 + + print(f"💾 Exported {exported} entries to {output_file}") + + # Print statistics + print(f"\n📊 Statistics:") + print(f" Total exported: {exported}") + + # Count unique conversations + unique_convs = set(e["conversation_id"] for e in entries) + print(f" Unique conversations: {len(unique_convs)}") + + # Check audio data + has_audio = sum(1 for e in entries if e.get("audio_chunks") and len(e["audio_chunks"]) > 0) + print(f" Has audio chunks: {has_audio}/{exported}") + + return True + +def main(): + args = parse_args() + + print("🚀 MongoDB Training Stash Export") + print("=" * 50) + + success = export_training_stash( + args.mongo_uri, + args.db_name, + args.output, + args.min_samples + ) + + print("\n" + "=" * 50) + + if success: + print("✅ Export complete!") + print("\n🎯 Next Steps:") + print(" 1. Review exported file:", args.output) + print(" 2. Train LoRA adapter:") + print(f" python train.py --train_manifest {args.output}") + print(" 3. Run anomaly scan (MongoDB flagging):") + print(" cd .. && uv run python src/advanced_omi_backend/scripts/run_anomaly_detection.py") + else: + print("❌ Export failed") + print("\n💡 Suggestions:") + print(" - Swipe left on user-loop popup to add samples") + print(" - Check MongoDB connection") + print(" - Lower --min_samples threshold") + +if __name__ == "__main__": + main() diff --git a/backends/advanced/event-detection/manage_data.py b/backends/advanced/event-detection/manage_data.py new file mode 100644 index 00000000..3effb743 --- /dev/null +++ b/backends/advanced/event-detection/manage_data.py @@ -0,0 +1,127 @@ +import os +import json +import argparse +import shutil +from datetime import datetime +import torch +from transformers import WhisperProcessor, WhisperForConditionalGeneration +import librosa + +def parse_args(): + parser = argparse.ArgumentParser(description="Manage Event Detection Dataset") + subparsers = parser.add_subparsers(dest="command", required=True) + + # Bootstrap command + cmd_bootstrap = subparsers.add_parser("bootstrap", help="Create initial manifest from folder") + cmd_bootstrap.add_argument("--audio_dir", required=True, help="Directory containing positive audio samples") + cmd_bootstrap.add_argument("--output_manifest", default="train.jsonl", help="Output jsonl file") + cmd_bootstrap.add_argument("--source_tag", default="", help="Tag to use for event") + + # Feedback command + cmd_feedback = subparsers.add_parser("feedback", help="Add feedback (positive/negative)") + cmd_feedback.add_argument("--audio_path", required=True, help="Path to audio file") + cmd_feedback.add_argument("--is_positive", action="store_true", help="Flag if sample is positive instance of event") + cmd_feedback.add_argument("--manifest", default="train.jsonl", help="Manifest file to update") + cmd_feedback.add_argument("--source_tag", default="", help="Tag to use if positive") + cmd_feedback.add_argument("--base_model", default="unsloth/whisper-large-v3", help="Base model for transcription") + + return parser.parse_args() + +def transcribe_audio(audio_path, model_id): + """Transcribe audio using base model to get ground truth text""" + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Transcribing with {model_id} on {device}...") + + processor = WhisperProcessor.from_pretrained(model_id) + model = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32) + model.to(device) + model.eval() + + audio, _ = librosa.load(audio_path, sr=16000) + inputs = processor(audio, sampling_rate=16000, return_tensors="pt") + input_features = inputs.input_features.to(device) + + if device == "cuda": + input_features = input_features.half() + + with torch.no_grad(): + generated_ids = model.generate(input_features) + + text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() + return text + +def bootstrap(audio_dir, output_manifest, source_tag): + if not os.path.exists(audio_dir): + raise FileNotFoundError(f"{audio_dir} does not exist") + + entries = [] + files = [f for f in os.listdir(audio_dir) if f.lower().endswith(('.wav', '.mp3', '.m4a'))] + + if not files: + print("No audio files found in directory.") + return + + print(f"Found {len(files)} files. Adding to {output_manifest}...") + + # Check if we should append or overwrite? Spec implies bootstrapping starts the process. + # I'll append if exists, or create new. + mode = 'a' if os.path.exists(output_manifest) else 'w' + + with open(output_manifest, mode) as f: + for filename in files: + file_path = os.path.join(audio_dir, filename) + # For bootstrapping, we assume these are just the event itself, + # so text is just the tag. + # Realistically, we might want to transcribe it, but let's keep it simple for now + # as these might be non-speech sounds (sneezes). + entry = { + "audio": file_path, + "text": source_tag + } + f.write(json.dumps(entry) + "\n") + print(f"Added {filename}") + +def add_feedback(audio_path, is_positive, manifest_path, source_tag, base_model): + if not os.path.exists(audio_path): + raise FileNotFoundError(f"{audio_path} not found") + + # Get ground truth text + # If positive, we want "Transcription " or just "" if it's non-speech? + # If negative, we want "Transcription" (without tag). + + base_text = transcribe_audio(audio_path, base_model) + print(f"Base transcription: {base_text}") + + if is_positive: + # If it's a positive sample, we ensure the tag is present. + # If the model didn't transcribe it (likely), we append it. + # If it's a pronunciation correction, the user might want to replace a specific word... + # But for "arbitrary event detection" usually implies adding a marker. + # Simple heuristic: Append tag to text. + final_text = f"{base_text} {source_tag}".strip() + else: + # Negative sample: The text is just what the base model hears (without the tag) + final_text = base_text.replace(source_tag, "") # Ensure tag isn't there by accident + + entry = { + "audio": audio_path, + "text": final_text, + "timestamp": datetime.now().isoformat(), + "type": "positive" if is_positive else "negative" + } + + with open(manifest_path, 'a') as f: + f.write(json.dumps(entry) + "\n") + + print(f"Added {'positive' if is_positive else 'negative'} feedback for {audio_path}") + +def main(): + args = parse_args() + + if args.command == "bootstrap": + bootstrap(args.audio_dir, args.output_manifest, args.source_tag) + elif args.command == "feedback": + add_feedback(args.audio_path, args.is_positive, args.manifest, args.source_tag, args.base_model) + +if __name__ == "__main__": + main() diff --git a/backends/advanced/event-detection/requirements.txt b/backends/advanced/event-detection/requirements.txt new file mode 100644 index 00000000..748abf19 --- /dev/null +++ b/backends/advanced/event-detection/requirements.txt @@ -0,0 +1,13 @@ +unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git +transformers +torch +librosa +numpy +datasets +evaluate +jiwer +tqdm +soundfile +accelerate +bitsandbytes +peft diff --git a/backends/advanced/event-detection/train.py b/backends/advanced/event-detection/train.py new file mode 100644 index 00000000..413e3a23 --- /dev/null +++ b/backends/advanced/event-detection/train.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 +# Event Detection Training - Fine-tune Whisper with LoRA +# ⚠️ NOT INTEGRATED WITH MAIN BACKEND - Use directly from CLI + + +import os +import json +import argparse +import librosa +import numpy as np +import torch +import tqdm +import evaluate +from dataclasses import dataclass +from typing import Any, Dict, List, Union + +from unsloth import FastModel, is_bf16_supported +from transformers import WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer +from datasets import Dataset + +# Global Constants +DEFAULT_MODEL_ID = "unsloth/whisper-large-v3" +DEFAULT_OUTPUT_DIR = "event_lora_adapter_unsloth" +DEFAULT_TRAIN_MANIFEST = "train.jsonl" +DEFAULT_TARGET_TOKEN = "EVENT_DETECTED" +DEFAULT_SOURCE_TAG = "" + +def parse_args(): + parser = argparse.ArgumentParser(description="Train Whisper LoRA for Event Detection") + parser.add_argument("--train_manifest", type=str, default=DEFAULT_TRAIN_MANIFEST, help="Path to training data jsonl") + parser.add_argument("--output_dir", type=str, default=DEFAULT_OUTPUT_DIR, help="Output directory for adapter") + parser.add_argument("--base_model", type=str, default=DEFAULT_MODEL_ID, help="Base Whisper model ID") + parser.add_argument("--target_token", type=str, default=DEFAULT_TARGET_TOKEN, help="Token to emit for event") + parser.add_argument("--source_tag", type=str, default=DEFAULT_SOURCE_TAG, help="Tag in text to replace") + return parser.parse_args() + +def prepare_dataset(manifest_path: str, source_tag: str, target_token: str) -> Dataset: + """Load and normalize training data from jsonl""" + audio_paths = [] + texts = [] + + if not os.path.exists(manifest_path): + raise FileNotFoundError(f"{manifest_path} not found.") + + with open(manifest_path, 'r') as f: + for line in f: + entry = json.loads(line) + # Normalize text + text = entry['text'].replace(source_tag, target_token) + + audio_paths.append(entry['audio']) + texts.append(text) + + # Oversampling for target keyword (x5 as per spec) + if target_token in text: + for _ in range(5): + audio_paths.append(entry['audio']) + texts.append(text) + + print(f"Loaded {len(audio_paths)} training samples.") + + data = { + "audio_path": audio_paths, + "text": texts + } + + return Dataset.from_dict(data) + +@dataclass +class DataCollatorSpeechSeq2SeqWithPadding: + processor: Any + + def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: + input_features = [{"input_features": feature["input_features"]} for feature in features] + batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") + + label_features = [{"input_ids": feature["labels"]} for feature in features] + labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") + + labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) + + if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item(): + labels = labels[:, 1:] + + batch["labels"] = labels + + return batch + +def main(): + args = parse_args() + + print(f"Loading model with Unsloth: {args.base_model}") + + # Load model using Unsloth's FastModel + model, tokenizer = FastModel.from_pretrained( + model_name=args.base_model, + dtype=None, # Auto detection + load_in_4bit=False, + auto_model=WhisperForConditionalGeneration, + whisper_language="English", + whisper_task="transcribe", + ) + + # Apply LoRA adapters + model = FastModel.get_peft_model( + model, + r=64, + target_modules=["q_proj", "v_proj"], + lora_alpha=64, + lora_dropout=0, + bias="none", + use_gradient_checkpointing="unsloth", + random_state=3407, + use_rslora=False, + loftq_config=None, + task_type=None, + ) + + # Configure generation settings + model.generation_config.language = "<|en|>" + model.generation_config.task = "transcribe" + model.config.suppress_tokens = [] + model.generation_config.forced_decoder_ids = None + + # Load dataset + dataset = prepare_dataset(args.train_manifest, args.source_tag, args.target_token) + + def formatting_prompts_func(example): + """Process audio and text for training""" + # Load audio file + audio_array, sr = librosa.load(example['audio_path'], sr=16000) + + # Extract features + features = tokenizer.feature_extractor( + audio_array, sampling_rate=16000 + ) + + # Tokenize text + tokenized_text = tokenizer.tokenizer(example["text"]) + + return { + "input_features": features.input_features[0], + "labels": tokenized_text.input_ids, + } + + print("Processing dataset...") + train_data = [] + for example in tqdm.tqdm(dataset, desc='Processing audio'): + # Errors will bubble up as per spec + result = formatting_prompts_func(example) + train_data.append(result) + + print(f"Successfully processed {len(train_data)} samples") + + # Split into train/test + split_idx = max(1, int(len(train_data) * 0.94)) + train_dataset = train_data[:split_idx] + test_dataset = train_data[split_idx:] + + print(f"Train samples: {len(train_dataset)}, Test samples: {len(test_dataset)}") + + metric = evaluate.load("wer") + + def compute_metrics(pred): + pred_logits = pred.predictions[0] + label_ids = pred.label_ids + + label_ids[label_ids == -100] = tokenizer.pad_token_id + + pred_ids = np.argmax(pred_logits, axis=-1) + + pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) + label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True) + + wer = 100 * metric.compute(predictions=pred_str, references=label_str) + + return {"wer": wer} + + data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=tokenizer) + + # Setup trainer + trainer = Seq2SeqTrainer( + model=model, + train_dataset=train_dataset, + data_collator=data_collator, + eval_dataset=test_dataset if len(test_dataset) > 0 else None, + tokenizer=tokenizer.feature_extractor, + compute_metrics=compute_metrics, + args=Seq2SeqTrainingArguments( + per_device_train_batch_size=1, + gradient_accumulation_steps=4, + warmup_steps=5, + max_steps=200, # Can be exposed as arg if needed, keeping simple for now + learning_rate=1e-4, + logging_steps=10, + optim="adamw_8bit", + fp16=not is_bf16_supported(), + bf16=is_bf16_supported(), + weight_decay=0.001, + remove_unused_columns=False, + lr_scheduler_type="linear", + label_names=['labels'], + eval_steps=20, + eval_strategy="steps" if len(test_dataset) > 0 else "no", + seed=3407, + output_dir=args.output_dir, + report_to="none", + ), + ) + + print("Starting training...") + trainer.train() + + # Save the model + print(f"Saving adapter to {args.output_dir}") + model.save_pretrained(args.output_dir) + tokenizer.save_pretrained(args.output_dir) + + print("Training complete!") + +if __name__ == "__main__": + main() diff --git a/backends/advanced/src/advanced_omi_backend/client_manager.py b/backends/advanced/src/advanced_omi_backend/client_manager.py index 68fd6ef8..0d7da277 100644 --- a/backends/advanced/src/advanced_omi_backend/client_manager.py +++ b/backends/advanced/src/advanced_omi_backend/client_manager.py @@ -9,6 +9,7 @@ import logging import uuid from typing import TYPE_CHECKING, Dict, Optional +import redis.asyncio as redis import redis.asyncio as redis diff --git a/backends/advanced/src/advanced_omi_backend/models/conversation.py b/backends/advanced/src/advanced_omi_backend/models/conversation.py index 2ec45f33..65cdfae5 100644 --- a/backends/advanced/src/advanced_omi_backend/models/conversation.py +++ b/backends/advanced/src/advanced_omi_backend/models/conversation.py @@ -11,7 +11,7 @@ from typing import Any, Dict, List, Optional, Union from beanie import Document, Indexed -from pydantic import BaseModel, Field, computed_field, field_validator, model_validator +from pydantic import BaseModel, Field, computed_field, model_validator from pymongo import IndexModel @@ -81,6 +81,13 @@ class TranscriptVersion(BaseModel): description="Source of speaker diarization: 'provider' (transcription service), 'pyannote' (speaker recognition), or None" ) metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional provider-specific metadata") + maybe_anomaly: Optional[Union[bool, str]] = Field( + None, + description=( + "Anomaly detection status: True (anomaly detected), False (no anomaly), " + "'verified' (user verified no anomaly), 'rejected' (user rejected/stashed for training)" + ) + ) class MemoryVersion(BaseModel): """Version of memory extraction with processing metadata.""" @@ -430,4 +437,4 @@ def create_conversation( if conversation_id is not None: conv_data["conversation_id"] = conversation_id - return Conversation(**conv_data) \ No newline at end of file + return Conversation(**conv_data) diff --git a/backends/advanced/src/advanced_omi_backend/routers/api_router.py b/backends/advanced/src/advanced_omi_backend/routers/api_router.py index e4c89531..e7af11a2 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/api_router.py +++ b/backends/advanced/src/advanced_omi_backend/routers/api_router.py @@ -23,6 +23,7 @@ obsidian_router, queue_router, system_router, + user_loop_router, user_router, ) from .modules.health_routes import router as health_router @@ -43,6 +44,7 @@ router.include_router(conversation_router) router.include_router(finetuning_router) router.include_router(knowledge_graph_router) +router.include_router(user_loop_router) router.include_router(memory_router) router.include_router(obsidian_router) router.include_router(system_router) diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/__init__.py b/backends/advanced/src/advanced_omi_backend/routers/modules/__init__.py index 501377fc..4a61a759 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/__init__.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/__init__.py @@ -15,6 +15,7 @@ - health_routes: Health check endpoints - websocket_routes: WebSocket connection handling - admin_routes: Admin-only system management endpoints +- user_loop_routes: Anomaly review and transcript verification - knowledge_graph_routes: Knowledge graph entities, relationships, and promises """ @@ -32,6 +33,7 @@ from .queue_routes import router as queue_router from .system_routes import router as system_router from .user_routes import router as user_router +from .user_loop_routes import router as user_loop_router from .websocket_routes import router as websocket_router __all__ = [ @@ -49,5 +51,6 @@ "queue_router", "system_router", "user_router", + "user_loop_router", "websocket_router", ] diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/user_loop_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/user_loop_routes.py new file mode 100644 index 00000000..cc341d36 --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/user_loop_routes.py @@ -0,0 +1,449 @@ +""" +User-loop routes for Chronicle backend. + +Provides anomaly review interface: +- GET /events: Returns conversations with maybe_anomaly: true +- POST /accept: Verifies transcript (sets maybe_anomaly to "verified") +- POST /reject: Stashes transcript for training (saves to training-stash) +- GET /audio/{version_id}: Returns audio file (converted to WAV if needed) + +Issues covered: +- Issue #1: Audio not playing (Opus→WAV conversion) +- Issue #2: /audio/undefined (404 error) +- Issue #3: FFmpeg not installed (fallback to Opus) +- Issue #5: Swipe right not working (accept updates MongoDB) +- Issue #6: Field name mismatch (uses transcript_version_id) +- Issue #7: Loading spinner stuck (empty events array) +- Issue #8: Wrong audio Content-Type (returns audio/wav) +""" + +import logging +import os +import io +import subprocess +import tempfile +import base64 +from datetime import datetime +from typing import Optional, List + +from bson import ObjectId +from fastapi import APIRouter, HTTPException, Depends +from fastapi.responses import FileResponse, Response +from pydantic import BaseModel + +# MongoDB client (shared with main app) +from advanced_omi_backend.database import get_database + +# Logging setup +logger = logging.getLogger(__name__) + +# Create router +router = APIRouter(tags=["user-loop"], prefix="/user-loop") + + +class SwipeAction(BaseModel): + """ + Request body for swipe actions (accept/reject). + + Note: Uses transcript_version_id (backend field) not version_id (frontend field). + This fixes Issue #6: Field name mismatch. + """ + transcript_version_id: str + conversation_id: str + reason: Optional[str] = None + timestamp: Optional[float] = None + + +class AnomalyEvent(BaseModel): + """Anomaly event returned to UI for review.""" + version_id: str + conversation_id: str + transcript: str + timestamp: float + audio_duration: float + speaker_count: int + word_count: int + audio_data: Optional[str] = None # Base64 encoded audio for preview + + +@router.get("/events") +async def get_events(db=Depends(get_database)) -> List[AnomalyEvent]: + """ + Returns list of anomaly events to review. + + Queries MongoDB for conversations with transcript versions where + maybe_anomaly is true (boolean). Verified transcripts + (maybe_anomaly: "verified") are filtered out. + + Fixes Issue #7: Loading spinner stuck. + Returns empty list when no anomalies exist. + """ + try: + events = [] + + # Query for conversations where ANY transcript version has maybe_anomaly: true + # Use $elemMatch to match array elements + pipeline = [ + { + "$match": { + "deleted": False, + "transcript_versions": { + "$elemMatch": { + "maybe_anomaly": True + } + } + } + }, + { + "$project": { + "conversation_id": 1, + "transcript_versions": 1, + "audio_chunks_count": 1, + "audio_total_duration": 1, + "created_at": 1 + } + } + ] + + cursor = db.conversations.aggregate(pipeline) + docs = await cursor.to_list(length=100) + + for doc in docs: + # Find transcript version with maybe_anomaly: true + for version in doc.get("transcript_versions", []): + if version.get("maybe_anomaly") is True: + # Handle both int and datetime for created_at + created_at_value = version.get("created_at", datetime.now()) + if isinstance(created_at_value, (int, float)): + timestamp = created_at_value + else: + timestamp = created_at_value.timestamp() + + event = AnomalyEvent( + version_id=version.get("version_id"), + conversation_id=doc.get("conversation_id"), + transcript=version.get("transcript", ""), + timestamp=timestamp, + audio_duration=doc.get("audio_total_duration", 0), + speaker_count=len([s for s in version.get("segments", []) if s.get("speaker")]), + word_count=version.get("metadata", {}).get("word_count", 0) + ) + events.append(event) + logger.info(f"Found {len(events)} anomaly events") + return events + + except Exception as e: + logger.error(f"Error fetching events: {e}") + raise HTTPException(status_code=500, detail=f"Error fetching events: {str(e)}") + + +@router.post("/accept") +async def accept_transcript(action: SwipeAction, db=Depends(get_database)): + """ + Accept transcript: Sets maybe_anomaly to "verified" string. + + This is a "left swipe" on the anomaly review interface. + After verification, the transcript won't appear in /events. + + Fixes Issue #5: Swipe right not working. + Updates MongoDB and sets verified_at timestamp. + + Args: + action: Swipe action with transcript_version_id and conversation_id + + Returns: + JSON response with status "success" + """ + try: + # Get conversation + conversation = await db.conversations.find_one({ + "conversation_id": action.conversation_id + }) + + if not conversation: + raise HTTPException(status_code=404, detail="Conversation Not Found") + + # Find target transcript version + target_version = None + for version in conversation.get("transcript_versions", []): + if version.get("version_id") == action.transcript_version_id: + target_version = version + break + + if not target_version: + raise HTTPException(status_code=404, detail="Transcript version Not Found") + + # Update transcript version: maybe_anomaly → "verified" + update_result = await db.conversations.update_one( + { + "conversation_id": action.conversation_id, + "transcript_versions.version_id": action.transcript_version_id + }, + { + "$set": { + "transcript_versions.$.maybe_anomaly": "verified", # String, not boolean + "transcript_versions.$.verified_at": datetime.now().isoformat() + } + } + ) + + if update_result.matched_count == 0: + raise HTTPException(status_code=404, detail="Conversation or version not found") + + logger.info(f"Verified transcript {action.transcript_version_id}") + + return { + "status": "success", + "message": "Verified transcript", + "version_id": action.transcript_version_id + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error accepting transcript: {e}") + raise HTTPException(status_code=500, detail=f"Error accepting transcript: {str(e)}") + + +@router.post("/reject") +async def reject_transcript(action: SwipeAction, db=Depends(get_database)): + """ + Reject transcript: Saves transcript and audio to training-stash. + + This is a "right swipe" on the anomaly review interface. + Stashes the transcript for training/fine-tuning models. + + Fixes Issue #5: Swipe right not working. + Saves to training-stash collection with audio data. + + Args: + action: Swipe action with transcript_version_id, conversation_id, and reason + + Returns: + JSON response with status "success" and stash_id + """ + try: + timestamp = action.timestamp or datetime.now().timestamp() + + # Get conversation details + conversation = await db.conversations.find_one({ + "conversation_id": action.conversation_id + }) + + if not conversation: + raise HTTPException(status_code=404, detail="Conversation Not Found") + + # Get specific transcript version + target_version = None + for version in conversation.get("transcript_versions", []): + if version.get("version_id") == action.transcript_version_id: + target_version = version + break + + if not target_version: + raise HTTPException(status_code=404, detail="Transcript version Not Found") + + # Get audio chunks for this conversation + audio_chunks_cursor = db.audio_chunks.find({ + "conversation_id": action.conversation_id + }).sort("chunk_index", 1) + + audio_chunks_data = [] + chunks = await audio_chunks_cursor.to_list(length=100) + for chunk in chunks: + # Get audio data - might be bytes or string + audio_data = chunk.get("audio_data") + # Convert to bytes if string + if isinstance(audio_data, str): + audio_data = audio_data.encode('utf-8') + + # Convert to bytes if string + if isinstance(audio_data, str): + audio_data = audio_data.encode('utf-8') + + # Convert binary audio to base64 for storage in training-stash + audio_b64 = base64.b64encode(audio_data).decode("utf-8") + audio_chunks_data.append({ + "chunk_index": chunk.get("chunk_index"), + "audio_data": audio_b64, + "duration": chunk.get("duration"), + "sample_rate": chunk.get("sample_rate"), + "channels": chunk.get("channels") + }) + + # Create training-stash entry + stash_entry = { + "transcript_version_id": action.transcript_version_id, # For test compatibility + "conversation_id": action.conversation_id, + "user_id": conversation.get("user_id"), + "client_id": conversation.get("client_id"), + "transcript": target_version.get("transcript"), + "segments": target_version.get("segments"), + "reason": action.reason, + "timestamp": timestamp, + "audio_chunks": audio_chunks_data, + "metadata": target_version.get("metadata", {}), + "created_at": datetime.now().isoformat() + } + + # Insert into training-stash + result = await db.training_stash.insert_one(stash_entry) + stash_id = str(result.inserted_id) + + # Mark transcript version as handled so it doesn't reappear in /events. + # /events only returns maybe_anomaly == True. + update_result = await db.conversations.update_one( + { + "conversation_id": action.conversation_id, + "transcript_versions.version_id": action.transcript_version_id, + }, + { + "$set": { + "transcript_versions.$.maybe_anomaly": "rejected", + "transcript_versions.$.rejected_at": datetime.now().isoformat(), + "transcript_versions.$.rejected_reason": action.reason, + } + }, + ) + + if update_result.matched_count == 0: + raise HTTPException(status_code=404, detail="Conversation or version not found") + + logger.info(f"Stashed transcript {action.transcript_version_id} with reason: {action.reason}") + + return { + "status": "success", + "message": "Stashed transcript for training", + "stash_id": stash_id, + "version_id": action.transcript_version_id + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error rejecting transcript: {e}") + raise HTTPException(status_code=500, detail=f"Error rejecting transcript: {str(e)}") + + +@router.get("/audio/{version_id}") +async def get_transcript_audio(version_id: str, db=Depends(get_database)): + """ + Returns audio file for a transcript version. + + Converts Opus audio to WAV format if FFmpeg is available. + Falls back to serving original Opus if conversion fails. + + Fixes: + - Issue #1: Audio not playing (Opus→WAV conversion) + - Issue #2: /audio/undefined returns 404 + - Issue #3: FFmpeg not installed (fallback to Opus) + - Issue #8: Wrong audio Content-Type (returns audio/wav) + + Args: + version_id: Transcript version ID to get audio for + + Returns: + FileResponse with audio/wav or audio/ogg content-type + """ + try: + # Find conversation with this version + pipeline = [ + {"$unwind": "$transcript_versions"}, + {"$match": {"transcript_versions.version_id": version_id}}, + {"$project": {"conversation_id": 1}} + ] + + cursor = db.conversations.aggregate(pipeline) + doc = await cursor.to_list(length=1) + + if not doc: + raise HTTPException(status_code=404, detail=f"Transcript version {version_id} not found") + + conversation_id = doc[0].get("conversation_id") + + # Get audio chunks + audio_chunks_cursor = db.audio_chunks.find({ + "conversation_id": conversation_id + }).sort("chunk_index", 1) + + chunks = await audio_chunks_cursor.to_list(length=100) + + if not chunks: + raise HTTPException(status_code=404, detail="No audio found for this transcript") + + # Combine audio chunks (assuming they're in the right format) + combined_audio = b"" + for chunk in chunks: + audio_data = chunk.get("audio_data") + if audio_data: + # Convert to bytes if string + if isinstance(audio_data, str): + audio_data = audio_data.encode('utf-8') + combined_audio += audio_data + + # Try to convert Opus to WAV using FFmpeg + try: + with tempfile.NamedTemporaryFile(delete=False, suffix=".opus") as opus_file: + opus_file.write(combined_audio) + opus_path = opus_file.name + + with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as wav_file: + wav_path = wav_file.name + + # Convert using FFmpeg: 16kHz, mono, PCM16 + ffmpeg_cmd = [ + "ffmpeg", + "-y", # Overwrite output file + "-i", opus_path, + "-acodec", "pcm_s16le", # PCM 16-bit little-endian + "-ar", "16000", # 16kHz sample rate + "-ac", "1", # Mono + wav_path + ] + + result = subprocess.run( + ffmpeg_cmd, + capture_output=True, + text=True, + timeout=10 + ) + + # Read WAV file + with open(wav_path, "rb") as f: + wav_audio = f.read() + + # Cleanup temp files + os.unlink(opus_path) + os.unlink(wav_path) + + logger.info(f"Converted audio to WAV: {len(wav_audio)} bytes") + + # Return WAV file + return Response( + content=wav_audio, + media_type="audio/wav", + headers={ + "Content-Disposition": f"attachment; filename=audio_{version_id}.wav", + "Content-Length": str(len(wav_audio)) + } + ) + + except (subprocess.TimeoutExpired, FileNotFoundError, Exception) as e: + # FFmpeg not available or failed - fallback to Opus + logger.warning(f"FFmpeg conversion failed: {e}, serving original Opus") + + # Return original Opus audio + return Response( + content=combined_audio, + media_type="audio/ogg", + headers={ + "Content-Disposition": f"attachment; filename=audio_{version_id}.opus", + "Content-Length": str(len(combined_audio)) + } + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting audio: {e}") + raise HTTPException(status_code=500, detail=f"Error getting audio: {str(e)}") diff --git a/backends/advanced/src/advanced_omi_backend/scripts/run_anomaly_detection.py b/backends/advanced/src/advanced_omi_backend/scripts/run_anomaly_detection.py new file mode 100644 index 00000000..244b0402 --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/scripts/run_anomaly_detection.py @@ -0,0 +1,76 @@ +""" +Placeholder anomaly detection script. + +This will be a cron job which uses some kind of algorithm to detect if transcription +is bad. It scans MongoDB for transcripts that don't have the maybe_anomaly flag set, +runs the detection in run_anomaly_detection(), and sets maybe_anomaly to True for +transcripts that look anomalous (placeholder - currently always marks as true). +""" + +import asyncio +import os +from datetime import datetime +from motor.motor_asyncio import AsyncIOMotorClient + +# MongoDB Configuration +MONGODB_URI = os.getenv("MONGODB_URI", "mongodb://localhost:27017") +DB_NAME = "chronicle" + +async def run_anomaly_detection(): + """Run anomaly detection on all transcripts without maybe_anomaly flag.""" + + # Connect to MongoDB + client = AsyncIOMotorClient(MONGODB_URI) + db = client[DB_NAME] + + print("🔍 Starting anomaly detection scan...") + + try: + # Find conversations with transcript versions where maybe_anomaly is not set + cursor = db.conversations.find({ + "deleted": False, + "transcript_versions.maybe_anomaly": None # Find transcripts without flag + }) + + count = 0 + conversations = await cursor.to_list(length=100) + + for conversation in conversations: + transcript_versions = conversation.get("transcript_versions", []) + + for version in transcript_versions: + if version.get("maybe_anomaly") is None: + # Mark as potential anomaly (placeholder - always returns true) + version_id = version.get("version_id") + + result = await db.conversations.update_one( + { + "conversation_id": conversation["conversation_id"], + "transcript_versions.version_id": version_id + }, + { + "$set": { + "transcript_versions.$.maybe_anomaly": True + } + } + ) + + if result.matched_count > 0: + count += 1 + print(f"✅ Marked version {version_id} as potential anomaly") + + print(f"🎉 Scan complete! Marked {count} transcripts as potential anomalies") + + # Also count total anomalies + anomaly_count = await db.conversations.count_documents({ + "deleted": False, + "transcript_versions.maybe_anomaly": True + }) + print(f"📊 Total potential anomalies in system: {anomaly_count}") + + finally: + client.close() + + +if __name__ == "__main__": + asyncio.run(run_anomaly_detection()) diff --git a/backends/advanced/webui/.env.example b/backends/advanced/webui/.env.example index f872dc47..67edf16d 100644 --- a/backends/advanced/webui/.env.example +++ b/backends/advanced/webui/.env.example @@ -1,2 +1,5 @@ # Backend API URL -VITE_BACKEND_URL=http://localhost:8000 \ No newline at end of file +VITE_BACKEND_URL=http://localhost:8000 + +# Enable/disable user-loop swipe modal +VITE_USER_LOOP_MODAL_ENABLED=false diff --git a/backends/advanced/webui/jest.config.js b/backends/advanced/webui/jest.config.js new file mode 100644 index 00000000..a14bbc24 --- /dev/null +++ b/backends/advanced/webui/jest.config.js @@ -0,0 +1,30 @@ +module.exports = { + preset: 'ts-jest', + testEnvironment: 'jsdom', + roots: ['/src'], + testMatch: [ + '**/__tests__/**/*.test.tsx', + '**/?(*.)+(spec|test).tsx' + ], + transform: { + '^.+\\.(ts|tsx)$': 'ts-jest', + }, + moduleNameMapper: { + '\\.(css|less|scss|sass)$': 'identity-obj-proxy', + 'lucide-react': '/node_modules/lucide-react/dist/esm/index.js', + }, + setupFilesAfterEnv: ['/jest.setup.js'], + collectCoverageFrom: [ + 'src/**/*.{ts,tsx}', + '!src/**/*.d.ts', + '!src/**/*.stories.tsx', + ], + coverageThreshold: { + global: { + branches: 70, + functions: 70, + lines: 70, + statements: 70, + }, + }, +} diff --git a/backends/advanced/webui/jest.setup.js b/backends/advanced/webui/jest.setup.js new file mode 100644 index 00000000..7e65da4c --- /dev/null +++ b/backends/advanced/webui/jest.setup.js @@ -0,0 +1,38 @@ +// Jest setup file for user-loop modal tests + +// Mock fetch globally +global.fetch = jest.fn(); + +// Mock window.matchMedia +Object.defineProperty(window, 'matchMedia', { + writable: true, + value: jest.fn().mockImplementation(query => ({ + matches: false, + media: query, + onchange: null, + addListener: jest.fn(), + removeListener: jest.fn(), + addEventListener: jest.fn(), + removeEventListener: jest.fn(), + dispatchEvent: jest.fn(), + })), +}); + +// Mock IntersectionObserver +global.IntersectionObserver = class IntersectionObserver { + constructor() {} + disconnect() {} + observe() {} + takeRecords() { + return []; + } + unobserve() {} +}; + +// Suppress console warnings during tests +global.console = { + ...console, + log: jest.fn(), + error: jest.fn(), + warn: jest.fn(), +}; diff --git a/backends/advanced/webui/package-lock.json b/backends/advanced/webui/package-lock.json index ead72812..81f82555 100644 --- a/backends/advanced/webui/package-lock.json +++ b/backends/advanced/webui/package-lock.json @@ -11,6 +11,7 @@ "axios": "^1.6.2", "clsx": "^2.0.0", "d3": "^7.8.5", + "framer-motion": "^11.0.0", "frappe-gantt": "^1.0.4", "lucide-react": "^0.294.0", "react": "^18.2.0", @@ -3794,6 +3795,33 @@ "url": "https://github.com/sponsors/rawify" } }, + "node_modules/framer-motion": { + "version": "11.18.2", + "resolved": "https://registry.npmjs.org/framer-motion/-/framer-motion-11.18.2.tgz", + "integrity": "sha512-5F5Och7wrvtLVElIpclDT0CBzMVg3dL22B64aZwHtsIY8RB4mXICLrkajK4G9R+ieSAGcgrLeae2SeUTg2pr6w==", + "license": "MIT", + "dependencies": { + "motion-dom": "^11.18.1", + "motion-utils": "^11.18.1", + "tslib": "^2.4.0" + }, + "peerDependencies": { + "@emotion/is-prop-valid": "*", + "react": "^18.0.0 || ^19.0.0", + "react-dom": "^18.0.0 || ^19.0.0" + }, + "peerDependenciesMeta": { + "@emotion/is-prop-valid": { + "optional": true + }, + "react": { + "optional": true + }, + "react-dom": { + "optional": true + } + } + }, "node_modules/frappe-gantt": { "version": "1.0.4", "resolved": "https://registry.npmjs.org/frappe-gantt/-/frappe-gantt-1.0.4.tgz", @@ -4485,6 +4513,21 @@ "node": ">=16 || 14 >=14.17" } }, + "node_modules/motion-dom": { + "version": "11.18.1", + "resolved": "https://registry.npmjs.org/motion-dom/-/motion-dom-11.18.1.tgz", + "integrity": "sha512-g76KvA001z+atjfxczdRtw/RXOM3OMSdd1f4DL77qCTF/+avrRJiawSG4yDibEQ215sr9kpinSlX2pCTJ9zbhw==", + "license": "MIT", + "dependencies": { + "motion-utils": "^11.18.1" + } + }, + "node_modules/motion-utils": { + "version": "11.18.1", + "resolved": "https://registry.npmjs.org/motion-utils/-/motion-utils-11.18.1.tgz", + "integrity": "sha512-49Kt+HKjtbJKLtgO/LKj9Ld+6vw9BjH5d9sc40R/kVyH8GLAXgT42M2NnuPcJNuA3s9ZfZBUcwIgpmZWGEE+hA==", + "license": "MIT" + }, "node_modules/ms": { "version": "2.1.3", "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", @@ -6072,7 +6115,6 @@ "version": "2.8.1", "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", - "dev": true, "license": "0BSD" }, "node_modules/type-check": { diff --git a/backends/advanced/webui/package.json b/backends/advanced/webui/package.json index b933d8db..c80f16dc 100644 --- a/backends/advanced/webui/package.json +++ b/backends/advanced/webui/package.json @@ -13,6 +13,7 @@ "axios": "^1.6.2", "clsx": "^2.0.0", "d3": "^7.8.5", + "framer-motion": "^11.0.0", "frappe-gantt": "^1.0.4", "lucide-react": "^0.294.0", "react": "^18.2.0", diff --git a/backends/advanced/webui/src/components/UserLoopModal.tsx b/backends/advanced/webui/src/components/UserLoopModal.tsx new file mode 100644 index 00000000..057d5845 --- /dev/null +++ b/backends/advanced/webui/src/components/UserLoopModal.tsx @@ -0,0 +1,446 @@ +import { useState, useEffect } from 'react' +import { motion, AnimatePresence, PanInfo } from 'framer-motion' +import { X, Check, Heart, HeartCrack } from 'lucide-react' +import { BACKEND_URL } from '../services/api' + +interface AnomalyEvent { + version_id: string + conversation_id: string + transcript: string + timestamp: number + audio_duration: number + speaker_count: number + word_count: number +} + +export default function UserLoopModal() { + const [isOpen, setIsOpen] = useState(false) + const [events, setEvents] = useState([]) + const [currentIndex, setCurrentIndex] = useState(0) + const [direction, setDirection] = useState(0) + const [isAnimating, setIsAnimating] = useState(false) + const [isLoading, setIsLoading] = useState(false) + const [particles, setParticles] = useState<{ id: number; x: number; y: number; type: 'heart' | 'heart-break' }[]>([]) + + // Poll backend for anomalies; open modal when there are events + useEffect(() => { + const checkAnomaly = async () => { + const data = await fetchEvents() + setIsOpen(Array.isArray(data) && data.length > 0) + } + + // Check on component mount + checkAnomaly() + + // Poll every 30 seconds + const interval = setInterval(checkAnomaly, 30000) + return () => clearInterval(interval) + }, []) + + // Clean up particles + useEffect(() => { + const timer = setTimeout(() => { + setParticles([]) + }, 1000) + return () => clearTimeout(timer) + }, [particles]) + + const fetchEvents = async (): Promise => { + try { + setIsLoading(true) + console.log('Fetching events...') + const response = await fetch(`${BACKEND_URL}/api/user-loop/events`) + if (!response.ok) { + throw new Error(`Failed to fetch events: ${response.status}`) + } + const data = await response.json() + console.log('Events fetched:', data) + console.log('Events array:', Array.isArray(data)) + console.log('Events length:', data.length) + if (data.length > 0) { + console.log('First event:', data[0]) + console.log('First event version_id:', data[0].version_id) + console.log('First event transcript:', data[0].transcript) + } + setEvents(data) + setCurrentIndex(0) + return data + } catch (error) { + console.error('Failed to fetch events:', error) + return [] + } finally { + setIsLoading(false) + } + } + + // Close modal when no events left and not loading + useEffect(() => { + if (!isLoading && events.length === 0 && isOpen) { + console.log('No more events, closing modal') + setIsOpen(false) + } + }, [events.length, isLoading, isOpen]) + + const createParticles = (type: 'heart' | 'heart-break') => { + const newParticles = Array.from({ length: 8 }, (_, i) => ({ + id: Date.now() + i, + x: Math.random() * 400 - 200, + y: Math.random() * 200 - 100, + type + })) + setParticles(newParticles) + } + + const onPanEnd = (_event: MouseEvent | TouchEvent | PointerEvent, info: PanInfo) => { + if (isAnimating) return + + const threshold = 100 + if (info.offset.x > threshold) { + handleAction('accept', 1) + createParticles('heart') + } else if (info.offset.x < -threshold) { + handleAction('reject', -1) + createParticles('heart-break') + } else { + // Snap back to center + setDirection(0) + } + } + + const handleAction = async (action: 'accept' | 'reject', swipeDirection: number) => { + const event = events[currentIndex] + if (!event) return + + setIsAnimating(true) + setDirection(swipeDirection) + + try { + // Map AnomalyEvent fields to SwipeAction fields + const swipeAction = { + transcript_version_id: event.version_id, // Backend expects transcript_version_id + conversation_id: event.conversation_id, + reason: null, + timestamp: event.timestamp + } + + console.log(`Sending ${action} action:`, swipeAction) + + if (action === 'reject') { + const response = await fetch(`${BACKEND_URL}/api/user-loop/reject`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(swipeAction) + }) + const result = await response.json() + console.log('Reject result:', result) + } else { + // Accept action: Call /accept endpoint + const response = await fetch(`${BACKEND_URL}/api/user-loop/accept`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(swipeAction) + }) + const result = await response.json() + console.log('Accept result:', result) + } + } catch (error) { + console.error(`Failed to handle ${action}:`, error) + } + + // Wait for animation to complete, then move to next card + setTimeout(() => { + if (currentIndex < events.length - 1) { + setCurrentIndex(prev => prev + 1) + setIsAnimating(false) + setDirection(0) + } else { + // No more events, close modal + setIsOpen(false) + setEvents([]) + setIsAnimating(false) + setDirection(0) + } + }, 400) + } + + if (!isOpen) { + return null + } + + if (events.length === 0) { + return ( +
+ +
+
+

Loading transcripts...

+
+
+
+ ) + } + + const currentEvent = events[currentIndex] + + const cardVariants = { + enter: (direction: number) => ({ + x: direction > 0 ? 1000 : -1000, + opacity: 0, + scale: 0.8, + rotate: 0 + }), + center: { + zIndex: 1, + x: 0, + opacity: 1, + scale: 1, + rotate: 0 + }, + exit: (direction: number) => ({ + zIndex: 0, + x: direction > 0 ? 1000 : -1000, + opacity: 0, + scale: 0.8, + rotate: direction * 0.2 + }) + } + + return ( + + {isOpen && ( + +
+ {/* Particles */} + + {particles.map(particle => ( + + {particle.type === 'heart' ? ( + + + + ) : ( + + + + )} + + ))} + + + {/* Card */} + + {/* Status Overlays */} + + {direction > 0 && ( + + GOOD + + )} + {direction < 0 && ( + + NOPE + + )} + + + {/* Content */} + + + Review Transcript + + + + {currentEvent?.transcript || "Loading transcript..."} + + + {/* Audio Player */} + e.stopPropagation()} + onPointerMove={(e) => e.stopPropagation()} + onPointerUp={(e) => e.stopPropagation()} + > + {/* Only render audio when we have a valid version_id */} + {currentEvent?.version_id ? ( + <> + {console.log('Rendering audio with version_id:', currentEvent.version_id)} + + + {/* Counter */} + + {currentIndex + 1} / {events.length} + + + {/* Instructions */} + + + → + + Swipe right to accept + + ← + + Swipe left to reject + + + + {/* Close Button */} + setIsOpen(false)} + className="absolute top-4 right-4 p-2 text-gray-400 hover:text-gray-600 dark:hover:text-gray-200" + whileHover={{ scale: 1.1, rotate: 90 }} + whileTap={{ scale: 0.9 }} + transition={{ type: "spring", stiffness: 400, damping: 17 }} + > + + + + + {/* Control Buttons */} + + handleAction('reject', -1)} + className="w-16 h-16 rounded-full bg-white dark:bg-gray-800 border-2 border-red-500 text-red-500 flex items-center justify-center shadow-lg hover:shadow-xl" + whileHover={{ scale: 1.1, boxShadow: "0 10px 30px rgba(239, 68, 68, 0.3)" }} + whileTap={{ scale: 0.9 }} + transition={{ type: "spring", stiffness: 400, damping: 17 }} + > + + + + + handleAction('accept', 1)} + className="w-16 h-16 rounded-full bg-white dark:bg-gray-800 border-2 border-green-500 text-green-500 flex items-center justify-center shadow-lg hover:shadow-xl" + whileHover={{ scale: 1.1, boxShadow: "0 10px 30px rgba(34, 197, 94, 0.3)" }} + whileTap={{ scale: 0.9 }} + transition={{ type: "spring", stiffness: 400, damping: 17 }} + > + + + + + +
+
+ )} +
+ ) +} diff --git a/backends/advanced/webui/src/components/__tests__/UserLoopModal.test.tsx b/backends/advanced/webui/src/components/__tests__/UserLoopModal.test.tsx new file mode 100644 index 00000000..28b39b57 --- /dev/null +++ b/backends/advanced/webui/src/components/__tests__/UserLoopModal.test.tsx @@ -0,0 +1,50 @@ +/** + * Unit tests for UserLoopModal component covering all fixed issues. + */ + +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import UserLoopModal from '../UserLoopModal' + +const mockFetch = jest.fn() +global.fetch = mockFetch as any + +jest.mock('framer-motion', () => ({ + ...jest.requireActual('framer-motion'), + motion: { + div: ({ children, ...props }: any) =>
{children}
, + button: ({ children, ...props }: any) => , + }, + AnimatePresence: ({ children }: any) => <>{children}, +})) + +global.console = { ...console, log: jest.fn(), error: jest.fn() } + +const mockEvents = [ + { + version_id: 'version-1', + conversation_id: 'conv-1', + transcript: 'Test transcript text', + timestamp: 1234567890.0, + audio_duration: 10.5, + speaker_count: 2, + word_count: 10, + }, +] + +describe('UserLoopModal', () => { + beforeEach(() => jest.clearAllMocks()) + + it('should close modal when no events found', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + json: async () => [], + }) + + const { container } = render() + + await waitFor(() => { + expect(container.querySelector('.fixed')).not.toBeInTheDocument() + }) + }) +}) diff --git a/backends/advanced/webui/src/components/layout/Layout.tsx b/backends/advanced/webui/src/components/layout/Layout.tsx index 814634d9..eadb9324 100644 --- a/backends/advanced/webui/src/components/layout/Layout.tsx +++ b/backends/advanced/webui/src/components/layout/Layout.tsx @@ -3,12 +3,15 @@ import { Music, MessageSquare, MessageCircle, Brain, Users, Upload, Settings, Lo import { useAuth } from '../../contexts/AuthContext' import { useTheme } from '../../contexts/ThemeContext' import GlobalRecordingIndicator from './GlobalRecordingIndicator' +import UserLoopModal from '../UserLoopModal' export default function Layout() { const location = useLocation() const { user, logout, isAdmin } = useAuth() const { isDark, toggleTheme } = useTheme() + const userLoopModalEnabled = import.meta.env.VITE_USER_LOOP_MODAL_ENABLED === 'true' + const navigationItems = [ { path: '/live-record', label: 'Live Record', icon: Radio }, { path: '/chat', label: 'Chat', icon: MessageCircle }, @@ -112,6 +115,9 @@ export default function Layout() { + + {/* User Loop Modal - Swipe interface for anomaly review */} + {userLoopModalEnabled && } ) -} \ No newline at end of file +} diff --git a/backends/advanced/webui/src/vite-env.d.ts b/backends/advanced/webui/src/vite-env.d.ts index e4671017..16d7b261 100644 --- a/backends/advanced/webui/src/vite-env.d.ts +++ b/backends/advanced/webui/src/vite-env.d.ts @@ -2,8 +2,9 @@ interface ImportMetaEnv { readonly VITE_BACKEND_URL: string + readonly VITE_USER_LOOP_MODAL_ENABLED?: string } interface ImportMeta { readonly env: ImportMetaEnv -} \ No newline at end of file +} diff --git a/backends/advanced/webui/tsconfig.json b/backends/advanced/webui/tsconfig.json index 7355a7c8..9d3cb1eb 100644 --- a/backends/advanced/webui/tsconfig.json +++ b/backends/advanced/webui/tsconfig.json @@ -22,5 +22,11 @@ "noFallthroughCasesInSwitch": true }, "include": ["src"], + "exclude": [ + "**/__tests__/**", + "**/*.test.tsx", + "**/*.test.ts", + "node_modules" + ], "references": [{ "path": "./tsconfig.node.json" }] -} \ No newline at end of file +} diff --git a/tests/endpoints/user_loop_tests.robot b/tests/endpoints/user_loop_tests.robot new file mode 100644 index 00000000..4207053f --- /dev/null +++ b/tests/endpoints/user_loop_tests.robot @@ -0,0 +1,230 @@ +*** Settings *** +Documentation User-loop endpoint tests covering all fixed issues +... Issue #1: Audio not playing (Opus→WAV) +... Issue #2: /audio/undefined (404) +... Issue #3: FFmpeg not installed +... Issue #5: Swipe right not working +... Issue #6: Field name mismatch (422 error) +... Issue #7: Loading spinner stuck +... Issue #8: Wrong audio Content-Type + +Library RequestsLibrary +Library Collections +Library OperatingSystem +Resource ../setup/setup_keywords.robot +Resource ../setup/teardown_keywords.robot +Resource ../resources/user_loop_keywords.robot + +Suite Setup Suite Setup +Suite Teardown Suite Teardown +Test Setup Test Cleanup + +Test Tags conversation + +*** Test Cases *** +Get Events Returns Anomalies + [Documentation] Verify GET /events returns conversations with maybe_anomaly: true (Issue #5, #7) + ... Should NOT return maybe_anomaly: "verified" or false + + ${timestamp}= Get Timestamp + ${conv_id}= Set Variable user-loop-anomaly-${timestamp} + ${version_id}= Set Variable version-${timestamp} + + TRY + Insert Test Conversation ${conv_id} ${version_id} maybe_anomaly=true + + ${response}= GET On Session api /api/user-loop/events expected_status=200 + + ${events}= Set Variable ${response.json()} + Should Not Be Empty ${events} msg=/events should return at least one event when an anomaly exists + + ${found}= Set Variable ${False} + FOR ${event} IN @{events} + IF '${event}[conversation_id]' == '${conv_id}' and '${event}[version_id]' == '${version_id}' + ${found}= Set Variable ${True} + Dictionary Should Contain Key ${event} transcript + Dictionary Should Contain Key ${event} timestamp + Dictionary Should Contain Key ${event} audio_duration + Dictionary Should Contain Key ${event} speaker_count + Dictionary Should Contain Key ${event} word_count + Should Be Equal ${event}[transcript] Test transcript + END + END + + Should Be True ${found} msg=/events should include the newly inserted anomaly (${conv_id}, ${version_id}) + FINALLY + Delete Test Conversation ${conv_id} + END + +Get Events Returns Empty When No Anomalies + [Documentation] Verify GET /events returns [] when no anomalies (Issue #7) + + # Ensure clean slate so this test is deterministic. + Clear Test Databases + + ${timestamp}= Get Timestamp + ${conv_id}= Set Variable user-loop-verified-${timestamp} + ${version_id}= Set Variable version-${timestamp} + + TRY + # maybe_anomaly=verified should NOT be returned by /events (only maybe_anomaly=True is returned) + Insert Test Conversation ${conv_id} ${version_id} maybe_anomaly=verified + + ${response}= GET On Session api /api/user-loop/events expected_status=200 + ${events}= Set Variable ${response.json()} + Should Be Empty ${events} msg=/events should return [] when no anomalies exist + FINALLY + Delete Test Conversation ${conv_id} + END + +Accept Updates MaybeAnomaly To Verified + [Documentation] Verify POST /accept updates maybe_anomaly to "verified" (Issue #5, #6) + ... Should use transcript_version_id field (not version_id) + + ${timestamp}= Get Timestamp + ${conv_id}= Set Variable user-loop-accept-${timestamp} + ${version_id}= Set Variable version-${timestamp} + + TRY + Insert Test Conversation ${conv_id} ${version_id} maybe_anomaly=true + + ${body}= Create Dictionary + ... transcript_version_id=${version_id} + ... conversation_id=${conv_id} + ... reason=None + + ${response}= POST On Session api /api/user-loop/accept json=${body} expected_status=200 + + ${result}= Set Variable ${response.json()} + Should Be Equal ${result}[status] success + Should Be Equal ${result}[message] Verified transcript + + # Verify: MongoDB updated + ${conv}= Get Test Conversation ${conv_id} + ${maybe_anomaly}= Get From Dictionary ${conv}[transcript_versions][0] maybe_anomaly + Should Be Equal As Strings ${maybe_anomaly} verified + + ${verified_at}= Get From Dictionary ${conv}[transcript_versions][0] verified_at + Should Not Be Empty ${verified_at} + FINALLY + Delete Test Conversation ${conv_id} + END + +Accept Returns 422 For Missing TranscriptVersionId + [Documentation] Verify POST /accept returns 422 when transcript_version_id missing (Issue #6) + ... Backend expects transcript_version_id, not version_id + + ${timestamp}= Get Timestamp + ${conv_id}= Set Variable user-loop-422-${timestamp} + ${version_id}= Set Variable version-${timestamp} + + TRY + Insert Test Conversation ${conv_id} ${version_id} maybe_anomaly=true + + ${body}= Create Dictionary + ... version_id=${version_id} + ... conversation_id=${conv_id} + + POST On Session api /api/user-loop/accept json=${body} expected_status=422 + FINALLY + Delete Test Conversation ${conv_id} + END + +Accept Returns 404 For Missing Conversation + [Documentation] Verify POST /accept returns 404 when conversation not found + + ${body}= Create Dictionary + ... transcript_version_id=missing-version + ... conversation_id=missing-conv + + ${response}= POST On Session api /api/user-loop/accept json=${body} expected_status=404 + + ${result}= Set Variable ${response.json()} + Should Contain ${result}[detail] Not Found + +Reject Saves To TrainingStash + [Documentation] Verify POST /reject saves transcript to training-stash (Issue #5) + + ${timestamp}= Get Timestamp + ${conv_id}= Set Variable user-loop-reject-${timestamp} + ${version_id}= Set Variable version-${timestamp} + ${stash_id}= Set Variable ${EMPTY} + + TRY + Insert Test Conversation ${conv_id} ${version_id} maybe_anomaly=true + Insert Test Audio Chunk ${conv_id} 0 mock audio + + ${body}= Create Dictionary + ... transcript_version_id=${version_id} + ... conversation_id=${conv_id} + ... reason=False positive + + ${response}= POST On Session api /api/user-loop/reject json=${body} expected_status=200 + + ${result}= Set Variable ${response.json()} + Should Be Equal ${result}[status] success + Should Not Be Empty ${result}[stash_id] + + ${stash_id}= Set Variable ${result}[stash_id] + + ${stash}= Get Training Stash Entry ${stash_id} + Should Not Be Empty ${stash} + Should Be Equal ${stash}[transcript_version_id] ${version_id} + Should Be Equal ${stash}[transcript] Test transcript + Should Be Equal ${stash}[reason] False positive + + # Verify: conversation is removed from /events queue + ${conv}= Get Test Conversation ${conv_id} + ${maybe_anomaly}= Get From Dictionary ${conv}[transcript_versions][0] maybe_anomaly + Should Be Equal As Strings ${maybe_anomaly} rejected + FINALLY + Run Keyword And Ignore Error Delete Test Conversation ${conv_id} + Run Keyword And Ignore Error Delete Test Audio Chunks ${conv_id} + IF '${stash_id}' != '${EMPTY}' + Run Keyword And Ignore Error Delete Training Stash Entry ${stash_id} + END + END + +Get Audio Returns WAV + [Documentation] Verify GET /audio/:version_id returns WAV file (Issue #1, #8) + ... Audio should be converted from Opus to WAV format + + ${timestamp}= Get Timestamp + ${conv_id}= Set Variable user-loop-audio-${timestamp} + ${version_id}= Set Variable version-${timestamp} + + TRY + Insert Test Conversation ${conv_id} ${version_id} maybe_anomaly=true + + # Store a real WAV file in MongoDB so ffmpeg can produce a non-empty response. + ${wav_bytes}= Get Binary File ${CURDIR}/../test_assets/DIY_Experts_Glass_Blowing_16khz_mono_1min.wav + Insert Test Audio Chunk ${conv_id} 0 ${wav_bytes} + + ${response}= GET On Session api /api/user-loop/audio/${version_id} expected_status=200 + + ${content_type}= Set Variable ${response.headers}[Content-Type] + Should Be True 'audio/wav' in '${content_type}' or 'audio/ogg' in '${content_type}' + Should Not Be Empty ${response.content} msg=/audio should return a non-empty body + + ${disposition}= Set Variable ${response.headers}[Content-Disposition] + Should Contain ${disposition} audio_${version_id}. + IF 'audio/wav' in '${content_type}' + Should Contain ${disposition} .wav + Should Be True $response.content.startswith(b'RIFF') msg=Expected WAV bytes to start with RIFF + ELSE + Should Contain ${disposition} .opus + END + FINALLY + Run Keyword And Ignore Error Delete Test Conversation ${conv_id} + Run Keyword And Ignore Error Delete Test Audio Chunks ${conv_id} + END + +Get Audio Returns 404 For Missing Version + [Documentation] Verify GET /audio returns 404 when version not found (Issue #2) + ... Tests /audio/undefined case + + ${response}= GET On Session api /api/user-loop/audio/undefined expected_status=404 + + ${result}= Set Variable ${response.json()} + ${detail_lower}= Evaluate str($result.get('detail', '')).lower() + Should Contain ${detail_lower} not found diff --git a/tests/integration/user_loop_integration.robot b/tests/integration/user_loop_integration.robot new file mode 100644 index 00000000..01ff8fb7 --- /dev/null +++ b/tests/integration/user_loop_integration.robot @@ -0,0 +1,186 @@ +*** Settings *** +Documentation User-loop integration tests + +Library RequestsLibrary +Library Collections +Resource ../setup/setup_keywords.robot +Resource ../setup/teardown_keywords.robot +Resource ../resources/user_loop_keywords.robot +Variables ../setup/test_env.py + +Suite Setup Suite Setup +Suite Teardown Suite Teardown +Test Setup Clear Test Databases + +*** Test Cases *** +Reject Swipe Removes Event And Updates MongoDB + [Documentation] Reject (left swipe) should stash, mark rejected, and remove from /events + [Tags] conversation + + ${timestamp}= Get Timestamp + ${conv_id}= Set Variable integration-reject-${timestamp} + ${version_id}= Set Variable version-${timestamp} + ${stash_id}= Set Variable ${EMPTY} + + TRY + Insert Test Conversation ${conv_id} ${version_id} maybe_anomaly=true + Insert Test Audio Chunk ${conv_id} 0 mock audio data + + ${response}= GET On Session api /api/user-loop/events expected_status=200 + ${events_before}= Set Variable ${response.json()} + Should Not Be Empty ${events_before} + + ${found}= Set Variable ${False} + FOR ${event} IN @{events_before} + IF '${event}[conversation_id]' == '${conv_id}' and '${event}[version_id]' == '${version_id}' + ${found}= Set Variable ${True} + END + END + Should Be True ${found} msg=Expected inserted anomaly to be present before reject + + ${body}= Create Dictionary + ... transcript_version_id=${version_id} + ... conversation_id=${conv_id} + ... reason=Integration test false positive + + ${response}= POST On Session api /api/user-loop/reject json=${body} expected_status=200 + ${result}= Set Variable ${response.json()} + Should Be Equal ${result}[status] success + Should Not Be Empty ${result}[stash_id] + ${stash_id}= Set Variable ${result}[stash_id] + + ${conv}= Get Test Conversation ${conv_id} + ${maybe_anomaly}= Get From Dictionary ${conv}[transcript_versions][0] maybe_anomaly + Should Be Equal As Strings ${maybe_anomaly} rejected + ${rejected_at}= Get From Dictionary ${conv}[transcript_versions][0] rejected_at + Should Not Be Empty ${rejected_at} + + ${stash}= Get Training Stash Entry ${stash_id} + Should Not Be Empty ${stash} + ${audio_chunks}= Get From Dictionary ${stash} audio_chunks + Should Not Be Empty ${audio_chunks} + + ${response}= GET On Session api /api/user-loop/events expected_status=200 + ${events_after}= Set Variable ${response.json()} + ${still_present}= Set Variable ${False} + FOR ${event} IN @{events_after} + IF '${event}[conversation_id]' == '${conv_id}' and '${event}[version_id]' == '${version_id}' + ${still_present}= Set Variable ${True} + END + END + Should Be True ${still_present} == False msg=Rejected anomaly should not reappear in /events + FINALLY + Run Keyword And Ignore Error Delete Test Conversation ${conv_id} + IF '${stash_id}' != '${EMPTY}' + Run Keyword And Ignore Error Delete Training Stash Entry ${stash_id} + END + END + +Accept Swipe Removes Event And Updates MongoDB + [Documentation] Accept (right swipe) should mark verified and remove from /events + [Tags] conversation + + ${timestamp}= Get Timestamp + ${conv_id}= Set Variable integration-accept-${timestamp} + ${version_id}= Set Variable version-${timestamp} + + TRY + Insert Test Conversation ${conv_id} ${version_id} maybe_anomaly=true + + ${response}= GET On Session api /api/user-loop/events expected_status=200 + ${events_before}= Set Variable ${response.json()} + Should Not Be Empty ${events_before} + + ${found}= Set Variable ${False} + FOR ${event} IN @{events_before} + IF '${event}[conversation_id]' == '${conv_id}' and '${event}[version_id]' == '${version_id}' + ${found}= Set Variable ${True} + END + END + Should Be True ${found} msg=Expected inserted anomaly to be present before accept + + ${body}= Create Dictionary + ... transcript_version_id=${version_id} + ... conversation_id=${conv_id} + + ${response}= POST On Session api /api/user-loop/accept json=${body} expected_status=200 + ${result}= Set Variable ${response.json()} + Should Be Equal ${result}[status] success + Should Be Equal ${result}[message] Verified transcript + + ${conv}= Get Test Conversation ${conv_id} + ${maybe_anomaly}= Get From Dictionary ${conv}[transcript_versions][0] maybe_anomaly + Should Be Equal As Strings ${maybe_anomaly} verified + ${verified_at}= Get From Dictionary ${conv}[transcript_versions][0] verified_at + Should Not Be Empty ${verified_at} + + ${response}= GET On Session api /api/user-loop/events expected_status=200 + ${events_after}= Set Variable ${response.json()} + ${still_present}= Set Variable ${False} + FOR ${event} IN @{events_after} + IF '${event}[conversation_id]' == '${conv_id}' and '${event}[version_id]' == '${version_id}' + ${still_present}= Set Variable ${True} + END + END + Should Be True ${still_present} == False msg=Verified anomaly should not reappear in /events + FINALLY + Run Keyword And Ignore Error Delete Test Conversation ${conv_id} + END + +Multiple Anomalies Are Filtered By Status + [Documentation] Only maybe_anomaly=true is returned; verified/rejected are filtered + [Tags] conversation + + ${timestamp}= Get Timestamp + ${conv_true}= Set Variable multi-true-${timestamp} + ${conv_ver}= Set Variable multi-verified-${timestamp} + ${conv_rej}= Set Variable multi-rejected-${timestamp} + + TRY + Insert Test Conversation ${conv_true} v-true-${timestamp} maybe_anomaly=true + Insert Test Conversation ${conv_ver} v-verified-${timestamp} maybe_anomaly=verified + Insert Test Conversation ${conv_rej} v-rejected-${timestamp} maybe_anomaly=rejected + + ${response}= GET On Session api /api/user-loop/events expected_status=200 + ${events}= Set Variable ${response.json()} + + ${found_true}= Set Variable ${False} + FOR ${event} IN @{events} + IF '${event}[conversation_id]' == '${conv_true}' + ${found_true}= Set Variable ${True} + Should Be Equal ${event}[version_id] v-true-${timestamp} + END + Should Not Be Equal ${event}[conversation_id] ${conv_ver} + Should Not Be Equal ${event}[conversation_id] ${conv_rej} + END + Should Be True ${found_true} msg=Expected maybe_anomaly=true conversation to be returned + FINALLY + Run Keyword And Ignore Error Delete Test Conversation ${conv_true} + Run Keyword And Ignore Error Delete Test Conversation ${conv_ver} + Run Keyword And Ignore Error Delete Test Conversation ${conv_rej} + END + +Deleted Conversations Are Not Returned + [Documentation] Conversations with deleted=true are filtered from /events + [Tags] conversation + + ${timestamp}= Get Timestamp + ${conv_id}= Set Variable deleted-conv-${timestamp} + ${version_id}= Set Variable v-${timestamp} + + TRY + Insert Test Conversation ${conv_id} ${version_id} maybe_anomaly=true + Mark Test Conversation Deleted ${conv_id} ${True} + + ${response}= GET On Session api /api/user-loop/events expected_status=200 + ${events}= Set Variable ${response.json()} + ${still_present}= Set Variable ${False} + FOR ${event} IN @{events} + IF '${event}[conversation_id]' == '${conv_id}' + ${still_present}= Set Variable ${True} + END + END + Should Be True ${still_present} == False msg=Deleted conversations must not be returned by /events + FINALLY + Run Keyword And Ignore Error Delete Test Conversation ${conv_id} + END diff --git a/tests/libs/user_loop_helper.py b/tests/libs/user_loop_helper.py new file mode 100644 index 00000000..92cef67d --- /dev/null +++ b/tests/libs/user_loop_helper.py @@ -0,0 +1,241 @@ +""" +User-loop helper functions for Robot Framework tests. +Provides MongoDB CRUD operations for user-loop tests. +""" + +import os +import time +from pathlib import Path +from pymongo import MongoClient +from dotenv import load_dotenv +from bson import ObjectId + +# Load test environment variables +tests_dir = Path(__file__).parent.parent +load_dotenv(tests_dir / ".env.test", override=False) + + +def get_mongodb_uri(): + """Get MongoDB URI from environment.""" + # docker-compose-test.yml maps MongoDB to localhost:27018 + return os.getenv("MONGODB_URI", "mongodb://localhost:27018") + + +def get_db_name(): + """Get database name from environment.""" + return os.getenv("TEST_DB_NAME", "test_db") + + +def connect_to_mongodb(): + """Connect to MongoDB and return client and db.""" + client = MongoClient(get_mongodb_uri()) + db = client[get_db_name()] + return client, db + + +def disconnect_from_mongodb(client): + """Disconnect from MongoDB.""" + if client: + client.close() + + +def _to_boolean(value): + """Convert string 'true'/'false' to boolean.""" + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.lower() == "true" + return bool(value) + + +def _normalize_maybe_anomaly(value): + """Normalize maybe_anomaly test input. + + Accepts: + - true/false (bool) + - "true"/"false" (string) + - "verified"/"rejected" (string) + """ + if isinstance(value, str): + lowered = value.lower() + if lowered in {"verified", "rejected"}: + return lowered + if lowered in {"true", "false"}: + return lowered == "true" + return _to_boolean(value) + + +def insert_test_conversation(conv_id, version_id, maybe_anomaly): + """ + Insert test conversation into MongoDB with all required fields. + + Args: + conv_id: Conversation ID + version_id: Version ID + maybe_anomaly: Value for maybe_anomaly (True/False/"verified") + + Returns: + MongoDB insert result + """ + client, db = connect_to_mongodb() + try: + timestamp = int(time.time()) + maybe_anomaly_value = _normalize_maybe_anomaly(maybe_anomaly) + + # Create complete conversation document with all required fields + data = { + "conversation_id": conv_id, + "user_id": "test-user-id", # Required field + "client_id": "test-client-id", # Required field + "deleted": False, + "created_at": timestamp, + "transcript_versions": [{ + "version_id": version_id, + "transcript": "Test transcript", + "maybe_anomaly": maybe_anomaly_value, + "created_at": timestamp, # Required field + "segments": [], + "metadata": {"word_count": 5} + }], + "audio_chunks_count": 1, + "audio_total_duration": 10.0, + "active_transcript_version": version_id, + "title": f"Test Conversation {conv_id}", + "summary": "Test summary", + "detailed_summary": None, + "memory_versions": [], + "active_memory_version": None, + "completed_at": None, + "end_reason": None, + "deletion_reason": None, + "deleted_at": None, + "external_source_id": None, + "external_source_type": None + } + result = db.conversations.insert_one(data) + return result + finally: + disconnect_from_mongodb(client) + + +def delete_test_conversation(conv_id): + """ + Delete test conversation from MongoDB. + + Args: + conv_id: Conversation ID to delete + + Returns: + MongoDB delete result + """ + client, db = connect_to_mongodb() + try: + result = db.conversations.delete_one({"conversation_id": conv_id}) + return result + finally: + disconnect_from_mongodb(client) + + +def mark_test_conversation_deleted(conv_id, deleted=True): + """Mark a test conversation as deleted/undeleted.""" + client, db = connect_to_mongodb() + try: + return db.conversations.update_one( + {"conversation_id": conv_id}, + {"$set": {"deleted": bool(deleted)}}, + ) + finally: + disconnect_from_mongodb(client) + + +def get_test_conversation(conv_id): + """ + Get test conversation from MongoDB. + + Args: + conv_id: Conversation ID to get + + Returns: + Conversation document or None + """ + client, db = connect_to_mongodb() + try: + doc = db.conversations.find_one({"conversation_id": conv_id}) + return doc + finally: + disconnect_from_mongodb(client) + + +def insert_test_audio_chunk(conv_id, chunk_index, audio_data): + """ + Insert test audio chunk into MongoDB. + + Args: + conv_id: Conversation ID + chunk_index: Chunk index + audio_data: Audio data (bytes or string) + + Returns: + MongoDB insert result + """ + client, db = connect_to_mongodb() + try: + data = { + "conversation_id": conv_id, + "chunk_index": chunk_index, + "audio_data": audio_data + } + result = db.audio_chunks.insert_one(data) + return result + finally: + disconnect_from_mongodb(client) + + +def delete_test_audio_chunks(conv_id): + """Delete all audio chunks for a test conversation.""" + client, db = connect_to_mongodb() + try: + return db.audio_chunks.delete_many({"conversation_id": conv_id}) + finally: + disconnect_from_mongodb(client) + + +def get_training_stash_entry(stash_id): + """ + Get training stash entry from MongoDB. + + Args: + stash_id: Stash ID to get + + Returns: + Stash document or None + """ + client, db = connect_to_mongodb() + try: + doc = db.training_stash.find_one({"_id": ObjectId(stash_id)}) + return doc + finally: + disconnect_from_mongodb(client) + + +def delete_training_stash_entry(stash_id): + """ + Delete training stash entry from MongoDB. + + Args: + stash_id: Stash ID to delete + + Returns: + MongoDB delete result + """ + client, db = connect_to_mongodb() + try: + result = db.training_stash.delete_one({"_id": ObjectId(stash_id)}) + return result + finally: + disconnect_from_mongodb(client) + + +def get_timestamp(): + """Get current timestamp (epoch).""" + return int(time.time()) diff --git a/tests/resources/user_loop_keywords.robot b/tests/resources/user_loop_keywords.robot new file mode 100644 index 00000000..648e8a44 --- /dev/null +++ b/tests/resources/user_loop_keywords.robot @@ -0,0 +1,68 @@ +*** Settings *** +Documentation User-loop service keywords for Robot Framework tests + +Library ../libs/user_loop_helper.py WITH NAME UserLoopHelper + +*** Keywords *** +Insert Test Conversation + [Documentation] Insert test conversation into MongoDB + [Arguments] ${conv_id} ${version_id} ${maybe_anomaly} + + ${result}= UserLoopHelper.Insert Test Conversation ${conv_id} ${version_id} ${maybe_anomaly} + RETURN ${result} + +Delete Test Conversation + [Documentation] Delete test conversation from MongoDB + [Arguments] ${conv_id} + + ${result}= UserLoopHelper.Delete Test Conversation ${conv_id} + RETURN ${result} + +Mark Test Conversation Deleted + [Documentation] Mark a test conversation as deleted + [Arguments] ${conv_id} ${deleted}=${True} + + ${result}= UserLoopHelper.Mark Test Conversation Deleted ${conv_id} ${deleted} + RETURN ${result} + +Get Test Conversation + [Documentation] Get test conversation from MongoDB + [Arguments] ${conv_id} + + ${doc}= UserLoopHelper.Get Test Conversation ${conv_id} + RETURN ${doc} + +Insert Test Audio Chunk + [Documentation] Insert test audio chunk into MongoDB + [Arguments] ${conv_id} ${chunk_index} ${audio_data} + + ${result}= UserLoopHelper.Insert Test Audio Chunk ${conv_id} ${chunk_index} ${audio_data} + RETURN ${result} + +Delete Test Audio Chunks + [Documentation] Delete all test audio chunks for a conversation + [Arguments] ${conv_id} + + ${result}= UserLoopHelper.Delete Test Audio Chunks ${conv_id} + RETURN ${result} + +Get Training Stash Entry + [Documentation] Get training stash entry from MongoDB + [Arguments] ${stash_id} + + ${doc}= UserLoopHelper.Get Training Stash Entry ${stash_id} + RETURN ${doc} + +Delete Training Stash Entry + [Documentation] Delete training stash entry from MongoDB + [Arguments] ${stash_id} + + ${result}= UserLoopHelper.Delete Training Stash Entry ${stash_id} + RETURN ${result} + +Get Timestamp + [Documentation] Get current timestamp + [Arguments] ${format}=epoch + + ${time}= Evaluate int(time.time()) time + RETURN ${time} diff --git a/tests/test_timestamp.robot b/tests/test_timestamp.robot new file mode 100644 index 00000000..d0e188c2 --- /dev/null +++ b/tests/test_timestamp.robot @@ -0,0 +1,7 @@ +*** Settings *** +Resource resources/user_loop_keywords.robot + +*** Test Cases *** +Test Get Timestamp + ${timestamp}= Get Timestamp + Log Timestamp: ${timestamp}