Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package gg.agit.konect.global.auth.bridge;

import java.security.SecureRandom;
import java.time.Duration;
import java.util.Base64;
import java.util.List;
import java.util.Optional;

import org.springframework.context.annotation.Profile;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.lang.Nullable;
import org.springframework.stereotype.Service;

import lombok.RequiredArgsConstructor;

@Profile("!local")
@Service
@RequiredArgsConstructor
public class NativeSessionBridgeService {

private static final int TOKEN_BYTES = 32;
private static final String KEY_PREFIX = "native:session-bridge:";
private static final Duration TTL = Duration.ofSeconds(30);
private static final DefaultRedisScript<String> GET_DEL_SCRIPT =
new DefaultRedisScript<>(
"local v = redis.call('GET', KEYS[1]); " +
"if v then redis.call('DEL', KEYS[1]); end; " +
"return v;",
String.class
);

private final SecureRandom secureRandom = new SecureRandom();

private final StringRedisTemplate redis;

public String issue(Integer userId) {
if (userId == null) {
throw new IllegalArgumentException("userId is required");
}

String token = generateToken();
redis.opsForValue().set(KEY_PREFIX + token, userId.toString(), TTL);

return token;
}

public Optional<Integer> consume(@Nullable String token) {
if (token == null || token.isBlank()) {
return Optional.empty();
}

String key = KEY_PREFIX + token;
String value = redis.execute(GET_DEL_SCRIPT, List.of(key));

if (value == null || value.isBlank()) {
return Optional.empty();
}

try {
return Optional.of(Integer.parseInt(value));
} catch (NumberFormatException e) {
return Optional.empty();
}
}

private String generateToken() {
byte[] bytes = new byte[TOKEN_BYTES];
secureRandom.nextBytes(bytes);
return Base64.getUrlEncoder().withoutPadding().encodeToString(bytes);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package gg.agit.konect.global.auth.bridge;

import java.io.IOException;

import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Profile;
import org.springframework.http.HttpStatus;
import org.springframework.util.StringUtils;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

import gg.agit.konect.global.auth.annotation.PublicApi;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;
import lombok.RequiredArgsConstructor;

@Profile("!local")
@RestController
@RequiredArgsConstructor
public class NativeSessionController {

@Value("${app.frontend.base-url}")
private String frontendBaseUrl;

private final NativeSessionBridgeService nativeSessionBridgeService;

@PublicApi
@GetMapping("/native/session/bridge")
public void bridge(
@RequestParam(name = "bridge_token", required = false) String bridgeToken,
HttpServletRequest request,
HttpServletResponse response
) throws IOException {
response.setHeader("Cache-Control", "no-store, no-cache, must-revalidate");

if (!StringUtils.hasText(bridgeToken)) {
response.sendError(HttpStatus.UNAUTHORIZED.value());
return;
}

Integer userId = nativeSessionBridgeService.consume(bridgeToken).orElse(null);

if (userId == null) {
response.sendError(HttpStatus.UNAUTHORIZED.value());
return;
}

HttpSession existing = request.getSession(false);
if (existing != null) {
existing.invalidate();
}

HttpSession session = request.getSession(true);
session.setAttribute("userId", userId);

response.sendRedirect(frontendBaseUrl + "/home");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import java.util.Optional;
import java.util.Set;

import org.springframework.beans.factory.ObjectProvider;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
Expand All @@ -18,6 +19,7 @@
import gg.agit.konect.domain.user.model.User;
import gg.agit.konect.domain.user.repository.UnRegisteredUserRepository;
import gg.agit.konect.domain.user.repository.UserRepository;
import gg.agit.konect.global.auth.bridge.NativeSessionBridgeService;
import gg.agit.konect.global.code.ApiResponseCode;
import gg.agit.konect.global.config.SecurityProperties;
import gg.agit.konect.global.exception.CustomException;
Expand All @@ -38,6 +40,7 @@ public class OAuth2LoginSuccessHandler implements AuthenticationSuccessHandler {
private final UserRepository userRepository;
private final UnRegisteredUserRepository unRegisteredUserRepository;
private final SecurityProperties securityProperties;
private final ObjectProvider<NativeSessionBridgeService> nativeSessionBridgeService;

@Override
public void onAuthenticationSuccess(
Expand Down Expand Up @@ -110,7 +113,31 @@ private void sendLoginSuccessResponse(
String redirectUri = (String)session.getAttribute("redirect_uri");
session.removeAttribute("redirect_uri");

response.sendRedirect(resolveSafeRedirect(redirectUri));
String safeRedirect = resolveSafeRedirect(redirectUri);

if (isAppleOauthCallback(safeRedirect)) {
NativeSessionBridgeService svc = nativeSessionBridgeService.getIfAvailable();

if (svc != null) {
String bridgeToken = svc.issue(user.getId());
safeRedirect = appendBridgeToken(safeRedirect, bridgeToken);
}
}

response.sendRedirect(safeRedirect);
}

private boolean isAppleOauthCallback(String redirectUri) {
return redirectUri != null && redirectUri.startsWith("konect://oauth/callback");
}

private String appendBridgeToken(String redirectUri, String bridgeToken) {
if (redirectUri.contains("bridge_token=")) {
return redirectUri;
}

char joiner = redirectUri.contains("?") ? '&' : '?';
return redirectUri + joiner + "bridge_token=" + bridgeToken;
}

private String extractEmail(OAuth2User oauthUser, Provider provider) {
Expand Down