Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions os_computer_use/llm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,29 @@ def parse_json(s):
return None


def extract_json_objects(s):
"""Extract all balanced JSON objects from a string."""
objects = []
brace_level = 0
start_index = None
for i, char in enumerate(s):
if char == "{":
if brace_level == 0:
start_index = i
brace_level += 1
elif char == "}":
brace_level -= 1
if brace_level == 0 and start_index is not None:
candidate = s[start_index : i + 1]
try:
obj = json.loads(candidate)
objects.append(obj)
except json.JSONDecodeError:
pass
start_index = None
return objects


class LLMProvider:
"""
The LLM provider is used to make calls to an LLM given a provider and model name, with optional tool use support
Expand Down Expand Up @@ -52,6 +75,13 @@ def create_function_schema(self, definitions):
properties[param_name] = {"type": "string", "description": param_desc}
required.append(param_name)

# Add a dummy property if no parameters are provided, because providers like Gemini require a non-empty "properties" object.
if not properties:
properties["noop"] = {
"type": "string",
"description": "Dummy parameter for function with no parameters.",
}

function_def = self.create_function_def(name, details, properties, required)
functions.append(function_def)

Expand Down Expand Up @@ -142,16 +172,15 @@ def call(self, messages, functions=None):

# Sometimes, function calls are returned unparsed by the inference provider. This code parses them manually.
if message.content and not tool_calls:
tool_call_matches = re.search(r"\{.*\}", message.content)
if tool_call_matches:
tool_call = parse_json(tool_call_matches.group(0))
# Some models use "arguments" as the key instead of "parameters"
parameters = tool_call.get("parameters", tool_call.get("arguments"))
if tool_call.get("name") and parameters:
json_objs = extract_json_objects(message.content)
for obj in json_objs:
parameters = obj.get("parameters", obj.get("arguments"))
if obj.get("name") and parameters is not None:
combined_tool_calls.append(
self.create_tool_call(tool_call.get("name"), parameters)
self.create_tool_call(obj.get("name"), parameters)
)
return None, combined_tool_calls
if combined_tool_calls:
return None, combined_tool_calls

return message.content, combined_tool_calls

Expand Down