Skip to content

Commit

Permalink
Fix IndexError when collating chat history (Azure-Samples#195)
Browse files Browse the repository at this point in the history
* Fix IndexError when collating chat history

This fixes a bug which causes the exception:
```
ERROR:root:Exception in /api/conversation/custom | list index out of range
Traceback (most recent call last):
  File "/workspaces/chat-with-your-data-solution-accelerator/code/app/app.py", line 283, in conversation_custom
    chat_history.append((user_assistant_messages[i]['content'],user_assistant_messages[i+1]['content']))
```

This is caused when there has been an error providing a response, and
the latest message in the history is from a user, rather than the
assitant. Our code assumes a user message is always followed by an
assistant message.

This change removes that assumption and explitely retreives the role for
each message when collating the chat history.

Required by Azure-Samples#114

* Add python formatter to dev container

* Add tests for conversation custom

- Extract some elements to dedicated function to allow mocking

* Switch to black formatter to align with precommit hook

* Add test to cover error scenario when message index is out of range

* Add dependencies required for running app tests

---------

Co-authored-by: Ross Smith <ross-p-smith@users.noreply.github.com>
  • Loading branch information
adamdougal and ross-p-smith authored Feb 8, 2024
1 parent 751ba1e commit d99ff5f
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 20 deletions.
1 change: 1 addition & 0 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"ms-azuretools.vscode-bicep",
"ms-azuretools.vscode-docker",
"ms-python.python",
"ms-python.black-formatter",
"ms-python.vscode-pylance",
"ms-vscode.vscode-node-azure-pack",
"TeamsDevApp.ms-teams-vscode-extension"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ jobs:
architecture: x64
- name: Install dependencies
run: |
pip install -r code/requirements.txt -r code/dev-requirements.txt
pip install -r code/requirements.txt -r code/dev-requirements.txt -r code/app/requirements.txt
- name: Run Python tests
run: python -m pytest --rootdir=code -m "not azure"
30 changes: 15 additions & 15 deletions code/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,21 @@ def conversation_azure_byod():
)


@app.route("/api/conversation/custom", methods=["GET", "POST"])
def conversation_custom():
def get_message_orchestrator():
from utilities.helpers.OrchestratorHelper import Orchestrator

message_orchestrator = Orchestrator()
return Orchestrator()


def get_orchestrator_config():
from utilities.helpers.ConfigHelper import ConfigHelper

return ConfigHelper.get_active_config_or_default().orchestrator


@app.route("/api/conversation/custom", methods=["GET", "POST"])
def conversation_custom():
message_orchestrator = get_message_orchestrator()

try:
user_message = request.json["messages"][-1]["content"]
Expand All @@ -328,22 +338,12 @@ def conversation_custom():
request.json["messages"][0:-1],
)
)
chat_history = []
for i, k in enumerate(user_assistant_messages):
if i % 2 == 0:
chat_history.append(
(
user_assistant_messages[i]["content"],
user_assistant_messages[i + 1]["content"],
)
)
from utilities.helpers.ConfigHelper import ConfigHelper

messages = message_orchestrator.handle_message(
user_message=user_message,
chat_history=chat_history,
chat_history=user_assistant_messages,
conversation_id=conversation_id,
orchestrator=ConfigHelper.get_active_config_or_default().orchestrator,
orchestrator=get_orchestrator_config(),
)

