Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions openfold3/core/data/primitives/sequence/msa.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def concatenate(
f"deletion {d1.shape[1]} vs {d2.shape[1]})."
)
# Preserve metadata if both are list/ndarray
if isinstance(self.metadata, (list, np.ndarray)) and isinstance(
msa_array.metadata, (list, np.ndarray)
if isinstance(self.metadata, list | np.ndarray) and isinstance(
msa_array.metadata, list | np.ndarray
):
metadata_concat_fn = partial(np.concatenate, axis=0)
else:
Expand Down Expand Up @@ -223,7 +223,7 @@ def multi_concatenate(
# metadata: can only stitch if all are array-like
if all(isinstance(md, pd.DataFrame) for md in metas):
meta_concat = pd.DataFrame() # pd.concat(metas, ignore_index=True)
elif all(isinstance(md, (list, np.ndarray)) for md in metas):
elif all(isinstance(md, list | np.ndarray) for md in metas):
meta_concat = np.concatenate([np.asarray(md) for md in metas], axis=0)
else:
meta_concat = pd.DataFrame()
Expand Down
2 changes: 1 addition & 1 deletion openfold3/core/data/primitives/structure/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ class AtomArrayView:
"""Container to access underlying arrays holding AtomArray attributes."""

def __init__(self, atom_array: AtomArray, indices: np.ndarray | slice):
if not isinstance(indices, (np.ndarray, slice)):
if not isinstance(indices, np.ndarray | slice):
raise ValueError(
"The indices argument must be a NumPy array or a slice object."
)
Expand Down
14 changes: 10 additions & 4 deletions openfold3/core/data/tools/colabfold_msa_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,17 +711,23 @@ def query_format_main(self):

# Read template alignments if the file exists and has content
template_alignments_file = self.output_directory / "raw/main/pdb70.m8"
if template_alignments_file.exists() and template_alignments_file.stat().st_size > 0:
if (
template_alignments_file.exists()
and template_alignments_file.stat().st_size > 0
):
template_alignments = pd.read_csv(
template_alignments_file, sep="\t", header=None
)
m_with_templates = set(template_alignments[0])
else:
# pdb70.m8 downloaded by Colabfold returned empty - No template alignments available
# pdb70.m8 downloaded by Colabfold returned empty
# --> No template alignments available
# Create empty DataFrame with expected column structure (at least column 0)
# to match the structure when file is read with header=None.
logger.warning(f"Colabfold returned no templates. \
Proceeding without template alignments for this batch.")
logger.warning(
"Colabfold returned no templates. \
Proceeding without template alignments for this batch."
)
template_alignments = pd.DataFrame()
m_with_templates = set()

Expand Down
2 changes: 1 addition & 1 deletion openfold3/core/utils/chunk_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def _compare_arg_caches(self, ac1, ac2):
consistent = True
for a1, a2 in zip(ac1, ac2, strict=True):
assert type(a1) is type(a2)
if isinstance(a1, (list, tuple)):
if isinstance(a1, list | tuple):
consistent &= self._compare_arg_caches(a1, a2)
elif isinstance(a1, dict):
a1_items = [v for _, v in sorted(a1.items(), key=lambda x: x[0])]
Expand Down
2 changes: 1 addition & 1 deletion openfold3/core/utils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def dict_multimap(fn, dicts):
new_dict[k] = [
dict_multimap(fn, [x[idx] for x in all_v]) for idx in range(len(v))
]
elif isinstance(v, (AtomArray, str)):
elif isinstance(v, AtomArray | str):
new_dict[k] = all_v
else:
new_dict[k] = fn(all_v)
Expand Down
18 changes: 17 additions & 1 deletion openfold3/hacks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import os
from pathlib import Path

Expand All @@ -8,7 +9,22 @@ def prep_deepspeed():


def prep_cutlass():
# apparently need to set the headers for cutlass
cutlass_lib_is_installed = importlib.util.find_spec("cutlass_library") is not None
cutlass_path = Path(os.environ.get("CUTLASS_PATH", "placeholder"))

# This workaround is used when the conda environment is created with the
# environments/production.yml + installation of cutlass repo
if not cutlass_lib_is_installed:
if not cutlass_path.exists():
raise OSError(
"CUTLASS_PATH environment variable is not set to a valid path, "
"and cutlass_library is not installed. Please install nvidia-cutlass"
"via pip or set CUTLASS_PATH to the root of a local cutlass clone."
)

return

# otherwise, apparently need to set the headers for cutlass
import cutlass_library

headers_dir = Path(cutlass_library.__file__).parent / "source/include"
Expand Down
32 changes: 15 additions & 17 deletions openfold3/tests/test_colabfold_msa.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _make_dummy_template_file(path: Path):

@staticmethod
def _make_empty_template_file(path: Path):
"""Create an empty pdb70.m8 file to simulate ColabFold returning empty template file."""
"""Create an empty pdb70.m8 file to simulate ColabFold empty templates."""
raw_main_dir = path / "raw" / "main"
raw_main_dir.mkdir(parents=True, exist_ok=True)
# Create an empty file (0 bytes)
Expand Down Expand Up @@ -333,10 +333,9 @@ def test_empty_m8_file_handling(
"""Test that empty pdb70.m8 file is handled gracefully without crashing."""
test_sequence = "TESTSEQUENCE"
query = self._construct_monomer_query(test_sequence)

# Create an empty pdb70.m8 file (0 bytes) to simulate ColabFold returning empty template file

self._make_empty_template_file(tmp_path)

mapper = collect_colabfold_msa_data(query)
runner = ColabFoldQueryRunner(
colabfold_mapper=mapper,
Expand All @@ -345,19 +344,19 @@ def test_empty_m8_file_handling(
user_agent="test-agent",
host_url="https://dummy.url",
)

# Should not raise EmptyDataError or any other exception
runner.query_format_main()

# Verify MSA processing still works
expected_unpaired_dir = tmp_path / "main"
assert expected_unpaired_dir.exists(), "Expected main MSA directory to exist"

expected_file = f"{get_sequence_hash(test_sequence)}.npz"
assert (expected_unpaired_dir / expected_file).exists(), (
f"Expected MSA file {expected_file} to exist"
)

# Verify no template files are created (since m8 file is empty)
template_alignments_dir = tmp_path / "template"
if template_alignments_dir.exists():
Expand All @@ -366,7 +365,7 @@ def test_empty_m8_file_handling(
assert len(template_files) == 0, (
"Expected no template files to be created when m8 file is empty"
)

# Test preprocess_colabfold_msas with empty template file
msa_compute_settings = MsaComputationSettings(
msa_file_format="npz",
Expand All @@ -376,25 +375,24 @@ def test_empty_m8_file_handling(
msa_output_directory=tmp_path,
cleanup_msa_dir=False,
)

# Call preprocess_colabfold_msas - should not raise any exception
processed_query_set = preprocess_colabfold_msas(
inference_query_set=query,
compute_settings=msa_compute_settings
inference_query_set=query, compute_settings=msa_compute_settings
)

# Verify that template fields are None/empty for all chains
for query_name, query_obj in processed_query_set.queries.items():
for chain in query_obj.chains:
assert chain.template_alignment_file_path is None, (
f"Expected template_alignment_file_path to be None for chain "
f"{chain.chain_ids} of query {query_name} when template file is empty, "
f"but got {chain.template_alignment_file_path}"
f"{chain.chain_ids} of query {query_name} when template file "
f"is empty, but got {chain.template_alignment_file_path}"
)
assert chain.template_entry_chain_ids is None, (
f"Expected template_entry_chain_ids to be None for chain "
f"{chain.chain_ids} of query {query_name} when template file is empty, "
f"but got {chain.template_entry_chain_ids}"
f"{chain.chain_ids} of query {query_name} when template file"
f"is empty, but got {chain.template_entry_chain_ids}"
)


Expand Down
Loading