Skip to content

Commit

Permalink
Refactor NeonLLM to remove unused internal variables
Browse files Browse the repository at this point in the history
Update docstrings and type annotations
Add `query_model` method to accept an `LLMRequest` and return an `LLMResponse`
Add unit test coverage for `NeonLLM`
  • Loading branch information
NeonDaniel committed Dec 12, 2024
1 parent 8a270c3 commit ed256da
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 68 deletions.
93 changes: 74 additions & 19 deletions neon_llm_core/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,13 @@
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from abc import ABC, abstractmethod
from typing import List
from typing import List, Optional, Tuple, Union

from neon_data_models.models.api import LLMRequest, LLMResponse
from openai import OpenAI
from ovos_utils.log import log_deprecation


class NeonLLM(ABC):
Expand All @@ -36,8 +41,6 @@ def __init__(self, config: dict):
@param config: Dict LLM configuration for this specific LLM
"""
self._llm_config = config
self._tokenizer = None
self._model = None

@property
def llm_config(self):
Expand All @@ -48,75 +51,127 @@ def llm_config(self):

@property
@abstractmethod
def tokenizer(self):
def tokenizer(self) -> Optional[object]:
"""
Get a Tokenizer object for the loaded model, if available.
:return: optional transformers.PreTrainedTokenizerBase object
"""
pass

@property
@abstractmethod
def tokenizer_model_name(self) -> str:
"""
Get a string tokenizer model name (i.e. a Huggingface `model id`)
associated with `self.tokenizer`.
"""
pass

@property
@abstractmethod
def model(self):
def model(self) -> OpenAI:
"""
Get an OpenAI client object to send requests to.
"""
pass

@property
@abstractmethod
def llm_model_name(self) -> str:
"""
Get a string model name for the configured `model`
"""
pass

@property
@abstractmethod
def _system_prompt(self) -> str:
"""
Get a default string system prompt to use when not included in requests
"""
pass

def ask(self, message: str, chat_history: List[List[str]], persona: dict) -> str:
""" Generates llm response based on user message and (user, llm) chat history """
"""
Generates llm response based on user message and (user, llm) chat history
"""
log_deprecation("This method is replaced by `query_model` which "
"accepts a single `LLMRequest` arg", "1.0.0")
prompt = self._assemble_prompt(message, chat_history, persona)
llm_text_output = self._call_model(prompt)
return llm_text_output

def query_model(self, request: LLMRequest) -> LLMResponse:
"""
Calls `self._assemble_prompt` to allow subclass to modify the input
query and then passes the updated query to `self._call_model`
:param request: LLMRequest object to generate a response to
:return:
"""
if request.model != self.llm_model_name:
raise ValueError(f"Requested model ({request.model}) is not this "
f"model ({self.llm_model_name}")
request.query = self._assemble_prompt(request.query, request.history,
request.persona.model_dump())
response = self._call_model(request.query, request)
history = request.history + [("llm", response)]
return LLMResponse(response=response, history=history)

@abstractmethod
def get_sorted_answer_indexes(self, question: str, answers: List[str], persona: dict) -> List[int]:
"""
Creates sorted list of answer indexes with respect to order provided in :param answers
Results should be sorted from best to worst
:param question: incoming question
:param answers: list of answers to rank
:returns list of indexes
Creates sorted list of answer indexes with respect to order provided in
`answers`. Results should be sorted from best to worst
:param question: incoming question
:param answers: list of answers to rank
:param persona: dict representation of Persona to use for sorting
:returns list of indexes
"""
pass

@abstractmethod
def _call_model(self, prompt: str) -> str:
def _call_model(self, prompt: str,
request: Optional[LLMRequest] = None) -> str:
"""
Wrapper for Model generation logic. This method may be called
asynchronously, so it is up to the extending class to use locks or
queue inputs as necessary.
:param prompt: Input text sequence
:param request: Optional LLMRequest object containing parameters to
include in model requests
:returns: Output text sequence generated by model
"""
pass

@abstractmethod
def _assemble_prompt(self, message: str, chat_history: List[List[str]], persona: dict):
def _assemble_prompt(self, message: str,
chat_history: List[Union[List[str], Tuple[str, str]]],
persona: dict) -> str:
"""
Assembles prompt engineering logic
:param message: Incoming prompt
:param chat_history: History of preceding conversation
:returns: assembled prompt
Assemble the prompt to send to the LLM
:param message: Input prompt to optionally modify
:param chat_history: History of preceding conversation
:param persona: dict representation of Persona that is requested
:returns: assembled prompt string
"""
pass

@abstractmethod
def _tokenize(self, prompt: str) -> List[str]:
"""
Tokenize the input prompt into a list of strings
:param prompt: Input to tokenize
:return: Tokenized representation of input prompt
"""
pass

@classmethod
def convert_role(cls, role: str) -> str:
""" Maps MQ role to LLM's internal domain """
"""
Maps MQ role to LLM's internal domain
:param role: Role in Neon LLM format
:return: Role in LLM internal format
"""
matching_llm_role = cls.mq_to_llm_role.get(role)
if not matching_llm_role:
raise ValueError(f"role={role} is undefined, supported are: "
Expand Down
80 changes: 34 additions & 46 deletions neon_llm_core/rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@
from threading import Thread
from typing import Optional

from neon_data_models.models.api import LLMRequest
from neon_mq_connector.connector import MQConnector
from neon_mq_connector.utils.rabbit_utils import create_mq_callback
from neon_utils.logger import LOG

from neon_data_models.models.api.mq import (LLMProposeResponse,
LLMDiscussResponse, LLMVoteResponse)
LLMDiscussResponse, LLMVoteResponse, LLMDiscussRequest, LLMVoteRequest)

