From 7d96114c702076b6981aadd9c3db30047b28923e Mon Sep 17 00:00:00 2001 From: Ryan Date: Thu, 14 Nov 2024 16:18:48 -0600 Subject: [PATCH] Added check on seq length when running on GPU to catch instances where sequnces are longer than 65535, which cannot run on CUDA. --- metapredict/backend/meta_tools.py | 31 +++++++++++++++++++++++++++++++ metapredict/backend/predictor.py | 8 +++++++- metapredict/parameters.py | 3 +++ 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/metapredict/backend/meta_tools.py b/metapredict/backend/meta_tools.py index c4ec6fe..f27bf99 100644 --- a/metapredict/backend/meta_tools.py +++ b/metapredict/backend/meta_tools.py @@ -457,3 +457,34 @@ def write_caid_format(input_dict, output_path, version): current_output.write(f'{res_and_score_index+1}\t{cur_residue}\t{write_score}\t{cur_binary}\n') current_output.close() + +# check max length +def exceeds_max_length(data, max_length=65535): + """ + Recursively checks if a string, any element in a list, or any value in a dictionary + exceeds the given maximum length. + + Parameters + ---------- + data : str, list, dict, or any + The input data to check. Can be a string, list, or dictionary. + max_length : int, optional + The maximum allowed length (default is 65535 characters). + + Returns + ------- + bool + True if any string exceeds the max_length, False otherwise. + """ + if isinstance(data, str): + # Check if the string exceeds the max length + return len(data) > max_length + elif isinstance(data, list): + # Check each element in the list + return any(exceeds_max_length(item, max_length) for item in data) + elif isinstance(data, dict): + # Check each value in the dictionary + return any(exceeds_max_length(value, max_length) for value in data.values()) + else: + # If not a string, list, or dict, it's not something we need to check + return False \ No newline at end of file diff --git a/metapredict/backend/predictor.py b/metapredict/backend/predictor.py index f56f155..d8bd9d0 100644 --- a/metapredict/backend/predictor.py +++ b/metapredict/backend/predictor.py @@ -19,10 +19,11 @@ import gc # local imports +from metapredict.backend.meta_tools import exceeds_max_length from metapredict.backend.data_structures import DisorderObject as _DisorderObject from metapredict.backend import domain_definition as _domain_definition from metapredict.backend.network_parameters import metapredict_networks, pplddt_networks -from metapredict.parameters import DEFAULT_NETWORK, DEFAULT_NETWORK_PLDDT +from metapredict.parameters import DEFAULT_NETWORK, DEFAULT_NETWORK_PLDDT, MAX_CUDA_LENGTH from metapredict.backend import encode_sequence from metapredict.backend import architectures from metapredict.metapredict_exceptions import MetapredictError @@ -560,6 +561,11 @@ def predict(inputs, else: device_string = check_device(use_device, default_device=default_to_device) + # check if using gpu, specifically cuda + if 'cuda' in device_string: + if exceeds_max_length(inputs, max_length=MAX_CUDA_LENGTH): + raise MetapredictError(f'The input sequence is too long to run on GPU. The max length for a sequence on a CUDA GPU is {MAX_CUDA_LENGTH}.\nPlease use CPU if you want to run sequences longer than 65535 amino acids.') + # set device device=torch.device(device_string) diff --git a/metapredict/parameters.py b/metapredict/parameters.py index ad3704e..3cb7180 100644 --- a/metapredict/parameters.py +++ b/metapredict/parameters.py @@ -9,3 +9,6 @@ # set the current default network for metapredict. DEFAULT_NETWORK = 'V3' DEFAULT_NETWORK_PLDDT = 'V2' + +# various constraints on predictions we've run across +MAX_CUDA_LENGTH=65535