forked from bogoconic1/pii-detection-1st-place
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_multidropout.py
52 lines (46 loc) · 2.1 KB
/
train_multidropout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import yaml
import subprocess
from datetime import datetime
# Load the configuration file
with open('configs/multidropouts_config.yaml', 'r') as file:
config = yaml.safe_load(file)
# Load training configurations
training_config = config['training']
# Loop through each validation fold
for VALIDATION_FOLD in config['validation_folds']:
OUTPUT_DIR = f"models/multidropouts-{VALIDATION_FOLD}-lr{training_config['learning_rate']}"
MODEL_NAME = f"custom-model-{training_config['max_length']}-fold-{VALIDATION_FOLD}"
current_date = datetime.now().strftime("%y%m%d_%H%M")
# Construct the training command
command = [
"accelerate", "launch", "--num_processes", "8", "deberta-multi-dropouts.py",
"--output_dir", OUTPUT_DIR,
"--validation_fold", str(VALIDATION_FOLD),
"--model_path", training_config['model_path'],
"--max_length", str(training_config['max_length']),
"--learning_rate", str(training_config['learning_rate']),
"--per_device_train_batch_size", str(training_config['per_device_train_batch_size']),
"--per_device_eval_batch_size", str(training_config['per_device_eval_batch_size']),
"--num_train_epochs", str(training_config['num_train_epochs']),
"--save_steps", str(training_config['save_steps']),
"--o_weight", str(training_config['o_weight']),
"--model_name", MODEL_NAME,
"--hash", training_config['hash_name'],
"--peft", str(training_config['peft']).lower(),
"--seed", str(training_config['seed']),
"--adv_mode", training_config['adv_stop_mode'],
"--adv_start", str(training_config['adv_start']),
"--loss", training_config['loss']
]
# Execute the command and redirect stdout and stderr
with open(
f"logs/dropouts-fold{VALIDATION_FOLD}-{current_date}.log", "w"
) as log_file:
process = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
for line in process.stdout:
print(line, end="")
log_file.write(line)
process.stdout.close()
process.wait()