Skip to content

Commit

Permalink
Update type annotations (#9917)
Browse files Browse the repository at this point in the history
Unblocks merging to master.
  • Loading branch information
akihironitta authored Jan 4, 2025
1 parent c300f38 commit 5d1b898
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
# Skip workflow if only certain files have been changed.
- name: Get changed files
id: changed-files-specific
uses: tj-actions/changed-files@v41
uses: tj-actions/changed-files@v45
with:
files: |
benchmark/**
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/datasets/git_mol_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def process(self) -> None:
img = self.img_transform(img).unsqueeze(0)
# graph
atom_features_list = []
for atom in mol.GetAtoms(): # type: ignore
for atom in mol.GetAtoms():
atom_feature = [
safe_index(
allowable_features['possible_atomic_num_list'],
Expand Down Expand Up @@ -219,7 +219,7 @@ def process(self) -> None:

edges_list = []
edge_features_list = []
for bond in mol.GetBonds(): # type: ignore
for bond in mol.GetBonds():
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
edge_feature = [
safe_index(
Expand Down
9 changes: 6 additions & 3 deletions torch_geometric/datasets/molecule_gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,10 @@ def clean_up_description(description: str) -> str:
return first_sentence


def extract_name(name_raw: str, description: str) -> Tuple[str, str, str]:
def extract_name(
name_raw: str,
description: str,
) -> Tuple[Optional[str], str, str]:
first_sentence = clean_up_description(description)

splitter = ' -- -- '
Expand Down Expand Up @@ -446,12 +449,12 @@ def extract_one_SDF_file(block_id: int) -> None:

x: torch.Tensor = torch.tensor([
types[atom.GetSymbol()] if atom.GetSymbol() in types else 5
for atom in m.GetAtoms() # type: ignore
for atom in m.GetAtoms()
])
x = one_hot(x, num_classes=len(types), dtype=torch.float)

rows, cols, edge_types = [], [], []
for bond in m.GetBonds(): # type: ignore
for bond in m.GetBonds():
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
edge_types += [bonds[bond.GetBondType()]] * 2
rows += [i, j]
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/utils/smiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def from_rdmol(mol: Any) -> 'torch_geometric.data.Data':
assert isinstance(mol, Chem.Mol)

xs: List[List[int]] = []
for atom in mol.GetAtoms(): # type: ignore
for atom in mol.GetAtoms():
row: List[int] = []
row.append(x_map['atomic_num'].index(atom.GetAtomicNum()))
row.append(x_map['chirality'].index(str(atom.GetChiralTag())))
Expand All @@ -108,7 +108,7 @@ def from_rdmol(mol: Any) -> 'torch_geometric.data.Data':
x = torch.tensor(xs, dtype=torch.long).view(-1, 9)

edge_indices, edge_attrs = [], []
for bond in mol.GetBonds(): # type: ignore
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()

Expand Down

0 comments on commit 5d1b898

Please sign in to comment.