From 774b03fca90dd48d19ee12ad15c174c87045a11e Mon Sep 17 00:00:00 2001 From: Anna Grim <108307071+anna-grim@users.noreply.github.com> Date: Mon, 30 Sep 2024 18:13:38 -0700 Subject: [PATCH] Refactor gnn training (#256) * minor upds * refactor: training pipeline * feat: find gcs image path * feat: feature generation in trainer * feat: validation sets in training * bug: hgraph forward passes with missing edge types * refactor: hgnn trainer * feat: functional training pipeline * bug: set validation data * refactor: combined train engine and pipeline * refactor: infernce pipeline, evaluation * moved files * upds --------- Co-authored-by: anna-grim --- src/deep_neurographs/inference.py | 4 ++-- src/deep_neurographs/utils/util.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/deep_neurographs/inference.py b/src/deep_neurographs/inference.py index de94666..630d7a2 100644 --- a/src/deep_neurographs/inference.py +++ b/src/deep_neurographs/inference.py @@ -4,14 +4,14 @@ @author: Anna Grim @email: anna.grim@alleninstitute.org -Routines for running inference with machine models that classifies edge -proposals. +Routines for running inference with a machine model that classifies edge proposals. """ from datetime import datetime from time import time from torch.nn.functional import sigmoid +from torch.utils.data import DataLoader from tqdm import tqdm import networkx as nx diff --git a/src/deep_neurographs/utils/util.py b/src/deep_neurographs/utils/util.py index e8c9a68..023b6a6 100644 --- a/src/deep_neurographs/utils/util.py +++ b/src/deep_neurographs/utils/util.py @@ -372,7 +372,7 @@ def write_txt(path, contents): f.close() -def write_to_s3(local_path, bucket_name, s3_key): +def write_to_s3(local_path, bucket_name, prefix): """ Writes a single file on local machine to an s3 bucket. @@ -382,7 +382,7 @@ def write_to_s3(local_path, bucket_name, s3_key): Path to file to be written to s3. bucket_name : str Name of s3 bucket. - s3_key : str + prefix : str Path within s3 bucket. Returns @@ -391,7 +391,7 @@ def write_to_s3(local_path, bucket_name, s3_key): """ s3 = boto3.client('s3') - s3.upload_file(local_path, bucket_name, s3_key) + s3.upload_file(local_path, bucket_name, prefix) # --- math utils ---