Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
neubig committed Nov 22, 2023
1 parent f4705dd commit 87c909a
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 11 deletions.
83 changes: 83 additions & 0 deletions examples/party_stance_detection/party_stance_detection.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Detecting Stance with Respect to Political Parties\n",
"\n",
"This tries to reproduce the results of [fair and balanced?](https://www-jstor-org.cmu.idm.oclc.org/stable/44014619) by Budak et al. (2016)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from llments.lm import empirical, hugging_face\n",
"from llments.distance.norm import L1Distance"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load the base LM\n",
"base_lm = hugging_face.load_from_spec('base_lm_spec.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Use few-shot learning from this dataset to turn it into a few-shot classifier\n",
"# https://huggingface.co/datasets/SetFit/tweet_eval_stance/viewer/stance_hillary\n",
"stance_dataset = empirical.EmpiricalDistribution(\n",
" \"Passage: If a man demanded staff to get him an ice tea he'd be called a sexists elitist pig.. Oink oink #Hillary #SemST\\nTarget: Hillary Clinton\\nStance: against\",\n",
" \"Passage: We're out here in G-town, and where are you #sctweets #SemST\\nTarget: Hillary Clinton\\nStance: none\",\n",
" \"Passage: If you're not watching @user speech right now you're missing her drop tons of wisdom. #SemST\\nTarget: Hillary Clinton\\nStance: favor\",\n",
")\n",
"stance_lm = base_lm.fit(stance_dataset, task_description=\"Predict the stance of a passage with respect to the target.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Classify stance in various different passages\n",
"answer_distances = {}\n",
"distance_function = L1Distance()\n",
"\n",
"news_source_names = [\"cnn\", \"wsj\", \"fox\", \"npr\"]\n",
"party_names = [\"Democratic\", \"Republican\"]\n",
"for news_source in news_source_names:\n",
" # Load the dataset (empirical distribution) for this source\n",
" news_dataset = empirical.load_from_text_file(f'news_data_{news_source}.txt')\n",
" # Enumerate the entire news dataset\n",
" for news_datapoint in news_dataset:\n",
" for party_name in party_names:\n",
" probs = stance_lm.log_probability([\"favor\", \"against\", \"none\"], f\"Passage: {news_datapoint}\\nTarget: {party_name}\\nStance: \")\n",
" raise NotImplementedError(\"Not finished yet.\")"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
25 changes: 19 additions & 6 deletions llments/lm/hugging_face.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@

from llments.lm.lm import LanguageModel


class HuggingFaceLM(LanguageModel):

def sample(self, condition: str | None) -> str:
"""Sample from the language model, possibly conditioned on a prefix."""
def sample(
self,
condition: str | None,
**kwargs,
) -> str:
"""Generate from the language model, possibly conditioned on a prefix."""
raise NotImplementedError("This is not implemented yet.")

def fit(self, target: LanguageModel) -> LanguageModel:
"""Fit the language model to a target language model's distribution."""
def fit(
self, target: LanguageModel, task_description: str | None = None
) -> LanguageModel:
"""Fit the language model to a target language model's distribution.
Args:
target: The language model that should be fitted to.
task_description: A task description that explains more about
what the language model that should be fit is doing (a prompt).
Returns:
The fitted language model.
"""
raise NotImplementedError("This is not implemented yet.")


Expand Down
24 changes: 19 additions & 5 deletions llments/lm/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,27 @@


class LanguageModel:

@abc.abstractclassmethod
def sample(self, condition: str | None) -> str:
"""Sample from the language model, possibly conditioned on a prefix."""
def generate(
self,
condition: str | None,
**kwargs,
) -> str:
"""Generate from the language model, possibly conditioned on a prefix."""
...

@abc.abstractclassmethod
def fit(self, target: "LanguageModel") -> "LanguageModel":
"""Fit the language model to a target language model's distribution."""
def fit(
self, target: "LanguageModel", task_description: str | None = None
) -> "LanguageModel":
"""Fit the language model to a target language model's distribution.
Args:
target: The language model that should be fitted to.
task_description: A task description that explains more about
what the language model that should be fit is doing (a prompt).
Returns:
The fitted language model.
"""
...

0 comments on commit 87c909a

Please sign in to comment.