Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dspy/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dspy.adapters.baml_adapter import BAMLAdapter
from dspy.adapters.base import Adapter
from dspy.adapters.chat_adapter import ChatAdapter
from dspy.adapters.json_adapter import JSONAdapter
Expand All @@ -18,4 +19,5 @@
"TwoStepAdapter",
"Tool",
"ToolCalls",
"BAMLAdapter",
]
34 changes: 28 additions & 6 deletions dspy/adapters/types/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def __init__(self, url: Any = None, *, download: bool = False, **data):

Any additional keyword arguments are passed to :class:`pydantic.BaseModel`.
"""

if url is not None and "url" not in data:
# Support a positional argument while allowing ``url=`` in **data.
if isinstance(url, dict) and set(url.keys()) == {"url"}:
Expand All @@ -66,13 +65,36 @@ def __init__(self, url: Any = None, *, download: bool = False, **data):
# Delegate the rest of initialization to pydantic's BaseModel.
super().__init__(**data)

@pydantic.model_validator(mode="before")
@classmethod
def validate_input(cls, data: Any):
"""Validate and normalize image input data."""
if isinstance(data, cls):
return data

# Handle positional argument case where data is not a dict
if not isinstance(data, dict):
# Convert non-dict input to dict format
data = {"url": data}

# Handle legacy dict form with single "url" key
if isinstance(data.get("url"), dict) and set(data["url"].keys()) == {"url"}:
data["url"] = data["url"]["url"]

# Extract download parameter if present, defaulting to False
download = data.pop("download", False) if isinstance(data, dict) else False

if "url" not in data:
raise ValueError("url field is required for Image")

# Normalize any accepted input into a base64 data URI or plain URL
data["url"] = encode_image(data["url"], download_images=download)

return data

@lru_cache(maxsize=32)
def format(self) -> list[dict[str, Any]] | str:
try:
image_url = encode_image(self.url)
except Exception as e:
raise ValueError(f"Failed to format image for DSPy: {e}")
return [{"type": "image_url", "image_url": {"url": image_url}}]
return [{"type": "image_url", "image_url": {"url": self.url}}]

@classmethod
def from_url(cls, url: str, download: bool = False):
Expand Down
18 changes: 17 additions & 1 deletion dspy/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,34 @@ def serialize_for_json(value: Any) -> Any:
return str(value)


def format_field_value(field_info: FieldInfo, value: Any, assume_text=True) -> str | dict:
def format_field_value(field_info: FieldInfo, value: Any, assume_text=True, is_placeholder=False) -> str | dict:
"""
Formats the value of the specified field according to the field's DSPy type (input or output),
annotation (e.g. str, int, etc.), and the type of the value itself.

This is used to format both the placeholder that goes in the system prompt and the actual value itself.

