diff --git a/examples/community_lm/community_lm.ipynb b/examples/community_lm/community_lm.ipynb index 4b36845..aa27aee 100644 --- a/examples/community_lm/community_lm.ipynb +++ b/examples/community_lm/community_lm.ipynb @@ -30,7 +30,7 @@ "from community_lm_constants import politician_feelings, groups_feelings, anes_df\n", "from community_lm_utils import generate_community_opinion, compute_group_stance\n", "\n", - "device = 'mps' # change to 'mps' if you have a mac, or 'cuda:0' if you have an NVIDIA GPU " + "device = 'cpu' # change to 'mps' if you have a mac, or 'cuda:0' if you have an NVIDIA GPU " ] }, { @@ -384,7 +384,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.11.8" } }, "nbformat": 4, diff --git a/llments/lm/base/hugging_face.py b/llments/lm/base/hugging_face.py index df04065..0cc5840 100644 --- a/llments/lm/base/hugging_face.py +++ b/llments/lm/base/hugging_face.py @@ -1,5 +1,4 @@ from llments.lm.lm import LanguageModel -from transformers import pipeline, set_seed, TextGenerationPipeline class HuggingFaceLM(LanguageModel): @@ -14,6 +13,12 @@ def __init__( model: The name of the model. device: The device to run the model on. """ + try: + from transformers import pipeline, set_seed, TextGenerationPipeline + except ImportError: + raise ImportError( + "You need to install the `transformers` package to use this class." + ) self.text_generator: TextGenerationPipeline = pipeline( "text-generation", model=model, device=device ) diff --git a/pyproject.toml b/pyproject.toml index 73b392d..d1a57d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ classifiers = [ ] dependencies = [ "pandas", + "tqdm", ] dynamic = ["version"]