Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LIVY-998][THRIF] Support connecting to an existing sessions using session name #445

Merged
merged 1 commit into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,13 @@ class LivyThriftSessionManager(val server: LivyThriftServer, val livyConf: LivyC
}

/**
* If the user specified an existing sessionId to use, the corresponding session is returned,
* otherwise a new session is created and returned.
* If the user specified an existing sessionId or session name to use, the corresponding session
* is returned, otherwise a new session is created and returned.
*/
private def getOrCreateLivySession(
def getOrCreateLivySession(
sessionHandle: SessionHandle,
sessionId: Option[Int],
sessionName: Option[String],
username: String,
createLivySession: () => InteractiveSession): InteractiveSession = {
sessionId match {
Expand All @@ -183,7 +184,27 @@ class LivyThriftSessionManager(val server: LivyThriftServer, val livyConf: LivyC
}
}
case None =>
createLivySession()
sessionName match {
case Some(name) =>
server.livySessionManager.get(name) match {
case None =>
createLivySession()
case Some(session) if !server.isAllowedToUse(username, session) =>
warn(s"$username has no modify permissions to InteractiveSession $name.")
throw new IllegalAccessException(
s"$username is not allowed to use InteractiveSession $name.")
case Some(session) =>
if (session.state.isActive) {
info(s"Reusing Session $name for $sessionHandle.")
session
} else {
warn(s"InteractiveSession $name is not active anymore.")
throw new IllegalArgumentException(s"Session $name is not active anymore.")
}
}
case None =>
createLivySession()
}
}
}

