Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Adding Google's Gemini to the MTLLM LLM roster #1492

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions jac-mtllm/mtllm/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .anthropic import Anthropic
from .base import BaseLLM
from .google import Gemini
from .groq import Groq
from .huggingface import Huggingface
from .ollama import Ollama
Expand All @@ -13,6 +14,7 @@
"Anthropic",
"Ollama",
"Huggingface",
"Gemini",
"Groq",
"OpenAI",
"TogetherAI",
Expand Down
72 changes: 69 additions & 3 deletions jac-mtllm/mtllm/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,16 +184,82 @@ def resolve_output(
"""Resolve the output string to return the reasoning and output."""
if self.verbose:
logger.info(f"Meaning Out\n{meaning_out}")
output_match = re.search(r"\[Output\](.*)", meaning_out, re.DOTALL)
logger.info(f"Output Hint\n{output_hint.type}")
primary_types = [
"str",
"int",
"float",
"bool",
"list",
"dict",
"tuple",
"set",
"Any",
"None",
]
if re.search(r"\[Output\]", meaning_out, re.IGNORECASE):
output_match = re.search(r"\[Output\](.*)", meaning_out, re.DOTALL)
if output_match:
output = output_match.group(1).strip()
else:
# if output_hint.type.startswith("(") and output_hint.type.endswith(")"):
# tuple_pattern = re.compile(r"\(\s*(.*?)\s*\)", re.DOTALL)
# tuple_match = tuple_pattern.search(meaning_out)
# if tuple_match:
# # extracted_tuple = f"({tuple_match.group(0).strip()})"
# output_match = tuple_match
# else:
# output_match = None
# if self.verbose:
# logger.info(f"Tuple Output: {output_match}")
if output_hint.type.split("[")[0] in primary_types:
primary_patterns = {
"int": r"[-+]?\d+",
"float": r"[-+]?\d*\.\d+",
"bool": r"True|False",
"str": r"'([^']*)'",
"list": r"\[.*?\]",
"dict": r"\{.*?\}",
"tuple": r"\(.*?\)",
"set": r"\{.*?\}",
"Any": r".+",
"None": r"None",
}
single_pattern = re.compile(
primary_patterns[output_hint.type.split("[")[0]], re.DOTALL
)
single_match = single_pattern.search(meaning_out)
if not single_match and output_hint.type.split("[")[0] == "str":
meaning_out = meaning_out.rstrip()
single_match = re.match(r".*", meaning_out)
output_match = single_match
if self.verbose:
logger.info(f"Single Output: {single_match}")
else:
custom_type_pattern = re.compile(
rf"{output_hint.type}\s*\((.*)\)", re.DOTALL
)

custom_match = custom_type_pattern.search(meaning_out)
if custom_match:
extracted_output = (
f"{output_hint.type}({custom_match.group(1).strip()})"
)
if self.verbose:
logger.info(f"Custom Type Output: {extracted_output}")
output_match = custom_match
else:
output_match = None
if output_match:
output = output_match.group(0).strip()
if not output_match:
output = self._extract_output(
meaning_out,
output_hint,
output_type_explanations,
self.max_tries,
)
else:
output = output_match.group(1).strip()

if self.type_check:
is_in_desired_format = self._check_output(
output, output_hint.type, output_type_explanations
Expand Down
78 changes: 78 additions & 0 deletions jac-mtllm/mtllm/llms/google.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Gemini API client for MTLLM."""

from mtllm.llms.base import BaseLLM


REASON_SUFFIX = """
Reason and return the output result(s) only, adhering to the provided Type in the following format.

[Reasoning] <Reason>
[Output] <result>
"""

NORMAL_SUFFIX = """Generate and return the output result(s) only, adhering to the provided Type in the following format without using markdown formatting.

[Output] <results>
""" # noqa E501

CHAIN_OF_THOUGHT_SUFFIX = """
Generate and return the output result(s) only, adhering to the provided Type in the following format. Perform the operation in a chain of thoughts.(Think Step by Step)

[Chain of Thoughts] <Chain of Thoughts>
[Output] <Result>
""" # noqa E501

REACT_SUFFIX = """
You are given with a list of tools you can use to do different things. To achieve the given [Action], incrementally think and provide tool_usage necessary to achieve what is thought.
Provide your answer adhering in the following format. tool_usage is a function call with the necessary arguments. Only provide one [THOUGHT] and [TOOL USAGE] at a time.

[Thought] <Thought>
[Tool Usage] <tool_usage>
""" # noqa E501


class Gemini(BaseLLM):
"""Google API client for MTLLM."""

MTLLM_METHOD_PROMPTS: dict[str, str] = {
"Normal": NORMAL_SUFFIX,
"Reason": REASON_SUFFIX,
"Chain-of-Thoughts": CHAIN_OF_THOUGHT_SUFFIX,
"ReAct": REACT_SUFFIX,
}

def __init__(
self,
verbose: bool = False,
max_tries: int = 10,
type_check: bool = False,
**kwargs: dict,
) -> None:
"""Initialize the Google API client."""
from google import genai # type: ignore
import os # type: ignore

super().__init__(verbose, max_tries, type_check)
self.client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])
self.model_name = str(kwargs.get("model_name", "gemini-1.5-flash"))
self.temperature = kwargs.get("temperature", 0.7)
self.max_tokens = kwargs.get("max_tokens", 1024)

def __infer__(self, meaning_in: str | list[dict], **kwargs: dict) -> str:
"""Infer the output from the input meaning."""
if not isinstance(meaning_in, str):
assert self.model_name.startswith(
("gemini-1.0", "aqa, text")
), f"Model {self.model_name} does not support input of type {type(meaning_in)}. Choose a Multi-Modal model."
# messages = [{"role": "user", "content": meaning_in}]
# messages = 'What is the earth\'s circumference'
config = {
"temperature": kwargs.get("temperature", self.temperature),
# "max_tokens": kwargs.get("max_tokens", self.max_tokens),
}
output = self.client.models.generate_content(
model=kwargs.get("model_name", self.model_name),
contents=meaning_in,
config=config,
)
return output.text
7 changes: 5 additions & 2 deletions jac-mtllm/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,29 @@ keywords = ["llm", "jaclang", "jaseci", "mtllm"]
[tool.poetry.dependencies]
jaclang = "~0.7.25"
loguru = "~0.7.2"
openai = { version = "~1.30.4", optional = true }
pillow = "~11.0.0"
openai = { version = "~1.58.1", optional = true }
anthropic = { version = "~0.26.1", optional = true }
ollama = { version = "~0.2.0", optional = true }
together = { version = "~1.2.0", optional = true }
transformers = { version = "~4.41.1", optional = true }
groq = { version = "~0.8.0", optional = true }
google-genai = { version = "~0.1.0", optional = true }
google-generativeai = { version = "~0.1.0", optional = true }

[tool.poetry.group.dev.dependencies]
pytest = "^8.3.2"

[tool.poetry.extras]
tools = ["wikipedia"]
video = ["opencv-python-headless"]
image = ["pillow"]
groq = ["groq"]
transformers = ["transformers"]
ollama = ["ollama"]
anthropic = ["anthropic"]
openai = ["openai"]
together = ["together"]
google = ["google-genai", "google-generativeai"]

[tool.poetry.plugins."jac"]
mtllm = "mtllm.plugin:JacFeature"
Expand Down
Loading