Skip to content

Commit c2df6a1

Browse files
committed
Add base session created in SparkConnectService
1 parent 20af57c commit c2df6a1

File tree

3 files changed

+59
-8
lines changed

3 files changed

+59
-8
lines changed

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ object SparkConnectService extends Logging {
436436
return
437437
}
438438

439+
sessionManager.initializeBaseSession(sc)
439440
startGRPCService()
440441
createListenerAndUI(sc)
441442

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import scala.util.control.NonFatal
2727

2828
import com.google.common.cache.CacheBuilder
2929

30-
import org.apache.spark.{SparkEnv, SparkSQLException}
30+
import org.apache.spark.{SparkContext, SparkEnv, SparkSQLException}
3131
import org.apache.spark.internal.Logging
3232
import org.apache.spark.internal.LogKeys.{INTERVAL, SESSION_HOLD_INFO}
3333
import org.apache.spark.sql.classic.SparkSession
@@ -39,6 +39,9 @@ import org.apache.spark.util.ThreadUtils
3939
*/
4040
class SparkConnectSessionManager extends Logging {
4141

42+
// Base SparkSession created from the SparkContext, used to create new isolated sessions
43+
@volatile private var baseSession: Option[SparkSession] = None
44+
4245
private val sessionStore: ConcurrentMap[SessionKey, SessionHolder] =
4346
new ConcurrentHashMap[SessionKey, SessionHolder]()
4447

@@ -48,6 +51,16 @@ class SparkConnectSessionManager extends Logging {
4851
.maximumSize(SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE))
4952
.build[SessionKey, SessionHolderInfo]()
5053

54+
/**
55+
* Initialize the base SparkSession from the provided SparkContext.
56+
* This should be called once during SparkConnectService startup.
57+
*/
58+
def initializeBaseSession(sc: SparkContext): Unit = synchronized {
59+
if (baseSession.isEmpty) {
60+
baseSession = Some(SparkSession.builder().sparkContext(sc).getOrCreate())
61+
}
62+
}
63+
5164
/** Executor for the periodic maintenance */
5265
private val scheduledExecutor: AtomicReference[ScheduledExecutorService] =
5366
new AtomicReference[ScheduledExecutorService]()
@@ -333,13 +346,7 @@ class SparkConnectSessionManager extends Logging {
333346
}
334347

335348
private def newIsolatedSession(): SparkSession = {
336-
val active = SparkSession.active
337-
if (active.sparkContext.isStopped) {
338-
assert(SparkSession.getDefaultSession.nonEmpty)
339-
SparkSession.getDefaultSession.get.newSession()
340-
} else {
341-
active.newSession()
342-
}
349+
baseSession.get.newSession()
343350
}
344351

345352
private def validateSessionCreate(key: SessionKey): Unit = {

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA
3232
override def beforeEach(): Unit = {
3333
super.beforeEach()
3434
SparkConnectService.sessionManager.invalidateAllSessions()
35+
SparkConnectService.sessionManager.initializeBaseSession(spark.sparkContext)
3536
}
3637

3738
test("sessionId needs to be an UUID") {
@@ -171,4 +172,46 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA
171172
sessionHolder.getPipelineExecution(graphId).isEmpty,
172173
"pipeline execution was not removed")
173174
}
175+
176+
test("initializeBaseSession initializes base session from SparkContext") {
177+
// Create a new session manager to test initialization
178+
val sessionManager = new SparkConnectSessionManager()
179+
180+
// Initialize the base session with the test SparkContext
181+
sessionManager.initializeBaseSession(spark.sparkContext)
182+
183+
// Create an isolated session and verify it was created successfully
184+
val key = SessionKey("user", UUID.randomUUID().toString)
185+
val sessionHolder = sessionManager.getOrCreateIsolatedSession(key, None)
186+
187+
// Verify the session was created and is not the same as the active session
188+
assert(sessionHolder != null)
189+
assert(sessionHolder.session.sessionUUID != spark.sessionUUID)
190+
191+
// Clean up
192+
sessionManager.closeSession(key)
193+
}
194+
195+
test("initializeBaseSession is idempotent") {
196+
// Create a new session manager to test initialization
197+
val sessionManager = new SparkConnectSessionManager()
198+
199+
// Initialize the base session multiple times
200+
sessionManager.initializeBaseSession(spark.sparkContext)
201+
val key1 = SessionKey("user1", UUID.randomUUID().toString)
202+
val sessionHolder1 = sessionManager.getOrCreateIsolatedSession(key1, None)
203+
val baseSessionUUID1 = sessionHolder1.session.sessionUUID
204+
205+
// Initialize again - should not change the base session
206+
sessionManager.initializeBaseSession(spark.sparkContext)
207+
val key2 = SessionKey("user2", UUID.randomUUID().toString)
208+
val sessionHolder2 = sessionManager.getOrCreateIsolatedSession(key2, None)
209+
210+
// Both sessions should be isolated from each other
211+
assert(sessionHolder1.session.sessionUUID != sessionHolder2.session.sessionUUID)
212+
213+
// Clean up
214+
sessionManager.closeSession(key1)
215+
sessionManager.closeSession(key2)
216+
}
174217
}

0 commit comments

Comments
 (0)