@@ -23,12 +23,13 @@ def train_intervention(config, model, tokenizer, split_to_dataset):
23
23
split_to_dataset [f'{ task_name } -{ split } ' ].select (
24
24
np .random .choice (
25
25
min (
26
- len (split_to_dataset [f'{ task_name } -{ split } ' ]) - 1 ,
26
+ len (split_to_dataset [f'{ task_name } -{ split } ' ]),
27
27
config ['cause_task_sample_size' ]),
28
- size = config ['iso_task_sample_size' ] if
28
+ size = min (config ['iso_task_sample_size' ],
29
+ len (split_to_dataset [f'{ task_name } -{ split } ' ])) if
29
30
config ['training_tasks' ][task_name ] == 'match_base' else min (
30
- len (split_to_dataset [f'{ task_name } -{ split } ' ]) -
31
- 1 , config ['cause_task_sample_size' ]),
31
+ len (split_to_dataset [f'{ task_name } -{ split } ' ]),
32
+ config ['cause_task_sample_size' ]),
32
33
replace = False ))
33
34
for task_name in config ['training_tasks' ]
34
35
if f'{ task_name } -{ split } ' in split_to_dataset
@@ -78,6 +79,8 @@ def train_intervention(config, model, tokenizer, split_to_dataset):
78
79
optimizer_params += [{'params' : v [0 ].rotate_layer .parameters ()}]
79
80
elif isinstance (v [0 ], DifferentialBinaryMasking ):
80
81
optimizer_params += [{'params' : v [0 ].parameters ()}]
82
+ else :
83
+ raise NotImplementedError
81
84
optimizer = torch .optim .AdamW (optimizer_params ,
82
85
lr = config ['init_lr' ],
83
86
weight_decay = 0 )
@@ -112,7 +115,7 @@ def train_intervention(config, model, tokenizer, split_to_dataset):
112
115
for step , inputs in enumerate (epoch_iterator ):
113
116
for k , v in inputs .items ():
114
117
if v is not None and isinstance (v , torch .Tensor ):
115
- inputs [k ] = v .to ("cuda" )
118
+ inputs [k ] = v .to (model . device )
116
119
b_s = inputs ["input_ids" ].shape [0 ]
117
120
position_ids = {
118
121
f'{ prefix } position_ids' :
0 commit comments