from neon_llm_core.utils.config import load_config
from neon_llm_core.llm import NeonLLM
Expand Down Expand Up @@ -123,21 +124,16 @@ def _handle_request_async(self, request: dict):
message_id = request["message_id"]
routing_key = request["routing_key"]

query = request["query"]
history = request["history"]
persona = request.get("persona", {})

try:
response = self.model.ask(message=query, chat_history=history,
persona=persona)
response = self.model.query_model(LLMRequest(**request))
except ValueError as err:
LOG.error(f'ValueError={err}')
response = ('Sorry, but I cannot respond to your message at the '
'moment, please try again later')
api_response = LLMProposeResponse(message_id=message_id,
response=response,
response=response.response,
routing_key=routing_key)
LOG.info(f"Sending response: {response}")
LOG.info(f"Sending response: {api_response}")
self.send_message(request_data=api_response.model_dump(),
queue=routing_key)
LOG.info(f"Handled ask request for message_id={message_id}")
Expand All @@ -149,76 +145,68 @@ def handle_score_request(self, body: dict):
Handles score requests (vote) from MQ to LLM
:param body: request body (dict)
"""
message_id = body["message_id"]
routing_key = body["routing_key"]

query = body["query"]
responses = body["responses"]
persona = body.get("persona", {})
request = LLMVoteRequest(**body)

if not responses:
if not request.responses:
sorted_answer_idx = []
else:
try:
sorted_answer_idx = self.model.get_sorted_answer_indexes(
question=query, answers=responses, persona=persona)
question=request.query, answers=request.responses,
persona=request.persona.model_dump())
except ValueError as err:
LOG.error(f'ValueError={err}')
sorted_answer_idx = []

api_response = LLMVoteResponse(message_id=message_id,
routing_key=routing_key,
api_response = LLMVoteResponse(message_id=request.message_id,
routing_key=request.routing_key,
sorted_answer_indexes=sorted_answer_idx)
self.send_message(request_data=api_response.model_dump(),
queue=routing_key)
LOG.info(f"Handled score request for message_id={message_id}")
queue=request.routing_key)
LOG.info(f"Handled score request for message_id={request.message_id}")

@create_mq_callback()
def handle_opinion_request(self, body: dict):
"""
Handles opinion requests (discuss) from MQ to LLM
:param body: request body (dict)
"""
message_id = body["message_id"]
routing_key = body["routing_key"]

query = body["query"]
options = body["options"]
persona = body.get("persona", {})
responses = list(options.values())
request = LLMDiscussRequest(**body)

if not responses:
if not request.options:
opinion = "Sorry, but I got no options to choose from."
else:
try:
sorted_answer_indexes = self.model.get_sorted_answer_indexes(
question=query, answers=responses, persona=persona)
best_respondent_nick, best_response = list(options.items())[
sorted_answer_indexes[0]]
question=request.query,
answers=list(request.options.values()),
persona=request.persona.model_dump())
best_respondent_nick, best_response = \
list(request.options.items())[sorted_answer_indexes[0]]
opinion = self._ask_model_for_opinion(
respondent_nick=best_respondent_nick,
question=query, answer=best_response, persona=persona)
llm_request=LLMRequest(**body), answer=best_response)
except ValueError as err:
LOG.error(f'ValueError={err}')
opinion = ("Sorry, but I experienced an issue trying to form "
"an opinion on this topic")

api_response = LLMDiscussResponse(message_id=message_id,
routing_key=routing_key,
api_response = LLMDiscussResponse(message_id=request.message_id,
routing_key=request.routing_key,
opinion=opinion)
self.send_message(request_data=api_response.model_dump(),
queue=routing_key)
LOG.info(f"Handled ask request for message_id={message_id}")
queue=request.routing_key)
LOG.info(f"Handled ask request for message_id={request.message_id}")

def _ask_model_for_opinion(self, respondent_nick: str, question: str,
answer: str, persona: dict) -> str:
prompt = self.compose_opinion_prompt(respondent_nick=respondent_nick,
question=question,
answer=answer)
opinion = self.model.ask(message=prompt, chat_history=[],
persona=persona)
LOG.info(f'Received LLM opinion={opinion}, prompt={prompt}')
return opinion
def _ask_model_for_opinion(self, llm_request: LLMRequest,
respondent_nick: str,
answer: str) -> str:
llm_request.query = self.compose_opinion_prompt(
respondent_nick=respondent_nick, question=llm_request.query,
answer=answer)
opinion = self.model.query_model(llm_request)
LOG.info(f'Received LLM opinion={opinion}, prompt={llm_request.query}')
return opinion.response

@staticmethod
@abstractmethod
Expand Down
3 changes: 2 additions & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# networking
neon-mq-connector>=0.7.2a1
neon_utils[sentry]==1.11.1a5
ovos-config~=0.0.10
ovos-config~=0.0
ovos-utils~=0.0
pydantic~=2.6
neon-data-models@git+https://github.com/neongeckocom/neon-data-models@FEAT_LLMRequestModels
Loading

0 comments on commit ed256da

Please sign in to comment.