Skip to content

Commit

Permalink
feat: tts question answering
Browse files Browse the repository at this point in the history
  • Loading branch information
m-barker committed Mar 6, 2024
1 parent b012e23 commit 014c6e9
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 10 deletions.
5 changes: 5 additions & 0 deletions skills/src/lasr_skills/xml_question_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import rospy
import smach
import xml.etree.ElementTree as ET
from lasr_voice import Voice

from lasr_vector_databases_msgs.srv import TxtQuery, TxtQueryRequest

Expand Down Expand Up @@ -64,4 +65,8 @@ def execute(self, userdata):
except rospy.ServiceException as e:
rospy.logwarn(f"Unable to perform Index Query. ({str(e)})")
userdata.closest_answers = []
voice = Voice()
voice.sync_tts(
"I'm sorry, I couldn't find an answer to your question. Please ask me another question."
)
return "failed"
3 changes: 1 addition & 2 deletions tasks/gpsr/launch/question_answer.launch
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
<launch>
<arg name="question" default='"Do the French like snails?"'/>
<arg name="k" default="1"/>
<arg name="index_path" default="/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/gpsr/data/questions.index"/>
<arg name="text_path" default="/home/mattbarker/LASR/lasr_ws/src/lasr-base/tasks/gpsr/data/questions.txt"/>
Expand All @@ -17,7 +16,7 @@
type="question_answer"
name="question_answer"
output="screen"
args="--question $(arg question) --k $(arg k) --index_path $(arg index_path) --txt_path $(arg text_path) --xml_path $(arg xml_path)"
args="--k $(arg k) --index_path $(arg index_path) --txt_path $(arg text_path) --xml_path $(arg xml_path)"
/>

</launch>
22 changes: 14 additions & 8 deletions tasks/gpsr/nodes/question_answer
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import rospy
import argparse
import smach
from lasr_voice import Voice
from lasr_skills.xml_question_answer import XmlQuestionAnswer
from gpsr.states.get_question import GetQuestion


class QuestionAnswerStateMachine(smach.StateMachine):
Expand All @@ -12,18 +14,23 @@ class QuestionAnswerStateMachine(smach.StateMachine):
outcomes=["succeeded", "failed"],
output_keys=["closest_answers"],
)
self.userdata.query_sentence = input_data["question"]
self.userdata.k = input_data["k"]
self.userdata.index_path = input_data["index_path"]
self.userdata.txt_path = input_data["txt_path"]
self.userdata.xml_path = input_data["xml_path"]
print(self.userdata)

with self:
smach.StateMachine.add(
"GET_QUESTION",
GetQuestion(),
transitions={"succeeded": "XML_QUESTION_ANSWER", "failed": "failed"},
remapping={"question": "query_sentence"},
)
smach.StateMachine.add(
"XML_QUESTION_ANSWER",
XmlQuestionAnswer(),
transitions={"succeeded": "succeeded", "failed": "failed"},
transitions={"succeeded": "succeeded", "failed": "GET_QUESTION"},
remapping={
"query_sentence": "query_sentence",
"k": "k",
Expand All @@ -37,12 +44,6 @@ class QuestionAnswerStateMachine(smach.StateMachine):

def parse_args() -> dict:
parser = argparse.ArgumentParser(description="GPSR Question Answer")
parser.add_argument(
"--question",
type=str,
help="The question to query",
required=True,
)
parser.add_argument(
"--k",
type=int,
Expand Down Expand Up @@ -78,10 +79,15 @@ if __name__ == "__main__":
print(args)
q_a_sm = QuestionAnswerStateMachine(args)
outcome = q_a_sm.execute()
voice = Voice()
if outcome == "succeeded":
rospy.loginfo(f"Question: {args['question']}")
rospy.loginfo(f"Closest Answers: {q_a_sm.userdata.closest_answers}")
voice.sync_tts(
f"The answer to your question is: {q_a_sm.userdata.closest_answers[0]}"
)
else:
rospy.logerr("Question Answer State Machine failed")
voice.sync_tts(f"Sorry, I wasn't able to find an answer to your question")

rospy.spin()
35 changes: 35 additions & 0 deletions tasks/gpsr/src/gpsr/states/get_question.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/usr/bin/env python3
import smach
import rospy
import actionlib
from lasr_voice import Voice
from lasr_speech_recognition_msgs.msg import (
TranscribeSpeechAction,
TranscribeSpeechGoal,
)


class GetQuestion(smach.State):
def __init__(self):
smach.State.__init__(
self, outcomes=["succeeded", "failed"], output_keys=["question"]
)
self.voice = Voice()
self.client = actionlib.SimpleActionClient(
"transcribe_speech", TranscribeSpeechAction
)

def execute(self, userdata):
try:
self.client.wait_for_server()
self.voice.sync_tts("Hello, I hear you have a question for me, ask away!")
goal = TranscribeSpeechGoal()
self.client.send_goal(goal)
self.client.wait_for_result()
result = self.client.get_result()
text = result.sequence
userdata.question = text
return "succeeded"
except Exception as e:
rospy.loginfo(f"Failed to get question: {e}")
return "failed"

0 comments on commit 014c6e9

Please sign in to comment.