-
Notifications
You must be signed in to change notification settings - Fork 1
/
gemini_pro.py
167 lines (137 loc) · 5 KB
/
gemini_pro.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""
This script demonstrates how to use the Gemini Pro API.
"""
import os
import base64
import asyncio
import spacy
from spacy.matcher import Matcher
from vertexai.preview.generative_models import (
GenerativeModel,
HarmCategory,
HarmBlockThreshold,
Part,
)
from rich.console import Console
from rich.markdown import Markdown
from rich.prompt import Prompt
from rich.traceback import install
# Enable pretty printing of exceptions with Rich
install()
# Load the spaCy model
nlp = spacy.load("en_core_web_sm")
# Create a console object for rich printing
console = Console()
def extract_filename(text):
"""Extracts filenames from the text."""
# Process the text with spaCy
doc = nlp(text)
# Define a pattern for matching filenames with extensions
pattern = [
{"TEXT": {"REGEX": "^[^\\s\\/]+\\.(jpg|png|mkv|mov|mp4|webm)$"}}
]
# Add the pattern to the matcher
matcher = Matcher(nlp.vocab)
matcher.add("FILENAME", [pattern])
# Apply the matcher to the doc
matches = matcher(doc)
# Extract and return the matched filenames
filenames = []
for _, start, end in matches:
# The matched span
span = doc[start:end]
filenames.append(span.text)
# Debugging: print the extracted filenames
# print("Extracted filenames:", filenames)
return filenames
model = GenerativeModel("gemini-pro")
chat = model.start_chat(history=[])
async def ask_gemini_pro(question):
"""Ask Gemini Pro a question and print the response using ChatSession."""
# Send the message to the chat session and get the response
response = chat.send_message(question)
# Print the response text
for part in response.candidates[0].content.parts:
console.print(part.text, style="bold green")
async def ask_gemini_pro_vision(question, source_folder, specific_file_name):
"""
Ask Gemini Pro Vision a question about a specific image file.
Args:
question: The question to ask.
source_folder: The folder containing the image file.
specific_file_name: The name of the image file.
"""
# Read the image file as bytes and encode it with base64
image_path = os.path.join(source_folder, specific_file_name)
with open(image_path, "rb") as image_file:
image_bytes = image_file.read()
encoded_image = base64.b64encode(image_bytes).decode('utf-8')
# Create a Part object with the image data
image_part = Part.from_data(data=encoded_image, mime_type="image/jpeg")
# Set up the generation configuration
generation_config = {
"max_output_tokens": 2048,
"temperature": 0.4,
"top_p": 1,
"top_k": 32,
}
# Set the safety settings to block harmful content
safety_settings = {
HarmCategory.HARM_CATEGORY_HATE_SPEECH:
HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmCategory.HARM_CATEGORY_HARASSMENT:
HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT:
HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT:
HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
}
# Create a GenerativeModel object for the Gemini Pro Vision model
model = GenerativeModel("gemini-pro-vision")
# Make the request and stream the responses
responses = model.generate_content(
[image_part, question],
generation_config=generation_config,
safety_settings=safety_settings,
stream=True,
)
# Handle the responses
for response in responses:
if response.candidates:
console.print(
response.candidates[0].content.parts[0].text,
style="bold green"
)
else:
console.print("No response candidates found.")
async def main():
"""Main function."""
# Clear the console screen before displaying the welcome message
os.system('cls' if os.name == 'nt' else 'clear')
console.print(Markdown("# Welcome to Gemini Pro"), style="bold magenta")
while True:
user_input = Prompt.ask(
"\nAsk your question (or type 'exit' to quit)",
default="exit"
)
if user_input.lower() == 'exit':
console.print("\nExiting the program.", style="bold red")
break
else:
specific_file_names = extract_filename(user_input)
# Add two blank lines after user input
console.print("\n")
if specific_file_names:
specific_file_name = specific_file_names[0]
await ask_gemini_pro_vision(
user_input,
"workspace",
specific_file_name
)
else:
# console.print("Calling ask_gemini_pro", style="bold yellow")
await ask_gemini_pro(user_input)
# Add one blank line after the model's response
# console.print("\n")
if __name__ == "__main__":
asyncio.run(main())