diff --git a/reverse-engineering-assistant/reverse_engineering_assistant/chat_client.py b/reverse-engineering-assistant/reverse_engineering_assistant/chat_client.py index 95f44e0..8cf6f67 100644 --- a/reverse-engineering-assistant/reverse_engineering_assistant/chat_client.py +++ b/reverse-engineering-assistant/reverse_engineering_assistant/chat_client.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 +import asyncio + from concurrent.futures import thread import queue -from re import M -from typing import Generator, List, Tuple +from typing import Generator, List, Optional, Tuple from uuid import uuid4 from prompt_toolkit import PromptSession import prompt_toolkit @@ -11,6 +12,8 @@ import argparse from pathlib import Path import random +import time + import threading @@ -69,20 +72,60 @@ def get_thinking_emoji() -> str: "🧙‍♀️", ]) -def find_connectable_extensions() -> Generator[Tuple[Path, str, str], None, None]: +def find_connectable_extensions() -> List[Tuple[Path, str, str]]: reva_temp_directory = os.path.join(Path.home(), '.reva') reva_temp = Path(reva_temp_directory) + connection_details: List[Tuple[Path, str, str]] = [] if reva_temp.exists(): for file in reva_temp.glob("reva-connection-*.connection"): connection_string = file.read_text() content = connection_string.split(":") if len(content) == 2: - yield file, content[0], content[1] + connection_details.append((file, content[0], content[1])) else: # If the connection string is the wrong format, we will remove it # this is to clean anything that dropped bad content in the directory logger.warning(f"Invalid connection string: {connection_string}. Cleaning.") file.unlink() + return connection_details + +async def check_connectivity(file: Path, host: str, port: str) -> Optional[RevaHeartbeatResponse]: + # First try a heartbeat to see if the connection is still alive + retries = 10 + logger.debug(f"Checking connection to {host}:{port} from {file}") + for _ in range(retries + 1): + try: + channel = grpc.insecure_channel(f"{host}:{port}") + stub = RevaHeartbeatStub(channel) + response = stub.heartbeat(RevaHeartbeatRequest()) + logger.info(f"Found connectable extension: {response}") + return response + except grpc.RpcError as e: + if retries > 0: + retries -= 1 + logger.warning(f"Failed to connect to {host}:{port}. Retrying in 1 second.") + await asyncio.sleep(1) + else: + # If we can't connect, clean it up + logger.debug(f"Removing old connection file: {file}") + file.unlink() + return None + +async def get_active_extensions(connection_details) -> List[RevaHeartbeatResponse]: + connectable_extensions = [] + + tasks = [] + for file, host, port in connection_details: + tasks.append(check_connectivity(file, host, port)) + + logger.debug(f"Checking {len(tasks)} connectable extensions") + for result in asyncio.as_completed(tasks): + response = await result + logger.debug(f"Got response: {response}") + if response: + connectable_extensions.append(response) + + return connectable_extensions def main(): parser = argparse.ArgumentParser(description="Reva Chat Client") @@ -119,27 +162,23 @@ def main(): console = Console(record=True) if not args.port: connectable_extensions: List[RevaHeartbeatResponse] = [] - for file, host, port in find_connectable_extensions(): - # First try a heartbeat to see if the connection is still alive - retries = 10 - for _ in range(2): - try: - channel = grpc.insecure_channel(f"{host}:{port}") - stub = RevaHeartbeatStub(channel) - response = stub.heartbeat(RevaHeartbeatRequest()) - connectable_extensions.append(response) - logger.info(f"Found connectable extension: {response}") + + with console.status(f"🐉 Searching for Ghidra extension..."): + connection_details = find_connectable_extensions() + logger.debug(f"Found connection details: {connection_details}") + + connectable_extensions = asyncio.run(get_active_extensions(connection_details), debug=args.debug) + if not connectable_extensions: + console.print("Please open a file in Ghidra to start the extension.") + # If there are no connection details, let's assume the user is starting Ghidra + for _ in range(30): + if not connectable_extensions: + time.sleep(1) + connection_details = find_connectable_extensions() + connectable_extensions = asyncio.run(get_active_extensions(connection_details), debug=args.debug) + else: break - except grpc.RpcError as e: - if retries > 0: - import time - retries -= 1 - logger.warning(f"Failed to connect to {host}:{port}. Retrying in 1 second.") - time.sleep(1) - else: - # If we can't connect, clean it up - logger.debug(f"Removing old connection file: {file}") - file.unlink() + if len(connectable_extensions) == 0: logger.error("No connectable extensions found. Is Ghidra running? Is the extension enabled?") parser.error("No connectable extensions found. Is Ghidra running? Is the extension enabled?")