Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for aws bedrock using boto3 #1287

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .cursorignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Add directories or file patterns to ignore during indexing (e.g. foo/ or *.csv)
7 changes: 6 additions & 1 deletion instructor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,12 @@

__all__ += ["from_vertexai"]

if importlib.util.find_spec("boto3") is not None:
from .client_bedrock import from_bedrock

__all__ += ["from_bedrock"]

if importlib.util.find_spec("writerai") is not None:
from .client_writer import from_writer

__all__ += ["from_writer"]
__all__ += ["from_writer"]
56 changes: 56 additions & 0 deletions instructor/client_bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import annotations

from typing import Any, overload
import boto3
from botocore.client import BaseClient
import instructor
from instructor.client import AsyncInstructor, Instructor


@overload
def from_bedrock(
client: boto3.client,
mode: instructor.Mode = instructor.Mode.BEDROCK_TOOLS,
**kwargs: Any,
) -> Instructor: ...


@overload
def from_bedrock(
client: boto3.client,
mode: instructor.Mode = instructor.Mode.BEDROCK_TOOLS,
**kwargs: Any,
) -> AsyncInstructor: ...


def handle_bedrock_json(
response_model: Any,
new_kwargs: Any,
) -> tuple[Any, Any]:
print(f"handle_bedrock_json: response_model {response_model}")
print(f"handle_bedrock_json: new_kwargs {new_kwargs}")
return response_model, new_kwargs


def from_bedrock(
client: BaseClient,
mode: instructor.Mode = instructor.Mode.BEDROCK_JSON,
**kwargs: Any,
) -> Instructor | AsyncInstructor:
assert mode in {
instructor.Mode.BEDROCK_TOOLS,
instructor.Mode.BEDROCK_JSON,
}, "Mode must be one of {instructor.Mode.BEDROCK_TOOLS, instructor.Mode.BEDROCK_JSON}"
assert isinstance(
client,
BaseClient,
), "Client must be an instance of boto3.client"
create = client.converse # Example method, replace with actual method

