Skip to content

Commit

Permalink
Added persona to llm
Browse files Browse the repository at this point in the history
  • Loading branch information
NeonBohdan committed Oct 29, 2023
1 parent 4169996 commit 2596e25
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions neon_llm_chatgpt/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _system_prompt(self) -> str:
def warmup(self):
self.model

def get_sorted_answer_indexes(self, question: str, answers: List[str]) -> List[int]:
def get_sorted_answer_indexes(self, question: str, answers: List[str], persona: dict) -> List[int]:
"""
Creates sorted list of answer indexes with respect to order provided in :param answers based on PPL score
Answers are sorted from best to worst
Expand All @@ -84,7 +84,7 @@ def get_sorted_answer_indexes(self, question: str, answers: List[str]) -> List[i
"""
if not answers:
return []
scores = self._score(prompt=question, targets=answers)
scores = self._score(prompt=question, targets=answers, persona=persona)
sorted_items = sorted(zip(range(len(answers)), scores), key=lambda x: x[1])
sorted_items_indexes = [x[0] for x in sorted_items]
return sorted_items_indexes
Expand All @@ -106,7 +106,7 @@ def _call_model(self, prompt: List[Dict[str, str]]) -> str:

return text

def _assemble_prompt(self, message: str, chat_history: List[List[str]]) -> List[Dict[str, str]]:
def _assemble_prompt(self, message: str, chat_history: List[List[str]], persona: dict) -> List[Dict[str, str]]:
"""
Assembles prompt engineering logic
Setup Guidance:
Expand All @@ -116,8 +116,9 @@ def _assemble_prompt(self, message: str, chat_history: List[List[str]]) -> List[
:param chat_history: History of preceding conversation
:returns: assembled prompt
"""
system_prompt = persona.get("description", self._system_prompt)
messages = [
{"role": "system", "content": self._system_prompt},
{"role": "system", "content": system_prompt},
]
# Context N messages
for role, content in chat_history[-self.context_depth:]:
Expand All @@ -126,29 +127,29 @@ def _assemble_prompt(self, message: str, chat_history: List[List[str]]) -> List[
messages.append({"role": "user", "content": message})
return messages

def _score(self, prompt: str, targets: List[str]) -> List[float]:
def _score(self, prompt: str, targets: List[str], persona: dict) -> List[float]:
"""
Calculates logarithmic probabilities for the list of provided text sequences
:param prompt: Input text sequence
:param targets: Output text sequences
:returns: List of calculated logarithmic probabilities per output text sequence
"""

question_embeddings, answers_embeddings = self._embeddings(question=prompt, answers=targets)
question_embeddings, answers_embeddings = self._embeddings(question=prompt, answers=targets, persona=persona)
scores_list = distances_from_embeddings(question_embeddings, answers_embeddings)
return scores_list

def _tokenize(self, prompt: str) -> None:
pass

def _embeddings(self, question: str, answers: List[str]) -> (List[float], List[List[float]]):
def _embeddings(self, question: str, answers: List[str], persona: dict) -> (List[float], List[List[float]]):
"""
Computes embeddings for the list of provided answers
:param question: Question for LLM to response to
:param answers: List of provided answers
:returns ppl values for each answer
"""
response = self.ask(question, [])
response = self.ask(question, [], persona=persona)
texts = [response] + answers
embeddings = get_embeddings(texts, engine="text-embedding-ada-002")
question_embeddings = embeddings[0]
Expand Down

0 comments on commit 2596e25

Please sign in to comment.