Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
BLIP_MRI/project/hf_results
BLIP_MRI/project/hf_logs
BLIP_MRI/logs
BLIP_MRI/project/wandb
BLIP_MRI/project/dataset/__pychache__
BLIP_MRI/project/model/__pychache__
BLIP_MRI/project/utils/__pychache__
BLIP_MRI/project/*.json
__pycache__/
*.pyc
6 changes: 3 additions & 3 deletions BLIP_MRI/environment_llava.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: /pscratch/sd/h/heehaw/anaconda/BLIP_MRI_llava
name: BLIP_MRI_llava
channels:
- conda-forge
dependencies:
Expand Down Expand Up @@ -125,7 +125,7 @@ dependencies:
- sympy==1.14.0
- threadpoolctl==3.6.0
- timm==0.4.12
- tokenizers==0.13.3
- tokenizers>=0.20
- torch==2.8.0
- torchvision==0.23.0
- tqdm==4.67.1
Expand All @@ -138,4 +138,4 @@ dependencies:
- wandb==0.17.0
- xxhash==3.5.0
- yarl==1.20.1
prefix: /pscratch/sd/h/heehaw/anaconda/BLIP_MRI_llava
prefix: /YOUR_DIRECTORY
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
18 changes: 9 additions & 9 deletions BLIP_MRI/project/dataset/dataset_T1_LLaVa.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,18 @@ def __transform_image__(self, image_file):
image = apply_transform(self.image_transform, image, map_items=False)
image = torch.tensor(image)
return image


def __transform_text__(self, label, add_context=False, sex=None, age=None):
def __transform_text__(self, label, add_context=False, sex=None, age=None):
if len(self.label_names) == 1 and 'sex' in self.label_names:
if int(label) == 1:
inst = f"{self.quest_template} Estimate sex of subject from this image. {self.ans_template} "
answer = f'male'
elif int(label) == 2:
inst = f"{self.quest_template} Estimate sex of subject from this image. {self.ans_template} "
answer = f'female'
sex_text = 'male' if int(label) == 1 else 'female'
inst = f"{self.quest_template} Estimate sex of subject from this image. {self.ans_template}"
answer = f"The brain shows {sex_text} characteristics."

elif len(self.label_names) == 1 and 'age' in self.label_names:
inst = f"{self.quest_template} Estimate age of subject from this image."
answer = f'{self.ans_template} {label}'
inst = f"{self.quest_template} Estimate age of subject from this image. {self.ans_template}"
answer = f"{label} years"

return inst, answer


Expand Down
Binary file modified BLIP_MRI/project/model/__pycache__/Bblip_t5.cpython-311.pyc
Binary file not shown.
Binary file modified BLIP_MRI/project/model/__pycache__/__init__.cpython-311.pyc
Binary file not shown.
164 changes: 125 additions & 39 deletions BLIP_MRI/project/utils/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ def preprocess_logits_for_metrics(logits, labels):
pred_ids = torch.argmax(logits, dim=-1)
return pred_ids


@torch.no_grad()
# @torch.no_grad()
def compute_metrics_with_tokenizer(tokenizer):
def compute_metrics(eval_preds):
predictions, labels = eval_preds
Expand All @@ -53,10 +52,9 @@ def compute_metrics(eval_preds):
pred_genders = []
true_genders = []

import re
for pred in decoded_preds:
pred_clean = pred.lower().strip()

import re
if re.search(r'\bfemale\b', pred_clean):
pred_genders.append(1)
elif re.search(r'\bmale\b', pred_clean):
Expand All @@ -66,29 +64,59 @@ def compute_metrics(eval_preds):

for label in decoded_labels:
label_clean = label.lower().strip()

import re
if re.search(r'\bfemale\b', label_clean):
true_genders.append(1)
elif re.search(r'\bmale\b', label_clean):
true_genders.append(0)
else:
true_genders.append(-1)

# Valid pairs
valid_pairs = [(p, t) for p, t in zip(pred_genders, true_genders) if p != -1 and t != -1]

if valid_pairs:
valid_preds, valid_trues = zip(*valid_pairs)
accuracy = balanced_accuracy_score(valid_trues, valid_preds)
f1 = f1_score(valid_trues, valid_preds, average='macro')
valid_accuracy = balanced_accuracy_score(valid_trues, valid_preds)
valid_f1 = f1_score(valid_trues, valid_preds, average='macro')
else:
valid_accuracy = 0.0
valid_f1 = 0.0

# Overall 메트릭 (invalid를 오답 처리)
overall_preds = []
overall_trues = []

for p, t in zip(pred_genders, true_genders):
if t != -1: # ground truth가 유효한 경우만
overall_trues.append(t)
if p == -1:
overall_preds.append(1 - t)
# overall_preds.append(-1)
else:
overall_preds.append(p)

if overall_preds:
overall_accuracy = balanced_accuracy_score(overall_trues, overall_preds)
overall_f1 = f1_score(overall_trues, overall_preds, average='macro')
else:
accuracy = 0.0
f1 = 0.0
overall_accuracy = 0.0
overall_f1 = 0.0

total_samples = len(pred_genders)
invalid_predictions = pred_genders.count(-1)
response_rate = (total_samples - invalid_predictions) / total_samples if total_samples > 0 else 0

metrics = {
'accuracy': accuracy,
'f1': f1
}
'accuracy': valid_accuracy,
'f1': valid_f1,
'overall_accuracy': overall_accuracy,
'overall_f1': overall_f1,
'response_rate': response_rate,
'valid_samples': len(valid_pairs),
'total_samples': total_samples,
'invalid_predictions': invalid_predictions
}

return metrics

return compute_metrics
Expand Down Expand Up @@ -297,9 +325,9 @@ def compute_loss(self, model, inputs, return_outputs=False):
def training_step(self, model, inputs):
loss = super().training_step(model, inputs)

# generation result
if self.state.global_step % 50 == 0 and self.state.global_step > 0:
self.log_generated_result(model, inputs)
# # generation result
# if self.state.global_step % 50 == 0 and self.state.global_step > 0:
# self.log_generated_result(model, inputs, mode="training")

# Log gradients at logging steps
modalities = list(inputs.keys())
Expand Down Expand Up @@ -476,35 +504,60 @@ def prediction_step(
print(f" - logits shape: {logits.shape if logits is not None else None}")
print(f" - labels shape: {labels.shape if labels is not None else None}")

# Log generated result during evaluation (first sample of each eval)
if not prediction_loss_only and not hasattr(self, '_eval_generation_logged'):
self._eval_generation_logged = True
self.log_generated_result(model, inputs, mode="evaluation")

return (loss, logits, labels)

def log_generated_result(self, model, inputs):
def log_generated_result(self, model, inputs, mode="training"):
"""
Log generated result during training or evaluation

Args:
model: The model to use for generation
inputs: Input dictionary (wrapped or unwrapped)
mode: "training" or "evaluation"
"""
actual_model = model.module if hasattr(model, 'module') else model

actual_model.eval()

# Only set eval mode for training (already in eval during evaluation)
if mode == "training":
actual_model.eval()

with torch.no_grad():
try:
modality = list(inputs.keys())[0]
sample_input = inputs[modality]

# Handle input format (different for training vs evaluation)
if 'pixel_values' in inputs and 'input_ids' in inputs:
sample_input = inputs
else:
# Still wrapped in modality key (typical for training)
modality_keys = [k for k in inputs.keys() if k in ['T1', 'rsfMRI']]
if modality_keys:
sample_input = inputs[modality_keys[0]]
else:
sample_input = inputs

input_ids = sample_input['input_ids'][0]

# Search ASSISTANT: token
assistant_tokens = self.tokenizer.encode("ASSISTANT:", add_special_tokens=False)
assistant_pos = None

for i in range(len(input_ids) - len(assistant_tokens)):
if torch.equal(input_ids[i:i+len(assistant_tokens)],
if torch.equal(input_ids[i:i+len(assistant_tokens)],
torch.tensor(assistant_tokens, device=input_ids.device)):
assistant_pos = i + len(assistant_tokens)
break

if assistant_pos is None:
print("Warning: ASSISTANT: not found in input")
print(f"[WARN] ASSISTANT: not found in {mode} input")
return

prompt_ids = input_ids[:assistant_pos].unsqueeze(0)


# Generate
generated_ids = actual_model.generate(
pixel_values=sample_input['pixel_values'][0:1],
input_ids=prompt_ids,
Expand All @@ -513,37 +566,70 @@ def log_generated_result(self, model, inputs):
temperature=0.1,
pad_token_id=self.tokenizer.pad_token_id,
)

generated_only = generated_ids[0][len(prompt_ids[0]):]
generated_text = self.tokenizer.decode(generated_only, skip_special_tokens=True)


# Build result dictionary
result = {
"type": mode,
"step": self.state.global_step,
"epoch": float(self.state.epoch) if hasattr(self.state, 'epoch') else 0,
"generated_text": generated_text,
}


# Add ground truth for evaluation mode
if mode == "evaluation":
labels = sample_input.get('labels', None)
if labels is not None:
labels_clean = labels[0].clone()
labels_clean[labels_clean == -100] = self.tokenizer.pad_token_id
ground_truth = self.tokenizer.decode(labels_clean, skip_special_tokens=True)
else:
ground_truth = "N/A"
result["ground_truth"] = ground_truth

# Save to JSON
json_file = "generation_logs.json"
if os.path.exists(json_file):
with open(json_file, 'r') as f:
logs = json.load(f)
else:
logs = []

logs.append(result)

with open(json_file, 'w') as f:
json.dump(logs, f, indent=2, ensure_ascii=False)

print(f"Step: {self.state.global_step}")
print(f"Generated: {generated_text}")
# Print output
prefix = "[TRAIN]" if mode == "training" else "[EVAL]"
if mode == "evaluation":
print("\n" + "="*80)
print(f"{prefix} Step: {self.state.global_step}, Epoch: {result['epoch']}")
print(f"{prefix} Generated: {generated_text}")
print(f"{prefix} Ground Truth: {result.get('ground_truth', 'N/A')}")
print("="*80 + "\n")
else:
print(f"{prefix} Step: {self.state.global_step}")
print(f"{prefix} Generated: {generated_text}")

except Exception as e:
print(f"[ERROR] Generation failed: {e}")
print(f"[ERROR] {mode.capitalize()} generation failed: {e}")
import traceback
traceback.print_exc()

actual_model.train()

# Restore train mode only if we changed it
if mode == "training":
actual_model.train()

def evaluation_loop(self, *args, **kwargs):
"""Override to reset generation flag at start of each evaluation"""
# Reset flag so we log generation once per eval
if hasattr(self, '_eval_generation_logged'):
delattr(self, '_eval_generation_logged')

return super().evaluation_loop(*args, **kwargs)



Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file modified BLIP_MRI/sample_scripts/BLIP_MRI_Blip_DDP_interactive.sh
100644 → 100755
Empty file.
Empty file modified BLIP_MRI/sample_scripts/BLIP_MRI_Blip_T1_DDP_interactive.sh
100644 → 100755
Empty file.
10 changes: 5 additions & 5 deletions BLIP_MRI/sample_scripts/BLIP_MRI_LLaVa_T1_DDP_interactive.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@

set +x

cd /pscratch/sd/h/heehaw/BLIP_MRI/project #TODO: Change to your own scratch space
cd /YOUR_PROJECT_DIRECTORY #TODO: Change to your own scratch space


module load python
#module load pytorch/1.13.1
module load cpe/23.03

conda activate /pscratch/sd/h/heehaw/anaconda/BLIP_MRI #TODO: Change to your own conda env
conda activate BLIP_MRI_llava #TODO: Change to your own conda env
# conda activate py39
# pip install timm
#export MASTER_ADDR=`/bin/hostname -s`
#export MASTER_PORT=29500
#export MASTER_PORT=$(shuf -i 29500-65535 -n 1)

export LIBRARY_PATH=$LD_LIBRARY_PATH
export TORCH_EXTENSIONS_DIR=/pscratch/sd/h/heehaw #TODO: Change to your own scratch space
export HF_HOME=/pscratch/sd/h/heehaw/huggingface #TODO: Change to your own scratch space
export TORCH_HOME=/pscratch/sd/h/heehaw/ #TODO: Change to your own scratch space
export TORCH_EXTENSIONS_DIR=/pscratch/sd/ #TODO: Change to your own scratch space
export HF_HOME=/pscratch/sd/ #TODO: Change to your own scratch space
export TORCH_HOME=/pscratch/sd/ #TODO: Change to your own scratch space


# #recent version (24.3.30)
Expand Down