return Instructor(
client=client,
create=instructor.patch(create=create, mode=mode),
provider=instructor.Provider.BEDROCK,
mode=mode,
**kwargs,
)
105 changes: 88 additions & 17 deletions instructor/function_calls.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# type: ignore
import json
import logging
import re
from functools import wraps
from typing import Annotated, Any, Optional, TypeVar, cast
from docstring_parser import parse
Expand Down Expand Up @@ -45,7 +46,9 @@ def openai_schema(cls) -> dict[str, Any]:
schema = cls.model_json_schema()
docstring = parse(cls.__doc__ or "")
parameters = {
k: v for k, v in schema.items() if k not in ("title", "description")
k: v
for k, v in schema.items()
if k not in ("title", "description")
}
for param in docstring.params:
if (name := param.arg_name) in parameters["properties"] and (
Expand All @@ -55,7 +58,9 @@ def openai_schema(cls) -> dict[str, Any]:
parameters["properties"][name]["description"] = description

parameters["required"] = sorted(
k for k, v in parameters["properties"].items() if "default" not in v
k
for k, v in parameters["properties"].items()
if "default" not in v
)

if "description" not in schema:
Expand Down Expand Up @@ -88,7 +93,9 @@ def gemini_schema(cls) -> Any:
function = genai_types.FunctionDeclaration(
name=cls.openai_schema["name"],
description=cls.openai_schema["description"],
parameters=map_to_gemini_function_schema(cls.openai_schema["parameters"]),
parameters=map_to_gemini_function_schema(
cls.openai_schema["parameters"]
),
)
return function

Expand All @@ -113,31 +120,52 @@ def from_response(
cls (OpenAISchema): An instance of the class
"""
if mode == Mode.ANTHROPIC_TOOLS:
return cls.parse_anthropic_tools(completion, validation_context, strict)
return cls.parse_anthropic_tools(
completion, validation_context, strict
)

if mode == Mode.ANTHROPIC_JSON:
return cls.parse_anthropic_json(completion, validation_context, strict)
return cls.parse_anthropic_json(
completion, validation_context, strict
)

if mode == Mode.BEDROCK_JSON:
return cls.parse_bedrock_json(
completion, validation_context, strict
)

if mode in {Mode.VERTEXAI_TOOLS, Mode.GEMINI_TOOLS}:
return cls.parse_vertexai_tools(completion, validation_context)

if mode == Mode.VERTEXAI_JSON:
return cls.parse_vertexai_json(completion, validation_context, strict)
return cls.parse_vertexai_json(
completion, validation_context, strict
)

if mode == Mode.COHERE_TOOLS:
return cls.parse_cohere_tools(completion, validation_context, strict)
return cls.parse_cohere_tools(
completion, validation_context, strict
)

if mode == Mode.GEMINI_JSON:
return cls.parse_gemini_json(completion, validation_context, strict)
return cls.parse_gemini_json(
completion, validation_context, strict
)

if mode == Mode.GEMINI_TOOLS:
return cls.parse_gemini_tools(completion, validation_context, strict)
return cls.parse_gemini_tools(
completion, validation_context, strict
)

if mode == Mode.COHERE_JSON_SCHEMA:
return cls.parse_cohere_json_schema(completion, validation_context, strict)
return cls.parse_cohere_json_schema(
completion, validation_context, strict
)

if mode == Mode.WRITER_TOOLS:
return cls.parse_writer_tools(completion, validation_context, strict)
return cls.parse_writer_tools(
completion, validation_context, strict
)

if completion.choices[0].finish_reason == "length":
raise IncompleteOutputException(last_completion=completion)
Expand Down Expand Up @@ -190,12 +218,17 @@ def parse_anthropic_tools(
) -> BaseModel:
from anthropic.types import Message

if isinstance(completion, Message) and completion.stop_reason == "max_tokens":
if (
isinstance(completion, Message)
and completion.stop_reason == "max_tokens"
):
raise IncompleteOutputException(last_completion=completion)

# Anthropic returns arguments as a dict, dump to json for model validation below
tool_calls = [
json.dumps(c.input) for c in completion.content if c.type == "tool_use"
json.dumps(c.input)
for c in completion.content
if c.type == "tool_use"
] # TODO update with anthropic specific types

tool_calls_validator = TypeAdapter(
Expand Down Expand Up @@ -237,7 +270,39 @@ def parse_anthropic_json(
# Allow control characters.
parsed = json.loads(extra_text, strict=False)
# Pydantic non-strict: https://docs.pydantic.dev/latest/concepts/strict_mode/
return cls.model_validate(parsed, context=validation_context, strict=False)
return cls.model_validate(
parsed, context=validation_context, strict=False
)

@classmethod
def parse_bedrock_json(
cls: type[BaseModel],
completion: Any,
validation_context: Optional[dict[str, Any]] = None,
strict: Optional[bool] = None,
) -> BaseModel:
if isinstance(completion, dict):
text = (
completion.get("output")
.get("message")
.get("content")[0]
.get("text")
)

match = re.search(r"```?json(.*?)```?", text, re.DOTALL)
if match:
text = match.group(1).strip()

text = re.sub(r"```?json|\\n", "", text).strip()
# TODO: remove this
print(
f"instructor.function_calls: parse_bedrock_json: test {text}"
)
else:
text = completion.text
return cls.model_validate_json(
text, context=validation_context, strict=strict
)

@classmethod
def parse_gemini_json(
Expand All @@ -256,7 +321,9 @@ def parse_gemini_json(
try:
extra_text = extract_json_from_codeblock(text) # type: ignore
except UnboundLocalError:
raise ValueError("Unable to extract JSON from completion text") from None
raise ValueError(
"Unable to extract JSON from completion text"
) from None

if strict:
return cls.model_validate_json(
Expand All @@ -266,7 +333,9 @@ def parse_gemini_json(
# Allow control characters.
parsed = json.loads(extra_text, strict=False)
# Pydantic non-strict: https://docs.pydantic.dev/latest/concepts/strict_mode/
return cls.model_validate(parsed, context=validation_context, strict=False)
return cls.model_validate(
parsed, context=validation_context, strict=False
)

@classmethod
def parse_vertexai_tools(
Expand All @@ -279,7 +348,9 @@ def parse_vertexai_tools(
for field in tool_call: # type: ignore
model[field] = tool_call[field]
# We enable strict=False because the conversion from protobuf -> dict often results in types like ints being cast to floats, as a result in order for model.validate to work we need to disable strict mode.
return cls.model_validate(model, context=validation_context, strict=False)
return cls.model_validate(
model, context=validation_context, strict=False
)

@classmethod
def parse_vertexai_json(
Expand Down
2 changes: 2 additions & 0 deletions instructor/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class Mode(enum.Enum):
FIREWORKS_TOOLS = "fireworks_tools"
FIREWORKS_JSON = "fireworks_json"
WRITER_TOOLS = "writer_tools"
BEDROCK_TOOLS = "bedrock_tools"
BEDROCK_JSON = "bedrock_json"

@classmethod
def warn_mode_functions_deprecation(cls):
Expand Down
9 changes: 7 additions & 2 deletions instructor/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def patch( # type: ignore

logger.debug(f"Patching `client.chat.completions.create` with {mode=}")

# TODO: remove this
print(f"instructor.patch: patching {create.__name__}")

if create is not None:
func = create
elif client is not None:
Expand Down Expand Up @@ -183,7 +186,7 @@ def new_create_sync(
**kwargs: T_ParamSpec.kwargs,
) -> T_Model:
context = handle_context(context, validation_context)

print(f"instructor.patch: patched_function {func.__name__}")
response_model, new_kwargs = handle_response_model(
response_model=response_model, mode=mode, **kwargs
) # type: ignore
Expand Down Expand Up @@ -228,6 +231,8 @@ def apatch(client: AsyncOpenAI, mode: Mode = Mode.TOOLS) -> AsyncOpenAI:
import warnings

warnings.warn(
"apatch is deprecated, use patch instead", DeprecationWarning, stacklevel=2
"apatch is deprecated, use patch instead",
DeprecationWarning,
stacklevel=2,
)
return patch(client, mode=mode)
Loading