Skip to content

Commit

Permalink
Improve experience starting reva-chat
Browse files Browse the repository at this point in the history
Handle when no file is open, file is currently opening and
better handle old connection files.

Now we have a nice emoji and instructions!
  • Loading branch information
cyberkaida committed Jul 8, 2024
1 parent bf4445c commit 2b5595d
Showing 1 changed file with 63 additions and 24 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#!/usr/bin/env python3
import asyncio

from concurrent.futures import thread
import queue
from re import M
from typing import Generator, List, Tuple
from typing import Generator, List, Optional, Tuple
from uuid import uuid4
from prompt_toolkit import PromptSession
import prompt_toolkit
Expand All @@ -11,6 +12,8 @@
import argparse
from pathlib import Path
import random
import time


import threading

Expand Down Expand Up @@ -69,20 +72,60 @@ def get_thinking_emoji() -> str:
"🧙‍♀️",
])

def find_connectable_extensions() -> Generator[Tuple[Path, str, str], None, None]:
def find_connectable_extensions() -> List[Tuple[Path, str, str]]:
reva_temp_directory = os.path.join(Path.home(), '.reva')
reva_temp = Path(reva_temp_directory)
connection_details: List[Tuple[Path, str, str]] = []
if reva_temp.exists():
for file in reva_temp.glob("reva-connection-*.connection"):
connection_string = file.read_text()
content = connection_string.split(":")
if len(content) == 2:
yield file, content[0], content[1]
connection_details.append((file, content[0], content[1]))
else:
# If the connection string is the wrong format, we will remove it
# this is to clean anything that dropped bad content in the directory
logger.warning(f"Invalid connection string: {connection_string}. Cleaning.")
file.unlink()
return connection_details

async def check_connectivity(file: Path, host: str, port: str) -> Optional[RevaHeartbeatResponse]:
# First try a heartbeat to see if the connection is still alive
retries = 10
logger.debug(f"Checking connection to {host}:{port} from {file}")
for _ in range(retries + 1):
try:
channel = grpc.insecure_channel(f"{host}:{port}")
stub = RevaHeartbeatStub(channel)
response = stub.heartbeat(RevaHeartbeatRequest())
logger.info(f"Found connectable extension: {response}")
return response
except grpc.RpcError as e:
if retries > 0:
retries -= 1
logger.warning(f"Failed to connect to {host}:{port}. Retrying in 1 second.")
await asyncio.sleep(1)
else:
# If we can't connect, clean it up
logger.debug(f"Removing old connection file: {file}")
file.unlink()
return None

async def get_active_extensions(connection_details) -> List[RevaHeartbeatResponse]:
connectable_extensions = []

tasks = []
for file, host, port in connection_details:
tasks.append(check_connectivity(file, host, port))

logger.debug(f"Checking {len(tasks)} connectable extensions")
for result in asyncio.as_completed(tasks):
response = await result
logger.debug(f"Got response: {response}")
if response:
connectable_extensions.append(response)

return connectable_extensions

def main():
parser = argparse.ArgumentParser(description="Reva Chat Client")
Expand Down Expand Up @@ -119,27 +162,23 @@ def main():
console = Console(record=True)
if not args.port:
connectable_extensions: List[RevaHeartbeatResponse] = []
for file, host, port in find_connectable_extensions():
# First try a heartbeat to see if the connection is still alive
retries = 10
for _ in range(2):
try:
channel = grpc.insecure_channel(f"{host}:{port}")
stub = RevaHeartbeatStub(channel)
response = stub.heartbeat(RevaHeartbeatRequest())
connectable_extensions.append(response)
logger.info(f"Found connectable extension: {response}")

with console.status(f"🐉 Searching for Ghidra extension..."):
connection_details = find_connectable_extensions()
logger.debug(f"Found connection details: {connection_details}")

connectable_extensions = asyncio.run(get_active_extensions(connection_details), debug=args.debug)
if not connectable_extensions:
console.print("Please open a file in Ghidra to start the extension.")
# If there are no connection details, let's assume the user is starting Ghidra
for _ in range(30):
if not connectable_extensions:
time.sleep(1)
connection_details = find_connectable_extensions()
connectable_extensions = asyncio.run(get_active_extensions(connection_details), debug=args.debug)
else:
break
except grpc.RpcError as e:
if retries > 0:
import time
retries -= 1
logger.warning(f"Failed to connect to {host}:{port}. Retrying in 1 second.")
time.sleep(1)
else:
# If we can't connect, clean it up
logger.debug(f"Removing old connection file: {file}")
file.unlink()

if len(connectable_extensions) == 0:
logger.error("No connectable extensions found. Is Ghidra running? Is the extension enabled?")
parser.error("No connectable extensions found. Is Ghidra running? Is the extension enabled?")
Expand Down

0 comments on commit 2b5595d

Please sign in to comment.