diff --git a/justfile b/justfile index dc26436aa..05457a79c 100644 --- a/justfile +++ b/justfile @@ -131,3 +131,7 @@ generate-dag-docs fail_on_diff="false": exit 1 fi fi + +# Generate files for a new provider +add-provider provider_name endpoint +media_types="image": + python3 openverse_catalog/templates/create_provider_ingester.py "{{ provider_name }}" "{{ endpoint }}" -m {{ media_types }} diff --git a/openverse_catalog/dags/providers/provider_api_scripts/provider_data_ingester.py b/openverse_catalog/dags/providers/provider_api_scripts/provider_data_ingester.py index 7028e0c4c..222ca3a76 100644 --- a/openverse_catalog/dags/providers/provider_api_scripts/provider_data_ingester.py +++ b/openverse_catalog/dags/providers/provider_api_scripts/provider_data_ingester.py @@ -68,8 +68,20 @@ class ProviderDataIngester(ABC): @abstractmethod def providers(self) -> dict[str, str]: """ - A dictionary whose keys are the supported `media_types`, and values are - the `provider` string in the `media` table of the DB for that type. + A dictionary mapping each supported media type to its corresponding + `provider` string (the string that will populate the `provider` field + in the Catalog DB). These strings should be defined as constants in + common.loader.provider_details.py + + By convention, when a provider supports multiple media types we set + separate provider strings for each type. For example: + + ``` + providers = { + "image": provider_details.MYPROVIDER_IMAGE_PROVIDER, + "audio": provider_details.MYPROVIDER_AUDIO_PROVIDER, + } + ``` """ pass @@ -105,7 +117,7 @@ def __init__(self, conf: dict = None, date: str = None): self.delayed_requester = DelayedRequester( delay=self.delay, headers=self.headers ) - self.media_stores = self.init_media_stores() + self.media_stores = self._init_media_stores() self.date = date # dag_run configuration options @@ -126,7 +138,7 @@ def __init__(self, conf: dict = None, date: str = None): # Create a generator to facilitate fetching the next set of query_params. self.override_query_params = (qp for qp in query_params_list) - def init_media_stores(self) -> dict[str, MediaStore]: + def _init_media_stores(self) -> dict[str, MediaStore]: """ Initialize a media store for each media type supported by this provider. @@ -153,7 +165,7 @@ def ingest_records(self, **kwargs) -> None: logger.info(f"Begin ingestion for {self.__class__.__name__}") while should_continue: - query_params = self.get_query_params(query_params, **kwargs) + query_params = self._get_query_params(query_params, **kwargs) if query_params is None: # Break out of ingestion if no query_params are supplied. This can # happen when the final `override_query_params` is processed. @@ -175,7 +187,7 @@ def ingest_records(self, **kwargs) -> None: # If errors have already been caught during processing, raise them # as well. - if error_summary := self.get_ingestion_errors(): + if error_summary := self._get_ingestion_errors(): raise error_summary from error raise @@ -192,7 +204,7 @@ def ingest_records(self, **kwargs) -> None: # Commit whatever records we were able to process, and rethrow the # exception so the taskrun fails. - self.commit_records() + self._commit_records() raise error from ingestion_error if self.limit and record_count >= self.limit: @@ -200,13 +212,13 @@ def ingest_records(self, **kwargs) -> None: should_continue = False # Commit whatever records we were able to process - self.commit_records() + self._commit_records() # If errors were caught during processing, raise them now - if error_summary := self.get_ingestion_errors(): + if error_summary := self._get_ingestion_errors(): raise error_summary - def get_ingestion_errors(self) -> AggregateIngestionError | None: + def _get_ingestion_errors(self) -> AggregateIngestionError | None: """ If any errors were skipped during ingestion, log them as well as the associated query parameters. Then return an AggregateIngestionError. @@ -235,10 +247,13 @@ def get_ingestion_errors(self) -> AggregateIngestionError | None: ) return None - def get_query_params(self, prev_query_params: dict | None, **kwargs) -> dict | None: + def _get_query_params( + self, prev_query_params: dict | None, **kwargs + ) -> dict | None: """ Returns the next set of query_params for the next request, handling - optional overrides via the dag_run conf. + optional overrides via the dag_run conf. This method should not be overridden; + instead override get_next_query_params. """ # If we are getting query_params for the first batch and initial_query_params # have been set, return them. @@ -391,7 +406,7 @@ def get_record_data(self, data: dict) -> dict | list[dict] | None: """ pass - def commit_records(self) -> int: + def _commit_records(self) -> int: total = 0 for store in self.media_stores.values(): total += store.commit() diff --git a/openverse_catalog/docs/adding_a_new_provider.md b/openverse_catalog/docs/adding_a_new_provider.md new file mode 100644 index 000000000..59d7f5f86 --- /dev/null +++ b/openverse_catalog/docs/adding_a_new_provider.md @@ -0,0 +1,99 @@ +# Openverse Providers + +## Overview + +The Openverse Catalog collects data from the APIs of sites that share openly-licensed media, and saves them in our Catalog database. This process is automated by [Airflow DAGs](https://airflow.apache.org/docs/apache-airflow/stable/concepts/dags.html) generated for each provider. A simple provider DAG looks like this: + +![Example DAG](assets/provider_dags/simple_dag.png) + +At a high level the steps are: + +1. `generate_filename`: Generates the named of a TSV (tab-separated values) text file that will be used for saving the data to the disk in later steps +2. `pull_data`: Pulls records from the provider API, collects just the data we need, and commits it to local storage in TSVs. +3. `load_data`: Loads the data from TSVs into the Catalog database, updating old records and discarding duplicates. +4. `report_load_completion`: Reports a summary of added and updated records. + +When a provider supports multiple media types (for example, `audio` *and* `images`), the `pull` step consumes data of all types, but separate `load` steps are generated: + +![Example Multi-Media DAG](assets/provider_dags/multi_media_dag.png) + +## Adding a New Provider + +Adding a new provider to Openverse means adding a new provider DAG. Fortunately, our DAG factories automate most of this process. To generate a fully functioning provider DAG, you need to: + +1. Implement a `ProviderDataIngester` +2. Add a `ProviderWorkflow` configuration class + +### Implementing a `ProviderDataIngester` class + +We call the code that pulls data from our provider APIs "Provider API scripts". You can find examples in [`provider_api_scripts` folder](../dags/providers/provider_api_scripts). This code will be run during the `pull` steps of the provider DAG. + +At a high level, a provider script should iteratively request batches of records from the provider API, extract data in the format required by Openverse, and commit it to local storage. Much of this logic is implemented in a [`ProviderDataIngester` base class](../dags/providers/provider_api_scripts/provider_data_ingester.py) (which also provides additional testing features **). To add a new provider, extend this class and implement its abstract methods. + +We provide a [script](../dags/templates/create_provider_ingester.py) that can be used to generate the files you'll need and get you started: + +``` +# PROVIDER_NAME: The name of the provider +# ENDPOINT: The API endpoint from which to fetch data +# MEDIA: Optionally, a space-delineated list of media types ingested by this provider +# (and supported by Openverse). If not provided, defaults to "image". + +> just add-provider + +# Example usages: + +# Creates a provider that supports just audio +> just add-provider TestProvider https://test.test/search audio + +# Creates a provider that supports images and audio +> just add-provider "Foobar Museum" https://foobar.museum.org/api/v1 image audio + +# Creates a provider that supports the default, just image +> just add-provider TestProvider https://test.test/search +``` + +You should see output similar to this: +``` +Creating files in /Users/staci/projects/openverse-projects/openverse-catalog +API script: openverse-catalog/openverse_catalog/dags/providers/provider_api_scripts/foobar_museum.py +API script test: openverse-catalog/tests/dags/providers/provider_api_scripts/test_foobar_museum.py + +NOTE: You will also need to add a new ProviderWorkflow dataclass configuration to the PROVIDER_WORKFLOWS list in `openverse-catalog/dags/providers/provider_workflows.py`. +``` + +This generates a provider script with a templated `ProviderDataIngester` for you in the [`provider_api_scripts` folder](../dags/providers/provider_api_scripts), as well as a corresponding test file. Complete the TODOs detailed in the generated files to implement behavior specific to your API. + +Some APIs may not fit perfectly into the established `ProviderDataIngester` pattern. For advanced use cases and examples of how to modify the ingestion flow, see the [`ProviderDataIngester` FAQ](provider_data_ingester_faq.md). + + +### Add a `ProviderWorkflow` configuration class + +Now that you have an ingester class, you're ready to wire up a provider DAG in Airflow to automatically pull data and load it into our Catalog database. This is done by defining a `ProviderWorkflow` configuration dataclass and adding it to the `PROVIDER_WORKFLOWS` list in [`provider_workflows.py`](../dags/providers/provider_workflows.py). Our DAG factories will pick up the configuration and generate a complete new DAG in Airflow! + +At minimum, you'll need to provide the following in your configuration: +* `provider_script`: the name of the file where you defined your `ProviderDataIngester` class +* `ingestion_callable`: the `ProviderDataIngester` class itself +* `media_types`: the media types your provider handles + +Example: +```python +# In openverse_catalog/dags/providers/provider_workflows.py +from providers.provider_api_scripts.foobar_museum import FoobarMuseumDataIngester + +... + +PROVIDER_WORKFLOWS = [ + ... + ProviderWorkflow( + provider_script='foobar_museum', + ingestion_callable=FooBarMuseumDataIngester, + media_types=("image", "audio",) + ) +] +``` + +There are many other options that allow you to tweak the `schedule` (when and how often your DAG is run), timeouts for individual steps of the DAG, and more. These are documented in the definition of the `ProviderWorkflow` dataclass. ** + +After adding your configuration, run `just up` and you should now have a fully functioning provider DAG! ** + +*NOTE*: when your code is merged, the DAG will become available in production but will be disabled by default. A contributor with Airflow access will need to manually turn the DAG on in production. diff --git a/openverse_catalog/docs/assets/provider_dags/multi_media_dag.png b/openverse_catalog/docs/assets/provider_dags/multi_media_dag.png new file mode 100644 index 000000000..ac330030f Binary files /dev/null and b/openverse_catalog/docs/assets/provider_dags/multi_media_dag.png differ diff --git a/openverse_catalog/docs/assets/provider_dags/simple_dag.png b/openverse_catalog/docs/assets/provider_dags/simple_dag.png new file mode 100644 index 000000000..a49958313 Binary files /dev/null and b/openverse_catalog/docs/assets/provider_dags/simple_dag.png differ diff --git a/openverse_catalog/docs/data_models.md b/openverse_catalog/docs/data_models.md new file mode 100644 index 000000000..ec48aff13 --- /dev/null +++ b/openverse_catalog/docs/data_models.md @@ -0,0 +1,57 @@ +** + +# Data Models + +The following is temporary, limited documentation of the columns for each of our Catalog data models. + +## Required Fields + +| field name | description | +| --- | --- | +| *foreign_identifier* | Unique identifier for the record on the source site. | +| *foreign_landing_url* | URL of page where the record lives on the source website. | +| *audio_url* / *image_url* | Direct link to the media file. Note that until [issue #784 is addressed](https://github.com/WordPress/openverse-catalog/issues/784) the field name differs depending on media type. | +| *license_info* | [LicenseInfo object](https://github.com/WordPress/openverse-catalog/blob/8423590fd86a0a3272ca91bc11f2f37979048181/openverse_catalog/dags/common/licenses/licenses.py#L25) that has (1) the URL of the license for the record, (2) string representation of the license, (3) version of the license, (4) raw license URL that was by provider, if different from canonical URL. Usually generated by calling [`get_license_info`](https://github.com/WordPress/openverse-catalog/blob/8423590fd86a0a3272ca91bc11f2f37979048181/openverse_catalog/dags/common/licenses/licenses.py#L29) on respective fields returns/available from the API. | + +## Optional Fields + +The following fields are optional, but it is highly encouraged to populate as much data as possible: + +| field name | description | +| --- | --- | +| *thumbnail_url* | Direct link to a thumbnail-sized version of the record. | +| *filesize* | Size of the main file in bytes. | +| *filetype* | The filetype of the main file, eg. 'mp3', 'jpg', etc. | +| *creator* | The creator of the image. | +| *creator_url* | The user page, or home page of the creator. | +| *title* | Title of the record. | +| *meta_data* | Dictionary of metadata about the record. Currently, a key we prefer to have is `description`. | +| *raw_tags* | List of tags associated with the record. | +| *watermarked* | Boolean, true if the record has a watermark. | + +#### Image-specific fields + +Image also has the following fields: + +| field_name | description | +| --- | --- | +| *width* | Image width in pixels. | +| *height* | Image height in pixels. | + +#### Audio-specific fields + +Audio has the following fields: + +| field_name | description | +| --- | --- | +| *duration* | Audio duration in milliseconds. | +| *bit_rate* | Audio bit rate as int. | +| *sample_rate* | Audio sample rate as int. | +| *category* | Category such as 'music', 'sound', 'audio_book', or 'podcast'. | +| *genres* | List of genres. | +| *set_foreign_id* | Unique identifier for the audio set on the source site. | +| *audio_set* | The name of the set (album, pack, etc) the audio is part of. | +| *set_position* | Position of the audio in the audio_set. | +| *set_thumbnail* | URL of the audio_set thumbnail. | +| *set_url* | URL of the audio_set. | +| *alt_files* | A dictionary with information about alternative files for the audio (different formats/quality). Dict should have the following keys: *url*, *filesize*, *bit_rate*, *sample_rate*. diff --git a/openverse_catalog/docs/provider_data_ingester_faq.md b/openverse_catalog/docs/provider_data_ingester_faq.md new file mode 100644 index 000000000..612cf7bd8 --- /dev/null +++ b/openverse_catalog/docs/provider_data_ingester_faq.md @@ -0,0 +1,142 @@ +# ProviderDataIngester FAQ + +The most straightforward implementation of a `ProviderDataIngester` repeats the following process: + +* Builds a set of query params for the next request, based upon the previous params (for example, by updating offsets or page numbers) +* Makes a single GET request to the configured static `endpoint` using the built params +* Extracts a "batch" of records from the response, as a list of record JSON representations +* Iterates over the records in the batch, extracting desired data, and commits them to local storage + +Some provider APIs may not fit neatly into this workflow. This document addresses some common use cases. + +## How do I process a provider "record" that contains data about multiple records? + +**Example**: You're pulling data from a Museum database, and each "record" in a batch contains multiple photos of a single physical object. + +**Solution**: The `get_record_data` method takes a `data` object representing a single record from the provider API. Typically, it extracts required data and returns it as a single dict. However, it can also return a **list of dictionaries** for cases like the one described, where multiple Openverse records can be extracted. + +```python +def get_record_data(self, data: dict) -> dict | list[dict] | None: + records = [] + + for record in data.get("associatedImages", []): + # Perform data processing for each record and add it to a list + records.append({ + "foreign_landing_url": record.get("foreign_landing_url") + ... + }) + + return records +``` + +## What if I can't get all the necessary information for a record from a single API request? + +**Example**: A provider's `search` endpoint returns a list of image records containing *most* of the information we need for each record, but not image dimensions. This information *is* available via the API by hitting a `details` endpoint for a given image, though. + +**Solution**: In this case, you can reuse the `get_response_json` method by passing in the endpoint you need: + +```python +def get_record_data(self, data: dict) -> dict | list[dict] | None: + ... + + # Get data from the details endpoint + response_json = self.get_response_json( + query_params={"my_params": "foo"}, + endpoint=f"https://foobar.museum.org/api/v1/images/{data.get("uuid")}" + ) + + ... +``` + +**NOTE**: When doing this, keep in mind that adding too many requests may slow down ingestion. Be aware of rate limits from your provider API as well. + +## What if my API endpoint isn't static and needs to change from one request to another? + +**Example**: Rather than passing a `page` number in query parameters, a provider expects the `page` as part of the endpoint path itself. + +**Solution**: If your `endpoint` needs to change, you can implement it as a `property`: + +```python +@property +def endpoint(self) -> str: + # Compute the endpoint using some instance variable + return f"https://foobar.museum.org/images/page/{self.page_number}" +``` + +In this example, `self.page_number` is an instance variable that gets updated after each request. To set up the instance variable you can override `__init__`, **being careful to remember to call `super` and pass through kwargs**, and then update it in `get_next_query_params`: + +```python +def __init__(self, *args, **kwargs): + # REQUIRED! + super().__init__(*args, **kwargs) + + # Set up our instance variable + self.page_number = None + +def get_next_query_params(self, prev_query_params: dict | None, **kwargs) -> dict: + # Remember that `get_next_query_params` is called before every request, even + # the first one. + + if self.page_number is None: + # Set initial value + self.page_number = 0 + else: + # Increment value on subsequent requests + self.page_number += 1 + + # Return your actual query params + return {} +``` + +Now each time `get_batch` is called, the `endpoint` is correctly updated. + +## How do I run ingestion for a set of discrete categories? + +**Example**: My provider has some set of categories that I'd like to iterate over and ingest data for. E.g., a particular audio provider's search endpoint requires you specify whether you're searching for "podcasts", "music", etc. I'd like to iterate over all the available categories and run ingestion for each. + +**Solution**: You can do this by overriding the `ingest_records` method, which accepts optional `kwargs` that it passes through on each call to `get_next_query_params`. This is best demonstrated with code: + +```python +CATEGORIES = ["music", "audio_book", "podcast"] + +def ingest_records(self, **kwargs): + # Iterate over categories and call the main ingestion function, passing in + # our category as a kwarg + for category in CATEGORIES: + super().ingest_records(category=category) + +def get_next_query_params(self, prev_query_params, **kwargs): + # Our category param will be available here + category = kwargs.get("category") + + # Add it to your query params + return { + "category": category, + ... + } +``` + +This will result in the ingestion function running once for each category. + +## What if I need to do more complex processing to get a batch? + +**Example**: A single GET request is insufficient to get a batch from a provider. Instead, several requests need to be made in sequence until a "batchcomplete" token is encountered. + +**Solution**: You can override `get_response_json` in order to implement more complex behavior. + +```python +# Psuedo code serves as an example +def get_response_json( + self, query_params: dict, endpoint: str | None = None, **kwargs +): + batch_json = None + + while True: + partial_response = super().get_response_json(query_params) + batch_json = self.merge_data(batch_json, partial_response) + + if "batchcomplete" in response_json: + break + + return batch_json +``` diff --git a/openverse_catalog/templates/__init__.py b/openverse_catalog/templates/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/openverse_catalog/templates/create_provider_ingester.py b/openverse_catalog/templates/create_provider_ingester.py new file mode 100644 index 000000000..74f8e61a1 --- /dev/null +++ b/openverse_catalog/templates/create_provider_ingester.py @@ -0,0 +1,180 @@ +""" +Script used to generate a templated ProviderDataIngester. +""" + +import argparse +import re +from pathlib import Path + +import inflection + + +TEMPLATES_PATH = Path(__file__).parent +PROJECT_PATH = TEMPLATES_PATH.parent +REPO_PATH = PROJECT_PATH.parent +MEDIA_TYPES = ["audio", "image"] + + +def _render_provider_configuration(provider: str, media_type: str): + """ + Render the provider configuration string for a particular media type. + """ + return f'"{media_type}": prov.{provider}_{media_type.upper()}_PROVIDER,' + + +def _get_filled_template( + template_path: Path, provider: str, endpoint: str, media_types: list[str] +): + with template_path.open("r", encoding="utf-8") as template: + camel_provider = inflection.camelize(provider) + screaming_snake_provider = inflection.underscore(provider).upper() + + # Build provider configuration + provider_configuration = "\n ".join( + _render_provider_configuration(screaming_snake_provider, media_type) + for media_type in media_types + ) + + template_string = template.read() + script_string = ( + template_string.replace("{provider}", camel_provider) + .replace("{screaming_snake_provider}", screaming_snake_provider) + .replace("{provider_underscore}", inflection.underscore(provider)) + .replace("{provider_data_ingester}", f"{camel_provider}DataIngester") + .replace("{endpoint}", endpoint) + .replace("{provider_configuration}", provider_configuration) + ) + + return script_string + + +def _render_file( + target: Path, + template_path: Path, + provider: str, + endpoint: str, + media_types: list[str], + name: str, +): + with target.open("w", encoding="utf-8") as target_file: + filled_template = _get_filled_template( + template_path, provider, endpoint, media_types + ) + target_file.write(filled_template) + print(f"{name + ':':<18} {target.relative_to(REPO_PATH)}") + + +def fill_template(provider, endpoint, media_types): + print(f"Creating files in {REPO_PATH}") + + dags_path = PROJECT_PATH / "dags" / "providers" + api_path = dags_path / "provider_api_scripts" + filename = inflection.underscore(provider) + + # Render the API file itself + script_template_path = TEMPLATES_PATH / "template_provider.py_template" + api_script_path = api_path / f"{filename}.py" + _render_file( + api_script_path, + script_template_path, + provider, + endpoint, + media_types, + "API script", + ) + + # Render the tests + script_template_path = TEMPLATES_PATH / "template_test.py_template" + tests_path = REPO_PATH / "tests" + # Mirror the directory structure, but under the "tests" top level directory + test_script_path = tests_path.joinpath(*api_path.parts[-3:]) / f"test_{filename}.py" + + _render_file( + test_script_path, + script_template_path, + provider, + endpoint, + media_types, + "API script test", + ) + + print( + """ +NOTE: You will also need to add a new ProviderWorkflow dataclass configuration to the \ +PROVIDER_WORKFLOWS list in `openverse-catalog/dags/providers/provider_workflows.py`. +""" + ) + + +def sanitize_provider(provider: str) -> str: + """ + Takes a provider string from user input and sanitizes it by: + - removing trailing whitespace + - replacing spaces and periods with underscores + - removing all characters other than alphanumeric characters, dashes, + and underscores. + + Eg: sanitize_provider("hello world.foo*/bar2&") -> "hello_world_foobar2" + """ + provider = provider.strip().replace(" ", "_").replace(".", "_") + + # Remove unsupported characters + return re.sub("[^0-9a-xA-Z-_]+", "", provider) + + +def parse_media_types(media_types: list[str]) -> list[str]: + """ + Parses valid media types out from user input. Defaults to ["image",] + """ + valid_media_types = [] + + if media_types is None: + media_types = [] + + for media_type in media_types: + if media_type in MEDIA_TYPES: + valid_media_types.append(media_type) + else: + print(f"Ignoring invalid type {media_type}") + + # Default to image if no valid types given + if not valid_media_types: + print('No media type given, defaulting to ["image",]') + return [ + "image", + ] + + return valid_media_types + + +def main(): + parser = argparse.ArgumentParser( + description="Create a new provider API ProviderDataIngester", + add_help=True, + ) + parser.add_argument( + "provider", help='Create the ingester for this provider (eg. "Wikimedia").' + ) + parser.add_argument( + "endpoint", + help="API endpoint to fetch data from" + ' (eg. "https://commons.wikimedia.org/w/api.php").', + ) + parser.add_argument( + "-m", + "--media", + type=str, + nargs="*", + help="Ingester will collect media of these types" + " ('audio'/'image'). Default value is ['image',]", + ) + args = parser.parse_args() + provider = sanitize_provider(args.provider) + endpoint = args.endpoint + media_types = parse_media_types(args.media) + + fill_template(provider, endpoint, media_types) + + +if __name__ == "__main__": + main() diff --git a/openverse_catalog/templates/template_provider.py_template b/openverse_catalog/templates/template_provider.py_template new file mode 100644 index 000000000..1adab17ec --- /dev/null +++ b/openverse_catalog/templates/template_provider.py_template @@ -0,0 +1,164 @@ +""" +TODO: This doc string will be used to generate documentation for the DAG in +DAGs.md. Update it to include any relevant information that you'd like to +be documented. + +Content Provider: {provider} + +ETL Process: Use the API to identify all CC licensed media. + +Output: TSV file containing the media and the + respective meta-data. + +Notes: {endpoint} +""" +import logging + +from airflow.models import Variable +from common import constants +from common.licenses import get_license_info +from common.loader import provider_details as prov +from providers.provider_api_scripts.provider_data_ingester import ProviderDataIngester + + +logger = logging.getLogger(__name__) + + +class {provider}DataIngester(ProviderDataIngester): + """ + This is a template for a ProviderDataIngester. + + Methods are shown with example implementations. Adjust them to suit your API. + """ + + # TODO: Add the provider constants to `common.loader.provider_details.py` + providers = { + {provider_configuration} + } + endpoint = "{endpoint}" + # TODO The following are set to their default values. Remove them if the defaults + # are acceptable, or override them. + delay = 1 + retries = 3 + batch_limit = 100 + headers = {} + + def get_next_query_params(self, prev_query_params: dict | None, **kwargs) -> dict: + # On the first request, `prev_query_params` will be `None`. We can detect this + # and return our default params. + if not prev_query_params: + # TODO: Return your default params here. `batch_limit` is not automatically added to + # the params, so make sure to add it here if you need it! + # TODO: If you need an API key, add the (empty) key to `openverse_catalog/env.template` + # Do not hardcode API keys! + return { + "limit": self.batch_limit, + "cc": 1, + "offset": 0, + "api_key": Variable.get("API_KEY_{screaming_snake_provider}") + } + else: + # TODO: Update any query params that change on subsequent requests. + # Example case shows the offset being incremented by batch limit. + return { + **prev_query_params, + "offset": prev_query_params["offset"] + self.batch_limit, + } + + def get_batch_data(self, response_json): + # Takes the raw API response from calling `get` on the endpoint, and returns + # the list of records to process. + # TODO: Update based on your API. + if response_json: + return response_json.get("results") + return None + + def get_media_type(self, record: dict): + # For a given record json, return the media type it represents. + # TODO: Update based on your API. TIP: May be hard-coded if the provider only + # returns records of one type, eg `return constants.IMAGE` + return record['media_type'] + + def get_record_data(self, data: dict) -> dict | list[dict] | None: + # Parse out the necessary info from the record data into a dictionary. + # TODO: Update based on your API. + # TODO: Important! Refer to the most up-to-date documentation about the + # available fields in `openverse_catalog/docs/data_models.md` + + # REQUIRED FIELDS: + # - foreign_identifier + # - foreign_landing_url + # - license_info + # - image_url / audio_url + # + # If a required field is missing, return early to prevent unnecesary + # processing. + if (foreign_identifier := data.get("foreign_id")) is None: + return None + + if (foreign_landing_url := data.get("url")) is None: + return None + + # TODO: Note the url field name differs depending on field type. Append + # `image_url` or `audio_url` depending on the type of record being processed. + if (image_url := data.get("image_url")) is None: + return None + + # Use the `get_license_info` utility to get license information from a URL. + license_url = data.get("license") + license_info = get_license_info(license_url) + if license_info is None: + return None + + # OPTIONAL FIELDS + # Obtain as many optional fields as possible. + thumbnail_url = data.get("thumbnail") + filesize = data.get("filesize") + filetype = data.get("filetype") + creator = data.get("creator") + creator_url = data.get("creator_url") + title = data.get("title") + meta_data = data.get("meta_data") + raw_tags = data.get("tags") + watermarked = data.get("watermarked") + + # MEDIA TYPE-SPECIFIC FIELDS + # Each Media type may also have its own optional fields. See documentation. + # TODO: Populate media type-specific fields. + # If your provider supports more than one media type, you'll need to first + # determine the media type of the record being processed. + # + # Example: + # media_type = self.get_media_type(data) + # media_type_specific_fields = self.get_media_specific_fields(media_type, data) + # + # If only one media type is supported, simply extract the fields here. + + return { + "foreign_landing_url": foreign_landing_url, + "image_url": image_url, + "license_info": license_info, + # Optional fields + "foreign_identifier": foreign_identifier, + "thumbnail_url": thumbnail_url, + "filesize": filesize, + "filetype": filetype, + "creator": creator, + "creator_url": creator_url, + "title": title, + "meta_data": meta_data, + "raw_tags": raw_tags, + "watermarked": watermarked, + # TODO: Remember to add any media-type specific fields here + } + + +def main(): + # Allows running ingestion from the CLI without Airflow running for debugging + # purposes. + ingester = {provider}DataIngester() + ingester.ingest_records() + + +if __name__ == "__main__": + main() diff --git a/openverse_catalog/templates/template_test.py_template b/openverse_catalog/templates/template_test.py_template new file mode 100644 index 000000000..3c4073c40 --- /dev/null +++ b/openverse_catalog/templates/template_test.py_template @@ -0,0 +1,67 @@ +""" +TODO: Add additional tests for any methods you added in your subclass. +Try to test edge cases (missing keys, different data types returned, Nones, etc). +You may also need to update the given test names to be more specific. + +Run your tests locally with `just test -k {provider_underscore}` +""" + +import json +from pathlib import Path + +import pytest +from providers.provider_api_scripts.{provider_underscore} import {provider_data_ingester} + +# TODO: API responses used for testing can be added to this directory +RESOURCES = Path(__file__).parent / "resources/{provider_underscore}" + +# Set up test class +ingester = {provider_data_ingester}() + + +def test_get_next_query_params_default_response(): + actual_result = ingester.get_next_query_params(None) + expected_result = { + # TODO: Fill out expected default query params + } + assert actual_result == expected_result + + +def test_get_next_query_params_updates_parameters(): + previous_query_params = { + # TODO: Fill out a realistic set of previous query params + } + actual_result = ingester.get_next_query_params(previous_query_params) + + expected_result = { + # TODO: Fill out what the next set of query params should be, + # incrementing offsets or page numbers if necessary + } + assert actual_result == expected_result + + +def test_get_media_type(): + # TODO: Test the correct media type is returned for each possible media type. + pass + + +def test_get_record_data(): + # High level test for `get_record_data`. One way to test this is to create a + # `tests/resources/{provider}/single_item.json` file containing a sample json + # representation of a record from the API under test, call `get_record_data` with + # the json, and directly compare to expected output. + # + # Make sure to add additional tests for records of each media type supported by + # your provider. + + # Sample code for loading in the sample json + with open(RESOURCES / "single_item.json") as f: + resource_json = json.load(f) + + actual_data = ingester.get_record_data(resource_json) + + expected_data = { + # TODO: Fill out the expected data which will be saved to the Catalog + } + + assert actual_data == expected_data diff --git a/tests/dags/providers/provider_api_scripts/test_provider_data_ingester.py b/tests/dags/providers/provider_api_scripts/test_provider_data_ingester.py index 7da644683..e6a6d0fa3 100644 --- a/tests/dags/providers/provider_api_scripts/test_provider_data_ingester.py +++ b/tests/dags/providers/provider_api_scripts/test_provider_data_ingester.py @@ -133,7 +133,7 @@ def test_ingest_records(): with ( patch.object(ingester, "get_batch") as get_batch_mock, patch.object(ingester, "process_batch", return_value=3) as process_batch_mock, - patch.object(ingester, "commit_records") as commit_mock, + patch.object(ingester, "_commit_records") as commit_mock, ): get_batch_mock.side_effect = [ (EXPECTED_BATCH_DATA, True), # First batch @@ -231,7 +231,7 @@ def test_ingest_records_commits_on_exception(): with ( patch.object(ingester, "get_batch") as get_batch_mock, patch.object(ingester, "process_batch", return_value=3) as process_batch_mock, - patch.object(ingester, "commit_records") as commit_mock, + patch.object(ingester, "_commit_records") as commit_mock, ): get_batch_mock.side_effect = [ (EXPECTED_BATCH_DATA, True), # First batch @@ -384,7 +384,7 @@ def test_commit_commits_all_stores(): patch.object(audio_store, "commit") as audio_store_mock, patch.object(image_store, "commit") as image_store_mock, ): - ingester.commit_records() + ingester._commit_records() assert audio_store_mock.called assert image_store_mock.called diff --git a/tests/templates/test_create_provider_ingester.py b/tests/templates/test_create_provider_ingester.py new file mode 100644 index 000000000..1d0c03635 --- /dev/null +++ b/tests/templates/test_create_provider_ingester.py @@ -0,0 +1,74 @@ +from pathlib import Path + +import pytest + +from openverse_catalog.templates import create_provider_ingester + + +@pytest.mark.parametrize( + "media_types_str, expected_types", + [ + # Just image + (["image"], ["image"]), + # Just audio + (["audio"], ["audio"]), + # Multiple valid types + (["image", "audio"], ["image", "audio"]), + # Discard only invalid types + (["image", "blorfl"], ["image"]), + (["blorfl", "audio", "image"], ["audio", "image"]), + # Defaults to image when all given types are invalid + (["blorfl", "wat"], ["image"]), + # Defaults to image when no types are given at all + ([""], ["image"]), + (None, ["image"]), + ], +) +def test_parse_media_types(media_types_str, expected_types): + actual_result = create_provider_ingester.parse_media_types(media_types_str) + assert actual_result == expected_types + + +@pytest.mark.parametrize( + "provider, expected_result", + [ + ("FoobarIndustries", "FoobarIndustries"), + # Do not remove hyphens or underscores + ("hello-world_foo", "hello-world_foo"), + # Replace spaces + ("Foobar Industries", "Foobar_Industries"), + # Replace periods + ("foobar.com", "foobar_com"), + # Remove trailing whitespace + (" hello world ", "hello_world"), + # Replace special characters + ("hello.world-foo*/bar2", "hello_world-foobar2"), + ], +) +def test_sanitize_provider(provider, expected_result): + actual_result = create_provider_ingester.sanitize_provider(provider) + assert actual_result == expected_result + + +def test_files_created(): + provider = "foobar_industries" + endpoint = "https://myfakeapi/v1" + media_type = "image" + + dags_path = create_provider_ingester.TEMPLATES_PATH.parent / "dags" / "providers" + expected_provider = dags_path / "provider_api_scripts" / "foobar_industries.py" + expected_test = ( + Path(__file__).parents[1] + / "dags" + / "providers" + / "provider_api_scripts" + / "test_foobar_industries.py" + ) + try: + create_provider_ingester.fill_template(provider, endpoint, media_type) + assert expected_provider.exists() + assert expected_test.exists() + finally: + # Clean up + expected_provider.unlink(missing_ok=True) + expected_test.unlink(missing_ok=True)