From 7d14025ee0084be534772d0924560377194d4ac8 Mon Sep 17 00:00:00 2001 From: Anthony Marcozzi Date: Mon, 3 Nov 2025 15:40:44 -0700 Subject: [PATCH] Refactor API initialization to enable lazy loading and improve modularity - Replaced direct API client instantiations with factory methods (e.g., `get_features_api`, `get_grids_api`) for better flexibility and lazy initialization. - Simplified `set_api_key` to clear cached API instances, ensuring new credentials are applied consistently. - Added robust error handling for missing API keys in SDK usage scenarios (`ensure_client` function). - Removed unused API client imports and instance variables across modules. - Introduced `test_api.py` to validate API initialization, lazy loading, and edge cases (e.g., missing API keys, programmatic key updates). --- fastfuels_sdk/api.py | 225 +++++++++++++++++++-- fastfuels_sdk/domains.py | 19 +- fastfuels_sdk/exports.py | 49 +++-- fastfuels_sdk/features.py | 31 ++- fastfuels_sdk/grids/feature_grid.py | 13 +- fastfuels_sdk/grids/grids.py | 37 ++-- fastfuels_sdk/grids/surface_grid.py | 17 +- fastfuels_sdk/grids/topography_grid.py | 19 +- fastfuels_sdk/grids/tree_grid.py | 19 +- fastfuels_sdk/inventories.py | 26 ++- tests/test_api.py | 261 +++++++++++++++++++++++++ 11 files changed, 575 insertions(+), 141 deletions(-) create mode 100644 tests/test_api.py diff --git a/fastfuels_sdk/api.py b/fastfuels_sdk/api.py index 016c78c..d3d4eff 100644 --- a/fastfuels_sdk/api.py +++ b/fastfuels_sdk/api.py @@ -6,22 +6,73 @@ from typing import Optional from fastfuels_sdk.client_library.api_client import ApiClient +from fastfuels_sdk.client_library.api import ( + DomainsApi, + InventoriesApi, + TreeInventoryApi, + FeaturesApi, + RoadFeatureApi, + WaterFeatureApi, + GridsApi, + TreeGridApi, + SurfaceGridApi, + TopographyGridApi, + FeatureGridApi, +) _client: Optional[ApiClient] = None +_domains_api: Optional[DomainsApi] = None +_inventories_api: Optional[InventoriesApi] = None +_tree_inventory_api: Optional[TreeInventoryApi] = None +_features_api: Optional[FeaturesApi] = None +_road_feature_api: Optional[RoadFeatureApi] = None +_water_feature_api: Optional[WaterFeatureApi] = None +_grids_api: Optional[GridsApi] = None +_tree_grid_api: Optional[TreeGridApi] = None +_surface_grid_api: Optional[SurfaceGridApi] = None +_topography_grid_api: Optional[TopographyGridApi] = None +_feature_grid_api: Optional[FeatureGridApi] = None def set_api_key(api_key: str) -> None: - global _client + """Set the API key for the FastFuels SDK. - config = { - "header_name": "api-key", - "header_value": api_key, - } + This will invalidate the cached API client and all API instances, + ensuring that subsequent API calls use the new credentials. - _client = ApiClient(**config) + Args: + api_key: The API key to use for authentication + """ + global _client, _domains_api, _inventories_api, _tree_inventory_api + global _features_api, _road_feature_api, _water_feature_api + global _grids_api, _tree_grid_api, _surface_grid_api, _topography_grid_api, _feature_grid_api + + _client = None + _domains_api = None + _inventories_api = None + _tree_inventory_api = None + _features_api = None + _road_feature_api = None + _water_feature_api = None + _grids_api = None + _tree_grid_api = None + _surface_grid_api = None + _topography_grid_api = None + _feature_grid_api = None + os.environ["FASTFUELS_API_KEY"] = api_key -def get_client() -> ApiClient: + +def get_client() -> Optional[ApiClient]: + """Get the current API client, creating one if necessary. + + This function will attempt to get the API client from: + 1. The existing _client instance if set_api_key() was called + 2. The FASTFUELS_API_KEY environment variable + + Returns: + The ApiClient instance, or None if no API key is configured + """ global _client if _client is not None: @@ -29,10 +80,8 @@ def get_client() -> ApiClient: api_key = os.getenv("FASTFUELS_API_KEY") if not api_key: - raise RuntimeError( - "FASTFUELS_API_KEY environment variable not set. " - "Please set this variable with your API key." - ) + return None + config = { "header_name": "api-key", "header_value": api_key, @@ -41,3 +90,157 @@ def get_client() -> ApiClient: _client = ApiClient(**config) return _client + + +def ensure_client() -> ApiClient: + """Ensure an API client is configured and return it. + + This function will raise a RuntimeError with a helpful message if no + API key has been configured. + + Returns: + The ApiClient instance + + Raises: + RuntimeError: If no API key is configured + """ + client = get_client() + if client is None: + raise RuntimeError( + "FastFuels API key not configured. Please either:\n" + " 1. Set the FASTFUELS_API_KEY environment variable, or\n" + " 2. Call fastfuels_sdk.api.set_api_key('your-api-key') before making API calls" + ) + return client + + +def get_domains_api() -> DomainsApi: + """Get the cached DomainsApi instance, creating it if necessary. + + Returns: + The DomainsApi instance + """ + global _domains_api + if _domains_api is None: + _domains_api = DomainsApi(ensure_client()) + return _domains_api + + +def get_inventories_api() -> InventoriesApi: + """Get the cached InventoriesApi instance, creating it if necessary. + + Returns: + The InventoriesApi instance + """ + global _inventories_api + if _inventories_api is None: + _inventories_api = InventoriesApi(ensure_client()) + return _inventories_api + + +def get_tree_inventory_api() -> TreeInventoryApi: + """Get the cached TreeInventoryApi instance, creating it if necessary. + + Returns: + The TreeInventoryApi instance + """ + global _tree_inventory_api + if _tree_inventory_api is None: + _tree_inventory_api = TreeInventoryApi(ensure_client()) + return _tree_inventory_api + + +def get_features_api() -> FeaturesApi: + """Get the cached FeaturesApi instance, creating it if necessary. + + Returns: + The FeaturesApi instance + """ + global _features_api + if _features_api is None: + _features_api = FeaturesApi(ensure_client()) + return _features_api + + +def get_road_feature_api() -> RoadFeatureApi: + """Get the cached RoadFeatureApi instance, creating it if necessary. + + Returns: + The RoadFeatureApi instance + """ + global _road_feature_api + if _road_feature_api is None: + _road_feature_api = RoadFeatureApi(ensure_client()) + return _road_feature_api + + +def get_water_feature_api() -> WaterFeatureApi: + """Get the cached WaterFeatureApi instance, creating it if necessary. + + Returns: + The WaterFeatureApi instance + """ + global _water_feature_api + if _water_feature_api is None: + _water_feature_api = WaterFeatureApi(ensure_client()) + return _water_feature_api + + +def get_grids_api() -> GridsApi: + """Get the cached GridsApi instance, creating it if necessary. + + Returns: + The GridsApi instance + """ + global _grids_api + if _grids_api is None: + _grids_api = GridsApi(ensure_client()) + return _grids_api + + +def get_tree_grid_api() -> TreeGridApi: + """Get the cached TreeGridApi instance, creating it if necessary. + + Returns: + The TreeGridApi instance + """ + global _tree_grid_api + if _tree_grid_api is None: + _tree_grid_api = TreeGridApi(ensure_client()) + return _tree_grid_api + + +def get_surface_grid_api() -> SurfaceGridApi: + """Get the cached SurfaceGridApi instance, creating it if necessary. + + Returns: + The SurfaceGridApi instance + """ + global _surface_grid_api + if _surface_grid_api is None: + _surface_grid_api = SurfaceGridApi(ensure_client()) + return _surface_grid_api + + +def get_topography_grid_api() -> TopographyGridApi: + """Get the cached TopographyGridApi instance, creating it if necessary. + + Returns: + The TopographyGridApi instance + """ + global _topography_grid_api + if _topography_grid_api is None: + _topography_grid_api = TopographyGridApi(ensure_client()) + return _topography_grid_api + + +def get_feature_grid_api() -> FeatureGridApi: + """Get the cached FeatureGridApi instance, creating it if necessary. + + Returns: + The FeatureGridApi instance + """ + global _feature_grid_api + if _feature_grid_api is None: + _feature_grid_api = FeatureGridApi(ensure_client()) + return _feature_grid_api diff --git a/fastfuels_sdk/domains.py b/fastfuels_sdk/domains.py index ad04370..1831db0 100644 --- a/fastfuels_sdk/domains.py +++ b/fastfuels_sdk/domains.py @@ -7,8 +7,7 @@ from typing import Optional, List # Internal imports -from fastfuels_sdk.api import get_client -from fastfuels_sdk.client_library.api import DomainsApi +from fastfuels_sdk.api import get_domains_api from fastfuels_sdk.client_library.models import ( Domain as DomainModel, CreateDomainRequest, @@ -21,8 +20,6 @@ # External imports import geopandas as gpd -_DOMAIN_API = DomainsApi(get_client()) - class Domain(DomainModel): """Domain resource for the FastFuels API. @@ -110,7 +107,7 @@ def from_id(cls, domain_id: str) -> "Domain": >>> domain.id 'abc123' """ - get_domain_response = _DOMAIN_API.get_domain(domain_id) + get_domain_response = get_domains_api().get_domain(domain_id) return cls(**get_domain_response.model_dump()) @classmethod @@ -178,7 +175,7 @@ def from_geojson( } request = CreateDomainRequest.from_dict(feature_data) - response = _DOMAIN_API.create_domain( + response = get_domains_api().create_domain( create_domain_request=request.model_dump() # noqa ) return cls(**response.model_dump()) if response else None @@ -297,7 +294,7 @@ def get(self, in_place: bool = False) -> "Domain": ensure all references to this Domain instance see the updated data. """ # Fetch latest data from API - response = _DOMAIN_API.get_domain(self.id) + response = get_domains_api().get_domain(self.id) if in_place: # Update all attributes of current instance @@ -371,7 +368,7 @@ def update( # Only make API call if there are fields to update if update_data: request = UpdateDomainRequest(**update_data) - response = _DOMAIN_API.update_domain( + response = get_domains_api().update_domain( domain_id=self.id, update_domain_request=request ) @@ -504,7 +501,7 @@ def export(self) -> dict: - Grid array data: Use grid export endpoints - Tree inventory records: Use inventory export endpoints """ - return _DOMAIN_API.export_domain_data(domain_id=self.id) + return get_domains_api().export_domain_data(domain_id=self.id) def delete(self) -> None: """Delete an existing domain resource based on the domain ID. @@ -526,7 +523,7 @@ def delete(self) -> None: >>> domain.get() # Raises NotFoundException """ - _DOMAIN_API.delete_domain(domain_id=self.id) + get_domains_api().delete_domain(domain_id=self.id) return None @@ -589,7 +586,7 @@ def list_domains( """ sort_by = DomainSortField(sort_by) if sort_by else None sort_order = DomainSortOrder(sort_order) if sort_order else None - list_response = _DOMAIN_API.list_domains( + list_response = get_domains_api().list_domains( page=page, size=size, sort_by=sort_by, sort_order=sort_order ) list_response.domains = [Domain(**d.model_dump()) for d in list_response.domains] diff --git a/fastfuels_sdk/exports.py b/fastfuels_sdk/exports.py index b269f64..170830d 100644 --- a/fastfuels_sdk/exports.py +++ b/fastfuels_sdk/exports.py @@ -10,34 +10,28 @@ from urllib.request import urlretrieve # Internal imports -from fastfuels_sdk.api import get_client +from fastfuels_sdk.api import ( + get_tree_inventory_api, + get_grids_api, + get_tree_grid_api, + get_surface_grid_api, + get_topography_grid_api, +) from fastfuels_sdk.utils import format_processing_error from fastfuels_sdk.client_library.models import Export as ExportModel -from fastfuels_sdk.client_library.api import ( - TreeInventoryApi, - GridsApi, - TreeGridApi, - SurfaceGridApi, - TopographyGridApi, - FeatureGridApi, -) - -# Initialize API clients -_TREE_INVENTORY_API = TreeInventoryApi(get_client()) -_GRIDS_API = GridsApi(get_client()) -_TREE_GRID_API = TreeGridApi(get_client()) -_SURFACE_GRID_API = SurfaceGridApi(get_client()) -_TOPOGRAPHY_GRID_API = TopographyGridApi(get_client()) -_FEATURE_GRID_API = FeatureGridApi(get_client()) -# Define a mapping of (resource, sub_resource) tuples to their corresponding API methods +# Define a mapping of (resource, sub_resource) tuples to functions that return the API methods +# This ensures that each call uses the current API client instance _API_METHODS = { - ("inventories", "tree"): _TREE_INVENTORY_API.get_tree_inventory_export, - ("grids", None): _GRIDS_API.get_grid_export, - ("grids", "tree"): _TREE_GRID_API.get_tree_grid_export, - ("grids", "surface"): _SURFACE_GRID_API.get_surface_grid_export, - ("grids", "topography"): _TOPOGRAPHY_GRID_API.get_topography_grid_export, - # ("grids", "feature"): _FEATURE_GRID_API.get_feature_grid_export, # Not yet implemented + ("inventories", "tree"): lambda: get_tree_inventory_api().get_tree_inventory_export, + ("grids", None): lambda: get_grids_api().get_grid_export, + ("grids", "tree"): lambda: get_tree_grid_api().get_tree_grid_export, + ("grids", "surface"): lambda: get_surface_grid_api().get_surface_grid_export, + ( + "grids", + "topography", + ): lambda: get_topography_grid_api().get_topography_grid_export, + # ("grids", "feature"): lambda: get_feature_grid_api().get_feature_grid_export, # Not yet implemented } _FILE_NAMES = { @@ -95,14 +89,15 @@ def __init__(self, **data: Any): """ super().__init__(**data) - api_method = _API_METHODS.get((self.resource, self.sub_resource)) - if api_method is None: + api_method_getter = _API_METHODS.get((self.resource, self.sub_resource)) + if api_method_getter is None: raise NotImplementedError( f"Export not implemented for resource={self.resource}, " f"sub_resource={self.sub_resource}" ) - self._api_get_method = lambda: api_method( + # Store a lambda that calls the getter function to get the current API method + self._api_get_method = lambda: api_method_getter()( domain_id=self.domain_id, export_format=self.format ) diff --git a/fastfuels_sdk/features.py b/fastfuels_sdk/features.py index b6c8ba4..c5b1b8f 100644 --- a/fastfuels_sdk/features.py +++ b/fastfuels_sdk/features.py @@ -8,13 +8,12 @@ from typing import Optional, List, Union, Dict, Any # Internal imports -from fastfuels_sdk.api import get_client -from fastfuels_sdk.utils import format_processing_error -from fastfuels_sdk.client_library.api import ( - FeaturesApi, - RoadFeatureApi, - WaterFeatureApi, +from fastfuels_sdk.api import ( + get_features_api, + get_road_feature_api, + get_water_feature_api, ) +from fastfuels_sdk.utils import format_processing_error from fastfuels_sdk.client_library.models import ( Features as FeaturesModel, RoadFeature as RoadFeatureModel, @@ -26,10 +25,6 @@ Geojson, ) -_FEATURES_API = FeaturesApi(get_client()) -_ROAD_FEATURE_API = RoadFeatureApi(get_client()) -_WATER_FEATURE_API = WaterFeatureApi(get_client()) - class Features(FeaturesModel): """Geographic features (roads and water bodies) associated with a domain. @@ -103,7 +98,7 @@ def from_domain_id(cls, domain_id: str) -> Features: -------- Features.get : Refresh feature data """ - features_response = _FEATURES_API.get_features(domain_id=domain_id) + features_response = get_features_api().get_features(domain_id=domain_id) response_data = _convert_api_models_to_sdk_classes( domain_id, features_response.model_dump(), features_response ) @@ -138,7 +133,7 @@ def get(self, in_place: bool = False) -> Features: -------- Features.from_domain : Get features for a specific domain """ - response = _FEATURES_API.get_features(domain_id=self.domain_id) + response = get_features_api().get_features(domain_id=self.domain_id) response_data = response.model_dump() response_data = _convert_api_models_to_sdk_classes( self.domain_id, response_data, response @@ -223,7 +218,7 @@ def create_road_feature( request = CreateRoadFeatureRequest(sources=sources_list, geojson=geojson_param) # Call API - response = _ROAD_FEATURE_API.create_road_feature( + response = get_road_feature_api().create_road_feature( domain_id=self.domain_id, create_road_feature_request=request ) @@ -386,7 +381,7 @@ def create_water_feature( ) # Call API - response = _WATER_FEATURE_API.create_water_feature( + response = get_water_feature_api().create_water_feature( domain_id=self.domain_id, create_water_feature_request=request ) @@ -505,7 +500,7 @@ def get(self, in_place: bool = False) -> RoadFeature: >>> # Or update the existing instance >>> road.get(in_place=True) """ - response = _ROAD_FEATURE_API.get_road_feature(domain_id=self.domain_id) + response = get_road_feature_api().get_road_feature(domain_id=self.domain_id) response_dict = response.model_dump() # The geojson field from response is already a Geojson instance, keep it @@ -602,7 +597,7 @@ def delete(self) -> None: >>> # Subsequent operations will raise NotFoundException >>> road.get() # raises NotFoundException """ - _ROAD_FEATURE_API.delete_road_feature(domain_id=self.domain_id) + get_road_feature_api().delete_road_feature(domain_id=self.domain_id) return None @@ -669,7 +664,7 @@ def get(self, in_place: bool = False) -> WaterFeature: >>> # Or update the existing instance >>> water.get(in_place=True) """ - response = _WATER_FEATURE_API.get_water_feature(domain_id=self.domain_id) + response = get_water_feature_api().get_water_feature(domain_id=self.domain_id) if in_place: # Update all attributes of current instance for key, value in response.model_dump().items(): @@ -761,7 +756,7 @@ def delete(self) -> None: >>> # Subsequent operations will raise NotFoundException >>> water.get() # raises NotFoundException """ - _WATER_FEATURE_API.delete_water_feature(domain_id=self.domain_id) + get_water_feature_api().delete_water_feature(domain_id=self.domain_id) return None diff --git a/fastfuels_sdk/grids/feature_grid.py b/fastfuels_sdk/grids/feature_grid.py index 69aab74..a7bf118 100644 --- a/fastfuels_sdk/grids/feature_grid.py +++ b/fastfuels_sdk/grids/feature_grid.py @@ -6,16 +6,13 @@ from __future__ import annotations # Internal imports -from fastfuels_sdk.api import get_client +from fastfuels_sdk.api import get_feature_grid_api from fastfuels_sdk.utils import format_processing_error -from fastfuels_sdk.client_library.api import FeatureGridApi from fastfuels_sdk.client_library.models import ( FeatureGrid as FeatureGridModel, GridAttributeMetadataResponse, ) -_SURFACE_GRID_API = FeatureGridApi(get_client()) - class FeatureGrid(FeatureGridModel): """Feature grid data within a domain's spatial boundaries.""" @@ -42,7 +39,7 @@ def from_domain_id(cls, domain_id: str) -> "FeatureGrid": >>> print(grid.status) 'completed' """ - response = _SURFACE_GRID_API.get_feature_grid(domain_id=domain_id) + response = get_feature_grid_api().get_feature_grid(domain_id=domain_id) return cls(domain_id=domain_id, **response.model_dump()) def get(self, in_place: bool = False) -> "FeatureGrid": @@ -70,7 +67,7 @@ def get(self, in_place: bool = False) -> "FeatureGrid": >>> # Or update the existing instance >>> grid.get(in_place=True) """ - response = _SURFACE_GRID_API.get_feature_grid(domain_id=self.domain_id) + response = get_feature_grid_api().get_feature_grid(domain_id=self.domain_id) if in_place: # Update all attributes of current instance for key, value in response.model_dump().items(): @@ -206,7 +203,7 @@ def get_attributes(self) -> GridAttributeMetadataResponse: >>> print(metadata.shape) [100, 100, 50] """ - return _SURFACE_GRID_API.get_feature_grid_attribute_metadata( + return get_feature_grid_api().get_feature_grid_attribute_metadata( domain_id=self.domain_id ) @@ -229,5 +226,5 @@ def delete(self) -> None: >>> # Subsequent operations will raise NotFoundException >>> grid.get() # raises NotFoundException """ - _SURFACE_GRID_API.delete_feature_grid(domain_id=self.domain_id) + get_feature_grid_api().delete_feature_grid(domain_id=self.domain_id) return None diff --git a/fastfuels_sdk/grids/grids.py b/fastfuels_sdk/grids/grids.py index 7688ffd..0077012 100644 --- a/fastfuels_sdk/grids/grids.py +++ b/fastfuels_sdk/grids/grids.py @@ -10,19 +10,18 @@ from typing import Optional, List # Internal imports -from fastfuels_sdk.api import get_client +from fastfuels_sdk.api import ( + get_grids_api, + get_tree_grid_api, + get_surface_grid_api, + get_topography_grid_api, + get_feature_grid_api, +) from fastfuels_sdk.exports import Export from fastfuels_sdk.grids.tree_grid import TreeGrid from fastfuels_sdk.grids.feature_grid import FeatureGrid from fastfuels_sdk.grids.surface_grid import SurfaceGrid from fastfuels_sdk.grids.topography_grid import TopographyGrid -from fastfuels_sdk.client_library.api import ( - GridsApi, - TreeGridApi, - SurfaceGridApi, - TopographyGridApi, - FeatureGridApi, -) from fastfuels_sdk.client_library.models import ( Grids as GridsModel, CreateSurfaceGridRequest, @@ -44,12 +43,6 @@ CreateFeatureGridRequest, ) -_GRIDS_API = GridsApi(get_client()) -_TREE_GRID_API = TreeGridApi(get_client()) -_SURFACE_GRID_API = SurfaceGridApi(get_client()) -_TOPOGRAPHY_GRID_API = TopographyGridApi(get_client()) -_FEATURE_GRID_API = FeatureGridApi(get_client()) - class Grids(GridsModel): """Container for different types of gridded data within a domain's spatial boundaries. @@ -127,7 +120,7 @@ def from_domain_id(cls, domain_id: str) -> Grids: >>> if grids.surface: ... print("Domain has surface grid data") """ - grids_response = _GRIDS_API.get_grids(domain_id=domain_id) + grids_response = get_grids_api().get_grids(domain_id=domain_id) response_data = grids_response.model_dump() response_data = _convert_api_models_to_sdk_classes(domain_id, response_data) @@ -157,7 +150,7 @@ def get(self, in_place: bool = False) -> Grids: >>> # Or update the existing instance >>> grids.get(in_place=True) """ - response = _GRIDS_API.get_grids(domain_id=self.domain_id) + response = get_grids_api().get_grids(domain_id=self.domain_id) response_data = response.model_dump() response_data = _convert_api_models_to_sdk_classes( self.domain_id, response_data @@ -270,7 +263,7 @@ def create_surface_grid( ), ) - response = _SURFACE_GRID_API.create_surface_grid( + response = get_surface_grid_api().create_surface_grid( domain_id=self.domain_id, create_surface_grid_request=request ) @@ -407,7 +400,7 @@ def create_topography_grid( aspect=(TopographyGridAspectSource.from_dict(aspect) if aspect else None), ) - response = _TOPOGRAPHY_GRID_API.create_topography_grid( + response = get_topography_grid_api().create_topography_grid( domain_id=self.domain_id, create_topography_grid_request=request ) @@ -541,7 +534,7 @@ def create_tree_grid( SAVR=(TreeGridSAVRSource.from_dict(savr) if savr else None), ) - response = _TREE_GRID_API.create_tree_grid( + response = get_tree_grid_api().create_tree_grid( domain_id=self.domain_id, create_tree_grid_request=request ) @@ -613,7 +606,7 @@ def create_feature_grid( attributes=attributes, # type: ignore # pydantic handles this for us ) - response = _FEATURE_GRID_API.create_feature_grid( + response = get_feature_grid_api().create_feature_grid( domain_id=self.domain_id, create_feature_grid_request=request ) @@ -656,7 +649,7 @@ def create_export(self, export_format: str) -> Export: >>> export.wait_until_completed() >>> export.to_file("grid_data.zip") """ - response = _GRIDS_API.create_grid_export( + response = get_grids_api().create_grid_export( domain_id=self.domain_id, export_format=export_format ) return Export(**response.model_dump()) @@ -686,7 +679,7 @@ def get_export(self, export_format: str) -> Export: >>> if export.status == "completed": ... export.to_file("grid_data.zarr") """ - response = _GRIDS_API.get_grid_export( + response = get_grids_api().get_grid_export( domain_id=self.domain_id, export_format=export_format ) return Export(**response.model_dump()) diff --git a/fastfuels_sdk/grids/surface_grid.py b/fastfuels_sdk/grids/surface_grid.py index 4208cfa..b8739a6 100644 --- a/fastfuels_sdk/grids/surface_grid.py +++ b/fastfuels_sdk/grids/surface_grid.py @@ -6,17 +6,14 @@ from __future__ import annotations # Internal imports -from fastfuels_sdk.api import get_client +from fastfuels_sdk.api import get_surface_grid_api from fastfuels_sdk.utils import format_processing_error from fastfuels_sdk.exports import Export -from fastfuels_sdk.client_library.api import SurfaceGridApi from fastfuels_sdk.client_library.models import ( SurfaceGrid as SurfaceGridModel, GridAttributeMetadataResponse, ) -_SURFACE_GRID_API = SurfaceGridApi(get_client()) - class SurfaceGrid(SurfaceGridModel): """Surface grid data within a domain's spatial boundaries.""" @@ -43,7 +40,7 @@ def from_domain_id(cls, domain_id: str) -> "SurfaceGrid": >>> print(grid.status) 'completed' """ - response = _SURFACE_GRID_API.get_surface_grid(domain_id=domain_id) + response = get_surface_grid_api().get_surface_grid(domain_id=domain_id) return cls(domain_id=domain_id, **response.model_dump()) def get(self, in_place: bool = False) -> "SurfaceGrid": @@ -71,7 +68,7 @@ def get(self, in_place: bool = False) -> "SurfaceGrid": >>> # Or update the existing instance >>> grid.get(in_place=True) """ - response = _SURFACE_GRID_API.get_surface_grid(domain_id=self.domain_id) + response = get_surface_grid_api().get_surface_grid(domain_id=self.domain_id) if in_place: # Update all attributes of current instance for key, value in response.model_dump().items(): @@ -207,7 +204,7 @@ def get_attributes(self) -> GridAttributeMetadataResponse: >>> print(metadata.shape) [100, 100, 50] """ - return _SURFACE_GRID_API.get_surface_grid_attribute_metadata( + return get_surface_grid_api().get_surface_grid_attribute_metadata( domain_id=self.domain_id ) @@ -234,7 +231,7 @@ def create_export(self, export_format: str) -> Export: >>> export.wait_until_completed() >>> export.to_file("grid_data.zarr") """ - response = _SURFACE_GRID_API.create_surface_grid_export( + response = get_surface_grid_api().create_surface_grid_export( domain_id=self.domain_id, export_format=export_format ) return Export(**response.model_dump()) @@ -262,7 +259,7 @@ def get_export(self, export_format: str) -> Export: >>> export.wait_until_completed() >>> export.to_file("grid_data.zarr") """ - response = _SURFACE_GRID_API.get_surface_grid_export( + response = get_surface_grid_api().get_surface_grid_export( domain_id=self.domain_id, export_format=export_format ) return Export(**response.model_dump()) @@ -286,5 +283,5 @@ def delete(self) -> None: >>> # Subsequent operations will raise NotFoundException >>> grid.get() # raises NotFoundException """ - _SURFACE_GRID_API.delete_surface_grid(domain_id=self.domain_id) + get_surface_grid_api().delete_surface_grid(domain_id=self.domain_id) return None diff --git a/fastfuels_sdk/grids/topography_grid.py b/fastfuels_sdk/grids/topography_grid.py index 8cbec3e..3bb6c7e 100644 --- a/fastfuels_sdk/grids/topography_grid.py +++ b/fastfuels_sdk/grids/topography_grid.py @@ -6,17 +6,14 @@ from __future__ import annotations # Internal imports -from fastfuels_sdk.api import get_client +from fastfuels_sdk.api import get_topography_grid_api from fastfuels_sdk.utils import format_processing_error from fastfuels_sdk.exports import Export -from fastfuels_sdk.client_library.api import TopographyGridApi from fastfuels_sdk.client_library.models import ( TopographyGrid as TopographyGridModel, GridAttributeMetadataResponse, ) -_TOPOGRAPHY_GRID_API = TopographyGridApi(get_client()) - class TopographyGrid(TopographyGridModel): """Topography grid data within a domain's spatial boundaries.""" @@ -43,7 +40,7 @@ def from_domain_id(cls, domain_id: str) -> "TopographyGrid": >>> print(grid.status) 'completed' """ - response = _TOPOGRAPHY_GRID_API.get_topography_grid(domain_id=domain_id) + response = get_topography_grid_api().get_topography_grid(domain_id=domain_id) return cls(domain_id=domain_id, **response.model_dump()) def get(self, in_place: bool = False) -> "TopographyGrid": @@ -71,7 +68,9 @@ def get(self, in_place: bool = False) -> "TopographyGrid": >>> # Or update the existing instance >>> grid.get(in_place=True) """ - response = _TOPOGRAPHY_GRID_API.get_topography_grid(domain_id=self.domain_id) + response = get_topography_grid_api().get_topography_grid( + domain_id=self.domain_id + ) if in_place: # Update all attributes of current instance for key, value in response.model_dump().items(): @@ -207,7 +206,7 @@ def get_attributes(self) -> GridAttributeMetadataResponse: >>> print(metadata.shape) [100, 100, 50] """ - return _TOPOGRAPHY_GRID_API.get_topography_grid_attribute_metadata( + return get_topography_grid_api().get_topography_grid_attribute_metadata( domain_id=self.domain_id ) @@ -234,7 +233,7 @@ def create_export(self, export_format: str) -> Export: >>> export.wait_until_completed() >>> export.to_file("grid_data.zarr") """ - response = _TOPOGRAPHY_GRID_API.create_topography_grid_export( + response = get_topography_grid_api().create_topography_grid_export( domain_id=self.domain_id, export_format=export_format ) return Export(**response.model_dump()) @@ -262,7 +261,7 @@ def get_export(self, export_format: str) -> Export: >>> export.wait_until_completed() >>> export.to_file("grid_data.zarr") """ - response = _TOPOGRAPHY_GRID_API.get_topography_grid_export( + response = get_topography_grid_api().get_topography_grid_export( domain_id=self.domain_id, export_format=export_format ) return Export(**response.model_dump()) @@ -286,5 +285,5 @@ def delete(self) -> None: >>> # Subsequent operations will raise NotFoundException >>> grid.get() # raises NotFoundException """ - _TOPOGRAPHY_GRID_API.delete_topography_grid(domain_id=self.domain_id) + get_topography_grid_api().delete_topography_grid(domain_id=self.domain_id) return None diff --git a/fastfuels_sdk/grids/tree_grid.py b/fastfuels_sdk/grids/tree_grid.py index db013d4..2fcee33 100644 --- a/fastfuels_sdk/grids/tree_grid.py +++ b/fastfuels_sdk/grids/tree_grid.py @@ -6,17 +6,14 @@ from __future__ import annotations # Internal imports -from fastfuels_sdk.api import get_client +from fastfuels_sdk.api import get_tree_grid_api from fastfuels_sdk.utils import format_processing_error from fastfuels_sdk.exports import Export -from fastfuels_sdk.client_library.api import TreeGridApi from fastfuels_sdk.client_library.models import ( TreeGrid as TreeGridModel, GridAttributeMetadataResponse, ) -_TREE_GRID_API = TreeGridApi(get_client()) - class TreeGrid(TreeGridModel): """Tree grid data within a domain's spatial boundaries.""" @@ -43,7 +40,7 @@ def from_domain_id(cls, domain_id: str) -> TreeGrid: >>> tree_grid.status 'completed' """ - response = _TREE_GRID_API.get_tree_grid(domain_id=domain_id) + response = get_tree_grid_api().get_tree_grid(domain_id=domain_id) return cls(domain_id=domain_id, **response.model_dump()) def get(self, in_place: bool = False) -> TreeGrid: @@ -71,7 +68,7 @@ def get(self, in_place: bool = False) -> TreeGrid: >>> # Or update the existing instance >>> tree_grid.get(in_place=True) """ - response = _TREE_GRID_API.get_tree_grid(domain_id=self.domain_id) + response = get_tree_grid_api().get_tree_grid(domain_id=self.domain_id) if in_place: # Update all attributes of current instance for key, value in response.model_dump().items(): @@ -194,7 +191,9 @@ def get_attributes(self) -> GridAttributeMetadataResponse: >>> print(metadata.shape) [100, 100, 50] """ - return _TREE_GRID_API.get_tree_grid_attribute_metadata(domain_id=self.domain_id) + return get_tree_grid_api().get_tree_grid_attribute_metadata( + domain_id=self.domain_id + ) def create_export(self, export_format: str) -> Export: """Create an export of the tree grid data. @@ -219,7 +218,7 @@ def create_export(self, export_format: str) -> Export: >>> export.wait_until_completed() >>> export.to_file("grid_data.zarr") """ - response = _TREE_GRID_API.create_tree_grid_export( + response = get_tree_grid_api().create_tree_grid_export( domain_id=self.domain_id, export_format=export_format ) return Export(**response.model_dump()) @@ -246,7 +245,7 @@ def get_export(self, export_format: str) -> Export: >>> export.wait_until_completed() >>> export.to_file("grid_data.zarr") """ - response = _TREE_GRID_API.get_tree_grid_export( + response = get_tree_grid_api().get_tree_grid_export( domain_id=self.domain_id, export_format=export_format ) return Export(**response.model_dump()) @@ -270,5 +269,5 @@ def delete(self) -> None: >>> # Subsequent operations will raise NotFoundException >>> tree_grid.get() # raises NotFoundException """ - _TREE_GRID_API.delete_tree_grid(domain_id=self.domain_id) + get_tree_grid_api().delete_tree_grid(domain_id=self.domain_id) return None diff --git a/fastfuels_sdk/inventories.py b/fastfuels_sdk/inventories.py index 9bc66db..7b8b814 100644 --- a/fastfuels_sdk/inventories.py +++ b/fastfuels_sdk/inventories.py @@ -9,13 +9,12 @@ from typing import Optional # Internal imports -from fastfuels_sdk.api import get_client +from fastfuels_sdk.api import get_inventories_api, get_tree_inventory_api from fastfuels_sdk.utils import ( parse_dict_items_to_pydantic_list, format_processing_error, ) from fastfuels_sdk.exports import Export -from fastfuels_sdk.client_library.api import InventoriesApi, TreeInventoryApi from fastfuels_sdk.client_library.models import ( Inventories as InventoriesModel, TreeInventory as TreeInventoryModel, @@ -31,9 +30,6 @@ # External imports import requests -_INVENTORIES_API = InventoriesApi(get_client()) -_TREE_INVENTORY_API = TreeInventoryApi(get_client()) - class Inventories(InventoriesModel): """ @@ -76,7 +72,9 @@ def from_domain_id(cls, domain_id: str) -> Inventories: >>> from fastfuels_sdk import Inventories >>> inventories = Inventories.from_domain_id("abc123") """ - inventories_response = _INVENTORIES_API.get_inventories(domain_id=domain_id) + inventories_response = get_inventories_api().get_inventories( + domain_id=domain_id + ) response_data = inventories_response.model_dump() response_data = _convert_api_models_to_sdk_classes(domain_id, response_data) @@ -110,7 +108,7 @@ def get(self, in_place: bool = False): >>> # Fetch and update the inventory data in place >>> inventories.get(in_place=True) """ - response = _INVENTORIES_API.get_inventories(domain_id=self.domain_id) + response = get_inventories_api().get_inventories(domain_id=self.domain_id) response_data = response.model_dump() response_data = _convert_api_models_to_sdk_classes( self.domain_id, response_data @@ -277,7 +275,7 @@ def create_tree_inventory( [feature_masks] if isinstance(feature_masks, str) else feature_masks ), ) - response = _TREE_INVENTORY_API.create_tree_inventory( + response = get_tree_inventory_api().create_tree_inventory( self.domain_id, request_body ) tree_inventory = TreeInventory( @@ -575,7 +573,7 @@ def create_tree_inventory_from_file_upload( raise ValueError(f"File must be a CSV: {file_path}") # Create tree inventory resource with "file" source - signed_url_response = _TREE_INVENTORY_API.create_tree_inventory( + signed_url_response = get_tree_inventory_api().create_tree_inventory( self.domain_id, CreateTreeInventoryRequest(sources=[TreeInventorySource.FILE]), ) @@ -709,7 +707,7 @@ def from_domain_id(cls, domain_id: str) -> TreeInventory: - Use get() to refresh the inventory data and wait_until_completed() to wait for processing to finish """ - response = _TREE_INVENTORY_API.get_tree_inventory(domain_id=domain_id) + response = get_tree_inventory_api().get_tree_inventory(domain_id=domain_id) return cls(domain_id=domain_id, **response.model_dump()) def get(self, in_place: bool = False): @@ -774,7 +772,7 @@ def get(self, in_place: bool = False): - This method is often used in conjunction with wait_until_completed() to monitor the progress of tree inventory processing. """ - response = _TREE_INVENTORY_API.get_tree_inventory(domain_id=self.domain_id) + response = get_tree_inventory_api().get_tree_inventory(domain_id=self.domain_id) if in_place: # Update all attributes of current instance for key, value in response.model_dump().items(): @@ -932,7 +930,7 @@ def delete(self) -> None: - Consider creating an export of important inventory data before deletion using create_export() """ - _TREE_INVENTORY_API.delete_tree_inventory(domain_id=self.domain_id) + get_tree_inventory_api().delete_tree_inventory(domain_id=self.domain_id) return None @@ -1013,7 +1011,7 @@ def create_export(self, export_format: str) -> Export: process - use get() to check status, wait_until_completed() to wait for completion, and to_file() to download """ - response = _TREE_INVENTORY_API.create_tree_inventory_export( + response = get_tree_inventory_api().create_tree_inventory_export( domain_id=self.domain_id, export_format=export_format ) return Export(**response.model_dump()) @@ -1077,7 +1075,7 @@ def get_export(self, export_format: str) -> Export: create_export().wait_until_completed() is simpler - Always check the export's status before attempting to download using to_file() """ - response = _TREE_INVENTORY_API.get_tree_inventory_export( + response = get_tree_inventory_api().get_tree_inventory_export( domain_id=self.domain_id, export_format=export_format ) return Export(**response.model_dump()) diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..b611912 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,261 @@ +""" +Tests for API initialization and lazy loading. + +These tests verify that: +1. The SDK can be imported without FASTFUELS_API_KEY env var (Issue #112) +2. The set_api_key() function properly updates all API clients (Issue #98) +""" + +import os +import pytest + + +@pytest.fixture(autouse=True) +def cleanup_api_state(): + """Fixture to clean up API state after each test.""" + # Store original state + original_env_key = os.environ.get("FASTFUELS_API_KEY") + + yield + + # Restore original environment variable + if original_env_key: + os.environ["FASTFUELS_API_KEY"] = original_env_key + # Reset the client to use the original key + from fastfuels_sdk import api + + api._client = None + # Trigger recreation with original key + api.get_client() + else: + os.environ.pop("FASTFUELS_API_KEY", None) + from fastfuels_sdk import api + + api._client = None + + +def test_import_without_env_var(): + """Test that the SDK can be imported without FASTFUELS_API_KEY set. + + This addresses Issue #112: Package import fails if fastfuels key not set as an environment variable. + + The SDK should allow imports without requiring the API key to be set upfront. + Users should be able to set the API key programmatically after import. + """ + # Save current env var state + original_key = os.environ.pop("FASTFUELS_API_KEY", None) + + try: + # These imports should not raise RuntimeError + from fastfuels_sdk import api + from fastfuels_sdk import Domain + from fastfuels_sdk import Inventories + from fastfuels_sdk import Features + from fastfuels_sdk import Grids + from fastfuels_sdk import Export + + # Import should succeed + assert api is not None + assert Domain is not None + assert Inventories is not None + assert Features is not None + assert Grids is not None + assert Export is not None + + finally: + # Restore env var + if original_key: + os.environ["FASTFUELS_API_KEY"] = original_key + + +def test_api_call_without_key_raises_error(): + """Test that API calls without setting a key raise helpful error.""" + # Save current env var state + original_key = os.environ.pop("FASTFUELS_API_KEY", None) + + try: + from fastfuels_sdk import api + + # Clear any cached client and API instances + api._client = None + api._domains_api = None + + # Attempting to use the API without a key should raise RuntimeError + with pytest.raises(RuntimeError) as exc_info: + api.get_domains_api() + + # Check that error message is helpful + error_msg = str(exc_info.value) + assert "API key not configured" in error_msg + assert "set_api_key" in error_msg or "FASTFUELS_API_KEY" in error_msg + + finally: + # Restore env var + if original_key: + os.environ["FASTFUELS_API_KEY"] = original_key + + +def test_set_api_key_updates_client(): + """Test that set_api_key() creates a new client and invalidates caches. + + This addresses Issue #98: Fix set_api_key function not updating the api key. + + When set_api_key() is called, it should: + 1. Create a new client with the new API key + 2. Invalidate all cached API instances + 3. Ensure subsequent API calls use the new key + """ + from fastfuels_sdk import api + + # Set initial API key + api.set_api_key("test-key-1") + + # Get initial client + client1 = api.get_client() + assert client1 is not None + assert client1.default_headers["api-key"] == "test-key-1" + + # Create a domain API instance with first key + domain_api1 = api.get_domains_api() + assert domain_api1 is not None + + # Set new API key + api.set_api_key("test-key-2") + + # Get new client - should be different + client2 = api.get_client() + assert client2 is not None + assert client2 is not client1 # Should be a new instance + assert client2.default_headers["api-key"] == "test-key-2" + + # Verify cached domain API was invalidated + assert api._domains_api is None + + # Get new domain API - should use new client + domain_api2 = api.get_domains_api() + assert domain_api2 is not None + assert domain_api2 is not domain_api1 # Should be a new instance + assert domain_api2.api_client.default_headers["api-key"] == "test-key-2" + + +def test_invalidation_affects_all_modules(): + """Test that set_api_key() invalidates cached APIs in all modules.""" + from fastfuels_sdk import api + + # Set initial API key and trigger lazy loading of all API instances + api.set_api_key("test-key-initial") + + api.get_domains_api() + api.get_inventories_api() + api.get_tree_inventory_api() + api.get_features_api() + api.get_road_feature_api() + api.get_water_feature_api() + api.get_grids_api() + api.get_tree_grid_api() + api.get_surface_grid_api() + api.get_topography_grid_api() + api.get_feature_grid_api() + + # Verify all API instances have been cached + assert api._domains_api is not None + assert api._inventories_api is not None + assert api._tree_inventory_api is not None + assert api._features_api is not None + assert api._road_feature_api is not None + assert api._water_feature_api is not None + assert api._grids_api is not None + assert api._tree_grid_api is not None + assert api._surface_grid_api is not None + assert api._topography_grid_api is not None + assert api._feature_grid_api is not None + + # Set new API key + api.set_api_key("test-key-new") + + # Verify all cached instances were invalidated + assert api._domains_api is None + assert api._inventories_api is None + assert api._tree_inventory_api is None + assert api._features_api is None + assert api._road_feature_api is None + assert api._water_feature_api is None + assert api._grids_api is None + assert api._tree_grid_api is None + assert api._surface_grid_api is None + assert api._topography_grid_api is None + assert api._feature_grid_api is None + + +def test_programmatic_api_key_setting(): + """Test the use case from Issue #112: setting API key programmatically.""" + # This is the use case that was broken before the fix + + # Remove env var to simulate user without env var set + original_key = os.environ.pop("FASTFUELS_API_KEY", None) + + try: + # Import should work without env var + from fastfuels_sdk import api # noqa: F401 + from fastfuels_sdk import Domain # noqa: F401 + + # Set API key programmatically + api.set_api_key("my-programmatic-key") + + # Verify client was created with correct key + client = api.get_client() + assert client is not None + assert client.default_headers["api-key"] == "my-programmatic-key" + + finally: + # Restore env var + if original_key: + os.environ["FASTFUELS_API_KEY"] = original_key + + +def test_env_var_takes_effect_on_first_use(): + """Test that FASTFUELS_API_KEY env var is used on first API call.""" + # Set env var + os.environ["FASTFUELS_API_KEY"] = "env-var-key" + + try: + from fastfuels_sdk import api + + # Clear any cached client + api._client = None + + # First call should use env var + client = api.get_client() + assert client is not None + assert client.default_headers["api-key"] == "env-var-key" + + finally: + # Clean up + os.environ.pop("FASTFUELS_API_KEY", None) + + +def test_set_api_key_overrides_env_var(): + """Test that set_api_key() overrides FASTFUELS_API_KEY env var.""" + # Set env var + os.environ["FASTFUELS_API_KEY"] = "env-var-key" + + try: + from fastfuels_sdk import api + + # Clear any cached client + api._client = None + + # First use env var + client1 = api.get_client() + assert client1.default_headers["api-key"] == "env-var-key" + + # Override with set_api_key + api.set_api_key("override-key") + + # Should now use override key + client2 = api.get_client() + assert client2.default_headers["api-key"] == "override-key" + + finally: + # Clean up + os.environ.pop("FASTFUELS_API_KEY", None)