diff --git a/neon_llm_core/llm.py b/neon_llm_core/llm.py index a4c9244..37235eb 100644 --- a/neon_llm_core/llm.py +++ b/neon_llm_core/llm.py @@ -23,8 +23,13 @@ # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + from abc import ABC, abstractmethod -from typing import List +from typing import List, Optional, Tuple, Union + +from neon_data_models.models.api import LLMRequest, LLMResponse +from openai import OpenAI +from ovos_utils.log import log_deprecation class NeonLLM(ABC): @@ -36,8 +41,6 @@ def __init__(self, config: dict): @param config: Dict LLM configuration for this specific LLM """ self._llm_config = config - self._tokenizer = None - self._model = None @property def llm_config(self): @@ -48,75 +51,127 @@ def llm_config(self): @property @abstractmethod - def tokenizer(self): + def tokenizer(self) -> Optional[object]: + """ + Get a Tokenizer object for the loaded model, if available. + :return: optional transformers.PreTrainedTokenizerBase object + """ pass @property @abstractmethod def tokenizer_model_name(self) -> str: + """ + Get a string tokenizer model name (i.e. a Huggingface `model id`) + associated with `self.tokenizer`. + """ pass @property @abstractmethod - def model(self): + def model(self) -> OpenAI: + """ + Get an OpenAI client object to send requests to. + """ pass @property @abstractmethod def llm_model_name(self) -> str: + """ + Get a string model name for the configured `model` + """ pass @property @abstractmethod def _system_prompt(self) -> str: + """ + Get a default string system prompt to use when not included in requests + """ pass def ask(self, message: str, chat_history: List[List[str]], persona: dict) -> str: - """ Generates llm response based on user message and (user, llm) chat history """ + """ + Generates llm response based on user message and (user, llm) chat history + """ + log_deprecation("This method is replaced by `query_model` which " + "accepts a single `LLMRequest` arg", "1.0.0") prompt = self._assemble_prompt(message, chat_history, persona) llm_text_output = self._call_model(prompt) return llm_text_output + def query_model(self, request: LLMRequest) -> LLMResponse: + """ + Calls `self._assemble_prompt` to allow subclass to modify the input + query and then passes the updated query to `self._call_model` + :param request: LLMRequest object to generate a response to + :return: + """ + if request.model != self.llm_model_name: + raise ValueError(f"Requested model ({request.model}) is not this " + f"model ({self.llm_model_name}") + request.query = self._assemble_prompt(request.query, request.history, + request.persona.model_dump()) + response = self._call_model(request.query, request) + history = request.history + [("llm", response)] + return LLMResponse(response=response, history=history) + @abstractmethod def get_sorted_answer_indexes(self, question: str, answers: List[str], persona: dict) -> List[int]: """ - Creates sorted list of answer indexes with respect to order provided in :param answers - Results should be sorted from best to worst - :param question: incoming question - :param answers: list of answers to rank - :returns list of indexes + Creates sorted list of answer indexes with respect to order provided in + `answers`. Results should be sorted from best to worst + :param question: incoming question + :param answers: list of answers to rank + :param persona: dict representation of Persona to use for sorting + :returns list of indexes """ pass @abstractmethod - def _call_model(self, prompt: str) -> str: + def _call_model(self, prompt: str, + request: Optional[LLMRequest] = None) -> str: """ Wrapper for Model generation logic. This method may be called asynchronously, so it is up to the extending class to use locks or queue inputs as necessary. :param prompt: Input text sequence + :param request: Optional LLMRequest object containing parameters to + include in model requests :returns: Output text sequence generated by model """ pass @abstractmethod - def _assemble_prompt(self, message: str, chat_history: List[List[str]], persona: dict): + def _assemble_prompt(self, message: str, + chat_history: List[Union[List[str], Tuple[str, str]]], + persona: dict) -> str: """ - Assembles prompt engineering logic - - :param message: Incoming prompt - :param chat_history: History of preceding conversation - :returns: assembled prompt + Assemble the prompt to send to the LLM + :param message: Input prompt to optionally modify + :param chat_history: History of preceding conversation + :param persona: dict representation of Persona that is requested + :returns: assembled prompt string """ pass @abstractmethod def _tokenize(self, prompt: str) -> List[str]: + """ + Tokenize the input prompt into a list of strings + :param prompt: Input to tokenize + :return: Tokenized representation of input prompt + """ pass @classmethod def convert_role(cls, role: str) -> str: - """ Maps MQ role to LLM's internal domain """ + """ + Maps MQ role to LLM's internal domain + :param role: Role in Neon LLM format + :return: Role in LLM internal format + """ matching_llm_role = cls.mq_to_llm_role.get(role) if not matching_llm_role: raise ValueError(f"role={role} is undefined, supported are: " diff --git a/neon_llm_core/rmq.py b/neon_llm_core/rmq.py index 6d87951..3444b11 100644 --- a/neon_llm_core/rmq.py +++ b/neon_llm_core/rmq.py @@ -28,12 +28,13 @@ from threading import Thread from typing import Optional +from neon_data_models.models.api import LLMRequest from neon_mq_connector.connector import MQConnector from neon_mq_connector.utils.rabbit_utils import create_mq_callback from neon_utils.logger import LOG from neon_data_models.models.api.mq import (LLMProposeResponse, - LLMDiscussResponse, LLMVoteResponse) + LLMDiscussResponse, LLMVoteResponse, LLMDiscussRequest, LLMVoteRequest) from neon_llm_core.utils.config import load_config from neon_llm_core.llm import NeonLLM @@ -123,21 +124,16 @@ def _handle_request_async(self, request: dict): message_id = request["message_id"] routing_key = request["routing_key"] - query = request["query"] - history = request["history"] - persona = request.get("persona", {}) - try: - response = self.model.ask(message=query, chat_history=history, - persona=persona) + response = self.model.query_model(LLMRequest(**request)) except ValueError as err: LOG.error(f'ValueError={err}') response = ('Sorry, but I cannot respond to your message at the ' 'moment, please try again later') api_response = LLMProposeResponse(message_id=message_id, - response=response, + response=response.response, routing_key=routing_key) - LOG.info(f"Sending response: {response}") + LOG.info(f"Sending response: {api_response}") self.send_message(request_data=api_response.model_dump(), queue=routing_key) LOG.info(f"Handled ask request for message_id={message_id}") @@ -149,29 +145,25 @@ def handle_score_request(self, body: dict): Handles score requests (vote) from MQ to LLM :param body: request body (dict) """ - message_id = body["message_id"] - routing_key = body["routing_key"] - - query = body["query"] - responses = body["responses"] - persona = body.get("persona", {}) + request = LLMVoteRequest(**body) - if not responses: + if not request.responses: sorted_answer_idx = [] else: try: sorted_answer_idx = self.model.get_sorted_answer_indexes( - question=query, answers=responses, persona=persona) + question=request.query, answers=request.responses, + persona=request.persona.model_dump()) except ValueError as err: LOG.error(f'ValueError={err}') sorted_answer_idx = [] - api_response = LLMVoteResponse(message_id=message_id, - routing_key=routing_key, + api_response = LLMVoteResponse(message_id=request.message_id, + routing_key=request.routing_key, sorted_answer_indexes=sorted_answer_idx) self.send_message(request_data=api_response.model_dump(), - queue=routing_key) - LOG.info(f"Handled score request for message_id={message_id}") + queue=request.routing_key) + LOG.info(f"Handled score request for message_id={request.message_id}") @create_mq_callback() def handle_opinion_request(self, body: dict): @@ -179,46 +171,42 @@ def handle_opinion_request(self, body: dict): Handles opinion requests (discuss) from MQ to LLM :param body: request body (dict) """ - message_id = body["message_id"] - routing_key = body["routing_key"] - - query = body["query"] - options = body["options"] - persona = body.get("persona", {}) - responses = list(options.values()) + request = LLMDiscussRequest(**body) - if not responses: + if not request.options: opinion = "Sorry, but I got no options to choose from." else: try: sorted_answer_indexes = self.model.get_sorted_answer_indexes( - question=query, answers=responses, persona=persona) - best_respondent_nick, best_response = list(options.items())[ - sorted_answer_indexes[0]] + question=request.query, + answers=list(request.options.values()), + persona=request.persona.model_dump()) + best_respondent_nick, best_response = \ + list(request.options.items())[sorted_answer_indexes[0]] opinion = self._ask_model_for_opinion( respondent_nick=best_respondent_nick, - question=query, answer=best_response, persona=persona) + llm_request=LLMRequest(**body), answer=best_response) except ValueError as err: LOG.error(f'ValueError={err}') opinion = ("Sorry, but I experienced an issue trying to form " "an opinion on this topic") - api_response = LLMDiscussResponse(message_id=message_id, - routing_key=routing_key, + api_response = LLMDiscussResponse(message_id=request.message_id, + routing_key=request.routing_key, opinion=opinion) self.send_message(request_data=api_response.model_dump(), - queue=routing_key) - LOG.info(f"Handled ask request for message_id={message_id}") + queue=request.routing_key) + LOG.info(f"Handled ask request for message_id={request.message_id}") - def _ask_model_for_opinion(self, respondent_nick: str, question: str, - answer: str, persona: dict) -> str: - prompt = self.compose_opinion_prompt(respondent_nick=respondent_nick, - question=question, - answer=answer) - opinion = self.model.ask(message=prompt, chat_history=[], - persona=persona) - LOG.info(f'Received LLM opinion={opinion}, prompt={prompt}') - return opinion + def _ask_model_for_opinion(self, llm_request: LLMRequest, + respondent_nick: str, + answer: str) -> str: + llm_request.query = self.compose_opinion_prompt( + respondent_nick=respondent_nick, question=llm_request.query, + answer=answer) + opinion = self.model.query_model(llm_request) + LOG.info(f'Received LLM opinion={opinion}, prompt={llm_request.query}') + return opinion.response @staticmethod @abstractmethod diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 95be90f..c025d80 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,6 +1,7 @@ # networking neon-mq-connector>=0.7.2a1 neon_utils[sentry]==1.11.1a5 -ovos-config~=0.0.10 +ovos-config~=0.0 +ovos-utils~=0.0 pydantic~=2.6 neon-data-models@git+https://github.com/neongeckocom/neon-data-models@FEAT_LLMRequestModels \ No newline at end of file diff --git a/tests/test_llm.py b/tests/test_llm.py index 8a6a92f..16a7cea 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -25,8 +25,104 @@ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from unittest import TestCase +from unittest.mock import Mock +from neon_data_models.models.api import LLMResponse + +from neon_llm_core.llm import NeonLLM + + +class MockLLM(NeonLLM): + + mq_to_llm_role = {"user": "user", + "llm": "assistant"} + + def __init__(self, *args, **kwargs): + NeonLLM.__init__(self, *args, **kwargs) + self._assemble_prompt = Mock(return_value=lambda *args: args[0]) + self._tokenize = Mock(return_value=lambda *args: args[0]) + self.get_sorted_answer_indexes = Mock(return_value=lambda *args: [i for i in range(len(args))]) + self._call_model = Mock(return_value="mock model response") + + @property + def tokenizer(self): + return None + + @property + def tokenizer_model_name(self) -> str: + return "mock_tokenizer" + + @property + def model(self): + return Mock() + + @property + def llm_model_name(self) -> str: + return "mock_model" + + @property + def _system_prompt(self) -> str: + return "mock system prompt" class TestNeonLLM(TestCase): - # TODO - pass + MockLLM.__abstractmethods__ = set() + config = {"test_config": True} + test_llm = MockLLM(config) + + def setUp(self): + self.test_llm._assemble_prompt.reset_mock() + self.test_llm._tokenize.reset_mock() + self.test_llm.get_sorted_answer_indexes.reset_mock() + self.test_llm._call_model.reset_mock() + + def test_init(self): + self.assertEqual(self.test_llm.llm_config, self.config) + self.assertIsNone(self.test_llm.tokenizer) + self.assertIsInstance(self.test_llm.tokenizer_model_name, str) + self.assertIsNotNone(self.test_llm.model) + self.assertIsInstance(self.test_llm.llm_model_name, str) + self.assertIsInstance(self.test_llm._system_prompt, str) + + def test_ask(self): + from neon_data_models.models.api import LLMPersona + message = "Test input" + history = [["user", "hello"], ["llm", "Hello. How can I help?"]] + persona = LLMPersona(name="test_persona", description="test persona") + + # Valid request + response = self.test_llm.ask(message, history, persona.model_dump()) + self.assertEqual(response, self.test_llm._call_model.return_value) + self.test_llm._assemble_prompt.assert_called_once_with(message, history, + persona.model_dump()) + self.test_llm._call_model.assert_called_once_with(self.test_llm._assemble_prompt.return_value) + + def test_query_model(self): + from neon_data_models.models.api import LLMPersona, LLMRequest + message = "Test input" + history = [["user", "hello"], ["llm", "Hello. How can I help?"]] + persona = LLMPersona(name="test_persona", description="test persona") + valid_request = LLMRequest(query=message, history=history, + persona=persona, + model=self.test_llm.llm_model_name) + response = self.test_llm.query_model(valid_request) + self.assertIsInstance(response, LLMResponse) + self.assertEqual(response.response, self.test_llm._call_model.return_value) + self.assertEqual(response.history[-1], + ("llm", self.test_llm._call_model.return_value)) + self.assertEqual(len(response.history), 3) + self.test_llm._assemble_prompt.assert_called_once_with( + message, valid_request.history, persona.model_dump()) + self.test_llm._call_model.assert_called_once_with( + self.test_llm._assemble_prompt.return_value, valid_request) + + # Request for a different model will raise an exception + invalid_request = LLMRequest(query=message, history=history, + persona=persona, model="invalid_model") + with self.assertRaises(ValueError): + self.test_llm.query_model(invalid_request) + + def test_convert_role(self): + self.assertEqual(self.test_llm.convert_role("user"), "user") + self.assertEqual(self.test_llm.convert_role("llm"), "assistant") + with self.assertRaises(ValueError): + self.test_llm.convert_role("assistant")