diff --git a/examples/party_stance_detection/party_stance_detection.ipynb b/examples/party_stance_detection/party_stance_detection.ipynb new file mode 100644 index 0000000..e6fbeb9 --- /dev/null +++ b/examples/party_stance_detection/party_stance_detection.ipynb @@ -0,0 +1,83 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Detecting Stance with Respect to Political Parties\n", + "\n", + "This tries to reproduce the results of [fair and balanced?](https://www-jstor-org.cmu.idm.oclc.org/stable/44014619) by Budak et al. (2016)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llments.lm import empirical, hugging_face\n", + "from llments.distance.norm import L1Distance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load the base LM\n", + "base_lm = hugging_face.load_from_spec('base_lm_spec.json')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use few-shot learning from this dataset to turn it into a few-shot classifier\n", + "# https://huggingface.co/datasets/SetFit/tweet_eval_stance/viewer/stance_hillary\n", + "stance_dataset = empirical.EmpiricalDistribution(\n", + " \"Passage: If a man demanded staff to get him an ice tea he'd be called a sexists elitist pig.. Oink oink #Hillary #SemST\\nTarget: Hillary Clinton\\nStance: against\",\n", + " \"Passage: We're out here in G-town, and where are you #sctweets #SemST\\nTarget: Hillary Clinton\\nStance: none\",\n", + " \"Passage: If you're not watching @user speech right now you're missing her drop tons of wisdom. #SemST\\nTarget: Hillary Clinton\\nStance: favor\",\n", + ")\n", + "stance_lm = base_lm.fit(stance_dataset, task_description=\"Predict the stance of a passage with respect to the target.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Classify stance in various different passages\n", + "answer_distances = {}\n", + "distance_function = L1Distance()\n", + "\n", + "news_source_names = [\"cnn\", \"wsj\", \"fox\", \"npr\"]\n", + "party_names = [\"Democratic\", \"Republican\"]\n", + "for news_source in news_source_names:\n", + " # Load the dataset (empirical distribution) for this source\n", + " news_dataset = empirical.load_from_text_file(f'news_data_{news_source}.txt')\n", + " # Enumerate the entire news dataset\n", + " for news_datapoint in news_dataset:\n", + " for party_name in party_names:\n", + " probs = stance_lm.log_probability([\"favor\", \"against\", \"none\"], f\"Passage: {news_datapoint}\\nTarget: {party_name}\\nStance: \")\n", + " raise NotImplementedError(\"Not finished yet.\")" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/llments/lm/hugging_face.py b/llments/lm/hugging_face.py index 4779a78..7ed6eb3 100644 --- a/llments/lm/hugging_face.py +++ b/llments/lm/hugging_face.py @@ -1,15 +1,28 @@ - from llments.lm.lm import LanguageModel class HuggingFaceLM(LanguageModel): - - def sample(self, condition: str | None) -> str: - """Sample from the language model, possibly conditioned on a prefix.""" + def sample( + self, + condition: str | None, + **kwargs, + ) -> str: + """Generate from the language model, possibly conditioned on a prefix.""" raise NotImplementedError("This is not implemented yet.") - def fit(self, target: LanguageModel) -> LanguageModel: - """Fit the language model to a target language model's distribution.""" + def fit( + self, target: LanguageModel, task_description: str | None = None + ) -> LanguageModel: + """Fit the language model to a target language model's distribution. + + Args: + target: The language model that should be fitted to. + task_description: A task description that explains more about + what the language model that should be fit is doing (a prompt). + + Returns: + The fitted language model. + """ raise NotImplementedError("This is not implemented yet.") diff --git a/llments/lm/lm.py b/llments/lm/lm.py index ac973ee..54ff1eb 100644 --- a/llments/lm/lm.py +++ b/llments/lm/lm.py @@ -2,13 +2,27 @@ class LanguageModel: - @abc.abstractclassmethod - def sample(self, condition: str | None) -> str: - """Sample from the language model, possibly conditioned on a prefix.""" + def generate( + self, + condition: str | None, + **kwargs, + ) -> str: + """Generate from the language model, possibly conditioned on a prefix.""" ... @abc.abstractclassmethod - def fit(self, target: "LanguageModel") -> "LanguageModel": - """Fit the language model to a target language model's distribution.""" + def fit( + self, target: "LanguageModel", task_description: str | None = None + ) -> "LanguageModel": + """Fit the language model to a target language model's distribution. + + Args: + target: The language model that should be fitted to. + task_description: A task description that explains more about + what the language model that should be fit is doing (a prompt). + + Returns: + The fitted language model. + """ ...