Skip to content

Commit

Permalink
Refactor gnn training (#256)
Browse files Browse the repository at this point in the history
* minor upds

* refactor: training pipeline

* feat: find gcs image path

* feat: feature generation in trainer

* feat: validation sets in training

* bug: hgraph forward passes with missing edge types

* refactor: hgnn trainer

* feat: functional training pipeline

* bug: set validation data

* refactor: combined train engine and pipeline

* refactor: infernce pipeline, evaluation

* moved files

* upds

---------

Co-authored-by: anna-grim <anna.grim@alleninstitute.org>
  • Loading branch information
anna-grim and anna-grim authored Oct 1, 2024
1 parent 4f04abc commit 774b03f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/deep_neurographs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
@author: Anna Grim
@email: anna.grim@alleninstitute.org
Routines for running inference with machine models that classifies edge
proposals.
Routines for running inference with a machine model that classifies edge proposals.
"""

from datetime import datetime
from time import time
from torch.nn.functional import sigmoid
from torch.utils.data import DataLoader
from tqdm import tqdm

import networkx as nx
Expand Down
6 changes: 3 additions & 3 deletions src/deep_neurographs/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def write_txt(path, contents):
f.close()


def write_to_s3(local_path, bucket_name, s3_key):
def write_to_s3(local_path, bucket_name, prefix):
"""
Writes a single file on local machine to an s3 bucket.
Expand All @@ -382,7 +382,7 @@ def write_to_s3(local_path, bucket_name, s3_key):
Path to file to be written to s3.
bucket_name : str
Name of s3 bucket.
s3_key : str
prefix : str
Path within s3 bucket.
Returns
Expand All @@ -391,7 +391,7 @@ def write_to_s3(local_path, bucket_name, s3_key):
"""
s3 = boto3.client('s3')
s3.upload_file(local_path, bucket_name, s3_key)
s3.upload_file(local_path, bucket_name, prefix)


# --- math utils ---
Expand Down

0 comments on commit 774b03f

Please sign in to comment.