Expand Down Expand Up @@ -248,7 +269,8 @@ class LivyThriftSessionManager(val server: LivyThriftServer, val livyConf: LivyC
livyServiceUGI.doAs(new PrivilegedExceptionAction[InteractiveSession] {
override def run(): InteractiveSession = {
livySession =
getOrCreateLivySession(sessionHandle, sessionId, username, createLivySession)
getOrCreateLivySession(sessionHandle, sessionId, createInteractiveRequest.name,
username, createLivySession)
synchronized {
managedLivySessionActiveUsers.get(livySession.id).foreach { numUsers =>
managedLivySessionActiveUsers(livySession.id) = numUsers + 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ import scala.concurrent.duration.Duration
import org.apache.hive.service.cli.{HiveSQLException, SessionHandle}
import org.junit.Assert._
import org.junit.Test
import org.mockito.Mockito.mock
import org.mockito.Mockito.{mock, when}

import org.apache.livy.LivyConf
import org.apache.livy.server.AccessManager
import org.apache.livy.server.interactive.InteractiveSession
import org.apache.livy.server.recovery.{SessionStore, StateStore}
import org.apache.livy.sessions.InteractiveSessionManager
import org.apache.livy.server.recovery.SessionStore
import org.apache.livy.server.AccessManager
import org.apache.livy.sessions.{InteractiveSessionManager, SessionState}
import org.apache.livy.utils.Clock.sleep

object ConnectionLimitType extends Enumeration {
Expand All @@ -46,7 +46,7 @@ class TestLivyThriftSessionManager {
import ConnectionLimitType._

private def createThriftSessionManager(
limitTypes: ConnectionLimitType*): LivyThriftSessionManager = {
limitTypes: ConnectionLimitType*): (LivyThriftSessionManager, LivyThriftServer) = {
val conf = new LivyConf()
conf.set(LivyConf.LIVY_SPARK_VERSION, sys.env("LIVY_SPARK_VERSION"))
val limit = 3
Expand All @@ -62,21 +62,23 @@ class TestLivyThriftSessionManager {
}

private def createThriftSessionManager(
maxSessionWait: Option[String]): LivyThriftSessionManager = {
maxSessionWait: Option[String]): (LivyThriftSessionManager, LivyThriftServer) = {
val conf = new LivyConf()
conf.set(LivyConf.LIVY_SPARK_VERSION, sys.env("LIVY_SPARK_VERSION"))
maxSessionWait.foreach(conf.set(LivyConf.THRIFT_SESSION_CREATION_TIMEOUT, _))
this.createThriftSessionManager(conf)
}

private def createThriftSessionManager(conf: LivyConf): LivyThriftSessionManager = {
private def createThriftSessionManager(conf: LivyConf): (LivyThriftSessionManager,
LivyThriftServer) = {
val server = new LivyThriftServer(
conf,
mock(classOf[InteractiveSessionManager]),
mock(classOf[SessionStore]),
mock(classOf[AccessManager])
)
new LivyThriftSessionManager(server, conf)
val sessionManager = new LivyThriftSessionManager(server, conf)
(sessionManager, server)
}

private def testLimit(
Expand All @@ -99,7 +101,7 @@ class TestLivyThriftSessionManager {

@Test
def testLimitConnectionsByUser(): Unit = {
val thriftSessionMgr = createThriftSessionManager(User)
val (thriftSessionMgr, _) = createThriftSessionManager(User)
val user = "alice"
val forwardedAddresses = new java.util.ArrayList[String]()
thriftSessionMgr.incrementConnections(user, "10.20.30.40", forwardedAddresses)
Expand All @@ -111,7 +113,7 @@ class TestLivyThriftSessionManager {

@Test
def testLimitConnectionsByIpAddress(): Unit = {
val thriftSessionMgr = createThriftSessionManager(IpAddress)
val (thriftSessionMgr, _) = createThriftSessionManager(IpAddress)
val ipAddress = "10.20.30.40"
val forwardedAddresses = new java.util.ArrayList[String]()
thriftSessionMgr.incrementConnections("alice", ipAddress, forwardedAddresses)
Expand All @@ -123,7 +125,7 @@ class TestLivyThriftSessionManager {

@Test
def testLimitConnectionsByUserAndIpAddress(): Unit = {
val thriftSessionMgr = createThriftSessionManager(UserIpAddress)
val (thriftSessionMgr, _) = createThriftSessionManager(UserIpAddress)
val user = "alice"
val ipAddress = "10.20.30.40"
val userAndAddress = user + ":" + ipAddress
Expand All @@ -149,7 +151,7 @@ class TestLivyThriftSessionManager {

@Test
def testMultipleConnectionLimits(): Unit = {
val thriftSessionMgr = createThriftSessionManager(User, IpAddress)
val (thriftSessionMgr, _) = createThriftSessionManager(User, IpAddress)
val user = "alice"
val ipAddress = "10.20.30.40"
val forwardedAddresses = new java.util.ArrayList[String]()
Expand All @@ -166,7 +168,7 @@ class TestLivyThriftSessionManager {

@Test(expected = classOf[TimeoutException])
def testGetLivySessionWaitForTimeout(): Unit = {
val thriftSessionMgr = createThriftSessionManager(Some("10ms"))
val (thriftSessionMgr, _) = createThriftSessionManager(Some("10ms"))
val sessionHandle = mock(classOf[SessionHandle])
val future = Future[InteractiveSession] {
sleep(100)
Expand All @@ -178,7 +180,7 @@ class TestLivyThriftSessionManager {

@Test(expected = classOf[TimeoutException])
def testGetLivySessionWithTimeoutException(): Unit = {
val thriftSessionMgr = createThriftSessionManager(None)
val (thriftSessionMgr, _) = createThriftSessionManager(None)
val sessionHandle = mock(classOf[SessionHandle])
val future = Future[InteractiveSession] {
throw new TimeoutException("Actively throw TimeoutException in Future.")
Expand All @@ -187,4 +189,72 @@ class TestLivyThriftSessionManager {
Await.ready(future, Duration(30, TimeUnit.SECONDS))
thriftSessionMgr.getLivySession(sessionHandle)
}


@Test
def testGetOrCreateLivySessionDifferentSessions(): Unit = {
val (thriftSessionMgr, server) = createThriftSessionManager(User, IpAddress)
val sessionHandle = mock(classOf[SessionHandle])
val sessionUser = "testUser"
val sessionId1 = Some(1)
val session1 = mock(classOf[InteractiveSession])
when(session1.state).thenReturn(SessionState.Running)
when(session1.owner).thenReturn(sessionUser)
when(server.livySessionManager.get(1)).thenReturn(Some(session1))
val sessionId2 = Some(2)
val session2 = mock(classOf[InteractiveSession])
when(session2.state).thenReturn(SessionState.Running)
when(session2.owner).thenReturn(sessionUser)
when(server.livySessionManager.get(2)).thenReturn(Some(session2))
val result1 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, sessionId1, None,
sessionUser, () => null)
val result2 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, sessionId2, None,
sessionUser, () => null)

assertNotNull(result1)
assertNotNull(result2)
assertNotEquals(result1, result2)
}

@Test
def testGetOrCreateLivySessionExistingSessionByID(): Unit = {
val (thriftSessionMgr, server) = createThriftSessionManager(User, IpAddress)
val sessionHandle = mock(classOf[SessionHandle])
val sessionUser = "testUser"
val sessionId = Some(1)
val session1 = mock(classOf[InteractiveSession])
when(session1.state).thenReturn(SessionState.Running)
when(session1.owner).thenReturn(sessionUser)
when(server.livySessionManager.get(1)).thenReturn(Some(session1))
val result1 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, sessionId, None,
sessionUser, () => null)
val result2 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, sessionId, None,
sessionUser, () => null)

assertNotNull(result1)
assertNotNull(result2)
assertEquals(result1, result2)
}


@Test
def testGetOrCreateLivySessionExistingSessionByName(): Unit = {
val (thriftSessionMgr, server) = createThriftSessionManager(User, IpAddress)
val sessionHandle = mock(classOf[SessionHandle])
val sessionUser = "testUser"
val sessionName = Some("sessionName")
val session1 = mock(classOf[InteractiveSession])
when(session1.state).thenReturn(SessionState.Running)
when(session1.owner).thenReturn(sessionUser)
when(server.livySessionManager.get("sessionName")).thenReturn(Some(session1))
val result1 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, None, sessionName,
sessionUser, () => null)
val result2 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, None, sessionName,
sessionUser, () => null)

assertNotNull(result1)
assertNotNull(result2)
assertEquals(result1, result2)
}

}
Loading