Args:
field_info: Information about the field, including its DSPy field type and annotation.
value: The value of the field.
Returns:
The formatted value of the field, represented as a string.
"""
string_value = None

# If the annotation is a Type subclass, but the value is not an instance of that type, and not a placeholder,
# then check if it's a valid serialized custom type string; otherwise, raise.
if (
issubclass(field_info.annotation, DspyType)
and not isinstance(value, field_info.annotation)
and not is_placeholder
):
# Allow serialized custom type string (e.g., when loading from saved state)
if isinstance(value, str) and value.startswith("<<CUSTOM-TYPE-START-IDENTIFIER>>") and value.endswith("<<CUSTOM-TYPE-END-IDENTIFIER>>"):
pass # Accept as valid serialized custom type
else:
raise TypeError(f"Value {value} is not an instance of {field_info.annotation}")

if isinstance(value, list) and field_info.annotation is str:
# If the field has no special type requirements, format it as a nice numbered list for the LM.
string_value = _format_input_list_field_value(value)
Expand Down
3 changes: 1 addition & 2 deletions tests/adapters/test_json_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ async def test_json_adapter_async_call():
signature = dspy.make_signature("question->answer")
adapter = dspy.JSONAdapter()
lm = dspy.utils.DummyLM([{"answer": "Paris"}], adapter=adapter)
with dspy.context(adapter=adapter):
result = await adapter.acall(lm, {}, signature, [], {"question": "What is the capital of France?"})
result = await adapter.acall(lm, {}, signature, [], {"question": "What is the capital of France?"})
assert result == [{"answer": "Paris"}]


Expand Down
92 changes: 73 additions & 19 deletions tests/signatures/test_adapter_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def count_patterns(obj, pattern):
return 0


def setup_predictor(signature, expected_output):
def setup_predictor(signature, expected_output, adapter=None):
"""Helper to set up a predictor with DummyLM"""
lm = DummyLM([expected_output])
lm = DummyLM([expected_output], adapter=adapter)
dspy.settings.configure(lm=lm)
return dspy.Predict(signature), lm

Expand Down Expand Up @@ -123,41 +123,95 @@ def test_basic_image_operations(test_case):
# Check result based on output field name
output_field = next(f for f in ["probabilities", "generated_code", "bboxes", "captions"] if hasattr(result, f))
assert getattr(result, output_field) == test_case["expected"][test_case["key_output"]]
assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 1


@pytest.mark.parametrize(
"adapter_type",
[
"chat_adapter",
"json_adapter",
"baml_adapter",
"xml_adapter",
],
)
@pytest.mark.parametrize(
"image_input,description",
[
("pil_image", "PIL Image"),
("encoded_pil_image", "encoded PIL image string"),
("dspy_image_download", "dspy.Image with download=True"),
("dspy_pil_image_with_download", "PIL Image with download=True"),
("dspy_image_no_download", "dspy.Image without download"),
],
)
def test_image_input_formats(
request, sample_pil_image, sample_dspy_image_download, sample_dspy_image_no_download, image_input, description
def test_image_input_formats_valid(
request, sample_url, sample_pil_image, sample_dspy_image_download, sample_dspy_image_no_download, image_input, description, adapter_type
):
"""Test different input formats for image fields"""
"""Test valid input formats for image fields"""
input_map = {
"dspy_image_download": sample_dspy_image_download,
"dspy_pil_image_with_download": dspy.Image(sample_pil_image, download=True),
"dspy_image_no_download": sample_dspy_image_no_download,
}
adapter_output_map = {
"chat_adapter": (dspy.ChatAdapter(), {"probabilities": {"dog": 0.8, "cat": 0.1, "bird": 0.1}}),
"json_adapter": (dspy.JSONAdapter(), {"probabilities": {"dog": 0.8, "cat": 0.1, "bird": 0.1}}),
"baml_adapter": (dspy.adapters.BAMLAdapter(), {"probabilities": {"dog": 0.8, "cat": 0.1, "bird": 0.1}}),
"xml_adapter": (dspy.XMLAdapter(), {"probabilities": {"dog": 0.8, "cat": 0.1, "bird": 0.1}}),
}

actual_input = input_map[image_input]
signature = "image: dspy.Image, class_labels: list[str] -> probabilities: dict[str, float]"
expected = {"probabilities": {"dog": 0.8, "cat": 0.1, "bird": 0.1}}
predictor, lm = setup_predictor(signature, expected)
expected_output = {"probabilities": {"dog": 0.8, "cat": 0.1, "bird": 0.1}}
adapter, lm_output = adapter_output_map[adapter_type]
predictor, lm = setup_predictor(signature, lm_output, adapter)
with dspy.context(adapter=adapter):
result = predictor(image=actual_input, class_labels=["dog", "cat", "bird"])
assert result.probabilities == expected_output["probabilities"]
assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 1

# Invalid input types: raw PIL image, encoded PIL image string, raw URL string
@pytest.mark.parametrize(
"adapter_type",
[
"chat_adapter",
"json_adapter",
"baml_adapter",
"xml_adapter",
],
)
@pytest.mark.parametrize(
"image_input,description",
[
("pil_image", "PIL Image"),
("encoded_pil_image", "encoded PIL image string"),
("url_non_dspy_image", "URL of an image"),
],
)
def test_image_input_formats_invalid(
request, sample_url, sample_pil_image, sample_dspy_image_download, sample_dspy_image_no_download, image_input, description, adapter_type
):
"""Test invalid input formats for image fields (should raise ValueError)"""
input_map = {
"pil_image": sample_pil_image,
"encoded_pil_image": encode_image(sample_pil_image),
"dspy_image_download": sample_dspy_image_download,
"dspy_image_no_download": sample_dspy_image_no_download,
"url_non_dspy_image": sample_url,
}
adapter_output_map = {
"chat_adapter": (dspy.ChatAdapter(), {"probabilities": {"dog": 0.8, "cat": 0.1, "bird": 0.1}}),
"json_adapter": (dspy.JSONAdapter(), {"probabilities": {"dog": 0.8, "cat": 0.1, "bird": 0.1}}),
"baml_adapter": (dspy.adapters.baml_adapter.BAMLAdapter(), {"probabilities": {"dog": 0.8, "cat": 0.1, "bird": 0.1}}),
"xml_adapter": (dspy.XMLAdapter(), {"probabilities": {"dog": 0.8, "cat": 0.1, "bird": 0.1}}),
}

actual_input = input_map[image_input]
# TODO(isaacbmiller): Support the cases without direct dspy.Image coercion
if image_input in ["pil_image", "encoded_pil_image"]:
pytest.xfail(f"{description} not fully supported without dspy.Image coercion")
if adapter_type == "two_step_adapter":
pytest.xfail("TwoStepAdapter is not known to support image input")

result = predictor(image=actual_input, class_labels=["dog", "cat", "bird"])
assert result.probabilities == expected["probabilities"]
assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 1
actual_input = input_map[image_input]
signature = "image: dspy.Image, class_labels: list[str] -> probabilities: dict[str, float]"
adapter, lm_output = adapter_output_map[adapter_type]
predictor, lm = setup_predictor(signature, lm_output, adapter)
with dspy.context(adapter=adapter):
with pytest.raises(TypeError):
predictor(image=actual_input, class_labels=["dog", "cat", "bird"])


def test_predictor_save_load(sample_url, sample_pil_image):
Expand Down
Loading