diff --git a/frame_semantic_transformer/predict.py b/frame_semantic_transformer/predict.py index 54f0121..df67f11 100644 --- a/frame_semantic_transformer/predict.py +++ b/frame_semantic_transformer/predict.py @@ -13,7 +13,7 @@ def predict( num_beams: int = 5, top_k: int = 50, top_p: float = 0.95, - do_sample: bool = True, + do_sample: bool = False, repetition_penalty: float = 2.5, length_penalty: float = 1.0, early_stopping: bool = True, @@ -47,7 +47,7 @@ def batch_predict( num_beams: int = 5, top_k: int = 50, top_p: float = 0.95, - do_sample: bool = True, + do_sample: bool = False, repetition_penalty: float = 2.5, length_penalty: float = 1.0, early_stopping: bool = True, @@ -90,7 +90,7 @@ def predict_on_ids( num_beams: int = 5, top_k: int = 50, top_p: float = 0.95, - do_sample: bool = True, + do_sample: bool = False, repetition_penalty: float = 2.5, length_penalty: float = 1.0, early_stopping: bool = True,