diff --git a/src/bria_client/toolkit/image.py b/src/bria_client/toolkit/image.py index 0a2add0..21a5fcd 100644 --- a/src/bria_client/toolkit/image.py +++ b/src/bria_client/toolkit/image.py @@ -62,6 +62,8 @@ def _process_image(self, image: ImageSource) -> Base64String | str: if isinstance(image, str): if image.startswith("http"): return image + if image.startswith("data:") and ";base64," in image: + return image.split(",", 1)[-1] if Image.is_base64(image): return image # infer it is a local path diff --git a/tests/unit/toolkit/test_image.py b/tests/unit/toolkit/test_image.py index 01da2fc..156ba54 100644 --- a/tests/unit/toolkit/test_image.py +++ b/tests/unit/toolkit/test_image.py @@ -15,6 +15,15 @@ def test_image_on_init_should_work_for_all_image_source_types(self, image_source # Assert assert image.as_bria_api_input is not None + @pytest.mark.parametrize("prefix", ["data:image/png;base64,", "data:image/jpeg;base64,"]) + def test_image_on_init_with_base64_data_uri_should_process_successfully(self, base64_image, prefix): + # Arrange + base64_with_header = f"{prefix}{base64_image}" + # Act + image = Image(base64_with_header) + # Assert + assert Image.is_base64(image.as_bria_api_input) + @pytest.mark.unit class TestImageSource: