Skip to content

Commit 88e521e

Browse files
committed
Add base session created in SparkConnectService
1 parent 20af57c commit 88e521e

File tree

3 files changed

+65
-8
lines changed

3 files changed

+65
-8
lines changed

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ object SparkConnectService extends Logging {
436436
return
437437
}
438438

439+
sessionManager.initializeBaseSession(sc)
439440
startGRPCService()
440441
createListenerAndUI(sc)
441442

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import scala.util.control.NonFatal
2727

2828
import com.google.common.cache.CacheBuilder
2929

30-
import org.apache.spark.{SparkEnv, SparkSQLException}
30+
import org.apache.spark.{SparkContext, SparkEnv, SparkSQLException}
3131
import org.apache.spark.internal.Logging
3232
import org.apache.spark.internal.LogKeys.{INTERVAL, SESSION_HOLD_INFO}
3333
import org.apache.spark.sql.classic.SparkSession
@@ -39,6 +39,9 @@ import org.apache.spark.util.ThreadUtils
3939
*/
4040
class SparkConnectSessionManager extends Logging {
4141

42+
// Base SparkSession created from the SparkContext, used to create new isolated sessions
43+
@volatile private var baseSession: Option[SparkSession] = None
44+
4245
private val sessionStore: ConcurrentMap[SessionKey, SessionHolder] =
4346
new ConcurrentHashMap[SessionKey, SessionHolder]()
4447

@@ -48,6 +51,16 @@ class SparkConnectSessionManager extends Logging {
4851
.maximumSize(SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE))
4952
.build[SessionKey, SessionHolderInfo]()
5053

54+
/**
55+
* Initialize the base SparkSession from the provided SparkContext.
56+
* This should be called once during SparkConnectService startup.
57+
*/
58+
def initializeBaseSession(sc: SparkContext): Unit = synchronized {
59+
if (baseSession.isEmpty) {
60+
baseSession = Some(SparkSession.builder().sparkContext(sc).getOrCreate())
61+
}
62+
}
63+
5164
/** Executor for the periodic maintenance */
5265
private val scheduledExecutor: AtomicReference[ScheduledExecutorService] =
5366
new AtomicReference[ScheduledExecutorService]()
@@ -333,13 +346,7 @@ class SparkConnectSessionManager extends Logging {
333346
}
334347

335348
private def newIsolatedSession(): SparkSession = {
336-
val active = SparkSession.active
337-
if (active.sparkContext.isStopped) {
338-
assert(SparkSession.getDefaultSession.nonEmpty)
339-
SparkSession.getDefaultSession.get.newSession()
340-
} else {
341-
active.newSession()
342-
}
349+
baseSession.get.newSession()
343350
}
344351

345352
private def validateSessionCreate(key: SessionKey): Unit = {

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.scalatest.BeforeAndAfterEach
2323
import org.scalatest.time.SpanSugar._
2424

2525
import org.apache.spark.SparkSQLException
26+
import org.apache.spark.sql.SparkSession
2627
import org.apache.spark.sql.pipelines.graph.{DataflowGraph, PipelineUpdateContextImpl}
2728
import org.apache.spark.sql.pipelines.logging.PipelineEvent
2829
import org.apache.spark.sql.test.SharedSparkSession
@@ -32,6 +33,7 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA
3233
override def beforeEach(): Unit = {
3334
super.beforeEach()
3435
SparkConnectService.sessionManager.invalidateAllSessions()
36+
SparkConnectService.sessionManager.initializeBaseSession(spark.sparkContext)
3537
}
3638

3739
test("sessionId needs to be an UUID") {
@@ -171,4 +173,51 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA
171173
sessionHolder.getPipelineExecution(graphId).isEmpty,
172174
"pipeline execution was not removed")
173175
}
176+
177+
test("baseSession allows creating sessions after default session is cleared") {
178+
// Create a new session manager to test initialization
179+
val sessionManager = new SparkConnectSessionManager()
180+
181+
// Initialize the base session with the test SparkContext
182+
sessionManager.initializeBaseSession(spark.sparkContext)
183+
184+
// Clear the default and active sessions to simulate the scenario where
185+
// SparkSession.active or SparkSession.getDefaultSession would fail
186+
SparkSession.clearDefaultSession()
187+
SparkSession.clearActiveSession()
188+
189+
// Create an isolated session - this should still work because we have baseSession
190+
val key = SessionKey("user", UUID.randomUUID().toString)
191+
val sessionHolder = sessionManager.getOrCreateIsolatedSession(key, None)
192+
193+
// Verify the session was created successfully
194+
assert(sessionHolder != null)
195+
assert(sessionHolder.session != null)
196+
197+
// Clean up
198+
sessionManager.closeSession(key)
199+
}
200+
201+
test("initializeBaseSession is idempotent") {
202+
// Create a new session manager to test initialization
203+
val sessionManager = new SparkConnectSessionManager()
204+
205+
// Initialize the base session multiple times
206+
sessionManager.initializeBaseSession(spark.sparkContext)
207+
val key1 = SessionKey("user1", UUID.randomUUID().toString)
208+
val sessionHolder1 = sessionManager.getOrCreateIsolatedSession(key1, None)
209+
val baseSessionUUID1 = sessionHolder1.session.sessionUUID
210+
211+
// Initialize again - should not change the base session
212+
sessionManager.initializeBaseSession(spark.sparkContext)
213+
val key2 = SessionKey("user2", UUID.randomUUID().toString)
214+
val sessionHolder2 = sessionManager.getOrCreateIsolatedSession(key2, None)
215+
216+
// Both sessions should be isolated from each other
217+
assert(sessionHolder1.session.sessionUUID != sessionHolder2.session.sessionUUID)
218+
219+
// Clean up
220+
sessionManager.closeSession(key1)
221+
sessionManager.closeSession(key2)
222+
}
174223
}

0 commit comments

Comments
 (0)