From 014c6e9fcdcb591e0526b78fc0e030bd50f91e98 Mon Sep 17 00:00:00 2001 From: m-barker Date: Wed, 6 Mar 2024 14:30:37 +0000 Subject: [PATCH] feat: tts question answering --- skills/src/lasr_skills/xml_question_answer.py | 5 +++ tasks/gpsr/launch/question_answer.launch | 3 +- tasks/gpsr/nodes/question_answer | 22 +++++++----- tasks/gpsr/src/gpsr/states/get_question.py | 35 +++++++++++++++++++ 4 files changed, 55 insertions(+), 10 deletions(-) create mode 100644 tasks/gpsr/src/gpsr/states/get_question.py diff --git a/skills/src/lasr_skills/xml_question_answer.py b/skills/src/lasr_skills/xml_question_answer.py index 7736ffc70..ca58bc64b 100644 --- a/skills/src/lasr_skills/xml_question_answer.py +++ b/skills/src/lasr_skills/xml_question_answer.py @@ -3,6 +3,7 @@ import rospy import smach import xml.etree.ElementTree as ET +from lasr_voice import Voice from lasr_vector_databases_msgs.srv import TxtQuery, TxtQueryRequest @@ -64,4 +65,8 @@ def execute(self, userdata): except rospy.ServiceException as e: rospy.logwarn(f"Unable to perform Index Query. ({str(e)})") userdata.closest_answers = [] + voice = Voice() + voice.sync_tts( + "I'm sorry, I couldn't find an answer to your question. Please ask me another question." + ) return "failed" diff --git a/tasks/gpsr/launch/question_answer.launch b/tasks/gpsr/launch/question_answer.launch index 60128141b..87de56930 100644 --- a/tasks/gpsr/launch/question_answer.launch +++ b/tasks/gpsr/launch/question_answer.launch @@ -1,5 +1,4 @@ - @@ -17,7 +16,7 @@ type="question_answer" name="question_answer" output="screen" - args="--question $(arg question) --k $(arg k) --index_path $(arg index_path) --txt_path $(arg text_path) --xml_path $(arg xml_path)" + args="--k $(arg k) --index_path $(arg index_path) --txt_path $(arg text_path) --xml_path $(arg xml_path)" /> \ No newline at end of file diff --git a/tasks/gpsr/nodes/question_answer b/tasks/gpsr/nodes/question_answer index 9c715a460..fe00ce725 100644 --- a/tasks/gpsr/nodes/question_answer +++ b/tasks/gpsr/nodes/question_answer @@ -2,7 +2,9 @@ import rospy import argparse import smach +from lasr_voice import Voice from lasr_skills.xml_question_answer import XmlQuestionAnswer +from gpsr.states.get_question import GetQuestion class QuestionAnswerStateMachine(smach.StateMachine): @@ -12,7 +14,6 @@ class QuestionAnswerStateMachine(smach.StateMachine): outcomes=["succeeded", "failed"], output_keys=["closest_answers"], ) - self.userdata.query_sentence = input_data["question"] self.userdata.k = input_data["k"] self.userdata.index_path = input_data["index_path"] self.userdata.txt_path = input_data["txt_path"] @@ -20,10 +21,16 @@ class QuestionAnswerStateMachine(smach.StateMachine): print(self.userdata) with self: + smach.StateMachine.add( + "GET_QUESTION", + GetQuestion(), + transitions={"succeeded": "XML_QUESTION_ANSWER", "failed": "failed"}, + remapping={"question": "query_sentence"}, + ) smach.StateMachine.add( "XML_QUESTION_ANSWER", XmlQuestionAnswer(), - transitions={"succeeded": "succeeded", "failed": "failed"}, + transitions={"succeeded": "succeeded", "failed": "GET_QUESTION"}, remapping={ "query_sentence": "query_sentence", "k": "k", @@ -37,12 +44,6 @@ class QuestionAnswerStateMachine(smach.StateMachine): def parse_args() -> dict: parser = argparse.ArgumentParser(description="GPSR Question Answer") - parser.add_argument( - "--question", - type=str, - help="The question to query", - required=True, - ) parser.add_argument( "--k", type=int, @@ -78,10 +79,15 @@ if __name__ == "__main__": print(args) q_a_sm = QuestionAnswerStateMachine(args) outcome = q_a_sm.execute() + voice = Voice() if outcome == "succeeded": rospy.loginfo(f"Question: {args['question']}") rospy.loginfo(f"Closest Answers: {q_a_sm.userdata.closest_answers}") + voice.sync_tts( + f"The answer to your question is: {q_a_sm.userdata.closest_answers[0]}" + ) else: rospy.logerr("Question Answer State Machine failed") + voice.sync_tts(f"Sorry, I wasn't able to find an answer to your question") rospy.spin() diff --git a/tasks/gpsr/src/gpsr/states/get_question.py b/tasks/gpsr/src/gpsr/states/get_question.py new file mode 100644 index 000000000..e7719034c --- /dev/null +++ b/tasks/gpsr/src/gpsr/states/get_question.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 +import smach +import rospy +import actionlib +from lasr_voice import Voice +from lasr_speech_recognition_msgs.msg import ( + TranscribeSpeechAction, + TranscribeSpeechGoal, +) + + +class GetQuestion(smach.State): + def __init__(self): + smach.State.__init__( + self, outcomes=["succeeded", "failed"], output_keys=["question"] + ) + self.voice = Voice() + self.client = actionlib.SimpleActionClient( + "transcribe_speech", TranscribeSpeechAction + ) + + def execute(self, userdata): + try: + self.client.wait_for_server() + self.voice.sync_tts("Hello, I hear you have a question for me, ask away!") + goal = TranscribeSpeechGoal() + self.client.send_goal(goal) + self.client.wait_for_result() + result = self.client.get_result() + text = result.sequence + userdata.question = text + return "succeeded" + except Exception as e: + rospy.loginfo(f"Failed to get question: {e}") + return "failed"