diff --git a/neon_llm_palm2/palm2.py b/neon_llm_palm2/palm2.py index fe378bf..846b83c 100644 --- a/neon_llm_palm2/palm2.py +++ b/neon_llm_palm2/palm2.py @@ -24,27 +24,28 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import openai -from openai.embeddings_utils import get_embeddings, distances_from_embeddings +import os +from vertexai.language_models import ChatModel, ChatMessage, TextEmbeddingModel +from openai.embeddings_utils import distances_from_embeddings from typing import List, Dict from neon_llm_core.llm import NeonLLM -class ChatGPT(NeonLLM): +class Palm2(NeonLLM): mq_to_llm_role = { "user": "user", - "llm": "assistant" + "llm": "bot" } def __init__(self, config): super().__init__(config) - self.model_name = config["model"] + self._embedding = None self.role = config["role"] self.context_depth = config["context_depth"] self.max_tokens = config["max_tokens"] - self.api_key = config["key"] + self.api_key_path = config["key_path"] self.warmup() @property @@ -56,11 +57,16 @@ def tokenizer_model_name(self) -> str: return "" @property - def model(self) -> openai: + def model(self) -> ChatModel: if self._model is None: - openai.api_key = self.api_key - self._model = openai + self._model = ChatModel.from_pretrained("chat-bison@001") return self._model + + @property + def embedding(self) -> TextEmbeddingModel: + if self._embedding is None: + self._embedding = TextEmbeddingModel.from_pretrained("textembedding-gecko@001") + return self._embedding @property def llm_model_name(self) -> str: @@ -88,20 +94,23 @@ def get_sorted_answer_indexes(self, question: str, answers: List[str], persona: sorted_items_indexes = [x[0] for x in sorted_items] return sorted_items_indexes - def _call_model(self, prompt: List[Dict[str, str]]) -> str: + def _call_model(self, prompt: Dict) -> str: """ - Wrapper for ChatGPT Model generation logic + Wrapper for Palm2 Model generation logic :param prompt: Input messages sequence :returns: Output text sequence generated by model """ - response = openai.ChatCompletion.create( - model=self.llm_model_name, - messages=prompt, + chat = self._model.start_chat( + context=prompt["system_prompt"], + message_history=prompt["chat_history"], + max_output_tokens=self.max_tokens, temperature=0, - max_tokens=self.max_tokens, ) - text = response.choices[0].message['content'] + response = chat.send_message( + prompt["message"], + ) + text = response.text return text @@ -109,22 +118,25 @@ def _assemble_prompt(self, message: str, chat_history: List[List[str]], persona: """ Assembles prompt engineering logic Setup Guidance: - https://platform.openai.com/docs/guides/gpt/chat-completions-api + https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/overview :param message: Incoming prompt :param chat_history: History of preceding conversation :returns: assembled prompt """ system_prompt = persona.get("description", self._system_prompt) - messages = [ - {"role": "system", "content": system_prompt}, - ] # Context N messages + messages = [] for role, content in chat_history[-self.context_depth:]: - role_chatgpt = self.convert_role(role) - messages.append({"role": role_chatgpt, "content": content}) - messages.append({"role": "user", "content": message}) - return messages + role_palm2 = self.convert_role(role) + messages.append(ChatMessage(content, role_palm2)) + prompt = { + "system_prompt": system_prompt, + "chat_history": messages, + "message": message + } + + return prompt def _score(self, prompt: str, targets: List[str], persona: dict) -> List[float]: """ @@ -150,7 +162,8 @@ def _embeddings(self, question: str, answers: List[str], persona: dict) -> (List """ response = self.ask(question, [], persona=persona) texts = [response] + answers - embeddings = get_embeddings(texts, engine="text-embedding-ada-002") + embeddings_obj = self._embedding.get_embeddings(texts) + embeddings = [embedding.values for embedding in embeddings_obj] question_embeddings = embeddings[0] answers_embeddings = embeddings[1:] return question_embeddings, answers_embeddings \ No newline at end of file