Skip to content

Commit 319ab85

Browse files
committed
Check training sample size.
1 parent c19ef0f commit 319ab85

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

scripts/train_intervention.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@ def train_intervention(config, model, tokenizer, split_to_dataset):
2323
split_to_dataset[f'{task_name}-{split}'].select(
2424
np.random.choice(
2525
min(
26-
len(split_to_dataset[f'{task_name}-{split}']) - 1,
26+
len(split_to_dataset[f'{task_name}-{split}']),
2727
config['cause_task_sample_size']),
28-
size=config['iso_task_sample_size'] if
28+
size=min(config['iso_task_sample_size'],
29+
len(split_to_dataset[f'{task_name}-{split}'])) if
2930
config['training_tasks'][task_name] == 'match_base' else min(
30-
len(split_to_dataset[f'{task_name}-{split}']) -
31-
1, config['cause_task_sample_size']),
31+
len(split_to_dataset[f'{task_name}-{split}']),
32+
config['cause_task_sample_size']),
3233
replace=False))
3334
for task_name in config['training_tasks']
3435
if f'{task_name}-{split}' in split_to_dataset
@@ -78,6 +79,8 @@ def train_intervention(config, model, tokenizer, split_to_dataset):
7879
optimizer_params += [{'params': v[0].rotate_layer.parameters()}]
7980
elif isinstance(v[0], DifferentialBinaryMasking):
8081
optimizer_params += [{'params': v[0].parameters()}]
82+
else:
83+
raise NotImplementedError
8184
optimizer = torch.optim.AdamW(optimizer_params,
8285
lr=config['init_lr'],
8386
weight_decay=0)
@@ -112,7 +115,7 @@ def train_intervention(config, model, tokenizer, split_to_dataset):
112115
for step, inputs in enumerate(epoch_iterator):
113116
for k, v in inputs.items():
114117
if v is not None and isinstance(v, torch.Tensor):
115-
inputs[k] = v.to("cuda")
118+
inputs[k] = v.to(model.device)
116119
b_s = inputs["input_ids"].shape[0]
117120
position_ids = {
118121
f'{prefix}position_ids':

0 commit comments

Comments
 (0)