Skip to content

Commit

Permalink
Added check on seq length when running on GPU to catch instances wher…
Browse files Browse the repository at this point in the history
…e sequnces are longer than 65535, which cannot run on CUDA.
  • Loading branch information
ryanemenecker committed Nov 14, 2024
1 parent af05ea5 commit 7d96114
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 1 deletion.
31 changes: 31 additions & 0 deletions metapredict/backend/meta_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,3 +457,34 @@ def write_caid_format(input_dict, output_path, version):
current_output.write(f'{res_and_score_index+1}\t{cur_residue}\t{write_score}\t{cur_binary}\n')

current_output.close()

# check max length
def exceeds_max_length(data, max_length=65535):
"""
Recursively checks if a string, any element in a list, or any value in a dictionary
exceeds the given maximum length.
Parameters
----------
data : str, list, dict, or any
The input data to check. Can be a string, list, or dictionary.
max_length : int, optional
The maximum allowed length (default is 65535 characters).
Returns
-------
bool
True if any string exceeds the max_length, False otherwise.
"""
if isinstance(data, str):
# Check if the string exceeds the max length
return len(data) > max_length
elif isinstance(data, list):
# Check each element in the list
return any(exceeds_max_length(item, max_length) for item in data)
elif isinstance(data, dict):
# Check each value in the dictionary
return any(exceeds_max_length(value, max_length) for value in data.values())
else:
# If not a string, list, or dict, it's not something we need to check
return False
8 changes: 7 additions & 1 deletion metapredict/backend/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
import gc

# local imports
from metapredict.backend.meta_tools import exceeds_max_length
from metapredict.backend.data_structures import DisorderObject as _DisorderObject
from metapredict.backend import domain_definition as _domain_definition
from metapredict.backend.network_parameters import metapredict_networks, pplddt_networks
from metapredict.parameters import DEFAULT_NETWORK, DEFAULT_NETWORK_PLDDT
from metapredict.parameters import DEFAULT_NETWORK, DEFAULT_NETWORK_PLDDT, MAX_CUDA_LENGTH
from metapredict.backend import encode_sequence
from metapredict.backend import architectures
from metapredict.metapredict_exceptions import MetapredictError
Expand Down Expand Up @@ -560,6 +561,11 @@ def predict(inputs,
else:
device_string = check_device(use_device, default_device=default_to_device)

# check if using gpu, specifically cuda
if 'cuda' in device_string:
if exceeds_max_length(inputs, max_length=MAX_CUDA_LENGTH):
raise MetapredictError(f'The input sequence is too long to run on GPU. The max length for a sequence on a CUDA GPU is {MAX_CUDA_LENGTH}.\nPlease use CPU if you want to run sequences longer than 65535 amino acids.')

# set device
device=torch.device(device_string)

Expand Down
3 changes: 3 additions & 0 deletions metapredict/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@
# set the current default network for metapredict.
DEFAULT_NETWORK = 'V3'
DEFAULT_NETWORK_PLDDT = 'V2'

# various constraints on predictions we've run across
MAX_CUDA_LENGTH=65535

0 comments on commit 7d96114

Please sign in to comment.