From 3d76e0bbc30f735665cc4d84659c5737fb8af08a Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 21 Feb 2025 15:56:48 +0800 Subject: [PATCH] [SPARK-51275][PYTHON][ML][CONNECT] Session propagation in python readwrite ### What changes were proposed in this pull request? Session propagation in python readwrite ### Why are the changes needed? to avoid session recreation ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing test should cover ### Was this patch authored or co-authored using generative AI tooling? no Closes #50035 from zhengruifeng/py_ml_sc_session. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- python/pyspark/ml/pipeline.py | 8 +++++--- python/pyspark/ml/util.py | 8 ++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 18f537cf197a8..b77392a50c7f2 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -420,10 +420,11 @@ def saveImpl( """ stageUids = [stage.uid for stage in stages] jsonParams = {"stageUids": stageUids, "language": "Python"} - DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams) + spark = cast(SparkSession, sc) if hasattr(sc, "createDataFrame") else SparkSession.active() + DefaultParamsWriter.saveMetadata(instance, path, spark, paramMap=jsonParams) stagesDir = os.path.join(path, "stages") for index, stage in enumerate(stages): - cast(MLWritable, stage).write().save( + cast(MLWritable, stage).write().session(spark).save( PipelineSharedReadWrite.getStagePath(stage.uid, index, len(stages), stagesDir) ) @@ -443,12 +444,13 @@ def load( """ stagesDir = os.path.join(path, "stages") stageUids = metadata["paramMap"]["stageUids"] + spark = cast(SparkSession, sc) if hasattr(sc, "createDataFrame") else SparkSession.active() stages = [] for index, stageUid in enumerate(stageUids): stagePath = PipelineSharedReadWrite.getStagePath( stageUid, index, len(stageUids), stagesDir ) - stage: "PipelineStage" = DefaultParamsReader.loadParamsInstance(stagePath, sc) + stage: "PipelineStage" = DefaultParamsReader.loadParamsInstance(stagePath, spark) stages.append(stage) return (metadata["uid"], stages) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 4919b828a35cf..6b3d6101c249f 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -462,7 +462,7 @@ def sparkSession(self) -> SparkSession: Returns the user-specified Spark Session or the default. """ if self._sparkSession is None: - self._sparkSession = SparkSession._getActiveSessionOrCreate() + self._sparkSession = SparkSession.active() assert self._sparkSession is not None return self._sparkSession @@ -809,10 +809,10 @@ def saveMetadata( If given, this is saved in the "paramMap" field. """ metadataPath = os.path.join(path, "metadata") + spark = cast(SparkSession, sc) if hasattr(sc, "createDataFrame") else SparkSession.active() metadataJson = DefaultParamsWriter._get_metadata_to_save( - instance, sc, extraMetadata, paramMap + instance, spark, extraMetadata, paramMap ) - spark = sc if isinstance(sc, SparkSession) else SparkSession._getActiveSessionOrCreate() spark.createDataFrame([(metadataJson,)], schema=["value"]).coalesce(1).write.text( metadataPath ) @@ -932,7 +932,7 @@ def loadMetadata( If non empty, this is checked against the loaded metadata. """ metadataPath = os.path.join(path, "metadata") - spark = sc if isinstance(sc, SparkSession) else SparkSession._getActiveSessionOrCreate() + spark = cast(SparkSession, sc) if hasattr(sc, "createDataFrame") else SparkSession.active() metadataStr = spark.read.text(metadataPath).first()[0] # type: ignore[index] loadedVals = DefaultParamsReader._parseMetaData(metadataStr, expectedClassName) return loadedVals