diff --git a/llments/lm/base/empirical.py b/llments/lm/base/empirical.py index e9e8399..f09ce71 100644 --- a/llments/lm/base/empirical.py +++ b/llments/lm/base/empirical.py @@ -51,6 +51,9 @@ def calculate_probability(self, x: str) -> float: # Implementation logic raise NotImplementedError("This is not implemented yet.") + def set_seed(self, seed: int): + random.seed(seed) + def load_from_text_file(text_file: str): """Load the distribution from a text file.""" diff --git a/llments/lm/base/hugging_face.py b/llments/lm/base/hugging_face.py index 52ae6f3..d26dd66 100644 --- a/llments/lm/base/hugging_face.py +++ b/llments/lm/base/hugging_face.py @@ -1,5 +1,5 @@ from llments.lm.lm import LanguageModel -from transformers import pipeline +from transformers import pipeline, set_seed class HuggingFaceLM(LanguageModel): @@ -64,6 +64,14 @@ def generate( ) return [res["generated_text"] for res in results] + def set_seed(self, seed: int): + """Set the seed for the language model. + + Args: + seed: The seed to set for the language model. + """ + set_seed(seed) + def load_from_spec(spec_file: str) -> HuggingFaceLM: """Load a language model from a specification file. diff --git a/llments/lm/lm.py b/llments/lm/lm.py index 862a708..54bbc82 100644 --- a/llments/lm/lm.py +++ b/llments/lm/lm.py @@ -39,3 +39,12 @@ def generate( str: Sampled output sequences from the language model. """ ... + + @abc.abstractmethod + def set_seed(self, seed: int): + """Set the seed for the language model. + + Args: + seed: The seed to set for the language model. + """ + ...