Skip to content

Commit 569cf42

Browse files
authored
refactor: preprocess_image + tests (#702)
1 parent 4a5dfcd commit 569cf42

File tree

3 files changed

+163
-31
lines changed

3 files changed

+163
-31
lines changed

src/rai_core/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
44

55
[tool.poetry]
66
name = "rai_core"
7-
version = "2.5.2"
7+
version = "2.5.3"
88
description = "Core functionality for RAI framework"
99
authors = ["Maciej Majek <maciej.majek@robotec.ai>", "Bartłomiej Boczek <bartlomiej.boczek@robotec.ai>", "Kajetan Rachwał <kajetan.rachwal@robotec.ai>"]
1010
readme = "README.md"

src/rai_core/rai/messages/conversion.py

Lines changed: 91 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,40 +14,104 @@
1414

1515

1616
import base64
17-
from typing import Any, Callable, Union
17+
import os
18+
from io import BytesIO
19+
from typing import Callable, Union
1820

19-
import cv2
2021
import numpy as np
2122
import requests
23+
from PIL import Image as PILImage
2224
from PIL.Image import Image
2325

2426

2527
def preprocess_image(
26-
image: Union[Image, str, bytes, np.ndarray[Any, np.dtype[np.uint8 | np.float_]]],
27-
encoding_function: Callable[[Any], str] = lambda x: base64.b64encode(x).decode(
28+
image: Union[Image, str, bytes, np.ndarray],
29+
encoding_function: Callable[[bytes], str] = lambda b: base64.b64encode(b).decode(
2830
"utf-8"
2931
),
3032
) -> str:
31-
if isinstance(image, Image):
32-
image = np.array(image)
33-
_, image_data = cv2.imencode(".png", image)
34-
encoding_function = lambda x: base64.b64encode(x).decode("utf-8")
35-
elif isinstance(image, str) and image.startswith(("http://", "https://")):
36-
response = requests.get(image)
37-
response.raise_for_status()
38-
image_data = response.content
39-
elif isinstance(image, str):
40-
with open(image, "rb") as image_file:
41-
image_data = image_file.read()
42-
elif isinstance(image, bytes):
43-
image_data = image
44-
encoding_function = lambda x: x.decode("utf-8")
45-
elif isinstance(image, np.ndarray): # type: ignore
46-
if image.dtype == np.float32 or image.dtype == np.float64:
47-
image = (image * 255).astype(np.uint8)
48-
_, image_data = cv2.imencode(".png", image)
49-
encoding_function = lambda x: base64.b64encode(x).decode("utf-8")
50-
else:
51-
image_data = image
52-
53-
return encoding_function(image_data)
33+
"""Convert various image inputs into a base64-encoded PNG string.
34+
35+
Parameters
36+
----------
37+
image : PIL.Image.Image or str or bytes or numpy.ndarray
38+
Supported inputs:
39+
- PIL Image
40+
- Path to a file, ``file://`` URL, or HTTP(S) URL
41+
- Raw bytes containing image data
42+
- ``numpy.ndarray`` with dtype ``uint8`` or ``float32``/``float64``;
43+
grayscale or 3/4-channel arrays are supported.
44+
encoding_function : callable, optional
45+
Function that converts PNG bytes to the final string representation.
46+
By default, returns base64-encoded UTF-8 string.
47+
48+
Returns
49+
-------
50+
str
51+
Base64-encoded PNG string.
52+
53+
Raises
54+
------
55+
FileNotFoundError
56+
If a file path (or ``file://`` URL) does not exist.
57+
TypeError
58+
If the input type is not supported.
59+
requests.HTTPError
60+
If fetching an HTTP(S) URL fails with a non-2xx response.
61+
requests.RequestException
62+
If a network error occurs while fetching an HTTP(S) URL.
63+
OSError
64+
If the input cannot be decoded as an image by Pillow.
65+
66+
Notes
67+
-----
68+
- All inputs are decoded and re-encoded to PNG to guarantee consistent output.
69+
- Float arrays are assumed to be in [0, 1] and are scaled to ``uint8``.
70+
- Network requests use a timeout of ``(5, 15)`` seconds (connect, read).
71+
72+
Examples
73+
--------
74+
>>> b64_png = preprocess_image("path/to/image.jpg")
75+
>>> import numpy as np
76+
>>> arr = np.random.rand(64, 64, 3).astype(np.float32)
77+
>>> b64_png = preprocess_image(arr)
78+
"""
79+
80+
def _to_pil_from_ndarray(arr: np.ndarray) -> Image:
81+
a = arr
82+
if a.dtype in (np.float32, np.float64):
83+
a = np.clip(a, 0.0, 1.0)
84+
a = (a * 255.0).round().astype(np.uint8)
85+
a = np.ascontiguousarray(a)
86+
return PILImage.fromarray(a)
87+
88+
def _ensure_pil(img: Union[Image, str, bytes, np.ndarray]) -> Image:
89+
if isinstance(img, Image):
90+
return img
91+
if isinstance(img, np.ndarray): # type: ignore
92+
return _to_pil_from_ndarray(img)
93+
if isinstance(img, str):
94+
if img.startswith(("http://", "https://")):
95+
response = requests.get(img, timeout=(5, 15))
96+
response.raise_for_status()
97+
return PILImage.open(BytesIO(response.content))
98+
if img.startswith("file://"):
99+
file_path = img[len("file://") :]
100+
else:
101+
# fallback to file path if not marked with file://
102+
file_path = img
103+
if not os.path.exists(file_path):
104+
raise FileNotFoundError(f"File not found: {file_path}")
105+
return PILImage.open(file_path)
106+
if isinstance(img, bytes):
107+
return PILImage.open(BytesIO(img))
108+
raise TypeError(f"Unsupported image type: {type(img).__name__}")
109+
110+
pil_image = _ensure_pil(image)
111+
112+
# Normalize to PNG bytes
113+
with BytesIO() as buffer:
114+
pil_image.save(buffer, format="PNG")
115+
png_bytes = buffer.getvalue()
116+
117+
return encoding_function(png_bytes)

tests/messages/test_utils.py

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
# limitations under the License.
1414

1515
import base64
16+
import threading
17+
from functools import partial
18+
from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer
1619
from io import BytesIO
20+
from pathlib import Path
1721

1822
import numpy as np
1923
import pytest
@@ -30,10 +34,74 @@ def decode_image(base64_string: str) -> Image.Image:
3034
@pytest.mark.parametrize(
3135
"test_image",
3236
[
33-
np.zeros((300, 300, 3)),
37+
np.zeros((300, 300), dtype=np.uint8),
38+
np.zeros((300, 300, 3), dtype=np.uint8),
39+
np.zeros((300, 300, 4), dtype=np.uint8),
40+
np.random.rand(300, 300).astype(np.float32),
41+
np.random.rand(300, 300, 3).astype(np.float32),
42+
np.random.rand(300, 300, 4).astype(np.float32),
43+
np.random.rand(300, 300).astype(np.float64),
44+
np.random.rand(300, 300, 3).astype(np.float64),
45+
np.random.rand(300, 300, 4).astype(np.float64),
3446
"tests/resources/image.png",
3547
],
3648
)
37-
def test_preprocess_image(test_image):
49+
def test_preprocess_image_always_png(test_image):
3850
base64_image = preprocess_image(test_image)
39-
_ = decode_image(base64_image) # noqa: F841
51+
img = decode_image(base64_image)
52+
assert img.format == "PNG"
53+
assert img.size[0] > 0 and img.size[1] > 0
54+
55+
56+
def test_preprocess_image_from_bytes_and_file_url(tmp_path: Path):
57+
# Create a temporary PNG file
58+
arr = (np.random.rand(32, 32, 3) * 255).astype(np.uint8)
59+
pil_img = Image.fromarray(arr)
60+
file_path = tmp_path / "tmp.png"
61+
pil_img.save(file_path, format="PNG")
62+
63+
# bytes input
64+
with open(file_path, "rb") as f:
65+
raw_bytes = f.read()
66+
b64_bytes = preprocess_image(raw_bytes)
67+
img_from_bytes = decode_image(b64_bytes)
68+
assert img_from_bytes.format == "PNG"
69+
assert img_from_bytes.size == (32, 32)
70+
71+
# file:// URL input
72+
file_url = f"file://{file_path}"
73+
b64_file_url = preprocess_image(file_url)
74+
img_from_url = decode_image(b64_file_url)
75+
assert img_from_url.format == "PNG"
76+
assert img_from_url.size == (32, 32)
77+
78+
79+
def test_preprocess_image_unsupported_type():
80+
class NotAnImage:
81+
pass
82+
83+
with pytest.raises(TypeError):
84+
_ = preprocess_image(NotAnImage())
85+
86+
87+
def test_preprocess_image_http_url():
88+
resources_dir = Path(__file__).resolve().parents[1] / "resources"
89+
image_path = resources_dir / "image.png"
90+
assert image_path.exists(), "tests/resources/image.png must exist for this test"
91+
92+
handler = partial(SimpleHTTPRequestHandler, directory=str(resources_dir))
93+
httpd = ThreadingHTTPServer(("127.0.0.1", 0), handler)
94+
95+
try:
96+
port = httpd.server_address[1]
97+
thread = threading.Thread(target=httpd.serve_forever, daemon=True)
98+
thread.start()
99+
100+
url = f"http://127.0.0.1:{port}/image.png"
101+
base64_image = preprocess_image(url)
102+
img = decode_image(base64_image)
103+
assert img.format == "PNG"
104+
assert img.size[0] > 0 and img.size[1] > 0
105+
finally:
106+
httpd.shutdown()
107+
httpd.server_close()

0 commit comments

Comments
 (0)