Skip to content

Commit

Permalink
🔨[DEV] modify the optional kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
fairyshine committed Oct 1, 2024
1 parent bff9616 commit 15fb395
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FM_log/
log_FastMindAPI/
local_tests/

# Mac System File
Expand Down
3 changes: 2 additions & 1 deletion DEVELOPMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ GenerationConfig

https://github.com/huggingface/transformers/blob/main/src/transformers/generation/configuration_utils.py#L94


# Llama.cpp

llama.\_\_call\_\_

https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/llama.py#L1836

# OpenAI

Expand Down
Binary file added asset/DevelopmentStatistics.xlsx
Binary file not shown.
29 changes: 20 additions & 9 deletions src/fastmindapi/model/llama_cpp/LLM.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from ...utils.transform import convert_numpy_float32_to_float
from ...utils.transform import convert_numpy_float32_to_float, clean_dict_null_value
from ... import logger

class LlamacppLLM:
Expand All @@ -25,11 +25,19 @@ def generate(self,
max_new_tokens: Optional[int] = None,
return_logits: Optional[bool] = None,
logits_top_k: Optional[int] = None,
stop_strings: Optional[list[str]] = None):
stop_strings: Optional[list[str]] = None,
config: Optional[dict] = None):
optional_kwargs = {
"max_tokens": max_new_tokens,
"logprobs": logits_top_k if return_logits else None,
"stop": stop_strings,
"temperature": (config["temperature"] if "temperature" in config else None) if config else None,
"top_p": (config["top_p"] if "top_p" in config else None) if config else None,
"top_k": (config["top_k"] if "top_k" in config else None) if config else None,
"repeat_penalty": (config["repetition_penalty"] if "repetition_penalty" in config else None) if config else None,
}
response = self.model(input_text,
max_tokens=max_new_tokens,
logprobs = logits_top_k if return_logits else None,
stop = stop_strings,
**clean_dict_null_value(optional_kwargs),
echo = True,
)
full_text = response["choices"][0]["text"]
Expand Down Expand Up @@ -78,9 +86,12 @@ def chat(self,
logprobs: Optional[bool] = False,
top_logprobs: Optional[int] = 10,
stop: Optional[list[str]] = None):
optional_kwargs = {
"max_tokens": max_completion_tokens,
"logprobs": logprobs,
"top_logprobs": top_logprobs if logprobs else None,
"stop": stop
}
response = self.model.create_chat_completion(messages,
max_tokens=max_completion_tokens,
logprobs=logprobs,
top_logprobs=top_logprobs if logprobs else None,
stop=stop)
**clean_dict_null_value(optional_kwargs))
return response
33 changes: 20 additions & 13 deletions src/fastmindapi/model/openai/ChatModel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from ...utils.transform import convert_openai_logprobs
from ...utils.transform import convert_openai_logprobs, clean_dict_null_value
from ... import logger

