Skip to content

Commit

Permalink
improve whitelist check and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Schmidt committed Feb 16, 2024
1 parent b25d701 commit 42cf874
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 87 deletions.
10 changes: 4 additions & 6 deletions openeo_fastapi/client/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@ async def get_collections():
if response.status == 200 and resp.get("collections"):
collections_list = []
for collection_json in resp["collections"]:
# For the collections from STAC, only let them through if they're on the whitelist
# This has to be before the legacy collections are added.
if (
len(app_settings.STAC_COLLECTIONS_WHITELIST) < 1
if not (
app_settings.STAC_COLLECTIONS_WHITELIST
or collection_json["id"]
in app_settings.STAC_COLLECTIONS_WHITELIST
):
Expand Down Expand Up @@ -59,8 +57,8 @@ async def get_collection(collection_id):
) as response:
resp = await response.json()
if response.status == 200 and resp.get("id"):
if (
len(app_settings.STAC_COLLECTIONS_WHITELIST) < 1
if not (
app_settings.STAC_COLLECTIONS_WHITELIST
or resp["id"] in app_settings.STAC_COLLECTIONS_WHITELIST
):
return Collection(**resp)
Expand Down
71 changes: 2 additions & 69 deletions openeo_fastapi/client/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from pydantic import AnyUrl, BaseModel, Extra, Field, confloat, constr

# Most of these models are based on previous work from EODC openeo-python-api

# Avoids a Pydantic error:
# TypeError: You should use `typing_extensions.TypedDict` instead of
# `typing.TypedDict` with Python < 3.9.2. Without it, there is no way to
Expand Down Expand Up @@ -230,10 +232,6 @@ class Capabilities(BaseModel):


class CollectionId(str):
"""
Based on https://github.com/stac-utils/stac-fastapi/tree/main/stac_fastapi/types/stac_fastapi/types
"""

collection_id: constr(regex=rb"^[\w\-\.~\/]+$") = Field(
...,
description="A unique identifier for the collection, which MUST match the specified pattern.",
Expand All @@ -242,10 +240,6 @@ class CollectionId(str):


class StacExtensions(BaseModel):
"""
Based on https://github.com/stac-utils/stac-fastapi/tree/main/stac_fastapi/types/stac_fastapi/types
"""

__root__: list[Union[AnyUrl, str]] = Field(
...,
description=(
Expand All @@ -258,32 +252,20 @@ class StacExtensions(BaseModel):


class StacAssets(BaseModel):
"""
Based on https://github.com/stac-utils/stac-fastapi/tree/main/stac_fastapi/types/stac_fastapi/types
"""

pass

class Config:
extra = Extra.allow


class Role(Enum):
"""
Based on https://github.com/stac-utils/stac-fastapi/tree/main/stac_fastapi/types/stac_fastapi/types
"""

producer = "producer"
licensor = "licensor"
processor = "processor"
host = "host"


class StacProvider(BaseModel):
"""
Based on https://github.com/stac-utils/stac-fastapi/tree/main/stac_fastapi/types/stac_fastapi/types
"""

name: str = Field(
...,
description="The name of the organization or the individual.",
Expand Down Expand Up @@ -321,10 +303,6 @@ class StacProvider(BaseModel):


class StacProviders(BaseModel):
"""
Based on https://github.com/stac-utils/stac-fastapi/tree/main/stac_fastapi/types/stac_fastapi/types
"""

__root__: list[StacProvider] = Field(
...,
description=(
Expand All @@ -336,10 +314,6 @@ class StacProviders(BaseModel):


class Description(BaseModel):
"""
Based on https://github.com/stac-utils/stac-fastapi/tree/main/stac_fastapi/types/stac_fastapi/types
"""

__root__: str = Field(
...,
description="""Detailed description to explain the entity.
Expand All @@ -348,19 +322,11 @@ class Description(BaseModel):


class Dimension(BaseModel):
"""
Based on https://github.com/stac-utils/stac-fastapi/tree/main/stac_fastapi/types/stac_fastapi/types
"""

type: Type2 = Field(..., description="Type of the dimension.")
description: Optional[Description] = None


class Spatial(BaseModel):
"""
Based on https://github.com/stac-utils/stac-fastapi/tree/main/stac_fastapi/types/stac_fastapi/types
"""

bbox: Optional[list[list[float]]] = Field(
None,
description=(
Expand All @@ -375,10 +341,6 @@ class Spatial(BaseModel):


class IntervalItem(BaseModel):
"""
Based on https://github.com/stac-utils/stac-fastapi/tree/main/stac_fastapi/types/stac_fastapi/types
"""

__root__: list[Any] = Field(
...,
description=(
Expand All @@ -390,10 +352,6 @@ class IntervalItem(BaseModel):


class Temporal(BaseModel):
"""
Based on https://github.com/stac-utils/stac-fastapi/tree/main/stac_fastapi/types/stac_fastapi/types
"""

interval: Optional[list[IntervalItem]] = Field(
None,
description=(
Expand All @@ -408,10 +366,6 @@ class Temporal(BaseModel):


class Extent(BaseModel):
"""
Based on https://github.com/stac-utils/stac-fastapi/tree/main/stac_fastapi/types/stac_fastapi/types
"""

spatial: Spatial = Field(
...,
description="The *potential* spatial extents of the features in the collection.",
Expand All @@ -425,19 +379,11 @@ class Extent(BaseModel):


class CollectionSummaryStats(BaseModel):
"""
Based on https://github.com/stac-utils/stac-fastapi/tree/main/stac_fastapi/types/stac_fastapi/types
"""

min: Union[str, float] = Field(alias="minimum")
max: Union[str, float] = Field(alias="maximum")


class StacLicense(BaseModel):
"""
Based on https://github.com/stac-utils/stac-fastapi/tree/main/stac_fastapi/types/stac_fastapi/types
"""

__root__: str = Field(
...,
description=(
Expand All @@ -454,10 +400,6 @@ class StacLicense(BaseModel):


class Collection(BaseModel):
"""
Based on https://github.com/stac-utils/stac-fastapi/tree/main/stac_fastapi/types/stac_fastapi/types
"""

stac_version: StacVersion
stac_extensions: Optional[StacExtensions] = None
type: Optional[Type1] = Field(
Expand Down Expand Up @@ -577,10 +519,6 @@ class Config:


class LinksPagination(BaseModel):
"""
Based on https://github.com/stac-utils/stac-fastapi/tree/main/stac_fastapi/types/stac_fastapi/types
"""

__root__: list[Link] = Field(
...,
description="""Links related to this list of resources, for example links for pagination\nor
Expand All @@ -595,11 +533,6 @@ class LinksPagination(BaseModel):


class Collections(TypedDict, total=False):
"""
All collections endpoint.
https://github.com/radiantearth/stac-api-spec/tree/master/collections
"""

collections: list[Collection]
links: list[dict[str, Any]]

Expand Down
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
from unittest import mock

Expand All @@ -9,6 +10,8 @@
from openeo_fastapi.client.core import OpenEOCore

pytestmark = pytest.mark.unit
path_to_current_file = os.path.realpath(__file__)
current_directory = os.path.split(path_to_current_file)[0]


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -59,3 +62,14 @@ def core_api():
api = OpenEOApi(client=client, app=FastAPI())

return api


@pytest.fixture()
def collections():
with open(os.path.join(current_directory, "collections.json")) as f_in:
return json.load(f_in)


@pytest.fixture
def s2a_collection(collections):
return collections[0]
37 changes: 25 additions & 12 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
from unittest import mock

import pytest
from aioresponses import aioresponses
Expand All @@ -10,15 +11,9 @@
from openeo_fastapi.client.collections import get_collection, get_collections
from openeo_fastapi.client.models import Collection

path_to_current_file = os.path.realpath(__file__)
current_directory = os.path.split(path_to_current_file)[0]


@pytest.mark.asyncio
async def test_get_collections():
# TODO: Make collections a fixture
with open(os.path.join(current_directory, "collections.json")) as f_in:
collections = json.load(f_in)
async def test_get_collections(collections):
with aioresponses() as m:
m.get("http://test-stac-api.mock.com/api/collections", payload=collections)

Expand All @@ -29,18 +24,36 @@ async def test_get_collections():


@pytest.mark.asyncio
async def test_get_collection():
with open(os.path.join(current_directory, "collections.json")) as f_in:
collection = json.load(f_in)["collections"][0]
async def test_get_collections_whitelist(collections, s2a_collection):
with mock.patch.dict(os.environ, {"STAC_COLLECTIONS_WHITELIST": "Sentinel-2A"}):
with aioresponses() as m:
m.get(
"http://test-stac-api.mock.com/api/collections",
payload={
"collections": [s2a_collection],
"links": collections["links"],
},
)

data = await get_collections()

col = data["collections"][0]

assert col == s2a_collection
m.assert_called_once_with("http://test-stac-api.mock.com/api/collections")


@pytest.mark.asyncio
async def test_get_collection(s2a_collection):
with aioresponses() as m:
m.get(
"http://test-stac-api.mock.com/api/collections/Sentinel-2A",
payload=collection,
payload=s2a_collection,
)

data = await get_collection("Sentinel-2A")

assert data == Collection(**collection)
assert data == Collection(**s2a_collection)
m.assert_called_once_with(
"http://test-stac-api.mock.com/api/collections/Sentinel-2A"
)
Expand Down

0 comments on commit 42cf874

Please sign in to comment.