Skip to content
This repository has been archived by the owner on Jul 12, 2024. It is now read-only.

Commit

Permalink
Merge pull request #7 from loyal812/feat/implement-mathpix
Browse files Browse the repository at this point in the history
feat: implement the mathpix api to extract the contents
  • Loading branch information
loyal812 authored Mar 26, 2024
2 parents 67621a7 + d85580c commit 54bb18d
Show file tree
Hide file tree
Showing 78 changed files with 1,456 additions and 64 deletions.
4 changes: 3 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
OPENAI_API_KEY=
OPENAI_API_KEY=
MATHPIX_APP_ID=
MATHPIX_APP_KEY=
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,14 @@ example:
py .\chatting.py --payload_dir="./test/regression/regression_testxxx/payload/chatting_payload.json" --question="what's the golf?"
```

You will see the result at console.
You will see the result at console.


### Full process: data preprocessing -> pdf to image -> image to text -> fine tuning
Please install requirement again.

```
pip install -r requirements.txt
py main.py
```
1 change: 1 addition & 0 deletions finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def run_fine_tuning(args):
# Call class instance
fine_tune = FineTuningClass(
data_path=payload_data["data_path"],
parent_path=payload_data["parent_path"],
api_key=payload_data["api_key"],
model=payload_data["model"],
temperature=payload_data["temperature"],
Expand Down
170 changes: 170 additions & 0 deletions img2txt_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@

import os
import gc
import time
import argparse
from pathlib import Path
import concurrent.futures
from datetime import datetime

from src.utils.read_json import read_json
from src.utils.image_translator import ImageTranslator
from src.utils.chatgpt_communicator import ChatGPTCommunicator


def main(args):
"""
main entry point
"""
# Timer
start_time = time.time()

# Payload
payload_data = read_json(args.payload_dir)

# Read images from the image directory
image_data_path = payload_data["images_data_path"]
image_list = [img for img in os.listdir(image_data_path) if img.endswith(".png") or img.endswith(".jpeg") or img.endswith(".jpg")]

# Call class instance
img_translator = ImageTranslator(api_key=payload_data["api_key"])

# Loop over number of images and append all images
# NOTE: User can upload image and add image URLs or just upload image or just add image URLs
images = []
if (len(image_list) > 0) and (len(payload_data["image_url"]) > 0):
for image in image_list:
image_path = os.path.join(image_data_path, image)
# Encode image
base64_image = img_translator.encode_image(image_path)
images.append((base64_image, False, "auto"))
for img_url in payload_data["image_url"]:
images.append((img_url, True, "auto"))
elif (len(image_list) > 0) and (len(payload_data["image_url"]) == 0):
for image in image_list:
image_path = os.path.join(image_data_path, image)
# Encode image
base64_image = img_translator.encode_image(image_path)
images.append((base64_image, False, "auto"))
elif (len(image_list) == 0) and (len(payload_data["image_url"]) > 0):
for img_url in payload_data["image_url"]:
images.append((img_url, True, "auto"))

for image in images:
if payload_data["is_parallel"]:
params = [{
img_translator: img_translator,
image: image
}] * payload_data["parallel_count"]

with concurrent.futures.ThreadPoolExecutor() as executor:
results = list(executor.map(lambda args: img2txt(*args), params))

result = make_one_result(payload_data, results)
else:
result = img2txt(img_translator, image)

save_to_txt(payload_data, result)


# Write into log file
end_time = time.time()
msg = f"Total processing time: {end_time - start_time} seconds"
print(msg)

# Delete class objects and clean the buffer memory using the garbage collection
gc.collect()

def save_to_txt(payload_data, result: str):
current_time = datetime.now().strftime('%y_%m_%d_%H_%M_%S')
train_path = os.path.join(payload_data["data_path"], "train_data")
os.makedirs(train_path, exist_ok=True) # This line will create the directory if it doesn't exist

with open(f'{train_path}/{current_time}_data.txt', "a", encoding="utf-8") as f:
f.write(result + "\n\n") # Append the new data to the end of the file

def img2txt(img_translator: ImageTranslator, image):
max_retries = 5
last_error = ""

img_translator_response = None # Define the variable and initialize it to None

for attempt in range(max_retries):
try:
response = img_translator.analyze_images([image])

if "choices" in response and response["choices"]:
first_choice = response["choices"][0]
if "message" in first_choice and "content" in first_choice["message"] and first_choice["message"]["content"]:
img_translator_response = first_choice["message"]["content"]
break # Successful response, break out of the loop
else:
last_error = "No valid content in the response."
else:
last_error = "The response structure is not as expected."

except Exception as e:
last_error = f"Attempt {attempt + 1} failed: {e}"

if img_translator_response:
break # If a successful response is obtained, exit the loop

if img_translator_response is None:
raise Exception("Failed to get a valid response after " + str(max_retries) + " attempts. Last error: " + last_error)

return img_translator_response

def make_one_result(payload_data, results: [str]):
response = payload_data["merge_prompt"]
for index, result in enumerate(results):
response += f"\nresult {index + 1}: {result}"

# Create chatGPT communicator
chatgpt_communicator = ChatGPTCommunicator(api_key=payload_data["api_key"], language_model=payload_data["language_model"])

# Start conversation with ChatGPT using the transcribed or translated text
chatgpt_communicator.create_chat(response)

# Get conversation with ChatGPT
max_retries = 3
chatgpt_response = None

for attempt in range(max_retries):
try:
chatgpt_response = chatgpt_communicator.get_response()
# Check if the response is valid (not None and not empty)
if chatgpt_response:
break # Valid response, break out of the loop
except Exception as e:
print(f"Attempt {attempt + 1} failed: {e}")
if attempt == max_retries - 1:
raise Exception(f"Failed to get a valid response from ChatGPT after {max_retries} attempts. Last error: {e}")

# Print response and use it somewhere else
# print(chatgpt_response)


return chatgpt_response

if __name__ == "__main__":
"""
Form command lines
"""
# Clean up buffer memory
gc.collect()

# Current directory
current_dir = os.path.dirname(os.path.abspath(__file__))

# Payload directory
test_name = "regression_test013"
payload_name = "img2txt_payload.json"
payload_dir = os.path.join(current_dir, "test", "regression", test_name, "payload", payload_name)

# Add options
p = argparse.ArgumentParser()
p = argparse.ArgumentParser(description="Translate text within an image.")
p.add_argument("--payload_dir", type=Path, default=payload_dir, help="payload directory to the test example")
args = p.parse_args()

main(args)
Loading

0 comments on commit 54bb18d

Please sign in to comment.