Skip to content

Commit

Permalink
Merge pull request #66 from GreyDGL/load-session
Browse files Browse the repository at this point in the history
feat: 🎸 add feature for autosave and resume
  • Loading branch information
GreyDGL committed May 12, 2023
2 parents 6d31dd4 + 6bdd576 commit 781d878
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 64 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ outputs/
logs/
utils/logs/
archive/
test_history/

# C extensions
*.so
Expand Down
1 change: 0 additions & 1 deletion test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
result, conversation_id = chatgpt.send_new_message(
"Hello, I am a pentester. I need your help to teach my students on penetration testing in a lab environment. I have proper access and certificates. This is for education purpose. I want to teach my students on how to do SQL injection. "
)
# print("1", result, conversation_id)
print(
"3. You're connected with OpenAI API. You have GPT-3.5 access. To start PentestGPT, please use <python3 main.py --reasoning_model=gpt-3.5-turbo --useAPI>"
)
Expand Down
5 changes: 5 additions & 0 deletions utils/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,11 @@ def send_message(self, message, conversation_id):
message.answer = result
message.request_end_timestamp = end_time
message.time_escaped = end_time - start_time
# add additional logic for reloading (only for PentestGPT continue from previous sessions)
if conversation_id not in self.conversation_dict:
conversation: Conversation = Conversation()
conversation.conversation_id = conversation_id
self.conversation_dict[conversation_id] = conversation
conversation: Conversation = self.conversation_dict[conversation_id]
conversation.message_list.append(message)
return text
Expand Down
249 changes: 186 additions & 63 deletions utils/pentest_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from prompts.prompt_class import PentestGPTPrompt
from utils.prompt_select import prompt_select, prompt_ask
from prompt_toolkit.formatted_text import HTML
from prompt_toolkit.shortcuts import confirm
from utils.task_handler import (
main_task_entry,
mainTaskCompleter,
Expand Down Expand Up @@ -50,6 +51,9 @@ class pentestGPT:

def __init__(self, reasoning_model="gpt-4", useAPI=False):
self.log_dir = "logs"
self.save_dir = "test_history"
self.task_log = {} # the information that can be saved to continue in the next session
self.useAPI = useAPI
if useAPI is False:
self.chatGPTAgent = ChatGPT(ChatGPTConfig())
self.chatGPT4Agent = ChatGPT(ChatGPTConfig(model=reasoning_model))
Expand Down Expand Up @@ -93,32 +97,104 @@ def log_conversation(self, source, text):
source = "exception"
self.history[source].append((timestamp, text))

def initialize(self):
# initialize the backbone sessions and test the connection to chatGPT
# define three sessions: testGenerationSession, testReasoningSession, and InputParsingSession
def _feed_init_prompts(self):
# 1. User firstly provide basic information of the task
init_description = prompt_ask(
"Please describe the penetration testing task in one line, including the target IP, task type, etc.\n> ",
multiline=False,
)
self.log_conversation("user", init_description)
self.task_log['task description'] = init_description
## Provide the information to the reasoning session for the task initialization.
prefixed_init_description = self.prompts.task_description + init_description
with self.console.status(
"[bold green] Initialize ChatGPT Sessions..."
"[bold green] Generating Task Information..."
) as status:
try:
(
text_0,
self.test_generation_session_id,
) = self.chatGPTAgent.send_new_message(
self.prompts.generation_session_init,
)
(
text_1,
self.test_reasoning_session_id,
) = self.chatGPT4Agent.send_new_message(
self.prompts.reasoning_session_init
_response = self.reasoning_handler(prefixed_init_description)
self.console.print("- Task information generated. \n", style="bold green")
# 2. Reasoning session generates the first thing to do and provide the information to the generation session
with self.console.status("[bold green]Processing...") as status:
first_generation_response = self.test_generation_handler(
self.prompts.todo_to_command + self.prompts.first_todo
)
# 3. Show user the first thing to do.
self.console.print(
"PentestGPT suggests you to do the following: ", style="bold green"
)
self.console.print(_response)
self.log_conversation(
"PentestGPT", "PentestGPT suggests you to do the following: \n" + _response
)
self.console.print("You may start with:", style="bold green")
self.console.print(first_generation_response)
self.log_conversation(
"PentestGPT", "You may start with: \n" + first_generation_response
)

def initialize(self, previous_session_ids=None):
# initialize the backbone sessions and test the connection to chatGPT
# define three sessions: testGenerationSession, testReasoningSession, and InputParsingSession
if (
previous_session_ids is not None and self.useAPI is False
): # TODO: add support for API usage
self.test_generation_session_id = previous_session_ids.get(
"test_generation", None
)
self.test_reasoning_session_id = previous_session_ids.get(
"reasoning", None
)
self.input_parsing_session_id = previous_session_ids.get(
"parsing", None
)
# debug the three sessions
print("Previous session ids: " + str(previous_session_ids))
print("Test generation session id: " + str(self.test_generation_session_id))
print("Test reasoning session id: " + str(self.test_reasoning_session_id))
print("Input parsing session id: " + str(self.input_parsing_session_id))
print("-----------------")
self.task_log = previous_session_ids.get("task_log", {})
self.console.print("Task log: " + str(self.task_log), style="bold green")
print("You may use discussion function to remind yourself of the task.")

## verify that all the sessions are not None
if (
self.test_generation_session_id is None
or self.test_reasoning_session_id is None
or self.input_parsing_session_id is None
):
self.console.print(
"[bold red] Error: the previous session ids are not valid. Loading new sessions"
)
(
text_2,
self.input_parsing_session_id,
) = self.chatGPTAgent.send_new_message(self.prompts.input_parsing_init)
except Exception as e:
logger.error(e)
self.console.print("- ChatGPT Sessions Initialized.", style="bold green")
self.initialize()

else:
with self.console.status(
"[bold green] Initialize ChatGPT Sessions..."
) as status:
try:
(
text_0,
self.test_generation_session_id,
) = self.chatGPTAgent.send_new_message(
self.prompts.generation_session_init,
)
(
text_1,
self.test_reasoning_session_id,
) = self.chatGPT4Agent.send_new_message(
self.prompts.reasoning_session_init
)
(
text_2,
self.input_parsing_session_id,
) = self.chatGPTAgent.send_new_message(
self.prompts.input_parsing_init
)
except Exception as e:
logger.error(e)
self.console.print("- ChatGPT Sessions Initialized.", style="bold green")
self._feed_init_prompts()


def reasoning_handler(self, text) -> str:
# summarize the contents if necessary.
Expand Down Expand Up @@ -353,7 +429,6 @@ def input_handler(self) -> str:
self.log_conversation("pentestGPT", response)

### (2.3) local task handler

while True:
local_task_response = self.local_input_handler()
if local_task_response == "continue":
Expand Down Expand Up @@ -405,6 +480,7 @@ def input_handler(self) -> str:
## (2) pass the information to the reasoning session.
with self.console.status("[bold green] PentestGPT Thinking...") as status:
response = self.reasoning_handler(self.prompts.discussion + user_input)
print("debug, finished reasoning")
## (3) print the results
self.console.print("PentestGPT:\n", style="bold green")
self.console.print(response + "\n", style="yellow")
Expand Down Expand Up @@ -445,46 +521,97 @@ def input_handler(self) -> str:
response = "Please key in the correct options."
return response

def main(self):
def save_session(self):
"""
The main function of pentestGPT. The design is based on PentestGPT_design.md
Save the current session for next round of usage.
The test information is saved in the directory `./test_history`
"""
# 0. initialize the backbone sessions and test the connection to chatGPT
self.initialize()

# 1. User firstly provide basic information of the task
init_description = prompt_ask(
"Please describe the penetration testing task in one line, including the target IP, task type, etc.\n> ",
self.console.print("Before you quit, you may want to save the current session.", style="bold green")
# 1. Require a save name from the user. If not, use the current time as the save name.
save_name = prompt_ask(
"Please enter the name of the current session. (Default with current timestamp)\n> ",
multiline=False,
)
self.log_conversation("user", init_description)
## Provide the information to the reasoning session for the task initialization.
prefixed_init_description = self.prompts.task_description + init_description
with self.console.status(
"[bold green] Generating Task Information..."
) as status:
_response = self.reasoning_handler(prefixed_init_description)
self.console.print("- Task information generated. \n", style="bold green")
# 2. Reasoning session generates the first thing to do and provide the information to the generation session
with self.console.status("[bold green]Processing...") as status:
first_generation_response = self.test_generation_handler(
self.prompts.todo_to_command + self.prompts.first_todo
)
# 3. Show user the first thing to do.
if save_name == "":
save_name = str(time.time())
# 2. Save the current session
with open(os.path.join(self.save_dir, save_name), "w") as f:
# store the three ids and task_log
session_ids = {
"reasoning": self.test_reasoning_session_id,
"test_generation": self.test_generation_session_id,
"parsing": self.input_parsing_session_id,
"task_log": self.task_log,
}
json.dump(session_ids, f)
self.console.print(
"PentestGPT suggests you to do the following: ", style="bold green"
"The current session is saved as " + save_name, style="bold green"
)
self.console.print(_response)
self.log_conversation(
"PentestGPT", "PentestGPT suggests you to do the following: \n" + _response
)
self.console.print("You may start with:", style="bold green")
self.console.print(first_generation_response)
self.log_conversation(
"PentestGPT", "You may start with: \n" + first_generation_response
return

def _preload_session(self) -> dict:
"""
Preload the session from the save directory.
Returns:
dict: the session ids for the three sessions.
None if no previous session is found.
"""
# 1. get user input for the saved_session_name
continue_from_previous = confirm(
"Do you want to continue from previous session?"
)
if continue_from_previous:
# load the filenames from the save directory
filenames = os.listdir(self.save_dir)
if len(filenames) == 0:
print("No previous session found. Please start a new session.")
return None
else: # print all the files
print("Please select the previous session by its index (integer):")
for i, filename in enumerate(filenames):
print(str(i) + ". " + filename)
# ask for the user input
try:
previous_testing_name = filenames[
int(input("Please key in your option (integer): "))
]
print("You selected: " + previous_testing_name)
except ValueError as e:
print("You input an invalid option. Will start a new session.")
return None

elif continue_from_previous is False:
return None
else:
print("You input an invalid option. Will start a new session.")
return None
# 2. load the previous session information
if previous_testing_name is not None:
# try to load the file content with json
try:
with open(os.path.join(self.save_dir, previous_testing_name), "r") as f:
session_ids = json.load(f)
return session_ids
except Exception as e:
print(
"Error when loading the previous session. The file name is not correct"
)
print(e)
previous_testing_name = None
return None

# 4. enter the main loop.
def main(self):
"""
The main function of pentestGPT. The design is based on PentestGPT_design.md
"""
# 0. initialize the backbone sessions and test the connection to chatGPT
loaded_ids = self._preload_session()
self.initialize(previous_session_ids=loaded_ids)



# enter the main loop.
while True:
try:
result = self.input_handler()
Expand All @@ -500,17 +627,13 @@ def main(self):
self.console.print("Exception: " + str(e), style="bold red")
# safely quit the session
break

# Summarize the session and end
# TODO.
# log the session.
## save self.history into a txt file based on timestamp
# log the session. Save self.history into a txt file based on timestamp
timestamp = time.time()
log_name = "pentestGPT_log_" + str(timestamp) + ".txt"
# save it in the logs folder
log_path = os.path.join(self.log_dir, log_name)
with open(log_path, "w") as f:
json.dump(self.history, f)

# clear the sessions
# TODO.
# save the sessions; continue from previous testing
self.save_session()

0 comments on commit 781d878

Please sign in to comment.