Skip to content

Commit

Permalink
set max length to 120
Browse files Browse the repository at this point in the history
  • Loading branch information
helpmefindaname committed Dec 14, 2021
1 parent 819417f commit f6f9f42
Show file tree
Hide file tree
Showing 59 changed files with 1,586 additions and 4,781 deletions.
54 changes: 13 additions & 41 deletions examples/ner/run_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,58 +28,38 @@ class ModelArguments:
default="first",
metadata={"help": "Subtoken pooling strategy used for fine-tuned."},
)
hidden_size: int = field(
default=256, metadata={"help": "Hidden size for NER model."}
)
use_crf: bool = field(
default=False, metadata={"help": "Whether to use a CRF on-top or not."}
)
hidden_size: int = field(default=256, metadata={"help": "Hidden size for NER model."})
use_crf: bool = field(default=False, metadata={"help": "Whether to use a CRF on-top or not."})


@dataclass
class TrainingArguments:
num_epochs: int = field(
default=10, metadata={"help": "The number of training epochs."}
)
batch_size: int = field(
default=8, metadata={"help": "Batch size used for training."}
)
num_epochs: int = field(default=10, metadata={"help": "The number of training epochs."})
batch_size: int = field(default=8, metadata={"help": "Batch size used for training."})
mini_batch_chunk_size: int = field(
default=1,
metadata={"help": "If smaller than batch size, batches will be chunked."},
)
learning_rate: float = field(default=5e-05, metadata={"help": "Learning rate"})
seed: int = field(
default=42, metadata={"help": "Seed used for reproducible fine-tuning results."}
)
seed: int = field(default=42, metadata={"help": "Seed used for reproducible fine-tuning results."})
device: str = field(default="cuda:0", metadata={"help": "CUDA device string."})
weight_decay: float = field(
default=0.0, metadata={"help": "Weight decay for optimizer."}
)
embeddings_storage_mode: str = field(
default="none", metadata={"help": "Defines embedding storage method."}
)
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for optimizer."})
embeddings_storage_mode: str = field(default="none", metadata={"help": "Defines embedding storage method."})


@dataclass
class FlertArguments:
context_size: int = field(
default=0, metadata={"help": "Context size when using FLERT approach."}
)
context_size: int = field(default=0, metadata={"help": "Context size when using FLERT approach."})
respect_document_boundaries: bool = field(
default=False,
metadata={
"help": "Whether to respect document boundaries or not when using FLERT."
},
metadata={"help": "Whether to respect document boundaries or not when using FLERT."},
)


@dataclass
class DataArguments:
dataset_name: str = field(metadata={"help": "Flair NER dataset name."})
dataset_arguments: str = field(
default="", metadata={"help": "Dataset arguments for Flair NER dataset."}
)
dataset_arguments: str = field(default="", metadata={"help": "Dataset arguments for Flair NER dataset."})
output_dir: str = field(
default="resources/taggers/ner",
metadata={"help": "Defines output directory for final fine-tuned model."},
Expand All @@ -91,11 +71,7 @@ def get_flair_corpus(data_args):

for name, obj in inspect.getmembers(flair.datasets.sequence_labeling):
if inspect.isclass(obj):
if (
name.startswith("NER")
or name.startswith("CONLL")
or name.startswith("WNUT")
):
if name.startswith("NER") or name.startswith("CONLL") or name.startswith("WNUT"):
ner_task_mapping[name] = obj

dataset_args = {}
Expand All @@ -105,17 +81,13 @@ def get_flair_corpus(data_args):
dataset_args = json.loads(data_args.dataset_arguments)

if not dataset_name in ner_task_mapping:
raise ValueError(
f"Dataset name {dataset_name} is not a valid Flair datasets name!"
)
raise ValueError(f"Dataset name {dataset_name} is not a valid Flair datasets name!")

return ner_task_mapping[dataset_name](**dataset_args)


def main():
parser = HfArgumentParser(
(ModelArguments, TrainingArguments, FlertArguments, DataArguments)
)
parser = HfArgumentParser((ModelArguments, TrainingArguments, FlertArguments, DataArguments))

if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
(
Expand Down
4 changes: 1 addition & 3 deletions flair/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@
"stream": "ext://sys.stdout",
}
},
"loggers": {
"flair": {"handlers": ["console"], "level": "INFO", "propagate": False}
},
"loggers": {"flair": {"handlers": ["console"], "level": "INFO", "propagate": False}},
}
)

Expand Down
Loading

0 comments on commit f6f9f42

Please sign in to comment.