diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1e0162f..4b3dc6b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8, 3.9, "3.10"] + python-version: ["3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v2 @@ -22,8 +22,6 @@ jobs: - name: Install Dependencies run: | pip install .[test] - # Aspera installation # - . install-aspera.sh - name: 🧪 - Testing run: | pytest -v diff --git a/Containerfile b/Containerfile index f8d8c63..23b9172 100644 --- a/Containerfile +++ b/Containerfile @@ -1,21 +1,22 @@ -FROM python:3.9-slim +FROM mambaorg/micromamba:1.5.8 LABEL maintainer="Microbiome Informatics" -LABEL version="0.9.0" -LABEL description="EBI Fetch Tool Docker Image." +LABEL version="1.0.0" +LABEL description="EBI Fetch Tool." -# We need curl to download aspera and ps for nextflow monitoring -ENV DEBIAN_FRONTEND=noninteractive +COPY --chown=$MAMBA_USER:$MAMBA_USER conda_environment.yml /tmp/env.yaml -RUN apt update && apt install -y curl procps && rm -rf /var/lib/apt/lists/* +RUN micromamba install -y -n base -f /tmp/env.yaml && \ + micromamba clean --all --yes -COPY . . +ARG MAMBA_DOCKERFILE_ACTIVATE=1 -RUN pip install --no-cache-dir . +ENV PATH="$MAMBA_ROOT_PREFIX/bin:$PATH" -# Aspera is an IBM library for data sharing -RUN ./install-aspera.sh +COPY --chown=$MAMBA_USER:$MAMBA_USER . /opt/fetch-tool-src -RUN export PATH=$PATH:/aspera-cli/cli/bin +WORKDIR /opt/fetch-tool-src -CMD [ fetch-read-tool ] +RUN pip install . --no-cache-dir + +ENTRYPOINT ["/usr/local/bin/_entrypoint.sh"] diff --git a/README.md b/README.md index e5c4ef0..f7e1642 100644 --- a/README.md +++ b/README.md @@ -35,19 +35,12 @@ pre-commit will run a set of pre-configured tools before allowing you to commit This repo uses [pytest](https://docs.pytest.org). -It requires the aspera cli installed in the default location (`install-aspera.sh` with no parameters). - -To run the test suite: -```bash -pytest -``` - ## Install fetch tool ### Using Conda ```bash -$ conda create -q -n fetch_tool python=3.8 +$ conda create -q -n fetch_tool python=3.10 $ conda activate fetch_tool ``` @@ -60,33 +53,20 @@ $ pip install fetch-tool Install from the git repo ```bash -$ pip install git+ssh://git@github.com/EBI-Metagenomics/fetch_tool.git +$ pip install https://github.com/EBI-Metagenomics/fetch_tool/archive/master.zip ``` -#### Configuration file +#### Configuration options + +The tool has a number of options, with sensible defaults for the most common use cases. Setup the configuration file, the template [fetchdata-config-template.json](config/fetchdata-config-template.json) for the configuration file. The required fields are: -- For Aspera - - aspera_bin (the path to ascp, usually in the aspera installation under /cli/bin) - - aspera_cert (the path to the ascp provided cert, usually in the aspera installation under /cli/etc/asperaweb_id_dsa.openssh) -- To pull private ENA data +- - ena_api_user - ena_api_password -### Install Aspera - -## Install - -Run the `install-aspera.sh` command here, it has only one optional parameter (the installation folder). - -```bash -./install path/to/installation-i-want -``` - -Otherwise it will install it in $PWD/aspera-cli - ## Fetch read files (amplicon and WGS data) ### Usage @@ -122,7 +102,7 @@ optional arguments: Download amplicon study: ```bash -$ fetch-read-tool -p SRP062869 -c fetchdata-config-local.json -v -d /home//temp/ +$ fetch-read-tool -p SRP062869 -v -d /home//temp/ ``` ## Fetch assembly files @@ -163,5 +143,5 @@ optional arguments: Download assembly study: ```bash -$ fetch-assembly-tool -p ERP111288 -c fetchdata-config-local.json -v -d /home//temp/ +$ fetch-assembly-tool -p ERP111288 -v -d /home//temp/ ``` diff --git a/conda_environment.yml b/conda_environment.yml new file mode 100644 index 0000000..cc2e32b --- /dev/null +++ b/conda_environment.yml @@ -0,0 +1,16 @@ +name: fetchtool +channels: + - bioconda + - conda-forge + - defaults +dependencies: + - python=3.10 + - pip=24.0 + - conda-forge::procps-ng=4.0.4 + - conda-forge::wget=1.21.4 + - conda-forge::rsync=3.3.0 + - conda-forge::pandas=2.2.2 + - pip: + - requests==2.32.3 + - flufl.lock==8.1.0 + - boto3==1.34.134 diff --git a/config/testing.json b/config/testing.json index a609671..0543b4a 100755 --- a/config/testing.json +++ b/config/testing.json @@ -1,7 +1,5 @@ { "url_max_attempts": 5, "ena_api_username": "", - "ena_api_password": "", - "aspera_bin": "", - "aspera_cert": "" + "ena_api_password": "" } diff --git a/fetchtool/__init__.py b/fetchtool/__init__.py index 3e2f46a..5becc17 100644 --- a/fetchtool/__init__.py +++ b/fetchtool/__init__.py @@ -1 +1 @@ -__version__ = "0.9.0" +__version__ = "1.0.0" diff --git a/fetchtool/abstract_fetch.py b/fetchtool/abstract_fetch.py index 089ae6c..f76269a 100644 --- a/fetchtool/abstract_fetch.py +++ b/fetchtool/abstract_fetch.py @@ -28,15 +28,32 @@ from abc import ABC, abstractmethod from importlib.metadata import version +import boto3 import pandas as pd import requests +from botocore import UNSIGNED +from botocore.config import Config from flufl.lock import Lock from pandas.errors import EmptyDataError +from tenacity import ( + RetryError, + before_log, + retry, + retry_if_result, + stop_after_attempt, + wait_exponential, +) from fetchtool.exceptions import ENAFetch204, ENAFetch401, ENAFetchFail PRIVATE_ENA_FTP = "ftp.dcc-private.ebi.ac.uk" PUBLIC_ENA_FTP = "ftp.ebi.ac.uk" +MAX_ATTEMPTS = 3 + + +def is_false(value): + """Tenacity - retry method""" + return value is False class AbstractDataFetcher(ABC): @@ -64,13 +81,16 @@ def __init__(self, argv=sys.argv[1:]): self.create_output_dir(self.args.dir) self.base_dir = self.args.dir + self.config = {} + self._load_default_config_values() + config_file = os.getenv("FETCH_TOOL_CONFIG", None) if not self.args.config_file and not config_file: - raise ValueError("Missing configuration file. It shoud be provided using -c or setting the env variable $FETCH_TOOL_CONFIG") - - with open(self.args.config_file or config_file) as f: - self.config = json.load(f) + logging.debug("No config file was provided, the tool will use the default values for public data") + else: + with open(self.args.config_file or config_file) as f: + self.config = self.config | json.load(f) self.ENA_API_USER = self.config["ena_api_username"] self.ENA_API_PASSWORD = self.config["ena_api_password"] @@ -80,8 +100,7 @@ def __init__(self, argv=sys.argv[1:]): self.force_mode = self.args.force self.desc_file_only = self.args.fix_desc_file self.ignore_errors = self.args.ignore_errors - - self.prod_user = os.environ.get("USER") == "emgpr" + self.ebi = self.args.ebi self._process_additional_args() if self.args.projects or self.args.project_list: @@ -148,9 +167,18 @@ def _parse_args(self, argv): help="Fixed runs in project description file", action="store_true", ) + parser.add_argument("-e", "--ebi", required=False, help="Set this flag when running on EBI infrastructure", action="store_true") parser = self.add_arguments(parser) return parser.parse_args(argv) + def _load_default_config_values(self): + """Load the default values in the config object""" + self.config["ena_api_username"] = "" + self.config["ena_api_password"] = "" + self.config["url_max_attempts"] = 5 + self.config["fire_endpoint"] = "http://hl.fire.sdo.ebi.ac.uk" + self.config["fire_ena_bucket"] = "era-public" + @staticmethod def add_arguments(parser): return parser @@ -226,29 +254,45 @@ def download_raw_files(self, project_accession, new_runs): file_md5s = run["MD5"] for dl_file, dl_name in zip(download_sources, filenames): dest = os.path.join(raw_dir, dl_name) - self.download_raw_file(dl_file, dest, file_md5s, self.private_mode) - - def download_raw_file(self, dl_file, dest, dl_md5s, is_public): + try: + self.download_raw_file(dl_file, dest, file_md5s) + except RetryError: + logging.error(f"Failed to download file {dl_file}.") + if not self.ignore_errors: + raise + + @retry( + retry=retry_if_result(is_false), + stop=stop_after_attempt(MAX_ATTEMPTS), + wait=wait_exponential(multiplier=1, min=2, max=5), + before=before_log(logging, logging.DEBUG), + ) + def download_raw_file(self, dl_file, dest, dl_md5s): """ Returns true if file was re-downloaded """ filename = os.path.basename(dest) file_downloaded = False if not self._is_file_valid(dest, dl_md5s) or self.force_mode: - silentremove(dest) + silent_remove(dest) try: - file_downloaded = self.download_aspera(dest, dl_file) + # Copying data from NFS within EBI infrastructure only works for public data + if not self.private_mode and self.ebi: + logging.info("Downloading using EBI's Fire AWS compatible storage") + file_downloaded = self.download_fire(dest, dl_file) if not file_downloaded: - logging.info("Aspera didn't work.. trying FTP with lftp") + logging.info("Downloading from the FTP server with lftp.") file_downloaded = self.download_lftp(dest, dl_file) + if not self.private_mode and not file_downloaded: + logging.info("Downloading with rsync using EBI's rsync server.") + file_downloaded = self.download_rsync(dest, dl_file) if not file_downloaded: - logging.info("FTP didn't worked... trying wget now...") + logging.info("Downloading with wget.") file_downloaded = self.download_wget(dest, dl_file) except Exception as e: - if self.ignore_errors: - logging.warning(e) - else: - raise e + logging.error(e) + if not self.ignore_errors: + return False else: logging.info("File {} already exists and MD5 matches, skipping download".format(filename)) @@ -398,14 +442,7 @@ def _filter_secondary_files(self, joined_file_names, md5s): return filtered_file_names, filtered_md5s def _get_raw_filenames(self, filepaths, md5s, run_id, is_submitted_file): - """ - Rename file names if submitted files or if generated assemblies - :param filepaths: - :param md5s: - :param run_id: - :param is_submitted_file: - :return: - """ + """Rename file names if submitted files or if generated assemblies""" filepaths, md5s = self._filter_secondary_files(filepaths, md5s) if is_submitted_file or (not is_submitted_file and run_id.startswith("ERZ")): file_names = self._rename_raw_files(filepaths, run_id) @@ -471,54 +508,54 @@ def _is_file_valid(dest, file_md5): logging.info("File {} exists, but MD5 does not match".format(basename)) return False - def download_wget(self, dest, url, auth=True): + def download_wget(self, dest, url): """Download the files on the url using wget.""" if url[:4] == "ftp.": url = "ftp://" + url - attempt = 0 - while True: - try: - logging.info("Downloading file from FTP server..." + url) - download_command = [ - "wget", - "-v", - "--user={}".format(self.ENA_API_USER), - "--password={}".format(self.ENA_API_PASSWORD) if auth else "-q", - "-t", - "5", - "-O", - dest, - url, - ] - retcode = subprocess.call(download_command) - if retcode: - logging.error("Error downloading the file from " + url) - else: - logging.info("Done.") - break - except IOError as err: - logging.error("Error downloading the file from " + url) - logging.error(err) - attempt += 1 - if attempt >= self.config["url_max_attempts"]: - logging.critical("Failed to retrieve" + url + " after " + str(attempt) + " attempts") - if self.interactive_mode: - var = input("Please type C to continue to fetch the next sequence file or anything else to exit: ") - if not var.upper().startswith("C"): - logging.info("Exiting now") - sys.exit(0) - else: - break - else: - if self.force_mode: - logging.warning("Force mode is activated. Will skip the download of this run and move onto the next sequence!") - break - else: - logging.warning( - "Too many failed attempts. Program will exit now. " - + "Try again to fetch the data in interactive mode (-i option)!" - ) - sys.exit(1) + logging.info("Downloading file from FTP server..." + url) + download_command = ["wget", f"--user={self.ENA_API_USER if self.private_mode else 'anonymous' }"] + if self.private_mode: + download_command.append( + f"--password={self.ENA_API_PASSWORD}", + ) + download_command.extend( + [ + "-q", + "-t", + "5", + "-O", + dest, + url, + ] + ) + logging.info(" ".join(download_command)) + result = subprocess.run(download_command, text=True, capture_output=True) + if result.returncode != 0: + logging.error(f"Error downloading the file with wget. Command: {' '.join(download_command)}.") + logging.error(f"Stdout: {result.stdout}") + logging.error(f"Stderr: {result.stderr}") + return False + return True + + def download_rsync(self, dest, url): + """Download from from the EBI rsync endpoint.""" + # replace protocol + if url[:4] == "ftp.": + url = "rsync://" + url + download_command = [ + "rsync", + "-v", + url, + dest, + ] + logging.info(" ".join(download_command)) + result = subprocess.run(download_command, capture_output=True, text=True) + if result.returncode != 0: + logging.error(f"Error rsyncing the file. Command: {' '.join(download_command)}.") + logging.error(f"Stdout: {result.stdout}") + logging.error(f"Stderr: {result.stderr}") + return False + return True def download_lftp(self, dest, url): """Download from ENA FTP server. @@ -532,84 +569,45 @@ def download_lftp(self, dest, url): path_list = url.split("ebi.ac.uk/")[-1].split("/")[:-1] path = "/".join(path_list) file_name = os.path.basename(url) - attempt = 0 - while attempt <= 3: - try: - with ftplib.FTP(server, timeout=300) as ftp: - logging.info("Downloading file from FTP server..." + url) - if self.private_mode: - logging.info("Logging in...") - ftp.login(self.ENA_API_USER, self.ENA_API_PASSWORD) - else: - logging.info("Logging as anonymous") - ftp.login() - ftp.cwd(path) - logging.info("Getting the file...") - # store with the same name - with open(dest, "wb") as output_file: - ftp.retrbinary("RETR " + file_name, output_file.write) - logging.info("File " + dest + " downloaded.") - return True - except ftplib.all_errors as e: - logging.error(e) - attempt += 1 - else: + + try: + with ftplib.FTP(server, timeout=300) as ftp: + logging.info("Downloading file from FTP server..." + url) + if self.private_mode: + logging.info("Logging in...") + ftp.login(self.ENA_API_USER, self.ENA_API_PASSWORD) + else: + logging.info("Logging as anonymous") + ftp.login() + ftp.cwd(path) + logging.info("Getting the file...") + # store with the same name + with open(dest, "wb") as output_file: + ftp.retrbinary("RETR " + file_name, output_file.write) + logging.info("File " + dest + " downloaded.") + return True + except ftplib.all_errors as e: + logging.error(e) return False - def download_aspera(self, dest: str, url: str) -> bool: - """Download using the aspera cli. + def download_fire(self, dest: str, url: str) -> bool: + """Copy the file using the aws cli to access EBI Fire. Only works within EBI Network Usage example, to get file path and names from full FTP URL - url = ftp.sra.ebi.ac.uk/vol1/sequence/ERZ166/ERZ1669403/contig.fa.gz + - dest = destination path """ - ASPERA_SERVER = self.config.get("aspera_server", "fasp.ebi.ac.uk") - ASPERA_BIN = os.environ.get("ASPERA_BIN") or self.config.get("aspera_bin") - # The cert is needed by the aspera cli tool (asperaweb_id_dsa.openssh) - which usually is in /cli/etc/" - ASPERA_CERT = os.environ.get("ASPERA_CERT") or self.config.get("aspera_cert") - if ASPERA_BIN is None or ASPERA_CERT is None: - logging.error("Aspera needs the binary ('aspera_bin') and the cert ('aspera_cert') config values") - return False - - ASPERA_PORT = self.config.get("aspera_port", 33001) - ASPERA_ENA_PUBLIC_USER = self.config.get("aspera_ena_public_user", "era-fasp") - - path = "/".join(url.split("ebi.ac.uk/")[-1].split("/")[:-1]) - file_name = os.path.basename(url) - - aspera_user_host = f"{ASPERA_ENA_PUBLIC_USER}@{ASPERA_SERVER}:{path}/{file_name}" - ascp_command = [ - ASPERA_BIN, - "-l", - "300m", - "-P", - str(ASPERA_PORT), - "-i", - ASPERA_CERT, - ] - if self.private_mode: - # For private ones we need to remove the certificate - os.environ["ASPERA_SCP_PASS"] = self.ENA_API_PASSWORD - aspera_user_host = f"{self.ENA_API_USER}@{ASPERA_SERVER}:{path}/{file_name}" - del ascp_command[-2:] - - ascp_command.extend([aspera_user_host, dest]) - - # only for debug logging level - if logging.DEBUG >= logging.root.level: - # "-L-", # print logging info, useful in case it fails - index = 5 if self.private_mode else 3 - ascp_command.insert(index, "-L-") - + fire_path = url.replace("ftp.sra.ebi.ac.uk/vol1/", "") + fire_endpoint = self.config["fire_endpoint"] + ena_bucket_name = self.config["fire_ena_bucket"] try: - logging.info("Downloading with Aspera") - logging.info(" ".join(ascp_command)) - result = subprocess.run(ascp_command, capture_output=True, text=True, check=True) - logging.info(result.stdout) - logging.debug(result.stderr) - except Exception as error: - logging.exception(error) - logging.error("Failed to download the files with aspera") + s3 = boto3.client("s3", endpoint_url=fire_endpoint, config=Config(signature_version=UNSIGNED)) + object_key = fire_path + s3.download_file(ena_bucket_name, object_key, dest) + logging.info("File downloaded successfully") + except Exception as ex: + logging.exception(ex) + logging.error(f"Download the file with boto3 (aws cli) failed source: {url}, dest: {dest}.") return False - return True @staticmethod @@ -655,7 +653,8 @@ def sanity_check_project_accessions(self): sys.exit(1) -def silentremove(filename): +def silent_remove(filename): + """Remove a file, if the file doesn't exist it will not raise an exception""" try: os.remove(filename) except OSError as e: diff --git a/install-aspera.sh b/install-aspera.sh deleted file mode 100755 index e9e32a3..0000000 --- a/install-aspera.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash -set -e - -INSTALL_DIR="${1:-$PWD/aspera-cli}" -VERSION=3.9.6 -FILE=ibm-aspera-cli-3.9.6.1467.159c5b1-linux-64-release - -if [[ "$OSTYPE" == "darwin"* ]]; then - FILE=ibm-aspera-cli-3.9.6.1467.159c5b1-mac-10.11-64-release -fi - -curl https://download.asperasoft.com/download/sw/cli/$VERSION/$FILE.sh -o $FILE.sh - -if [[ "$OSTYPE" == "darwin"* ]]; then - LANG=C sed -i.bk -e "s@INSTALL_DIR=\"\$HOME\/Applications\"@INSTALL_DIR=$INSTALL_DIR@" $FILE.sh -else - sed -i.bk "s@INSTALL_DIR=\~\/.aspera@INSTALL_DIR=$INSTALL_DIR@" $FILE.sh -fi - -chmod +x $FILE.sh - -./$FILE.sh diff --git a/pyproject.toml b/pyproject.toml index 3ab8ca4..e6c853e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,9 +21,11 @@ classifiers = [ ] dependencies = [ - "requests==2.27.1", - "flufl.lock==8.0.2", - "pandas==1.1.5", + "requests==2.32.3", + "flufl.lock==8.1.0", + "pandas==2.2.2", + "boto3==1.34.134", + "tenacity==8.4.2", ] [project.optional-dependencies] @@ -81,7 +83,7 @@ target-version = ["py38"] profile = "black" [tool.bumpversion] -current_version = "0.9.0" +current_version = "1.0.0" commit = true tag = true diff --git a/pytest.ini b/pytest.ini index 15e4ffc..0c0550f 100644 --- a/pytest.ini +++ b/pytest.ini @@ -2,3 +2,5 @@ python_files = tests/*.py pythonpath = . addopts = --cov src/ --cov-report xml --cov-report html +log_cli = true +log_cli_level = INFO diff --git a/tests/fixtures/config/complete.json b/tests/fixtures/config/complete.json new file mode 100755 index 0000000..d0fb5ae --- /dev/null +++ b/tests/fixtures/config/complete.json @@ -0,0 +1,7 @@ +{ + "ena_api_username": "ENA_FAKE", + "ena_api_password": "FAKE", + "url_max_attempts": 10, + "fire_endpoint": "fake_endpoint", + "fire_ena_bucket": "fake_bucket" +} diff --git a/tests/fixtures/config/partial.json b/tests/fixtures/config/partial.json new file mode 100755 index 0000000..98f6a8c --- /dev/null +++ b/tests/fixtures/config/partial.json @@ -0,0 +1,5 @@ +{ + "url_max_attempts": 8, + "fire_endpoint": "fake_endpoint", + "fire_ena_bucket": "fake_bucket" +} diff --git a/tests/integration/test_int_fetch_assemblies.py b/tests/integration/test_int_fetch_assemblies.py index 0068f10..73764a3 100644 --- a/tests/integration/test_int_fetch_assemblies.py +++ b/tests/integration/test_int_fetch_assemblies.py @@ -16,17 +16,13 @@ import csv import os -import subprocess -import sys from unittest.mock import patch import pytest from fetchtool import fetch_assemblies -FIXTURES_DIR = os.path.abspath( - os.path.join(os.path.dirname(__file__), os.pardir, "fixtures") -) +FIXTURES_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, "fixtures")) class WorkingDir: @@ -109,31 +105,31 @@ def validate_full_study(tmpdir): @pytest.mark.flaky class TestFetchCompleteStudyAssemblies: - @patch("fetchtool.abstract_fetch.subprocess.run") - def test_fetch_all_study_data(self, run_mock, tmpdir): - def raise_ex(*args, **kwargs): - raise Exception + @patch("fetchtool.abstract_fetch.AbstractDataFetcher.download_lftp") + @patch("fetchtool.abstract_fetch.AbstractDataFetcher.download_rsync") + def test_fetch_all_study_data(self, lftp_mock, rsync_mock, tmpdir): + lftp_mock.return_value = False + rsync_mock.return_value = False - run_mock.side_effect = raise_ex with WorkingDir(tmpdir): - fassemblies = fetch_assemblies.FetchAssemblies( - ["-p", study_id, "-v", "-d", str(tmpdir)] - ) + fassemblies = fetch_assemblies.FetchAssemblies(["-p", study_id, "-v", "-d", str(tmpdir)]) fassemblies.fetch() validate_full_study(tmpdir) - run_mock.call_count = 2 # 2 assemblies - tried with aspera + lftp_mock.call_count = 2 + rsync_mock.call_count = 2 @patch("fetchtool.fetch_assemblies.AbstractDataFetcher.download_lftp") @patch("fetchtool.fetch_assemblies.AbstractDataFetcher.download_wget") - def test_fetch_sequential_runs_with_aspera(self, lftp_mock, wget_mock, tmpdir): - def raise_ex(*args, **kwargs): - raise Exception + def test_fetch_sequential_runs_with_rsync(self, wget_mock, lftp_mock, tmpdir): + def return_false(*args, **kwargs): + return False - lftp_mock.side_effect = raise_ex - wget_mock.side_effect = raise_ex + lftp_mock.side_effect = return_false + wget_mock.side_effect = return_false with WorkingDir(tmpdir): - fassemblies = fetch_assemblies.FetchAssemblies( - ["-p", study_id, "-d", str(tmpdir)] - ) + fassemblies = fetch_assemblies.FetchAssemblies(["-p", study_id, "-d", str(tmpdir)]) fassemblies.fetch() validate_full_study(tmpdir) + + assert lftp_mock.called + assert wget_mock.called is False # this was it's only called after rsync fails diff --git a/tests/unit/test_config_loading.py b/tests/unit/test_config_loading.py new file mode 100644 index 0000000..5998720 --- /dev/null +++ b/tests/unit/test_config_loading.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Copyright 2018-2024 EMBL - European Bioinformatics Institute +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from fetchtool import fetch_reads + +FIXTURES_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, "fixtures")) + + +class TestConfigLoading: + def test_config_defaults(self): + fetch = fetch_reads.FetchReads(argv=["-p", "ERP001736"]) + assert fetch.config == { + "ena_api_username": "", + "ena_api_password": "", + "url_max_attempts": 5, + "fire_endpoint": "http://hl.fire.sdo.ebi.ac.uk", + "fire_ena_bucket": "era-public", + } + + def test_config_override_with_json_file(self): + fetch = fetch_reads.FetchReads(argv=["-p", "ERP001736", "-c", f"{FIXTURES_DIR}/config/complete.json"]) + assert fetch.config == { + "ena_api_username": "ENA_FAKE", + "ena_api_password": "FAKE", + "url_max_attempts": 10, + "fire_endpoint": "fake_endpoint", + "fire_ena_bucket": "fake_bucket", + } + + def test_config_override_partial_with_json(self): + fetch = fetch_reads.FetchReads(argv=["-p", "ERP001736", "-c", f"{FIXTURES_DIR}/config/partial.json"]) + assert fetch.config == { + "ena_api_username": "", + "ena_api_password": "", + "url_max_attempts": 8, + "fire_endpoint": "fake_endpoint", + "fire_ena_bucket": "fake_bucket", + } diff --git a/tests/unit/test_fetch_assemblies.py b/tests/unit/test_fetch_assemblies.py index 36b3cf8..a9cacb2 100644 --- a/tests/unit/test_fetch_assemblies.py +++ b/tests/unit/test_fetch_assemblies.py @@ -47,6 +47,7 @@ def test_argparse_should_include_additional_args(self): "assembly_type", "fix_desc_file", "ignore_errors", + "ebi", } assert set(vars(args)) == accepted_args diff --git a/tests/unit/test_fetch_reads.py b/tests/unit/test_fetch_reads.py index d09024a..b3eaa85 100644 --- a/tests/unit/test_fetch_reads.py +++ b/tests/unit/test_fetch_reads.py @@ -23,9 +23,7 @@ from fetchtool import fetch_reads -FIXTURES_DIR = os.path.abspath( - os.path.join(os.path.dirname(__file__), os.pardir, "fixtures") -) +FIXTURES_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, "fixtures")) @pytest.mark.flaky @@ -47,6 +45,7 @@ def test_argparse_should_include_additional_args(self): "run_list", "fix_desc_file", "ignore_errors", + "ebi", } assert set(vars(args)) == accepted_args @@ -111,9 +110,7 @@ def test_map_project_info_db_row_should_copy_fields(self): @staticmethod def mock_get_study_from_run(self, *args, **kwargs): - return [ - {"run_accession": "ERR2777789", "secondary_study_accession": "ERP110686"} - ] + return [{"run_accession": "ERR2777789", "secondary_study_accession": "ERP110686"}] """ 1. INVALID = incorrect file format @@ -181,9 +178,7 @@ def mock_get_run_metadata(self, *args, **kwargs): ] @patch("fetchtool.fetch_reads.FetchReads._retrieve_ena_url") - def test_process_additional_args_should_find_study_accessions_for_runs( - self, mocked_class1, tmpdir - ): + def test_process_additional_args_should_find_study_accessions_for_runs(self, mocked_class1, tmpdir): study_accession = "ERP110686" run_id = "ERR2777789" fetch_reads.FetchReads._retrieve_ena_url = self.mock_get_study_from_run @@ -193,9 +188,7 @@ def test_process_additional_args_should_find_study_accessions_for_runs( assert fetch.args.projects == {study_accession} @patch("fetchtool.fetch_reads.FetchReads._retrieve_ena_url") - def test_retrieve_project_should_return_only_valid_reads_and_check_md5( - self, mocked_class1, tmpdir - ): + def test_retrieve_project_should_return_only_valid_reads_and_check_md5(self, mocked_class1, tmpdir): study_accession = "ERP110686" valid_file_for = ("ERR2777790_1.fastq.gz", "39f9956b66880e386d741eea2a0e54c1") valid_file_rev = ("ERR2777790_2.fastq.gz", "9e6db19a2ef56383e8e426784ffff424") @@ -207,19 +200,14 @@ def test_retrieve_project_should_return_only_valid_reads_and_check_md5( ] fetch_reads.FetchReads._retrieve_ena_url = self.mock_get_run_metadata fetch_reads.FetchReads.download_lftp = True - fetch = fetch_reads.FetchReads( - argv=["-p", study_accession, "-d", str(tmpdir), "--private"] - ) + fetch = fetch_reads.FetchReads(argv=["-p", study_accession, "-d", str(tmpdir), "--private"]) runs = fetch._retrieve_project_info_from_api(study_accession) for x in runs: for file in x["file"]: run_path = tmpdir / file Path(str(run_path)).touch() assert len(runs) == 1 - assert ( - os.listdir(str(tmpdir)).sort() - == ["ERR2777790_2.fastq.gz", "ERR2777790_1.fastq.gz"].sort() - ) + assert os.listdir(str(tmpdir)).sort() == ["ERR2777790_2.fastq.gz", "ERR2777790_1.fastq.gz"].sort() for x, y in [valid_file_for, valid_file_rev]: assert not fetch._is_file_valid(str(tmpdir / x), y) project_dir = tmpdir / study_accession