diff --git a/model-client/src/commonMain/kotlin/org/modelix/model/oauth/ModelixAuthClient.kt b/model-client/src/commonMain/kotlin/org/modelix/model/oauth/ModelixAuthClient.kt index 36e2571705..3d00af7d09 100644 --- a/model-client/src/commonMain/kotlin/org/modelix/model/oauth/ModelixAuthClient.kt +++ b/model-client/src/commonMain/kotlin/org/modelix/model/oauth/ModelixAuthClient.kt @@ -1,9 +1,12 @@ package org.modelix.model.oauth import io.ktor.client.HttpClientConfig -import io.ktor.client.plugins.auth.Auth -import io.ktor.client.plugins.auth.providers.BearerTokens -import io.ktor.client.plugins.auth.providers.bearer +import io.ktor.client.plugins.api.Send +import io.ktor.client.plugins.api.createClientPlugin +import io.ktor.client.request.HttpRequestBuilder +import io.ktor.client.request.takeFrom +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpStatusCode /** * Functions and states for authenticating to a model server. @@ -15,8 +18,7 @@ expect class ModelixAuthClient() { * * @param config Config for the HTTP client to be created. * This config will be modified to enable authentication. - * @param baseUrl Base url of model server. - * Required for PKCE flow in JVM. + * @param authConfig Authentication configuration (OAuth or token provider). */ fun installAuth( config: HttpClientConfig<*>, @@ -25,21 +27,61 @@ expect class ModelixAuthClient() { } internal fun installAuthWithAuthTokenProvider(config: HttpClientConfig<*>, authTokenProvider: suspend () -> String?) { - config.apply { - install(Auth) { - bearer { - loadTokens { - authTokenProvider()?.let { authToken -> BearerTokens(authToken, "") } + // Single custom plugin that handles ALL auth: + // - Adds token to every request + // - Retries on 401 with refreshed token + // - Retries on 403 with refreshed token + // This avoids conflicts with Ktor's Auth plugin which doesn't handle 403 + val authPlugin = createClientPlugin("ModelixAuthPlugin") { + var cachedToken: String? = null + var attemptedRefresh = false + + on(Send) { request -> + // Get token for request if not already present + if (request.headers[HttpHeaders.Authorization] == null) { + val token = cachedToken ?: authTokenProvider() + cachedToken = token + if (token != null) { + request.headers.append(HttpHeaders.Authorization, "Bearer $token") + } + } + + val call = proceed(request) + val status = call.response.status + + // Retry on 401 or 403 with a fresh token + if ((status == HttpStatusCode.Unauthorized || status == HttpStatusCode.Forbidden) && !attemptedRefresh) { + attemptedRefresh = true + + // Get fresh token + val freshToken = authTokenProvider() + + if (freshToken != null && freshToken != cachedToken) { + cachedToken = freshToken + + // Copy request using takeFrom (copies method, url, body, headers) + val newRequest = HttpRequestBuilder().takeFrom(request) + + // Replace Authorization header with fresh token + newRequest.headers.remove(HttpHeaders.Authorization) + newRequest.headers.append(HttpHeaders.Authorization, "Bearer $freshToken") + + proceed(newRequest) + } else { + call } - refreshTokens { - val providedToken = authTokenProvider() - if (providedToken != null && providedToken != this.oldTokens?.accessToken) { - BearerTokens(providedToken, "") - } else { - null - } + } else { + // Reset flag on successful responses + @Suppress("MagicNumber") + if (status.value in 200..299) { + attemptedRefresh = false } + call } } } + + config.apply { + install(authPlugin) + } } diff --git a/model-client/src/jvmTest/kotlin/org/modelix/model/client2/ModelixAuthClientTest.kt b/model-client/src/jvmTest/kotlin/org/modelix/model/client2/ModelixAuthClientTest.kt index 9501b71b05..0cffa02f1b 100644 --- a/model-client/src/jvmTest/kotlin/org/modelix/model/client2/ModelixAuthClientTest.kt +++ b/model-client/src/jvmTest/kotlin/org/modelix/model/client2/ModelixAuthClientTest.kt @@ -1,12 +1,22 @@ package org.modelix.model.client2 +import io.ktor.client.request.get +import io.ktor.client.statement.bodyAsText +import io.ktor.http.HttpStatusCode +import io.ktor.server.application.call +import io.ktor.server.response.respond +import io.ktor.server.routing.get +import io.ktor.server.routing.routing +import io.ktor.server.testing.testApplication import kotlinx.coroutines.CancellationException import kotlinx.coroutines.runBlocking import kotlinx.coroutines.withTimeout import org.modelix.model.oauth.IAuthRequestHandler import org.modelix.model.oauth.ModelixAuthClient import org.modelix.model.oauth.OAuthConfig +import org.modelix.model.oauth.TokenProviderAuthConfig import kotlin.test.Test +import kotlin.test.assertEquals import kotlin.test.assertFailsWith import kotlin.test.assertTrue import kotlin.time.Duration.Companion.milliseconds @@ -41,4 +51,220 @@ class ModelixAuthClientTest { assertTrue(browseCalled) } + + @Test + fun `401 response triggers token refresh and retry with new token`() = testApplication { + var tokenProviderCallCount = 0 + val token1 = "expired-token" + val token2 = "fresh-token" + var serverRequestCount = 0 + + application { + routing { + get("/protected") { + serverRequestCount++ + val authHeader = call.request.headers["Authorization"] + when (authHeader) { + "Bearer $token1" -> call.respond(HttpStatusCode.Unauthorized, "Token expired") + "Bearer $token2" -> call.respond(HttpStatusCode.OK, "Success") + else -> call.respond(HttpStatusCode.Unauthorized, "No token") + } + } + } + } + + val authClient = createClient { + ModelixAuthClient().installAuth( + this, + TokenProviderAuthConfig { + tokenProviderCallCount++ + if (tokenProviderCallCount == 1) token1 else token2 + }, + ) + } + + val response = authClient.get("/protected") + + assertEquals(HttpStatusCode.OK, response.status, "Should succeed after token refresh") + assertEquals("Success", response.bodyAsText()) + assertTrue(tokenProviderCallCount >= 2, "Token provider should be called at least twice: initial + refresh. Was called $tokenProviderCallCount times") + assertTrue(serverRequestCount >= 2, "Server should receive at least 2 requests: initial 401 + retry. Received $serverRequestCount") + } + + @Test + fun `403 response does not cause infinite retry loop`() = testApplication { + var tokenProviderCallCount = 0 + var serverRequestCount = 0 + + application { + routing { + get("/always-forbidden") { + serverRequestCount++ + // Always return 403, simulating permanently insufficient permissions + call.respond(HttpStatusCode.Forbidden, "Permission denied") + } + } + } + + val authClient = createClient { + ModelixAuthClient().installAuth( + this, + TokenProviderAuthConfig { + tokenProviderCallCount++ + "token-$tokenProviderCallCount" + }, + ) + } + + val response = authClient.get("/always-forbidden") + + assertEquals(HttpStatusCode.Forbidden, response.status, "Should return 403 after retry attempts exhausted") + // Key assertion: should NOT retry indefinitely + assertTrue(serverRequestCount <= 3, "Should not retry more than a few times. Server received $serverRequestCount requests") + assertTrue(tokenProviderCallCount <= 3, "Token provider should not be called excessively. Called $tokenProviderCallCount times") + } + + @Test + fun `403 triggers token refresh attempt on first occurrence`() = testApplication { + var tokenProviderCallCount = 0 + val token1 = "old-branch-token" + val token2 = "new-branch-token" + var serverRequestCount = 0 + + application { + routing { + get("/branch-resource") { + serverRequestCount++ + val authHeader = call.request.headers["Authorization"] + when (authHeader) { + "Bearer $token1" -> { + // First token doesn't have permission for new branch + call.respond(HttpStatusCode.Forbidden, "No permission for this branch") + } + "Bearer $token2" -> { + // Refreshed token has correct permissions + call.respond(HttpStatusCode.OK, "Branch access granted") + } + else -> call.respond(HttpStatusCode.Unauthorized, "No token") + } + } + } + } + + val authClient = createClient { + ModelixAuthClient().installAuth( + this, + TokenProviderAuthConfig { + tokenProviderCallCount++ + if (tokenProviderCallCount == 1) token1 else token2 + }, + ) + } + + val response = authClient.get("/branch-resource") + + // Verify that 403 triggered a token refresh and successful retry + assertEquals(HttpStatusCode.OK, response.status, "Should succeed after token refresh on 403") + assertEquals("Branch access granted", response.bodyAsText()) + assertEquals(2, tokenProviderCallCount, "Token provider should be called twice: initial load + refresh on 403") + assertEquals(2, serverRequestCount, "Server should receive 2 requests: initial 403 + retry with new token") + } + + @Test + fun `successful response after 403 allows future 403 retry`() = testApplication { + var tokenProviderCallCount = 0 + var serverRequestCount = 0 + var returnForbidden = true + + application { + routing { + get("/resource") { + serverRequestCount++ + if (returnForbidden) { + call.respond(HttpStatusCode.Forbidden, "Forbidden") + } else { + call.respond(HttpStatusCode.OK, "Success") + } + } + } + } + + val authClient = createClient { + ModelixAuthClient().installAuth( + this, + TokenProviderAuthConfig { + tokenProviderCallCount++ + "token-$tokenProviderCallCount" + }, + ) + } + + // First request - gets 403 + val response1 = authClient.get("/resource") + assertEquals(HttpStatusCode.Forbidden, response1.status) + val requestsAfterFirst403 = serverRequestCount + + // Second request - succeeds (simulating fix to permissions) + returnForbidden = false + val response2 = authClient.get("/resource") + assertEquals(HttpStatusCode.OK, response2.status) + + // Third request - gets 403 again (new branch change) + returnForbidden = true + val response3 = authClient.get("/resource") + assertEquals(HttpStatusCode.Forbidden, response3.status) + + // Verify the flag was reset after success, allowing retry attempt on third request + // This is the key behavior: after success, a new 403 should trigger a fresh retry attempt + assertTrue( + serverRequestCount > requestsAfterFirst403 + 1, + "After successful response, new 403 should trigger retry attempt. Total requests: $serverRequestCount", + ) + } + + @Test + fun `403 retry uses the new token in Authorization header`() = testApplication { + val tokensReceivedByServer = mutableListOf() + var tokenProviderCallCount = 0 + val token1 = "initial-token-abc" + val token2 = "refreshed-token-xyz" + + application { + routing { + get("/verify-token") { + val authHeader = call.request.headers["Authorization"] + tokensReceivedByServer.add(authHeader) + + when (authHeader) { + "Bearer $token1" -> call.respond(HttpStatusCode.Forbidden, "Old token") + "Bearer $token2" -> call.respond(HttpStatusCode.OK, "New token accepted") + else -> call.respond(HttpStatusCode.Unauthorized, "Unknown: $authHeader") + } + } + } + } + + val authClient = createClient { + ModelixAuthClient().installAuth( + this, + TokenProviderAuthConfig { + tokenProviderCallCount++ + if (tokenProviderCallCount == 1) token1 else token2 + }, + ) + } + + val response = authClient.get("/verify-token") + + // Print debug info + println("Tokens received by server in order: $tokensReceivedByServer") + println("Token provider called $tokenProviderCallCount times") + println("Response status: ${response.status}") + + // Verify server received both tokens in correct order + assertEquals(2, tokensReceivedByServer.size, "Server should receive exactly 2 requests") + assertEquals("Bearer $token1", tokensReceivedByServer[0], "First request should use initial token") + assertEquals("Bearer $token2", tokensReceivedByServer[1], "Second request (retry) should use refreshed token") + assertEquals(HttpStatusCode.OK, response.status, "Final response should be OK") + } }