Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/python/tools/preprocess/converters/readers/spark_readers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from pathlib import Path

from pyspark.sql import SparkSession
Expand All @@ -9,7 +10,7 @@
class SparkDelimitedFileReader(Reader):
def __init__(
self,
spark: SparkSession,
spark: SparkSession.builder.appName("marius_spark").getOrCreate(),
train_edges: Path,
valid_edges: Path = None,
test_edges: Path = None,
Expand Down Expand Up @@ -39,6 +40,25 @@ def __init__(

self.spark = spark

if str(train_edges).startswith("s3a://"):
if (
"AWS_ACCESS_KEY_ID" not in os.environ
or "AWS_SECRET_ACCESS_KEY" not in os.environ
or "S3_BUCKET" not in os.environ
):
print(
"Edge path is an s3 path, but required env variables not set. {}, {} and {} need to be set".format(
"AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "S3_BUCKET"
)
)
exit()
self.spark._jsc.hadoopConfiguration().set(
"fs.s3a.aws.credentials.provider", "org.apache.hadoop.fs.s3a.SimpleAWSCredentialsProvider"
)
self.spark._jsc.hadoopConfiguration().set("fs.s3a.access.key", os.getenv("AWS_ACCESS_KEY_ID"))
self.spark._jsc.hadoopConfiguration().set("fs.s3a.secret.key", os.getenv("AWS_SECRET_ACCESS_KEY"))
self.spark._jsc.hadoopConfiguration().set("fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem")

self.train_edges = train_edges
self.valid_edges = valid_edges
self.test_edges = test_edges
Expand Down
3 changes: 2 additions & 1 deletion src/python/tools/preprocess/converters/spark_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(
spark_executor_memory: str = "4g",
):
self.output_dir = output_dir
self.use_s3 = True if str(train_edges).startswith("s3a://") else False

self.spark = (
SparkSession.builder.appName(SPARK_APP_NAME)
Expand All @@ -135,7 +136,7 @@ def __init__(
else:
self.partitioner = None

self.writer = SparkWriter(self.spark, self.output_dir, partitioned_evaluation)
self.writer = SparkWriter(self.spark, self.output_dir, self.use_s3, partitioned_evaluation)

self.train_split = None
self.valid_split = None
Expand Down
Loading