From f49b09dcee9253905ac718099ee79dfebcf87b10 Mon Sep 17 00:00:00 2001 From: Nexus CLI Developer Date: Wed, 14 Jan 2026 15:59:21 -0600 Subject: [PATCH] feat: Add data enrichment tools (scoring, prioritization, balancing) and CLI command --- .env.example | 6 + CLI_README.md | 76 ++++++++ README.md | 66 +++++-- api/main.py | 79 +++++++++ cli.py | 324 ++++++++++++++++++++++++++++++++++ tunekit/__init__.py | 2 + tunekit/data/models.json | 351 +++++++++++++++++++++++++++++++++++++ tunekit/tools/__init__.py | 2 + tunekit/tools/enrich.py | 249 ++++++++++++++++++++++++++ tunekit/tools/model_rec.py | 305 +++++++++++--------------------- 10 files changed, 1247 insertions(+), 213 deletions(-) create mode 100644 .env.example create mode 100644 CLI_README.md create mode 100644 cli.py create mode 100644 tunekit/data/models.json create mode 100644 tunekit/tools/enrich.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..ad37844 --- /dev/null +++ b/.env.example @@ -0,0 +1,6 @@ +# GitHub Token for Gist creation (Required for "Open in Colab") +# Create one at: https://github.com/settings/tokens (Scope: gist) +GITHUB_TOKEN=your_github_token_here + +# Allowed CORS origins (comma separated) +ALLOWED_ORIGINS=http://localhost:3000,http://localhost:8000,http://127.0.0.1:8000 diff --git a/CLI_README.md b/CLI_README.md new file mode 100644 index 0000000..dcf91b1 --- /dev/null +++ b/CLI_README.md @@ -0,0 +1,76 @@ +# TuneKit CLI Manager + +Herramienta de línea de comandos para gestionar los modelos soportados por TuneKit. Permite añadir, listar y eliminar modelos de la configuración sin necesidad de editar código. + +## Uso Básico + +El script se encuentra en `TuneKit/cli.py`. Ejecútalo desde la raíz del proyecto. + +### Listar Modelos + +Muestra todos los modelos configurados actualmente. + +```bash +python TuneKit/cli.py list +``` + +### Añadir un Modelo + +Inicia un asistente interactivo impulsado por "IA" que te guiará paso a paso. + +```bash +python TuneKit/cli.py add +``` + +**Características Inteligentes:** +- **Autocompletado desde Hugging Face**: Al introducir el ID del modelo (ej: `google/gemma-3-270m-it`), la CLI consultará automáticamente la API de Hugging Face para obtener: + - Ventana de contexto (Context Window). + - Sugerencia de nombre amigable. +- **Validación**: Evita duplicados y campos vacíos. + +El asistente te pedirá: +1. **Metadata**: ID, nombre, tamaño, GPU recomendada. +2. **Puntuación**: Qué tan bueno es el modelo para diferentes tareas (0-100). +3. **Razonamiento**: Características clave para mostrar al usuario. + +### Eliminar un Modelo + +Elimina un modelo existente por su clave única. + +```bash +python TuneKit/cli.py remove +``` + +Ejemplo: +```bash +python TuneKit/cli.py remove deepseek-1.3b +``` + +## Herramientas de Datos + +### Enriquecimiento de Datos (Enrich) + +Mejora automáticamente la calidad de tu dataset mediante métricas de calidad, filtrado y balanceo de clases. + +```bash +python TuneKit/cli.py enrich [opciones] +``` + +**Opciones:** +- `--top_n `: Mantiene solo los N mejores ejemplos según su puntuación de calidad. +- `--no-balance`: Desactiva el balanceo automático de clases (útil si no es una tarea de clasificación). +- `-o `: Especifica el archivo de salida (por defecto: `nombre_enriched.jsonl`). + +**Ejemplo:** +```bash +# Enriquecer y guardar solo los 100 mejores ejemplos +python TuneKit/cli.py enrich data.jsonl --top_n 100 +``` + +## Archivos de Configuración + +La configuración se almacena en `TuneKit/tunekit/data/models.json`. Este archivo es generado y gestionado automáticamente por el CLI, pero puede editarse manualmente si es necesario. + +## Desarrollo + +Si añades nuevos campos a la lógica de recomendación en `model_rec.py`, asegúrate de actualizar el CLI para soportarlos. diff --git a/README.md b/README.md index 1328c0e..330d380 100644 --- a/README.md +++ b/README.md @@ -116,21 +116,65 @@ TuneKit uses the standard conversation format: --- -## Run Locally +## Data Enrichment -```bash -# Clone the repo -git clone https://github.com/riyanshibohra/TuneKit.git -cd TuneKit +TuneKit now includes tools to automatically improve your dataset quality before training: -# Install dependencies -pip install -r requirements.txt +- **Quality Scoring**: Evaluates every conversation on complexity, lexical diversity, and dialogue balance. +- **Smart Prioritization**: Automatically ranks examples and filters out low-quality ones. +- **Class Balancing**: Detects underrepresented classes in classification datasets and automatically balances them. -# Start the server -uvicorn api.main:app --reload -``` +--- -Open [http://localhost:8000](http://localhost:8000) in your browser. +## Development Setup + +### Prerequisites +- Python 3.10+ +- Git + +### Installation + +1. **Clone the repository** + ```bash + git clone https://github.com/riyanshibohra/TuneKit.git + cd TuneKit + ``` + +2. **Create a virtual environment (Recommended)** + ```bash + # Windows + python -m venv venv + .\venv\Scripts\activate + + # macOS/Linux + python3 -m venv venv + source venv/bin/activate + ``` + +3. **Install dependencies** + ```bash + pip install -r requirements.txt + ``` + +4. **Configuration** + Copy `.env.example` to `.env` and configure your tokens: + ```bash + # Windows (PowerShell) + cp .env.example .env + + # macOS/Linux + cp .env.example .env + ``` + + > **Note:** A `GITHUB_TOKEN` is required to automatically create private Gists for the Colab notebooks. + > [Generate a token here](https://github.com/settings/tokens) (Scope: `gist`). + +5. **Start the server** + ```bash + uvicorn api.main:app --reload + ``` + + The app will be available at [http://localhost:8000](http://localhost:8000). --- diff --git a/api/main.py b/api/main.py index 2f0b9f5..5fd5f67 100644 --- a/api/main.py +++ b/api/main.py @@ -31,6 +31,7 @@ analyze_dataset, generate_package, recommend_model, + enrich_dataset, ) from tunekit.training import ( generate_training_notebook, @@ -188,6 +189,12 @@ class GenerateRequest(BaseModel): session_id: str +class EnrichRequest(BaseModel): + session_id: str + top_n: Optional[int] = None + balance: bool = True + + class SessionResponse(BaseModel): session_id: str status: str @@ -748,6 +755,78 @@ async def generate(request: GenerateRequest): ) +@app.post("/enrich") +async def enrich(request: EnrichRequest): + """Enrich the dataset with metrics, prioritization and balancing.""" + session_id = request.session_id + + if session_id not in sessions: + raise HTTPException(status_code=404, detail="Session not found") + + session = sessions[session_id] + state = session.get("state") + + if not state: + raise HTTPException(status_code=400, detail="No data found") + + # Reload raw_data if needed + if not reload_raw_data_if_needed(session): + raise HTTPException(status_code=400, detail="Could not load data") + + # Set config + state["enrich_config"] = { + "top_n": request.top_n, + "balance": request.balance + } + + result = enrich_dataset(state) + + # Update state with enriched data + # IMPORTANT: We replace raw_data so downstream tools use the improved version + if result.get("enriched_data"): + state["raw_data"] = result["enriched_data"] + state["num_rows"] = len(result["enriched_data"]) + + # Update file on disk with enriched data? + # Maybe we should save a new file version. + # For now, let's just update memory state and maybe save to a temp file if we want to download it. + + # Save enriched data to a new file for persistence/download + original_path = session["file_path"] + name, ext = os.path.splitext(original_path) + enriched_path = f"{name}_enriched{ext}" + + try: + with open(enriched_path, 'w', encoding='utf-8') as f: + for entry in result["enriched_data"]: + f.write(json.dumps(entry) + '\n') + + # Update session to point to new file + session["file_path"] = enriched_path + state["file_path"] = enriched_path + + except Exception as e: + print(f"Warning: Failed to save enriched file: {e}") + + state.update(result) + + # Re-run validation/analysis on new data + val_res = validate_quality(state) + state.update(val_res) + + # Clear raw_data to save memory + state["raw_data"] = None + sessions[session_id]["state"] = state + + return { + "session_id": session_id, + "status": "success", + "stats": result.get("enrichment_stats", {}), + "quality_score": state.get("quality_score"), + "quality_issues": state.get("quality_issues") + } + + @app.get("/download/{session_id}") async def download(session_id: str): """Download the generated training package as a ZIP file.""" diff --git a/cli.py b/cli.py new file mode 100644 index 0000000..b2a6ad5 --- /dev/null +++ b/cli.py @@ -0,0 +1,324 @@ +#!/usr/bin/env python3 +""" +TuneKit CLI +=========== +Manage models and configuration for TuneKit. +""" + +import argparse +import json +import os +import sys +from typing import Dict, Any + +# Add local path to import tunekit modules +sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) +from tunekit.tools.enrich import enrich_dataset +from tunekit.tools.ingest import ingest_data + +# Path to configuration +DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tunekit', 'data') +CONFIG_PATH = os.path.join(DATA_DIR, 'models.json') + +def load_config() -> Dict[str, Any]: + """Load configuration from JSON file.""" + try: + with open(CONFIG_PATH, 'r') as f: + return json.load(f) + except FileNotFoundError: + print(f"❌ Error: Configuration file not found at {CONFIG_PATH}") + sys.exit(1) + +def save_config(config: Dict[str, Any]): + """Save configuration to JSON file.""" + with open(CONFIG_PATH, 'w') as f: + json.dump(config, f, indent=2) + print(f"💾 Configuration saved to {CONFIG_PATH}") + +def list_models(args): + """List all configured models.""" + config = load_config() + models = config.get('models', {}) + + print(f"\n📦 Configured Models ({len(models)}):") + print("=" * 60) + print(f"{'Key':<20} | {'Name':<25} | {'Size':<10}") + print("-" * 60) + + for key, data in models.items(): + meta = data.get('metadata', {}) + print(f"{key:<20} | {meta.get('name', 'N/A'):<25} | {meta.get('size', 'N/A'):<10}") + print("=" * 60) + +def input_default(prompt: str, default: Any) -> Any: + """Get user input with a default value.""" + user_input = input(f"{prompt} [{default}]: ").strip() + if not user_input: + return default + return user_input + +import requests + +def get_hf_metadata(model_id: str) -> Dict[str, Any]: + """Fetch model metadata from HuggingFace API.""" + print(f" 🔍 Querying HuggingFace API for '{model_id}'...") + try: + # Get basic info + url = f"https://huggingface.co/api/models/{model_id}" + response = requests.get(url, timeout=5) + + if response.status_code != 200: + print(" ⚠️ Could not fetch metadata (Model not found or API error)") + return {} + + data = response.json() + + # Extract config to find context window + config_url = f"https://huggingface.co/{model_id}/raw/main/config.json" + config_resp = requests.get(config_url, timeout=5) + config_data = config_resp.json() if config_resp.status_code == 200 else {} + + metadata = { + "id": data.get("modelId", model_id), + "tags": data.get("tags", []), + "downloads": data.get("downloads", 0), + } + + # Try to guess size from tags or safetensors + for tag in metadata["tags"]: + if tag.endswith("b") or tag.endswith("m"): + # Rough heuristic + pass + + # Try to find context window + ctx = config_data.get("max_position_embeddings") or config_data.get("seq_length") or config_data.get("n_positions") + if ctx: + metadata["context_window"] = ctx + + # Try to guess params count from safetensors index if available + # (This is complex, so we'll stick to manual entry if not obvious) + + print(" ✅ Metadata found!") + return metadata + + except Exception as e: + print(f" ⚠️ API Error: {e}") + return {} + +def add_model(args): + """Interactive wizard to add a new model.""" + config = load_config() + defaults = config.get('defaults', {}) + + print("\n🤖 TuneKit AI Assistant") + print("=======================") + print("I'll help you add a new model to the ecosystem.") + print("Please provide the details below.\n") + + # 1. Metadata + print("📝 Step 1: Model Metadata") + + while True: + key = input(" Unique Key (e.g., deepseek-1.3b): ").strip() + if not key: + print(" Key cannot be empty.") + continue + if key in config['models']: + print(f" ❌ Model '{key}' already exists. Choose another key.") + continue + break + + hf_id = input(" HuggingFace ID (e.g., deepseek-ai/deepseek-coder-1.3b-instruct): ").strip() + + # Auto-fetch metadata + hf_meta = {} + if hf_id: + hf_meta = get_hf_metadata(hf_id) + + name_default = hf_id.split('/')[-1].replace('-', ' ').title() if hf_id else "" + name = input_default(" Display Name", name_default) + + size = input(" Size (e.g., 1.3B): ").strip() + + ctx_default = hf_meta.get("context_window", 8192) + context_window = int(input_default(" Context Window (tokens)", ctx_default)) + gpu_tier = input_default(" Recommended GPU", "T4") + + # 2. Scores + print("\n📊 Step 2: Scoring Configuration") + print(" Rate the model's performance (0-100) for each task.") + + task_scores = {} + for task in ['classify', 'qa', 'conversation', 'generation', 'extraction']: + task_scores[task] = int(input_default(f" Score for '{task}'", defaults.get('task_score', 25))) + + print("\n Set bonus points for special capabilities.") + multi_turn = int(input_default(" Multi-turn Bonus (0-10)", 0)) + json_score = int(input_default(" JSON Output Score (0-20)", 10)) + long_output = int(input_default(" Long Output Score (0-20)", 10)) + + # 3. Reasons + print("\n🗣️ Step 3: Assistant Reasoning") + print(" What makes this model special? (Comma separated features)") + features_input = input(" Features: ").strip() + features = [f.strip() for f in features_input.split(',')] if features_input else [] + + # Construct model object + new_model = { + "metadata": { + "id": hf_id, + "name": name, + "size": size, + "context_window": context_window, + "training_time_base": 3, # Default + "cost_base": 0.0, + "gpu_tier": gpu_tier, + "memory_gb": 8, # Conservative default + "accuracy_baseline": 80 + }, + "scores": { + "task": task_scores, + "size": { + "small": 20, + "medium": 20, + "large": 20 + }, + "output": { + "long": long_output, + "json": json_score + }, + "multi_turn_bonus": multi_turn + }, + "deployment": ["cloud_api", "desktop_app", "not_sure"], + "reasons": { + "features": features + } + } + + # Save + config['models'][key] = new_model + + # Update tiebreakers (append to end) + tiebreaker = config.get('tiebreaker', {}) + for cat in tiebreaker: + if key not in tiebreaker[cat]: + tiebreaker[cat].append(key) + + save_config(config) + print(f"\n✅ Model '{name}' ({key}) successfully added to TuneKit!") + +def remove_model(args): + """Remove a model.""" + config = load_config() + key = args.key + + if key not in config['models']: + print(f"❌ Error: Model '{key}' not found.") + return + + confirm = input(f"⚠️ Are you sure you want to remove '{key}'? (y/N): ").lower() + if confirm == 'y': + del config['models'][key] + + # Remove from tiebreakers + tiebreaker = config.get('tiebreaker', {}) + for cat in tiebreaker: + if key in tiebreaker[cat]: + tiebreaker[cat].remove(key) + + save_config(config) + print(f"🗑️ Model '{key}' removed.") + else: + print("Cancelled.") + +def enrich_cmd(args): + """Run data enrichment on a file.""" + file_path = args.file + + if not os.path.exists(file_path): + print(f"❌ Error: File not found at {file_path}") + return + + print(f"\n🚀 Enriching dataset: {file_path}") + print(" Analysis started...") + + # 1. Ingest + state = {"file_path": file_path} + ingest_res = ingest_data(state) + + if ingest_res.get("error_msg"): + print(f"❌ Ingestion Error: {ingest_res['error_msg']}") + return + + state.update(ingest_res) + raw_count = len(state["raw_data"]) + print(f" ✓ Loaded {raw_count} examples.") + + # 2. Enrich + print(f" Applying metrics (Top N: {args.top_n}, Balance: {args.balance})...") + state["enrich_config"] = { + "top_n": args.top_n, + "balance": args.balance + } + + enrich_res = enrich_dataset(state) + enriched_data = enrich_res.get("enriched_data", []) + stats = enrich_res.get("enrichment_stats", {}) + + # 3. Save Output + if args.output: + out_path = args.output + else: + name, ext = os.path.splitext(file_path) + out_path = f"{name}_enriched{ext}" + + try: + with open(out_path, 'w', encoding='utf-8') as f: + for entry in enriched_data: + f.write(json.dumps(entry) + '\n') + + print("\n✅ Enrichment Complete!") + print(f" Input: {stats.get('original_count')} examples") + print(f" Output: {stats.get('enriched_count')} examples") + print(f" Score: {stats.get('avg_quality_score', 0):.4f} (Avg Quality)") + print(f" Saved: {out_path}") + + except Exception as e: + print(f"❌ Save Error: {e}") + +def main(): + parser = argparse.ArgumentParser(description="TuneKit CLI Manager") + subparsers = parser.add_subparsers(dest="command", help="Command to run") + + # List command + subparsers.add_parser("list", help="List configured models") + + # Add command + subparsers.add_parser("add", help="Add a new model via AI Assistant") + + # Remove command + remove_parser = subparsers.add_parser("remove", help="Remove a model") + remove_parser.add_argument("key", help="Model key to remove") + + # Enrich command + enrich_parser = subparsers.add_parser("enrich", help="Enrich a dataset") + enrich_parser.add_argument("file", help="Path to JSONL file") + enrich_parser.add_argument("--top_n", type=int, default=None, help="Keep only top N examples") + enrich_parser.add_argument("--no-balance", action="store_false", dest="balance", help="Disable class balancing") + enrich_parser.add_argument("-o", "--output", help="Output file path") + + args = parser.parse_args() + + if args.command == "list": + list_models(args) + elif args.command == "add": + add_model(args) + elif args.command == "remove": + remove_model(args) + elif args.command == "enrich": + enrich_cmd(args) + else: + parser.print_help() + +if __name__ == "__main__": + main() diff --git a/tunekit/__init__.py b/tunekit/__init__.py index a3e1796..39afb37 100644 --- a/tunekit/__init__.py +++ b/tunekit/__init__.py @@ -17,6 +17,7 @@ analyze_dataset, recommend_model, generate_package, + enrich_dataset, ) __version__ = "0.1.0" @@ -28,4 +29,5 @@ "analyze_dataset", "recommend_model", "generate_package", + "enrich_dataset", ] diff --git a/tunekit/data/models.json b/tunekit/data/models.json new file mode 100644 index 0000000..6394c89 --- /dev/null +++ b/tunekit/data/models.json @@ -0,0 +1,351 @@ +{ + "models": { + "phi-4-mini": { + "metadata": { + "id": "microsoft/Phi-4-mini-instruct", + "name": "Phi-4 Mini", + "size": "3.8B", + "context_window": 128000, + "training_time_base": 3, + "cost_base": 0.036, + "gpu_tier": "A10G", + "memory_gb": 12, + "accuracy_baseline": 87 + }, + "scores": { + "task": { + "classify": 50, + "qa": 35, + "conversation": 25, + "generation": 20, + "extraction": 50 + }, + "size": { + "small": 20, + "medium": 25, + "large": 10 + }, + "output": { + "long": 5, + "json": 20 + }, + "multi_turn_bonus": 3 + }, + "deployment": [ + "cloud_api", + "desktop_app", + "web_browser", + "not_sure" + ], + "reasons": { + "classify": "Best for classification tasks", + "extraction": "Excellent at structured extraction", + "qa": "Excellent reasoning capabilities", + "features": [ + "Supports function calling", + "Works well with limited data" + ] + } + }, + "gemma-3-2b": { + "metadata": { + "id": "google/gemma-2-2b-it", + "name": "Gemma 2 2B", + "size": "2B", + "context_window": 8192, + "training_time_base": 2, + "cost_base": 0.012, + "gpu_tier": "T4", + "memory_gb": 6, + "accuracy_baseline": 82 + }, + "scores": { + "task": { + "classify": 35, + "qa": 20, + "conversation": 20, + "generation": 20, + "extraction": 25 + }, + "size": { + "small": 30, + "medium": 15, + "large": 5 + }, + "output": { + "long": 5, + "json": 5 + }, + "multi_turn_bonus": 2 + }, + "deployment": [ + "cloud_api", + "desktop_app", + "mobile_app", + "ios_app", + "android_app", + "web_browser", + "edge_device", + "not_sure" + ], + "reasons": { + "classify": "Fast and efficient for classification", + "features": [ + "Optimized for on-device deployment", + "Ideal for small datasets" + ] + } + }, + "llama-3.2-3b": { + "metadata": { + "id": "meta-llama/Llama-3.2-3B-Instruct", + "name": "Llama 3.2 3B", + "size": "3B", + "context_window": 128000, + "training_time_base": 4, + "cost_base": 0.048, + "gpu_tier": "A10G", + "memory_gb": 10, + "accuracy_baseline": 89 + }, + "scores": { + "task": { + "classify": 30, + "qa": 50, + "conversation": 50, + "generation": 45, + "extraction": 40 + }, + "size": { + "small": 10, + "medium": 30, + "large": 30 + }, + "output": { + "long": 15, + "json": 15 + }, + "multi_turn_bonus": 10 + }, + "deployment": [ + "cloud_api", + "desktop_app", + "mobile_app", + "ios_app", + "android_app", + "web_browser", + "not_sure" + ], + "reasons": { + "qa": "Top performer for Q&A tasks", + "conversation": "Best for conversational AI", + "generation": "Creative and coherent text generation", + "features": [ + "Optimized for on-device deployment", + "Excellent at multi-turn context tracking" + ] + } + }, + "qwen-2.5-3b": { + "metadata": { + "id": "Qwen/Qwen2.5-3B-Instruct", + "name": "Qwen 2.5 3B", + "size": "3B", + "context_window": 32768, + "training_time_base": 4, + "cost_base": 0.048, + "gpu_tier": "A10G", + "memory_gb": 10, + "accuracy_baseline": 88 + }, + "scores": { + "task": { + "classify": 25, + "qa": 30, + "conversation": 35, + "generation": 30, + "extraction": 30 + }, + "size": { + "small": 10, + "medium": 20, + "large": 10 + }, + "output": { + "long": 10, + "json": 10 + }, + "multi_turn_bonus": 6 + }, + "deployment": [ + "cloud_api", + "desktop_app", + "mobile_app", + "ios_app", + "android_app", + "web_browser", + "edge_device", + "not_sure" + ], + "reasons": { + "extraction": "Strong JSON/structured output", + "conversation": "Multilingual conversation support", + "generation": "Strong multilingual generation", + "features": [ + "Multilingual (29 languages)" + ] + } + }, + "mistral-7b": { + "metadata": { + "id": "mistralai/Mistral-7B-Instruct-v0.3", + "name": "Mistral 7B", + "size": "7B", + "context_window": 8192, + "training_time_base": 6, + "cost_base": 0.3, + "gpu_tier": "A100-40GB", + "memory_gb": 18, + "accuracy_baseline": 91 + }, + "scores": { + "task": { + "classify": 20, + "qa": 45, + "conversation": 45, + "generation": 50, + "extraction": 35 + }, + "size": { + "small": 5, + "medium": 20, + "large": 25 + }, + "output": { + "long": 20, + "json": 10 + }, + "multi_turn_bonus": 8 + }, + "deployment": [ + "cloud_api", + "desktop_app", + "not_sure" + ], + "reasons": { + "qa": "Strong reasoning and knowledge base", + "conversation": "Natural dialogue and context tracking", + "generation": "Best for long-form generation", + "features": [ + "Supports function calling", + "Optimized for longer outputs" + ] + } + }, + "gemma-3-270m": { + "metadata": { + "id": "google/gemma-3-270m-it", + "name": "Gemma 3 270M", + "size": "270M", + "context_window": 32768, + "training_time_base": 3, + "cost_base": 0.0, + "gpu_tier": "T4", + "memory_gb": 8, + "accuracy_baseline": 80 + }, + "scores": { + "task": { + "classify": 100, + "qa": 20, + "conversation": 20, + "generation": 15, + "extraction": 40 + }, + "size": { + "small": 50, + "medium": 20, + "large": 20 + }, + "output": { + "long": 0, + "json": 10 + }, + "multi_turn_bonus": 0 + }, + "deployment": [ + "cloud_api", + "desktop_app", + "mobile_app", + "ios_app", + "android_app", + "web_browser", + "edge_device", + "not_sure" + ], + "reasons": { + "features": [ + "Ultra-lightweight", + "Mobile-ready", + "32K context" + ] + } + } + }, + "defaults": { + "task_score": 25, + "size_score": 15, + "output_score": 10, + "multi_turn_bonus": 0 + }, + "tiebreaker": { + "classify": [ + "phi-4-mini", + "gemma-3-2b", + "llama-3.2-3b", + "mistral-7b", + "qwen-2.5-3b", + "gemma-3-270m" + ], + "extraction": [ + "phi-4-mini", + "llama-3.2-3b", + "mistral-7b", + "qwen-2.5-3b", + "gemma-3-2b", + "gemma-3-270m" + ], + "qa": [ + "llama-3.2-3b", + "mistral-7b", + "phi-4-mini", + "qwen-2.5-3b", + "gemma-3-2b", + "gemma-3-270m" + ], + "conversation": [ + "llama-3.2-3b", + "mistral-7b", + "qwen-2.5-3b", + "phi-4-mini", + "gemma-3-2b", + "gemma-3-270m" + ], + "generation": [ + "mistral-7b", + "llama-3.2-3b", + "qwen-2.5-3b", + "phi-4-mini", + "gemma-3-2b", + "gemma-3-270m" + ], + "default": [ + "mistral-7b", + "llama-3.2-3b", + "phi-4-mini", + "qwen-2.5-3b", + "gemma-3-2b", + "gemma-3-270m" + ] + } +} \ No newline at end of file diff --git a/tunekit/tools/__init__.py b/tunekit/tools/__init__.py index 104a554..06d9ca4 100644 --- a/tunekit/tools/__init__.py +++ b/tunekit/tools/__init__.py @@ -9,6 +9,7 @@ from .analyze import analyze_dataset from .model_rec import recommend_model from .package import generate_package +from .enrich import enrich_dataset __all__ = [ "ingest_data", @@ -16,4 +17,5 @@ "analyze_dataset", "recommend_model", "generate_package", + "enrich_dataset", ] diff --git a/tunekit/tools/enrich.py b/tunekit/tools/enrich.py new file mode 100644 index 0000000..d20bc8e --- /dev/null +++ b/tunekit/tools/enrich.py @@ -0,0 +1,249 @@ +""" +Enrich Data Tool +================ +Utilities for improving fine-tuning examples: +1. Quality Metrics (Complexity, Diversity, etc.) +2. Prioritization (Scoring examples) +3. Augmentation (Class balancing, Synthetic generation placeholders) +""" + +from typing import TYPE_CHECKING, List, Dict, Any, Tuple +import random +import re +import math +import json +from collections import Counter + +if TYPE_CHECKING: + from tunekit.state import TuneKitState + +# ============================================================================ +# METRICS +# ============================================================================ + +def _calculate_lexical_diversity(text: str) -> float: + """Calculate Type-Token Ratio (TTR) as a measure of lexical diversity.""" + words = re.findall(r'\w+', text.lower()) + if not words: + return 0.0 + return len(set(words)) / len(words) + +def _calculate_instruction_complexity(text: str) -> float: + """ + Estimate instruction complexity based on length and structure. + Returns 0.0 to 1.0 + """ + words = text.split() + word_count = len(words) + + # Heuristic: Instructions between 10 and 50 words are often "good" complexity + # Too short = simple command + # Too long = might be confusing or few-shot + + score = 0.0 + if word_count < 5: + score = 0.2 + elif word_count < 10: + score = 0.5 + elif word_count < 50: + score = 0.9 + else: + score = 0.7 # Still good, but maybe verbose + + # Bonus for punctuation indicating structure + if "?" in text: score += 0.1 + if "\n" in text: score += 0.1 # Lists or formatting + + return min(1.0, score) + +def _score_conversation_quality(messages: List[Dict]) -> float: + """ + Score a conversation based on multiple heuristics. + Returns 0.0 to 1.0 + """ + if not messages: + return 0.0 + + score = 0.5 # Base score + + user_msgs = [m["content"] for m in messages if m["role"] == "user"] + asst_msgs = [m["content"] for m in messages if m["role"] == "assistant"] + + if not user_msgs or not asst_msgs: + return 0.0 + + # 1. Balance check (User vs Assistant length) + avg_user_len = sum(len(m) for m in user_msgs) / len(user_msgs) + avg_asst_len = sum(len(m) for m in asst_msgs) / len(asst_msgs) + + # Assistant should generally be helpful (not too short) but not hallucinatingly long + if avg_asst_len < 10: + score -= 0.2 + elif avg_asst_len > 2000: + score -= 0.1 # Very long might be okay, but slight penalty for potential rambling + else: + score += 0.1 + + # 2. Lexical Diversity of Assistant + # We want varied vocabulary + diversity = _calculate_lexical_diversity(" ".join(asst_msgs)) + score += (diversity * 0.2) + + # 3. Instruction Complexity + # We want complex/interesting user queries + complexity = _calculate_instruction_complexity(" ".join(user_msgs)) + score += (complexity * 0.2) + + return max(0.0, min(1.0, score)) + +# ============================================================================ +# PRIORITIZATION & FILTERING +# ============================================================================ + +def prioritize_examples(raw_data: List[Dict], top_n: int = None, threshold: float = 0.3) -> List[Dict]: + """ + Sort examples by quality score and filter out low-quality ones. + """ + scored_data = [] + for entry in raw_data: + quality = _score_conversation_quality(entry.get("messages", [])) + # Store score in the entry for debugging/analysis + entry["_quality_score"] = quality + scored_data.append(entry) + + # Sort by score descending + scored_data.sort(key=lambda x: x["_quality_score"], reverse=True) + + # Filter by threshold + filtered_data = [x for x in scored_data if x["_quality_score"] >= threshold] + + # Return top N + if top_n and top_n < len(filtered_data): + return filtered_data[:top_n] + + return filtered_data + +# ============================================================================ +# AUGMENTATION +# ============================================================================ + +def balance_classes(raw_data: List[Dict], target_col: str = None) -> List[Dict]: + """ + Simple augmentation: Oversample underrepresented classes for classification tasks. + (Requires identifying a 'label' which is tricky in pure chat, + but we can try to infer from short assistant responses if they look like classes) + """ + # Heuristic: If assistant responses are short (<50 chars) and repetitive, + # treat them as classification labels. + + # 1. Extract potential labels + labels = [] + for entry in raw_data: + # Check last assistant message + msgs = entry.get("messages", []) + if not msgs: + labels.append(None) + continue + last_msg = msgs[-1] + if last_msg["role"] == "assistant": + content = last_msg["content"].strip().lower() + if len(content) < 50: + labels.append(content) + else: + labels.append(None) # Not a classification-style response + else: + labels.append(None) + + # If too many None, skip balancing + if not labels or labels.count(None) / len(labels) > 0.5: + return raw_data # Probably not a classification dataset + + # 2. Count classes + valid_labels = [l for l in labels if l is not None] + if not valid_labels: + return raw_data + + counts = Counter(valid_labels) + max_count = max(counts.values()) + + # 3. Oversample + augmented_data = list(raw_data) + + for label, count in counts.items(): + if count < max_count: + # Find examples with this label + examples = [ + entry for entry, l in zip(raw_data, labels) + if l == label + ] + + if not examples: continue + + # Calculate how many to add + needed = max_count - count + + # Add random copies (simple oversampling) + # In a real scenario, we'd use an LLM to rephrase here + for _ in range(needed): + # Deep copy to avoid reference issues + copy_entry = json.loads(json.dumps(random.choice(examples))) + augmented_data.append(copy_entry) + + return augmented_data + +def generate_variations(state: "TuneKitState") -> dict: + """ + Placeholder for LLM-based data generation. + """ + # TODO: Implement LLM integration for true synthetic data generation + return { + "warning": "LLM-based generation requires API configuration. Using heuristic balancing instead." + } + +# ============================================================================ +# MAIN TOOL +# ============================================================================ + +def enrich_dataset(state: "TuneKitState") -> dict: + """ + Enrich the dataset with scores, prioritization, and basic augmentation. + + Inputs (from state): + - raw_data: List[Dict] + - enrich_config: Dict (optional) + - top_n: int + - balance: bool + + Outputs (to state): + - enriched_data: List[Dict] + - enrichment_stats: Dict + """ + raw_data = state.get("raw_data", []) + config = state.get("enrich_config", {}) + + if not raw_data: + return {"enriched_data": [], "enrichment_stats": {"error": "No data"}} + + # 1. Scoring & Prioritization + top_n = config.get("top_n", None) + + # Use 0.0 threshold to keep everything by default unless very bad + prioritized_data = prioritize_examples(raw_data, top_n=top_n, threshold=0.1) + + # 2. Balancing (Augmentation) + if config.get("balance", True): + final_data = balance_classes(prioritized_data) + else: + final_data = prioritized_data + + # Stats + stats = { + "original_count": len(raw_data), + "enriched_count": len(final_data), + "avg_quality_score": sum(x.get("_quality_score", 0) for x in final_data) / len(final_data) if final_data else 0 + } + + return { + "enriched_data": final_data, + "enrichment_stats": stats + } diff --git a/tunekit/tools/model_rec.py b/tunekit/tools/model_rec.py index b605e02..32a30c0 100644 --- a/tunekit/tools/model_rec.py +++ b/tunekit/tools/model_rec.py @@ -4,142 +4,37 @@ Simple 3-factor scoring: Task (50) + Size (30) + Output (20) = 100 points """ +import os +import json from typing import Dict, List # ════════════════════════════════════════════════════════════════════════════ -# MODEL METADATA +# CONFIGURATION # ════════════════════════════════════════════════════════════════════════════ -MODELS = { - 'phi-4-mini': { - 'id': 'microsoft/Phi-4-mini-instruct', - 'name': 'Phi-4 Mini', - 'size': '3.8B', - 'context_window': 128000, - 'training_time_base': 3, - 'cost_base': 0.036, - 'gpu_tier': 'A10G', - 'memory_gb': 12, - 'accuracy_baseline': 87 - }, - 'gemma-3-2b': { - 'id': 'google/gemma-2-2b-it', - 'name': 'Gemma 2 2B', - 'size': '2B', - 'context_window': 8192, - 'training_time_base': 2, - 'cost_base': 0.012, - 'gpu_tier': 'T4', - 'memory_gb': 6, - 'accuracy_baseline': 82 - }, - 'llama-3.2-3b': { - 'id': 'meta-llama/Llama-3.2-3B-Instruct', - 'name': 'Llama 3.2 3B', - 'size': '3B', - 'context_window': 128000, - 'training_time_base': 4, - 'cost_base': 0.048, - 'gpu_tier': 'A10G', - 'memory_gb': 10, - 'accuracy_baseline': 89 - }, - 'qwen-2.5-3b': { - 'id': 'Qwen/Qwen2.5-3B-Instruct', - 'name': 'Qwen 2.5 3B', - 'size': '3B', - 'context_window': 32768, - 'training_time_base': 4, - 'cost_base': 0.048, - 'gpu_tier': 'A10G', - 'memory_gb': 10, - 'accuracy_baseline': 88 - }, - 'mistral-7b': { - 'id': 'mistralai/Mistral-7B-Instruct-v0.3', - 'name': 'Mistral 7B', - 'size': '7B', - 'context_window': 8192, - 'training_time_base': 6, - 'cost_base': 0.30, - 'gpu_tier': 'A100-40GB', - 'memory_gb': 18, - 'accuracy_baseline': 91 - } -} - -# ════════════════════════════════════════════════════════════════════════════ -# DEPLOYMENT FILTERS -# ════════════════════════════════════════════════════════════════════════════ - -DEPLOYMENT_FILTERS = { - 'cloud_api': ['phi-4-mini', 'gemma-3-2b', 'llama-3.2-3b', 'qwen-2.5-3b', 'mistral-7b'], - 'desktop_app': ['phi-4-mini', 'gemma-3-2b', 'llama-3.2-3b', 'qwen-2.5-3b', 'mistral-7b'], - 'mobile_app': ['phi-4-mini', 'gemma-3-2b', 'llama-3.2-3b'], - 'ios_app': ['phi-4-mini', 'gemma-3-2b', 'llama-3.2-3b'], - 'android_app': ['phi-4-mini', 'gemma-3-2b', 'llama-3.2-3b'], - 'web_browser': ['phi-4-mini', 'gemma-3-2b', 'llama-3.2-3b'], - 'edge_device': ['gemma-3-2b', 'phi-4-mini'], - 'not_sure': ['phi-4-mini', 'gemma-3-2b', 'llama-3.2-3b', 'qwen-2.5-3b', 'mistral-7b'] -} - -# ════════════════════════════════════════════════════════════════════════════ -# SCORING TABLES -# ════════════════════════════════════════════════════════════════════════════ - -# Factor 1: TASK TYPE (50 points max) -TASK_SCORES = { - 'classify': { - 'phi-4-mini': 50, 'gemma-3-2b': 35, 'llama-3.2-3b': 30, 'qwen-2.5-3b': 25, 'mistral-7b': 20 - }, - 'qa': { - 'llama-3.2-3b': 50, 'mistral-7b': 45, 'phi-4-mini': 35, 'qwen-2.5-3b': 30, 'gemma-3-2b': 20 - }, - 'conversation': { - 'llama-3.2-3b': 50, 'mistral-7b': 45, 'qwen-2.5-3b': 35, 'phi-4-mini': 25, 'gemma-3-2b': 20 - }, - 'generation': { - 'mistral-7b': 50, 'llama-3.2-3b': 45, 'qwen-2.5-3b': 30, 'phi-4-mini': 20, 'gemma-3-2b': 20 - }, - 'extraction': { - 'phi-4-mini': 50, 'llama-3.2-3b': 40, 'mistral-7b': 35, 'qwen-2.5-3b': 30, 'gemma-3-2b': 25 - } -} - -# Factor 2: DATASET SIZE (30 points max) -SIZE_SCORES_SMALL = { # < 500 examples - 'gemma-3-2b': 30, 'phi-4-mini': 20, 'llama-3.2-3b': 10, 'qwen-2.5-3b': 10, 'mistral-7b': 5 -} -SIZE_SCORES_MEDIUM = { # 500-2000 examples - 'llama-3.2-3b': 30, 'phi-4-mini': 25, 'mistral-7b': 20, 'qwen-2.5-3b': 20, 'gemma-3-2b': 15 -} -SIZE_SCORES_LARGE = { # > 2000 examples - 'llama-3.2-3b': 30, 'mistral-7b': 25, 'phi-4-mini': 10, 'qwen-2.5-3b': 10, 'gemma-3-2b': 5 -} +DATA_DIR = os.path.dirname(os.path.abspath(__file__)).replace('tools', 'data') +CONFIG_PATH = os.path.join(DATA_DIR, 'models.json') -# Factor 3: OUTPUT CHARACTERISTICS (20 points max) -OUTPUT_SCORES_LONG = { # avg_response_length > 200 - 'mistral-7b': 20, 'llama-3.2-3b': 15, 'qwen-2.5-3b': 10, 'phi-4-mini': 5, 'gemma-3-2b': 5 -} -OUTPUT_SCORES_JSON = { # JSON output detected - 'phi-4-mini': 20, 'llama-3.2-3b': 15, 'mistral-7b': 10, 'qwen-2.5-3b': 10, 'gemma-3-2b': 5 -} +def load_config() -> dict: + """Load configuration from JSON file.""" + try: + with open(CONFIG_PATH, 'r') as f: + return json.load(f) + except FileNotFoundError: + # Fallback empty config if file is missing (should not happen in prod) + return {"models": {}, "defaults": {}, "tiebreaker": {}} -# Multi-turn bonus (+10 points max) -MULTI_TURN_BONUS = { - 'llama-3.2-3b': 10, 'mistral-7b': 8, 'qwen-2.5-3b': 6, 'phi-4-mini': 3, 'gemma-3-2b': 2 -} - -# Tie-breaker priorities per task -TIEBREAKER = { - 'classify': ['phi-4-mini', 'gemma-3-2b', 'llama-3.2-3b', 'mistral-7b', 'qwen-2.5-3b'], - 'extraction': ['phi-4-mini', 'llama-3.2-3b', 'mistral-7b', 'qwen-2.5-3b', 'gemma-3-2b'], - 'qa': ['llama-3.2-3b', 'mistral-7b', 'phi-4-mini', 'qwen-2.5-3b', 'gemma-3-2b'], - 'conversation': ['llama-3.2-3b', 'mistral-7b', 'qwen-2.5-3b', 'phi-4-mini', 'gemma-3-2b'], - 'generation': ['mistral-7b', 'llama-3.2-3b', 'qwen-2.5-3b', 'phi-4-mini', 'gemma-3-2b'], - 'default': ['mistral-7b', 'llama-3.2-3b', 'phi-4-mini', 'qwen-2.5-3b', 'gemma-3-2b'] -} +CONFIG = load_config() +# Helper to access metadata easily +MODELS = {k: v['metadata'] for k, v in CONFIG['models'].items()} +DEFAULTS = CONFIG.get('defaults', { + "task_score": 25, + "size_score": 15, + "output_score": 10, + "multi_turn_bonus": 0 +}) +TIEBREAKER = CONFIG.get('tiebreaker', {}) # ════════════════════════════════════════════════════════════════════════════ # MAIN RECOMMENDATION FUNCTION @@ -180,47 +75,54 @@ def recommend_model( if deployment_target == 'edge_device': return build_response( - primary_key='gemma-3-2b', + primary_key='gemma-3-270m', score=100, - all_scores={'gemma-3-2b': 100, 'phi-4-mini': 80}, - reasons=['Smallest model (2B params)', 'Optimized for low-power devices'], + all_scores={'gemma-3-270m': 100, 'gemma-3-2b': 90, 'phi-4-mini': 80}, + reasons=['Smallest model (270M params)', 'Optimized for low-power devices'], num_examples=num_examples, - alternatives=[{'model': 'phi-4-mini', 'score': 80}] + alternatives=[{'model': 'gemma-3-2b', 'score': 90}] ) - # Filter models by deployment target - allowed_models = DEPLOYMENT_FILTERS.get(deployment_target, DEPLOYMENT_FILTERS['not_sure']) - - # Score each model + # Score each model based on JSON config scores = {} - for model_key in allowed_models: - # Factor 1: Task type (50 points) - task_scores = TASK_SCORES.get(user_task, TASK_SCORES['classify']) - task_score = task_scores.get(model_key, 25) + + for model_key, model_data in CONFIG['models'].items(): + # 1. Filter by deployment target + allowed_deployments = model_data.get('deployment', []) + if deployment_target != 'not_sure' and deployment_target not in allowed_deployments: + continue + + model_scores = model_data.get('scores', {}) - # Factor 2: Dataset size (30 points) + # Factor 1: Task type + task_scores_map = model_scores.get('task', {}) + task_score = task_scores_map.get(user_task, DEFAULTS['task_score']) + + # Factor 2: Dataset size + size_scores_map = model_scores.get('size', {}) if num_examples < 500: - size_score = SIZE_SCORES_SMALL.get(model_key, 15) + size_score = size_scores_map.get('small', DEFAULTS['size_score']) elif num_examples >= 2000: - size_score = SIZE_SCORES_LARGE.get(model_key, 15) + size_score = size_scores_map.get('large', DEFAULTS['size_score']) else: - size_score = SIZE_SCORES_MEDIUM.get(model_key, 15) + size_score = size_scores_map.get('medium', DEFAULTS['size_score']) - # Factor 3: Output characteristics (20 points) + # Factor 3: Output characteristics + output_scores_map = model_scores.get('output', {}) if avg_response_length > 200: - output_score = OUTPUT_SCORES_LONG.get(model_key, 10) + output_score = output_scores_map.get('long', DEFAULTS['output_score']) elif looks_like_json: - output_score = OUTPUT_SCORES_JSON.get(model_key, 10) + output_score = output_scores_map.get('json', DEFAULTS['output_score']) else: - output_score = 10 + output_score = DEFAULTS['output_score'] - # Bonus: Multi-turn conversations (+10 points) - multi_turn_bonus = MULTI_TURN_BONUS.get(model_key, 0) if is_multi_turn else 0 + # Bonus: Multi-turn + multi_turn_bonus = model_scores.get('multi_turn_bonus', DEFAULTS['multi_turn_bonus']) if is_multi_turn else 0 scores[model_key] = task_score + size_score + output_score + multi_turn_bonus # Pick winner with tie-breaking - priority_order = TIEBREAKER.get(user_task, TIEBREAKER['default']) + priority_order = TIEBREAKER.get(user_task, TIEBREAKER.get('default', [])) def get_priority(model_key): try: @@ -228,8 +130,19 @@ def get_priority(model_key): except ValueError: return 99 + # Sort by score DESC, then by priority ASC sorted_models = sorted(scores.items(), key=lambda x: (-x[1], get_priority(x[0]))) + if not sorted_models: + # Fallback if no models match filters + return build_response( + primary_key='phi-4-mini', # Fallback safe default + score=50, + all_scores={'phi-4-mini': 50}, + reasons=['Fallback recommendation (no models matched specific criteria)'], + num_examples=num_examples + ) + primary_key = sorted_models[0][0] primary_score = sorted_models[0][1] @@ -238,27 +151,26 @@ def get_priority(model_key): # Build alternatives alternatives = [] for model_key, score in sorted_models[1:3]: - model = MODELS[model_key] + model_meta = MODELS[model_key] alt_reasons = [] if score >= primary_score - 10: alt_reasons.append(f"Close match ({score}/100)") - if model_key in ['llama-3.2-3b', 'gemma-3-2b']: - if model_key == 'llama-3.2-3b': - alt_reasons.append("Requires HuggingFace approval (gated)") - elif model_key == 'gemma-3-2b': - alt_reasons.append("Requires terms agreement (gated)") + # Check gated status (simplified logic based on ID pattern) + if "meta-llama" in model_meta['id'] or "gemma" in model_meta['id']: + alt_reasons.append("Requires HuggingFace approval (gated)") - if model['training_time_base'] < MODELS[primary_key]['training_time_base']: - time_diff = MODELS[primary_key]['training_time_base'] - model['training_time_base'] + primary_meta = MODELS[primary_key] + if model_meta.get('training_time_base', 0) < primary_meta.get('training_time_base', 0): + time_diff = primary_meta['training_time_base'] - model_meta['training_time_base'] alt_reasons.append(f"~{time_diff} min faster training") - if model['memory_gb'] < MODELS[primary_key]['memory_gb']: - alt_reasons.append(f"Lower VRAM ({model['memory_gb']}GB vs {MODELS[primary_key]['memory_gb']}GB)") + if model_meta.get('memory_gb', 0) < primary_meta.get('memory_gb', 0): + alt_reasons.append(f"Lower VRAM ({model_meta['memory_gb']}GB vs {primary_meta['memory_gb']}GB)") - if model['context_window'] > MODELS[primary_key]['context_window']: - alt_reasons.append(f"Larger context ({model['context_window']//1000}K tokens)") + if model_meta.get('context_window', 0) > primary_meta.get('context_window', 0): + alt_reasons.append(f"Larger context ({model_meta['context_window']//1000}K tokens)") alternatives.append({ 'model': model_key, @@ -279,54 +191,35 @@ def get_priority(model_key): def generate_reasons(model_key: str, task: str, num_examples: int, avg_response: int, is_json: bool, is_multi_turn: bool, deployment: str) -> List[str]: """Generate human-readable reasons for the recommendation.""" reasons = [] - model = MODELS[model_key] + model_data = CONFIG['models'].get(model_key, {}) + model_meta = model_data.get('metadata', {}) + reasons_config = model_data.get('reasons', {}) # Task-based reason - task_reasons = { - ('phi-4-mini', 'classify'): 'Best for classification tasks', - ('gemma-3-2b', 'classify'): 'Fast and efficient for classification', - ('phi-4-mini', 'extraction'): 'Excellent at structured extraction', - ('qwen-2.5-3b', 'extraction'): 'Strong JSON/structured output', - ('llama-3.2-3b', 'qa'): 'Top performer for Q&A tasks', - ('mistral-7b', 'qa'): 'Strong reasoning and knowledge base', - ('phi-4-mini', 'qa'): 'Excellent reasoning capabilities', - ('llama-3.2-3b', 'conversation'): 'Best for conversational AI', - ('mistral-7b', 'conversation'): 'Natural dialogue and context tracking', - ('qwen-2.5-3b', 'conversation'): 'Multilingual conversation support', - ('mistral-7b', 'generation'): 'Best for long-form generation', - ('llama-3.2-3b', 'generation'): 'Creative and coherent text generation', - ('qwen-2.5-3b', 'generation'): 'Strong multilingual generation', - } - reason = task_reasons.get((model_key, task)) - if reason: - reasons.append(reason) + if task in reasons_config: + reasons.append(reasons_config[task]) else: reasons.append(f"Strong performance for {task} tasks") - # Model characteristics - if model_key in ['phi-4-mini', 'mistral-7b']: - reasons.append('Supports function calling') - - if model_key == 'qwen-2.5-3b': - reasons.append('Multilingual (29 languages)') - - if model_key in ['gemma-3-2b', 'llama-3.2-3b']: - reasons.append('Optimized for on-device deployment') + # Generic features + reasons.extend(reasons_config.get('features', [])[:2]) - if model_key in ['phi-4-mini', 'llama-3.2-3b', 'qwen-2.5-3b']: - context_window = model.get('context_window', 0) - if context_window >= 128000: - reasons.append('128K token context window') - elif context_window >= 32000: - reasons.append('32K token context window') + # Context window logic + context_window = model_meta.get('context_window', 0) + if context_window >= 128000: + reasons.append('128K token context window') + elif context_window >= 32000: + reasons.append('32K token context window') - if is_multi_turn and model_key in ['llama-3.2-3b', 'mistral-7b']: + if is_multi_turn and model_data.get('scores', {}).get('multi_turn_bonus', 0) > 5: reasons.append('Excellent at multi-turn context tracking') # Size-based reason if num_examples < 500: if model_key == 'gemma-3-2b': reasons.append('Ideal for small datasets') + elif model_key == 'gemma-3-270m': + reasons.append('Perfect for tiny datasets') elif model_key == 'phi-4-mini': reasons.append('Works well with limited data') elif num_examples > 2000: @@ -340,14 +233,22 @@ def generate_reasons(model_key: str, task: str, num_examples: int, avg_response: reasons.append('Excellent JSON/structured output') # Deployment reason - if deployment in ['mobile_app', 'ios_app', 'android_app'] and model_key in ['phi-4-mini', 'gemma-3-2b']: + if deployment in ['mobile_app', 'ios_app', 'android_app'] and model_key in ['phi-4-mini', 'gemma-3-2b', 'gemma-3-270m']: reasons.append('Optimized for mobile deployment') - elif deployment == 'web_browser' and model_key in ['gemma-3-2b', 'phi-4-mini']: + elif deployment == 'web_browser' and model_key in ['gemma-3-2b', 'phi-4-mini', 'gemma-3-270m']: reasons.append('Runs efficiently in browser') - elif deployment == 'edge_device' and model_key == 'gemma-3-2b': + elif deployment == 'edge_device' and model_key in ['gemma-3-2b', 'gemma-3-270m']: reasons.append('Designed for edge devices') - return reasons[:5] + # Deduplicate and limit + unique_reasons = [] + seen = set() + for r in reasons: + if r not in seen: + unique_reasons.append(r) + seen.add(r) + + return unique_reasons[:5] def build_response(