diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fce2b0e --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +BLIP_MRI/project/hf_results +BLIP_MRI/project/hf_logs +BLIP_MRI/logs +BLIP_MRI/project/wandb +BLIP_MRI/project/dataset/__pychache__ +BLIP_MRI/project/model/__pychache__ +BLIP_MRI/project/utils/__pychache__ +BLIP_MRI/project/*.json +BLIP_MRI/project/data/*.json +BLIP_MRI/project/check_metadata.py +BLIP_MRI/project/data/abcd +BLIP_MRI/project/data/gard +BLIP_MRI/project/data/ukb +__pycache__/ +*.pyc diff --git a/BLIP_MRI/README.md b/BLIP_MRI/README.md index 8b13789..271d5b6 100644 --- a/BLIP_MRI/README.md +++ b/BLIP_MRI/README.md @@ -1 +1,128 @@ +# LLaVA-NeXT-Interleave for Brain MRI Comparison Tasks + +Multi-image comparison framework using LLaVA-NeXT-Interleave for brain MRI analysis. + +--- + +## Overview + +**Architecture:** LLaVA-NeXT-Interleave (Qwen-0.5b) +**Task:** Reference-augmented comparison +**Format:** Multi-turn conversation with interleaved images + +**Example:** +``` +Turn 1 (Reference): +User: "Here is a brain scan from a male participant. " +Assistant: "Understood. I've analyzed the reference scan." + +Turn 2 (Query): +User: "Compare this scan with the reference. What is the sex?" +Assistant: "Based on the comparison, the subject is male." +``` + +--- + +## 1. Data Preparation + +### Generate Comparison JSON + +#### Sex(or any categorical attributes) Comparison (Categorical) + +```bash +python generate_json_comparison_split_general.py \ + --study_sample ABCD \ + --meta_path /path/to/phenotype.csv \ + --img_dir /path/to/images \ + --target_col sex # (can be other categorical attributes)\ + --task_type categorical \ + --output_dir ./ \ + --num_pairs 3 \ + --seed 1234 +``` + +**Output:** +- `data/ABCD_sex_comparison_tasks_train.json` +- `data/ABCD_sex_comparison_tasks_val.json` +- `data/ABCD_sex_comparison_tasks_test.json` + +#### BMI_sds(or any numerical attributes) Regression (Numerical) + +```bash +python generate_json_comparison_split_general.py \ + --study_sample ABCD \ + --meta_path /path/to/phenotype.csv \ + --img_dir /path/to/images \ + --target_col BMI_sds # (can be other numerical attributes) \ + --task_type numerical \ + --output_dir ./ \ + --num_pairs 3 \ + --seed 1234 +``` + +**Output:** +- `data/ABCD_BMI_sds_comparison_tasks_train.json` +- `data/ABCD_BMI_sds_comparison_tasks_val.json` +- `data/ABCD_BMI_sds_comparison_tasks_test.json` + +**Key Parameters:** +- `--num_pairs`: Number of references per query subject + +--- + +## 2. Data Split Logic + +**Complete Separation:** +- **Inter-split:** Train/Val/Test subjects do NOT overlap +- **Intra-split:** Query and Reference pools do NOT overlap within each split + +**Example (1000 subjects, 70/15/15 split):** +``` +Train: 700 subjects + ├─ Query: 350 subjects + └─ Reference: 350 subjects (different from query!) + +Val: 150 subjects + ├─ Query: 75 subjects + └─ Reference: 75 subjects + +Test: 150 subjects + ├─ Query: 75 subjects + └─ Reference: 75 subjects +``` + +**Why?** Test subjects NEVER appear in training (even as references) for true generalization test + +--- + +## 3. Training + +### Configure + +Edit `config/Brain_LLaVa_train_Deepspeed_joint_multiturn_comparison.yaml`: + +```yaml +dataset: + target_col: "sex" # or "age", "diagnosis", etc. + + train_json: + - "./data/ABCD_sex_comparison_tasks_train.json" + val_json: + - "./data/ABCD_sex_comparison_tasks_val.json" + test_json: + - "./data/ABCD_sex_comparison_tasks_test.json" + +``` + +## 4. Troubleshooting + +**Error: File not found** +- Check image paths in JSON match actual files + +**Error: Image token mismatch** +- Ensure using updated `Trainer_LLaVA_Next_interleave.py` + +**Low metrics** +- Check data split: Are train/val/test balanced? +- Review generation logs for prediction quality diff --git a/BLIP_MRI/environment_llava.yaml b/BLIP_MRI/environment_llava.yaml index 688c737..2496f24 100644 --- a/BLIP_MRI/environment_llava.yaml +++ b/BLIP_MRI/environment_llava.yaml @@ -1,4 +1,4 @@ -name: /pscratch/sd/h/heehaw/anaconda/BLIP_MRI_llava +name: BLIP_MRI_llava channels: - conda-forge dependencies: @@ -125,7 +125,7 @@ dependencies: - sympy==1.14.0 - threadpoolctl==3.6.0 - timm==0.4.12 - - tokenizers==0.13.3 + - tokenizers>=0.20 - torch==2.8.0 - torchvision==0.23.0 - tqdm==4.67.1 @@ -138,4 +138,4 @@ dependencies: - wandb==0.17.0 - xxhash==3.5.0 - yarl==1.20.1 -prefix: /pscratch/sd/h/heehaw/anaconda/BLIP_MRI_llava +prefix: /YOUR_DIRECTORY \ No newline at end of file diff --git a/BLIP_MRI/project/config/Brain_LLaVa_train_Deepspeed_joint_multiturn_comparison.yaml b/BLIP_MRI/project/config/Brain_LLaVa_train_Deepspeed_joint_multiturn_comparison.yaml new file mode 100644 index 0000000..499d177 --- /dev/null +++ b/BLIP_MRI/project/config/Brain_LLaVa_train_Deepspeed_joint_multiturn_comparison.yaml @@ -0,0 +1,40 @@ +wandb: + API_KEY: "YOUR_API_KEY" + +seed: 1234 + +dataset: + # Target column + # Single-task: + target_col: "sex" # Options: 'sex', 'age', 'diagnosis', 'bmi', etc. + + # Multi-task (future support): + # target_col: ["sex", "age"] # List for multiple targets + + train_json: + - "./data/ABCD_sex_comparison_tasks_train.json" + val_json: + - "./data/ABCD_sex_comparison_tasks_val.json" + test_json: + - "./data/ABCD_sex_comparison_tasks_test.json" + + # Image size + img_size: [120, 120, 120] + + +model: + hf_name: "llava-hf/llava-interleave-qwen-0.5b-hf" + patch_size: [10, 10, 10] + + +trainer: + max_epochs: 50 + learning_rate: 0.00005 + warmup_steps: 500 + weight_decay: 0.01 + per_device_batch_size: 1 # Multi-image (reference + query) requires more memory + gradient_accumulation_steps: 4 + gradient_checkpointing: True + logging_steps: 1 + ckpt_dir: "./hf_results/{}/last.ckpt" + resume_training: False diff --git a/BLIP_MRI/project/data/generate_json_general_comparison_split.py b/BLIP_MRI/project/data/generate_json_general_comparison_split.py new file mode 100644 index 0000000..b1331f8 --- /dev/null +++ b/BLIP_MRI/project/data/generate_json_general_comparison_split.py @@ -0,0 +1,907 @@ +""" +Generate JSON files for Comparison Tasks + +Supports: +- Categorical tasks: sex, diagnosis, etc. +- Numerical tasks: age, BMI, glucose, etc. + +This script ensures complete separation: +- Inter-split: Train/Val/Test subjects do NOT overlap +- Intra-split: Within each split, Query and Reference pools do NOT overlap +""" + +import os +import json +import pandas as pd +import glob +import numpy as np +from pathlib import Path +import random + +def load_subjects_and_images(meta_path, img_dir, subject_id_col, target_col, study_sample='ABCD', max_subjects=None): + """Load metadata and available images""" + + # Load metadata + meta = pd.read_csv(meta_path) + meta = meta[[subject_id_col, target_col]].dropna() + + # Load available images + image_files = glob.glob(os.path.join(img_dir, '*.nii.gz')) + image_dict = {} + + suffix_len = -7 # Remove '.nii.gz' + + for img_path in image_files: + filename = os.path.basename(img_path) + subject_id = filename[:suffix_len] + image_dict[subject_id] = img_path + + if subject_id_col == 'subject_id': # GARD + # CSV의 subject_id 타입 확인 + if pd.api.types.is_integer_dtype(meta[subject_id_col]): + # subject_id가 int면 image_dict의 key도 int로 변환 + # _brain 같은 suffix 제거 + image_dict_converted = {} + for k, v in image_dict.items(): + # 숫자만 추출 (_brain 제거) + k_clean = k.replace('_brain', '') + try: + image_dict_converted[int(k_clean)] = v + except ValueError: + continue # 변환 실패하면 스킵 + image_dict = image_dict_converted + + # Filter subjects with both metadata and images + meta = meta[meta[subject_id_col].isin(image_dict.keys())].reset_index(drop=True) + + # Limit number of subjects if specified + if max_subjects is not None and len(meta) > max_subjects: + print(f"Limiting to {max_subjects} subjects (from {len(meta)})") + # Stratified sampling to maintain class balance + if pd.api.types.is_numeric_dtype(meta[target_col]) and meta[target_col].nunique() <= 10: + # Categorical: stratified by class + samples_per_class = max_subjects // meta[target_col].nunique() + meta = meta.groupby(target_col, group_keys=False).apply( + lambda x: x.sample(min(len(x), samples_per_class), random_state=1234) + ).reset_index(drop=True) + else: + # Numerical or large categorical: random sample + meta = meta.sample(n=max_subjects, random_state=1234).reset_index(drop=True) + + # Remap sex values to 0/1 if target_col is 'sex' + if 'sex' in target_col.lower(): + unique_values = meta[target_col].unique() + if set(unique_values).issubset({1, 2}): + print(f"Sex values are 1/2 format. Remapping: 1->0 (male), 2->1 (female)") + meta[target_col] = meta[target_col] - 1 + elif set(unique_values).issubset({'M', 'F', 'Male', 'Female', 'male', 'female'}): + print(f"Sex values are string format. Remapping: M/Male/male->0, F/Female/female->1") + meta[target_col] = meta[target_col].map({ + 'M': 0, 'Male': 0, 'male': 0, + 'F': 1, 'Female': 1, 'female': 1 + }) + elif not set(unique_values).issubset({0, 1}): + print(f"[WARN] Sex values are unexpected format: {unique_values}") + print(f" Expected: 0/1, 1/2, or M/F variants") + + return meta, image_dict + + +def detect_task_type(meta, target_col): + """ + Automatically detect if task is categorical or numerical + + Returns: + 'categorical' or 'numerical' + """ + unique_values = meta[target_col].unique() + + # Check if all values are numeric + if pd.api.types.is_numeric_dtype(meta[target_col]): + # If small number of unique values (< 10), likely categorical + if len(unique_values) <= 10: + return 'categorical' + else: + return 'numerical' + else: + return 'categorical' + + +def parse_categorical_mapping(meta, target_col, mapping_str=None): + """ + Parse categorical mapping from string or auto-detect + + Args: + meta: DataFrame + target_col: Target column name + mapping_str: Optional mapping string like "male=0,female=1" or "1=male,2=female" + + Returns: + value_to_label: dict mapping numeric values to string labels + label_to_value: dict mapping string labels to numeric values + """ + + if mapping_str: + # Parse user-provided mapping + # Supports: "male=0,female=1" or "0=male,1=female" + pairs = mapping_str.split(',') + value_to_label = {} + label_to_value = {} + + for pair in pairs: + parts = pair.strip().split('=') + if len(parts) == 2: + key, val = parts[0].strip(), parts[1].strip() + # Try to determine which is numeric + try: + num_val = int(key) + str_label = val + except ValueError: + num_val = int(val) + str_label = key + + value_to_label[num_val] = str_label + label_to_value[str_label] = num_val + else: + # Auto-detect from data + unique_values = sorted(meta[target_col].unique()) + + # Special handling for 'sex' column + if 'sex' in target_col.lower(): + # Common sex encoding: 0/1 or 1/2 + if set(unique_values) == {0, 1}: + value_to_label = {0: "male", 1: "female"} + label_to_value = {"male": 0, "female": 1} + print(" Detected sex column with 0/1 encoding (0=male, 1=female)") + elif set(unique_values) == {1, 2}: + # Will be remapped to 0/1 later + value_to_label = {1: "male", 2: "female"} + label_to_value = {"male": 1, "female": 2} + print(" Detected sex column with 1/2 encoding (1=male, 2=female)") + else: + # Fallback for unexpected values + value_to_label = {val: str(val) for val in unique_values} + label_to_value = {str(val): val for val in unique_values} + # Check if string values (use as-is) + elif not pd.api.types.is_numeric_dtype(meta[target_col]): + # String categorical values - use original values as labels + value_to_label = {val: str(val) for val in unique_values} + label_to_value = {str(val): val for val in unique_values} + print(f" Using original string values as labels: {list(unique_values)}") + # Check if already 0-indexed integers + elif set(unique_values) == set(range(len(unique_values))): + # Use generic labels + value_to_label = {i: f"class_{i}" for i in unique_values} + label_to_value = {f"class_{i}": i for i in unique_values} + # Check if 1-indexed + elif set(unique_values) == set(range(1, len(unique_values) + 1)): + # Remap to 0-indexed with generic labels + value_to_label = {i: f"class_{i-1}" for i in unique_values} + label_to_value = {f"class_{i-1}": i for i in unique_values} + else: + # Mixed values, use as-is + value_to_label = {val: str(val) for val in unique_values} + label_to_value = {str(val): val for val in unique_values} + + return value_to_label, label_to_value + + +def remap_categorical_values(meta, target_col, value_to_label): + """ + Remap categorical values to 0-indexed if needed + + Returns: + meta: DataFrame with remapped values + value_to_label: Updated mapping + """ + + # Check if values need remapping + unique_values = sorted(meta[target_col].unique()) + + if set(unique_values) == set(value_to_label.keys()): + # Already correct, but might need 0-indexing + if min(unique_values) == 1: + # 1-indexed → 0-indexed + print(f"Remapping {target_col} from 1-indexed to 0-indexed:") + for old_val in unique_values: + new_val = old_val - 1 + label = value_to_label[old_val] + print(f" {old_val} ({label}) → {new_val}") + meta[target_col] = meta[target_col].replace(old_val, new_val) + + # Update mapping + new_value_to_label = {old_val - 1: label for old_val, label in value_to_label.items()} + return meta, new_value_to_label + + return meta, value_to_label + + +def split_subjects_categorical(meta, subject_id_col, target_col, value_to_label, + train_ratio=0.7, val_ratio=0.15, seed=1234): + """ + Split subjects for categorical tasks (stratified by class) with COMPLETE SEPARATION + + Returns a dictionary with 6 subject pools: + - train_query, train_ref + - val_query, val_ref + - test_query, test_ref + + This ensures: + 1. Inter-split: Train/Val/Test subjects don't overlap + 2. Intra-split: Query and Reference pools don't overlap within each split + """ + + random.seed(seed) + np.random.seed(seed) + + train_subjects = [] + val_subjects = [] + test_subjects = [] + + # First split by class (stratified) + for value, label in value_to_label.items(): + class_subjects = meta[meta[target_col] == value][subject_id_col].values.tolist() + random.shuffle(class_subjects) + + n = len(class_subjects) + n_train = int(n * train_ratio) + n_val = int(n * val_ratio) + + train_subjects.extend(class_subjects[:n_train]) + val_subjects.extend(class_subjects[n_train:n_train+n_val]) + test_subjects.extend(class_subjects[n_train+n_val:]) + + print(f" {label} (value={value}): {n} subjects") + print(f" Train: {n_train}, Val: {n_val}, Test: {n - n_train - n_val}") + + # Further split each set into query and reference (50/50) + def split_query_ref(subjects_list): + """Split subjects into query and reference pools""" + random.shuffle(subjects_list) + n = len(subjects_list) + query = subjects_list[:n//2] + ref = subjects_list[n//2:] + return query, ref + + train_query, train_ref = split_query_ref(train_subjects) + val_query, val_ref = split_query_ref(val_subjects) + test_query, test_ref = split_query_ref(test_subjects) + + print(f"\n Query/Reference split:") + print(f" Train: Query={len(train_query)}, Ref={len(train_ref)}") + print(f" Val: Query={len(val_query)}, Ref={len(val_ref)}") + print(f" Test: Query={len(test_query)}, Ref={len(test_ref)}") + + return { + 'train_query': train_query, + 'train_ref': train_ref, + 'val_query': val_query, + 'val_ref': val_ref, + 'test_query': test_query, + 'test_ref': test_ref + } + + +def split_subjects_numerical(meta, subject_id_col, target_col, + train_ratio=0.7, val_ratio=0.15, seed=1234): + """ + Split subjects for numerical tasks (stratified by value bins) with COMPLETE SEPARATION + + Returns a dictionary with 6 subject pools: + - train_query, train_ref + - val_query, val_ref + - test_query, test_ref + + This ensures: + 1. Inter-split: Train/Val/Test subjects don't overlap + 2. Intra-split: Query and Reference pools don't overlap within each split + """ + + random.seed(seed) + np.random.seed(seed) + + # Bin values into quartiles for stratification + meta['_bin'] = pd.qcut(meta[target_col], q=4, labels=False, duplicates='drop') + + train_subjects = [] + val_subjects = [] + test_subjects = [] + + # First split each bin into train/val/test + for bin_idx in sorted(meta['_bin'].unique()): + bin_subjects = meta[meta['_bin'] == bin_idx][subject_id_col].values.tolist() + random.shuffle(bin_subjects) + + n = len(bin_subjects) + n_train = int(n * train_ratio) + n_val = int(n * val_ratio) + + train_subjects.extend(bin_subjects[:n_train]) + val_subjects.extend(bin_subjects[n_train:n_train+n_val]) + test_subjects.extend(bin_subjects[n_train+n_val:]) + + bin_meta = meta[meta['_bin'] == bin_idx] + print(f" Bin {bin_idx} ({bin_meta[target_col].min():.1f}-{bin_meta[target_col].max():.1f}): {n} subjects") + print(f" Train: {n_train}, Val: {n_val}, Test: {n - n_train - n_val}") + + meta = meta.drop('_bin', axis=1) + + # Further split each set into query and reference (50/50) + def split_query_ref(subjects_list): + """Split subjects into query and reference pools""" + random.shuffle(subjects_list) + n = len(subjects_list) + query = subjects_list[:n//2] + ref = subjects_list[n//2:] + return query, ref + + train_query, train_ref = split_query_ref(train_subjects) + val_query, val_ref = split_query_ref(val_subjects) + test_query, test_ref = split_query_ref(test_subjects) + + print(f"\n Query/Reference split:") + print(f" Train: Query={len(train_query)}, Ref={len(train_ref)}") + print(f" Val: Query={len(val_query)}, Ref={len(val_ref)}") + print(f" Test: Query={len(test_query)}, Ref={len(test_ref)}") + + return { + 'train_query': train_query, + 'train_ref': train_ref, + 'val_query': val_query, + 'val_ref': val_ref, + 'test_query': test_query, + 'test_ref': test_ref + } + + +def generate_comparison_tasks_categorical( + query_subjects, + reference_subjects, + meta, + image_dict, + subject_id_col, + target_col, + value_to_label, + num_pairs_per_subject=5, + same_class_ratio=0.5, + seed=1234 +): + """Generate comparison tasks for categorical target""" + + random.seed(seed) + + query_meta = meta[meta[subject_id_col].isin(query_subjects)].reset_index(drop=True) + + # Group reference subjects by class + ref_meta = meta[meta[subject_id_col].isin(reference_subjects)] + ref_by_class = {} + for value in value_to_label.keys(): + ref_by_class[value] = ref_meta[ref_meta[target_col] == value][subject_id_col].values.tolist() + + print(f"\nReference pool: {len(reference_subjects)} subjects") + for value, label in value_to_label.items(): + print(f" {label}: {len(ref_by_class[value])}") + + all_tasks = [] + + for _, row in query_meta.iterrows(): + query_id = row[subject_id_col] + query_value = int(row[target_col]) + query_label = value_to_label[query_value] + query_img_path = image_dict[query_id] + + # Determine same-class vs different-class pairs + num_same = int(num_pairs_per_subject * same_class_ratio) + num_diff = num_pairs_per_subject - num_same + + # Sample same-class references + same_pool = [s for s in ref_by_class[query_value] if s != query_id] + if len(same_pool) >= num_same: + same_refs = random.sample(same_pool, num_same) + else: + same_refs = same_pool + + # Sample different-class references + diff_pool = [] + for value in value_to_label.keys(): + if value != query_value: + diff_pool.extend(ref_by_class[value]) + + if len(diff_pool) >= num_diff: + diff_refs = random.sample(diff_pool, num_diff) + else: + diff_refs = diff_pool + + # Create tasks for same-class + for ref_id in same_refs: + ref_value = query_value + ref_label = query_label + ref_img_path = image_dict[ref_id] + + task = create_task_categorical( + query_id, query_value, query_label, query_img_path, + ref_id, ref_value, ref_label, ref_img_path, + comparison_type='same', + target_name=target_col + ) + all_tasks.append(task) + + # Create tasks for different-class + for ref_id in diff_refs: + ref_value = int(meta[meta[subject_id_col] == ref_id][target_col].values[0]) + ref_label = value_to_label[ref_value] + ref_img_path = image_dict[ref_id] + + task = create_task_categorical( + query_id, query_value, query_label, query_img_path, + ref_id, ref_value, ref_label, ref_img_path, + comparison_type='different', + target_name=target_col + ) + all_tasks.append(task) + + print(f"Generated {len(all_tasks)} comparison tasks") + return all_tasks + +def generate_comparison_tasks_categorical( + query_subjects, + reference_subjects, + meta, + image_dict, + subject_id_col, + target_col, + value_to_label, + num_pairs_per_subject=5, + same_class_ratio=0.5, + seed=1234 + ): + """Generate comparison tasks for categorical target""" + + random.seed(seed) + + query_meta = meta[meta[subject_id_col].isin(query_subjects)].reset_index(drop=True) + + # Group reference subjects by class + ref_meta = meta[meta[subject_id_col].isin(reference_subjects)] + ref_by_class = {} + for value in value_to_label.keys(): + ref_by_class[value] = ref_meta[ref_meta[target_col] == value][subject_id_col].values.tolist() + + print(f"\nReference pool: {len(reference_subjects)} subjects") + for value, label in value_to_label.items(): + print(f" {label}: {len(ref_by_class[value])}") + + all_tasks = [] + + for _, row in query_meta.iterrows(): + query_id = row[subject_id_col] + # Convert to Python native types + if isinstance(query_id, (np.integer, np.int64)): + query_id = int(query_id) + + query_value = int(row[target_col]) + query_label = value_to_label[query_value] + query_img_path = image_dict[query_id] + + # Determine same-class vs different-class pairs + num_same = int(num_pairs_per_subject * same_class_ratio) + num_diff = num_pairs_per_subject - num_same + + # Sample same-class references + same_pool = [s for s in ref_by_class[query_value] if s != query_id] + if len(same_pool) >= num_same: + same_refs = random.sample(same_pool, num_same) + else: + same_refs = same_pool + + # Sample different-class references + diff_pool = [] + for value in value_to_label.keys(): + if value != query_value: + diff_pool.extend(ref_by_class[value]) + + if len(diff_pool) >= num_diff: + diff_refs = random.sample(diff_pool, num_diff) + else: + diff_refs = diff_pool + + # Create tasks for same-class + for ref_id in same_refs: + # Convert to Python native types + if isinstance(ref_id, (np.integer, np.int64)): + ref_id = int(ref_id) + + ref_value = query_value + ref_label = query_label + ref_img_path = image_dict[ref_id] + + task = create_task_categorical( + query_id, query_value, query_label, query_img_path, + ref_id, ref_value, ref_label, ref_img_path, + comparison_type='same', + target_name=target_col + ) + all_tasks.append(task) + + # Create tasks for different-class + for ref_id in diff_refs: + # Convert to Python native types + if isinstance(ref_id, (np.integer, np.int64)): + ref_id = int(ref_id) + + ref_value = int(meta[meta[subject_id_col] == ref_id][target_col].values[0]) + ref_label = value_to_label[ref_value] + ref_img_path = image_dict[ref_id] + + task = create_task_categorical( + query_id, query_value, query_label, query_img_path, + ref_id, ref_value, ref_label, ref_img_path, + comparison_type='different', + target_name=target_col + ) + all_tasks.append(task) + + print(f"Generated {len(all_tasks)} comparison tasks") + return all_tasks + +def generate_comparison_tasks_numerical( + query_subjects, + reference_subjects, + meta, + image_dict, + subject_id_col, + target_col, + num_pairs_per_subject=6, + seed=1234 + ): + """Generate comparison tasks for numerical target (e.g., age)""" + + random.seed(seed) + + query_meta = meta[meta[subject_id_col].isin(query_subjects)].reset_index(drop=True) + ref_meta = meta[meta[subject_id_col].isin(reference_subjects)] + + print(f"\nReference pool: {len(reference_subjects)} subjects") + print(f" {target_col} range: {ref_meta[target_col].min():.1f} - {ref_meta[target_col].max():.1f}") + + all_tasks = [] + + for _, row in query_meta.iterrows(): + query_id = row[subject_id_col] + # Convert to Python native types + if isinstance(query_id, (np.integer, np.int64)): + query_id = int(query_id) + + query_value = float(row[target_col]) + query_img_path = image_dict[query_id] + + # Sample references across different value ranges + ref_pool = [s for s in reference_subjects if s != query_id] + + if len(ref_pool) >= num_pairs_per_subject: + selected_refs = random.sample(ref_pool, num_pairs_per_subject) + else: + selected_refs = ref_pool + + for ref_id in selected_refs: + # Convert to Python native types + if isinstance(ref_id, (np.integer, np.int64)): + ref_id = int(ref_id) + + ref_value = float(meta[meta[subject_id_col] == ref_id][target_col].values[0]) + ref_img_path = image_dict[ref_id] + + task = create_task_numerical( + query_id, query_value, query_img_path, + ref_id, ref_value, ref_img_path, + target_name=target_col + ) + all_tasks.append(task) + + print(f"Generated {len(all_tasks)} comparison tasks") + return all_tasks + + +def create_task_categorical(query_id, query_value, query_label, query_img_path, + ref_id, ref_value, ref_label, ref_img_path, + comparison_type, target_name): + """Create task for categorical target""" + + task_id = f"{query_id}_{comparison_type}_{target_name}_comparison" + assistant_reasoning = ( + f"Based on comparison with the reference scan, this appears to be {query_label}." + ) + + task = { + "task_id": task_id, + "task_type": "T1", + "subject_ids": [ref_id, query_id], + "modalities": ["sMRI", "sMRI"], + "images": [ + {"path": ref_img_path, "token": "", "modality": "sMRI"}, + {"path": query_img_path, "token": "", "modality": "sMRI"} + ], + "conversations": [ + { + "role": "user", + "content": [ + {"type": "text", "text": f"Here is a T1-weighted brain MRI from a {ref_label} participant. This will serve as your reference scan."}, + {"type": "image", "modality": "sMRI", "image_path": ref_img_path} + ] + }, + { + "role": "assistant", + "content": [{"type": "text", "text": f"Understood. I've analyzed the reference {ref_label} brain scan."}] + }, + { + "role": "user", + "content": [ + {"type": "text", "text": f"Compare this brain scan with the reference. What is the {target_name}?"}, + {"type": "image", "modality": "sMRI", "image_path": query_img_path} + ] + }, + { + "role": "assistant", + "content": [{"type": "text", "text": assistant_reasoning}] + } + ], + "metadata": { + "subject_id": query_id, + "subject_label": query_label, + "subject_label_numeric": int(query_value), + "reference_id": ref_id, + "reference_label": ref_label, + "reference_label_numeric": int(ref_value), + "comparison_type": comparison_type, + "task": f"{target_name}_classification_via_comparison", + "target_name": target_name, + "task_type": "categorical" + } + } + + return task + + +def create_task_numerical(query_id, query_value, query_img_path, + ref_id, ref_value, ref_img_path, target_name): + """Create task for numerical target""" + + task_id = f"{query_id}_{target_name}_comparison" + + assistant_reasoning = ( + f"I estimate this subject's {target_name} to be approximately {query_value:.1f}." + ) + + task = { + "task_id": task_id, + "task_type": "T1", + "subject_ids": [ref_id, query_id], + "modalities": ["sMRI", "sMRI"], + "images": [ + {"path": ref_img_path, "token": "", "modality": "sMRI"}, + {"path": query_img_path, "token": "", "modality": "sMRI"} + ], + "conversations": [ + { + "role": "user", + "content": [ + {"type": "text", "text": f"Here is a T1-weighted brain MRI from a participant with {target_name}: {ref_value:.1f}. This will serve as your reference scan."}, + {"type": "image", "modality": "sMRI", "image_path": ref_img_path} + ] + }, + { + "role": "assistant", + "content": [{"type": "text", "text": f"Understood. I've analyzed the reference brain scan ({target_name}: {ref_value:.1f})."}] + }, + { + "role": "user", + "content": [ + {"type": "text", "text": f"Compare this brain scan with the reference. What is the {target_name}?"}, + {"type": "image", "modality": "sMRI", "image_path": query_img_path} + ] + }, + { + "role": "assistant", + "content": [{"type": "text", "text": assistant_reasoning}] + } + ], + "metadata": { + "subject_id": query_id, + "subject_value": float(query_value), + "reference_id": ref_id, + "reference_value": float(ref_value), + "task": f"{target_name}_regression_via_comparison", + "target_name": target_name, + "task_type": "numerical" + } + } + + return task + + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description='Generate comparison task JSON with proper split (GENERAL: categorical or numerical)' + ) + parser.add_argument('--study_sample', type=str, default='ABCD', choices=['ABCD', 'UKB', 'GARD'], + help='Study sample name') + parser.add_argument('--meta_path', type=str, required=True, + help='Path to phenotype CSV file') + parser.add_argument('--img_dir', type=str, required=True, + help='Directory containing MRI images') + parser.add_argument('--target_col', type=str, required=True, + help='Target column name (e.g., sex, age, BMI)') + parser.add_argument('--task_type', type=str, default='auto', choices=['auto', 'categorical', 'numerical'], + help='Task type (auto-detect if not specified)') + parser.add_argument('--categorical_mapping', type=str, default=None, + help='Categorical mapping (e.g., "male=0,female=1" or "1=male,2=female")') + parser.add_argument('--output_dir', type=str, default='./data', + help='Output directory for JSON files') + parser.add_argument('--output_prefix', type=str, default=None, + help='Output file prefix (default: {study_sample}_{target_col}_comparison_tasks)') + parser.add_argument('--subject_id_col', type=str, default=None, + help='Subject ID column name (default: subjectkey for ABCD, eid for UKB)') + parser.add_argument('--num_pairs', type=int, default=5, + help='Number of comparison pairs per query subject') + parser.add_argument('--same_class_ratio', type=float, default=0.5, + help='Ratio of same-class comparisons (categorical only)') + parser.add_argument('--train_ratio', type=float, default=0.7, + help='Train set ratio') + parser.add_argument('--val_ratio', type=float, default=0.15, + help='Validation set ratio') + parser.add_argument('--seed', type=int, default=1234, + help='Random seed') + parser.add_argument('--max_subjects', type=int, default=None, help='Maximum number of subjects to use (for quick testing)') + + args = parser.parse_args() + + # Set defaults + # if args.subject_id_col is None: + # args.subject_id_col = 'subjectkey' if args.study_sample == 'ABCD' else 'eid' + if args.subject_id_col is None: + if args.study_sample == 'ABCD': + args.subject_id_col = 'subjectkey' + elif args.study_sample == 'UKB': + args.subject_id_col = 'eid' + elif args.study_sample == 'GARD': + args.subject_id_col = 'subject_id' # ← 이 부분 추가 + + else: + print("[WARN] Unknown study_sample. Please specify--subject_id_col manually.") + args.subject_id_col = 'subject_id' + + if args.output_prefix is None: + args.output_prefix = f"{args.study_sample}_{args.target_col}_comparison_tasks" + + print("=" * 70) + print(f"GENERATING {args.target_col.upper()} COMPARISON TASKS WITH PROPER SPLIT") + print("=" * 70) + print(f"Study: {args.study_sample}") + print(f"Target: {args.target_col}") + print(f"Metadata: {args.meta_path}") + print(f"Images: {args.img_dir}") + print("=" * 70) + + # Load data + print("\n[Step 1] Loading subjects and images...") + meta, image_dict = load_subjects_and_images( + args.meta_path, args.img_dir, args.subject_id_col, args.target_col, args.study_sample, args.max_subjects + ) + print(f"Loaded {len(meta)} subjects with images") + + # Detect task type + if args.task_type == 'auto': + task_type = detect_task_type(meta, args.target_col) + print(f"\n[Step 2] Auto-detected task type: {task_type}") + else: + task_type = args.task_type + print(f"\n[Step 2] Task type: {task_type}") + + # Split subjects + if task_type == 'categorical': + value_to_label, label_to_value = parse_categorical_mapping(meta, args.target_col, args.categorical_mapping) + print(f"\nCategorical mapping:") + for value, label in sorted(value_to_label.items()): + print(f" {value} → {label}") + + meta, value_to_label = remap_categorical_values(meta, args.target_col, value_to_label) + + print("\n[Step 3] Splitting subjects (stratified by class) with COMPLETE SEPARATION...") + splits = split_subjects_categorical( + meta, args.subject_id_col, args.target_col, value_to_label, + args.train_ratio, args.val_ratio, args.seed + ) + else: + print(f"\n{args.target_col} range: {meta[args.target_col].min():.1f} - {meta[args.target_col].max():.1f}") + print("\n[Step 3] Splitting subjects (stratified by value bins) with COMPLETE SEPARATION...") + splits = split_subjects_numerical( + meta, args.subject_id_col, args.target_col, + args.train_ratio, args.val_ratio, args.seed + ) + + print(f"\nTotal subjects:") + print(f" Train: {len(splits['train_query']) + len(splits['train_ref'])}") + print(f" Val: {len(splits['val_query']) + len(splits['val_ref'])}") + print(f" Test: {len(splits['test_query']) + len(splits['test_ref'])}") + + # Generate tasks + os.makedirs(args.output_dir, exist_ok=True) + + if task_type == 'categorical': + # Train: query from train_query, reference from train_ref (COMPLETE SEPARATION) + print("\nGenerating TRAIN tasks (categorical)...") + train_tasks = generate_comparison_tasks_categorical( + splits['train_query'], splits['train_ref'], meta, image_dict, + args.subject_id_col, args.target_col, value_to_label, + args.num_pairs, args.same_class_ratio, args.seed + ) + # Val: query from val_query, reference from val_ref (COMPLETE SEPARATION) + print("\nGenerating VAL tasks (categorical)...") + val_tasks = generate_comparison_tasks_categorical( + splits['val_query'], splits['val_ref'], meta, image_dict, + args.subject_id_col, args.target_col, value_to_label, + args.num_pairs, args.same_class_ratio, args.seed + 1 + ) + # Test: query from test_query, reference from test_ref (COMPLETE SEPARATION) + print("\nGenerating TEST tasks (categorical)...") + test_tasks = generate_comparison_tasks_categorical( + splits['test_query'], splits['test_ref'], meta, image_dict, + args.subject_id_col, args.target_col, value_to_label, + args.num_pairs, args.same_class_ratio, args.seed + 2 + ) + else: + # Train: query from train_query, reference from train_ref (COMPLETE SEPARATION) + print("\nGenerating TRAIN tasks (numerical)...") + train_tasks = generate_comparison_tasks_numerical( + splits['train_query'], splits['train_ref'], meta, image_dict, + args.subject_id_col, args.target_col, args.num_pairs, args.seed + ) + # Val: query from val_query, reference from val_ref (COMPLETE SEPARATION) + print("\nGenerating VAL tasks (numerical)...") + val_tasks = generate_comparison_tasks_numerical( + splits['val_query'], splits['val_ref'], meta, image_dict, + args.subject_id_col, args.target_col, args.num_pairs, args.seed + 1 + ) + # Test: query from test_query, reference from test_ref (COMPLETE SEPARATION) + print("\nGenerating TEST tasks (numerical)...") + test_tasks = generate_comparison_tasks_numerical( + splits['test_query'], splits['test_ref'], meta, image_dict, + args.subject_id_col, args.target_col, args.num_pairs, args.seed + 2 + ) + + # Save + train_path = os.path.join(args.output_dir, f"{args.output_prefix}_train.json") + val_path = os.path.join(args.output_dir, f"{args.output_prefix}_val.json") + test_path = os.path.join(args.output_dir, f"{args.output_prefix}_test.json") + + with open(train_path, 'w') as f: + json.dump(train_tasks, f, indent=2) + with open(val_path, 'w') as f: + json.dump(val_tasks, f, indent=2) + with open(test_path, 'w') as f: + json.dump(test_tasks, f, indent=2) + + print(f"\nSaved: {train_path}") + print(f"Saved: {val_path}") + print(f"Saved: {test_path}") + + # Summary + print("\n" + "=" * 70) + print("SUMMARY") + print("=" * 70) + print(f"Task type: {task_type}") + print(f"Target: {args.target_col}") + print(f"Train tasks: {len(train_tasks)}") + print(f"Val tasks: {len(val_tasks)}") + print(f"Test tasks: {len(test_tasks)}") + print(f"Total: {len(train_tasks) + len(val_tasks) + len(test_tasks)}") + print("=" * 70) + + print("\nSample task:") + print(json.dumps(train_tasks[0], indent=2)) + + +if __name__ == '__main__': + main() diff --git a/BLIP_MRI/project/data/generate_json_sex_comparison_split.py b/BLIP_MRI/project/data/generate_json_sex_comparison_split.py new file mode 100644 index 0000000..8f6c90f --- /dev/null +++ b/BLIP_MRI/project/data/generate_json_sex_comparison_split.py @@ -0,0 +1,500 @@ +""" +Generate JSON files for Sex Comparison Task (Multi-turn Conversation) + +Task: Given a reference image with known sex, predict query image's sex through comparison +Format: 2-turn conversation +- Turn 1: User provides reference image + sex label -> Assistant acknowledges +- Turn 2: User provides query image + asks comparison -> Assistant predicts sex +""" + +import os +import json +import pandas as pd +import glob +import numpy as np +from pathlib import Path +import random + + +def load_subjects_and_images(meta_path, img_dir, subject_id_col, sex_col, study_sample='ABCD'): + """Load metadata and available images""" + + # Load metadata + meta = pd.read_csv(meta_path) + meta = meta[[subject_id_col, sex_col]].dropna() + + # Load available images + image_files = glob.glob(os.path.join(img_dir, '*.nii.gz')) + image_dict = {} + + # Determine suffix length based on study sample + suffix_len = -7 # Remove '.nii.gz' + + for img_path in image_files: + filename = os.path.basename(img_path) + subject_id = filename[:suffix_len] + image_dict[subject_id] = img_path + + # Filter subjects with both metadata and images + meta = meta[meta[subject_id_col].isin(image_dict.keys())].reset_index(drop=True) + + # Remap sex values to 0/1 if needed + unique_sex_values = meta[sex_col].unique() + if set(unique_sex_values).issubset({1, 2}): + meta[sex_col] = meta[sex_col] - 1 + + return meta, image_dict + + +def split_subjects(meta, subject_id_col, sex_col, train_ratio=0.7, val_ratio=0.15, seed=1234): + """ + Split subjects into train/val/test with COMPLETE SEPARATION + + Each split is further divided into query and reference pools (50/50) + to ensure no subject appears as both query and reference. + + Returns: + Dictionary with keys: + - train_query, train_ref + - val_query, val_ref + - test_query, test_ref + """ + + random.seed(seed) + np.random.seed(seed) + + # Separate by sex + males = meta[meta[sex_col] == 0][subject_id_col].values.tolist() + females = meta[meta[sex_col] == 1][subject_id_col].values.tolist() + + random.shuffle(males) + random.shuffle(females) + + # Split males into train/val/test + n_males = len(males) + n_train_males = int(n_males * train_ratio) + n_val_males = int(n_males * val_ratio) + + train_males = males[:n_train_males] + val_males = males[n_train_males:n_train_males+n_val_males] + test_males = males[n_train_males+n_val_males:] + + # Split females into train/val/test + n_females = len(females) + n_train_females = int(n_females * train_ratio) + n_val_females = int(n_females * val_ratio) + + train_females = females[:n_train_females] + val_females = females[n_train_females:n_train_females+n_val_females] + test_females = females[n_train_females+n_val_females:] + + # Further split each set into query and reference (50/50) + def split_query_ref(males_list, females_list): + """Split into query and reference pools""" + # Males + n_m = len(males_list) + query_males = males_list[:n_m//2] + ref_males = males_list[n_m//2:] + + # Females + n_f = len(females_list) + query_females = females_list[:n_f//2] + ref_females = females_list[n_f//2:] + + query = query_males + query_females + ref = ref_males + ref_females + + return query, ref + + train_query, train_ref = split_query_ref(train_males, train_females) + val_query, val_ref = split_query_ref(val_males, val_females) + test_query, test_ref = split_query_ref(test_males, test_females) + + # Print summary + print(f"Total subjects: {len(meta)}") + print(f" Males: {len(males)}, Females: {len(females)}") + + print(f"\nTrain: {len(train_males) + len(train_females)} total") + print(f" Query: {len(train_query)} (Males: {len([s for s in train_query if s in males])}, Females: {len([s for s in train_query if s in females])})") + print(f" Reference: {len(train_ref)} (Males: {len([s for s in train_ref if s in males])}, Females: {len([s for s in train_ref if s in females])})") + + print(f"\nVal: {len(val_males) + len(val_females)} total") + print(f" Query: {len(val_query)} (Males: {len([s for s in val_query if s in males])}, Females: {len([s for s in val_query if s in females])})") + print(f" Reference: {len(val_ref)} (Males: {len([s for s in val_ref if s in males])}, Females: {len([s for s in val_ref if s in females])})") + + print(f"\nTest: {len(test_males) + len(test_females)} total") + print(f" Query: {len(test_query)} (Males: {len([s for s in test_query if s in males])}, Females: {len([s for s in test_query if s in females])})") + print(f" Reference: {len(test_ref)} (Males: {len([s for s in test_ref if s in males])}, Females: {len([s for s in test_ref if s in females])})") + + return { + 'train_query': train_query, + 'train_ref': train_ref, + 'val_query': val_query, + 'val_ref': val_ref, + 'test_query': test_query, + 'test_ref': test_ref + } + + +def generate_comparison_tasks( + query_subjects, + reference_subjects, + meta, + image_dict, + subject_id_col, + sex_col, + num_pairs_per_subject=5, + same_sex_ratio=0.5, + seed=1234 +): + """ + Generate comparison tasks + + Args: + query_subjects: List of subjects to use as queries + reference_subjects: List of subjects to use as references + meta: Full metadata DataFrame + image_dict: Dict mapping subject_id to image path + num_pairs_per_subject: Number of reference pairs per query + same_sex_ratio: Ratio of same-sex vs different-sex comparisons + """ + + random.seed(seed) + + # Filter metadata to query subjects + query_meta = meta[meta[subject_id_col].isin(query_subjects)].reset_index(drop=True) + + # Separate reference subjects by sex + ref_meta = meta[meta[subject_id_col].isin(reference_subjects)] + ref_males = ref_meta[ref_meta[sex_col] == 0][subject_id_col].values.tolist() + ref_females = ref_meta[ref_meta[sex_col] == 1][subject_id_col].values.tolist() + + print(f"\nReference pool: {len(reference_subjects)} subjects") + print(f" Males: {len(ref_males)}, Females: {len(ref_females)}") + + all_tasks = [] + + for _, row in query_meta.iterrows(): + query_id = row[subject_id_col] + query_sex = int(row[sex_col]) + query_sex_label = 'male' if query_sex == 0 else 'female' + query_img_path = image_dict[query_id] + + # Determine how many same-sex vs different-sex pairs + num_same = int(num_pairs_per_subject * same_sex_ratio) + num_diff = num_pairs_per_subject - num_same + + # Sample reference subjects (exclude query itself if in reference pool) + if query_sex == 0: # Query is male + same_pool = [s for s in ref_males if s != query_id] + diff_pool = ref_females + else: # Query is female + same_pool = [s for s in ref_females if s != query_id] + diff_pool = ref_males + + # Sample same-sex references + if len(same_pool) >= num_same: + same_refs = random.sample(same_pool, num_same) + else: + same_refs = same_pool + if len(same_refs) < num_same: + print(f"Warning: Query {query_id} has only {len(same_refs)} same-sex references (requested {num_same})") + + # Sample different-sex references + if len(diff_pool) >= num_diff: + diff_refs = random.sample(diff_pool, num_diff) + else: + diff_refs = diff_pool + if len(diff_refs) < num_diff: + print(f"Warning: Query {query_id} has only {len(diff_refs)} different-sex references (requested {num_diff})") + + # Create tasks for same-sex comparisons + for ref_id in same_refs: + ref_sex = query_sex + ref_sex_label = query_sex_label + ref_img_path = image_dict[ref_id] + + task = create_task( + query_id=query_id, + query_sex=query_sex, + query_sex_label=query_sex_label, + query_img_path=query_img_path, + ref_id=ref_id, + ref_sex=ref_sex, + ref_sex_label=ref_sex_label, + ref_img_path=ref_img_path, + comparison_type='same' + ) + all_tasks.append(task) + + # Create tasks for different-sex comparisons + for ref_id in diff_refs: + ref_sex = 1 - query_sex + ref_sex_label = 'female' if ref_sex == 1 else 'male' + ref_img_path = image_dict[ref_id] + + task = create_task( + query_id=query_id, + query_sex=query_sex, + query_sex_label=query_sex_label, + query_img_path=query_img_path, + ref_id=ref_id, + ref_sex=ref_sex, + ref_sex_label=ref_sex_label, + ref_img_path=ref_img_path, + comparison_type='different' + ) + all_tasks.append(task) + + print(f"Generated {len(all_tasks)} comparison tasks") + + return all_tasks + + +def create_task(query_id, query_sex, query_sex_label, query_img_path, + ref_id, ref_sex, ref_sex_label, ref_img_path, comparison_type): + """Create a single comparison task in JSON format""" + + task_id = f"{query_id}_{comparison_type}_sex_comparison" + + # Generate assistant responses based on comparison + if comparison_type == 'same': + assistant_reasoning = ( + f"Based on comparison with the reference scan, this appears to be a {query_sex_label} subject. " + f"Structural similarities include comparable gray matter volumes and white matter distribution patterns " + f"typical of {query_sex_label} brain anatomy." + ) + else: + assistant_reasoning = ( + f"Based on comparison with the reference scan, this appears to be a {query_sex_label} subject. " + f"Despite being compared with a {ref_sex_label} reference, I observe distinct structural differences " + f"in gray matter distribution and white matter patterns characteristic of {query_sex_label} brain anatomy." + ) + + task = { + "task_id": task_id, + "task_type": "T1", + "subject_ids": [ref_id, query_id], + "modalities": ["sMRI", "sMRI"], + "images": [ + { + "path": ref_img_path, + "token": "", + "modality": "sMRI" + }, + { + "path": query_img_path, + "token": "", + "modality": "sMRI" + } + ], + "conversations": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": f"Here is a T1-weighted brain MRI from a {ref_sex_label} participant. This will serve as your reference scan." + }, + { + "type": "image", + "modality": "sMRI", + "image_path": ref_img_path + } + ] + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": f"Understood. I've analyzed the reference {ref_sex_label} brain scan." + } + ] + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Compare this brain scan with the reference. What is the likely biological sex of this subject?" + }, + { + "type": "image", + "modality": "sMRI", + "image_path": query_img_path + } + ] + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": assistant_reasoning + } + ] + } + ], + "metadata": { + "subject_id": query_id, + "subject_label": query_sex_label, + "subject_label_numeric": query_sex, + "reference_id": ref_id, + "reference_label": ref_sex_label, + "reference_label_numeric": ref_sex, + "comparison_type": comparison_type, + "task": "sex_classification_via_comparison" + } + } + + return task + + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description='Generate sex comparison task JSON with proper train/val/test split (NO DATA LEAKAGE)' + ) + parser.add_argument('--study_sample', type=str, default='ABCD', choices=['ABCD', 'UKB'], + help='Study sample name') + parser.add_argument('--meta_path', type=str, required=True, + help='Path to phenotype CSV file') + parser.add_argument('--img_dir', type=str, required=True, + help='Directory containing MRI images') + parser.add_argument('--output_dir', type=str, default='./data', + help='Output directory for JSON files') + parser.add_argument('--output_prefix', type=str, default='ABCD_sex_comparison_tasks', + help='Output file prefix') + parser.add_argument('--subject_id_col', type=str, default=None, + help='Subject ID column name (default: subjectkey for ABCD, eid for UKB)') + parser.add_argument('--sex_col', type=str, default='sex', + help='Sex column name') + parser.add_argument('--num_pairs', type=int, default=5, + help='Number of comparison pairs per query subject') + parser.add_argument('--same_sex_ratio', type=float, default=0.5, + help='Ratio of same-sex comparisons') + parser.add_argument('--train_ratio', type=float, default=0.7, + help='Train set ratio') + parser.add_argument('--val_ratio', type=float, default=0.15, + help='Validation set ratio') + parser.add_argument('--seed', type=int, default=1234, + help='Random seed') + + args = parser.parse_args() + + # Set default subject ID column if not specified + if args.subject_id_col is None: + if args.study_sample == 'ABCD': + args.subject_id_col = 'subjectkey' + elif args.study_sample == 'UKB': + args.subject_id_col = 'eid' + + print("=" * 70) + print("GENERATING SEX COMPARISON TASKS WITH PROPER SPLIT") + print("=" * 70) + print(f"Study: {args.study_sample}") + print(f"Metadata: {args.meta_path}") + print(f"Images: {args.img_dir}") + print(f"Output: {args.output_dir}/{args.output_prefix}_*.json") + print("=" * 70) + + # Load subjects and images + print("\n Loading subjects and images...") + meta, image_dict = load_subjects_and_images( + meta_path=args.meta_path, + img_dir=args.img_dir, + subject_id_col=args.subject_id_col, + sex_col=args.sex_col, + study_sample=args.study_sample + ) + + # Split subjects into train/val/test with query/reference separation + print("\nSplitting subjects with COMPLETE SEPARATION...") + splits = split_subjects( + meta=meta, + subject_id_col=args.subject_id_col, + sex_col=args.sex_col, + train_ratio=args.train_ratio, + val_ratio=args.val_ratio, + seed=args.seed + ) + + # Generate tasks for each split + os.makedirs(args.output_dir, exist_ok=True) + + # Train: query from train_query, reference from train_ref (NO OVERLAP!) + print("\nGenerating TRAIN tasks...") + train_tasks = generate_comparison_tasks( + query_subjects=splits['train_query'], + reference_subjects=splits['train_ref'], + meta=meta, + image_dict=image_dict, + subject_id_col=args.subject_id_col, + sex_col=args.sex_col, + num_pairs_per_subject=args.num_pairs, + same_sex_ratio=args.same_sex_ratio, + seed=args.seed + ) + + train_path = os.path.join(args.output_dir, f"{args.output_prefix}_train.json") + with open(train_path, 'w') as f: + json.dump(train_tasks, f, indent=2) + print(f"✓ Saved: {train_path}") + + # Val: query from val_query, reference from val_ref (NO OVERLAP!) + print("\nGenerating VAL tasks...") + val_tasks = generate_comparison_tasks( + query_subjects=splits['val_query'], + reference_subjects=splits['val_ref'], + meta=meta, + image_dict=image_dict, + subject_id_col=args.subject_id_col, + sex_col=args.sex_col, + num_pairs_per_subject=args.num_pairs, + same_sex_ratio=args.same_sex_ratio, + seed=args.seed + 1 + ) + + val_path = os.path.join(args.output_dir, f"{args.output_prefix}_val.json") + with open(val_path, 'w') as f: + json.dump(val_tasks, f, indent=2) + print(f"Saved: {val_path}") + + # Test: query from test_query, reference from test_ref (NO OVERLAP!) + print("\nGenerating TEST tasks...") + test_tasks = generate_comparison_tasks( + query_subjects=splits['test_query'], + reference_subjects=splits['test_ref'], + meta=meta, + image_dict=image_dict, + subject_id_col=args.subject_id_col, + sex_col=args.sex_col, + num_pairs_per_subject=args.num_pairs, + same_sex_ratio=args.same_sex_ratio, + seed=args.seed + 2 + ) + + test_path = os.path.join(args.output_dir, f"{args.output_prefix}_test.json") + with open(test_path, 'w') as f: + json.dump(test_tasks, f, indent=2) + print(f"Saved: {test_path}") + + # Summary + print("\n" + "=" * 70) + print("SUMMARY") + print("=" * 70) + print(f"Train tasks: {len(train_tasks)}") + print(f"Val tasks: {len(val_tasks)}") + print(f"Test tasks: {len(test_tasks)}") + print(f"Total tasks: {len(train_tasks) + len(val_tasks) + len(test_tasks)}") + print("=" * 70) + + # Print sample task + print("\nSample TRAIN task:") + print(json.dumps(train_tasks[0], indent=2)) + + +if __name__ == '__main__': + main() diff --git a/BLIP_MRI/project/dataset/__pycache__/__init__.cpython-311.pyc b/BLIP_MRI/project/dataset/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 7a2f529..0000000 Binary files a/BLIP_MRI/project/dataset/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/BLIP_MRI/project/dataset/__pycache__/__init__.cpython-39.pyc b/BLIP_MRI/project/dataset/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 9b3cbeb..0000000 Binary files a/BLIP_MRI/project/dataset/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/BLIP_MRI/project/dataset/__pycache__/datamodule_rsfMRI.cpython-311.pyc b/BLIP_MRI/project/dataset/__pycache__/datamodule_rsfMRI.cpython-311.pyc deleted file mode 100644 index 0df9552..0000000 Binary files a/BLIP_MRI/project/dataset/__pycache__/datamodule_rsfMRI.cpython-311.pyc and /dev/null differ diff --git a/BLIP_MRI/project/dataset/__pycache__/datamodule_rsfMRI.cpython-39.pyc b/BLIP_MRI/project/dataset/__pycache__/datamodule_rsfMRI.cpython-39.pyc deleted file mode 100644 index 55e993d..0000000 Binary files a/BLIP_MRI/project/dataset/__pycache__/datamodule_rsfMRI.cpython-39.pyc and /dev/null differ diff --git a/BLIP_MRI/project/dataset/__pycache__/dataset_T1.cpython-311.pyc b/BLIP_MRI/project/dataset/__pycache__/dataset_T1.cpython-311.pyc deleted file mode 100644 index 1e30807..0000000 Binary files a/BLIP_MRI/project/dataset/__pycache__/dataset_T1.cpython-311.pyc and /dev/null differ diff --git a/BLIP_MRI/project/dataset/__pycache__/dataset_T1.cpython-39.pyc b/BLIP_MRI/project/dataset/__pycache__/dataset_T1.cpython-39.pyc deleted file mode 100644 index 5dde042..0000000 Binary files a/BLIP_MRI/project/dataset/__pycache__/dataset_T1.cpython-39.pyc and /dev/null differ diff --git a/BLIP_MRI/project/dataset/__pycache__/dataset_rsfMRI.cpython-311.pyc b/BLIP_MRI/project/dataset/__pycache__/dataset_rsfMRI.cpython-311.pyc deleted file mode 100644 index 342a18f..0000000 Binary files a/BLIP_MRI/project/dataset/__pycache__/dataset_rsfMRI.cpython-311.pyc and /dev/null differ diff --git a/BLIP_MRI/project/dataset/__pycache__/dataset_rsfMRI.cpython-39.pyc b/BLIP_MRI/project/dataset/__pycache__/dataset_rsfMRI.cpython-39.pyc deleted file mode 100644 index 2246190..0000000 Binary files a/BLIP_MRI/project/dataset/__pycache__/dataset_rsfMRI.cpython-39.pyc and /dev/null differ diff --git a/BLIP_MRI/project/dataset/dataset_T1_LLaVa.py b/BLIP_MRI/project/dataset/dataset_T1_LLaVa.py index 5f089a7..f0955b5 100644 --- a/BLIP_MRI/project/dataset/dataset_T1_LLaVa.py +++ b/BLIP_MRI/project/dataset/dataset_T1_LLaVa.py @@ -74,18 +74,18 @@ def __transform_image__(self, image_file): image = apply_transform(self.image_transform, image, map_items=False) image = torch.tensor(image) return image + - def __transform_text__(self, label, add_context=False, sex=None, age=None): + def __transform_text__(self, label, add_context=False, sex=None, age=None): if len(self.label_names) == 1 and 'sex' in self.label_names: - if int(label) == 1: - inst = f"{self.quest_template} Estimate sex of subject from this image. {self.ans_template} " - answer = f'male' - elif int(label) == 2: - inst = f"{self.quest_template} Estimate sex of subject from this image. {self.ans_template} " - answer = f'female' + sex_text = 'male' if int(label) == 1 else 'female' + inst = f"{self.quest_template} Estimate sex of subject from this image. {self.ans_template}" + answer = f"The brain shows {sex_text} characteristics." + elif len(self.label_names) == 1 and 'age' in self.label_names: - inst = f"{self.quest_template} Estimate age of subject from this image." - answer = f'{self.ans_template} {label}' + inst = f"{self.quest_template} Estimate age of subject from this image. {self.ans_template}" + answer = f"{label} years" + return inst, answer diff --git a/BLIP_MRI/project/dataset/dataset_T1_LLaVaNextInterleave_comparison.py b/BLIP_MRI/project/dataset/dataset_T1_LLaVaNextInterleave_comparison.py new file mode 100644 index 0000000..ee33243 --- /dev/null +++ b/BLIP_MRI/project/dataset/dataset_T1_LLaVaNextInterleave_comparison.py @@ -0,0 +1,393 @@ +""" +Multi-turn Conversation Dataset for Comparison Tasks + +Supports: +- Sex comparison (reference + query) +- Age comparison (reference + query) +- 2-turn conversation format +""" + +import os +import json +import numpy as np +import torch +from torch.utils.data import Dataset + +from monai.data import NibabelReader +from monai.transforms import LoadImage, Randomizable, apply_transform, AddChannel, Compose, Resize, NormalizeIntensity, RandAxisFlip, ToTensor +from monai.utils import MAX_SEED, get_seed + +from utils.utils import to_3tuple + + +class MultiTurnComparisonDataset(Dataset, Randomizable): + """ + Multi-turn conversation dataset for comparison-based tasks + + Format: + Turn 1: User shows reference image + label → Assistant acknowledges + Turn 2: User shows query image + asks question → Assistant predicts + + Args: + json_path: Path to JSON file with comparison tasks + processor: HuggingFace processor + img_size: Image size [H, W, D] + mode: 'train' or 'eval' + """ + + def __init__(self, + json_path=None, + processor=None, + img_size=None, + mode='train'): + + self.json_path = json_path + self.processor = processor + self.tokenizer = processor.tokenizer if processor is not None else None + self.img_size = img_size + self.mode = mode + + # Load JSON data + with open(json_path, 'r') as f: + self.tasks = json.load(f) + + print(f"Loaded {len(self.tasks)} tasks from {json_path}") + + # Define image transform + self.image_transform = self.define_augmentation(mode=mode) + self.image_loader = LoadImage(reader=None, image_only=True, dtype=np.float32) + + self.set_random_state(seed=get_seed()) + self._seed = 0 + + + def define_augmentation(self, mode='train'): + """Define image augmentation""" + img_size = to_3tuple(self.img_size) + if mode == 'train': + transform = Compose([ + AddChannel(), + Resize(img_size), + RandAxisFlip(prob=0.5), + NormalizeIntensity() + ]) + elif mode == 'eval': + transform = Compose([ + AddChannel(), + Resize(img_size), + NormalizeIntensity() + ]) + return transform + + + def randomize(self, data=None) -> None: + self._seed = self.R.randint(MAX_SEED, dtype='uint32') + + + def __transform_image__(self, image_file): + """Load and transform a single image""" + image = self.image_loader(image_file) + if self.image_transform is not None: + if isinstance(self.image_transform, Randomizable): + self.image_transform.set_random_state(seed=self._seed) + image = apply_transform(self.image_transform, image, map_items=False) + image = torch.tensor(image) + return image + + + def __build_conversation_text__(self, task): + """ + Build multi-turn conversation text from task + + Returns: + full_text: Complete conversation in LLaVA-NeXT-Interleave format + answer_start_pos: Position where assistant's final answer starts (for label masking) + """ + + conversations = task['conversations'] + + # Build conversation following Qwen2 format + full_text = "" + + for i, turn in enumerate(conversations): + role = turn['role'] + content_list = turn['content'] + + if role == 'user': + full_text += "<|im_start|>user\n" + elif role == 'assistant': + full_text += "<|im_start|>assistant\n" + + # Process content (text + image tokens) + for content_item in content_list: + if content_item['type'] == 'text': + full_text += content_item['text'] + elif content_item['type'] == 'image': + full_text += "" + + full_text += "<|im_end|>\n" + + return full_text + + + # def __preprocess_as_hf__(self, images, full_text): + # """ + # Tokenize multi-turn conversation and apply instruction masking + + # Args: + # images: List of [ref_image_tensor, query_image_tensor] + # full_text: Complete conversation text + + # Returns: + # Dictionary with pixel_values, input_ids, attention_mask, labels + # """ + # inputs = {} + # inputs['pixel_values'] = {} + # inputs['input_ids'] = {} + # inputs['attention_mask'] = {} + # inputs['labels'] = {} + + # # ========== 핵심 수정! ========== + # # 두 이미지를 개별적으로 batch 차원 추가한 후 합치기 + # # ref_image: [C, H, W, D] → [1, C, H, W, D] + # # query_image: [C, H, W, D] → [1, C, H, W, D] + # # 합치기: [2, C, H, W, D] → 이제 PatchEmbed가 batch=2로 처리 + + # processed_images = [] + # for img in images: + # # Add batch dimension to each image + # processed_images.append(img.unsqueeze(0)) # [1, C, H, W, D] + + # # Concatenate along batch dimension + # batched_images = torch.cat(processed_images, dim=0) # [2, C, H, W, D] + + # inputs['pixel_values']['T1'] = batched_images + # # ================================== + + # # Tokenize full conversation + # full_encoding = self.tokenizer( + # full_text, + # add_special_tokens=True, + # padding='max_length', + # max_length=512, + # truncation=True, + # return_tensors='pt' + # ) + + # input_ids = full_encoding['input_ids'].squeeze(0) + # attention_mask = full_encoding['attention_mask'].squeeze(0) + + # # Initialize labels + # labels = input_ids.clone() + # labels[attention_mask == 0] = -100 # Mask padding + + # # Apply instruction masking: mask everything except the LAST assistant's response + # # We want to train only on the final answer, not on the intermediate "Understood" response + + # # Find all assistant tokens + # assistant_pattern = "<|im_start|>assistant\n" + # assistant_tokens = self.tokenizer.encode(assistant_pattern, add_special_tokens=False) + # assistant_tensor = torch.tensor(assistant_tokens, device=input_ids.device) + + # assistant_positions = [] + # for i in range(len(input_ids) - len(assistant_tokens) + 1): + # if torch.equal(input_ids[i:i+len(assistant_tokens)], assistant_tensor): + # assistant_positions.append(i + len(assistant_tokens)) + + # if len(assistant_positions) >= 2: + # # Mask everything before the LAST assistant response + # last_assistant_pos = assistant_positions[-1] + # labels[:last_assistant_pos] = -100 + # elif len(assistant_positions) == 1: + # # Only one assistant response (shouldn't happen in 2-turn, but handle it) + # labels[:assistant_positions[0]] = -100 + + # inputs['input_ids']['T1'] = input_ids + # inputs['attention_mask']['T1'] = attention_mask + # inputs['labels']['T1'] = labels + + # return inputs + + def __preprocess_as_hf__(self, images, full_text): + """ + Tokenize multi-turn conversation and apply instruction masking + + Args: + images: List of [ref_image_tensor, query_image_tensor] + full_text: Complete conversation text + + Returns: + Dictionary with pixel_values, input_ids, attention_mask, labels + """ + inputs = {} + inputs['pixel_values'] = {} + inputs['input_ids'] = {} + inputs['attention_mask'] = {} + inputs['labels'] = {} + + # Process multiple images for interleave-style multi-image handling + # Each image: [C, H, W, D] = [1, 120, 120, 120] + # Stack along batch dimension: [num_images, C, H, W, D] + # Examples: + # - 2 images (1 ref + 1 query): [2, 1, 120, 120, 120] + # - 4 images (3 refs + 1 query): [4, 1, 120, 120, 120] + # + # This allows PatchEmbedInterleave to process each image independently + # The batch dimension here represents multiple images, NOT multiple samples + # Each will be independently processed through vision encoder + stacked_images = torch.stack(images) # [num_images, 1, 120, 120, 120] + + inputs['pixel_values']['T1'] = stacked_images + + # Tokenize full conversation + full_encoding = self.tokenizer( + full_text, + add_special_tokens=True, + padding='max_length', + max_length=512, + truncation=True, + return_tensors='pt' + ) + + input_ids = full_encoding['input_ids'].squeeze(0) + attention_mask = full_encoding['attention_mask'].squeeze(0) + + # Initialize labels + labels = input_ids.clone() + labels[attention_mask == 0] = -100 # Mask padding + + # Apply instruction masking: mask everything except the LAST assistant's response + # We want to train only on the final answer, not on the intermediate "Understood" response + + # Find all assistant start tokens + assistant_pattern = "<|im_start|>assistant\n" + assistant_tokens = self.tokenizer.encode(assistant_pattern, add_special_tokens=False) + assistant_tensor = torch.tensor(assistant_tokens, device=input_ids.device) + + assistant_positions = [] + for i in range(len(input_ids) - len(assistant_tokens) + 1): + if torch.equal(input_ids[i:i+len(assistant_tokens)], assistant_tensor): + assistant_positions.append(i + len(assistant_tokens)) + + if len(assistant_positions) >= 2: + # Mask everything before the LAST assistant response (including intermediate assistants) + # This includes: all user turns, all previous assistant responses + last_assistant_pos = assistant_positions[-1] + labels[:last_assistant_pos] = -100 + + # Additionally mask all PREVIOUS assistant responses (between first and last) + # Find <|im_end|> tokens to identify where each assistant response ends + im_end_pattern = "<|im_end|>\n" + im_end_tokens = self.tokenizer.encode(im_end_pattern, add_special_tokens=False) + im_end_tensor = torch.tensor(im_end_tokens, device=input_ids.device) + + # Mask intermediate assistant responses (between positions[0] and positions[-1]) + for assistant_start in assistant_positions[:-1]: # All except last + # Find the next <|im_end|> after this assistant start + for j in range(assistant_start, last_assistant_pos): + if j + len(im_end_tokens) <= len(input_ids): + if torch.equal(input_ids[j:j+len(im_end_tokens)], im_end_tensor): + # Mask from assistant_start to end of <|im_end|> + labels[assistant_start:j+len(im_end_tokens)] = -100 + break + elif len(assistant_positions) == 1: + # Only one assistant response (shouldn't happen in 2-turn, but handle it) + labels[:assistant_positions[0]] = -100 + + inputs['input_ids']['T1'] = input_ids + inputs['attention_mask']['T1'] = attention_mask + inputs['labels']['T1'] = labels + + return inputs + + def __len__(self) -> int: + return len(self.tasks) + + + def __getitem__(self, index: int): + """ + Returns a multi-turn comparison sample + + Returns: + Dictionary with: + - pixel_values: Tensor [num_images, C, H, W, D] (dynamically determined from JSON) + - input_ids, attention_mask, labels: Tokenized multi-turn conversation + - modality: 'Comparison' + """ + + task = self.tasks[index] + + # Load ALL images dynamically (supports N references + 1 query) + # JSON format: images = [ref1, ref2, ..., refN, query] + images = [] + for img_info in task['images']: + img_path = img_info['path'] + img_tensor = self.__transform_image__(img_path) + images.append(img_tensor) + + # Build conversation text + full_text = self.__build_conversation_text__(task) + + # Preprocess for model + inputs = self.__preprocess_as_hf__(images=images, full_text=full_text) + # Don't add 'modality' key - trainer extracts modality from dict keys (T1, rsfMRI, etc.) + + return inputs + + +class ComparisonDataModule: + """ + Data module for comparison tasks (train/val/test splits) + """ + + def __init__(self, + train_json=None, + val_json=None, + test_json=None, + processor=None, + img_size=None): + + self.train_json = train_json + self.val_json = val_json + self.test_json = test_json + self.processor = processor + self.img_size = img_size + + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None + + self.setup() + + + def setup(self): + """Create train/val/test datasets""" + + if self.train_json is not None: + self.train_dataset = MultiTurnComparisonDataset( + json_path=self.train_json, + processor=self.processor, + img_size=self.img_size, + mode='train' + ) + print(f"Train: {len(self.train_dataset)} tasks") + + if self.val_json is not None: + self.val_dataset = MultiTurnComparisonDataset( + json_path=self.val_json, + processor=self.processor, + img_size=self.img_size, + mode='eval' + ) + print(f"Val: {len(self.val_dataset)} tasks") + + if self.test_json is not None: + self.test_dataset = MultiTurnComparisonDataset( + json_path=self.test_json, + processor=self.processor, + img_size=self.img_size, + mode='eval' + ) + print(f"Test: {len(self.test_dataset)} tasks") + + return self.train_dataset, self.val_dataset, self.test_dataset diff --git a/BLIP_MRI/project/main_BLLaVaNextInterleave_comparison_hf_joint_T1.py b/BLIP_MRI/project/main_BLLaVaNextInterleave_comparison_hf_joint_T1.py new file mode 100644 index 0000000..0c46fcd --- /dev/null +++ b/BLIP_MRI/project/main_BLLaVaNextInterleave_comparison_hf_joint_T1.py @@ -0,0 +1,257 @@ +""" +Main training script for Multi-turn Comparison Tasks with LLaVA-NeXT-Interleave + +Supports: +- Sex comparison (reference + query → predict sex) +- Numerical values comparison (reference + query → predict age, bmi, ...) +- 2-turn conversation format +""" + +import datetime +import hashlib +from omegaconf import OmegaConf +from omegaconf import ListConfig + +import torch +import transformers +from transformers import Trainer, TrainingArguments +from utils.Trainer_LLaVaNextInterleave_comparison import CustomTrainer +from utils.Trainer_LLaVaNextInterleave_comparison import compute_metrics_with_tokenizer, preprocess_logits_for_metrics + +from utils.data import CustomDataCollatorWithPadding +from dataset.dataset_T1_LLaVaNextInterleave_comparison import ComparisonDataModule + +import os +import wandb + +import warnings +warnings.filterwarnings('ignore') + +def __main__(): + ### setting huggingface verbose + transformers.logging.set_verbosity_info() + + ### make experiment ID + time_hash = datetime.datetime.now().time() + hash_key = hashlib.sha1(str(time_hash).encode()).hexdigest()[:6] + + config = OmegaConf.load("./config/Brain_LLaVa_train_Deepspeed_joint_multiturn_comparison.yaml") + + ### setting logger + wandb.login(key=config.wandb.API_KEY) + os.environ['WANDB_PROJECT'] = "BLIP_sMRI_LLaVA_Next_Interleave_MultiTurn_Comparison" + os.environ["WANDB_RUN_ID"] = f'{hash_key}' + + ### setting seed + transformers.set_seed(config.seed) + + ### setting processor and tokenizer for LLaVA-NeXT-Interleave + from transformers import AutoProcessor + processor = AutoProcessor.from_pretrained("llava-hf/llava-interleave-qwen-0.5b-hf", trust_remote_code=True) + tokenizer = processor.tokenizer + + ### Load comparison task datasets from multiple sources + from utils.data import InterleaveDataset + + train_datasets = [] + eval_datasets = [] + test_datasets = [] + + + # Support multiple JSON files (e.g., ABCD, UKB, etc.) + if isinstance(config.dataset.train_json, (list, ListConfig)): + train_json_list = list(config.dataset.train_json) + val_json_list = list(config.dataset.val_json) + test_json_list = list(config.dataset.test_json) + else: + train_json_list = [config.dataset.train_json] + val_json_list = [config.dataset.val_json] + test_json_list = [config.dataset.test_json] + + for train_json, val_json, test_json in zip(train_json_list, val_json_list, test_json_list): + data_module = ComparisonDataModule( + train_json=train_json, + val_json=val_json, + test_json=test_json, + processor=processor, + img_size=config.dataset.img_size + ) + + if data_module.train_dataset is not None: + train_datasets.append(data_module.train_dataset) + if data_module.val_dataset is not None: + eval_datasets.append(data_module.val_dataset) + if data_module.test_dataset is not None: + test_datasets.append(data_module.test_dataset) + + # Concatenate all datasets + if len(train_datasets) > 1: + train_dataset = InterleaveDataset(train_datasets, shuffle=True, seed=config.seed) + elif len(train_datasets) == 1: + train_dataset = train_datasets[0] + else: + train_dataset = None + + if len(eval_datasets) > 1: + eval_dataset = InterleaveDataset(eval_datasets, shuffle=False, seed=config.seed) + elif len(eval_datasets) == 1: + eval_dataset = eval_datasets[0] + else: + eval_dataset = None + + if len(test_datasets) > 1: + test_dataset = InterleaveDataset(test_datasets, shuffle=False, seed=config.seed) + elif len(test_datasets) == 1: + test_dataset = test_datasets[0] + else: + test_dataset = None + + #### setting model - LLaVA-NeXT-Interleave + # from model.Bblip_t5 import PatchEmbed + from model.Bblip_t5_interleave import PatchEmbedInterleave + from transformers import LlavaForConditionalGeneration + + # Load LLaVA-NeXT-Interleave model (Qwen-based) + model = LlavaForConditionalGeneration.from_pretrained( + "llava-hf/llava-interleave-qwen-0.5b-hf", + # torch_dtype=torch.bfloat16, + # low_cpu_mem_usage=True, + # trust_remote_code=True, + # attn_implementation="eager" + ) + + patch_embed = PatchEmbedInterleave( + T1_size=config.dataset.img_size, + T1_patch_size=config.model.patch_size, + in_chans=1, + embed_dim=int(model.vision_tower.vision_model.embeddings.patch_embedding.out_channels) + ) + + # # Replace vision encoder's patch embedding layer for 3D brain MRI + # patch_embed = PatchEmbedInterleave( + # T1_size=config.dataset.img_size, + # T1_patch_size=config.model.patch_size, + # rsfMRI_size=[96, 96, 96, 24], # Placeholder (not used) + # rsfMRI_patch_size=[16, 16, 16, 3], # Placeholder (not used) + # in_chans=1, + # embed_dim=int(model.vision_tower.vision_model.embeddings.patch_embedding.out_channels)) + + setattr(model.vision_tower.vision_model, "embeddings", patch_embed) + + # Freeze vision encoder except embeddings + for name, param in model.vision_tower.vision_model.named_parameters(): + if 'encoder' in name: + param.requires_grad = False + if 'pre_layernorm' in name: + param.requires_grad = False + if 'post_layernorm' in name: + param.requires_grad = False + if 'embeddings' in name: + param.requires_grad = True + + # Freeze multi-modal projector + for name, param in model.named_parameters(): + if 'multi_modal_projector' in name: + param.requires_grad = False + + # Freeze language model + for name, param in model.named_parameters(): + if 'model.layers' in name: # Qwen2 uses model.layers + param.requires_grad = False + if 'lm_head' in name: + param.requires_grad = False + + # set gradient checkpointing + model.gradient_checkpointing_enable() + + training_args = TrainingArguments( + # basic settings + output_dir=f'./hf_results/{os.environ["WANDB_RUN_ID"]}', + do_train=True, + do_eval=True, + remove_unused_columns=False, + # training + num_train_epochs=config.trainer.max_epochs, + learning_rate=config.trainer.learning_rate, + warmup_steps=config.trainer.warmup_steps, + weight_decay=config.trainer.weight_decay, + per_device_train_batch_size=config.trainer.per_device_batch_size, + per_device_eval_batch_size=config.trainer.per_device_batch_size, + gradient_accumulation_steps=config.trainer.gradient_accumulation_steps, + # # arguments for reducing memory + # bf16=True, + # bf16_full_eval=True, + # for evaluation and loggings + report_to = 'wandb', + logging_dir=f'./hf_logs/{os.environ["WANDB_RUN_ID"]}', + logging_steps=config.trainer.logging_steps, + eval_strategy="steps", + eval_steps=1000, + eval_accumulation_steps=1, + save_steps=1000, + disable_tqdm=False, + # checkpoint saving + save_strategy="steps", + save_total_limit=3, + load_best_model_at_end=True + ) + + # Determine task type for metrics + # Support both single target and multi-target (list) + target_col = config.dataset.get('target_col', None) + task_type = config.dataset.get('task_type', 'categorical') + + if target_col: + # target_col can be string or list + if isinstance(target_col, list): + targets = target_col + print(f"[INFO] Multi-task mode with targets: {targets}") + else: + targets = [target_col] + print(f"[INFO] Single-task mode with target: {target_col}") + else: + print(f"[WARN] No target_col or task_type specified") + + + # Use existing compute_metrics - it already handles long reasoning text! + # The key difference is that multi-turn generates longer responses, + # but the extraction logic (regex search for 'male'/'female' or numbers) is the same + trainer = CustomTrainer( + args=training_args, + model=model, + tokenizer=tokenizer, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + compute_metrics=compute_metrics_with_tokenizer(tokenizer=tokenizer, targets=targets), + data_collator = CustomDataCollatorWithPadding( + tokenizer=tokenizer, + padding=True, + max_length=512 + ), + model_optimization_type = 'joint', + ) + + # training + trainer.train() + + # test + if test_dataset is not None: + trainer.evaluate(eval_dataset=test_dataset, metric_key_prefix='test') + + +if __name__ == '__main__': + __main__() + + +""" +##TODO + +1. Generate JSON files: + python generate_json_general_comparison_split.py + +2. Update config with correct paths + +3. Train: + python main_MultiTurn_Comparison.py +""" diff --git a/BLIP_MRI/project/model/Bblip_t5_interleave.py b/BLIP_MRI/project/model/Bblip_t5_interleave.py new file mode 100644 index 0000000..9363aae --- /dev/null +++ b/BLIP_MRI/project/model/Bblip_t5_interleave.py @@ -0,0 +1,153 @@ +""" +Multi-Image PatchEmbed for LLaVA-NeXT-Interleave style processing + +This module supports processing multiple 3D brain MRI images independently, +similar to how LLaVA-NeXT-Interleave handles multiple 2D images. + +Key differences from Bblip_t5.py: +- Supports batch dimension containing multiple images (e.g., reference + query) +- Each image is processed independently through the same patch embedding layer +- Returns concatenated features that can be interleaved in the language model +""" + +import torch +import torch.nn as nn +from timm.models.layers import trunc_normal_ + + +class PatchEmbedInterleave(nn.Module): + """ + Image to Patch Embedding with Multi-Image Support + + Supports processing multiple 3D brain MRI images where the batch dimension + represents individual images that should be processed independently. + + Args: + T1_size: Size of T1 image [D, H, W] + T1_patch_size: Patch size for T1 [pD, pH, pW] + in_chans: Number of input channels (default: 1) + embed_dim: Embedding dimension + dtype: Data type for parameters + """ + + def __init__(self, + T1_size=[120, 120, 120], + T1_patch_size=[10, 10, 10], + in_chans=1, + embed_dim=1152, # SigLIP hidden_size for llava-interleave-qwen-0.5b-hf + dtype=torch.float32): + super().__init__() + self.embed_dim = embed_dim + + # Patchifying layer for T1 images + T1_num_patches = (T1_size[0] // T1_patch_size[0]) * \ + (T1_size[1] // T1_patch_size[1]) * \ + (T1_size[2] // T1_patch_size[2]) + + self.T1_grid_size = ( + T1_size[0] // T1_patch_size[0], + T1_size[1] // T1_patch_size[1], + T1_size[2] // T1_patch_size[2] + ) + self.T1_size = T1_size + self.T1_patch_size = T1_patch_size + self.T1_num_patches = T1_num_patches + + # Convolutional projection layer + self.T1_proj = nn.Conv3d( + in_chans, + embed_dim, + kernel_size=T1_patch_size, + stride=T1_patch_size, + dtype=dtype + ) + + # Positional embeddings + self.T1_positional_embeddings = nn.Parameter( + torch.zeros(1, T1_num_patches, embed_dim) + ) + trunc_normal_(self.T1_positional_embeddings, std=.02) + + + def forward_embeddings(self, x): + """ + Process 3D brain MRI through patch embedding + + Args: + x: Input tensor of shape [B, C, D, H, W] + B can be batch_size OR batch_size * num_images + Each image along the B dimension is processed independently + + Returns: + Patch embeddings of shape [B, num_patches, embed_dim] + """ + if len(x.shape) == 5: + B, C, D, H, W = x.shape + + # Validate input size + assert D == self.T1_size[0] and H == self.T1_size[1] and W == self.T1_size[2], \ + f"Input image size ({D}*{H}*{W}) doesn't match model ({self.T1_size[0]}*{self.T1_size[1]}*{self.T1_size[2]})." + + # Apply convolutional projection + # Input: [B, C, D, H, W] + # Output: [B, embed_dim, grid_D, grid_H, grid_W] + x = self.T1_proj(x) + + # Flatten spatial dimensions and transpose + # [B, embed_dim, grid_D, grid_H, grid_W] -> [B, embed_dim, num_patches] + x = x.flatten(2) + + # [B, embed_dim, num_patches] -> [B, num_patches, embed_dim] + x = x.transpose(1, 2) + + # Add positional embeddings + # Positional embeddings are shared across all images in the batch + x = x + self.T1_positional_embeddings + + return x # [B, num_patches, embed_dim] + else: + raise ValueError(f"Expected 5D tensor [B, C, D, H, W], got shape {x.shape}") + + + def forward(self, x, interpolate_pos_encoding=False): + """ + Forward pass supporting both single and multi-image inputs + + Args: + x: Input tensor, can be: + - [B, C, D, H, W]: Standard batch of images + - [1, num_images, C, D, H, W]: Batch with multiple images per sample + interpolate_pos_encoding: Not used, for API compatibility + + Returns: + Patch embeddings [B*num_images, num_patches, embed_dim] + + Note: + Unlike the original PatchEmbed, this version does NOT concatenate + embeddings from multiple images along the batch dimension. + Instead, it keeps them separate so they can be interleaved properly + with text tokens in the language model. + """ + if isinstance(x, dict): + # Handle dict input (multi-modality case) + # For multi-turn comparison, we only use 'T1' modality + raise NotImplementedError( + "Multi-modality dict input not supported in PatchEmbedInterleave. " + "Use separate forward passes for each modality." + ) + else: + # Check if input has extra batch dimension from data collator + if len(x.shape) == 6: + # Shape: [batch_size, num_images, C, D, H, W] + # Reshape to: [batch_size * num_images, C, D, H, W] + batch_size, num_images, C, D, H, W = x.shape + x = x.reshape(batch_size * num_images, C, D, H, W) + + # Process all images in the batch + outputs = self.forward_embeddings(x) + return outputs + + + def get_num_patches(self): + """Return the number of patches per image""" + return self.T1_num_patches diff --git a/BLIP_MRI/project/model/__pycache__/Bblip_t5.cpython-311.pyc b/BLIP_MRI/project/model/__pycache__/Bblip_t5.cpython-311.pyc index 2bdc2b7..ffef048 100644 Binary files a/BLIP_MRI/project/model/__pycache__/Bblip_t5.cpython-311.pyc and b/BLIP_MRI/project/model/__pycache__/Bblip_t5.cpython-311.pyc differ diff --git a/BLIP_MRI/project/model/__pycache__/__init__.cpython-311.pyc b/BLIP_MRI/project/model/__pycache__/__init__.cpython-311.pyc index 987014c..ca6bd95 100644 Binary files a/BLIP_MRI/project/model/__pycache__/__init__.cpython-311.pyc and b/BLIP_MRI/project/model/__pycache__/__init__.cpython-311.pyc differ diff --git a/BLIP_MRI/project/utils/Trainer.py b/BLIP_MRI/project/utils/Trainer.py index c525c52..b102404 100644 --- a/BLIP_MRI/project/utils/Trainer.py +++ b/BLIP_MRI/project/utils/Trainer.py @@ -40,8 +40,7 @@ def preprocess_logits_for_metrics(logits, labels): pred_ids = torch.argmax(logits, dim=-1) return pred_ids - -@torch.no_grad() +# @torch.no_grad() def compute_metrics_with_tokenizer(tokenizer): def compute_metrics(eval_preds): predictions, labels = eval_preds @@ -53,10 +52,9 @@ def compute_metrics(eval_preds): pred_genders = [] true_genders = [] + import re for pred in decoded_preds: pred_clean = pred.lower().strip() - - import re if re.search(r'\bfemale\b', pred_clean): pred_genders.append(1) elif re.search(r'\bmale\b', pred_clean): @@ -66,8 +64,6 @@ def compute_metrics(eval_preds): for label in decoded_labels: label_clean = label.lower().strip() - - import re if re.search(r'\bfemale\b', label_clean): true_genders.append(1) elif re.search(r'\bmale\b', label_clean): @@ -75,20 +71,52 @@ def compute_metrics(eval_preds): else: true_genders.append(-1) + # Valid pairs valid_pairs = [(p, t) for p, t in zip(pred_genders, true_genders) if p != -1 and t != -1] if valid_pairs: valid_preds, valid_trues = zip(*valid_pairs) - accuracy = balanced_accuracy_score(valid_trues, valid_preds) - f1 = f1_score(valid_trues, valid_preds, average='macro') + valid_accuracy = balanced_accuracy_score(valid_trues, valid_preds) + valid_f1 = f1_score(valid_trues, valid_preds, average='macro') + else: + valid_accuracy = 0.0 + valid_f1 = 0.0 + + # Overall 메트릭 (invalid를 오답 처리) + overall_preds = [] + overall_trues = [] + + for p, t in zip(pred_genders, true_genders): + if t != -1: # ground truth가 유효한 경우만 + overall_trues.append(t) + if p == -1: + overall_preds.append(1 - t) + # overall_preds.append(-1) + else: + overall_preds.append(p) + + if overall_preds: + overall_accuracy = balanced_accuracy_score(overall_trues, overall_preds) + overall_f1 = f1_score(overall_trues, overall_preds, average='macro') else: - accuracy = 0.0 - f1 = 0.0 + overall_accuracy = 0.0 + overall_f1 = 0.0 + + total_samples = len(pred_genders) + invalid_predictions = pred_genders.count(-1) + response_rate = (total_samples - invalid_predictions) / total_samples if total_samples > 0 else 0 metrics = { - 'accuracy': accuracy, - 'f1': f1 - } + 'accuracy': valid_accuracy, + 'f1': valid_f1, + 'overall_accuracy': overall_accuracy, + 'overall_f1': overall_f1, + 'response_rate': response_rate, + 'valid_samples': len(valid_pairs), + 'total_samples': total_samples, + 'invalid_predictions': invalid_predictions + } + return metrics return compute_metrics @@ -297,9 +325,9 @@ def compute_loss(self, model, inputs, return_outputs=False): def training_step(self, model, inputs): loss = super().training_step(model, inputs) - # generation result - if self.state.global_step % 50 == 0 and self.state.global_step > 0: - self.log_generated_result(model, inputs) + # # generation result + # if self.state.global_step % 50 == 0 and self.state.global_step > 0: + # self.log_generated_result(model, inputs, mode="training") # Log gradients at logging steps modalities = list(inputs.keys()) @@ -476,35 +504,60 @@ def prediction_step( print(f" - logits shape: {logits.shape if logits is not None else None}") print(f" - labels shape: {labels.shape if labels is not None else None}") + # Log generated result during evaluation (first sample of each eval) + if not prediction_loss_only and not hasattr(self, '_eval_generation_logged'): + self._eval_generation_logged = True + self.log_generated_result(model, inputs, mode="evaluation") + return (loss, logits, labels) - def log_generated_result(self, model, inputs): + def log_generated_result(self, model, inputs, mode="training"): + """ + Log generated result during training or evaluation + + Args: + model: The model to use for generation + inputs: Input dictionary (wrapped or unwrapped) + mode: "training" or "evaluation" + """ actual_model = model.module if hasattr(model, 'module') else model - - actual_model.eval() + + # Only set eval mode for training (already in eval during evaluation) + if mode == "training": + actual_model.eval() + with torch.no_grad(): try: - modality = list(inputs.keys())[0] - sample_input = inputs[modality] - + # Handle input format (different for training vs evaluation) + if 'pixel_values' in inputs and 'input_ids' in inputs: + sample_input = inputs + else: + # Still wrapped in modality key (typical for training) + modality_keys = [k for k in inputs.keys() if k in ['T1', 'rsfMRI']] + if modality_keys: + sample_input = inputs[modality_keys[0]] + else: + sample_input = inputs + input_ids = sample_input['input_ids'][0] - + # Search ASSISTANT: token assistant_tokens = self.tokenizer.encode("ASSISTANT:", add_special_tokens=False) assistant_pos = None - + for i in range(len(input_ids) - len(assistant_tokens)): - if torch.equal(input_ids[i:i+len(assistant_tokens)], + if torch.equal(input_ids[i:i+len(assistant_tokens)], torch.tensor(assistant_tokens, device=input_ids.device)): assistant_pos = i + len(assistant_tokens) break - + if assistant_pos is None: - print("Warning: ASSISTANT: not found in input") + print(f"[WARN] ASSISTANT: not found in {mode} input") return - + prompt_ids = input_ids[:assistant_pos].unsqueeze(0) - + + # Generate generated_ids = actual_model.generate( pixel_values=sample_input['pixel_values'][0:1], input_ids=prompt_ids, @@ -513,37 +566,70 @@ def log_generated_result(self, model, inputs): temperature=0.1, pad_token_id=self.tokenizer.pad_token_id, ) - + generated_only = generated_ids[0][len(prompt_ids[0]):] generated_text = self.tokenizer.decode(generated_only, skip_special_tokens=True) - + + # Build result dictionary result = { + "type": mode, "step": self.state.global_step, "epoch": float(self.state.epoch) if hasattr(self.state, 'epoch') else 0, "generated_text": generated_text, } - + + # Add ground truth for evaluation mode + if mode == "evaluation": + labels = sample_input.get('labels', None) + if labels is not None: + labels_clean = labels[0].clone() + labels_clean[labels_clean == -100] = self.tokenizer.pad_token_id + ground_truth = self.tokenizer.decode(labels_clean, skip_special_tokens=True) + else: + ground_truth = "N/A" + result["ground_truth"] = ground_truth + + # Save to JSON json_file = "generation_logs.json" if os.path.exists(json_file): with open(json_file, 'r') as f: logs = json.load(f) else: logs = [] - + logs.append(result) - + with open(json_file, 'w') as f: json.dump(logs, f, indent=2, ensure_ascii=False) - print(f"Step: {self.state.global_step}") - print(f"Generated: {generated_text}") + # Print output + prefix = "[TRAIN]" if mode == "training" else "[EVAL]" + if mode == "evaluation": + print("\n" + "="*80) + print(f"{prefix} Step: {self.state.global_step}, Epoch: {result['epoch']}") + print(f"{prefix} Generated: {generated_text}") + print(f"{prefix} Ground Truth: {result.get('ground_truth', 'N/A')}") + print("="*80 + "\n") + else: + print(f"{prefix} Step: {self.state.global_step}") + print(f"{prefix} Generated: {generated_text}") except Exception as e: - print(f"[ERROR] Generation failed: {e}") + print(f"[ERROR] {mode.capitalize()} generation failed: {e}") import traceback traceback.print_exc() - - actual_model.train() + + # Restore train mode only if we changed it + if mode == "training": + actual_model.train() + + def evaluation_loop(self, *args, **kwargs): + """Override to reset generation flag at start of each evaluation""" + # Reset flag so we log generation once per eval + if hasattr(self, '_eval_generation_logged'): + delattr(self, '_eval_generation_logged') + + return super().evaluation_loop(*args, **kwargs) \ No newline at end of file diff --git a/BLIP_MRI/project/utils/Trainer_LLaVaNextInterleave_comparison.py b/BLIP_MRI/project/utils/Trainer_LLaVaNextInterleave_comparison.py new file mode 100644 index 0000000..844e3f6 --- /dev/null +++ b/BLIP_MRI/project/utils/Trainer_LLaVaNextInterleave_comparison.py @@ -0,0 +1,893 @@ +import os +import json +import numpy as np + +import torch +from torch import nn +from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler +from torch.utils.data.distributed import DistributedSampler +import datasets + +from transformers import Trainer +from transformers.trainer_utils import has_length, seed_worker +from transformers.training_args import ParallelMode +from transformers.utils import ( + is_datasets_available, + is_sagemaker_mp_enabled, + +) +from transformers.trainer_pt_utils import ( + DistributedLengthGroupedSampler, + DistributedSamplerWithLoop, + LengthGroupedSampler, + SequentialDistributedSampler, + nested_detach, + IterableDatasetShard, + +) + +from sklearn.metrics import balanced_accuracy_score, f1_score +from dataclasses import dataclass + + +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union + + +def preprocess_logits_for_metrics(logits, labels): + if isinstance(logits, tuple): + logits = logits[0] + + pred_ids = torch.argmax(logits, dim=-1) + return pred_ids + +# @torch.no_grad() +def compute_metrics_with_tokenizer(tokenizer, targets): + """ + Automatically compute metrics based on target types. + Categorical: sex -> accuracy, f1 + Numerical: age, bmi, glucose -> MAE, RMSE + """ + import re + from sklearn.metrics import mean_absolute_error, mean_squared_error + + def compute_metrics(eval_preds): + predictions, labels = eval_preds + labels = np.where(labels != -100, labels, tokenizer.pad_token_id) + + decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + + metrics = {} + total_samples = len(decoded_preds) + + # Check if single-task or multi-task + is_single_task = len(targets) == 1 + + # Determine task type + task_name = targets[0] if is_single_task else None + is_sex_task = task_name and 'sex' in task_name.lower() + + # Check if numerical (simple heuristic: common numerical task names) + numerical_keywords = ['age', 'bmi', 'glucose', 'weight', 'height', 'score'] + is_numerical_task = task_name and any(kw in task_name.lower() for kw in numerical_keywords) + + if is_single_task and is_sex_task: + # Sex classification task + pred_genders = [] + true_genders = [] + + for pred in decoded_preds: + pred_clean = pred.lower().strip() + if re.search(r'\bfemale\b', pred_clean): + pred_genders.append(1) + elif re.search(r'\bmale\b', pred_clean): + pred_genders.append(0) + else: + pred_genders.append(-1) + + for label in decoded_labels: + label_clean = label.lower().strip() + if re.search(r'\bfemale\b', label_clean): + true_genders.append(1) + elif re.search(r'\bmale\b', label_clean): + true_genders.append(0) + else: + true_genders.append(-1) + + # Valid pairs (only for metrics on valid predictions) + valid_pairs = [(p, t) for p, t in zip(pred_genders, true_genders) if p != -1 and t != -1] + + if valid_pairs: + valid_preds, valid_trues = zip(*valid_pairs) + valid_accuracy = balanced_accuracy_score(valid_trues, valid_preds) + valid_f1 = f1_score(valid_trues, valid_preds, average='macro') + else: + valid_accuracy = 0.0 + valid_f1 = 0.0 + + # Overall metrics (treat invalid as wrong answer) + overall_preds = [] + overall_trues = [] + + for p, t in zip(pred_genders, true_genders): + if t != -1: # Only when ground truth is valid + overall_trues.append(t) + if p == -1: + # Treat invalid as wrong (flip the answer) + overall_preds.append(1 - t) + else: + overall_preds.append(p) + + if overall_preds: + overall_accuracy = balanced_accuracy_score(overall_trues, overall_preds) + overall_f1 = f1_score(overall_trues, overall_preds, average='macro') + else: + overall_accuracy = 0.0 + overall_f1 = 0.0 + + invalid_predictions = pred_genders.count(-1) + response_rate = (total_samples - invalid_predictions) / total_samples if total_samples > 0 else 0 + + metrics = { + 'accuracy': valid_accuracy, + 'f1': valid_f1, + 'overall_accuracy': overall_accuracy, + 'overall_f1': overall_f1, + 'response_rate': response_rate, + 'valid_samples': len(valid_pairs), + 'total_samples': total_samples, + 'invalid_predictions': invalid_predictions + } + + elif is_single_task and is_numerical_task: + # Single numerical task (e.g., age, bmi, glucose) + task_name = targets[0] + pred_values = [] + true_values = [] + + for pred in decoded_preds: + pred_clean = pred.strip() + # Extract first number + match = re.search(r'(-?\d+\.?\d*)', pred_clean) + if match: + pred_values.append(float(match.group(1))) + else: + pred_values.append(-1) + + for label in decoded_labels: + label_clean = label.strip() + match = re.search(r'(-?\d+\.?\d*)', pred_clean) + if match: + true_values.append(float(match.group(1))) + else: + true_values.append(-1) + + # Valid pairs + valid_pairs = [(p, t) for p, t in zip(pred_values, true_values) if p != -1 and t != -1] + + if valid_pairs: + valid_preds, valid_trues = zip(*valid_pairs) + mae = mean_absolute_error(valid_trues, valid_preds) + rmse = np.sqrt(mean_squared_error(valid_trues, valid_preds)) + else: + mae = 0.0 + rmse = 0.0 + + invalid_predictions = pred_values.count(-1) + response_rate = (total_samples - invalid_predictions) / total_samples if total_samples > 0 else 0 + + metrics = { + f'{task_name}_mae': mae, + f'{task_name}_rmse': rmse, + 'response_rate': response_rate, + 'valid_samples': len(valid_pairs), + 'total_samples': total_samples, + 'invalid_predictions': invalid_predictions + } + + elif is_single_task: + # Other categorical tasks (not sex) + # Extract unique labels from ground truth + print(f"[INFO] Generic categorical task detected: {task_name}") + print(f"[INFO] Extracting labels from ground truth...") + + # First pass: collect all possible labels from ground truth + all_labels = set() + for label in decoded_labels: + label_clean = label.lower().strip() + # Try to extract label after common patterns + # Pattern 1: "appears to be X" + match = re.search(r'appears to be\s+(\w+)', label_clean) + if match: + all_labels.add(match.group(1)) + # Pattern 2: "is X" + elif re.search(r'\bis\s+(\w+)', label_clean): + match = re.search(r'\bis\s+(\w+)', label_clean) + all_labels.add(match.group(1)) + # Pattern 3: just the label itself (e.g., "control", "patient") + else: + words = label_clean.split() + if len(words) == 1: # Only single word + all_labels.add(words[0]) + + # Create label to idx mapping + label_to_idx = {label: idx for idx, label in enumerate(sorted(all_labels))} + print(f"[INFO] Detected labels: {label_to_idx}") + + pred_values = [] + true_values = [] + + for pred in decoded_preds: + pred_clean = pred.lower().strip() + found = False + for label_text in label_to_idx.keys(): + if re.search(rf'\b{label_text}\b', pred_clean): + pred_values.append(label_to_idx[label_text]) + found = True + break + if not found: + pred_values.append(-1) + + for label in decoded_labels: + label_clean = label.lower().strip() + found = False + for label_text in label_to_idx.keys(): + if re.search(rf'\b{label_text}\b', label_clean): + true_values.append(label_to_idx[label_text]) + found = True + break + if not found: + true_values.append(-1) + + # Valid pairs + valid_pairs = [(p, t) for p, t in zip(pred_values, true_values) if p != -1 and t != -1] + + if valid_pairs: + valid_preds, valid_trues = zip(*valid_pairs) + accuracy = balanced_accuracy_score(valid_trues, valid_preds) + f1 = f1_score(valid_trues, valid_preds, average='macro') + else: + accuracy = 0.0 + f1 = 0.0 + + invalid_predictions = pred_values.count(-1) + response_rate = (total_samples - invalid_predictions) / total_samples if total_samples > 0 else 0 + + metrics = { + 'accuracy': accuracy, + 'f1': f1, + 'response_rate': response_rate, + 'valid_samples': len(valid_pairs), + 'total_samples': total_samples, + 'invalid_predictions': invalid_predictions + } + + # else: + # # Multi-task + # for task_name in targets: + # if task_name in categorical_tasks: + # # Categorical task + # pattern = rf'{task_name}:\s*(\w+)' + # pred_values = [] + # true_values = [] + + # for pred in decoded_preds: + # pred_clean = pred.lower().strip() + # match = re.search(pattern, pred_clean) + # if match: + # label_text = match.group(1) + # if label_text in categorical_tasks[task_name]: + # pred_values.append(categorical_tasks[task_name][label_text]) + # else: + # pred_values.append(-1) + # else: + # pred_values.append(-1) + + # for label in decoded_labels: + # label_clean = label.lower().strip() + # match = re.search(pattern, label_clean) + # if match: + # label_text = match.group(1) + # if label_text in categorical_tasks[task_name]: + # true_values.append(categorical_tasks[task_name][label_text]) + # else: + # true_values.append(-1) + # else: + # true_values.append(-1) + + # valid_pairs = [(p, t) for p, t in zip(pred_values, true_values) if p != -1 and t != -1] + + # if valid_pairs: + # valid_preds, valid_trues = zip(*valid_pairs) + # accuracy = balanced_accuracy_score(valid_trues, valid_preds) + # f1 = f1_score(valid_trues, valid_preds, average='macro') + # else: + # accuracy = 0.0 + # f1 = 0.0 + + # metrics[f'{task_name}_accuracy'] = accuracy + # metrics[f'{task_name}_f1'] = f1 + + # else: + # # Numerical task + # pattern = rf'{task_name}:\s*(\d+\.?\d*)' + # pred_values = [] + # true_values = [] + + # for pred in decoded_preds: + # pred_clean = pred.strip() + # match = re.search(pattern, pred_clean) + # if match: + # pred_values.append(float(match.group(1))) + # else: + # pred_values.append(-1) + + # for label in decoded_labels: + # label_clean = label.strip() + # match = re.search(pattern, label_clean) + # if match: + # true_values.append(float(match.group(1))) + # else: + # true_values.append(-1) + + # valid_pairs = [(p, t) for p, t in zip(pred_values, true_values) if p != -1 and t != -1] + + # if valid_pairs: + # valid_preds, valid_trues = zip(*valid_pairs) + # mae = mean_absolute_error(valid_trues, valid_preds) + # rmse = np.sqrt(mean_squared_error(valid_trues, valid_preds)) + # else: + # mae = 0.0 + # rmse = 0.0 + + # metrics[f'{task_name}_mae'] = mae + # metrics[f'{task_name}_rmse'] = rmse + + # # Overall response rate + # all_pred_values = [] + # for pred in decoded_preds: + # valid = True + # for task_name in targets: + # if task_name in categorical_tasks: + # pattern = rf'{task_name}:\s*(\w+)' + # else: + # pattern = rf'{task_name}:\s*(\d+\.?\d*)' + # if not re.search(pattern, pred.lower()): + # valid = False + # break + # all_pred_values.append(1 if valid else -1) + + # invalid_predictions = all_pred_values.count(-1) + # response_rate = (total_samples - invalid_predictions) / total_samples if total_samples > 0 else 0 + + # metrics['response_rate'] = response_rate + # metrics['total_samples'] = total_samples + # metrics['invalid_predictions'] = invalid_predictions + + return metrics + + return compute_metrics + + +class CustomTrainer(Trainer): + """ + Modified based on https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L294 + """ + + def __init__(self, model_optimization_type='sequential', *args, **kwargs): + # Set static graph for DDP + super().__init__(*args, **kwargs) + self._static_graph_set = False + self.model_optimization_type= model_optimization_type + + + def _ensure_set_static_graph(self, model): + if not self._static_graph_set and self.is_in_train: + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + model._set_static_graph() + self._static_graph_set = True + + + def repack_inputs_except_for_pixel_values(self, inputs, modalities): + """ + inputs = + { + 'T1': + { + 'pixel_values': torch.tensor([torch.tensor]), + 'input_ids': torch.tensor([torch.tensor]), + 'attention_mask': torch.tensor([torch.tensor]), + 'labels': torch.tensor([torch.tensor]) + }, + + 'rsfMRI': + { + 'pixel_values': torch.tensor([torch.tensor]), + 'input_ids': torch.tensor([torch.tensor]), + 'attention_mask': torch.tensor([torch.tensor]), + 'labels': torch.tensor([torch.tensor]) + }, + + } + + outputs = + { + 'pixel_values': { + 'T1': torch.tensor([torch.tensor]), + 'rsfMRI':torch.tensor([torch.tensor]), + } + 'input_ids': torch.tensor([torch.tensor]), + 'attention_mask': torch.tensor([torch.tensor]), + 'labels': torch.tensor([torch.tensor]), + + } + """ + assert len(modalities) > 1 + + outputs = {} + outputs['pixel_values'] = {} + outputs['input_ids'] = [] + outputs['attention_mask'] = [] + outputs['labels'] = [] + + for modality in modalities: + modality_data = inputs[modality] + #print(modality_data) + outputs['pixel_values'][modality] = modality_data['pixel_values'] + outputs['input_ids'].append(modality_data['input_ids']) + outputs['attention_mask'].append(modality_data['attention_mask']) + outputs['labels'].append(modality_data['labels']) + + outputs['input_ids'] = torch.cat(outputs['input_ids'], dim=0) + outputs['attention_mask'] = torch.cat(outputs['attention_mask'], dim=0) + outputs['labels'] = torch.cat(outputs['labels'], dim=0) + + return outputs + + + def _compute_modality_loss(self, model, inputs, labels=None): + """Helper function to compute loss for a single modality""" + outputs = model(**inputs) + + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index] + + if labels is not None: + unwrapped_model = self.accelerator.unwrap_model(model) + model_name = unwrapped_model.base_model.model._get_name() if _is_peft_model(unwrapped_model) else unwrapped_model._get_name() + loss = self.label_smoother(outputs, labels, shift_labels=model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values()) + else: + if isinstance(outputs, dict) and "loss" not in outputs: + raise ValueError(f"Model did not return loss. Got keys: {','.join(outputs.keys())}") + loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + + return loss, outputs + + + def _compute_dummy_gradient(self, model, active_modality): + """Compute dummy gradient for inactive modality parameters.""" + skip_modality = 'rsfMRI' if active_modality == 'T1' else 'T1' + + # Get embeddings module + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + base_model = model.module + else: + base_model = model + + embeddings = (base_model.vision_tower.vision_model.embeddings + if hasattr(base_model, 'vision_tower') + else base_model.vision_model.embeddings) # vision_tower is for LLaVA + + # Compute dummy loss + dummy_loss = 0. + for name, param in embeddings.named_parameters(): + if skip_modality in name: + dummy_loss += param.sum() * 0. + + return dummy_loss + + + def _compute_loss_with_labels(self, model, inputs): + """Compute loss handling both label_smoother and direct cases.""" + # Extract labels if using label smoother + if self.label_smoother and "labels" in inputs: + labels = inputs.pop("labels") + return self._compute_modality_loss(model, inputs, labels) + + outputs = model(**inputs) + + # Extract loss from various output formats + if hasattr(outputs, 'loss'): + loss = outputs.loss + elif isinstance(outputs, dict) and 'loss' in outputs: + loss = outputs['loss'] + elif isinstance(outputs, (tuple, list)) and len(outputs) > 0: + loss = outputs[0] + else: + raise ValueError(f"Model did not return a loss. Output type: {type(outputs)}") + + return loss, outputs + + + def compute_loss(self, model, inputs, return_outputs=False): + #TODO + #현재 방식의 코드에서는 태생적으로 순차적으로 두개의 모달리티로부터 각각 loss를 얻어서 합한 loss로 최적화할 수가 없다. + #왜냐하면 한개의 모달리티로부터 Loss를 얻기 위해서는 patch layer를 제외한 나머지 layer들을 전부 거쳐야하는데, 이렇게 하고 나면 거쳐간 layer들을 업데이트하지 않은 상태에서 두번째 모달리티의 데이터가 이런 layer들을 거치게 되면서 backward()에서 에러가 발생한다. + #그런데 흥미로운 점은 x-instruct-BLIP 페이퍼에서는 다양한 모달리티로부터 얻은 Loss들을 joint optimization하지 않아도 multi-modal network를 학습할 수 있음을 보였다. + #다만, OneLLM은 애초에 라우팅하는 것을 특장점으로 삼았기 때문에 joint optimization을 한다 + # joint optimization을 위해서는 BLIP2의 원래 코드를 짜고, 그 코드 위에다가 weight를 얹는 방식으로 진행해야할 것 같다. + + """ + How the loss is computed by Trainer. By default, all models return the loss in the first element. + + Subclass and override for custom behavior. + + inputs = + { + 'T1': + { + 'pixel_values': torch.tensor([torch.tensor]), + 'input_ids': torch.tensor([torch.tensor]), + 'attention_mask': torch.tensor([torch.tensor]), + 'labels': torch.tensor([torch.tensor]) + }, + + 'rsfMRI': + { + 'pixel_values': torch.tensor([torch.tensor]), + 'input_ids': torch.tensor([torch.tensor]), + 'attention_mask': torch.tensor([torch.tensor]), + 'labels': torch.tensor([torch.tensor]) + }, + + } + + """ + self._ensure_set_static_graph(model) + total_loss = 0. + outputs = None + modalities = list(inputs.keys()) + + if len(modalities) == 1: + # Single modality: add dummy gradient for stability + modality = modalities[0] + inputs_single = inputs[modality].copy() + + # Dummy loss for unused modality parameters + dummy_loss = self._compute_dummy_gradient(model, modality) + + # Compute actual loss + loss, outputs = self._compute_loss_with_labels(model, inputs_single) + total_loss = dummy_loss + loss + + else: # len(modalities) >= 2 + # Multiple modalities: repack and compute + inputs_repacked = self.repack_inputs_except_for_pixel_values(inputs, modalities) + loss, outputs = self._compute_loss_with_labels(model, inputs_repacked) + total_loss = loss + + return (total_loss, outputs) if return_outputs else total_loss + + + def training_step(self, model, inputs): + loss = super().training_step(model, inputs) + + # generation result + if self.state.global_step % 50 == 0 and self.state.global_step > 0: + self.log_generated_result(model, inputs, mode="training") + + # Log gradients at logging steps + modalities = list(inputs.keys()) + if len(modalities) == 1: + if self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0: + grad_norms = {} + for name, param in model.named_parameters(): + if param.requires_grad and param.grad is not None: + if modalities[0] in name: + if 'bias' in name: + continue + else: + grad_norms[f"grad/{name}"] = param.grad.norm().item() + + else: + if self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0: + grad_norms = {} + for name, param in model.named_parameters(): + if param.requires_grad and param.grad is not None: + if 'bias' in name: + continue + else: + grad_norms[f"grad/{name}"] = param.grad.norm().item() + + # Log to loggers through trainer's log() method + self.log(grad_norms) + + + """ + # Check gradients after backward + for name, param in model.named_parameters(): + if param.requires_grad and param.grad is not None: + print(f"{name} grad norm: {param.grad.norm().item()}") + """ + + return loss + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Perform an evaluation step on `model` using `inputs`. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Module`): + The model to evaluate. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + prediction_loss_only (`bool`): + Whether or not to return the loss only. + ignore_keys (`List[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + + Return: + Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, + logits and labels (each being optional). + """ + + modalities = list(inputs.keys()) + if len(modalities) == 1 and modalities[0] in ['T1', 'rsfMRI']: + inputs = inputs[modalities[0]] + elif len(modalities) > 1: + inputs = self.repack_inputs_except_for_pixel_values(inputs, modalities) + + has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names) + # For CLIP-like models capable of returning loss values. + # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss` + # is `True` in `model.forward`. + return_loss = inputs.get("return_loss", None) + if return_loss is None: + return_loss = self.can_return_loss + loss_without_labels = True if len(self.label_names) == 0 and return_loss else False + + inputs = self._prepare_inputs(inputs) + if ignore_keys is None: + if hasattr(self.model, "config"): + ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. + if has_labels or loss_without_labels: + labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) + if len(labels) == 1: + labels = labels[0] + else: + labels = None + + with torch.no_grad(): + if is_sagemaker_mp_enabled(): + raw_outputs = smp_forward_only(model, inputs) + if has_labels or loss_without_labels: + if isinstance(raw_outputs, dict): + loss_mb = raw_outputs["loss"] + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"]) + else: + loss_mb = raw_outputs[0] + logits_mb = raw_outputs[1:] + + loss = loss_mb.reduce_mean().detach().cpu() + logits = smp_nested_concat(logits_mb) + else: + loss = None + if isinstance(raw_outputs, dict): + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys) + else: + logits_mb = raw_outputs + logits = smp_nested_concat(logits_mb) + else: + if has_labels or loss_without_labels: + with self.compute_loss_context_manager(): + if len(modalities) == 1 and modalities[0] in ['T1', 'rsfMRI']: # do we need this logic + wrapped_inputs = {modalities[0]: inputs} + loss, outputs = self.compute_loss(model, wrapped_inputs, return_outputs=True) + else: + loss, outputs = self.compute_loss(model, inputs, return_outputs=True) + + if loss is not None: + if isinstance(loss, torch.Tensor): + loss = loss.mean().detach() + else: + loss = torch.tensor(loss) + + if isinstance(outputs, dict): + # LLaVA + logits = outputs.get('logits', None) + if logits is None: + # fallback + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) + if len(logits) == 1: + logits = logits[0] + elif hasattr(outputs, 'logits'): + logits = outputs.logits + else: + logits = outputs[1:] if len(outputs) > 1 else None + else: + loss = None + with self.compute_loss_context_manager(): + outputs = model(**inputs) + if isinstance(outputs, dict): + logits = outputs.get('logits', None) + if logits is None: + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) + elif hasattr(outputs, 'logits'): + logits = outputs.logits + else: + logits = outputs + + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0 and hasattr(outputs, '__getitem__'): + self._past = outputs[self.args.past_index - 1] + + if prediction_loss_only: + return (loss, None, None) + + if logits is not None: + logits = nested_detach(logits) + if isinstance(logits, (tuple, list)) and len(logits) == 1: + logits = logits[0] + + # Log generated result during evaluation (first sample of each eval) + if not prediction_loss_only and not hasattr(self, '_eval_generation_logged'): + self._eval_generation_logged = True + self.log_generated_result(model, inputs, mode="evaluation") + + return (loss, logits, labels) + + + def log_generated_result(self, model, inputs, mode="training"): + """ + Log generated result during training or evaluation + + Args: + model: The model to use for generation + inputs: Input dictionary (wrapped or unwrapped) + mode: "training" or "evaluation" + """ + actual_model = model.module if hasattr(model, 'module') else model + + # Only set eval mode for training (already in eval during evaluation) + if mode == "training": + actual_model.eval() + + with torch.no_grad(): + try: + # Handle input format (different for training vs evaluation) + if 'pixel_values' in inputs and 'input_ids' in inputs: + sample_input = inputs + else: + # Still wrapped in modality key (typical for training) + modality_keys = [k for k in inputs.keys() if k in ['T1', 'rsfMRI']] + if modality_keys: + sample_input = inputs[modality_keys[0]] + else: + sample_input = inputs + + # Get first sample from batch + input_ids = sample_input['input_ids'][0] + + # Handle pixel_values (supports both single-image and multi-image) + pixel_values = sample_input['pixel_values'] + if len(pixel_values.shape) == 6: + # Multi-image: [batch, num_images, C, D, H, W] -> take first batch + pixel_values_sample = pixel_values[0:1] # [1, num_images, C, D, H, W] + elif len(pixel_values.shape) == 5: + # Single-image: [batch, C, D, H, W] -> take first batch + pixel_values_sample = pixel_values[0:1] # [1, C, D, H, W] + else: + print(f"[WARN] Unexpected pixel_values shape: {pixel_values.shape}") + return + + # Search for LAST assistant token (multi-turn: we want to generate the final answer) + # Conversation structure: + # Turn 1: user → assistant (acknowledgment) + # Turn 2: user → assistant (answer) ← We generate this! + assistant_variants = ["<|im_start|>assistant\n", "<|im_start|>assistant"] + assistant_positions = [] + + for variant in assistant_variants: + assistant_tokens = self.tokenizer.encode(variant, add_special_tokens=False) + for i in range(len(input_ids) - len(assistant_tokens)): + if torch.equal(input_ids[i:i+len(assistant_tokens)], + torch.tensor(assistant_tokens, device=input_ids.device)): + assistant_positions.append(i + len(assistant_tokens)) + + if len(assistant_positions) == 0: + print(f"[WARN] Assistant token not found in {mode} input") + return + + # Use LAST assistant position (for multi-turn, this is the final answer) + last_assistant_pos = assistant_positions[-1] + prompt_ids = input_ids[:last_assistant_pos].unsqueeze(0) + + # Generate + generated_ids = actual_model.generate( + pixel_values=pixel_values_sample, # Use prepared pixel_values + input_ids=prompt_ids, + max_new_tokens=150, + do_sample=False, + temperature=0.1, + pad_token_id=self.tokenizer.pad_token_id, + ) + + generated_only = generated_ids[0][len(prompt_ids[0]):] + generated_text = self.tokenizer.decode(generated_only, skip_special_tokens=True) + + # Build result dictionary + result = { + "type": mode, + "step": self.state.global_step, + "epoch": float(self.state.epoch) if hasattr(self.state, 'epoch') else 0, + "generated_text": generated_text, + } + + # Add ground truth for evaluation mode + if mode == "evaluation": + labels = sample_input.get('labels', None) + if labels is not None: + labels_clean = labels[0].clone() + labels_clean[labels_clean == -100] = self.tokenizer.pad_token_id + ground_truth = self.tokenizer.decode(labels_clean, skip_special_tokens=True) + else: + ground_truth = "N/A" + result["ground_truth"] = ground_truth + + # Save to JSON + json_file = "generation_logs.json" + if os.path.exists(json_file): + with open(json_file, 'r') as f: + logs = json.load(f) + else: + logs = [] + + logs.append(result) + + with open(json_file, 'w') as f: + json.dump(logs, f, indent=2, ensure_ascii=False) + + # Print output + prefix = "[TRAIN]" if mode == "training" else "[EVAL]" + if mode == "evaluation": + print("\n" + "="*80) + print(f"{prefix} Step: {self.state.global_step}, Epoch: {result['epoch']}") + print(f"{prefix} Generated: {generated_text}") + print(f"{prefix} Ground Truth: {result.get('ground_truth', 'N/A')}") + print("="*80 + "\n") + else: + print(f"{prefix} Step: {self.state.global_step}") + print(f"{prefix} Generated: {generated_text}") + + except Exception as e: + print(f"[ERROR] {mode.capitalize()} generation failed: {e}") + import traceback + traceback.print_exc() + + # Restore train mode only if we changed it + if mode == "training": + actual_model.train() + + def evaluation_loop(self, *args, **kwargs): + """Override to reset generation flag at start of each evaluation""" + # Reset flag so we log generation once per eval + if hasattr(self, '_eval_generation_logged'): + delattr(self, '_eval_generation_logged') + + return super().evaluation_loop(*args, **kwargs) diff --git a/BLIP_MRI/project/utils/__pycache__/Trainer.cpython-311.pyc b/BLIP_MRI/project/utils/__pycache__/Trainer.cpython-311.pyc deleted file mode 100644 index b20beb1..0000000 Binary files a/BLIP_MRI/project/utils/__pycache__/Trainer.cpython-311.pyc and /dev/null differ diff --git a/BLIP_MRI/project/utils/__pycache__/Trainer.cpython-39.pyc b/BLIP_MRI/project/utils/__pycache__/Trainer.cpython-39.pyc deleted file mode 100644 index 1c2bdd9..0000000 Binary files a/BLIP_MRI/project/utils/__pycache__/Trainer.cpython-39.pyc and /dev/null differ diff --git a/BLIP_MRI/project/utils/__pycache__/__init__.cpython-311.pyc b/BLIP_MRI/project/utils/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index eb449ff..0000000 Binary files a/BLIP_MRI/project/utils/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/BLIP_MRI/project/utils/__pycache__/__init__.cpython-39.pyc b/BLIP_MRI/project/utils/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 8011179..0000000 Binary files a/BLIP_MRI/project/utils/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/BLIP_MRI/project/utils/__pycache__/data.cpython-311.pyc b/BLIP_MRI/project/utils/__pycache__/data.cpython-311.pyc deleted file mode 100644 index 0e87629..0000000 Binary files a/BLIP_MRI/project/utils/__pycache__/data.cpython-311.pyc and /dev/null differ diff --git a/BLIP_MRI/project/utils/__pycache__/data.cpython-39.pyc b/BLIP_MRI/project/utils/__pycache__/data.cpython-39.pyc deleted file mode 100644 index 952195c..0000000 Binary files a/BLIP_MRI/project/utils/__pycache__/data.cpython-39.pyc and /dev/null differ diff --git a/BLIP_MRI/project/utils/__pycache__/utils.cpython-311.pyc b/BLIP_MRI/project/utils/__pycache__/utils.cpython-311.pyc deleted file mode 100644 index c4337d0..0000000 Binary files a/BLIP_MRI/project/utils/__pycache__/utils.cpython-311.pyc and /dev/null differ diff --git a/BLIP_MRI/project/utils/__pycache__/utils.cpython-39.pyc b/BLIP_MRI/project/utils/__pycache__/utils.cpython-39.pyc deleted file mode 100644 index 1f2a546..0000000 Binary files a/BLIP_MRI/project/utils/__pycache__/utils.cpython-39.pyc and /dev/null differ diff --git a/BLIP_MRI/sample_scripts/BLIP_MRI_Blip_DDP_interactive.sh b/BLIP_MRI/sample_scripts/BLIP_MRI_Blip_DDP_interactive.sh old mode 100644 new mode 100755 diff --git a/BLIP_MRI/sample_scripts/BLIP_MRI_Blip_T1_DDP_interactive.sh b/BLIP_MRI/sample_scripts/BLIP_MRI_Blip_T1_DDP_interactive.sh old mode 100644 new mode 100755 diff --git a/BLIP_MRI/sample_scripts/BLIP_MRI_LLaVaNextInterleave_T1_DDP_interactive.sh b/BLIP_MRI/sample_scripts/BLIP_MRI_LLaVaNextInterleave_T1_DDP_interactive.sh new file mode 100755 index 0000000..0c57451 --- /dev/null +++ b/BLIP_MRI/sample_scripts/BLIP_MRI_LLaVaNextInterleave_T1_DDP_interactive.sh @@ -0,0 +1,21 @@ + +set +x + +cd YOUR_PROJECT_ROOT #TODO: Change to your own scratch space + + +module load python +#module load pytorch/1.13.1 +module load cpe/23.03 + +conda activate BLIP_MRI_llava #TODO: Change to your own conda env + +export LIBRARY_PATH=$LD_LIBRARY_PATH +export TORCH_EXTENSIONS_DIR= #TODO: Change to your own scratch space +export HF_HOME= #TODO: Change to your own scratch space +export TORCH_HOME= #TODO: Change to your own scratch space + + +# LLaVA-NeXT-Interleave Multi-turn Comparison Task (Sex or Age) +torchrun --nnodes 1 --nproc_per_node 1 main_BLLaVaNextInterleave_comparison_hf_joint_T1.py + diff --git a/BLIP_MRI/sample_scripts/BLIP_MRI_LLaVa_T1_DDP_interactive.sh b/BLIP_MRI/sample_scripts/BLIP_MRI_LLaVa_T1_DDP_interactive.sh old mode 100644 new mode 100755 index f6c6ffd..f4f1765 --- a/BLIP_MRI/sample_scripts/BLIP_MRI_LLaVa_T1_DDP_interactive.sh +++ b/BLIP_MRI/sample_scripts/BLIP_MRI_LLaVa_T1_DDP_interactive.sh @@ -1,14 +1,14 @@ set +x -cd /pscratch/sd/h/heehaw/BLIP_MRI/project #TODO: Change to your own scratch space +cd /YOUR_PROJECT_DIRECTORY #TODO: Change to your own scratch space module load python #module load pytorch/1.13.1 module load cpe/23.03 -conda activate /pscratch/sd/h/heehaw/anaconda/BLIP_MRI #TODO: Change to your own conda env +conda activate BLIP_MRI_llava #TODO: Change to your own conda env # conda activate py39 # pip install timm #export MASTER_ADDR=`/bin/hostname -s` @@ -16,9 +16,9 @@ conda activate /pscratch/sd/h/heehaw/anaconda/BLIP_MRI #TODO: Change to your o #export MASTER_PORT=$(shuf -i 29500-65535 -n 1) export LIBRARY_PATH=$LD_LIBRARY_PATH -export TORCH_EXTENSIONS_DIR=/pscratch/sd/h/heehaw #TODO: Change to your own scratch space -export HF_HOME=/pscratch/sd/h/heehaw/huggingface #TODO: Change to your own scratch space -export TORCH_HOME=/pscratch/sd/h/heehaw/ #TODO: Change to your own scratch space +export TORCH_EXTENSIONS_DIR=/pscratch/sd/ #TODO: Change to your own scratch space +export HF_HOME=/pscratch/sd/ #TODO: Change to your own scratch space +export TORCH_HOME=/pscratch/sd/ #TODO: Change to your own scratch space # #recent version (24.3.30)