Skip to content

Commit f5f2ea4

Browse files
authored
Add set_seed (#17)
* Add sampling from hugging face * Add set seed function * Fix merge * Fix comment
1 parent f147f39 commit f5f2ea4

File tree

3 files changed

+21
-1
lines changed

3 files changed

+21
-1
lines changed

llments/lm/base/empirical.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ def calculate_probability(self, x: str) -> float:
5151
# Implementation logic
5252
raise NotImplementedError("This is not implemented yet.")
5353

54+
def set_seed(self, seed: int):
55+
random.seed(seed)
56+
5457

5558
def load_from_text_file(text_file: str):
5659
"""Load the distribution from a text file."""

llments/lm/base/hugging_face.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from llments.lm.lm import LanguageModel
2-
from transformers import pipeline
2+
from transformers import pipeline, set_seed
33

44

55
class HuggingFaceLM(LanguageModel):
@@ -64,6 +64,14 @@ def generate(
6464
)
6565
return [res["generated_text"] for res in results]
6666

67+
def set_seed(self, seed: int):
68+
"""Set the seed for the language model.
69+
70+
Args:
71+
seed: The seed to set for the language model.
72+
"""
73+
set_seed(seed)
74+
6775

6876
def load_from_spec(spec_file: str) -> HuggingFaceLM:
6977
"""Load a language model from a specification file.

llments/lm/lm.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,12 @@ def generate(
3939
str: Sampled output sequences from the language model.
4040
"""
4141
...
42+
43+
@abc.abstractmethod
44+
def set_seed(self, seed: int):
45+
"""Set the seed for the language model.
46+
47+
Args:
48+
seed: The seed to set for the language model.
49+
"""
50+
...

0 commit comments

Comments
 (0)