diff --git a/neon_llm_chatgpt/chatgpt.py b/neon_llm_chatgpt/chatgpt.py index 0f356c1..b12f124 100644 --- a/neon_llm_chatgpt/chatgpt.py +++ b/neon_llm_chatgpt/chatgpt.py @@ -74,7 +74,7 @@ def _system_prompt(self) -> str: def warmup(self): self.model - def get_sorted_answer_indexes(self, question: str, answers: List[str]) -> List[int]: + def get_sorted_answer_indexes(self, question: str, answers: List[str], persona: dict) -> List[int]: """ Creates sorted list of answer indexes with respect to order provided in :param answers based on PPL score Answers are sorted from best to worst @@ -84,7 +84,7 @@ def get_sorted_answer_indexes(self, question: str, answers: List[str]) -> List[i """ if not answers: return [] - scores = self._score(prompt=question, targets=answers) + scores = self._score(prompt=question, targets=answers, persona=persona) sorted_items = sorted(zip(range(len(answers)), scores), key=lambda x: x[1]) sorted_items_indexes = [x[0] for x in sorted_items] return sorted_items_indexes @@ -106,7 +106,7 @@ def _call_model(self, prompt: List[Dict[str, str]]) -> str: return text - def _assemble_prompt(self, message: str, chat_history: List[List[str]]) -> List[Dict[str, str]]: + def _assemble_prompt(self, message: str, chat_history: List[List[str]], persona: dict) -> List[Dict[str, str]]: """ Assembles prompt engineering logic Setup Guidance: @@ -116,8 +116,9 @@ def _assemble_prompt(self, message: str, chat_history: List[List[str]]) -> List[ :param chat_history: History of preceding conversation :returns: assembled prompt """ + system_prompt = persona.get("description", self._system_prompt) messages = [ - {"role": "system", "content": self._system_prompt}, + {"role": "system", "content": system_prompt}, ] # Context N messages for role, content in chat_history[-self.context_depth:]: @@ -126,7 +127,7 @@ def _assemble_prompt(self, message: str, chat_history: List[List[str]]) -> List[ messages.append({"role": "user", "content": message}) return messages - def _score(self, prompt: str, targets: List[str]) -> List[float]: + def _score(self, prompt: str, targets: List[str], persona: dict) -> List[float]: """ Calculates logarithmic probabilities for the list of provided text sequences :param prompt: Input text sequence @@ -134,21 +135,21 @@ def _score(self, prompt: str, targets: List[str]) -> List[float]: :returns: List of calculated logarithmic probabilities per output text sequence """ - question_embeddings, answers_embeddings = self._embeddings(question=prompt, answers=targets) + question_embeddings, answers_embeddings = self._embeddings(question=prompt, answers=targets, persona=persona) scores_list = distances_from_embeddings(question_embeddings, answers_embeddings) return scores_list def _tokenize(self, prompt: str) -> None: pass - def _embeddings(self, question: str, answers: List[str]) -> (List[float], List[List[float]]): + def _embeddings(self, question: str, answers: List[str], persona: dict) -> (List[float], List[List[float]]): """ Computes embeddings for the list of provided answers :param question: Question for LLM to response to :param answers: List of provided answers :returns ppl values for each answer """ - response = self.ask(question, []) + response = self.ask(question, [], persona=persona) texts = [response] + answers embeddings = get_embeddings(texts, engine="text-embedding-ada-002") question_embeddings = embeddings[0]