diff --git a/preprocess/processor.py b/preprocess/processor.py index 2276390..8bc471d 100644 --- a/preprocess/processor.py +++ b/preprocess/processor.py @@ -12,12 +12,14 @@ class TaskProcessor(object): - def __init__(self, task: BaseTask, data_path: str, output_path: str, model_path: str, resample: str): + def __init__(self, task: BaseTask, data_path: str, output_path: str, model_path: str, resample: str, + token_shapes: bool=False): self.task: BaseTask = task self.data_path: str = data_path self.model_path = model_path self.output_path = output_path self.task_output_path = os.path.join(self.output_path, task.spec().output_path()) + self.token_shapes = token_shapes self.resample = self._parse_resample_string(resample) if not os.path.exists(self.task_output_path): os.makedirs(self.task_output_path, exist_ok=True) @@ -142,5 +144,8 @@ def _run_fairseq_preprocess(self, input_name: str, destdir: str): dict_path: str = os.path.join(self.model_path, "dict.txt") cmd.append("--srcdict") cmd.append(dict_path) + if self.token_shapes: + cmd.append("--task") + cmd.append("masked_lm_with_token_shapes") logging.info("running %s", cmd.__repr__()) subprocess.run(cmd) \ No newline at end of file diff --git a/run_tasks.py b/run_tasks.py index 36d7443..6629b04 100644 --- a/run_tasks.py +++ b/run_tasks.py @@ -63,13 +63,14 @@ def __init__(self, task: BaseTask, task_id: str, input_dir: str, output_dir: str self.model_name: str = os.path.basename(model_dir) self.task_output_dir: str = os.path.join(self.output_dir, f"{task.spec().output_path()}-bin") - def prepare_task(self, resample: str): - processor = TaskProcessor(self.task, self.input_dir, self.output_dir, self.model_dir, resample) + def prepare_task(self, resample: str, token_shapes: bool): + processor = TaskProcessor(self.task, self.input_dir, self.output_dir, self.model_dir, resample, token_shapes) processor.prepare() - def train_task(self, train_epochs: int, fp16: bool, max_sentences: int, update_freq: int): + def train_task(self, train_epochs: int, fp16: bool, max_sentences: int, update_freq: int, token_shapes: bool): train_size = self._count_train() - trainer = TaskTrainer(self.task, self.output_dir, self.model_dir, train_size, arch=self.arch, fp16=fp16) + trainer = TaskTrainer(self.task, self.output_dir, self.model_dir, train_size, + arch=self.arch, fp16=fp16, token_shapes=token_shapes) trainer.train(train_epochs=train_epochs, max_sentences=max_sentences, update_freq=update_freq) def evaluate_task(self): @@ -109,7 +110,7 @@ def log_score(self, task_name: str, task_id: str, params: Dict, scores: Dict): def run_tasks(arch: str, model_dir: str, input_dir: str="data", output_dir: str="data_processed", tasks: str=None, train_epochs: int=10, fp16: bool=False, max_sentences: int=1, update_freq: int=16, - evaluation_only: bool=False, resample: str=None, seed: int=None): + evaluation_only: bool=False, resample: str=None, token_shapes: bool=False, seed: int=None): assert arch in ("roberta_base", "roberta_large", "bart_base", "bart_large") params = locals() if tasks is None: @@ -127,8 +128,8 @@ def run_tasks(arch: str, model_dir: str, input_dir: str="data", output_dir: str= task = task_class() runner: TaskRunner = TaskRunner(task, task_id, input_dir, output_dir, model_dir, arch, seed) if not evaluation_only: - runner.prepare_task(resample) - runner.train_task(train_epochs, fp16, max_sentences, update_freq) + runner.prepare_task(resample, token_shapes) + runner.train_task(train_epochs, fp16, max_sentences, update_freq, token_shapes) score = runner.evaluate_task() runner.log_score(task_name, task_id, params, score) diff --git a/train/trainer.py b/train/trainer.py index 4da6270..c4e7f8f 100644 --- a/train/trainer.py +++ b/train/trainer.py @@ -1,4 +1,3 @@ -import importlib import logging import os import random @@ -12,8 +11,8 @@ class TaskTrainer(object): - def __init__(self, task: BaseTask, data_path: str, model_path: str, train_size: int, - checkpoint: str="model.pt", arch: str="roberta_large", fp16: bool=False): + def __init__(self, task: BaseTask, data_path: str, model_path: str, train_size: int, checkpoint: str="model.pt", + arch: str="roberta_large", fp16: bool=False, token_shapes: bool=False): self.task: BaseTask = task self.train_size: int = train_size self.data_path: str = data_path @@ -24,6 +23,7 @@ def __init__(self, task: BaseTask, data_path: str, model_path: str, train_size: self.arch: str = arch self.learning_rate = "1e-5" self.fp16 = fp16 + self.token_shapes = token_shapes def train(self, max_sentences: int=1, update_freq: int=16, train_epochs: int=10, seed: int=None): self._run_fairseq_train(seed, max_sentences=max_sentences, update_freq=update_freq, max_epoch=train_epochs) @@ -42,6 +42,7 @@ def _run_fairseq_train(self, seed: int, max_sentences: int=16, update_freq: int= restore_file = os.path.join(self.model_path, self.checkpoint) assert os.path.exists(restore_file) checkpoint_path = os.path.join("checkpoints", self.model_name, self.task.spec().output_path()) + task = "sentence_prediction" if not self.token_shapes else "sentence_prediction_with_token_shapes" self._remove_previous_checkpoints(checkpoint_path) cmd = [ self.task_data_path, @@ -50,7 +51,7 @@ def _run_fairseq_train(self, seed: int, max_sentences: int=16, update_freq: int= "--max-sentences", str(max_sentences), "--update-freq", str(update_freq), "--max-tokens", "4400", - "--task", "sentence_prediction", + "--task", task, "--reset-optimizer", "--reset-dataloader", "--reset-meters", @@ -107,12 +108,12 @@ def _run_fairseq_train(self, seed: int, max_sentences: int=16, update_freq: int= def _run_training(self, cmd: List[str]): try: - from fairseq_cli.train import cli_main_helper + from fairseq_cli.train import main parser = options.get_training_parser() if self.arch.startswith("bart"): parser.add_argument("--max-positions", type=int) args = options.parse_args_and_arch(parser, input_args=cmd) - cli_main_helper(args) + main(args) except ImportError: cmd.insert(0, "fairseq-train") subprocess.run(cmd)