Skip to content

Commit

Permalink
Implement request/response models for all LLM messaging
Browse files Browse the repository at this point in the history
Update LLMPersona to be compatible with model in `neon-llm-core`
Refactor `MqLlmRequest` to `LLMProposeRequest` to be consistent with Chatbotsforum terminology
  • Loading branch information
NeonDaniel committed Nov 16, 2024
1 parent fdfc8e9 commit 20c4853
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 16 deletions.
15 changes: 14 additions & 1 deletion neon_data_models/models/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from typing import List, Tuple, Optional, Literal
from pydantic import Field, model_validator
from pydantic import Field, model_validator, computed_field

from neon_data_models.models.base import BaseModel

Expand All @@ -40,6 +40,11 @@ class LLMPersona(BaseModel):
system_prompt: str = Field(
None, description="System prompt associated with this persona. "
"If None, `description` will be used.")
enabled: bool = Field(
True, description="Flag used to mark a defined persona as "
"available for use.")
user_id: Optional[str] = Field(
None, description="`user_id` of the user who created this persona.")

@model_validator(mode='after')
def validate_request(self):
Expand All @@ -48,6 +53,14 @@ def validate_request(self):
self.system_prompt = self.description
return self

@computed_field
@property
def id(self) -> str:
persona_id = self.name
if self.user_id:
persona_id += f"_{self.user_id}"
return persona_id


class LLMRequest(BaseModel):
query: str = Field(description="Incoming user prompt")
Expand Down
32 changes: 29 additions & 3 deletions neon_data_models/models/api/mq.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from typing import Literal, Optional
from typing import Literal, Optional, Dict, List
from pydantic import Field

from neon_data_models.models.api.llm import LLMRequest, LLMPersona
Expand All @@ -40,7 +40,7 @@ class UserDbRequest(MQContext):
user: Optional[User] = None


class MqLlmRequest(MQContext, LLMRequest):
class LLMProposeRequest(MQContext, LLMRequest):
model: Optional[str] = Field(
default=None,
description="MQ implementation defines `model` as optional because the "
Expand All @@ -52,4 +52,30 @@ class MqLlmRequest(MQContext, LLMRequest):
"LLM module.")


__all__ = [UserDbRequest.__name__, MqLlmRequest.__name__]
class LLMProposeResponse(MQContext):
response: str = Field(description="LLM response to the prompt")


class LLMDiscussRequest(LLMProposeRequest):
options: Dict[str, str] = Field(
description="Mapping of participant name to response to be discussed.")


class LLMDiscussResponse(MQContext):
opinion: str = Field(description="LLM response to the available options.")


class LLMVoteRequest(LLMProposeRequest):
responses: List[str] = Field(
description="List of responses to choose from.")


class LLMVoteResponse(MQContext):
sorted_answer_indexes: List[int] = Field(
description="Indices of `responses` ordered high to low by preference.")


__all__ = [UserDbRequest.__name__, LLMProposeRequest.__name__,
LLMProposeResponse.__name__, LLMDiscussRequest.__name__,
LLMDiscussResponse.__name__, LLMVoteRequest.__name__,
LLMVoteResponse.__name__]
24 changes: 12 additions & 12 deletions tests/models/api/test_mq.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_user_db_request(self):
UserDbRequest(operation="create", username="test_user")

def test_mq_llm_request(self):
from neon_data_models.models.api.mq import MqLlmRequest
from neon_data_models.models.api.mq import LLMProposeRequest
from neon_data_models.models.api.llm import LLMRequest
from neon_data_models.models.base.contexts import MQContext

Expand All @@ -55,31 +55,31 @@ def test_mq_llm_request(self):
message_id = "test_mid"

# Valid fully-defined
valid_request = MqLlmRequest(query=query, history=history,
persona=persona, model=model_name,
message_id=message_id)
self.assertIsInstance(valid_request, MqLlmRequest)
valid_request = LLMProposeRequest(query=query, history=history,
persona=persona, model=model_name,
message_id=message_id)
self.assertIsInstance(valid_request, LLMProposeRequest)
self.assertIsInstance(valid_request, LLMRequest)
self.assertIsInstance(valid_request, MQContext)

# Valid backwards-compat (no model or persona)
backwards_compat = MqLlmRequest(query=query, history=history,
message_id=message_id)
self.assertIsInstance(backwards_compat, MqLlmRequest)
backwards_compat = LLMProposeRequest(query=query, history=history,
message_id=message_id)
self.assertIsInstance(backwards_compat, LLMProposeRequest)
self.assertIsInstance(backwards_compat, LLMRequest)
self.assertIsInstance(backwards_compat, MQContext)
self.assertIsNone(backwards_compat.model)
self.assertIsNone(backwards_compat.persona)

# Invalid Persona defined
with self.assertRaises(ValidationError):
MqLlmRequest(query=query, history=history, message_id=message_id,
persona={})
LLMProposeRequest(query=query, history=history, message_id=message_id,
persona={})

# Invalid MQ Context
with self.assertRaises(ValidationError):
MqLlmRequest(query=query, history=history)
LLMProposeRequest(query=query, history=history)

# Invalid LLM Request
with self.assertRaises(ValidationError):
MqLlmRequest(history=history, message_id=message_id)
LLMProposeRequest(history=history, message_id=message_id)

0 comments on commit 20c4853

Please sign in to comment.