diff --git a/tasks/gpsr/src/gpsr/states/command_similarity_matcher.py b/tasks/gpsr/src/gpsr/states/command_similarity_matcher.py index 3982f9f67..b7df8875b 100755 --- a/tasks/gpsr/src/gpsr/states/command_similarity_matcher.py +++ b/tasks/gpsr/src/gpsr/states/command_similarity_matcher.py @@ -52,12 +52,12 @@ def execute(self, userdata): sm = smach.StateMachine(outcomes=["succeeded", "failed"]) with sm: sm.userdata.tts_phrase = "Please tell me your command." - # smach.StateMachine.add( - # "ASK_FOR_COMMAND", - # AskAndListen(), - # transitions={"succeeded": "COMMAND_SIMILARITY_MATCHER", "failed": "failed"}, - # remapping={"transcribed_speech": "command"}, - # ) + smach.StateMachine.add( + "ASK_FOR_COMMAND", + AskAndListen(), + transitions={"succeeded": "COMMAND_SIMILARITY_MATCHER", "failed": "failed"}, + remapping={"transcribed_speech": "command"}, + ) sm.add( "LISTEN", Listen(), @@ -71,18 +71,18 @@ def execute(self, userdata): sm.add( "COMMAND_SIMILARITY_MATCHER", CommandSimilarityMatcher([1177943] * 10), - transitions={"succeeded": "LISTEN"}, + transitions={"succeeded": "SAY_MATCHED_COMMAND", "failed": "failed"}, + ) + sm.add( + "SAY_MATCHED_COMMAND", + Say(), + transitions={ + "succeeded": "ASK_FOR_COMMAND", + "aborted": "failed", + "preempted": "failed", + }, + remapping={"text": "matched_command"}, ) - # smach.StateMachine.add( - # "SAY_MATCHED_COMMAND", - # Say(), - # transitions={ - # "succeeded": "ASK_FOR_COMMAND", - # "aborted": "failed", - # "preempted": "failed", - # }, - # remapping={"text": "matched_command"}, - # ) sm.execute() rospy.spin() diff --git a/tasks/receptionist/CMakeLists.txt b/tasks/receptionist/CMakeLists.txt index e99a3b839..373ffcf35 100644 --- a/tasks/receptionist/CMakeLists.txt +++ b/tasks/receptionist/CMakeLists.txt @@ -169,6 +169,7 @@ include_directories( catkin_install_python(PROGRAMS scripts/main.py scripts/test_find_and_look_at.py + scripts/create_name_and_drink_vector_db.py DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} ) diff --git a/tasks/receptionist/data/name_and_drink.index b/tasks/receptionist/data/name_and_drink.index new file mode 100644 index 000000000..4460434e9 Binary files /dev/null and b/tasks/receptionist/data/name_and_drink.index differ diff --git a/tasks/receptionist/data/name_and_drink_vector_db.txt b/tasks/receptionist/data/name_and_drink_vector_db.txt new file mode 100644 index 000000000..54b8581de --- /dev/null +++ b/tasks/receptionist/data/name_and_drink_vector_db.txt @@ -0,0 +1,70 @@ +My name is Adel and my favorite drink is cola +My name is Adel and my favorite drink is iced tea +My name is Adel and my favorite drink is juice pack +My name is Adel and my favorite drink is milk +My name is Adel and my favorite drink is orange juice +My name is Adel and my favorite drink is red wine +My name is Adel and my favorite drink is tropical juice +My name is Angel and my favorite drink is cola +My name is Angel and my favorite drink is iced tea +My name is Angel and my favorite drink is juice pack +My name is Angel and my favorite drink is milk +My name is Angel and my favorite drink is orange juice +My name is Angel and my favorite drink is red wine +My name is Angel and my favorite drink is tropical juice +My name is Axel and my favorite drink is cola +My name is Axel and my favorite drink is iced tea +My name is Axel and my favorite drink is juice pack +My name is Axel and my favorite drink is milk +My name is Axel and my favorite drink is orange juice +My name is Axel and my favorite drink is red wine +My name is Axel and my favorite drink is tropical juice +My name is Charlie and my favorite drink is cola +My name is Charlie and my favorite drink is iced tea +My name is Charlie and my favorite drink is juice pack +My name is Charlie and my favorite drink is milk +My name is Charlie and my favorite drink is orange juice +My name is Charlie and my favorite drink is red wine +My name is Charlie and my favorite drink is tropical juice +My name is Jane and my favorite drink is cola +My name is Jane and my favorite drink is iced tea +My name is Jane and my favorite drink is juice pack +My name is Jane and my favorite drink is milk +My name is Jane and my favorite drink is orange juice +My name is Jane and my favorite drink is red wine +My name is Jane and my favorite drink is tropical juice +My name is Jules and my favorite drink is cola +My name is Jules and my favorite drink is iced tea +My name is Jules and my favorite drink is juice pack +My name is Jules and my favorite drink is milk +My name is Jules and my favorite drink is orange juice +My name is Jules and my favorite drink is red wine +My name is Jules and my favorite drink is tropical juice +My name is Morgan and my favorite drink is cola +My name is Morgan and my favorite drink is iced tea +My name is Morgan and my favorite drink is juice pack +My name is Morgan and my favorite drink is milk +My name is Morgan and my favorite drink is orange juice +My name is Morgan and my favorite drink is red wine +My name is Morgan and my favorite drink is tropical juice +My name is Paris and my favorite drink is cola +My name is Paris and my favorite drink is iced tea +My name is Paris and my favorite drink is juice pack +My name is Paris and my favorite drink is milk +My name is Paris and my favorite drink is orange juice +My name is Paris and my favorite drink is red wine +My name is Paris and my favorite drink is tropical juice +My name is Robin and my favorite drink is cola +My name is Robin and my favorite drink is iced tea +My name is Robin and my favorite drink is juice pack +My name is Robin and my favorite drink is milk +My name is Robin and my favorite drink is orange juice +My name is Robin and my favorite drink is red wine +My name is Robin and my favorite drink is tropical juice +My name is Simone and my favorite drink is cola +My name is Simone and my favorite drink is iced tea +My name is Simone and my favorite drink is juice pack +My name is Simone and my favorite drink is milk +My name is Simone and my favorite drink is orange juice +My name is Simone and my favorite drink is red wine +My name is Simone and my favorite drink is tropical juice diff --git a/tasks/receptionist/scripts/create_name_and_drink_vector_db.py b/tasks/receptionist/scripts/create_name_and_drink_vector_db.py new file mode 100644 index 000000000..82f22d4ad --- /dev/null +++ b/tasks/receptionist/scripts/create_name_and_drink_vector_db.py @@ -0,0 +1,48 @@ +import os +from typing import List +import rospy +import rospkg +from lasr_vector_databases_msgs.srv import TxtIndex, TxtIndexRequest + + +def create_txt_file(output_path: str) -> None: + """Creates a txt file containing all permutations of + "My name is and my favorite drink is " + + Args: + output_path (str): Path to the output txt file + """ + + names: List[str] = rospy.get_param("/priors/names") + drinks: List[str] = rospy.get_param("/priors/drinks") + + with open(output_path, "w") as f: + for name in names: + for drink in drinks: + f.write(f"My name is {name} and my favorite drink is {drink}\n") + + +def create_vector_db(txt_path: str, output_path: str) -> None: + """Creates a vector database from a txt file containing + all permutations of "My name is and my favorite drink is " + + Args: + txt_path (str): Path to the txt file + output_path (str): Path to the output vector database + """ + + rospy.wait_for_service("lasr_faiss/txt_index") + txt_index = rospy.ServiceProxy("lasr_faiss/txt_index", TxtIndex) + + request = TxtIndexRequest() + request.txt_paths = [txt_path] + request.index_paths = [output_path] + request.index_factory_string = "Flat" + txt_index(request) + + +if __name__ == "__main__": + data_dir = os.path.join(rospkg.RosPack().get_path("receptionist"), "data") + txt_path = os.path.join(data_dir, "name_and_drink_vector_db.txt") + output_path = os.path.join(data_dir, "name_and_drink.index") + create_vector_db(txt_path, output_path) diff --git a/tasks/receptionist/src/receptionist/states/match_name_and_drink.py b/tasks/receptionist/src/receptionist/states/match_name_and_drink.py new file mode 100755 index 000000000..0982fa0a3 --- /dev/null +++ b/tasks/receptionist/src/receptionist/states/match_name_and_drink.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +import os +import rospy +import rospkg +import smach +from lasr_skills import AskAndListen, Say +from lasr_vector_databases_msgs.srv import TxtQuery, TxtQueryRequest + + +class MatchNameAndDrink(smach.State): + def __init__(self): + smach.State.__init__( + self, + outcomes=["succeeded", "failed"], + input_keys=["sequence"], + output_keys=["matched_name", "matched_drink", "sequence"], + ) + + self._query_service = rospy.ServiceProxy("lasr_faiss/txt_query", TxtQuery) + self._text_path = os.path.join( + rospkg.RosPack().get_path("receptionist"), + "data", + "name_and_drink_vector_db.txt", + ) + self._index_path = os.path.join( + rospkg.RosPack().get_path("receptionist"), "data", "name_and_drink.index" + ) + + def execute(self, userdata): + rospy.loginfo(f"Received transcript: {userdata.sequence}") + request = TxtQueryRequest() + request.txt_paths = [self._text_path] + request.index_paths = [self._index_path] + request.query_sentence = userdata.sequence + request.k = 1 + response = self._query_service(request) + matched_name, matched_drink = response.closest_sentences[0].split( + " and my favorite drink is " + ) + matched_name = matched_name.split("My name is ")[1] + userdata.matched_name = matched_name + userdata.matched_drink = matched_drink + rospy.loginfo( + f"Matched name: {matched_name} and matched drink: {matched_drink}" + ) + userdata.sequence = ( + f"Hello {matched_name}, I see that your favorite drink is {matched_drink}." + ) + return "succeeded" + + +if __name__ == "__main__": + rospy.init_node("match_name_and_drink") + sm = smach.StateMachine(outcomes=["succeeded", "failed"]) + with sm: + smach.StateMachine.add( + "ASK_FOR_NAME_AND_DRINK", + AskAndListen( + tts_phrase="Hello, please tell me your name and favorite drink." + ), + transitions={"succeeded": "MATCH_NAME_AND_DRINK", "failed": "failed"}, + remapping={"transcribed_speech": "sequence"}, + ) + smach.StateMachine.add( + "MATCH_NAME_AND_DRINK", + MatchNameAndDrink(), + transitions={"succeeded": "SAY_MATCHED_NAME_AND_DRINK", "failed": "failed"}, + remapping={"sequence": "sequence"}, + ) + smach.StateMachine.add( + "SAY_MATCHED_NAME_AND_DRINK", + Say(), + transitions={ + "succeeded": "succeeded", + "aborted": "failed", + "preempted": "failed", + }, + remapping={ + "text": "sequence", + }, + ) + sm.execute()