Skip to content

Commit

Permalink
Merge pull request #1 from Pingdred/Validate_threshold
Browse files Browse the repository at this point in the history
Validate memory threshold
  • Loading branch information
nicola-corbellini authored Oct 31, 2023
2 parents c6ab51c + 3c22984 commit 008f965
Showing 1 changed file with 30 additions and 6 deletions.
36 changes: 30 additions & 6 deletions settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from cat.mad_hatter.decorators import plugin
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator

def validate_threshold(value):
if value < 0 or value > 1:
return False

return True


class MySettings(BaseModel):
Expand All @@ -12,13 +18,31 @@ class MySettings(BaseModel):
extra={"type": "TextArea"}
)
episodic_memory_k: int = 3
episodic_memory_threshold: int = 0.7
episodic_memory_threshold: float = 0.7
declarative_memory_k: int = 3
declarative_memory_threshold: int = 0.7
declarative_memory_threshold: float = 0.7
procedural_memory_k: int = 3
procedural_memory_threshold: int = 0.7
procedural_memory_threshold: float = 0.7

@field_validator("episodic_memory_threshold")
@classmethod
def episodic_memory_threshold_validator(cls, threshold):
if not validate_threshold(threshold):
raise ValueError("Episodic memory threshold must be between 0 and 1")

@field_validator("declarative_memory_threshold")
@classmethod
def declarative_memory_threshold_validator(cls, threshold):
if not validate_threshold(threshold):
raise ValueError("Declarative memory threshold must be between 0 and 1")

@field_validator("procedural_memory_threshold")
@classmethod
def procedural_memory_threshold_validator(cls, threshold):
if not validate_threshold(threshold):
raise ValueError("Procedural memory threshold must be between 0 and 1")


@plugin
def settings_schema():
return MySettings.schema()
def settings_model():
return MySettings

0 comments on commit 008f965

Please sign in to comment.