Skip to content

Commit

Permalink
Split up requirements for rxn and graph
Browse files Browse the repository at this point in the history
feaurisations/kernels.
  • Loading branch information
leojklarner committed Nov 3, 2023
1 parent a6da1ab commit 8f7aacf
Show file tree
Hide file tree
Showing 9 changed files with 20 additions and 24 deletions.
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATES/bug_report.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ If applicable, add screenshots to help explain your problem.

- OS: [e.g. iOS]
- Python Version
- Graphein Version [e.g. 22] & how it was installed
- Gauche Version [e.g. 1.0.0] & how it was installed

**Additional context**
Add any other context about the problem here.
1 change: 0 additions & 1 deletion .requirements/base.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ scikit-learn
rdkit
tqdm
selfies
graphein
torch
gpytorch
botorch
1 change: 1 addition & 0 deletions .requirements/graphs.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
graphein
2 changes: 2 additions & 0 deletions .requirements/rxn.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
rxnfp
drfp
13 changes: 3 additions & 10 deletions gauche/dataloader/molprop_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@ def validate(
"""

invalid_mols = np.array(
[
True if MolFromSmiles(x) is None else False
for x in self.features
]
[True if MolFromSmiles(x) is None else False for x in self.features]
)
if np.any(invalid_mols):
print(
Expand Down Expand Up @@ -88,9 +85,7 @@ def validate(
for smiles in self.features
]

def featurize(
self, representation: Union[str, Callable], **kwargs
) -> None:
def featurize(self, representation: Union[str, Callable], **kwargs) -> None:
"""Transforms SMILES into the specified molecular representation.
:param representation: the desired molecular representation, one of [ecfp_fingerprints, fragments, ecfp_fragprints, molecular_graphs, bag_of_smiles, bag_of_selfies, mqn] or a callable that takes a list of SMILES strings as input and returns the desired featurization.
Expand Down Expand Up @@ -151,9 +146,7 @@ def featurize(
elif representation == "bag_of_selfies":
from gauche.representations.strings import bag_of_characters

self.features = bag_of_characters(
self.features, selfies=True, **kwargs
)
self.features = bag_of_characters(self.features, selfies=True, **kwargs)

elif representation == "bag_of_smiles":
from gauche.representations.strings import bag_of_characters
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ force_grid_wrap = 0
include_trailing_comma = true
line_length = 79 # match our custom config above
multi_line_output = 3
src_paths = ["graphein", "test"]
src_paths = ["gauche", "test"]
float_to_top = true
use_parentheses = true

Expand Down
15 changes: 7 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

HERE = os.path.abspath(os.path.dirname(__file__))


def read(*parts):
# intentionally *not* adding an encoding option to open
return codecs.open(os.path.join(HERE, *parts), "r").read()
Expand Down Expand Up @@ -54,25 +55,23 @@ def read_requirements(*parts) -> List[str]:

INSTALL_REQUIRES: List[str] = read_requirements(".requirements/base.in")
EXTRA_REQUIRES: Dict[str, List[str]] = {
"rxn": read_requirements(".requirements/rxn.in"),
"graphs": read_requirements(".requirements/graphs.in"),
"dev": read_requirements(".requirements/dev.in"),
"docs": read_requirements(".requirements/docs.in"),
"cpu": read_requirements(".requirements/cpu.in"),
"cu116": read_requirements(".requirements/cu116.in"),
"cu117": read_requirements(".requirements/cu117.in"),
}

# Add all requires
all_requires: List[str] = []
for k, v in EXTRA_REQUIRES.items():
if k not in ["cu116", "cu117"]:
if k not in ["dev", "docs"]:
all_requires.extend(v)
EXTRA_REQUIRES["all"] = list(set(all_requires))


def find_version(*file_paths):
version_file = read(*file_paths)
version_match = re.search(
r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M
)
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M)
if version_match:
return version_match.group(1)
raise RuntimeError("Unable to find version string.")
Expand All @@ -86,7 +85,7 @@ def find_version(*file_paths):
setup(
name="gauche",
version=version,
description="Gaussian Process Library for Molecules, Proteins and General Chemistry in PyTorch.",
description="Gaussian Process Library for Molecules, Chemical Reactions and Proteins.",
long_description=readme,
long_description_content_type="text/markdown",
license="MIT",
Expand Down
6 changes: 4 additions & 2 deletions tests/test_dataloaders/test_molprop_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def test_benchmark_loader(dataset, representation, kwargs):
"""

dataset_root = os.path.abspath(
os.path.join("..", "..", "data", "property_prediction")
os.path.join(
os.path.abspath(__file__), "..", "..", "..", "data", "property_prediction"
)
)

# load through benchmark loading method
Expand Down Expand Up @@ -89,7 +91,7 @@ def test_invalid_data():

dataloader = MolPropLoader()
dataloader.read_csv(
path=os.path.join(os.getcwd(), "invalid_molprop_data.csv"),
path=os.path.join(os.path.abspath(__file__), "invalid_molprop_data.csv"),
smiles_column="SMILES",
label_column="labels",
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataloaders/test_reaction_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_invalid_data():

dataloader = ReactionLoader()
dataloader.read_csv(
path=os.path.join(os.getcwd(), "invalid_reaction_data.csv"),
path=os.path.join(os.path.abspath(__file__), "invalid_reaction_data.csv"),
reactant_column=["ligand", "additive", "base", "aryl halide"],
label_column="yield",
)
Expand Down

0 comments on commit 8f7aacf

Please sign in to comment.