Skip to content

Commit

Permalink
Added az data set
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Probst committed Sep 1, 2023
1 parent adc48f8 commit 2a1874e
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 7 deletions.
18 changes: 11 additions & 7 deletions scripts/encoding/encode_az_reactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from rdkit import Chem
from rdkit.Chem import rdChemReactions
from drfp import DrfpEncoder
from tqdm import tqdm


# %%
Expand Down Expand Up @@ -46,12 +47,15 @@ def get_az_rxns(fold_idx: int = 0):
r2_scores = []
rmse_scores = []

for fold_idx in range(10):
output = []

for fold_idx in tqdm(range(10)):
root_path = Path(__file__).resolve().parent
az_path = Path(root_path, "../../data/az")

train, valid, test = get_az_rxns(fold_idx)

output_splits = {}
for data, split in [(train, "train"), (valid, "valid"), (test, "test")]:
X, mapping = DrfpEncoder.encode(
data.smiles.to_numpy(),
Expand All @@ -68,11 +72,11 @@ def get_az_rxns(fold_idx: int = 0):

y = data["yield"].to_numpy()

fingerprints_file_name = Path(az_path, f"{fold_idx}-{split}-2048-3-true.pkl")
map_file_name = Path(az_path, f"{fold_idx}-{split}-2048-3-true.map.pkl")
output_splits[split] = {"X": X, "y": y, "mapping": mapping}

output.append(output_splits)

with open(map_file_name, "wb+") as f:
pickle.dump(mapping, f, protocol=pickle.HIGHEST_PROTOCOL)
out_file_name = Path(az_path, f"az-2048-3-true.pkl")

with open(fingerprints_file_name, "wb+") as f:
pickle.dump((X, y), f, protocol=pickle.HIGHEST_PROTOCOL)
with open(out_file_name, "wb+") as f:
pickle.dump(output, f)
79 changes: 79 additions & 0 deletions scripts/training/yield_prediction_az.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import pickle
from pathlib import Path
from typing import Tuple
from statistics import stdev
import numpy as np
from xgboost import XGBRegressor
from sklearn.metrics import r2_score, mean_absolute_error


def save_results(
set_name: str,
split_id: str,
file_name: str,
ground_truth: np.ndarray,
prediction: np.ndarray,
) -> None:
with open(f"{set_name}_{split_id}_{file_name}.csv", "w+") as f:
for gt, pred in zip(ground_truth, prediction):
f.write(f"{set_name},{split_id},{file_name},{gt},{pred}\n")


def predict_az():
root_path = Path(__file__).resolve().parent
az_file_path = Path(root_path, "../../data/az/az-2048-3-true.pkl")

data = pickle.load(open(az_file_path, "rb"))

r2s = []
maes = []

for i, split in enumerate(data):
X_train, y_train, X_valid, y_valid, X_test, y_test = (
split["train"]["X"],
split["train"]["y"],
split["valid"]["X"],
split["valid"]["y"],
split["test"]["X"],
split["test"]["y"],
)

# Vanilla hyperparams
model = XGBRegressor(
n_estimators=999999,
learning_rate=0.01,
max_depth=12,
min_child_weight=6,
colsample_bytree=0.6,
subsample=0.8,
random_state=42,
)

model.fit(
X_train,
y_train,
eval_set=[(X_valid, y_valid)],
early_stopping_rounds=10,
verbose=False,
)

y_pred = model.predict(X_test, ntree_limit=model.best_ntree_limit)
y_pred[y_pred < 0.0] = 0.0

# save_results("az", split, sample_file, y_test, y_pred)
r_squared = r2_score(y_test, y_pred)
mae = mean_absolute_error(y_test, y_pred)
print(f"Test {i + 1}", r_squared, mae)
r2s.append(r_squared)
maes.append(mae)

print("Tests R2:", sum(r2s) / len(r2s), stdev(r2s))
print("Tests MAE:", sum(maes) / len(maes), stdev(maes))


def main():
predict_az()


if __name__ == "__main__":
main()

0 comments on commit 2a1874e

Please sign in to comment.