diff --git a/configExample.json b/configExample.json index 789ab4b4..11af66a4 100644 --- a/configExample.json +++ b/configExample.json @@ -8,21 +8,24 @@ "project": "production", "type": "core" }, - "regex": { + + "regex": { "mail_recipient": "username@suffix.com", "_comment": "File finding patterns. Only single capture group accepted (for reverse/forward identifier)", "file_pattern": "\\w{8,12}_\\w{8,10}(?:-\\d+)*_L\\d_(?:R)*(\\d{1}).fastq.gz", "_comment": "Organisms recognized enough to be considered stable", "verified_organisms": [] }, + "_comment": "Folders", - "folders": { + "folders": { "_comment": "Root folder for ALL output", "results": "/tmp/MLST/results/", "_comment": "Report collection folder", "reports": "/tmp/MLST/reports/", "_comment": "Log file position and name", "log_file": "/tmp/microsalt.log", + "_comment": "Root folder for input fasta sequencing data", "seqdata": "/tmp/projects/", "_comment": "ST profiles. Each ST profile file under 'profiles' have an identicial folder under references", @@ -32,18 +35,18 @@ "_comment": "Resistances. Commonly from resFinder", "resistances": "/tmp/MLST/references/resistances", "_comment": "Download path for NCBI genomes, for alignment usage", - "genomes": "/tmp/MLST/references/genomes", - "_comment": "PubMLST credentials", - "pubmlst_credentials": "/tmp/MLST/credentials" + "genomes": "/tmp/MLST/references/genomes" }, + "_comment": "Database/Flask configuration", "database": { "SQLALCHEMY_DATABASE_URI": "sqlite:////tmp/microsalt.db", "SQLALCHEMY_TRACK_MODIFICATIONS": "False", "DEBUG": "True" }, + "_comment": "Thresholds for Displayed results", - "threshold": { + "threshold": { "_comment": "Typing thresholds", "mlst_id": 100, "mlst_novel_id": 99.5, @@ -69,15 +72,11 @@ "bp_50x_warn": 50, "bp_100x_warn": 20 }, + "_comment": "Genologics temporary configuration file", "genologics": { "baseuri": "https://lims.facility.se/", "username": "limsuser", "password": "mypassword" - }, - "_comment": "PubMLST credentials", - "pubmlst": { - "client_id": "", - "client_secret": "" } -} \ No newline at end of file +} diff --git a/microSALT/utils/pubmlst/__init__.py b/microSALT/utils/pubmlst/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/microSALT/utils/pubmlst/authentication.py b/microSALT/utils/pubmlst/authentication.py deleted file mode 100644 index 87a2e0a1..00000000 --- a/microSALT/utils/pubmlst/authentication.py +++ /dev/null @@ -1,106 +0,0 @@ -import json -import os -from datetime import datetime, timedelta -from dateutil import parser -from rauth import OAuth1Session -from microSALT import logger -from microSALT.utils.pubmlst.helpers import BASE_API, save_session_token, load_auth_credentials, get_path, folders_config, credentials_path_key, pubmlst_session_credentials_file_name -from microSALT.utils.pubmlst.exceptions import ( - PUBMLSTError, - SessionTokenRequestError, - SessionTokenResponseError, -) - -session_token_validity = 12 # 12-hour validity -session_expiration_buffer = 60 # 60-second buffer - -def get_new_session_token(db: str): - """Request a new session token using all credentials for a specific database.""" - logger.debug("Fetching a new session token for database '{db}'...") - - try: - consumer_key, consumer_secret, access_token, access_secret = load_auth_credentials() - - url = f"{BASE_API}/db/{db}/oauth/get_session_token" - - session = OAuth1Session( - consumer_key=consumer_key, - consumer_secret=consumer_secret, - access_token=access_token, - access_token_secret=access_secret, - ) - - response = session.get(url, headers={"User-Agent": "BIGSdb downloader"}) - logger.debug("Response Status Code: {status_code}") - - if response.ok: - try: - token_data = response.json() - session_token = token_data.get("oauth_token") - session_secret = token_data.get("oauth_token_secret") - - if not session_token or not session_secret: - raise SessionTokenResponseError( - db, "Missing 'oauth_token' or 'oauth_token_secret' in response." - ) - - expiration_time = datetime.now() + timedelta(hours=session_token_validity) - - save_session_token(db, session_token, session_secret, expiration_time) - return session_token, session_secret - - except (ValueError, KeyError) as e: - raise SessionTokenResponseError(db, f"Invalid response format: {str(e)}") - else: - raise SessionTokenRequestError( - db, response.status_code, response.text - ) - - except PUBMLSTError as e: - logger.error(f"Error during token fetching: {e}") - raise - except Exception as e: - logger.error(f"Unexpected error: {e}") - raise PUBMLSTError(f"Unexpected error while fetching session token for database '{db}': {e}") - -def load_session_credentials(db: str): - """Load session token from file for a specific database.""" - try: - credentials_file = os.path.join( - get_path(folders_config, credentials_path_key), - pubmlst_session_credentials_file_name - ) - - if not os.path.exists(credentials_file): - logger.debug("Session file does not exist. Fetching a new session token.") - return get_new_session_token(db) - - with open(credentials_file, "r") as f: - try: - all_sessions = json.load(f) - except json.JSONDecodeError as e: - raise SessionTokenResponseError(db, f"Failed to parse session file: {str(e)}") - - db_session_data = all_sessions.get("databases", {}).get(db) - if not db_session_data: - logger.debug(f"No session token found for database '{db}'. Fetching a new session token.") - return get_new_session_token(db) - - expiration = parser.parse(db_session_data.get("expiration", "")) - if datetime.now() < expiration - timedelta(seconds=session_expiration_buffer): - logger.debug(f"Using existing session token for database '{db}'.") - session_token = db_session_data.get("token") - session_secret = db_session_data.get("secret") - - return session_token, session_secret - - logger.debug(f"Session token for database '{db}' has expired. Fetching a new session token.") - return get_new_session_token(db) - - except PUBMLSTError as e: - logger.error(f"PUBMLST-specific error occurred: {e}") - raise - except Exception as e: - logger.error(f"Unexpected error: {e}") - raise PUBMLSTError(f"Unexpected error while loading session token for database '{db}': {e}") - diff --git a/microSALT/utils/pubmlst/client.py b/microSALT/utils/pubmlst/client.py deleted file mode 100644 index f6ce9c16..00000000 --- a/microSALT/utils/pubmlst/client.py +++ /dev/null @@ -1,116 +0,0 @@ -import requests -from urllib.parse import urlencode -from microSALT.utils.pubmlst.helpers import ( - BASE_API, - generate_oauth_header, - load_auth_credentials, - parse_pubmlst_url -) -from microSALT.utils.pubmlst.constants import RequestType, HTTPMethod, ResponseHandler -from microSALT.utils.pubmlst.exceptions import PUBMLSTError, SessionTokenRequestError -from microSALT.utils.pubmlst.authentication import load_session_credentials -from microSALT import logger - -class PubMLSTClient: - """Client for interacting with the PubMLST authenticated API.""" - - def __init__(self): - """Initialize the PubMLST client.""" - try: - self.consumer_key, self.consumer_secret, self.access_token, self.access_secret = load_auth_credentials() - self.database = "pubmlst_test_seqdef" - self.session_token, self.session_secret = load_session_credentials(self.database) - except PUBMLSTError as e: - logger.error(f"Failed to initialize PubMLST client: {e}") - raise - - - @staticmethod - def parse_pubmlst_url(url: str): - """ - Wrapper for the parse_pubmlst_url function. - """ - return parse_pubmlst_url(url) - - - def _make_request(self, request_type: RequestType, method: HTTPMethod, url: str, db: str = None, response_handler: ResponseHandler = ResponseHandler.JSON): - """ Handle API requests.""" - try: - if db: - session_token, session_secret = load_session_credentials(db) - else: - session_token, session_secret = self.session_token, self.session_secret - - if request_type == RequestType.AUTH: - headers = { - "Authorization": generate_oauth_header(url, self.consumer_key, self.consumer_secret, self.access_token, self.access_secret) - } - elif request_type == RequestType.DB: - headers = { - "Authorization": generate_oauth_header(url, self.consumer_key, self.consumer_secret, session_token, session_secret) - } - else: - raise ValueError(f"Unsupported request type: {request_type}") - - if method == HTTPMethod.GET: - response = requests.get(url, headers=headers) - elif method == HTTPMethod.POST: - response = requests.post(url, headers=headers) - elif method == HTTPMethod.PUT: - response = requests.put(url, headers=headers) - else: - raise ValueError(f"Unsupported HTTP method: {method}") - - response.raise_for_status() - - if response_handler == ResponseHandler.CONTENT: - return response.content - elif response_handler == ResponseHandler.TEXT: - return response.text - elif response_handler == ResponseHandler.JSON: - return response.json() - else: - raise ValueError(f"Unsupported response handler: {response_handler}") - - except requests.exceptions.HTTPError as e: - raise SessionTokenRequestError(db or self.database, e.response.status_code, e.response.text) from e - except requests.exceptions.RequestException as e: - logger.error(f"Request failed: {e}") - raise PUBMLSTError(f"Request failed: {e}") from e - except Exception as e: - logger.error(f"Unexpected error during request: {e}") - raise PUBMLSTError(f"An unexpected error occurred: {e}") from e - - - def query_databases(self): - """Query available PubMLST databases.""" - url = f"{BASE_API}/db" - return self._make_request(RequestType.DB, HTTPMethod.GET, url, response_handler=ResponseHandler.JSON) - - - def download_locus(self, db: str, locus: str, **kwargs): - """Download locus sequence files.""" - base_url = f"{BASE_API}/db/{db}/loci/{locus}/alleles_fasta" - query_string = urlencode(kwargs) - url = f"{base_url}?{query_string}" if query_string else base_url - return self._make_request(RequestType.DB, HTTPMethod.GET, url, db=db, response_handler=ResponseHandler.TEXT) - - - def download_profiles_csv(self, db: str, scheme_id: int): - """Download MLST profiles in CSV format.""" - if not scheme_id: - raise ValueError("Scheme ID is required to download profiles CSV.") - url = f"{BASE_API}/db/{db}/schemes/{scheme_id}/profiles_csv" - return self._make_request(RequestType.DB, HTTPMethod.GET, url, db=db, response_handler=ResponseHandler.TEXT) - - - def retrieve_scheme_info(self, db: str, scheme_id: int): - """Retrieve information about a specific MLST scheme.""" - url = f"{BASE_API}/db/{db}/schemes/{scheme_id}" - return self._make_request(RequestType.DB, HTTPMethod.GET, url, db=db, response_handler=ResponseHandler.JSON) - - - def list_schemes(self, db: str): - """List available MLST schemes for a specific database.""" - url = f"{BASE_API}/db/{db}/schemes" - return self._make_request(RequestType.DB, HTTPMethod.GET, url, db=db, response_handler=ResponseHandler.JSON) diff --git a/microSALT/utils/pubmlst/constants.py b/microSALT/utils/pubmlst/constants.py deleted file mode 100644 index b77741ca..00000000 --- a/microSALT/utils/pubmlst/constants.py +++ /dev/null @@ -1,79 +0,0 @@ -from enum import Enum -from werkzeug.routing import Map, Rule - -class RequestType(Enum): - AUTH = "auth" - DB = "db" - -class CredentialsFile(Enum): - MAIN = "main" - SESSION = "session" - -class Encoding(Enum): - UTF8 = "utf-8" - -class HTTPMethod(Enum): - GET = "GET" - POST = "POST" - PUT = "PUT" - DELETE = "DELETE" - PATCH = "PATCH" - HEAD = "HEAD" - OPTIONS = "OPTIONS" - -class ResponseHandler(Enum): - CONTENT = "content" - TEXT = "text" - JSON = "json" - -url_map = Map([ - Rule('/', endpoint='root'), - Rule('/db', endpoint='db_root'), - Rule('/db/', endpoint='database_root'), - Rule('/db//classification_schemes', endpoint='classification_schemes'), - Rule('/db//classification_schemes/', endpoint='classification_scheme'), - Rule('/db//classification_schemes//groups', endpoint='classification_scheme_groups'), - Rule('/db//classification_schemes//groups/', endpoint='classification_scheme_group'), - Rule('/db//loci', endpoint='loci'), - Rule('/db//loci/', endpoint='locus'), - Rule('/db//loci//alleles', endpoint='locus_alleles'), - Rule('/db//loci//alleles_fasta', endpoint='locus_alleles_fasta'), - Rule('/db//loci//alleles/', endpoint='locus_allele'), - Rule('/db//loci//sequence', endpoint='locus_sequence_post'), - Rule('/db//sequence', endpoint='sequence_post'), - Rule('/db//sequences', endpoint='sequences'), - Rule('/db//schemes', endpoint='schemes'), - Rule('/db//schemes/', endpoint='scheme'), - Rule('/db//schemes//loci', endpoint='scheme_loci'), - Rule('/db//schemes//fields/', endpoint='scheme_field'), - Rule('/db//schemes//profiles', endpoint='scheme_profiles'), - Rule('/db//schemes//profiles_csv', endpoint='scheme_profiles_csv'), - Rule('/db//schemes//profiles/', endpoint='scheme_profile'), - Rule('/db//schemes//sequence', endpoint='scheme_sequence_post'), - Rule('/db//schemes//designations', endpoint='scheme_designations_post'), - Rule('/db//isolates', endpoint='isolates'), - Rule('/db//genomes', endpoint='genomes'), - Rule('/db//isolates/search', endpoint='isolates_search_post'), - Rule('/db//isolates/', endpoint='isolate'), - Rule('/db//isolates//allele_designations', endpoint='isolate_allele_designations'), - Rule('/db//isolates//allele_designations/', endpoint='isolate_allele_designation_locus'), - Rule('/db//isolates//allele_ids', endpoint='isolate_allele_ids'), - Rule('/db//isolates//schemes//allele_designations', endpoint='isolate_scheme_allele_designations'), - Rule('/db//isolates//schemes//allele_ids', endpoint='isolate_scheme_allele_ids'), - Rule('/db//isolates//contigs', endpoint='isolate_contigs'), - Rule('/db//isolates//contigs_fasta', endpoint='isolate_contigs_fasta'), - Rule('/db//isolates//history', endpoint='isolate_history'), - Rule('/db//contigs/', endpoint='contig'), - Rule('/db//fields', endpoint='fields'), - Rule('/db//fields/', endpoint='field'), - Rule('/db//users/', endpoint='user'), - Rule('/db//curators', endpoint='curators'), - Rule('/db//projects', endpoint='projects'), - Rule('/db//projects/', endpoint='project'), - Rule('/db//projects//isolates', endpoint='project_isolates'), - Rule('/db//submissions', endpoint='submissions'), - Rule('/db//submissions/', endpoint='submission'), - Rule('/db//submissions//messages', endpoint='submission_messages'), - Rule('/db//submissions//files', endpoint='submission_files'), - Rule('/db//submissions//files/', endpoint='submission_file'), -]) diff --git a/microSALT/utils/pubmlst/exceptions.py b/microSALT/utils/pubmlst/exceptions.py deleted file mode 100644 index 018ece63..00000000 --- a/microSALT/utils/pubmlst/exceptions.py +++ /dev/null @@ -1,65 +0,0 @@ -class PUBMLSTError(Exception): - """Base exception for PUBMLST utilities.""" - def __init__(self, message=None): - super(PUBMLSTError, self).__init__(f"PUBMLST: {message}") - - -class CredentialsFileNotFound(PUBMLSTError): - """Raised when the PUBMLST credentials file is not found.""" - def __init__(self, credentials_file): - message = ( - f"Credentials file not found: {credentials_file}. " - "Please generate it using the get_credentials script." - ) - super(CredentialsFileNotFound, self).__init__(message) - - -class InvalidCredentials(PUBMLSTError): - """Raised when the credentials file contains invalid or missing fields.""" - def __init__(self, missing_fields): - message = ( - "Invalid credentials: All fields (CLIENT_ID, CLIENT_SECRET, ACCESS_TOKEN, ACCESS_SECRET) " - f"must be non-empty. Missing or empty fields: {', '.join(missing_fields)}. " - "Please regenerate the credentials file using the get_credentials script." - ) - super(InvalidCredentials, self).__init__(message) - - -class PathResolutionError(PUBMLSTError): - """Raised when the file path cannot be resolved from the configuration.""" - def __init__(self, config_key): - message = ( - f"Failed to resolve the path for configuration key: '{config_key}'. " - "Ensure it is correctly set in the configuration." - ) - super(PathResolutionError, self).__init__(message) - - -class SaveSessionError(PUBMLSTError): - """Raised when saving the session token fails.""" - def __init__(self, db, reason): - message = f"Failed to save session token for database '{db}': {reason}" - super(SaveSessionError, self).__init__(message) - - -class SessionTokenRequestError(PUBMLSTError): - """Raised when requesting a session token fails.""" - def __init__(self, db, status_code, response_text): - message = f"Failed to fetch session token for database '{db}': {status_code} - {response_text}" - super(SessionTokenRequestError, self).__init__(message) - - -class SessionTokenResponseError(PUBMLSTError): - """Raised when the session token response is invalid.""" - def __init__(self, db, reason): - message = f"Invalid session token response for database '{db}': {reason}" - super(SessionTokenResponseError, self).__init__(message) - -class InvalidURLError(PUBMLSTError): - """Raised when the provided URL does not match any known patterns.""" - def __init__(self, href): - message = ( - f"The provided URL '{href}' does not match any known PUBMLST API patterns. " - "Please check the URL for correctness." - ) - super(InvalidURLError, self).__init__(message) diff --git a/microSALT/utils/pubmlst/get_credentials.py b/microSALT/utils/pubmlst/get_credentials.py deleted file mode 100644 index 4fe21e92..00000000 --- a/microSALT/utils/pubmlst/get_credentials.py +++ /dev/null @@ -1,88 +0,0 @@ -import sys -import os -from rauth import OAuth1Service -from microSALT import app -from microSALT.utils.pubmlst.helpers import get_path, BASE_API, BASE_WEB, folders_config, credentials_path_key, pubmlst_auth_credentials_file_name - -db = "pubmlst_test_seqdef" - - -def validate_credentials(client_id, client_secret): - """Ensure client_id and client_secret are not empty.""" - if not client_id or not client_id.strip(): - raise ValueError("Invalid CLIENT_ID: It must not be empty.") - if not client_secret or not client_secret.strip(): - raise ValueError("Invalid CLIENT_SECRET: It must not be empty.") - - -def get_request_token(service): - """Handle JSON response from the request token endpoint.""" - response = service.get_raw_request_token(params={"oauth_callback": "oob"}) - if not response.ok: - print(f"Error obtaining request token: {response.text}") - sys.exit(1) - data = response.json() - return data["oauth_token"], data["oauth_token_secret"] - - -def get_new_access_token(client_id, client_secret): - """Obtain a new access token and secret.""" - service = OAuth1Service( - name="BIGSdb_downloader", - consumer_key=client_id, - consumer_secret=client_secret, - request_token_url=f"{BASE_API}/db/{db}/oauth/get_request_token", - access_token_url=f"{BASE_API}/db/{db}/oauth/get_access_token", - base_url=BASE_API, - ) - request_token, request_secret = get_request_token(service) - print( - "Please log in using your user account at " - f"{BASE_WEB}?db={db}&page=authorizeClient&oauth_token={request_token} " - "using a web browser to obtain a verification code." - ) - verifier = input("Please enter verification code: ") - - raw_access = service.get_raw_access_token( - request_token, request_secret, params={"oauth_verifier": verifier} - ) - if not raw_access.ok: - print(f"Error obtaining access token: {raw_access.text}") - sys.exit(1) - - access_data = raw_access.json() - return access_data["oauth_token"], access_data["oauth_token_secret"] - - -def save_to_credentials_py(client_id, client_secret, access_token, access_secret, credentials_path, credentials_file): - """Save tokens in the credentials.py file.""" - credentials_path.mkdir(parents=True, exist_ok=True) - - with open(credentials_file, "w") as f: - f.write(f'CLIENT_ID = "{client_id}"\n') - f.write(f'CLIENT_SECRET = "{client_secret}"\n') - f.write(f'ACCESS_TOKEN = "{access_token}"\n') - f.write(f'ACCESS_SECRET = "{access_secret}"\n') - print(f"Tokens saved to {credentials_file}") - - -def main(): - try: - pubmlst_config = app.config["pubmlst"] - client_id = pubmlst_config["client_id"] - client_secret = pubmlst_config["client_secret"] - validate_credentials(client_id, client_secret) - credentials_path = get_path(folders_config, credentials_path_key) - credentials_file = os.path.join(credentials_path, pubmlst_auth_credentials_file_name) - access_token, access_secret = get_new_access_token(client_id, client_secret) - print(f"\nAccess Token: {access_token}") - print(f"Access Token Secret: {access_secret}") - save_to_credentials_py(client_id, client_secret, access_token, access_secret, credentials_path, credentials_file) - - except Exception as e: - print(f"Error: {e}") - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/microSALT/utils/pubmlst/helpers.py b/microSALT/utils/pubmlst/helpers.py deleted file mode 100644 index dfc881a3..00000000 --- a/microSALT/utils/pubmlst/helpers.py +++ /dev/null @@ -1,164 +0,0 @@ -import os -import base64 -import hashlib -import json -import hmac -import time -from pathlib import Path -from urllib.parse import quote_plus, urlencode -from werkzeug.exceptions import NotFound -from microSALT import app, logger -from microSALT.utils.pubmlst.exceptions import PUBMLSTError, PathResolutionError, CredentialsFileNotFound, InvalidCredentials, SaveSessionError, InvalidURLError -from microSALT.utils.pubmlst.constants import Encoding, url_map - -BASE_WEB = "https://pubmlst.org/bigsdb" -BASE_API = "https://rest.pubmlst.org" -BASE_API_HOST = "rest.pubmlst.org" - -credentials_path_key = "pubmlst_credentials" -pubmlst_auth_credentials_file_name = "pubmlst_credentials.env" -pubmlst_session_credentials_file_name = "pubmlst_session_credentials.json" -pubmlst_config = app.config["pubmlst"] -folders_config = app.config["folders"] - -def get_path(config, config_key: str): - """Get and expand the file path from the configuration.""" - try: - path = config.get(config_key) - if not path: - raise PathResolutionError(config_key) - - path = os.path.expandvars(path) - path = os.path.expanduser(path) - - return Path(path).resolve() - - except Exception as e: - raise PathResolutionError(config_key) from e - - -def load_auth_credentials(): - """Load client ID, client secret, access token, and access secret from credentials file.""" - try: - credentials_file = os.path.join( - get_path(folders_config, credentials_path_key), - pubmlst_auth_credentials_file_name - ) - - if not os.path.exists(credentials_file): - raise CredentialsFileNotFound(credentials_file) - - credentials = {} - with open(credentials_file, "r") as f: - exec(f.read(), credentials) - - consumer_key = credentials.get("CLIENT_ID", "").strip() - consumer_secret = credentials.get("CLIENT_SECRET", "").strip() - access_token = credentials.get("ACCESS_TOKEN", "").strip() - access_secret = credentials.get("ACCESS_SECRET", "").strip() - - missing_fields = [] - if not consumer_key: - missing_fields.append("CLIENT_ID") - if not consumer_secret: - missing_fields.append("CLIENT_SECRET") - if not access_token: - missing_fields.append("ACCESS_TOKEN") - if not access_secret: - missing_fields.append("ACCESS_SECRET") - - if missing_fields: - raise InvalidCredentials(missing_fields) - - return consumer_key, consumer_secret, access_token, access_secret - - except CredentialsFileNotFound: - raise - except InvalidCredentials: - raise - except PUBMLSTError as e: - logger.error(f"Unexpected error in load_credentials: {e}") - raise - except Exception as e: - raise PUBMLSTError("An unexpected error occurred while loading credentials: {e}") - - -def generate_oauth_header(url: str, oauth_consumer_key: str, oauth_consumer_secret: str, oauth_token: str, oauth_token_secret: str): - """Generate the OAuth1 Authorization header.""" - oauth_timestamp = str(int(time.time())) - oauth_nonce = base64.urlsafe_b64encode(os.urandom(32)).decode(Encoding.UTF8.value).strip("=") - oauth_signature_method = "HMAC-SHA1" - oauth_version = "1.0" - - oauth_params = { - "oauth_consumer_key": oauth_consumer_key, - "oauth_token": oauth_token, - "oauth_signature_method": oauth_signature_method, - "oauth_timestamp": oauth_timestamp, - "oauth_nonce": oauth_nonce, - "oauth_version": oauth_version, - } - - params_encoded = urlencode(sorted(oauth_params.items())) - base_string = f"GET&{quote_plus(url)}&{quote_plus(params_encoded)}" - signing_key = f"{oauth_consumer_secret}&{oauth_token_secret}" - - hashed = hmac.new(signing_key.encode(Encoding.UTF8.value), base_string.encode(Encoding.UTF8.value), hashlib.sha1) - oauth_signature = base64.b64encode(hashed.digest()).decode(Encoding.UTF8.value) - - oauth_params["oauth_signature"] = oauth_signature - - auth_header = "OAuth " + ", ".join( - [f'{quote_plus(k)}="{quote_plus(v)}"' for k, v in oauth_params.items()] - ) - return auth_header - -def save_session_token(db: str, token: str, secret: str, expiration_date: str): - """Save session token, secret, and expiration to a JSON file for the specified database.""" - try: - session_data = { - "token": token, - "secret": secret, - "expiration": expiration_date.isoformat(), - } - - credentials_file = os.path.join( - get_path(folders_config, credentials_path_key), - pubmlst_session_credentials_file_name - ) - - if os.path.exists(credentials_file): - with open(credentials_file, "r") as f: - all_sessions = json.load(f) - else: - all_sessions = {} - - if "databases" not in all_sessions: - all_sessions["databases"] = {} - - all_sessions["databases"][db] = session_data - - with open(credentials_file, "w") as f: - json.dump(all_sessions, f, indent=4) - - logger.debug( - f"Session token for database '{db}' saved to '{credentials_file}'." - ) - except (IOError, OSError) as e: - raise SaveSessionError(db, f"I/O error: {e}") - except ValueError as e: - raise SaveSessionError(db, f"Invalid data format: {e}") - except Exception as e: - raise SaveSessionError(db, f"Unexpected error: {e}") - -def parse_pubmlst_url(url: str): - """ - Match a URL against the URL map and return extracted parameters. - """ - adapter = url_map.bind("") - parsed_url = url.split(BASE_API_HOST)[-1] - try: - endpoint, values = adapter.match(parsed_url) - return {"endpoint": endpoint, **values} - except NotFound: - raise InvalidURLError(url) diff --git a/microSALT/utils/referencer.py b/microSALT/utils/referencer.py index aeac8593..2fa1b6c5 100644 --- a/microSALT/utils/referencer.py +++ b/microSALT/utils/referencer.py @@ -10,7 +10,6 @@ import subprocess import urllib.request import zipfile -from microSALT.utils.pubmlst.client import PubMLSTClient from Bio import Entrez import xml.etree.ElementTree as ET @@ -44,8 +43,6 @@ def __init__(self, config, log, sampleinfo={}, force=False): self.sampleinfo = self.sampleinfo[0] self.name = self.sampleinfo.get("CG_ID_sample") self.sample = self.sampleinfo - self.client = PubMLSTClient() - def identify_new(self, cg_id="", project=False): """ Automatically downloads pubMLST & NCBI organisms not already downloaded """ @@ -388,133 +385,92 @@ def add_pubmlst(self, organism): def query_pubmlst(self): """ Returns a json object containing all organisms available via pubmlst.org """ - db_query = self.client.query_databases() + # Example request URI: http://rest.pubmlst.org/db/pubmlst_neisseria_seqdef/schemes/1/profiles_csv + seqdef_url = dict() + databases = "http://rest.pubmlst.org/db" + db_req = urllib.request.Request(databases) + with urllib.request.urlopen(db_req) as response: + db_query = json.loads(response.read().decode("utf-8")) return db_query - def get_mlst_scheme(self, subtype_href): """ Returns the path for the MLST data scheme at pubMLST """ try: - parsed_data = self.client.parse_pubmlst_url(subtype_href) - db = parsed_data.get('db') - if not db: - self.logger.warning(f"Could not extract database name from URL: {subtype_href}") - return None - - # First, check scheme 1 - scheme_query_1 = self.client.retrieve_scheme_info(db, 1) - mlst = None - if "MLST" in scheme_query_1.get("description", ""): - mlst = f"{subtype_href}/schemes/1" - else: - # If scheme 1 isn't MLST, list all schemes and find the one with 'description' == 'MLST' - record_query = self.client.list_schemes(db) - for scheme in record_query.get("schemes", []): - if scheme.get("description") == "MLST": - mlst = scheme.get("scheme") - break - + mlst = False + record_req_1 = urllib.request.Request("{}/schemes/1".format(subtype_href)) + with urllib.request.urlopen(record_req_1) as response: + scheme_query_1 = json.loads(response.read().decode("utf-8")) + if "MLST" in scheme_query_1["description"]: + mlst = "{}/schemes/1".format(subtype_href) + if not mlst: + record_req = urllib.request.Request("{}/schemes".format(subtype_href)) + with urllib.request.urlopen(record_req) as response: + record_query = json.loads(response.read().decode("utf-8")) + for scheme in record_query["schemes"]: + if scheme["description"] == "MLST": + mlst = scheme["scheme"] if mlst: - self.logger.debug(f"Found data at pubMLST: {mlst}") + self.logger.debug("Found data at pubMLST: {}".format(mlst)) return mlst - else: - self.logger.warning(f"Could not find MLST data at {subtype_href}") - return None + else: + self.logger.warning("Could not find MLST data at {}".format(subtype_href)) except Exception as e: self.logger.warning(e) - return None - def external_version(self, organism, subtype_href): """ Returns the version (date) of the data available on pubMLST """ + mlst_href = self.get_mlst_scheme(subtype_href) try: - mlst_href = self.get_mlst_scheme(subtype_href) - if not mlst_href: - self.logger.warning(f"MLST scheme not found for URL: {subtype_href}") - return None - - parsed_data = self.client.parse_pubmlst_url(mlst_href) - db = parsed_data.get('db') - scheme_id = parsed_data.get('scheme_id') - if not db or not scheme_id: - self.logger.warning(f"Could not extract database name or scheme ID from MLST URL: {mlst_href}") - return None - - scheme_info = self.client.retrieve_scheme_info(db, scheme_id) - last_updated = scheme_info.get("last_updated") - if last_updated: - self.logger.debug(f"Retrieved last_updated: {last_updated} for organism: {organism}") - return last_updated - else: - self.logger.warning(f"No 'last_updated' field found for db: {db}, scheme_id: {scheme_id}") - return None + with urllib.request.urlopen(mlst_href) as response: + ver_query = json.loads(response.read().decode("utf-8")) + return ver_query["last_updated"] except Exception as e: - self.logger.warning(f"Could not determine pubMLST version for {organism}") + self.logger.warning("Could not determine pubMLST version for {}".format(organism)) self.logger.warning(e) - return None - def download_pubmlst(self, organism, subtype_href, force=False): """ Downloads ST and loci for a given organism stored on pubMLST if it is more recent. Returns update date """ organism = organism.lower().replace(" ", "_") + + # Pull version + extver = self.external_version(organism, subtype_href) + currver = self.db_access.get_version("profile_{}".format(organism)) + if ( + int(extver.replace("-", "")) + <= int(currver.replace("-", "")) + and not force + ): + # self.logger.info("Profile for {} already at latest version".format(organism.replace('_' ,' ').capitalize())) + return currver + + # Pull ST file + mlst_href = self.get_mlst_scheme(subtype_href) + st_target = "{}/{}".format(self.config["folders"]["profiles"], organism) + st_input = "{}/profiles_csv".format(mlst_href) + urllib.request.urlretrieve(st_input, st_target) + + # Pull locus files + loci_input = mlst_href + loci_req = urllib.request.Request(loci_input) + with urllib.request.urlopen(loci_req) as response: + loci_query = json.loads(response.read().decode("utf-8")) + + output = "{}/{}".format(self.config["folders"]["references"], organism) + try: - # Pull version - extver = self.external_version(organism, subtype_href) - currver = self.db_access.get_version(f"profile_{organism}") - if ( - int(extver.replace("-", "")) - <= int(currver.replace("-", "")) - and not force - ): - self.logger.info(f"Profile for {organism.replace('_', ' ').capitalize()} already at the latest version.") - return currver - - # Retrieve the MLST scheme URL - mlst_href = self.get_mlst_scheme(subtype_href) - if not mlst_href: - self.logger.warning(f"MLST scheme not found for URL: {subtype_href}") - return None - - # Parse the database name and scheme ID - parsed_data = self.client.parse_pubmlst_url(mlst_href) - db = parsed_data.get('db') - scheme_id = parsed_data.get('scheme_id') - if not db or not scheme_id: - self.logger.warning(f"Could not extract database name or scheme ID from MLST URL: {mlst_href}") - return None - - # Step 1: Download the profiles CSV - st_target = f"{self.config['folders']['profiles']}/{organism}" - profiles_csv = self.client.download_profiles_csv(db, scheme_id) - with open(st_target, "w") as profile_file: - profile_file.write(profiles_csv) - self.logger.info(f"Profiles CSV downloaded to {st_target}") - - # Step 2: Fetch scheme information to get loci - scheme_info = self.client.retrieve_scheme_info(db, scheme_id) - loci_list = scheme_info.get("loci", []) - - # Step 3: Download loci FASTA files - output = f"{self.config['folders']['references']}/{organism}" if os.path.isdir(output): shutil.rmtree(output) - os.makedirs(output) - - for locus_uri in loci_list: - locus_name = os.path.basename(os.path.normpath(locus_uri)) - loci_fasta = self.client.download_locus(db, locus_name) - with open(f"{output}/{locus_name}.tfa", "w") as fasta_file: - fasta_file.write(loci_fasta) - self.logger.info(f"Locus FASTA downloaded: {locus_name}.tfa") - - # Step 4: Create new indexes - self.index_db(output, ".tfa") - - return extver - except Exception as e: - self.logger.error(f"Failed to download data for {organism}: {e}") - return None - + except FileNotFoundError as e: + pass + os.makedirs(output) + + for locipath in loci_query["loci"]: + loci = os.path.basename(os.path.normpath(locipath)) + urllib.request.urlretrieve( + "{}/alleles_fasta".format(locipath), "{}/{}.tfa".format(output, loci) + ) + # Create new indexes + self.index_db(output, ".tfa") def fetch_pubmlst(self, force=False): """ Updates reference for data that is stored on pubMLST """ diff --git a/requirements.txt b/requirements.txt index 5cdd9804..6efdd7f2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,3 @@ pymysql==0.10.1 pyyaml==5.4.1 sqlalchemy==1.3.19 genologics==0.4.6 -rauth==0.7.3 - diff --git a/tests/test_commands.py b/tests/test_commands.py new file mode 100644 index 00000000..6dc37722 --- /dev/null +++ b/tests/test_commands.py @@ -0,0 +1,452 @@ +#!/usr/bin/env python + +import builtins +import click +import json +import logging +import pathlib +import pdb +import pytest +import re +import mock +import os +import sys + +from microSALT import __version__ + +from click.testing import CliRunner +from distutils.sysconfig import get_python_lib +from unittest.mock import patch, mock_open + +from microSALT import preset_config, logger +from microSALT.cli import root +from microSALT.store.db_manipulator import DB_Manipulator + + +def unpack_db_json(filename): + testdata = os.path.abspath( + os.path.join( + pathlib.Path(__file__).parent.parent, "tests/testdata/{}".format(filename) + ) + ) + # Check if release install exists + for entry in os.listdir(get_python_lib()): + if "microSALT-" in entry: + testdata = os.path.abspath( + os.path.join( + os.path.expandvars("$CONDA_PREFIX"), "testdata/{}".format(filename) + ) + ) + with open(testdata) as json_file: + data = json.load(json_file) + return data + + +@pytest.fixture +def dbm(): + db_file = re.search( + "sqlite:///(.+)", preset_config["database"]["SQLALCHEMY_DATABASE_URI"] + ).group(1) + dbm = DB_Manipulator(config=preset_config, log=logger) + dbm.create_tables() + + for entry in unpack_db_json("sampleinfo_projects.json"): + dbm.add_rec(entry, "Projects") + for entry in unpack_db_json("sampleinfo_mlst.json"): + dbm.add_rec(entry, "Seq_types") + for bentry in unpack_db_json("sampleinfo_resistance.json"): + dbm.add_rec(bentry, "Resistances") + for centry in unpack_db_json("sampleinfo_expec.json"): + dbm.add_rec(centry, "Expacs") + for dentry in unpack_db_json("sampleinfo_reports.json"): + dbm.add_rec(dentry, "Reports") + return dbm + + +@pytest.fixture(autouse=True) +def no_requests(monkeypatch): + """Remove requests.sessions.Session.request for all tests.""" + monkeypatch.delattr("requests.sessions.Session.request") + + +@pytest.fixture +def runner(): + runnah = CliRunner() + return runnah + + +@pytest.fixture +def config(): + config = os.path.abspath( + os.path.join(pathlib.Path(__file__).parent.parent, "configExample.json") + ) + # Check if release install exists + for entry in os.listdir(get_python_lib()): + if "microSALT-" in entry: + config = os.path.abspath( + os.path.join( + os.path.expandvars("$CONDA_PREFIX"), "testdata/configExample.json" + ) + ) + return config + + +@pytest.fixture +def path_testdata(): + testdata = os.path.abspath( + os.path.join( + pathlib.Path(__file__).parent.parent, + "tests/testdata/sampleinfo_samples.json", + ) + ) + # Check if release install exists + for entry in os.listdir(get_python_lib()): + if "microSALT-" in entry: + testdata = os.path.abspath( + os.path.join( + os.path.expandvars("$CONDA_PREFIX"), + "testdata/sampleinfo_samples.json", + ) + ) + return testdata + + +@pytest.fixture +def path_testproject(): + testproject = os.path.abspath( + os.path.join( + pathlib.Path(__file__).parent.parent, + "tests/testdata/AAA1234_2000.1.2_3.4.5", + ) + ) + # Check if release install exists + for entry in os.listdir(get_python_lib()): + if "microSALT-" in entry: + testproject = os.path.abspath( + os.path.join( + os.path.expandvars("$CONDA_PREFIX"), + "testdata/AAA1234_2000.1.2_3.4.5", + ) + ) + return testproject + + +def test_version(runner): + res = runner.invoke(root, "--version") + assert res.exit_code == 0 + assert __version__ in res.stdout + + +def test_groups(runner): + """These groups should only return the help text""" + base = runner.invoke(root, ["utils"]) + assert base.exit_code == 0 + base_invoke = runner.invoke(root, ["utils", "resync"]) + assert base_invoke.exit_code == 0 + base_invoke = runner.invoke(root, ["utils", "refer"]) + assert base_invoke.exit_code == 0 + +@patch("microSALT.utils.job_creator.Job_Creator.create_project") +@patch("microSALT.utils.reporter.Reporter.start_web") +@patch("multiprocessing.Process.terminate") +@patch("multiprocessing.Process.join") +@patch("microSALT.utils.reporter.requests.get") +@patch("microSALT.utils.reporter.smtplib") +@patch("microSALT.cli.os.path.isdir") +def test_finish_typical( + isdir, + smtp, + reqs_get, + proc_join, + proc_term, + webstart, + create_projct, + runner, + config, + path_testdata, + path_testproject, + caplog, + dbm, +): + caplog.set_level(logging.DEBUG, logger="main_logger") + caplog.clear() + + isdir.return_value = True + + # All subcommands + base_invoke = runner.invoke(root, ["utils", "finish"]) + assert base_invoke.exit_code == 2 + # Exhaustive parameter test + typical_run = runner.invoke( + root, + [ + "utils", + "finish", + path_testdata, + "--email", + "2@2.com", + "--config", + config, + "--report", + "default", + "--output", + "/tmp/", + "--input", + path_testproject, + ], + ) + assert typical_run.exit_code == 0 + assert "INFO - Execution finished!" in caplog.text + caplog.clear() + + +@patch("microSALT.utils.job_creator.Job_Creator.create_project") +@patch("microSALT.utils.reporter.Reporter.start_web") +@patch("multiprocessing.Process.terminate") +@patch("multiprocessing.Process.join") +@patch("microSALT.utils.reporter.requests.get") +@patch("microSALT.utils.reporter.smtplib") +@patch("microSALT.cli.os.path.isdir") +def test_finish_qc( + isdir, + smtp, + reqs_get, + proc_join, + proc_term, + webstart, + create_projct, + runner, + config, + path_testdata, + path_testproject, + caplog, + dbm, +): + caplog.set_level(logging.DEBUG, logger="main_logger") + caplog.clear() + + isdir.return_value = True + + special_run = runner.invoke( + root, + [ + "utils", + "finish", + path_testdata, + "--report", + "qc", + "--output", + "/tmp/", + "--input", + path_testproject, + ], + ) + assert special_run.exit_code == 0 + assert "INFO - Execution finished!" in caplog.text + caplog.clear() + + +@patch("microSALT.utils.job_creator.Job_Creator.create_project") +@patch("microSALT.utils.reporter.Reporter.start_web") +@patch("multiprocessing.Process.terminate") +@patch("multiprocessing.Process.join") +@patch("microSALT.utils.reporter.requests.get") +@patch("microSALT.utils.reporter.smtplib") +@patch("microSALT.cli.os.path.isdir") +def test_finish_motif( + isdir, + smtp, + reqs_get, + proc_join, + proc_term, + webstart, + create_projct, + runner, + config, + path_testdata, + path_testproject, + caplog, + dbm, +): + caplog.set_level(logging.DEBUG, logger="main_logger") + caplog.clear() + + isdir.return_value = True + + unique_report = runner.invoke( + root, + [ + "utils", + "finish", + path_testdata, + "--report", + "motif_overview", + "--output", + "/tmp/", + "--input", + path_testproject, + ], + ) + assert unique_report.exit_code == 0 + assert "INFO - Execution finished!" in caplog.text + caplog.clear() + + +@patch("microSALT.utils.reporter.Reporter.start_web") +@patch("multiprocessing.Process.terminate") +@patch("multiprocessing.Process.join") +@patch("microSALT.utils.reporter.requests.get") +@patch("microSALT.utils.reporter.smtplib") +def test_report( + smtplib, reqget, join, term, webstart, runner, path_testdata, caplog, dbm +): + caplog.set_level(logging.DEBUG, logger="main_logger") + caplog.clear() + + base_invoke = runner.invoke(root, ["utils", "report"]) + assert base_invoke.exit_code == 2 + + # Exhaustive parameter test + for rep_type in [ + "default", + "typing", + "motif_overview", + "qc", + "json_dump", + "st_update", + ]: + normal_report = runner.invoke( + root, + [ + "utils", + "report", + path_testdata, + "--type", + rep_type, + "--email", + "2@2.com", + "--output", + "/tmp/", + ], + ) + assert normal_report.exit_code == 0 + assert "INFO - Execution finished!" in caplog.text + caplog.clear() + collection_report = runner.invoke( + root, + [ + "utils", + "report", + path_testdata, + "--type", + rep_type, + "--collection", + "--output", + "/tmp/", + ], + ) + assert collection_report.exit_code == 0 + assert "INFO - Execution finished!" in caplog.text + caplog.clear() + + +@patch("microSALT.utils.reporter.Reporter.start_web") +@patch("multiprocessing.Process.terminate") +@patch("multiprocessing.Process.join") +@patch("microSALT.utils.reporter.requests.get") +@patch("microSALT.utils.reporter.smtplib") +def test_resync_overwrite(smtplib, reqget, join, term, webstart, runner, caplog, dbm): + caplog.set_level(logging.DEBUG, logger="main_logger") + caplog.clear() + + a = runner.invoke(root, ["utils", "resync", "overwrite", "AAA1234A1"]) + assert a.exit_code == 0 + assert "INFO - Execution finished!" in caplog.text + caplog.clear() + b = runner.invoke(root, ["utils", "resync", "overwrite", "AAA1234A1", "--force"]) + assert b.exit_code == 0 + assert "INFO - Execution finished!" in caplog.text + caplog.clear() + + +@patch("microSALT.utils.reporter.Reporter.start_web") +@patch("multiprocessing.Process.terminate") +@patch("multiprocessing.Process.join") +@patch("microSALT.utils.reporter.requests.get") +@patch("microSALT.utils.reporter.smtplib") +def test_resync_review(smtplib, reqget, join, term, webstart, runner, caplog, dbm): + caplog.set_level(logging.DEBUG, logger="main_logger") + caplog.clear() + + # Exhaustive parameter test + for rep_type in ["list", "report"]: + typical_work = runner.invoke( + root, + [ + "utils", + "resync", + "review", + "--email", + "2@2.com", + "--type", + rep_type, + "--output", + "/tmp/", + ], + ) + assert typical_work.exit_code == 0 + assert "INFO - Execution finished!" in caplog.text + caplog.clear() + delimited_work = runner.invoke( + root, + [ + "utils", + "resync", + "review", + "--skip_update", + "--customer", + "custX", + "--type", + rep_type, + "--output", + "/tmp/", + ], + ) + assert delimited_work.exit_code == 0 + assert "INFO - Execution finished!" in caplog.text + caplog.clear() + + +def test_refer(runner, caplog, dbm): + caplog.set_level(logging.DEBUG, logger="main_logger") + + list_invoke = runner.invoke(root, ["utils", "refer", "observe"]) + assert list_invoke.exit_code == 0 + + a = runner.invoke(root, ["utils", "refer", "add", "Homosapiens_Trams"]) + assert a.exit_code == 0 + # assert "INFO - Execution finished!" in caplog.text + caplog.clear() + b = runner.invoke(root, ["utils", "refer", "add", "Homosapiens_Trams", "--force"]) + assert b.exit_code == 0 + # assert "INFO - Execution finished!" in caplog.text + caplog.clear() + + +@patch("microSALT.utils.reporter.Reporter.start_web") +def test_view(webstart, runner, caplog, dbm): + caplog.set_level(logging.DEBUG, logger="main_logger") + + view = runner.invoke(root, ["utils", "view"]) + assert view.exit_code == 0 + # assert "INFO - Execution finished!" in caplog.text + caplog.clear() + + +@patch("os.path.isdir") +def test_generate(isdir, runner, caplog, dbm): + caplog.set_level(logging.DEBUG, logger="main_logger") + gent = runner.invoke(root, ["utils", "generate", "--input", "/tmp/"]) + assert gent.exit_code == 0 + fent = runner.invoke(root, ["utils", "generate"]) + assert fent.exit_code == 0 diff --git a/tests/test_config.py b/tests/test_config.py index d2332d93..d61bcd2d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -9,71 +9,74 @@ @pytest.fixture def exp_config(): - precon = { - 'slurm_header': {'time', 'threads', 'qos', 'job_prefix', 'project', 'type'}, - 'regex': {'file_pattern', 'mail_recipient', 'verified_organisms'}, - 'folders': {'results', 'reports', 'log_file', 'seqdata', 'profiles', 'references', 'resistances', 'genomes', 'expec', 'adapters', 'pubmlst_credentials'}, - 'threshold': {'mlst_id', 'mlst_novel_id', 'mlst_span', 'motif_id', 'motif_span', 'total_reads_warn', 'total_reads_fail', - 'NTC_total_reads_warn', 'NTC_total_reads_fail', 'mapped_rate_warn', 'mapped_rate_fail', 'duplication_rate_warn', - 'duplication_rate_fail', 'insert_size_warn', 'insert_size_fail', 'average_coverage_warn', 'average_coverage_fail', - 'bp_10x_warn', 'bp_10x_fail', 'bp_30x_warn', 'bp_50x_warn', 'bp_100x_warn'}, - 'database': {'SQLALCHEMY_DATABASE_URI', 'SQLALCHEMY_TRACK_MODIFICATIONS', 'DEBUG'}, - 'genologics': {'baseuri', 'username', 'password'}, - 'pubmlst': {'client_id', 'client_secret'}, + precon = \ + { + 'slurm_header': + {'time','threads', 'qos', 'job_prefix','project', 'type'}, + 'regex': + {'file_pattern', 'mail_recipient', 'verified_organisms'}, + 'folders': + {'results', 'reports', 'log_file', 'seqdata', 'profiles', 'references', 'resistances', 'genomes', 'expec', 'adapters'}, + 'threshold': + {'mlst_id', 'mlst_novel_id', 'mlst_span', 'motif_id', 'motif_span', 'total_reads_warn', 'total_reads_fail', 'NTC_total_reads_warn', \ + 'NTC_total_reads_fail', 'mapped_rate_warn', 'mapped_rate_fail', 'duplication_rate_warn', 'duplication_rate_fail', 'insert_size_warn', 'insert_size_fail', \ + 'average_coverage_warn', 'average_coverage_fail', 'bp_10x_warn', 'bp_10x_fail', 'bp_30x_warn', 'bp_50x_warn', 'bp_100x_warn'}, + 'database': + {'SQLALCHEMY_DATABASE_URI' ,'SQLALCHEMY_TRACK_MODIFICATIONS' , 'DEBUG'}, + 'genologics': + {'baseuri', 'username', 'password'}, 'dry': True, } return precon def test_existence(exp_config): """Checks that the configuration contains certain key variables""" - # level one + + #level one config_level_one = preset_config.keys() for entry in exp_config.keys(): if entry != 'dry': assert entry in config_level_one - # level two + #level two if isinstance(preset_config[entry], collections.Mapping): config_level_two = preset_config[entry].keys() for thing in exp_config[entry]: assert thing in config_level_two def test_reverse_existence(exp_config): - """Check that the configuration doesn't contain outdated variables""" + """Check that the configuration doesnt contain outdated variables""" - # level one + #level one config_level_one = exp_config.keys() for entry in preset_config.keys(): if entry not in ['_comment']: assert entry in config_level_one - # level two + #level two config_level_two = exp_config[entry] if isinstance(preset_config[entry], collections.Mapping): for thing in preset_config[entry].keys(): if thing != '_comment': assert thing in config_level_two +#def test_type(exp_config): +# """Verify that each variable uses the correct format""" +# pass + def test_paths(exp_config): """Tests existence for all paths mentioned in variables""" - # level one + #level one for entry in preset_config.keys(): if entry != '_comment': if isinstance(preset_config[entry], str) and '/' in preset_config[entry] and entry not in ['database', 'genologics']: unmade_fldr = preset_config[entry] - # Embed logic to expand vars and user here - unmade_fldr = os.path.expandvars(unmade_fldr) - unmade_fldr = os.path.expanduser(unmade_fldr) - unmade_fldr = os.path.abspath(unmade_fldr) assert (pathlib.Path(unmade_fldr).exists()) - # level two + #level two elif isinstance(preset_config[entry], collections.Mapping): for thing in preset_config[entry].keys(): if isinstance(preset_config[entry][thing], str) and '/' in preset_config[entry][thing] and entry not in ['database', 'genologics']: unmade_fldr = preset_config[entry][thing] - # Embed logic to expand vars and user here - unmade_fldr = os.path.expandvars(unmade_fldr) - unmade_fldr = os.path.expanduser(unmade_fldr) - unmade_fldr = os.path.abspath(unmade_fldr) assert (pathlib.Path(unmade_fldr).exists()) + diff --git a/tests/test_database.py b/tests/test_database.py index e9ca73d8..7b6f1e67 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -17,164 +17,136 @@ from microSALT import preset_config, logger from microSALT.cli import root - def unpack_db_json(filename): - testdata = os.path.abspath(os.path.join(pathlib.Path(__file__).parent.parent, 'tests/testdata/{}'.format(filename))) - #Check if release install exists - for entry in os.listdir(get_python_lib()): - if 'microSALT-' in entry: - testdata = os.path.abspath( - os.path.join(os.path.expandvars('$CONDA_PREFIX'), 'testdata/{}'.format(filename))) - with open(testdata) as json_file: - data = json.load(json_file) - return data - + testdata = os.path.abspath(os.path.join(pathlib.Path(__file__).parent.parent, 'tests/testdata/{}'.format(filename))) + #Check if release install exists + for entry in os.listdir(get_python_lib()): + if 'microSALT-' in entry: + testdata = os.path.abspath(os.path.join(os.path.expandvars('$CONDA_PREFIX'), 'testdata/{}'.format(filename))) + with open(testdata) as json_file: + data = json.load(json_file) + return data @pytest.fixture def dbm(): - db_file = re.search('sqlite:///(.+)', preset_config['database']['SQLALCHEMY_DATABASE_URI']).group(1) - dbm = DB_Manipulator(config=preset_config, log=logger) - dbm.create_tables() - - for antry in unpack_db_json('sampleinfo_projects.json'): - dbm.add_rec(antry, 'Projects') - for entry in unpack_db_json('sampleinfo_mlst.json'): - dbm.add_rec(entry, 'Seq_types') - for bentry in unpack_db_json('sampleinfo_resistance.json'): - dbm.add_rec(bentry, 'Resistances') - for centry in unpack_db_json('sampleinfo_expec.json'): - dbm.add_rec(centry, 'Expacs') - for dentry in unpack_db_json('sampleinfo_reports.json'): - dbm.add_rec(dentry, 'Reports') - return dbm - + db_file = re.search('sqlite:///(.+)', preset_config['database']['SQLALCHEMY_DATABASE_URI']).group(1) + dbm = DB_Manipulator(config=preset_config,log=logger) + dbm.create_tables() + + for antry in unpack_db_json('sampleinfo_projects.json'): + dbm.add_rec(antry, 'Projects') + for entry in unpack_db_json('sampleinfo_mlst.json'): + dbm.add_rec(entry, 'Seq_types') + for bentry in unpack_db_json('sampleinfo_resistance.json'): + dbm.add_rec(bentry, 'Resistances') + for centry in unpack_db_json('sampleinfo_expec.json'): + dbm.add_rec(centry, 'Expacs') + for dentry in unpack_db_json('sampleinfo_reports.json'): + dbm.add_rec(dentry, 'Reports') + return dbm def test_create_every_table(dbm): - assert dbm.engine.dialect.has_table(dbm.engine, 'samples') - assert dbm.engine.dialect.has_table(dbm.engine, 'seq_types') - assert dbm.engine.dialect.has_table(dbm.engine, 'resistances') - assert dbm.engine.dialect.has_table(dbm.engine, 'expacs') - assert dbm.engine.dialect.has_table(dbm.engine, 'projects') - assert dbm.engine.dialect.has_table(dbm.engine, 'reports') - assert dbm.engine.dialect.has_table(dbm.engine, 'collections') + assert dbm.engine.dialect.has_table(dbm.engine, 'samples') + assert dbm.engine.dialect.has_table(dbm.engine, 'seq_types') + assert dbm.engine.dialect.has_table(dbm.engine, 'resistances') + assert dbm.engine.dialect.has_table(dbm.engine, 'expacs') + assert dbm.engine.dialect.has_table(dbm.engine, 'projects') + assert dbm.engine.dialect.has_table(dbm.engine, 'reports') + assert dbm.engine.dialect.has_table(dbm.engine, 'collections') - -@pytest.mark.xfail(reason="Can no longer fetch from databases without authenticating") def test_add_rec(caplog, dbm): - #Adds records to all databases - dbm.add_rec( - {'ST': '130', 'arcC': '6', 'aroE': '57', 'glpF': '45', 'gmk': '2', 'pta': '7', 'tpi': '58', 'yqiL': '52', - 'clonal_complex': 'CC1'}, dbm.profiles['staphylococcus_aureus']) - assert len(dbm.query_rec(dbm.profiles['staphylococcus_aureus'], {'ST': '130'})) == 1 - assert len(dbm.query_rec(dbm.profiles['staphylococcus_aureus'], {'ST': '-1'})) == 0 - - dbm.add_rec( - {'ST': '130', 'arcC': '6', 'aroE': '57', 'glpF': '45', 'gmk': '2', 'pta': '7', 'tpi': '58', 'yqiL': '52', - 'clonal_complex': 'CC1'}, dbm.novel['staphylococcus_aureus']) - assert len(dbm.query_rec(dbm.novel['staphylococcus_aureus'], {'ST': '130'})) == 1 - assert len(dbm.query_rec(dbm.novel['staphylococcus_aureus'], {'ST': '-1'})) == 0 - - dbm.add_rec({'CG_ID_sample': 'ADD1234A1'}, 'Samples') - assert len(dbm.query_rec('Samples', {'CG_ID_sample': 'ADD1234A1'})) > 0 - assert len(dbm.query_rec('Samples', {'CG_ID_sample': 'XXX1234A10'})) == 0 - - dbm.add_rec({'CG_ID_sample': 'ADD1234A1', 'loci': 'mdh', 'contig_name': 'NODE_1'}, 'Seq_types') - assert len(dbm.query_rec('Seq_types', {'CG_ID_sample': 'ADD1234A1', 'loci': 'mdh', 'contig_name': 'NODE_1'})) > 0 - assert len(dbm.query_rec('Seq_types', {'CG_ID_sample': 'XXX1234A10', 'loci': 'mdh', 'contig_name': 'NODE_1'})) == 0 - - dbm.add_rec({'CG_ID_sample': 'ADD1234A1', 'gene': 'Type 1', 'instance': 'Type 1', 'contig_name': 'NODE_1'}, - 'Resistances') - assert len(dbm.query_rec('Resistances', {'CG_ID_sample': 'ADD1234A1', 'gene': 'Type 1', 'instance': 'Type 1', - 'contig_name': 'NODE_1'})) > 0 - assert len(dbm.query_rec('Resistances', {'CG_ID_sample': 'XXX1234A10', 'gene': 'Type 1', 'instance': 'Type 1', - 'contig_name': 'NODE_1'})) == 0 - - dbm.add_rec({'CG_ID_sample': 'ADD1234A1', 'gene': 'Type 1', 'instance': 'Type 1', 'contig_name': 'NODE_1'}, - 'Expacs') - assert len(dbm.query_rec('Expacs', {'CG_ID_sample': 'ADD1234A1', 'gene': 'Type 1', 'instance': 'Type 1', - 'contig_name': 'NODE_1'})) > 0 - assert len(dbm.query_rec('Expacs', {'CG_ID_sample': 'XXX1234A10', 'gene': 'Type 1', 'instance': 'Type 1', - 'contig_name': 'NODE_1'})) == 0 - - dbm.add_rec({'CG_ID_project': 'ADD1234'}, 'Projects') - assert len(dbm.query_rec('Projects', {'CG_ID_project': 'ADD1234'})) > 0 - assert len(dbm.query_rec('Projects', {'CG_ID_project': 'XXX1234'})) == 0 - - dbm.add_rec({'CG_ID_project': 'ADD1234', 'version': '1'}, 'Reports') - assert len(dbm.query_rec('Reports', {'CG_ID_project': 'ADD1234', 'version': '1'})) > 0 - assert len(dbm.query_rec('Reports', {'CG_ID_project': 'XXX1234', 'version': '1'})) == 0 - - dbm.add_rec({'CG_ID_sample': 'ADD1234', 'ID_collection': 'MyCollectionFolder'}, 'Collections') - assert len(dbm.query_rec('Collections', {'CG_ID_sample': 'ADD1234', 'ID_collection': 'MyCollectionFolder'})) > 0 - assert len(dbm.query_rec('Collections', {'CG_ID_sample': 'XXX1234', 'ID_collection': 'MyCollectionFolder'})) == 0 - - caplog.clear() - with pytest.raises(Exception): - dbm.add_rec({'CG_ID_sample': 'ADD1234A1'}, 'An_entry_that_does_not_exist') - assert "Attempted to access table" in caplog.text + #Adds records to all databases + dbm.add_rec({'ST':'130','arcC':'6','aroE':'57','glpF':'45','gmk':'2','pta':'7','tpi':'58','yqiL':'52','clonal_complex':'CC1'}, dbm.profiles['staphylococcus_aureus']) + assert len(dbm.query_rec(dbm.profiles['staphylococcus_aureus'], {'ST':'130'})) == 1 + assert len(dbm.query_rec(dbm.profiles['staphylococcus_aureus'], {'ST':'-1'})) == 0 + + dbm.add_rec({'ST':'130','arcC':'6','aroE':'57','glpF':'45','gmk':'2','pta':'7','tpi':'58','yqiL':'52','clonal_complex':'CC1'}, dbm.novel['staphylococcus_aureus']) + assert len(dbm.query_rec(dbm.novel['staphylococcus_aureus'], {'ST':'130'})) == 1 + assert len(dbm.query_rec(dbm.novel['staphylococcus_aureus'], {'ST':'-1'})) == 0 + + dbm.add_rec({'CG_ID_sample':'ADD1234A1'}, 'Samples') + assert len(dbm.query_rec('Samples', {'CG_ID_sample':'ADD1234A1'})) > 0 + assert len(dbm.query_rec('Samples', {'CG_ID_sample':'XXX1234A10'})) == 0 + + dbm.add_rec({'CG_ID_sample':'ADD1234A1', 'loci':'mdh', 'contig_name':'NODE_1'}, 'Seq_types') + assert len(dbm.query_rec('Seq_types', {'CG_ID_sample':'ADD1234A1', 'loci':'mdh', 'contig_name':'NODE_1'})) > 0 + assert len(dbm.query_rec('Seq_types', {'CG_ID_sample':'XXX1234A10', 'loci':'mdh', 'contig_name':'NODE_1'})) == 0 + + dbm.add_rec({'CG_ID_sample':'ADD1234A1', 'gene':'Type 1', 'instance':'Type 1', 'contig_name':'NODE_1'}, 'Resistances') + assert len(dbm.query_rec('Resistances',{'CG_ID_sample':'ADD1234A1', 'gene':'Type 1', 'instance':'Type 1', 'contig_name':'NODE_1'})) > 0 + assert len(dbm.query_rec('Resistances',{'CG_ID_sample':'XXX1234A10', 'gene':'Type 1', 'instance':'Type 1', 'contig_name':'NODE_1'})) == 0 + + dbm.add_rec({'CG_ID_sample':'ADD1234A1','gene':'Type 1', 'instance':'Type 1', 'contig_name':'NODE_1'}, 'Expacs') + assert len(dbm.query_rec('Expacs',{'CG_ID_sample':'ADD1234A1','gene':'Type 1', 'instance':'Type 1', 'contig_name':'NODE_1'})) > 0 + assert len(dbm.query_rec('Expacs',{'CG_ID_sample':'XXX1234A10','gene':'Type 1', 'instance':'Type 1', 'contig_name':'NODE_1'})) == 0 + dbm.add_rec({'CG_ID_project':'ADD1234'}, 'Projects') + assert len(dbm.query_rec('Projects',{'CG_ID_project':'ADD1234'})) > 0 + assert len(dbm.query_rec('Projects',{'CG_ID_project':'XXX1234'})) == 0 + + dbm.add_rec({'CG_ID_project':'ADD1234','version':'1'}, 'Reports') + assert len(dbm.query_rec('Reports',{'CG_ID_project':'ADD1234','version':'1'})) > 0 + assert len(dbm.query_rec('Reports',{'CG_ID_project':'XXX1234','version':'1'})) == 0 + + dbm.add_rec({'CG_ID_sample':'ADD1234', 'ID_collection':'MyCollectionFolder'}, 'Collections') + assert len(dbm.query_rec('Collections',{'CG_ID_sample':'ADD1234', 'ID_collection':'MyCollectionFolder'})) > 0 + assert len(dbm.query_rec('Collections',{'CG_ID_sample':'XXX1234', 'ID_collection':'MyCollectionFolder'})) == 0 + + caplog.clear() + with pytest.raises(Exception): + dbm.add_rec({'CG_ID_sample': 'ADD1234A1'}, 'An_entry_that_does_not_exist') + assert "Attempted to access table" in caplog.text @patch('sys.exit') def test_upd_rec(sysexit, caplog, dbm): - dbm.add_rec({'CG_ID_sample': 'UPD1234A1'}, 'Samples') - assert len(dbm.query_rec('Samples', {'CG_ID_sample': 'UPD1234A1'})) == 1 - assert len(dbm.query_rec('Samples', {'CG_ID_sample': 'UPD1234A2'})) == 0 - - dbm.upd_rec({'CG_ID_sample': 'UPD1234A1'}, 'Samples', {'CG_ID_sample': 'UPD1234A2'}) - assert len(dbm.query_rec('Samples', {'CG_ID_sample': 'UPD1234A1'})) == 0 - assert len(dbm.query_rec('Samples', {'CG_ID_sample': 'UPD1234A2'})) == 1 + dbm.add_rec({'CG_ID_sample':'UPD1234A1'}, 'Samples') + assert len(dbm.query_rec('Samples', {'CG_ID_sample':'UPD1234A1'})) == 1 + assert len(dbm.query_rec('Samples', {'CG_ID_sample':'UPD1234A2'})) == 0 - dbm.upd_rec({'CG_ID_sample': 'UPD1234A2'}, 'Samples', {'CG_ID_sample': 'UPD1234A1'}) + dbm.upd_rec({'CG_ID_sample':'UPD1234A1'}, 'Samples', {'CG_ID_sample':'UPD1234A2'}) + assert len(dbm.query_rec('Samples', {'CG_ID_sample':'UPD1234A1'})) == 0 + assert len(dbm.query_rec('Samples', {'CG_ID_sample':'UPD1234A2'})) == 1 - caplog.clear() - dbm.add_rec({'CG_ID_sample': 'UPD1234A1_uniq', 'Customer_ID_sample': 'cust000'}, 'Samples') - dbm.add_rec({'CG_ID_sample': 'UPD1234A2_uniq', 'Customer_ID_sample': 'cust000'}, 'Samples') - dbm.upd_rec({'Customer_ID_sample': 'cust000'}, 'Samples', {'Customer_ID_sample': 'cust030'}) - dbm.upd_rec({'Customer_ID_sample': 'cust000'}, 'Samples', {'Customer_ID_sample': 'cust030'}) - assert "More than 1 record found" in caplog.text + dbm.upd_rec({'CG_ID_sample': 'UPD1234A2'}, 'Samples', {'CG_ID_sample': 'UPD1234A1'}) + caplog.clear() + dbm.add_rec({'CG_ID_sample': 'UPD1234A1_uniq', 'Customer_ID_sample': 'cust000'}, 'Samples') + dbm.add_rec({'CG_ID_sample': 'UPD1234A2_uniq', 'Customer_ID_sample': 'cust000'}, 'Samples') + dbm.upd_rec({'Customer_ID_sample': 'cust000'}, 'Samples', {'Customer_ID_sample': 'cust030'}) + dbm.upd_rec({'Customer_ID_sample': 'cust000'}, 'Samples', {'Customer_ID_sample': 'cust030'}) + assert "More than 1 record found" in caplog.text -@pytest.mark.xfail(reason="Can no longer fetch from databases without authenticating") def test_allele_ranker(dbm): - dbm.add_rec({'CG_ID_sample': 'MLS1234A1', 'CG_ID_project': 'MLS1234', 'organism': 'staphylococcus_aureus'}, - 'Samples') - assert dbm.alleles2st('MLS1234A1') == 130 - best_alleles = {'arcC': {'contig_name': 'NODE_1', 'allele': 6}, 'aroE': {'contig_name': 'NODE_1', 'allele': 57}, - 'glpF': {'contig_name': 'NODE_1', 'allele': 45}, 'gmk': {'contig_name': 'NODE_1', 'allele': 2}, - 'pta': {'contig_name': 'NODE_1', 'allele': 7}, 'tpi': {'contig_name': 'NODE_1', 'allele': 58}, - 'yqiL': {'contig_name': 'NODE_1', 'allele': 52}} - assert dbm.bestAlleles('MLS1234A1') == best_alleles - - for entry in unpack_db_json('sampleinfo_mlst.json'): - entry['allele'] = 0 - entry['CG_ID_sample'] = 'MLS1234A2' - dbm.add_rec(entry, 'Seq_types') - dbm.alleles2st('MLS1234A2') == -1 - - -@pytest.mark.xfail(reason="Can no longer fetch from databases without authenticating") -def test_get_and_set_report(dbm): - dbm.add_rec({'CG_ID_sample': 'ADD1234A1', 'method_sequencing': '1000:1'}, 'Samples') - dbm.add_rec({'CG_ID_project': 'ADD1234', 'version': '1'}, 'Reports') - assert dbm.get_report('ADD1234').version == 1 + dbm.add_rec({'CG_ID_sample':'MLS1234A1', 'CG_ID_project':'MLS1234','organism':'staphylococcus_aureus'}, 'Samples') + assert dbm.alleles2st('MLS1234A1') == 130 + best_alleles = {'arcC': {'contig_name': 'NODE_1', 'allele': 6}, 'aroE': {'contig_name': 'NODE_1', 'allele': 57}, 'glpF': {'contig_name': 'NODE_1', 'allele': 45}, 'gmk': {'contig_name': 'NODE_1', 'allele': 2}, 'pta': {'contig_name': 'NODE_1', 'allele': 7}, 'tpi': {'contig_name': 'NODE_1', 'allele': 58}, 'yqiL': {'contig_name': 'NODE_1', 'allele': 52}} + assert dbm.bestAlleles('MLS1234A1') == best_alleles + + for entry in unpack_db_json('sampleinfo_mlst.json'): + entry['allele'] = 0 + entry['CG_ID_sample'] = 'MLS1234A2' + dbm.add_rec(entry, 'Seq_types') + dbm.alleles2st('MLS1234A2') == -1 - dbm.upd_rec({'CG_ID_sample': 'ADD1234A1', 'method_sequencing': '1000:1'}, 'Samples', - {'CG_ID_sample': 'ADD1234A1', 'method_sequencing': '1000:2'}) - dbm.set_report('ADD1234') - assert dbm.get_report('ADD1234').version != 1 +def test_get_and_set_report(dbm): + dbm.add_rec({'CG_ID_sample':'ADD1234A1', 'method_sequencing':'1000:1'}, 'Samples') + dbm.add_rec({'CG_ID_project':'ADD1234','version':'1'}, 'Reports') + assert dbm.get_report('ADD1234').version == 1 + dbm.upd_rec({'CG_ID_sample':'ADD1234A1', 'method_sequencing':'1000:1'}, 'Samples', {'CG_ID_sample':'ADD1234A1', 'method_sequencing':'1000:2'}) + dbm.set_report('ADD1234') + assert dbm.get_report('ADD1234').version != 1 @patch('sys.exit') def test_purge_rec(sysexit, caplog, dbm): - dbm.add_rec({'CG_ID_sample': 'UPD1234A1'}, 'Samples') - dbm.purge_rec('UPD1234A1', 'Collections') - - caplog.clear() - dbm.purge_rec('UPD1234A1', 'Not_Samples_nor_Collections') - assert "Incorrect type" in caplog.text + dbm.add_rec({'CG_ID_sample':'UPD1234A1'}, 'Samples') + dbm.purge_rec('UPD1234A1', 'Collections') + caplog.clear() + dbm.purge_rec('UPD1234A1', 'Not_Samples_nor_Collections') + assert "Incorrect type" in caplog.text def test_top_index(dbm): - dbm.add_rec({'CG_ID_sample': 'Uniq_ID_123', 'total_reads': 100}, 'Samples') - dbm.add_rec({'CG_ID_sample': 'Uniq_ID_321', 'total_reads': 100}, 'Samples') - ti_returned = dbm.top_index('Samples', {'total_reads': '100'}, 'total_reads') + dbm.add_rec({'CG_ID_sample': 'Uniq_ID_123', 'total_reads':100}, 'Samples') + dbm.add_rec({'CG_ID_sample': 'Uniq_ID_321', 'total_reads':100}, 'Samples') + ti_returned = dbm.top_index('Samples', {'total_reads':'100'}, 'total_reads') diff --git a/tests/test_jobcreator.py b/tests/test_jobcreator.py index c3ad7c51..f401395f 100644 --- a/tests/test_jobcreator.py +++ b/tests/test_jobcreator.py @@ -16,96 +16,80 @@ from microSALT import preset_config, logger from microSALT.cli import root - @pytest.fixture def testdata(): - testdata = os.path.abspath( - os.path.join(pathlib.Path(__file__).parent.parent, 'tests/testdata/sampleinfo_samples.json')) - #Check if release install exists - for entry in os.listdir(get_python_lib()): - if 'microSALT-' in entry: - testdata = os.path.abspath( - os.path.join(os.path.expandvars('$CONDA_PREFIX'), 'testdata/sampleinfo_samples.json')) - with open(testdata) as json_file: - data = json.load(json_file) - return data - + testdata = os.path.abspath(os.path.join(pathlib.Path(__file__).parent.parent, 'tests/testdata/sampleinfo_samples.json')) + #Check if release install exists + for entry in os.listdir(get_python_lib()): + if 'microSALT-' in entry: + testdata = os.path.abspath(os.path.join(os.path.expandvars('$CONDA_PREFIX'), 'testdata/sampleinfo_samples.json')) + with open(testdata) as json_file: + data = json.load(json_file) + return data def fake_search(int): - return "fake" - - + return "fake" @patch('os.listdir') @patch('os.stat') @patch('gzip.open') -@pytest.mark.xfail(reason="Can no longer fetch from databases without authenticating") def test_verify_fastq(gopen, stat, listdir, testdata): - listdir.return_value = ["ACC6438A3_HVMHWDSXX_L1_1.fastq.gz", "ACC6438A3_HVMHWDSXX_L1_2.fastq.gz", - "ACC6438A3_HVMHWDSXX_L2_2.fastq.gz", "ACC6438A3_HVMHWDSXX_L2_2.fastq.gz"] - stata = mock.MagicMock() - stata.st_size = 2000 - stat.return_value = stata - - jc = Job_Creator(run_settings={'input': '/tmp/'}, config=preset_config, log=logger, sampleinfo=testdata) - t = jc.verify_fastq() - assert len(t) > 0 - + listdir.return_value = ["ACC6438A3_HVMHWDSXX_L1_1.fastq.gz", "ACC6438A3_HVMHWDSXX_L1_2.fastq.gz", "ACC6438A3_HVMHWDSXX_L2_2.fastq.gz", "ACC6438A3_HVMHWDSXX_L2_2.fastq.gz"] + stata = mock.MagicMock() + stata.st_size = 2000 + stat.return_value = stata + jc = Job_Creator(run_settings={'input':'/tmp/'}, config=preset_config, log=logger,sampleinfo=testdata) + t = jc.verify_fastq() + assert len(t) > 0 @patch('re.search') @patch('microSALT.utils.job_creator.glob.glob') -@pytest.mark.xfail(reason="Can no longer fetch from databases without authenticating") def test_blast_subset(glob_search, research, testdata): - jc = Job_Creator(run_settings={'input': '/tmp/'}, config=preset_config, log=logger, sampleinfo=testdata) - researcha = mock.MagicMock() - researcha.group = fake_search - research.return_value = researcha - glob_search.return_value = ["/a/a/a", "/a/a/b", "/a/a/c"] - - jc.blast_subset('mlst', '/tmp/*') - jc.blast_subset('other', '/tmp/*') - outfile = open(jc.get_sbatch(), 'r') - count = 0 - for x in outfile.readlines(): - if "blastn -db" in x: - count = count + 1 - assert count > 0 + jc = Job_Creator(run_settings={'input':'/tmp/'}, config=preset_config, log=logger,sampleinfo=testdata) + researcha = mock.MagicMock() + researcha.group = fake_search + research.return_value = researcha + glob_search.return_value = ["/a/a/a", "/a/a/b","/a/a/c"] + + jc.blast_subset('mlst', '/tmp/*') + jc.blast_subset('other', '/tmp/*') + outfile = open(jc.get_sbatch(), 'r') + count = 0 + for x in outfile.readlines(): + if "blastn -db" in x: + count = count + 1 + assert count > 0 - -@pytest.mark.xfail(reason="Can no longer fetch from databases without authenticating") -def test_create_snpsection(subproc, testdata): - #Sets up subprocess mocking - process_mock = mock.Mock() - attrs = {'communicate.return_value': ('output 123456789', 'error')} - process_mock.configure_mock(**attrs) - subproc.return_value = process_mock - - testdata = [testdata[0]] - jc = Job_Creator(run_settings={'input': ['AAA1234A1', 'AAA1234A2']}, config=preset_config, log=logger, - sampleinfo=testdata) - jc.snp_job() - outfile = open(jc.get_sbatch(), 'r') - count = 0 - for x in outfile.readlines(): - if "# SNP pair-wise distance" in x: - count = count + 1 - assert count > 0 - - -@pytest.mark.xfail(reason="Can no longer fetch from databases without authenticating") @patch('subprocess.Popen') -def test_project_job(subproc, testdata): - #Sets up subprocess mocking - process_mock = mock.Mock() - attrs = {'communicate.return_value': ('output 123456789', 'error')} - process_mock.configure_mock(**attrs) - subproc.return_value = process_mock +def test_create_snpsection(subproc,testdata): + #Sets up subprocess mocking + process_mock = mock.Mock() + attrs = {'communicate.return_value': ('output 123456789', 'error')} + process_mock.configure_mock(**attrs) + subproc.return_value = process_mock + + testdata = [testdata[0]] + jc = Job_Creator(run_settings={'input':['AAA1234A1','AAA1234A2']}, config=preset_config, log=logger,sampleinfo=testdata) + jc.snp_job() + outfile = open(jc.get_sbatch(), 'r') + count = 0 + for x in outfile.readlines(): + if "# SNP pair-wise distance" in x: + count = count + 1 + assert count > 0 - jc = Job_Creator(config=preset_config, log=logger, sampleinfo=testdata, - run_settings={'pool': ["AAA1234A1", "AAA1234A2"], 'input': '/tmp/AAA1234'}) - jc.project_job() +@patch('subprocess.Popen') +def test_project_job(subproc,testdata): + #Sets up subprocess mocking + process_mock = mock.Mock() + attrs = {'communicate.return_value': ('output 123456789', 'error')} + process_mock.configure_mock(**attrs) + subproc.return_value = process_mock + jc = Job_Creator(config=preset_config, log=logger, sampleinfo=testdata, run_settings={'pool':["AAA1234A1","AAA1234A2"], 'input':'/tmp/AAA1234'}) + jc.project_job() def test_create_collection(): - pass + pass + diff --git a/tests/test_scraper.py b/tests/test_scraper.py index 8046bce3..82689df1 100644 --- a/tests/test_scraper.py +++ b/tests/test_scraper.py @@ -14,63 +14,51 @@ from microSALT.utils.scraper import Scraper from microSALT.utils.referencer import Referencer - @pytest.fixture def testdata_prefix(): - test_path = os.path.abspath(os.path.join(pathlib.Path(__file__).parent.parent, 'tests/testdata/')) - #Check if release install exists - for entry in os.listdir(get_python_lib()): - if 'microSALT-' in entry: - test_path = os.path.abspath(os.path.join(os.path.expandvars('$CONDA_PREFIX'), 'testdata/')) - return test_path - + test_path = os.path.abspath(os.path.join(pathlib.Path(__file__).parent.parent, 'tests/testdata/')) + #Check if release install exists + for entry in os.listdir(get_python_lib()): + if 'microSALT-' in entry: + test_path = os.path.abspath(os.path.join(os.path.expandvars('$CONDA_PREFIX'), 'testdata/')) + return test_path @pytest.fixture def testdata(): - testdata = os.path.abspath( - os.path.join(pathlib.Path(__file__).parent.parent, 'tests/testdata/sampleinfo_samples.json')) - #Check if release install exists - for entry in os.listdir(get_python_lib()): - if 'microSALT-' in entry: - testdata = os.path.abspath( - os.path.join(os.path.expandvars('$CONDA_PREFIX'), 'testdata/sampleinfo_samples.json')) - with open(testdata) as json_file: - data = json.load(json_file) - return data - + testdata = os.path.abspath(os.path.join(pathlib.Path(__file__).parent.parent, 'tests/testdata/sampleinfo_samples.json')) + #Check if release install exists + for entry in os.listdir(get_python_lib()): + if 'microSALT-' in entry: + testdata = os.path.abspath(os.path.join(os.path.expandvars('$CONDA_PREFIX'), 'testdata/sampleinfo_samples.json')) + with open(testdata) as json_file: + data = json.load(json_file) + return data @pytest.fixture def scraper(testdata): - scrape_obj = Scraper(config=preset_config, log=logger, sampleinfo=testdata[0]) - return scrape_obj - + scrape_obj = Scraper(config=preset_config, log=logger,sampleinfo=testdata[0]) + return scrape_obj @pytest.fixture def init_references(testdata): - ref_obj = Referencer(config=preset_config, log=logger, sampleinfo=testdata) - ref_obj.identify_new(testdata[0].get('CG_ID_project'), project=True) - ref_obj.update_refs() + ref_obj = Referencer(config=preset_config, log=logger, sampleinfo=testdata) + ref_obj.identify_new(testdata[0].get('CG_ID_project'),project=True) + ref_obj.update_refs() - -@pytest.mark.xfail(reason="Can no longer fetch from databases without authenticating") def test_quast_scraping(scraper, testdata_prefix, caplog): - scraper.scrape_quast(filename="{}/quast_results.tsv".format(testdata_prefix)) - + scraper.scrape_quast(filename="{}/quast_results.tsv".format(testdata_prefix)) -@pytest.mark.xfail(reason="Can no longer fetch from databases without authenticating") def test_blast_scraping(scraper, testdata_prefix, caplog): - caplog.set_level(logging.DEBUG) - scraper.scrape_blast(type='seq_type', file_list=["{}/blast_single_loci.txt".format(testdata_prefix)]) - assert "candidate" in caplog.text - - caplog.clear() - hits = scraper.scrape_blast(type='resistance', file_list=["{}/blast_single_resistance.txt".format(testdata_prefix)]) - genes = [h["gene"] for h in hits] + caplog.set_level(logging.DEBUG) + scraper.scrape_blast(type='seq_type',file_list=["{}/blast_single_loci.txt".format(testdata_prefix)]) + assert "candidate" in caplog.text - assert "blaOXA-48" in genes - assert "blaVIM-4" in genes + caplog.clear() + hits = scraper.scrape_blast(type='resistance',file_list=["{}/blast_single_resistance.txt".format(testdata_prefix)]) + genes = [h["gene"] for h in hits] + assert "blaOXA-48" in genes + assert "blaVIM-4" in genes -@pytest.mark.xfail(reason="Can no longer fetch from databases without authenticating") def test_alignment_scraping(scraper, init_references, testdata_prefix): - scraper.scrape_alignment(file_list=glob.glob("{}/*.stats.*".format(testdata_prefix))) + scraper.scrape_alignment(file_list=glob.glob("{}/*.stats.*".format(testdata_prefix)))