Skip to content

Commit

Permalink
Add explicit tests for safe_strseq_iter
Browse files Browse the repository at this point in the history
Also make it handle `bytes` by doing an implicit decode. Although this
is not part of our advertised interface anywhere, handling bytes means
that we will have a more graceful behavior in APIs where this is used
if a user passes a bytestring at runtime.
  • Loading branch information
sirosen committed Jul 5, 2023
1 parent b9f2430 commit 264e4bb
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/globus_sdk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ def slash_join(a: str, b: str | None) -> str:


def safe_strseq_iter(
value: t.Iterable[t.Any] | str | uuid.UUID,
) -> t.Generator[str, None, None]:
value: t.Iterable[t.Any] | bytes | str | uuid.UUID,
) -> t.Iterator[str]:
"""
Given an Iterable (typically of strings), produce an iterator over it of strings.
This is a passthrough with two caveats:
This is a passthrough with some caveats:
- if the value is a solitary string, yield only that value
- if the value is a bytestring, yield the utf-8 decoded string
- if the value is a solitary UUID, yield only that value (as a string)
- str values in the iterable which are not strings
Expand All @@ -60,6 +61,8 @@ def safe_strseq_iter(
"""
if isinstance(value, str):
yield value
elif isinstance(value, bytes):
yield value.decode()
elif isinstance(value, uuid.UUID):
yield str(value)
else:
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import uuid

import pytest

from globus_sdk import utils
Expand Down Expand Up @@ -68,3 +70,17 @@ def y(cls):
return cls.x["x"]

assert Foo.y == 1


@pytest.mark.parametrize(
"value, expected_result",
(
("foo", ["foo"]),
(b"foo", ["foo"]),
((1, 2, 3), ["1", "2", "3"]),
(uuid.UUID(int=10), [f"{uuid.UUID(int=10)}"]),
(["foo", uuid.UUID(int=5)], ["foo", f"{uuid.UUID(int=5)}"]),
),
)
def test_safe_strseq_iter(value, expected_result):
assert list(utils.safe_strseq_iter(value)) == expected_result

0 comments on commit 264e4bb

Please sign in to comment.