class OpenAIChatModel:
Expand Down Expand Up @@ -28,7 +28,7 @@ def __call__(self, input_text: str,
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": input_text}
],
max_completion_tokens=max_new_tokens
max_tokens=max_new_tokens
)
return completion.choices[0].message.content
except Exception as e:
Expand All @@ -39,7 +39,16 @@ def generate(self,
max_new_tokens: Optional[int] = None,
return_logits: Optional[bool] = None,
logits_top_k: Optional[int] = None,
stop_strings: Optional[list[str]] = None):
stop_strings: Optional[list[str]] = None,
config: Optional[dict] = None):
optional_kwargs = {
"max_tokens": max_new_tokens,
"logprobs": return_logits,
"top_logprobs": logits_top_k if return_logits else None,
"stop": stop_strings,
"temperature": (config["temperature"] if "temperature" in config else None) if config else None,
"top_p": (config["top_p"] if "top_p" in config else None) if config else None,
}
while True:
try:
completion = self.client.chat.completions.create(
Expand All @@ -48,10 +57,7 @@ def generate(self,
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": input_text}
],
max_completion_tokens=max_new_tokens,
logprobs=return_logits,
top_logprobs=logits_top_k if return_logits else None,
stop=stop_strings
**clean_dict_null_value(optional_kwargs)
)
break
except Exception as e:
Expand All @@ -73,17 +79,18 @@ def generate(self,
def chat(self,
messages: list[dict],
max_completion_tokens: Optional[int] = None,
logprobs: Optional[bool] = False,
top_logprobs: Optional[int] = 10,
logprobs: Optional[bool] = None, # Defaults to false
top_logprobs: Optional[int] = None,
stop: Optional[list[str]] = None):
optional_kwargs = {"max_tokens": max_completion_tokens,
"logprobs": logprobs,
"top_logprobs": top_logprobs if logprobs else None,
"stop": stop}
try:
completion = self.client.chat.completions.create(
model= self.model_name,
messages=messages,
max_tokens=max_completion_tokens,
logprobs=logprobs,
top_logprobs=top_logprobs if logprobs else None,
stop=stop
**clean_dict_null_value(optional_kwargs)
)
return completion.model_dump()
except Exception as e:
Expand Down
18 changes: 12 additions & 6 deletions src/fastmindapi/model/transformers/CausalLM.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Optional

from ...utils.transform import clean_dict_null_value

class TransformersCausalLM:
def __init__(self,
tokenizer,
Expand Down Expand Up @@ -34,7 +36,8 @@ def generate(self,
max_new_tokens: Optional[int] = None,
return_logits: Optional[bool] = None,
logits_top_k: Optional[int] = None,
stop_strings: Optional[list[str]] = None):
stop_strings: Optional[list[str]] = None,
config: Optional[dict] = None):
import torch
import torch.nn.functional as F

Expand All @@ -43,10 +46,12 @@ def generate(self,
input_token_list = [self.tokenizer.decode([token_id]) for token_id in input_id_list]

with torch.no_grad():
outputs = self.model.generate(**inputs,
max_new_tokens=max_new_tokens,
stop_strings=stop_strings,
tokenizer=self.tokenizer)
generate_kwargs = {"generation_config": clean_dict_null_value(config) if config else None,
"max_new_tokens": max_new_tokens,
"stop_strings": stop_strings}
outputs = self.model.generate(inputs.input_ids,
**clean_dict_null_value(generate_kwargs),
tokenizer=self.tokenizer)
full_id_list = outputs[0].tolist()
full_token_list = [self.tokenizer.decode([token_id]) for token_id in full_id_list]
full_text = self.tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
Expand Down Expand Up @@ -135,7 +140,8 @@ def chat(self,
}

with torch.no_grad():
outputs = self.model.generate(**inputs, **generate_kwargs)
outputs = self.model.generate(**inputs,
**clean_dict_null_value(generate_kwargs))

full_text = self.tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
re_inputs = self.tokenizer.batch_decode(inputs.input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
Expand Down
8 changes: 8 additions & 0 deletions src/fastmindapi/server/router/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,20 @@ class BasicModel(BaseModel):

model_config = ConfigDict(protected_namespaces=())

class GenerationConfig(BaseModel):
do_sample: Optional[bool] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
repetition_penalty: Optional[float] = None

class GenerationRequest(BaseModel):
input_text: str
max_new_tokens: Optional[int] = None
return_logits: Optional[bool] = None
logits_top_k: Optional[int] = None
stop_strings: Optional[list[str]] = None
config: Optional[GenerationConfig] = None

model_config=ConfigDict(protected_namespaces=())

Expand Down
3 changes: 3 additions & 0 deletions src/fastmindapi/utils/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ def convert_numpy_float32_to_float(d):
else:
return d

def clean_dict_null_value(d):
return { k:d[k] for k in d if d[k] }

def convert_openai_logprobs(logprobs):
logprobs = logprobs.model_dump()
logits_list = []
Expand Down

0 comments on commit 15fb395

Please sign in to comment.