Skip to content

Commit

Permalink
Updated model
Browse files Browse the repository at this point in the history
  • Loading branch information
NeonBohdan committed Nov 25, 2023
1 parent 1254f7a commit 8b84cce
Showing 1 changed file with 38 additions and 25 deletions.
63 changes: 38 additions & 25 deletions neon_llm_palm2/palm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,28 @@
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import openai
from openai.embeddings_utils import get_embeddings, distances_from_embeddings
import os
from vertexai.language_models import ChatModel, ChatMessage, TextEmbeddingModel
from openai.embeddings_utils import distances_from_embeddings

from typing import List, Dict
from neon_llm_core.llm import NeonLLM


class ChatGPT(NeonLLM):
class Palm2(NeonLLM):

mq_to_llm_role = {
"user": "user",
"llm": "assistant"
"llm": "bot"
}

def __init__(self, config):
super().__init__(config)
self.model_name = config["model"]
self._embedding = None
self.role = config["role"]
self.context_depth = config["context_depth"]
self.max_tokens = config["max_tokens"]
self.api_key = config["key"]
self.api_key_path = config["key_path"]
self.warmup()

@property
Expand All @@ -56,11 +57,16 @@ def tokenizer_model_name(self) -> str:
return ""

@property
def model(self) -> openai:
def model(self) -> ChatModel:
if self._model is None:
openai.api_key = self.api_key
self._model = openai
self._model = ChatModel.from_pretrained("chat-bison@001")
return self._model

@property
def embedding(self) -> TextEmbeddingModel:
if self._embedding is None:
self._embedding = TextEmbeddingModel.from_pretrained("textembedding-gecko@001")
return self._embedding

@property
def llm_model_name(self) -> str:
Expand Down Expand Up @@ -88,43 +94,49 @@ def get_sorted_answer_indexes(self, question: str, answers: List[str], persona:
sorted_items_indexes = [x[0] for x in sorted_items]
return sorted_items_indexes

def _call_model(self, prompt: List[Dict[str, str]]) -> str:
def _call_model(self, prompt: Dict) -> str:
"""
Wrapper for ChatGPT Model generation logic
Wrapper for Palm2 Model generation logic
:param prompt: Input messages sequence
:returns: Output text sequence generated by model
"""

response = openai.ChatCompletion.create(
model=self.llm_model_name,
messages=prompt,
chat = self._model.start_chat(
context=prompt["system_prompt"],
message_history=prompt["chat_history"],
max_output_tokens=self.max_tokens,
temperature=0,
max_tokens=self.max_tokens,
)
text = response.choices[0].message['content']
response = chat.send_message(
prompt["message"],
)
text = response.text

return text

def _assemble_prompt(self, message: str, chat_history: List[List[str]], persona: dict) -> List[Dict[str, str]]:
"""
Assembles prompt engineering logic
Setup Guidance:
https://platform.openai.com/docs/guides/gpt/chat-completions-api
https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/overview
:param message: Incoming prompt
:param chat_history: History of preceding conversation
:returns: assembled prompt
"""
system_prompt = persona.get("description", self._system_prompt)
messages = [
{"role": "system", "content": system_prompt},
]
# Context N messages
messages = []
for role, content in chat_history[-self.context_depth:]:
role_chatgpt = self.convert_role(role)
messages.append({"role": role_chatgpt, "content": content})
messages.append({"role": "user", "content": message})
return messages
role_palm2 = self.convert_role(role)
messages.append(ChatMessage(content, role_palm2))
prompt = {
"system_prompt": system_prompt,
"chat_history": messages,
"message": message
}

return prompt

def _score(self, prompt: str, targets: List[str], persona: dict) -> List[float]:
"""
Expand All @@ -150,7 +162,8 @@ def _embeddings(self, question: str, answers: List[str], persona: dict) -> (List
"""
response = self.ask(question, [], persona=persona)
texts = [response] + answers
embeddings = get_embeddings(texts, engine="text-embedding-ada-002")
embeddings_obj = self._embedding.get_embeddings(texts)
embeddings = [embedding.values for embedding in embeddings_obj]
question_embeddings = embeddings[0]
answers_embeddings = embeddings[1:]
return question_embeddings, answers_embeddings

0 comments on commit 8b84cce

Please sign in to comment.