diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala index cd8d2f50a..03e10c35a 100644 --- a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala +++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala @@ -155,12 +155,13 @@ class LivyThriftSessionManager(val server: LivyThriftServer, val livyConf: LivyC } /** - * If the user specified an existing sessionId to use, the corresponding session is returned, - * otherwise a new session is created and returned. + * If the user specified an existing sessionId or session name to use, the corresponding session + * is returned, otherwise a new session is created and returned. */ - private def getOrCreateLivySession( + def getOrCreateLivySession( sessionHandle: SessionHandle, sessionId: Option[Int], + sessionName: Option[String], username: String, createLivySession: () => InteractiveSession): InteractiveSession = { sessionId match { @@ -183,7 +184,27 @@ class LivyThriftSessionManager(val server: LivyThriftServer, val livyConf: LivyC } } case None => - createLivySession() + sessionName match { + case Some(name) => + server.livySessionManager.get(name) match { + case None => + createLivySession() + case Some(session) if !server.isAllowedToUse(username, session) => + warn(s"$username has no modify permissions to InteractiveSession $name.") + throw new IllegalAccessException( + s"$username is not allowed to use InteractiveSession $name.") + case Some(session) => + if (session.state.isActive) { + info(s"Reusing Session $name for $sessionHandle.") + session + } else { + warn(s"InteractiveSession $name is not active anymore.") + throw new IllegalArgumentException(s"Session $name is not active anymore.") + } + } + case None => + createLivySession() + } } } @@ -248,7 +269,8 @@ class LivyThriftSessionManager(val server: LivyThriftServer, val livyConf: LivyC livyServiceUGI.doAs(new PrivilegedExceptionAction[InteractiveSession] { override def run(): InteractiveSession = { livySession = - getOrCreateLivySession(sessionHandle, sessionId, username, createLivySession) + getOrCreateLivySession(sessionHandle, sessionId, createInteractiveRequest.name, + username, createLivySession) synchronized { managedLivySessionActiveUsers.get(livySession.id).foreach { numUsers => managedLivySessionActiveUsers(livySession.id) = numUsers + 1 diff --git a/thriftserver/server/src/test/scala/org/apache/livy/thriftserver/TestLivyThriftSessionManager.scala b/thriftserver/server/src/test/scala/org/apache/livy/thriftserver/TestLivyThriftSessionManager.scala index 11eea31fb..cbfc006c8 100644 --- a/thriftserver/server/src/test/scala/org/apache/livy/thriftserver/TestLivyThriftSessionManager.scala +++ b/thriftserver/server/src/test/scala/org/apache/livy/thriftserver/TestLivyThriftSessionManager.scala @@ -27,13 +27,13 @@ import scala.concurrent.duration.Duration import org.apache.hive.service.cli.{HiveSQLException, SessionHandle} import org.junit.Assert._ import org.junit.Test -import org.mockito.Mockito.mock +import org.mockito.Mockito.{mock, when} import org.apache.livy.LivyConf -import org.apache.livy.server.AccessManager import org.apache.livy.server.interactive.InteractiveSession -import org.apache.livy.server.recovery.{SessionStore, StateStore} -import org.apache.livy.sessions.InteractiveSessionManager +import org.apache.livy.server.recovery.SessionStore +import org.apache.livy.server.AccessManager +import org.apache.livy.sessions.{InteractiveSessionManager, SessionState} import org.apache.livy.utils.Clock.sleep object ConnectionLimitType extends Enumeration { @@ -46,7 +46,7 @@ class TestLivyThriftSessionManager { import ConnectionLimitType._ private def createThriftSessionManager( - limitTypes: ConnectionLimitType*): LivyThriftSessionManager = { + limitTypes: ConnectionLimitType*): (LivyThriftSessionManager, LivyThriftServer) = { val conf = new LivyConf() conf.set(LivyConf.LIVY_SPARK_VERSION, sys.env("LIVY_SPARK_VERSION")) val limit = 3 @@ -62,21 +62,23 @@ class TestLivyThriftSessionManager { } private def createThriftSessionManager( - maxSessionWait: Option[String]): LivyThriftSessionManager = { + maxSessionWait: Option[String]): (LivyThriftSessionManager, LivyThriftServer) = { val conf = new LivyConf() conf.set(LivyConf.LIVY_SPARK_VERSION, sys.env("LIVY_SPARK_VERSION")) maxSessionWait.foreach(conf.set(LivyConf.THRIFT_SESSION_CREATION_TIMEOUT, _)) this.createThriftSessionManager(conf) } - private def createThriftSessionManager(conf: LivyConf): LivyThriftSessionManager = { + private def createThriftSessionManager(conf: LivyConf): (LivyThriftSessionManager, + LivyThriftServer) = { val server = new LivyThriftServer( conf, mock(classOf[InteractiveSessionManager]), mock(classOf[SessionStore]), mock(classOf[AccessManager]) ) - new LivyThriftSessionManager(server, conf) + val sessionManager = new LivyThriftSessionManager(server, conf) + (sessionManager, server) } private def testLimit( @@ -99,7 +101,7 @@ class TestLivyThriftSessionManager { @Test def testLimitConnectionsByUser(): Unit = { - val thriftSessionMgr = createThriftSessionManager(User) + val (thriftSessionMgr, _) = createThriftSessionManager(User) val user = "alice" val forwardedAddresses = new java.util.ArrayList[String]() thriftSessionMgr.incrementConnections(user, "10.20.30.40", forwardedAddresses) @@ -111,7 +113,7 @@ class TestLivyThriftSessionManager { @Test def testLimitConnectionsByIpAddress(): Unit = { - val thriftSessionMgr = createThriftSessionManager(IpAddress) + val (thriftSessionMgr, _) = createThriftSessionManager(IpAddress) val ipAddress = "10.20.30.40" val forwardedAddresses = new java.util.ArrayList[String]() thriftSessionMgr.incrementConnections("alice", ipAddress, forwardedAddresses) @@ -123,7 +125,7 @@ class TestLivyThriftSessionManager { @Test def testLimitConnectionsByUserAndIpAddress(): Unit = { - val thriftSessionMgr = createThriftSessionManager(UserIpAddress) + val (thriftSessionMgr, _) = createThriftSessionManager(UserIpAddress) val user = "alice" val ipAddress = "10.20.30.40" val userAndAddress = user + ":" + ipAddress @@ -149,7 +151,7 @@ class TestLivyThriftSessionManager { @Test def testMultipleConnectionLimits(): Unit = { - val thriftSessionMgr = createThriftSessionManager(User, IpAddress) + val (thriftSessionMgr, _) = createThriftSessionManager(User, IpAddress) val user = "alice" val ipAddress = "10.20.30.40" val forwardedAddresses = new java.util.ArrayList[String]() @@ -166,7 +168,7 @@ class TestLivyThriftSessionManager { @Test(expected = classOf[TimeoutException]) def testGetLivySessionWaitForTimeout(): Unit = { - val thriftSessionMgr = createThriftSessionManager(Some("10ms")) + val (thriftSessionMgr, _) = createThriftSessionManager(Some("10ms")) val sessionHandle = mock(classOf[SessionHandle]) val future = Future[InteractiveSession] { sleep(100) @@ -178,7 +180,7 @@ class TestLivyThriftSessionManager { @Test(expected = classOf[TimeoutException]) def testGetLivySessionWithTimeoutException(): Unit = { - val thriftSessionMgr = createThriftSessionManager(None) + val (thriftSessionMgr, _) = createThriftSessionManager(None) val sessionHandle = mock(classOf[SessionHandle]) val future = Future[InteractiveSession] { throw new TimeoutException("Actively throw TimeoutException in Future.") @@ -187,4 +189,72 @@ class TestLivyThriftSessionManager { Await.ready(future, Duration(30, TimeUnit.SECONDS)) thriftSessionMgr.getLivySession(sessionHandle) } + + + @Test + def testGetOrCreateLivySessionDifferentSessions(): Unit = { + val (thriftSessionMgr, server) = createThriftSessionManager(User, IpAddress) + val sessionHandle = mock(classOf[SessionHandle]) + val sessionUser = "testUser" + val sessionId1 = Some(1) + val session1 = mock(classOf[InteractiveSession]) + when(session1.state).thenReturn(SessionState.Running) + when(session1.owner).thenReturn(sessionUser) + when(server.livySessionManager.get(1)).thenReturn(Some(session1)) + val sessionId2 = Some(2) + val session2 = mock(classOf[InteractiveSession]) + when(session2.state).thenReturn(SessionState.Running) + when(session2.owner).thenReturn(sessionUser) + when(server.livySessionManager.get(2)).thenReturn(Some(session2)) + val result1 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, sessionId1, None, + sessionUser, () => null) + val result2 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, sessionId2, None, + sessionUser, () => null) + + assertNotNull(result1) + assertNotNull(result2) + assertNotEquals(result1, result2) + } + + @Test + def testGetOrCreateLivySessionExistingSessionByID(): Unit = { + val (thriftSessionMgr, server) = createThriftSessionManager(User, IpAddress) + val sessionHandle = mock(classOf[SessionHandle]) + val sessionUser = "testUser" + val sessionId = Some(1) + val session1 = mock(classOf[InteractiveSession]) + when(session1.state).thenReturn(SessionState.Running) + when(session1.owner).thenReturn(sessionUser) + when(server.livySessionManager.get(1)).thenReturn(Some(session1)) + val result1 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, sessionId, None, + sessionUser, () => null) + val result2 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, sessionId, None, + sessionUser, () => null) + + assertNotNull(result1) + assertNotNull(result2) + assertEquals(result1, result2) + } + + + @Test + def testGetOrCreateLivySessionExistingSessionByName(): Unit = { + val (thriftSessionMgr, server) = createThriftSessionManager(User, IpAddress) + val sessionHandle = mock(classOf[SessionHandle]) + val sessionUser = "testUser" + val sessionName = Some("sessionName") + val session1 = mock(classOf[InteractiveSession]) + when(session1.state).thenReturn(SessionState.Running) + when(session1.owner).thenReturn(sessionUser) + when(server.livySessionManager.get("sessionName")).thenReturn(Some(session1)) + val result1 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, None, sessionName, + sessionUser, () => null) + val result2 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, None, sessionName, + sessionUser, () => null) + + assertNotNull(result1) + assertNotNull(result2) + assertEquals(result1, result2) + } + }