response_obj = {
Expand Down
6 changes: 4 additions & 2 deletions code/utilities/orchestrator/LangChainAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,10 @@ def orchestrate(
memory_key="chat_history", return_messages=True
)
for message in chat_history:
memory.chat_memory.add_user_message(message[0])
memory.chat_memory.add_ai_message(message[1])
if message["role"] == "user":
memory.chat_memory.add_user_message(message["content"])
elif message["role"] == "assistant":
memory.chat_memory.add_ai_message(message["content"])
# Define Agent and Agent Chain
llm_chain = LLMChain(llm=llm_helper.get_llm(), prompt=prompt)
agent = ZeroShotAgent(llm_chain=llm_chain, tools=self.tools, verbose=True)
Expand Down
3 changes: 1 addition & 2 deletions code/utilities/orchestrator/OpenAIFunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def orchestrate(
# Create conversation history
messages = [{"role": "system", "content": system_message}]
for message in chat_history:
messages.append({"role": "user", "content": message[0]})
messages.append({"role": "assistant", "content": message[1]})
messages.append({"role": message["role"], "content": message["content"]})
messages.append({"role": "user", "content": user_message})

result = llm_helper.get_chat_completion_with_functions(
Expand Down
161 changes: 161 additions & 0 deletions tests/test_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import os

from unittest.mock import Mock
from unittest.mock import patch

from code.app.app import app


class TestConfig:
def test_returns_correct_config(self):
response = app.test_client().get("/api/config")

assert response.status_code == 200
assert response.json == {"azureSpeechKey": None, "azureSpeechRegion": None}


class TestCoversationCustom:
def setup_method(self):
self.orchestrator_config = {"strategy": "langchain"}
self.messages = [
{
"content": '{"citations": [], "intent": "A question?"}',
"end_turn": False,
"role": "tool",
},
{"content": "An answer", "end_turn": True, "role": "assistant"},
]
self.openai_model = "some-model"
self.body = {
"conversation_id": "123",
"messages": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi, how can I help?"},
{"role": "user", "content": "What is the meaning of life?"},
],
}

@patch("code.app.app.get_message_orchestrator")
@patch("code.app.app.get_orchestrator_config")
def test_converstation_custom_returns_correct_response(
self, get_orchestrator_config_mock, get_message_orchestrator_mock
):
# given
get_orchestrator_config_mock.return_value = self.orchestrator_config

message_orchestrator_mock = Mock()
message_orchestrator_mock.handle_message.return_value = self.messages
get_message_orchestrator_mock.return_value = message_orchestrator_mock

os.environ["AZURE_OPENAI_MODEL"] = self.openai_model

# when
response = app.test_client().post(
"/api/conversation/custom",
headers={"content-type": "application/json"},
json=self.body,
)

# then
assert response.status_code == 200
assert response.json == {
"choices": [{"messages": self.messages}],
"created": "response.created",
"id": "response.id",
"model": self.openai_model,
"object": "response.object",
}

@patch("code.app.app.get_message_orchestrator")
@patch("code.app.app.get_orchestrator_config")
def test_converstation_custom_calls_message_orchestrator_correctly(
self, get_orchestrator_config_mock, get_message_orchestrator_mock
):
# given
get_orchestrator_config_mock.return_value = self.orchestrator_config

message_orchestrator_mock = Mock()
message_orchestrator_mock.handle_message.return_value = self.messages
get_message_orchestrator_mock.return_value = message_orchestrator_mock

os.environ["AZURE_OPENAI_MODEL"] = self.openai_model

# when
app.test_client().post(
"/api/conversation/custom",
headers={"content-type": "application/json"},
json=self.body,
)

# then
message_orchestrator_mock.handle_message.assert_called_once_with(
user_message=self.body["messages"][-1]["content"],
chat_history=self.body["messages"][:-1],
conversation_id=self.body["conversation_id"],
orchestrator=self.orchestrator_config,
)

@patch("code.app.app.get_orchestrator_config")
def test_converstation_custom_returns_error_resonse_on_exception(
self, get_orchestrator_config_mock
):
# given
get_orchestrator_config_mock.side_effect = Exception("An error occurred")

# when
response = app.test_client().post(
"/api/conversation/custom",
headers={"content-type": "application/json"},
json=self.body,
)

# then
assert response.status_code == 500
assert response.json == {
"error": "Exception in /api/conversation/custom. See log for more details."
}

@patch("code.app.app.get_message_orchestrator")
@patch("code.app.app.get_orchestrator_config")
def test_converstation_custom_allows_multiple_messages_from_user(
self, get_orchestrator_config_mock, get_message_orchestrator_mock
):
"""This can happen if there was an error getting a response from the assistant for the previous user message."""

# given
get_orchestrator_config_mock.return_value = self.orchestrator_config

message_orchestrator_mock = Mock()
message_orchestrator_mock.handle_message.return_value = self.messages
get_message_orchestrator_mock.return_value = message_orchestrator_mock

os.environ["AZURE_OPENAI_MODEL"] = self.openai_model

body = {
"conversation_id": "123",
"messages": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi, how can I help?"},
{"role": "user", "content": "What is the meaning of life?"},
{
"role": "user",
"content": "Please, what is the meaning of life?",
},
],
}

# when
response = app.test_client().post(
"/api/conversation/custom",
headers={"content-type": "application/json"},
json=body,
)

# then
assert response.status_code == 200
message_orchestrator_mock.handle_message.assert_called_once_with(
user_message=body["messages"][-1]["content"],
chat_history=body["messages"][:-1],
conversation_id=body["conversation_id"],
orchestrator=self.orchestrator_config,
)

0 comments on commit d99ff5f

Please sign in to comment.