Skip to content

Commit aa54757

Browse files
authored
Merge pull request #70 from cubenlp/rex/change-param
change params
2 parents ee2b2a2 + c4a2656 commit aa54757

File tree

6 files changed

+45
-50
lines changed

6 files changed

+45
-50
lines changed

chattool/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
__author__ = """Rex Wang"""
44
__email__ = '1073853456@qq.com'
5-
__version__ = '3.0.0'
5+
__version__ = '3.0.1'
66

77
import os, sys, requests
88
from .chattool import Chat, Resp
@@ -129,7 +129,7 @@ def debug_log( net_url:str="https://www.baidu.com"
129129
if test_response:
130130
print("\nTest response:", message)
131131
chat = Chat(message)
132-
chat.getresponse(max_requests=3)
132+
chat.getresponse(max_tries=3)
133133
chat.print_log()
134134

135135
print("\nDebug is finished.")

chattool/asynctool.py

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ async def async_post( session
1010
, url
1111
, data:str
1212
, headers:Dict
13-
, max_requests:int=1
13+
, max_tries:int=1
1414
, timeinterval=0
1515
, timeout=0):
1616
"""Asynchronous post request
@@ -21,7 +21,7 @@ async def async_post( session
2121
url (str): chat completion url
2222
data (str): payload of the request
2323
headers (Dict): request headers
24-
max_requests (int, optional): maximum number of requests to make. Defaults to 1.
24+
max_tries (int, optional): maximum number of requests to make. Defaults to 1.
2525
timeinterval (int, optional): time interval between two API calls. Defaults to 0.
2626
timeout (int, optional): timeout for the API call. Defaults to 0(no timeout).
2727
@@ -30,15 +30,15 @@ async def async_post( session
3030
"""
3131
async with sem:
3232
ntries = 0
33-
while max_requests > 0:
33+
while max_tries > 0:
3434
try:
3535
async with session.post(url, headers=headers, data=data, timeout=timeout) as response:
3636
resp = await response.text()
3737
resp = Resp(json.loads(resp))
3838
assert resp.is_valid(), resp.error_message
3939
return resp
4040
except Exception as e:
41-
max_requests -= 1
41+
max_tries -= 1
4242
ntries += 1
4343
time.sleep(random.random() * timeinterval)
4444
print(f"Request Failed({ntries}):{e}")
@@ -50,11 +50,10 @@ async def async_process_msgs( chatlogs:List[List[Dict]]
5050
, chkpoint:str
5151
, api_key:str
5252
, chat_url:str
53-
, max_requests:int=1
54-
, ncoroutines:int=1
53+
, max_tries:int=1
54+
, nproc:int=1
5555
, timeout:int=0
5656
, timeinterval:int=0
57-
, max_tokens:Union[Callable, None]=None
5857
, **options
5958
)->List[bool]:
6059
"""Process messages asynchronously
@@ -63,38 +62,36 @@ async def async_process_msgs( chatlogs:List[List[Dict]]
6362
chatlogs (List[List[Dict]]): list of chat logs
6463
chkpoint (str): checkpoint file
6564
api_key (Union[str, None], optional): API key. Defaults to None.
66-
max_requests (int, optional): maximum number of requests to make. Defaults to 1.
67-
ncoroutines (int, optional): number of coroutines. Defaults to 5.
65+
max_tries (int, optional): maximum number of requests to make. Defaults to 1.
66+
nproc (int, optional): number of coroutines. Defaults to 5.
6867
timeout (int, optional): timeout for the API call. Defaults to 0(no timeout).
6968
timeinterval (int, optional): time interval between two API calls. Defaults to 0.
7069
7170
Returns:
7271
List[bool]: list of responses
7372
"""
7473
# load from checkpoint
75-
chats = load_chats(chkpoint, withid=True) if os.path.exists(chkpoint) else []
74+
chats = load_chats(chkpoint) if os.path.exists(chkpoint) else []
7675
chats.extend([None] * (len(chatlogs) - len(chats)))
7776
costs = [0] * len(chatlogs)
7877
headers = {
7978
"Content-Type": "application/json",
8079
"Authorization": "Bearer " + api_key
8180
}
82-
ncoroutines += 1 # add one for the main coroutine
83-
sem = asyncio.Semaphore(ncoroutines)
81+
nproc += 1 # add one for the main coroutine
82+
sem = asyncio.Semaphore(nproc)
8483
locker = asyncio.Lock()
8584

8685
async def chat_complete(ind, locker, chat_log, chkpoint, **options):
8786
payload = {"messages": chat_log}
8887
payload.update(options)
89-
if max_tokens is not None:
90-
payload['max_tokens'] = max_tokens(chat_log)
9188
data = json.dumps(payload)
9289
resp = await async_post( session=session
9390
, sem=sem
9491
, url=chat_url
9592
, data=data
9693
, headers=headers
97-
, max_requests=max_requests
94+
, max_tries=max_tries
9895
, timeinterval=timeinterval
9996
, timeout=timeout)
10097
## saving files
@@ -130,16 +127,16 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str]
130127
, model:str='gpt-3.5-turbo'
131128
, api_key:Union[str, None]=None
132129
, chat_url:Union[str, None]=None
133-
, max_requests:int=1
134-
, ncoroutines:int=1
130+
, max_tries:int=1
135131
, nproc:int=1
136132
, timeout:int=0
137133
, timeinterval:int=0
138134
, clearfile:bool=False
139135
, notrun:bool=False
140136
, msg2log:Union[Callable, None]=None
141137
, data2chat:Union[Callable, None]=None
142-
, max_tokens:Union[Callable, int, None]=None
138+
, max_requests:int=-1
139+
, ncoroutines:int=1
143140
, **options
144141
):
145142
"""Asynchronous chat completion
@@ -149,8 +146,7 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str]
149146
chkpoint (str): checkpoint file
150147
model (str, optional): model to use. Defaults to 'gpt-3.5-turbo'.
151148
api_key (Union[str, None], optional): API key. Defaults to None.
152-
max_requests (int, optional): maximum number of requests to make. Defaults to 1.
153-
ncoroutines (int, optional): (Deprecated)number of coroutines. Defaults to 1.
149+
max_tries (int, optional): maximum number of requests to make. Defaults to 1.
154150
nproc (int, optional): number of coroutines. Defaults to 1.
155151
timeout (int, optional): timeout for the API call. Defaults to 0(no timeout).
156152
timeinterval (int, optional): time interval between two API calls. Defaults to 0.
@@ -161,8 +157,8 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str]
161157
Defaults to None.
162158
data2chat (Union[Callable, None], optional): function to convert data to Chat object.
163159
Defaults to None.
164-
max_tokens (Union[Callable, int, None], optional): function to calculate the maximum
165-
number of tokens for the API call. Defaults to None.
160+
max_requests (int, optional): (Deprecated)maximum number of requests to make. Defaults to -1.
161+
ncoroutines (int, optional): (Deprecated)number of coroutines. Defaults to 1.
166162
167163
Returns:
168164
List[Dict]: list of responses
@@ -184,20 +180,18 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str]
184180
chat_url = os.path.join(chattool.base_url, "v1/chat/completions")
185181
chat_url = chattool.request.normalize_url(chat_url)
186182
# run async process
187-
assert ncoroutines > 0, "ncoroutines must be greater than 0!"
188-
if isinstance(max_tokens, int):
189-
max_tokens = lambda chat_log: max_tokens
183+
assert nproc > 0, "nproc must be greater than 0!"
184+
max_tries = max(max_tries, max_requests)
190185
args = {
191186
"chatlogs": chatlogs,
192187
"chkpoint": chkpoint,
193188
"api_key": api_key,
194189
"chat_url": chat_url,
195-
"max_requests": max_requests,
196-
"ncoroutines": nproc,
190+
"max_tries": max_tries,
191+
"nproc": nproc,
197192
"timeout": timeout,
198193
"timeinterval": timeinterval,
199194
"model": model,
200-
"max_tokens": max_tokens,
201195
**options
202196
}
203197
if notrun: # when use in Jupyter Notebook

chattool/chattool.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,20 +170,22 @@ def print_log(self, sep: Union[str, None]=None):
170170

171171
# Part2: response and async response
172172
def getresponse( self
173-
, max_requests:int=1
173+
, max_tries:int = 1
174174
, timeout:int = 0
175175
, timeinterval:int = 0
176176
, update:bool = True
177177
, stream:bool = False
178+
, max_requests:int=-1
178179
, **options)->Resp:
179180
"""Get the API response
180181
181182
Args:
182-
max_requests (int, optional): maximum number of requests to make. Defaults to 1.
183+
max_tries (int, optional): maximum number of requests to make. Defaults to 1.
183184
timeout (int, optional): timeout for the API call. Defaults to 0(no timeout).
184185
timeinterval (int, optional): time interval between two API calls. Defaults to 0.
185186
update (bool, optional): whether to update the chat log. Defaults to True.
186187
options (dict, optional): other options like `temperature`, `top_p`, etc.
188+
max_requests (int, optional): (deprecated) maximum number of requests to make. Defaults to -1(no limit
187189
188190
Returns:
189191
Resp: API response
@@ -194,10 +196,11 @@ def getresponse( self
194196
func_call = options.get('function_call', self.function_call)
195197
if api_key is None: warnings.warn("API key is not set!")
196198
msg, resp, numoftries = self.chat_log, None, 0
199+
max_tries = max(max_tries, max_requests)
197200
if stream: # TODO: add the `usage` key to the response
198201
warnings.warn("stream mode is not supported yet! Use `async_stream_responses()` instead.")
199202
# make requests
200-
while max_requests:
203+
while max_tries:
201204
try:
202205
# make API Call
203206
if funcs is not None: options['functions'] = funcs
@@ -209,7 +212,7 @@ def getresponse( self
209212
assert resp.is_valid(), resp.error_message
210213
break
211214
except Exception as e:
212-
max_requests -= 1
215+
max_tries -= 1
213216
numoftries += 1
214217
time.sleep(random.random() * timeinterval)
215218
print(f"Try again ({numoftries}):{e}\n")

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
with open('README.md') as readme_file:
88
readme = readme_file.read()
99

10-
VERSION = '3.0.0'
10+
VERSION = '3.0.1'
1111

1212
requirements = [
1313
'Click>=7.0', 'requests>=2.20', "responses>=0.23", 'aiohttp>=3.8',

tests/test_async.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ async def show_resp(chat):
4444
def test_async_process():
4545
chkpoint = testpath + "test_async.jsonl"
4646
t = time.time()
47-
resp = async_chat_completion(chatlogs[:1], chkpoint, clearfile=True, ncoroutines=3)
48-
resp = async_chat_completion(chatlogs, chkpoint, clearfile=True, ncoroutines=3)
47+
resp = async_chat_completion(chatlogs[:1], chkpoint, clearfile=True, nproc=3)
48+
resp = async_chat_completion(chatlogs, chkpoint, clearfile=True, nproc=3)
4949
assert all(resp)
5050
print(f"Time elapsed: {time.time() - t:.2f}s")
5151

@@ -55,7 +55,7 @@ def test_failed_async():
5555
chattool.api_key = "sk-invalid"
5656
chkpoint = testpath + "test_async_fail.jsonl"
5757
words = ["hello", "Can you help me?", "Do not translate this word", "I need help with my homework"]
58-
resp = async_chat_completion(words, chkpoint, clearfile=True, ncoroutines=3)
58+
resp = async_chat_completion(words, chkpoint, clearfile=True, nproc=3)
5959
chattool.api_key = api_key
6060

6161
def test_async_process_withfunc():
@@ -66,15 +66,13 @@ def msg2log(msg):
6666
chat.system("translate the words from English to Chinese")
6767
chat.user(msg)
6868
return chat.chat_log
69-
def max_tokens(chat_log):
70-
return Chat(chat_log).prompt_token()
71-
async_chat_completion(words, chkpoint, clearfile=True, ncoroutines=3, max_tokens=max_tokens, msg2log=msg2log)
69+
async_chat_completion(words, chkpoint, clearfile=True, nproc=3, msg2log=msg2log)
7270

7371
def test_normal_process():
7472
chkpoint = testpath + "test_nomal.jsonl"
7573
def data2chat(data):
7674
chat = Chat(data)
77-
chat.getresponse(max_requests=3)
75+
chat.getresponse(max_tries=3)
7876
return chat
7977
t = time.time()
8078
process_chats(chatlogs, data2chat, chkpoint, clearfile=True)

tests/test_function.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
def test_call_weather():
3535
chat = Chat("What's the weather like in Boston?")
36-
resp = chat.getresponse(functions=functions, function_call='auto', max_requests=3)
36+
resp = chat.getresponse(functions=functions, function_call='auto', max_tries=3)
3737
# TODO: wrap the response
3838
if resp.finish_reason == 'function_call':
3939
# test response from chat api
@@ -54,12 +54,12 @@ def test_auto_response():
5454
chat = Chat("What's the weather like in Boston?")
5555
chat.functions, chat.function_call = functions, 'auto'
5656
chat.name2func = name2func
57-
chat.autoresponse(max_requests=2)
57+
chat.autoresponse(max_tries=2)
5858
chat.print_log()
5959
chat.clear()
6060
# response with nonempty content
6161
chat.user("what is the result of 1+1, and What's the weather like in Boston?")
62-
chat.autoresponse(max_requests=2)
62+
chat.autoresponse(max_tries=2)
6363

6464
# generate docstring from functions
6565
def add(a: int, b: int) -> int:
@@ -100,20 +100,20 @@ def test_add_and_mult():
100100
chat.name2func = {'add': add} # dictionary of functions
101101
chat.function_call = 'auto' # auto decision
102102
# run until success: maxturns=-1
103-
chat.autoresponse(max_requests=3, display=True, timeinterval=2)
103+
chat.autoresponse(max_tries=3, display=True, timeinterval=2)
104104
# response should be finished
105105
chat.simplify()
106106
chat.print_log()
107107
# use the setfuncs method
108108
chat = Chat("find the value of 124842 * 3423424")
109109
chat.setfuncs([add, mult]) # multi choice
110-
chat.autoresponse(max_requests=3, timeinterval=2)
110+
chat.autoresponse(max_tries=3, timeinterval=2)
111111
chat.simplify() # simplify the chat log
112112
chat.print_log()
113113
# test multichoice
114114
chat.clear()
115115
chat.user("find the value of 23723 + 12312, and 23723 * 12312")
116-
chat.autoresponse(max_requests=3, timeinterval=2)
116+
chat.autoresponse(max_tries=3, timeinterval=2)
117117

118118
def test_mock_resp():
119119
chat = Chat("find the sum of 1235 and 3423")
@@ -122,12 +122,12 @@ def test_mock_resp():
122122
para = {'name': 'add', 'arguments': '{\n "a": 1235,\n "b": 3423\n}'}
123123
chat.assistant(content=None, function_call=para)
124124
chat.callfunction()
125-
chat.getresponse(max_requests=2)
125+
chat.getresponse(max_tries=2)
126126

127127
def test_use_exec_function():
128128
chat = Chat("find the result of sqrt(121314)")
129129
chat.setfuncs([exec_python_code])
130-
chat.autoresponse(max_requests=2)
130+
chat.autoresponse(max_tries=2)
131131

132132
def test_find_permutation_group():
133133
pass

0 commit comments

Comments
 (0)