diff --git a/neon_data_models/models/api/llm.py b/neon_data_models/models/api/llm.py index b164a12..5b5ab58 100644 --- a/neon_data_models/models/api/llm.py +++ b/neon_data_models/models/api/llm.py @@ -25,7 +25,7 @@ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from typing import List, Tuple, Optional, Literal -from pydantic import Field, model_validator +from pydantic import Field, model_validator, computed_field from neon_data_models.models.base import BaseModel @@ -40,6 +40,11 @@ class LLMPersona(BaseModel): system_prompt: str = Field( None, description="System prompt associated with this persona. " "If None, `description` will be used.") + enabled: bool = Field( + True, description="Flag used to mark a defined persona as " + "available for use.") + user_id: Optional[str] = Field( + None, description="`user_id` of the user who created this persona.") @model_validator(mode='after') def validate_request(self): @@ -48,6 +53,14 @@ def validate_request(self): self.system_prompt = self.description return self + @computed_field + @property + def id(self) -> str: + persona_id = self.name + if self.user_id: + persona_id += f"_{self.user_id}" + return persona_id + class LLMRequest(BaseModel): query: str = Field(description="Incoming user prompt") diff --git a/neon_data_models/models/api/mq.py b/neon_data_models/models/api/mq.py index c8bcd73..7cc5712 100644 --- a/neon_data_models/models/api/mq.py +++ b/neon_data_models/models/api/mq.py @@ -24,7 +24,7 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from typing import Literal, Optional +from typing import Literal, Optional, Dict, List from pydantic import Field from neon_data_models.models.api.llm import LLMRequest, LLMPersona @@ -40,7 +40,7 @@ class UserDbRequest(MQContext): user: Optional[User] = None -class MqLlmRequest(MQContext, LLMRequest): +class LLMProposeRequest(MQContext, LLMRequest): model: Optional[str] = Field( default=None, description="MQ implementation defines `model` as optional because the " @@ -52,4 +52,30 @@ class MqLlmRequest(MQContext, LLMRequest): "LLM module.") -__all__ = [UserDbRequest.__name__, MqLlmRequest.__name__] +class LLMProposeResponse(MQContext): + response: str = Field(description="LLM response to the prompt") + + +class LLMDiscussRequest(LLMProposeRequest): + options: Dict[str, str] = Field( + description="Mapping of participant name to response to be discussed.") + + +class LLMDiscussResponse(MQContext): + opinion: str = Field(description="LLM response to the available options.") + + +class LLMVoteRequest(LLMProposeRequest): + responses: List[str] = Field( + description="List of responses to choose from.") + + +class LLMVoteResponse(MQContext): + sorted_answer_indexes: List[int] = Field( + description="Indices of `responses` ordered high to low by preference.") + + +__all__ = [UserDbRequest.__name__, LLMProposeRequest.__name__, + LLMProposeResponse.__name__, LLMDiscussRequest.__name__, + LLMDiscussResponse.__name__, LLMVoteRequest.__name__, + LLMVoteResponse.__name__] diff --git a/tests/models/api/test_mq.py b/tests/models/api/test_mq.py index e7afa5d..000a20a 100644 --- a/tests/models/api/test_mq.py +++ b/tests/models/api/test_mq.py @@ -44,7 +44,7 @@ def test_user_db_request(self): UserDbRequest(operation="create", username="test_user") def test_mq_llm_request(self): - from neon_data_models.models.api.mq import MqLlmRequest + from neon_data_models.models.api.mq import LLMProposeRequest from neon_data_models.models.api.llm import LLMRequest from neon_data_models.models.base.contexts import MQContext @@ -55,17 +55,17 @@ def test_mq_llm_request(self): message_id = "test_mid" # Valid fully-defined - valid_request = MqLlmRequest(query=query, history=history, - persona=persona, model=model_name, - message_id=message_id) - self.assertIsInstance(valid_request, MqLlmRequest) + valid_request = LLMProposeRequest(query=query, history=history, + persona=persona, model=model_name, + message_id=message_id) + self.assertIsInstance(valid_request, LLMProposeRequest) self.assertIsInstance(valid_request, LLMRequest) self.assertIsInstance(valid_request, MQContext) # Valid backwards-compat (no model or persona) - backwards_compat = MqLlmRequest(query=query, history=history, - message_id=message_id) - self.assertIsInstance(backwards_compat, MqLlmRequest) + backwards_compat = LLMProposeRequest(query=query, history=history, + message_id=message_id) + self.assertIsInstance(backwards_compat, LLMProposeRequest) self.assertIsInstance(backwards_compat, LLMRequest) self.assertIsInstance(backwards_compat, MQContext) self.assertIsNone(backwards_compat.model) @@ -73,13 +73,13 @@ def test_mq_llm_request(self): # Invalid Persona defined with self.assertRaises(ValidationError): - MqLlmRequest(query=query, history=history, message_id=message_id, - persona={}) + LLMProposeRequest(query=query, history=history, message_id=message_id, + persona={}) # Invalid MQ Context with self.assertRaises(ValidationError): - MqLlmRequest(query=query, history=history) + LLMProposeRequest(query=query, history=history) # Invalid LLM Request with self.assertRaises(ValidationError): - MqLlmRequest(history=history, message_id=message_id) + LLMProposeRequest(history=history, message_id=message_id)