From f889e36a426a7690b63bdacdd01f0bc5bde097a5 Mon Sep 17 00:00:00 2001 From: anish1206 Date: Fri, 30 Jan 2026 11:29:56 +0530 Subject: [PATCH 01/30] feat: implement self-correcting CHIME model and federated learning --- .gitignore | 4 + dreamsApp/app/dashboard/main.py | 68 +++- dreamsApp/app/fl_worker.py | 222 +++++++++++++ .../app/templates/dashboard/profile.html | 102 +++++- dreamsApp/app/utils/logger.py | 56 ++++ dreamsApp/app/utils/sentiment.py | 19 +- dreamsApp/docs/federated-learning.md | 314 ++++++++++++++++++ tests/test_fl.py | 129 +++++++ 8 files changed, 905 insertions(+), 9 deletions(-) create mode 100644 dreamsApp/app/fl_worker.py create mode 100644 dreamsApp/app/utils/logger.py create mode 100644 dreamsApp/docs/federated-learning.md create mode 100644 tests/test_fl.py diff --git a/.gitignore b/.gitignore index 48b37bf..b3c3516 100644 --- a/.gitignore +++ b/.gitignore @@ -184,3 +184,7 @@ cython_debug/ # Virtual environments venv310/ venv/ + +# Federated Learning Models +dreamsApp/app/models/production_chime_model/ +dreamsApp/app/models/temp_training_artifact/ diff --git a/dreamsApp/app/dashboard/main.py b/dreamsApp/app/dashboard/main.py index 0a5d007..e43d7ea 100644 --- a/dreamsApp/app/dashboard/main.py +++ b/dreamsApp/app/dashboard/main.py @@ -10,6 +10,8 @@ from wordcloud import WordCloud from ..utils.llms import generate from flask import jsonify +from bson.objectid import ObjectId +import datetime def generate_wordcloud_b64(keywords, colormap): """Refactor: Helper to generate base64 encoded word cloud image.""" @@ -113,9 +115,13 @@ def profile(target): chime_lookup = {k.lower(): k for k in chime_counts} for post in user_posts: - if post.get('chime_analysis'): - label = post['chime_analysis'].get('label', '').lower() - original_key = chime_lookup.get(label) + # Prioritize user correction if available + label_to_use = post.get('corrected_label') + if not label_to_use and post.get('chime_analysis'): + label_to_use = post['chime_analysis'].get('label', '') + + if label_to_use: + original_key = chime_lookup.get(label_to_use.lower()) if original_key: chime_counts[original_key] += 1 @@ -177,7 +183,20 @@ def profile(target): wordcloud_positive_data = generate_wordcloud_b64(positive_keywords, 'GnBu') wordcloud_negative_data = generate_wordcloud_b64(negative_keywords, 'OrRd') - return render_template('dashboard/profile.html', plot_url=plot_data, chime_plot_url=chime_plot_data, positive_wordcloud_url=wordcloud_positive_data, negative_wordcloud_url=wordcloud_negative_data, thematics=thematics,user_id=str(target_user_id)) + # Sort posts to get the latest one + user_posts.sort(key=lambda x: x['timestamp'], reverse=True) + latest_post = user_posts[0] if user_posts else None + + return render_template( + 'dashboard/profile.html', + plot_url=plot_data, + chime_plot_url=chime_plot_data, + positive_wordcloud_url=wordcloud_positive_data, + negative_wordcloud_url=wordcloud_negative_data, + thematics=thematics, + user_id=str(target_user_id), + latest_post=latest_post # Pass only the latest post for feedback + ) @bp.route('/clusters/') @login_required @@ -226,4 +245,43 @@ def thematic_refresh(user_id): return jsonify({ "success": False, "message": str(e) - }), 500 \ No newline at end of file + }), 500 + +@bp.route('/correct_chime', methods=['POST']) +@login_required +def correct_chime(): + data = request.get_json() + post_id = data.get('post_id') + corrected_label = data.get('corrected_label') + + if not all([post_id, corrected_label]): + return jsonify({'success': False, 'error': 'Missing fields'}), 400 + + mongo = current_app.mongo['posts'] + + # Update the post using $set to add correction data + result = mongo.update_one( + {'_id': ObjectId(post_id)}, + { + '$set': { + 'corrected_label': corrected_label, + 'is_fl_processed': False, + 'correction_timestamp': datetime.datetime.now() + } + } + ) + + if result.modified_count > 0: + # Check for FL Trigger + pending_count = mongo.count_documents({'corrected_label': {'$exists': True}, 'is_fl_processed': False}) + + if pending_count >= 50: + # Trigger FL training in background thread (user doesn't wait) + import threading + from dreamsApp.app.fl_worker import run_federated_round + thread = threading.Thread(target=run_federated_round, daemon=True) + thread.start() + + return jsonify({'success': True}) + else: + return jsonify({'success': False, 'error': 'Post not found or no change'}), 404 \ No newline at end of file diff --git a/dreamsApp/app/fl_worker.py b/dreamsApp/app/fl_worker.py new file mode 100644 index 0000000..1686e03 --- /dev/null +++ b/dreamsApp/app/fl_worker.py @@ -0,0 +1,222 @@ +import torch +import shutil +import os +import json +import datetime +from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig +from dreamsApp.app import create_app +from dreamsApp.app.utils.logger import setup_logger + +# Setup Logger +logger = setup_logger('fl_worker') + +# --- CONFIGURATION --- +BASE_MODEL_ID = "ashh007/dreams-chime-bert" +# Determine absolute paths based on app location to ensure robustness +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +# dreamsApp/app/models/production_chime_model +PRODUCTION_MODEL_DIR = os.path.join(BASE_DIR, "models", "production_chime_model") +# dreamsApp/app/models/temp_training_artifact +TEMP_MODEL_DIR = os.path.join(BASE_DIR, "models", "temp_training_artifact") + +BATCH_SIZE = 50 +LEARNING_RATE = 1e-5 # Conservative learning rate + +# "Anchor Set": 5 obvious examples that MUST remain correct (Prevent catastrophic forgetting) +ANCHOR_EXAMPLES = [ + {"text": "I feel completely safe and surrounded.", "label": "Connectedness"}, + {"text": "I see a bright future ahead.", "label": "Hope"}, + {"text": "I don't know who I am anymore.", "label": "Identity"}, + {"text": "My life has deep purpose.", "label": "Meaning"}, + {"text": "I have the power to change my situation.", "label": "Empowerment"} +] + +def validate_model(model, tokenizer, training_samples, label2id): + """ + Returns True if model passes BOTH Safety Checks and Improvement Checks. + """ + model.eval() + logger.info("Running Validation Gate...") + + # 1. ANCHOR CHECK (Safety) + correct_anchors = 0 + with torch.no_grad(): + for example in ANCHOR_EXAMPLES: + inputs = tokenizer(example["text"], return_tensors="pt", truncation=True, padding=True) + outputs = model(**inputs) + pred_id = torch.argmax(outputs.logits).item() + + # Dynamic Label Check + target_str = example["label"] + target_id = label2id.get(target_str) + + # If the label exists in model config and matches prediction + if target_id is not None and pred_id == target_id: + correct_anchors += 1 + else: + # Debug print for failure + # Get the string label for the prediction + id2label = {v: k for k, v in label2id.items()} + pred_str = id2label.get(pred_id, "Unknown") + logger.debug(f"[Anchor Fail] Text: '{example['text'][:30]}...' Expected: {target_str}, Got: {pred_str}") + + logger.info(f"[Safety Check] Anchor Accuracy: {correct_anchors}/{len(ANCHOR_EXAMPLES)}") + if correct_anchors < 3: # Relaxed slightly for small batch variance + logger.error("FAIL: Model has forgotten basic concepts (Catastrophic Forgetting).") + return False + + # 2. IMPROVEMENT CHECK (Did it learn?) + correct_new = 0 + total_new = len(training_samples) + with torch.no_grad(): + for text, label_idx in training_samples: + inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) + outputs = model(**inputs) + pred_label = torch.argmax(outputs.logits).item() + if pred_label == label_idx: + correct_new += 1 + + logger.info(f"[Improvement Check] Training Set Accuracy: {correct_new}/{total_new}") + + if correct_new / total_new < 0.5: + logger.error("FAIL: Model failed to learn the new corrections.") + return False + + return True + +def run_federated_round(): + app = create_app() + with app.app_context(): + mongo = app.mongo + logger.info("FL WORKER: Waking up...") + + try: + # 1. Fetch Pending Data + query = { + 'corrected_label': {'$exists': True}, + 'is_fl_processed': False + } + + # Limit to batch size + cursor = mongo['posts'].find(query).limit(BATCH_SIZE) + pending_posts = list(cursor) + + if len(pending_posts) < BATCH_SIZE: + logger.info(f"Only {len(pending_posts)} corrections available. Waiting for {BATCH_SIZE}.") + return + + # Prepare Data + # We need to fetch the configuration to know the label map + try: + config = AutoConfig.from_pretrained(BASE_MODEL_ID) + label2id = config.label2id + except Exception as e: + # Fallback if config fetch fails + logger.warning(f"Could not load config from HuggingFace: {e}. Using fallback label map.") + label2id = {"Connectedness": 0, "Hope": 1, "Identity": 2, "Meaning": 3, "Empowerment": 4} + + training_data = [] # List of (text, label_idx) + valid_ids = [] + + for p in pending_posts: + lbl = p.get('corrected_label') + if lbl in label2id: + training_data.append((p.get('caption'), label2id[lbl])) + valid_ids.append(p['_id']) + elif lbl == 'None': + # Mark 'None' as processed but don't train + mongo['posts'].update_one({'_id': p['_id']}, {'$set': {'is_fl_processed': True, 'fl_status': 'skipped'}}) + logger.debug(f"Skipped 'None' label for post {p['_id']}") + + if not training_data: + logger.info("No valid labels found (mostly 'None'). Marking processed and exiting.") + return + + logger.info(f"Starting Training Round with {len(training_data)} samples.") + + # 2. Load Model (CONTINUOUS LEARNING) + if os.path.exists(PRODUCTION_MODEL_DIR): + logger.info(f"Loading existing Production Model from {PRODUCTION_MODEL_DIR}...") + load_path = PRODUCTION_MODEL_DIR + else: + logger.info("First run: Loading Base Model from Hugging Face...") + load_path = BASE_MODEL_ID + + tokenizer = AutoTokenizer.from_pretrained(load_path) + model = AutoModelForSequenceClassification.from_pretrained(load_path, num_labels=len(label2id)) + + # Freeze BERT Base, Train Head + if hasattr(model, 'bert'): + for param in model.bert.parameters(): + param.requires_grad = False + elif hasattr(model, 'base_model'): + for param in model.base_model.parameters(): + param.requires_grad = False + + logger.debug("Base layers frozen. Training classifier head only.") + + # 3. Training Loop + model.train() + optimizer = torch.optim.AdamW(model.classifier.parameters(), lr=LEARNING_RATE) + + texts = [item[0] for item in training_data] + labels_tensor = torch.tensor([item[1] for item in training_data]) + inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt") + + EPOCHS = 3 + for epoch in range(EPOCHS): + optimizer.zero_grad() + outputs = model(**inputs, labels=labels_tensor) + loss = outputs.loss + loss.backward() + optimizer.step() + logger.info(f"[Epoch {epoch+1}/{EPOCHS}] Loss: {loss.item():.4f}") + + # 4. Save to TEMP + if os.path.exists(TEMP_MODEL_DIR): + shutil.rmtree(TEMP_MODEL_DIR) # Clean start + model.save_pretrained(TEMP_MODEL_DIR) + tokenizer.save_pretrained(TEMP_MODEL_DIR) + logger.debug(f"Model saved to temp directory: {TEMP_MODEL_DIR}") + + # 5. Validation Gate + passed = validate_model(model, tokenizer, training_data, label2id) + + if passed: + logger.info("Update Accepted! Promoting to Production...") + # ATOMIC SWAP: Overwrite production folder + if os.path.exists(PRODUCTION_MODEL_DIR): + shutil.rmtree(PRODUCTION_MODEL_DIR) + + # Ensure parent dict exists + os.makedirs(os.path.dirname(PRODUCTION_MODEL_DIR), exist_ok=True) + + shutil.copytree(TEMP_MODEL_DIR, PRODUCTION_MODEL_DIR) + logger.info(f"SUCCESS: Central Model updated at {PRODUCTION_MODEL_DIR}") + else: + logger.warning("Update Rejected. Discarding changes.") + + # Cleanup Temp + if os.path.exists(TEMP_MODEL_DIR): + shutil.rmtree(TEMP_MODEL_DIR) + + # 6. Finish + logger.info("Updating database records...") + mongo['posts'].update_many( + {'_id': {'$in': valid_ids}}, + {'$set': { + 'is_fl_processed': True, + 'fl_round_date': datetime.datetime.now() + }} + ) + logger.info(f"Round Successfully Completed. Processed {len(valid_ids)} items.") + + except Exception as e: + logger.error(f"CRITICAL FAILURE during FL round: {str(e)}", exc_info=True) + # Cleanup temp if it exists after a failure + if os.path.exists(TEMP_MODEL_DIR): + shutil.rmtree(TEMP_MODEL_DIR) + raise # Re-raise so caller knows it failed + +if __name__ == "__main__": + run_federated_round() diff --git a/dreamsApp/app/templates/dashboard/profile.html b/dreamsApp/app/templates/dashboard/profile.html index 5fbdce9..6b44884 100644 --- a/dreamsApp/app/templates/dashboard/profile.html +++ b/dreamsApp/app/templates/dashboard/profile.html @@ -173,10 +173,110 @@

Challenging Themes

+ + {% if latest_post %} +
+

Latest Entry Analysis

+ +
+
+
+
{{ latest_post.timestamp.strftime('%Y-%m-%d %H:%M') }}
+ Sentiment: {{ latest_post.sentiment.label }} +
+

"{{ latest_post.caption }}"

+ +
+
+ AI Classification: + {% set current_label = latest_post.corrected_label if latest_post.corrected_label else latest_post.chime_analysis.label %} + + {{ current_label }} + + {% if latest_post.corrected_label %} + Verified ✓ + {% endif %} +
+ + {% if not latest_post.corrected_label %} +
+ + +
+ + + + {% endif %} +
+
+
+
+ {% endif %} +