Skip to content

Commit

Permalink
from_json method (#37)
Browse files Browse the repository at this point in the history
* feat : Response class from_json method added

* feat : CustomPromptTemplate class from_json method added

* feat : Prompt class from_json method added

* fix : test_response.py updated

* fix : tests updated

* fix : autopep8

* fix : minor edits

* fix : test_prompt.py updated

* fix : extra space bug fixed
  • Loading branch information
sepandhaghighi authored Feb 5, 2025
1 parent aa83d2c commit 761e7c7
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 48 deletions.
6 changes: 3 additions & 3 deletions memor/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
INVALID_RESPONSES_MESSAGE = "Invalid responses. It must be a list of `Response` objects."
INVALID_CUSTOM_MAP_MESSAGE = "Invalid custom map: it must be a dictionary with keys and values that can be converted to strings."
INVALID_ROLE_MESSAGE = "Invalid role. It must be an instance of Role enum."
INVALID_TEMPLATE_FILE_MESSAGE = "Invalid template file. It should be a JSON file with proper fields."
INVALID_PROMPT_FILE_MESSAGE = "Invalid prompt file. It should be a JSON file with proper fields."
INVALID_RESPONSE_FILE_MESSAGE = "Invalid response file. It should be a JSON file with proper fields."
INVALID_TEMPLATE_STRUCTURE_MESSAGE = "Invalid template structure. It should be a JSON object with proper fields."
INVALID_PROMPT_STRUCTURE_MESSAGE = "Invalid prompt structure. It should be a JSON object with proper fields."
INVALID_RESPONSE_STRUCTURE_MESSAGE = "Invalid response structure. It should be a JSON object with proper fields."
INVALID_RENDER_FORMAT_MESSAGE = "Invalid render format. It must be an instance of PromptRenderFormat enum."
PROMPT_RENDER_ERROR_MESSAGE = "Prompt template and properties are incompatible."
DATA_SAVE_SUCCESS_MESSAGE = "Everything seems good."
Expand Down
57 changes: 37 additions & 20 deletions memor/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .params import DATE_TIME_FORMAT
from .params import PromptRenderFormat, DATA_SAVE_SUCCESS_MESSAGE
from .params import Role
from .params import INVALID_PROMPT_FILE_MESSAGE, INVALID_TEMPLATE_MESSAGE
from .params import INVALID_PROMPT_STRUCTURE_MESSAGE, INVALID_TEMPLATE_MESSAGE
from .params import INVALID_ROLE_MESSAGE, INVALID_RESPONSE_MESSAGE
from .params import PROMPT_RENDER_ERROR_MESSAGE, INVALID_RESPONSES_MESSAGE
from .params import INVALID_RENDER_FORMAT_MESSAGE
Expand Down Expand Up @@ -236,7 +236,7 @@ def save(self, file_path, save_template=True):
try:
with open(file_path, "w") as file:
data = self.to_json(save_template=save_template)
file.write(data, indent=4)
file.write(data)
except Exception as e:
result["status"] = False
result["message"] = str(e)
Expand All @@ -252,21 +252,38 @@ def load(self, file_path):
"""
validate_path(file_path)
with open(file_path, "r") as file:
try:
loaded_obj = json.loads(file.read())
self._message = loaded_obj["message"]
self._responses = loaded_obj["responses"]
self._role = Role(loaded_obj["role"])
self._template = PresetPromptTemplate.DEFAULT.value
if "template" in loaded_obj:
self._template = CustomPromptTemplate(**loaded_obj["template"])
self._memor_version = loaded_obj["memor_version"]
self._date_created = datetime.datetime.strptime(loaded_obj["date_created"], DATE_TIME_FORMAT)
self._date_modified = datetime.datetime.strptime(loaded_obj["date_modified"], DATE_TIME_FORMAT)
self._selected_response_index = loaded_obj["selected_response_index"]
self.select_response(index=self._selected_response_index)
except Exception:
raise MemorValidationError(INVALID_PROMPT_FILE_MESSAGE)
self.from_json(file.read())

def from_json(self, json_doc):
"""
Load attributes from the JSON document.
:param json_doc: JSON document
:type json_doc: str
:return: None
"""
try:
loaded_obj = json.loads(json_doc)
self._message = loaded_obj["message"]
responses = []
for response in loaded_obj["responses"]:
response_obj = Response()
response_obj.from_json(response)
responses.append(response_obj)
self._responses = responses
self._role = Role(loaded_obj["role"])
self._template = PresetPromptTemplate.DEFAULT.value
if "template" in loaded_obj:
template_obj = CustomPromptTemplate()
template_obj.from_json(loaded_obj["template"])
self._template = template_obj
self._memor_version = loaded_obj["memor_version"]
self._date_created = datetime.datetime.strptime(loaded_obj["date_created"], DATE_TIME_FORMAT)
self._date_modified = datetime.datetime.strptime(loaded_obj["date_modified"], DATE_TIME_FORMAT)
self._selected_response_index = loaded_obj["selected_response_index"]
self.select_response(index=self._selected_response_index)
except Exception:
raise MemorValidationError(INVALID_PROMPT_STRUCTURE_MESSAGE)

def to_json(self, save_template=True):
"""
Expand All @@ -278,10 +295,10 @@ def to_json(self, save_template=True):
"""
data = self.to_dict(save_template=save_template)
for index, response in enumerate(data["responses"]):
data["responses"][index] = response.to_dict()
data["responses"][index] = response.to_json()
if "template" in data:
data["template"] = data["template"].to_dict()
data["role"] = str(data["role"])
data["template"] = data["template"].to_json()
data["role"] = data["role"].value
data["date_created"] = datetime.datetime.strftime(data["date_created"], DATE_TIME_FORMAT)
data["date_modified"] = datetime.datetime.strftime(data["date_modified"], DATE_TIME_FORMAT)
return json.dumps(data, indent=4)
Expand Down
38 changes: 24 additions & 14 deletions memor/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .params import MEMOR_VERSION
from .params import DATE_TIME_FORMAT
from .params import DATA_SAVE_SUCCESS_MESSAGE
from .params import INVALID_RESPONSE_FILE_MESSAGE
from .params import INVALID_RESPONSE_STRUCTURE_MESSAGE
from .params import INVALID_ROLE_MESSAGE
from .params import Role
from .errors import MemorValidationError
Expand Down Expand Up @@ -204,18 +204,28 @@ def load(self, file_path):
"""
validate_path(file_path)
with open(file_path, "r") as file:
try:
loaded_obj = json.loads(file.read())
self._message = loaded_obj["message"]
self._score = loaded_obj["score"]
self._temperature = loaded_obj["temperature"]
self._model = loaded_obj["model"]
self._role = Role(loaded_obj["role"])
self._memor_version = loaded_obj["memor_version"]
self._date_created = datetime.datetime.strptime(loaded_obj["date_created"], DATE_TIME_FORMAT)
self._date_modified = datetime.datetime.strptime(loaded_obj["date_modified"], DATE_TIME_FORMAT)
except Exception:
raise MemorValidationError(INVALID_RESPONSE_FILE_MESSAGE)
self.from_json(file.read())

def from_json(self, json_doc):
"""
Load attributes from the JSON document.
:param json_doc: JSON document
:type json_doc: str
:return: None
"""
try:
loaded_obj = json.loads(json_doc)
self._message = loaded_obj["message"]
self._score = loaded_obj["score"]
self._temperature = loaded_obj["temperature"]
self._model = loaded_obj["model"]
self._role = Role(loaded_obj["role"])
self._memor_version = loaded_obj["memor_version"]
self._date_created = datetime.datetime.strptime(loaded_obj["date_created"], DATE_TIME_FORMAT)
self._date_modified = datetime.datetime.strptime(loaded_obj["date_modified"], DATE_TIME_FORMAT)
except Exception:
raise MemorValidationError(INVALID_RESPONSE_STRUCTURE_MESSAGE)

def to_json(self):
"""
Expand All @@ -226,7 +236,7 @@ def to_json(self):
data = self.to_dict()
data["date_created"] = datetime.datetime.strftime(data["date_created"], DATE_TIME_FORMAT)
data["date_modified"] = datetime.datetime.strftime(data["date_modified"], DATE_TIME_FORMAT)
data["role"] = str(data["role"])
data["role"] = data["role"].value
return json.dumps(data, indent=4)

def to_dict(self):
Expand Down
32 changes: 21 additions & 11 deletions memor/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from enum import Enum
from .params import DATE_TIME_FORMAT
from .params import DATA_SAVE_SUCCESS_MESSAGE
from .params import INVALID_TEMPLATE_FILE_MESSAGE
from .params import INVALID_TEMPLATE_STRUCTURE_MESSAGE
from .params import MEMOR_VERSION
from .errors import MemorValidationError
from .functions import get_time_utc
Expand Down Expand Up @@ -157,16 +157,26 @@ def load(self, file_path):
"""
validate_path(file_path)
with open(file_path, "r") as file:
try:
loaded_obj = json.loads(file.read())
self._content = loaded_obj["content"]
self._title = loaded_obj["title"]
self._memor_version = loaded_obj["memor_version"]
self._custom_map = loaded_obj["custom_map"]
self._date_created = datetime.datetime.strptime(loaded_obj["date_created"], DATE_TIME_FORMAT)
self._date_modified = datetime.datetime.strptime(loaded_obj["date_modified"], DATE_TIME_FORMAT)
except Exception:
raise MemorValidationError(INVALID_TEMPLATE_FILE_MESSAGE)
self.from_json(file.read())

def from_json(self, json_doc):
"""
Load attributes from the JSON document.
:param json_doc: JSON document
:type json_doc: str
:return: None
"""
try:
loaded_obj = json.loads(json_doc)
self._content = loaded_obj["content"]
self._title = loaded_obj["title"]
self._memor_version = loaded_obj["memor_version"]
self._custom_map = loaded_obj["custom_map"]
self._date_created = datetime.datetime.strptime(loaded_obj["date_created"], DATE_TIME_FORMAT)
self._date_modified = datetime.datetime.strptime(loaded_obj["date_modified"], DATE_TIME_FORMAT)
except Exception:
raise MemorValidationError(INVALID_TEMPLATE_STRUCTURE_MESSAGE)

def to_json(self):
"""
Expand Down
33 changes: 33 additions & 0 deletions tests/test_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,39 @@ def test_repr():
assert repr(prompt) == "Prompt(message={message})".format(message=prompt.message)


def test_json():
message = "Hello, how are you?"
response1 = Response(message="I am fine.", model="GPT-4", temperature=0.5, role=Role.USER, score=0.8)
response2 = Response(message="Thanks!", model="GPT-4", temperature=0.5, role=Role.USER, score=0.8)
prompt1 = Prompt(
message=message,
responses=[
response1,
response2],
role=Role.USER,
template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD)
prompt1_json = prompt1.to_json()
prompt2 = Prompt()
prompt2.from_json(prompt1_json)
assert prompt1 == prompt2


def test_load():
message = "Hello, how are you?"
response1 = Response(message="I am fine.", model="GPT-4", temperature=0.5, role=Role.USER, score=0.8)
response2 = Response(message="Thanks!", model="GPT-4", temperature=0.5, role=Role.USER, score=0.8)
prompt1 = Prompt(
message=message,
responses=[
response1,
response2],
role=Role.USER,
template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD)
result = prompt1.save("prompt_test1.json")
prompt2 = Prompt(file_path="prompt_test1.json")
assert result["status"] and prompt1 == prompt2


def test_equality1():
message = "Hello, how are you?"
response1 = Response(message="I am fine.", model="GPT-4", temperature=0.5, role=Role.USER, score=0.8)
Expand Down
15 changes: 15 additions & 0 deletions tests/test_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ def test_date():
assert response.date_created == date_time_utc


def test_json():
response1 = Response(message="I am fine.", model="GPT-4", temperature=0.5, role=Role.USER, score=0.8)
response1_json = response1.to_json()
response2 = Response()
response2.from_json(response1_json)
assert response1 == response2


def test_save():
response = Response(message="I am fine.", model="GPT-4", temperature=0.5, role=Role.USER, score=0.8)
result = response.save("response_test1.json")
Expand All @@ -75,6 +83,13 @@ def test_save():
assert result["status"] and json.loads(response.to_json()) == saved_response


def test_load():
response1 = Response(message="I am fine.", model="GPT-4", temperature=0.5, role=Role.USER, score=0.8)
result = response1.save("response_test2.json")
response2 = Response(file_path="response_test2.json")
assert result["status"] and response1 == response2


def test_copy1():
response1 = Response(message="I am fine.", model="GPT-4", temperature=0.5, role=Role.USER, score=0.8)
response2 = copy.copy(response1)
Expand Down

0 comments on commit 761e7c7

Please sign in to comment.