Skip to content

Commit

Permalink
Fix bugs in API tools (to fix #348) (#349)
Browse files Browse the repository at this point in the history
* Use litellm.utils.get_max_tokens appropriatel

* Set max_tokens appropriately

* Use 16k model for HyDE

* Use updated max_tokens logic in async call

* Reduce context size needed for HyDE
  • Loading branch information
viswavi authored Sep 11, 2023
1 parent 7223bf8 commit 5ca2858
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 192 deletions.
191 changes: 3 additions & 188 deletions prompt2model/model_retriever/generate_hypothetical_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from prompt2model.prompt_parser import PromptSpec
from prompt2model.utils import API_ERRORS, api_tools, handle_api_error

PROMPT_PREFIX = '''HuggingFace contains models, which are each given a user-generated description. The first section of the description, delimited with two "---" lines, consists of a YAML description of the model. This may contain fields like "language" (supported by model), "datasets" (used to train the model), "tags" (e.g. tasks relevant to the model), and "metrics" (used to evaluate the model). Create a hypothetical HuggingFace model description that would satisfy a given user instruction. Here are some examples:
PROMPT_PREFIX = """HuggingFace contains models, which are each given a user-generated description. The first section of the description, delimited with two "---" lines, consists of a YAML description of the model. This may contain fields like "language" (supported by model), "datasets" (used to train the model), "tags" (e.g. tasks relevant to the model), and "metrics" (used to evaluate the model). Create a hypothetical HuggingFace model description that would satisfy a given user instruction. Here are some examples:
Instruction: "Give me some translation from English to Vietnamese. Input English and output Vietnamese."
Hypothetical model description:
Expand Down Expand Up @@ -70,10 +70,6 @@
- prepro: normalization + SentencePiece (spm32k,spm32k)
- url_model: https://object.pouta.csc.fi/Tatoeba-MT-models/eng-vie/opus-2020-06-17.zip
- url_test_set: https://object.pouta.csc.fi/Tatoeba-MT-models/eng-vie/opus-2020-06-17.test.txt
- src_alpha3: eng
- tgt_alpha3: vie
Expand Down Expand Up @@ -102,14 +98,6 @@
- long_pair: eng-vie
- helsinki_git_sha: 480fcbe0ee1bf4774bcbe6226ad9f58e63f6c535
- transformers_git_sha: 2207e5d8cb224e954a7cba69fa4ac2309e9ff30b
- port_machine: brutasse
- port_time: 2020-08-21-14:41
Instruction: "I want to summarize things like news articles."
Hypothetical model description:
Expand Down Expand Up @@ -144,58 +132,11 @@
type: rouge
value: 23.8845
verified: true
- name: ROUGE-LSUM
type: rouge
value: 32.9017
verified: true
- name: loss
type: loss
value: 2.5757133960723877
verified: true
- name: gen_len
type: gen_len
value: 76.3984
verified: true
---
## Model description
[PEGASUS](https://github.com/google-research/pegasus) fine-tuned for summarization
## Install "sentencepiece" library required for tokenizer
```
pip install sentencepiece
```
## Model in Action 🚀
```
import torch
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
model_name = 'tuner007/pegasus_summarizer'
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = PegasusTokenizer.from_pretrained(model_name)
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device)
def get_response(input_text):
batch = tokenizer([input_text],truncation=True,padding='longest',max_length=1024, return_tensors="pt").to(torch_device)
gen_out = model.generate(**batch,max_length=128,num_beams=5, num_return_sequences=1, temperature=1.5)
output_text = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
return output_text
```
#### Example:
context = """"
India wicket-keeper batsman Rishabh Pant has said someone from the crowd threw a ball on pacer Mohammed Siraj while he was fielding in the ongoing third Test against England on Wednesday. Pant revealed the incident made India skipper Virat Kohli "upset". "I think, somebody threw a ball inside, at Siraj, so he [Kohli] was upset," said Pant in a virtual press conference after the close of the first day\'s play."You can say whatever you want to chant, but don\'t throw things at the fielders and all those things. It is not good for cricket, I guess," he added.In the third session of the opening day of the third Test, a section of spectators seemed to have asked Siraj the score of the match to tease the pacer. The India pacer however came with a brilliant reply as he gestured 1-0 (India leading the Test series) towards the crowd.Earlier this month, during the second Test match, there was some bad crowd behaviour on a show as some unruly fans threw champagne corks at India batsman KL Rahul.Kohli also intervened and he was seen gesturing towards the opening batsman to know more about the incident. An over later, the TV visuals showed that many champagne corks were thrown inside the playing field, and the Indian players were visibly left frustrated.Coming back to the game, after bundling out India for 78, openers Rory Burns and Haseeb Hameed ensured that England took the honours on the opening day of the ongoing third Test.At stumps, England\'s score reads 120/0 and the hosts have extended their lead to 42 runs. For the Three Lions, Burns (52*) and Hameed (60*) are currently unbeaten at the crease.Talking about the pitch on opening day, Pant said, "They took the heavy roller, the wicket was much more settled down, and they batted nicely also," he said. "But when we batted, the wicket was slightly soft, and they bowled in good areas, but we could have applied [ourselves] much better."Both England batsmen managed to see off the final session and the hosts concluded the opening day with all ten wickets intact, extending the lead to 42.(ANI)
"""
```
get_response(context)
```
#### Output:
Team India wicketkeeper-batsman Rishabh Pant has said that Virat Kohli was "upset" after someone threw a ball on pacer Mohammed Siraj while he was fielding in the ongoing third Test against England. "You can say whatever you want to chant, but don't throw things at the fielders and all those things. It's not good for cricket, I guess," Pant added.'
#### [Inshort](https://www.inshorts.com/) (60 words News summary app, rated 4.4 by 5,27,246+ users on android playstore) summary:
India wicketkeeper-batsman Rishabh Pant has revealed that captain Virat Kohli was upset with the crowd during the first day of Leeds Test against England because someone threw a ball at pacer Mohammed Siraj. Pant added, "You can say whatever you want to chant, but don't throw things at the fielders and all those things. It is not good for cricket."
> Created by [Arpit Rajauria](https://twitter.com/arpit_rajauria)
[![Twitter icon](https://cdn0.iconfinder.com/data/icons/shift-logotypes/32/Twitter-32.png)](https://twitter.com/arpit_rajauria)
Expand Down Expand Up @@ -233,95 +174,6 @@ def get_response(input_text):
name: Accuracy
verified: true
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiN2YyOGMxYjY2Y2JhMjkxNjIzN2FmMjNiNmM2ZWViNGY3MTNmNWI2YzhiYjYxZTY0ZGUyN2M1NGIxZjRiMjQwZiIsInZlcnNpb24iOjF9.uui0srxV5ZHRhxbYN6082EZdwpnBgubPJ5R2-Wk8HTWqmxYE3QHidevR9LLAhidqGw6Ih93fK0goAXncld_gBg
- type: precision
value: 0.8978260869565218
name: Precision
verified: true
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiMzgwYTYwYjA2MmM0ZTYwNDk0M2NmNTBkZmM2NGNhYzQ1OGEyN2NkNDQ3Mzc2NTQyMmZiNDJiNzBhNGVhZGUyOSIsInZlcnNpb24iOjF9.eHjLmw3K02OU69R2Au8eyuSqT3aBDHgZCn8jSzE3_urD6EUSSsLxUpiAYR4BGLD_U6-ZKcdxVo_A2rdXqvUJDA
- type: recall
value: 0.9301801801801802
name: Recall
verified: true
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiMGIzM2E3MTI2Mzc2MDYwNmU3ZTVjYmZmZDBkNjY4ZTc5MGY0Y2FkNDU3NjY1MmVkNmE3Y2QzMzAwZDZhOWY1NiIsInZlcnNpb24iOjF9.PUZlqmct13-rJWBXdHm5tdkXgETL9F82GNbbSR4hI8MB-v39KrK59cqzFC2Ac7kJe_DtOeUyosj34O_mFt_1DQ
- type: auc
value: 0.9716626673402374
name: AUC
verified: true
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiMDM0YWIwZmQ4YjUwOGZmMWU2MjI1YjIxZGQ2MzNjMzRmZmYxMzZkNGFjODhlMDcyZDM1Y2RkMWZlOWQ0MWYwNSIsInZlcnNpb24iOjF9.E7GRlAXmmpEkTHlXheVkuL1W4WNjv4JO3qY_WCVsTVKiO7bUu0UVjPIyQ6g-J1OxsfqZmW3Leli1wY8vPBNNCQ
- type: f1
value: 0.9137168141592922
name: F1
verified: true
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiMGU4MjNmOGYwZjZjMDQ1ZTkyZTA4YTc1MWYwOTM0NDM4ZWY1ZGVkNDY5MzNhYTQyZGFlNzIyZmUwMDg3NDU0NyIsInZlcnNpb24iOjF9.mW5ftkq50Se58M-jm6a2Pu93QeKa3MfV7xcBwvG3PSB_KNJxZWTCpfMQp-Cmx_EMlmI2siKOyd8akYjJUrzJCA
- type: loss
value: 0.39013850688934326
name: loss
verified: true
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiMTZiNzAyZDc0MzUzMmE1MGJiN2JlYzFiODE5ZTNlNGE4MmI4YzRiMTc2ODEzMTUwZmEzOTgxNzc4YjJjZTRmNiIsInZlcnNpb24iOjF9.VqIC7uYC-ZZ8ss9zQOlRV39YVOOLc5R36sIzCcVz8lolh61ux_5djm2XjpP6ARc6KqEnXC4ZtfNXsX2HZfrtCQ
- task:
type: text-classification
name: Text Classification
dataset:
name: sst2
type: sst2
config: default
split: train
metrics:
- type: accuracy
value: 0.9885521685548412
name: Accuracy
verified: true
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiY2I3NzU3YzhmMDkxZTViY2M3OTY1NmI0ZTdmMDQxNjNjYzJiZmQxNzczM2E4YmExYTY5ODY0NDBkY2I4ZjNkOCIsInZlcnNpb24iOjF9.4Gtk3FeVc9sPWSqZIaeUXJ9oVlPzm-NmujnWpK2y5s1Vhp1l6Y1pK5_78wW0-NxSvQqV6qd5KQf_OAEpVAkQDA
- type: precision
value: 0.9881965062029833
name: Precision Macro
verified: true
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiZDdlZDMzY2I3MTAwYTljNmM4MGMyMzU2YjAzZDg1NDYwN2ZmM2Y5OWZhMjUyMGJiNjY1YmZiMzFhMDI2ODFhNyIsInZlcnNpb24iOjF9.cqmv6yBxu4St2mykRWrZ07tDsiSLdtLTz2hbqQ7Gm1rMzq9tdlkZ8MyJRxtME_Y8UaOG9rs68pV-gKVUs8wABw
- type: precision
value: 0.9885521685548412
name: Precision Micro
verified: true
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiZjFlYzAzNmE1YjljNjUwNzBjZjEzZDY0ZDQyMmY5ZWM2OTBhNzNjYjYzYTk1YWE1NjU3YTMxZDQwOTE1Y2FkNyIsInZlcnNpb24iOjF9.jnCHOkUHuAOZZ_ZMVOnetx__OVJCS6LOno4caWECAmfrUaIPnPNV9iJ6izRO3sqkHRmxYpWBb-27GJ4N3LU-BQ
- type: precision
value: 0.9885639626373408
name: Precision Weighted
verified: true
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiZGUyODFjNjBlNTE2MTY3ZDAxOGU1N2U0YjUyY2NiZjhkOGVmYThjYjBkNGU3NTRkYzkzNDQ2MmMwMjkwMWNiMyIsInZlcnNpb24iOjF9.zTNabMwApiZyXdr76QUn7WgGB7D7lP-iqS3bn35piqVTNsv3wnKjZOaKFVLIUvtBXq4gKw7N2oWxvWc4OcSNDg
- type: recall
value: 0.9886145346602994
name: Recall Macro
verified: true
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiNTU1YjlhODU3YTkyNTdiZDcwZGFlZDBiYjY0N2NjMGM2NTRiNjQ3MDNjNGMxOWY2ZGQ4NWU1YmMzY2UwZTI3YSIsInZlcnNpb24iOjF9.xaLPY7U-wHsJ3DDui1yyyM-xWjL0Jz5puRThy7fczal9x05eKEQ9s0a_WD-iLmapvJs0caXpV70hDe2NLcs-DA
- type: recall
value: 0.9885521685548412
name: Recall Micro
verified: true
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiODE0YTU0MDBlOGY4YzU0MjY5MzA3OTk2OGNhOGVkMmU5OGRjZmFiZWI2ZjY5ODEzZTQzMTI0N2NiOTVkNDliYiIsInZlcnNpb24iOjF9.SOt1baTBbuZRrsvGcak2sUwoTrQzmNCbyV2m1_yjGsU48SBH0NcKXicidNBSnJ6ihM5jf_Lv_B5_eOBkLfNWDQ
- type: recall
value: 0.9885521685548412
name: Recall Weighted
verified: true
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiZWNkNmM0ZGRlNmYxYzIwNDk4OTI5MzIwZWU1NzZjZDVhMDcyNDFlMjBhNDQxODU5OWMwMWNhNGEzNjY3ZGUyOSIsInZlcnNpb24iOjF9.b15Fh70GwtlG3cSqPW-8VEZT2oy0CtgvgEOtWiYonOovjkIQ4RSLFVzVG-YfslaIyfg9RzMWzjhLnMY7Bpn2Aw
- type: f1
value: 0.9884019815052447
name: F1 Macro
verified: true
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiYmM4NjQ5Yjk5ODRhYTU1MTY3MmRhZDBmODM1NTg3OTFiNWM4NDRmYjI0MzZkNmQ1MzE3MzcxODZlYzBkYTMyYSIsInZlcnNpb24iOjF9.74RaDK8nBVuGRl2Se_-hwQvP6c4lvVxGHpcCWB4uZUCf2_HoC9NT9u7P3pMJfH_tK2cpV7U3VWGgSDhQDi-UBQ
- type: f1
value: 0.9885521685548412
name: F1 Micro
verified: true
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiZDRmYWRmMmQ0YjViZmQxMzhhYTUyOTE1MTc0ZDU1ZjQyZjFhMDYzYzMzZDE0NzZlYzQyOTBhMTBhNmM5NTlkMiIsInZlcnNpb24iOjF9.VMn_psdAHIZTlW6GbjERZDe8MHhwzJ0rbjV_VJyuMrsdOh5QDmko-wEvaBWNEdT0cEKsbggm-6jd3Gh81PfHAQ
- type: f1
value: 0.9885546181087554
name: F1 Weighted
verified: true
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiMjUyZWFhZDZhMGQ3MzBmYmRiNDVmN2FkZDBjMjk3ODk0OTAxNGZkMWE0NzU5ZjI0NzE0NGZiNzM0N2Y2NDYyOSIsInZlcnNpb24iOjF9.YsXBhnzEEFEW6jw3mQlFUuIrW7Gabad2Ils-iunYJr-myg0heF8NEnEWABKFE1SnvCWt-69jkLza6SupeyLVCA
- type: loss
value: 0.040652573108673096
name: loss
verified: true
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiZTc3YjU3MjdjMzkxODA5MjU5NGUyY2NkMGVhZDg3ZWEzMmU1YWVjMmI0NmU2OWEyZTkzMTVjNDZiYTc0YjIyNCIsInZlcnNpb24iOjF9.lA90qXZVYiILHMFlr6t6H81Oe8a-4KmeX-vyCC1BDia2ofudegv6Vb46-4RzmbtuKeV6yy6YNNXxXxqVak1pAg
---
# DistilBERT base uncased finetuned SST-2
Expand All @@ -345,26 +197,6 @@ def get_response(input_text):
- [Model Documentation](https://huggingface.co/docs/transformers/main/en/model_doc/distilbert#transformers.DistilBertForSequenceClassification)
- [DistilBERT paper](https://arxiv.org/abs/1910.01108)
## How to Get Started With the Model
Example of single-label classification:
​​
```python
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
predicted_class_id = logits.argmax().item()
model.config.id2label[predicted_class_id]
```
## Uses
#### Direct Use
Expand All @@ -379,13 +211,7 @@ def get_response(input_text):
Based on a few experimentations, we observed that this model could produce biased predictions that target underrepresented populations.
For instance, for sentences like `This film was filmed in COUNTRY`, this binary classification model will give radically different probabilities for the positive label depending on the country (0.89 if the country is France, but 0.08 if the country is Afghanistan) when nothing in the input indicates such a strong semantic shift. In this [colab](https://colab.research.google.com/gist/ageron/fb2f64fb145b4bc7c49efc97e5f114d3/biasmap.ipynb), [Aurélien Géron](https://twitter.com/aureliengeron) made an interesting map plotting these probabilities for each country.
<img src="https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/map.jpeg" alt="Map of positive probabilities per country." width="500"/>
We strongly advise users to thoroughly probe these aspects on their use-cases in order to evaluate the risks of this model. We recommend looking at the following bias evaluation datasets as a place to start: [WinoBias](https://huggingface.co/datasets/wino_bias), [WinoGender](https://huggingface.co/datasets/super_glue), [Stereoset](https://huggingface.co/datasets/stereoset).
For instance, for sentences like `This film was filmed in COUNTRY`, this binary classification model will give radically different probabilities for the positive label depending on the country (0.89 if the country is France, but 0.08 if the country is Afghanistan) when nothing in the input indicates such a strong semantic shift.
# Training
Expand All @@ -394,19 +220,8 @@ def get_response(input_text):
The authors use the following Stanford Sentiment Treebank([sst2](https://huggingface.co/datasets/sst2)) corpora for the model.
#### Training Procedure
###### Fine-tuning hyper-parameters
- learning_rate = 1e-5
- batch_size = 32
- warmup = 600
- max_seq_length = 128
- num_train_epochs = 3.0
```
:''' # noqa: E501
:""" # noqa: E501


def generate_hypothetical_model_description(
Expand Down
18 changes: 14 additions & 4 deletions prompt2model/utils/api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def __init__(
if max_tokens is None:
try:
self.max_tokens = litellm.utils.get_max_tokens(model_name)
if isinstance(self.max_tokens, dict):
self.max_tokens = self.max_tokens["max_tokens"]
except Exception:
pass

Expand Down Expand Up @@ -86,7 +88,12 @@ def generate_one_completion(
An OpenAI-like response object if there were no errors in generation.
In case of API-specific error, Exception object is captured and returned.
"""
max_tokens = self.max_tokens or 4 * count_tokens_from_string(prompt)
num_prompt_tokens = count_tokens_from_string(prompt)
if self.max_tokens:
max_tokens = self.max_tokens - num_prompt_tokens
else:
max_tokens = 4 * num_prompt_tokens

response = completion( # completion gets the key from os.getenv
model=self.model_name,
messages=[
Expand Down Expand Up @@ -169,9 +176,12 @@ async def _throttled_completion_acreate(
await asyncio.sleep(10)
return {"choices": [{"message": {"content": ""}}]}

max_tokens = self.max_tokens or 4 * max(
count_tokens_from_string(prompt) for prompt in prompts
)
num_prompt_tokens = max(count_tokens_from_string(prompt) for prompt in prompts)
if self.max_tokens:
max_tokens = self.max_tokens - num_prompt_tokens
else:
max_tokens = 4 * num_prompt_tokens

async_responses = [
_throttled_completion_acreate(
model=self.model_name,
Expand Down

0 comments on commit 5ca2858

Please sign in to comment.