Skip to content

Commit 5ca2858

Browse files
authored
Fix bugs in API tools (to fix #348) (#349)
* Use litellm.utils.get_max_tokens appropriatel * Set max_tokens appropriately * Use 16k model for HyDE * Use updated max_tokens logic in async call * Reduce context size needed for HyDE
1 parent 7223bf8 commit 5ca2858

File tree

2 files changed

+17
-192
lines changed

2 files changed

+17
-192
lines changed

prompt2model/model_retriever/generate_hypothetical_document.py

Lines changed: 3 additions & 188 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from prompt2model.prompt_parser import PromptSpec
88
from prompt2model.utils import API_ERRORS, api_tools, handle_api_error
99

10-
PROMPT_PREFIX = '''HuggingFace contains models, which are each given a user-generated description. The first section of the description, delimited with two "---" lines, consists of a YAML description of the model. This may contain fields like "language" (supported by model), "datasets" (used to train the model), "tags" (e.g. tasks relevant to the model), and "metrics" (used to evaluate the model). Create a hypothetical HuggingFace model description that would satisfy a given user instruction. Here are some examples:
10+
PROMPT_PREFIX = """HuggingFace contains models, which are each given a user-generated description. The first section of the description, delimited with two "---" lines, consists of a YAML description of the model. This may contain fields like "language" (supported by model), "datasets" (used to train the model), "tags" (e.g. tasks relevant to the model), and "metrics" (used to evaluate the model). Create a hypothetical HuggingFace model description that would satisfy a given user instruction. Here are some examples:
1111
1212
Instruction: "Give me some translation from English to Vietnamese. Input English and output Vietnamese."
1313
Hypothetical model description:
@@ -70,10 +70,6 @@
7070
7171
- prepro: normalization + SentencePiece (spm32k,spm32k)
7272
73-
- url_model: https://object.pouta.csc.fi/Tatoeba-MT-models/eng-vie/opus-2020-06-17.zip
74-
75-
- url_test_set: https://object.pouta.csc.fi/Tatoeba-MT-models/eng-vie/opus-2020-06-17.test.txt
76-
7773
- src_alpha3: eng
7874
7975
- tgt_alpha3: vie
@@ -102,14 +98,6 @@
10298
10399
- long_pair: eng-vie
104100
105-
- helsinki_git_sha: 480fcbe0ee1bf4774bcbe6226ad9f58e63f6c535
106-
107-
- transformers_git_sha: 2207e5d8cb224e954a7cba69fa4ac2309e9ff30b
108-
109-
- port_machine: brutasse
110-
111-
- port_time: 2020-08-21-14:41
112-
113101
114102
Instruction: "I want to summarize things like news articles."
115103
Hypothetical model description:
@@ -144,58 +132,11 @@
144132
type: rouge
145133
value: 23.8845
146134
verified: true
147-
- name: ROUGE-LSUM
148-
type: rouge
149-
value: 32.9017
150-
verified: true
151-
- name: loss
152-
type: loss
153-
value: 2.5757133960723877
154-
verified: true
155-
- name: gen_len
156-
type: gen_len
157-
value: 76.3984
158-
verified: true
159135
---
160136
161137
## Model description
162138
[PEGASUS](https://github.com/google-research/pegasus) fine-tuned for summarization
163139
164-
## Install "sentencepiece" library required for tokenizer
165-
```
166-
pip install sentencepiece
167-
```
168-
169-
## Model in Action 🚀
170-
```
171-
import torch
172-
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
173-
model_name = 'tuner007/pegasus_summarizer'
174-
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
175-
tokenizer = PegasusTokenizer.from_pretrained(model_name)
176-
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device)
177-
178-
def get_response(input_text):
179-
batch = tokenizer([input_text],truncation=True,padding='longest',max_length=1024, return_tensors="pt").to(torch_device)
180-
gen_out = model.generate(**batch,max_length=128,num_beams=5, num_return_sequences=1, temperature=1.5)
181-
output_text = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
182-
return output_text
183-
```
184-
#### Example:
185-
context = """"
186-
India wicket-keeper batsman Rishabh Pant has said someone from the crowd threw a ball on pacer Mohammed Siraj while he was fielding in the ongoing third Test against England on Wednesday. Pant revealed the incident made India skipper Virat Kohli "upset". "I think, somebody threw a ball inside, at Siraj, so he [Kohli] was upset," said Pant in a virtual press conference after the close of the first day\'s play."You can say whatever you want to chant, but don\'t throw things at the fielders and all those things. It is not good for cricket, I guess," he added.In the third session of the opening day of the third Test, a section of spectators seemed to have asked Siraj the score of the match to tease the pacer. The India pacer however came with a brilliant reply as he gestured 1-0 (India leading the Test series) towards the crowd.Earlier this month, during the second Test match, there was some bad crowd behaviour on a show as some unruly fans threw champagne corks at India batsman KL Rahul.Kohli also intervened and he was seen gesturing towards the opening batsman to know more about the incident. An over later, the TV visuals showed that many champagne corks were thrown inside the playing field, and the Indian players were visibly left frustrated.Coming back to the game, after bundling out India for 78, openers Rory Burns and Haseeb Hameed ensured that England took the honours on the opening day of the ongoing third Test.At stumps, England\'s score reads 120/0 and the hosts have extended their lead to 42 runs. For the Three Lions, Burns (52*) and Hameed (60*) are currently unbeaten at the crease.Talking about the pitch on opening day, Pant said, "They took the heavy roller, the wicket was much more settled down, and they batted nicely also," he said. "But when we batted, the wicket was slightly soft, and they bowled in good areas, but we could have applied [ourselves] much better."Both England batsmen managed to see off the final session and the hosts concluded the opening day with all ten wickets intact, extending the lead to 42.(ANI)
187-
"""
188-
189-
```
190-
get_response(context)
191-
```
192-
#### Output:
193-
Team India wicketkeeper-batsman Rishabh Pant has said that Virat Kohli was "upset" after someone threw a ball on pacer Mohammed Siraj while he was fielding in the ongoing third Test against England. "You can say whatever you want to chant, but don't throw things at the fielders and all those things. It's not good for cricket, I guess," Pant added.'
194-
195-
#### [Inshort](https://www.inshorts.com/) (60 words News summary app, rated 4.4 by 5,27,246+ users on android playstore) summary:
196-
India wicketkeeper-batsman Rishabh Pant has revealed that captain Virat Kohli was upset with the crowd during the first day of Leeds Test against England because someone threw a ball at pacer Mohammed Siraj. Pant added, "You can say whatever you want to chant, but don't throw things at the fielders and all those things. It is not good for cricket."
197-
198-
199140
> Created by [Arpit Rajauria](https://twitter.com/arpit_rajauria)
200141
[![Twitter icon](https://cdn0.iconfinder.com/data/icons/shift-logotypes/32/Twitter-32.png)](https://twitter.com/arpit_rajauria)
201142
@@ -233,95 +174,6 @@ def get_response(input_text):
233174
name: Accuracy
234175
verified: true
235176
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiN2YyOGMxYjY2Y2JhMjkxNjIzN2FmMjNiNmM2ZWViNGY3MTNmNWI2YzhiYjYxZTY0ZGUyN2M1NGIxZjRiMjQwZiIsInZlcnNpb24iOjF9.uui0srxV5ZHRhxbYN6082EZdwpnBgubPJ5R2-Wk8HTWqmxYE3QHidevR9LLAhidqGw6Ih93fK0goAXncld_gBg
236-
- type: precision
237-
value: 0.8978260869565218
238-
name: Precision
239-
verified: true
240-
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiMzgwYTYwYjA2MmM0ZTYwNDk0M2NmNTBkZmM2NGNhYzQ1OGEyN2NkNDQ3Mzc2NTQyMmZiNDJiNzBhNGVhZGUyOSIsInZlcnNpb24iOjF9.eHjLmw3K02OU69R2Au8eyuSqT3aBDHgZCn8jSzE3_urD6EUSSsLxUpiAYR4BGLD_U6-ZKcdxVo_A2rdXqvUJDA
241-
- type: recall
242-
value: 0.9301801801801802
243-
name: Recall
244-
verified: true
245-
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiMGIzM2E3MTI2Mzc2MDYwNmU3ZTVjYmZmZDBkNjY4ZTc5MGY0Y2FkNDU3NjY1MmVkNmE3Y2QzMzAwZDZhOWY1NiIsInZlcnNpb24iOjF9.PUZlqmct13-rJWBXdHm5tdkXgETL9F82GNbbSR4hI8MB-v39KrK59cqzFC2Ac7kJe_DtOeUyosj34O_mFt_1DQ
246-
- type: auc
247-
value: 0.9716626673402374
248-
name: AUC
249-
verified: true
250-
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiMDM0YWIwZmQ4YjUwOGZmMWU2MjI1YjIxZGQ2MzNjMzRmZmYxMzZkNGFjODhlMDcyZDM1Y2RkMWZlOWQ0MWYwNSIsInZlcnNpb24iOjF9.E7GRlAXmmpEkTHlXheVkuL1W4WNjv4JO3qY_WCVsTVKiO7bUu0UVjPIyQ6g-J1OxsfqZmW3Leli1wY8vPBNNCQ
251-
- type: f1
252-
value: 0.9137168141592922
253-
name: F1
254-
verified: true
255-
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiMGU4MjNmOGYwZjZjMDQ1ZTkyZTA4YTc1MWYwOTM0NDM4ZWY1ZGVkNDY5MzNhYTQyZGFlNzIyZmUwMDg3NDU0NyIsInZlcnNpb24iOjF9.mW5ftkq50Se58M-jm6a2Pu93QeKa3MfV7xcBwvG3PSB_KNJxZWTCpfMQp-Cmx_EMlmI2siKOyd8akYjJUrzJCA
256-
- type: loss
257-
value: 0.39013850688934326
258-
name: loss
259-
verified: true
260-
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiMTZiNzAyZDc0MzUzMmE1MGJiN2JlYzFiODE5ZTNlNGE4MmI4YzRiMTc2ODEzMTUwZmEzOTgxNzc4YjJjZTRmNiIsInZlcnNpb24iOjF9.VqIC7uYC-ZZ8ss9zQOlRV39YVOOLc5R36sIzCcVz8lolh61ux_5djm2XjpP6ARc6KqEnXC4ZtfNXsX2HZfrtCQ
261-
- task:
262-
type: text-classification
263-
name: Text Classification
264-
dataset:
265-
name: sst2
266-
type: sst2
267-
config: default
268-
split: train
269-
metrics:
270-
- type: accuracy
271-
value: 0.9885521685548412
272-
name: Accuracy
273-
verified: true
274-
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiY2I3NzU3YzhmMDkxZTViY2M3OTY1NmI0ZTdmMDQxNjNjYzJiZmQxNzczM2E4YmExYTY5ODY0NDBkY2I4ZjNkOCIsInZlcnNpb24iOjF9.4Gtk3FeVc9sPWSqZIaeUXJ9oVlPzm-NmujnWpK2y5s1Vhp1l6Y1pK5_78wW0-NxSvQqV6qd5KQf_OAEpVAkQDA
275-
- type: precision
276-
value: 0.9881965062029833
277-
name: Precision Macro
278-
verified: true
279-
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiZDdlZDMzY2I3MTAwYTljNmM4MGMyMzU2YjAzZDg1NDYwN2ZmM2Y5OWZhMjUyMGJiNjY1YmZiMzFhMDI2ODFhNyIsInZlcnNpb24iOjF9.cqmv6yBxu4St2mykRWrZ07tDsiSLdtLTz2hbqQ7Gm1rMzq9tdlkZ8MyJRxtME_Y8UaOG9rs68pV-gKVUs8wABw
280-
- type: precision
281-
value: 0.9885521685548412
282-
name: Precision Micro
283-
verified: true
284-
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiZjFlYzAzNmE1YjljNjUwNzBjZjEzZDY0ZDQyMmY5ZWM2OTBhNzNjYjYzYTk1YWE1NjU3YTMxZDQwOTE1Y2FkNyIsInZlcnNpb24iOjF9.jnCHOkUHuAOZZ_ZMVOnetx__OVJCS6LOno4caWECAmfrUaIPnPNV9iJ6izRO3sqkHRmxYpWBb-27GJ4N3LU-BQ
285-
- type: precision
286-
value: 0.9885639626373408
287-
name: Precision Weighted
288-
verified: true
289-
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiZGUyODFjNjBlNTE2MTY3ZDAxOGU1N2U0YjUyY2NiZjhkOGVmYThjYjBkNGU3NTRkYzkzNDQ2MmMwMjkwMWNiMyIsInZlcnNpb24iOjF9.zTNabMwApiZyXdr76QUn7WgGB7D7lP-iqS3bn35piqVTNsv3wnKjZOaKFVLIUvtBXq4gKw7N2oWxvWc4OcSNDg
290-
- type: recall
291-
value: 0.9886145346602994
292-
name: Recall Macro
293-
verified: true
294-
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiNTU1YjlhODU3YTkyNTdiZDcwZGFlZDBiYjY0N2NjMGM2NTRiNjQ3MDNjNGMxOWY2ZGQ4NWU1YmMzY2UwZTI3YSIsInZlcnNpb24iOjF9.xaLPY7U-wHsJ3DDui1yyyM-xWjL0Jz5puRThy7fczal9x05eKEQ9s0a_WD-iLmapvJs0caXpV70hDe2NLcs-DA
295-
- type: recall
296-
value: 0.9885521685548412
297-
name: Recall Micro
298-
verified: true
299-
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiODE0YTU0MDBlOGY4YzU0MjY5MzA3OTk2OGNhOGVkMmU5OGRjZmFiZWI2ZjY5ODEzZTQzMTI0N2NiOTVkNDliYiIsInZlcnNpb24iOjF9.SOt1baTBbuZRrsvGcak2sUwoTrQzmNCbyV2m1_yjGsU48SBH0NcKXicidNBSnJ6ihM5jf_Lv_B5_eOBkLfNWDQ
300-
- type: recall
301-
value: 0.9885521685548412
302-
name: Recall Weighted
303-
verified: true
304-
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiZWNkNmM0ZGRlNmYxYzIwNDk4OTI5MzIwZWU1NzZjZDVhMDcyNDFlMjBhNDQxODU5OWMwMWNhNGEzNjY3ZGUyOSIsInZlcnNpb24iOjF9.b15Fh70GwtlG3cSqPW-8VEZT2oy0CtgvgEOtWiYonOovjkIQ4RSLFVzVG-YfslaIyfg9RzMWzjhLnMY7Bpn2Aw
305-
- type: f1
306-
value: 0.9884019815052447
307-
name: F1 Macro
308-
verified: true
309-
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiYmM4NjQ5Yjk5ODRhYTU1MTY3MmRhZDBmODM1NTg3OTFiNWM4NDRmYjI0MzZkNmQ1MzE3MzcxODZlYzBkYTMyYSIsInZlcnNpb24iOjF9.74RaDK8nBVuGRl2Se_-hwQvP6c4lvVxGHpcCWB4uZUCf2_HoC9NT9u7P3pMJfH_tK2cpV7U3VWGgSDhQDi-UBQ
310-
- type: f1
311-
value: 0.9885521685548412
312-
name: F1 Micro
313-
verified: true
314-
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiZDRmYWRmMmQ0YjViZmQxMzhhYTUyOTE1MTc0ZDU1ZjQyZjFhMDYzYzMzZDE0NzZlYzQyOTBhMTBhNmM5NTlkMiIsInZlcnNpb24iOjF9.VMn_psdAHIZTlW6GbjERZDe8MHhwzJ0rbjV_VJyuMrsdOh5QDmko-wEvaBWNEdT0cEKsbggm-6jd3Gh81PfHAQ
315-
- type: f1
316-
value: 0.9885546181087554
317-
name: F1 Weighted
318-
verified: true
319-
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiMjUyZWFhZDZhMGQ3MzBmYmRiNDVmN2FkZDBjMjk3ODk0OTAxNGZkMWE0NzU5ZjI0NzE0NGZiNzM0N2Y2NDYyOSIsInZlcnNpb24iOjF9.YsXBhnzEEFEW6jw3mQlFUuIrW7Gabad2Ils-iunYJr-myg0heF8NEnEWABKFE1SnvCWt-69jkLza6SupeyLVCA
320-
- type: loss
321-
value: 0.040652573108673096
322-
name: loss
323-
verified: true
324-
verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiZTc3YjU3MjdjMzkxODA5MjU5NGUyY2NkMGVhZDg3ZWEzMmU1YWVjMmI0NmU2OWEyZTkzMTVjNDZiYTc0YjIyNCIsInZlcnNpb24iOjF9.lA90qXZVYiILHMFlr6t6H81Oe8a-4KmeX-vyCC1BDia2ofudegv6Vb46-4RzmbtuKeV6yy6YNNXxXxqVak1pAg
325177
---
326178
327179
# DistilBERT base uncased finetuned SST-2
@@ -345,26 +197,6 @@ def get_response(input_text):
345197
- [Model Documentation](https://huggingface.co/docs/transformers/main/en/model_doc/distilbert#transformers.DistilBertForSequenceClassification)
346198
- [DistilBERT paper](https://arxiv.org/abs/1910.01108)
347199
348-
## How to Get Started With the Model
349-
350-
Example of single-label classification:
351-
​​
352-
```python
353-
import torch
354-
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
355-
356-
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
357-
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
358-
359-
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
360-
with torch.no_grad():
361-
logits = model(**inputs).logits
362-
363-
predicted_class_id = logits.argmax().item()
364-
model.config.id2label[predicted_class_id]
365-
366-
```
367-
368200
## Uses
369201
370202
#### Direct Use
@@ -379,13 +211,7 @@ def get_response(input_text):
379211
380212
Based on a few experimentations, we observed that this model could produce biased predictions that target underrepresented populations.
381213
382-
For instance, for sentences like `This film was filmed in COUNTRY`, this binary classification model will give radically different probabilities for the positive label depending on the country (0.89 if the country is France, but 0.08 if the country is Afghanistan) when nothing in the input indicates such a strong semantic shift. In this [colab](https://colab.research.google.com/gist/ageron/fb2f64fb145b4bc7c49efc97e5f114d3/biasmap.ipynb), [Aurélien Géron](https://twitter.com/aureliengeron) made an interesting map plotting these probabilities for each country.
383-
384-
<img src="https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/map.jpeg" alt="Map of positive probabilities per country." width="500"/>
385-
386-
We strongly advise users to thoroughly probe these aspects on their use-cases in order to evaluate the risks of this model. We recommend looking at the following bias evaluation datasets as a place to start: [WinoBias](https://huggingface.co/datasets/wino_bias), [WinoGender](https://huggingface.co/datasets/super_glue), [Stereoset](https://huggingface.co/datasets/stereoset).
387-
388-
214+
For instance, for sentences like `This film was filmed in COUNTRY`, this binary classification model will give radically different probabilities for the positive label depending on the country (0.89 if the country is France, but 0.08 if the country is Afghanistan) when nothing in the input indicates such a strong semantic shift.
389215
390216
# Training
391217
@@ -394,19 +220,8 @@ def get_response(input_text):
394220
395221
396222
The authors use the following Stanford Sentiment Treebank([sst2](https://huggingface.co/datasets/sst2)) corpora for the model.
397-
398-
#### Training Procedure
399-
400-
###### Fine-tuning hyper-parameters
401-
402-
403-
- learning_rate = 1e-5
404-
- batch_size = 32
405-
- warmup = 600
406-
- max_seq_length = 128
407-
- num_train_epochs = 3.0
408223
```
409-
:''' # noqa: E501
224+
:""" # noqa: E501
410225

411226

412227
def generate_hypothetical_model_description(

prompt2model/utils/api_tools.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def __init__(
5858
if max_tokens is None:
5959
try:
6060
self.max_tokens = litellm.utils.get_max_tokens(model_name)
61+
if isinstance(self.max_tokens, dict):
62+
self.max_tokens = self.max_tokens["max_tokens"]
6163
except Exception:
6264
pass
6365

@@ -86,7 +88,12 @@ def generate_one_completion(
8688
An OpenAI-like response object if there were no errors in generation.
8789
In case of API-specific error, Exception object is captured and returned.
8890
"""
89-
max_tokens = self.max_tokens or 4 * count_tokens_from_string(prompt)
91+
num_prompt_tokens = count_tokens_from_string(prompt)
92+
if self.max_tokens:
93+
max_tokens = self.max_tokens - num_prompt_tokens
94+
else:
95+
max_tokens = 4 * num_prompt_tokens
96+
9097
response = completion( # completion gets the key from os.getenv
9198
model=self.model_name,
9299
messages=[
@@ -169,9 +176,12 @@ async def _throttled_completion_acreate(
169176
await asyncio.sleep(10)
170177
return {"choices": [{"message": {"content": ""}}]}
171178

172-
max_tokens = self.max_tokens or 4 * max(
173-
count_tokens_from_string(prompt) for prompt in prompts
174-
)
179+
num_prompt_tokens = max(count_tokens_from_string(prompt) for prompt in prompts)
180+
if self.max_tokens:
181+
max_tokens = self.max_tokens - num_prompt_tokens
182+
else:
183+
max_tokens = 4 * num_prompt_tokens
184+
175185
async_responses = [
176186
_throttled_completion_acreate(
177187
model=self.model_name,

0 commit comments

Comments
 (0)