Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support specifying a client certificate for mTLS auth #940

Merged
merged 1 commit into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@ class FeverSecurityKey private constructor() : SecurityKey() {
var serverUrl: String? = null
var username: String? = null
var password: String? = null
var clientCertificateAlias: String? = null

constructor(serverUrl: String?, username: String?, password: String?) : this() {
constructor(serverUrl: String?, username: String?, password: String?, clientCertificateAlias: String?) : this() {
this.serverUrl = serverUrl
this.username = username
this.password = password
this.clientCertificateAlias = clientCertificateAlias
}

constructor(value: String? = DESUtils.empty) : this() {
decode(value, FeverSecurityKey::class.java).let {
serverUrl = it.serverUrl
username = it.username
password = it.password
clientCertificateAlias = it.clientCertificateAlias
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@ class FreshRSSSecurityKey private constructor() : SecurityKey() {
var serverUrl: String? = null
var username: String? = null
var password: String? = null
var clientCertificateAlias: String? = null

constructor(serverUrl: String?, username: String?, password: String?) : this() {
constructor(serverUrl: String?, username: String?, password: String?, clientCertificateAlias: String?) : this() {
this.serverUrl = serverUrl
this.username = username
this.password = password
this.clientCertificateAlias = clientCertificateAlias
}

constructor(value: String? = DESUtils.empty) : this() {
decode(value, FreshRSSSecurityKey::class.java).let {
serverUrl = it.serverUrl
username = it.username
password = it.password
clientCertificateAlias = it.clientCertificateAlias
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@ class GoogleReaderSecurityKey private constructor() : SecurityKey() {
var serverUrl: String? = null
var username: String? = null
var password: String? = null
var clientCertificateAlias: String? = null

constructor(serverUrl: String?, username: String?, password: String?) : this() {
constructor(serverUrl: String?, username: String?, password: String?, clientCertificateAlias: String?) : this() {
this.serverUrl = serverUrl
this.username = username
this.password = password
this.clientCertificateAlias = clientCertificateAlias
}

constructor(value: String? = DESUtils.empty) : this() {
decode(value, GoogleReaderSecurityKey::class.java).let {
serverUrl = it.serverUrl
username = it.username
password = it.password
clientCertificateAlias = it.clientCertificateAlias
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,13 @@ class FeverRssService @Inject constructor(
private suspend fun getFeverAPI() =
FeverSecurityKey(accountDao.queryById(context.currentAccountId)!!.securityKey).run {
FeverAPI.getInstance(
context = context,
serverUrl = serverUrl!!,
username = username!!,
password = password!!,
httpUsername = null,
httpPassword = null,
clientCertificateAlias = clientCertificateAlias,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@ class GoogleReaderRssService @Inject constructor(
private suspend fun getGoogleReaderAPI() =
GoogleReaderSecurityKey(accountDao.queryById(context.currentAccountId)!!.securityKey).run {
GoogleReaderAPI.getInstance(
context = context,
serverUrl = serverUrl!!,
username = username!!,
password = password!!,
httpUsername = null,
httpPassword = null,
clientCertificateAlias = clientCertificateAlias,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@

package me.ash.reader.infrastructure.di

import android.annotation.SuppressLint
import android.content.Context
import android.security.KeyChain
import dagger.Module
import dagger.Provides
import dagger.hilt.InstallIn
Expand All @@ -31,15 +33,18 @@ import okhttp3.Cache
import okhttp3.Interceptor
import okhttp3.OkHttpClient
import okhttp3.Response
import okhttp3.internal.platform.Platform
import java.io.File
import java.net.Socket
import java.security.KeyManagementException
import java.security.NoSuchAlgorithmException
import java.security.Principal
import java.security.PrivateKey
import java.security.cert.X509Certificate
import java.util.concurrent.TimeUnit
import javax.inject.Singleton
import javax.net.ssl.HostnameVerifier
import javax.net.ssl.SSLContext
import javax.net.ssl.TrustManager
import javax.net.ssl.X509KeyManager
import javax.net.ssl.X509TrustManager

/**
Expand All @@ -54,18 +59,21 @@ object OkHttpClientModule {
fun provideOkHttpClient(
@ApplicationContext context: Context,
): OkHttpClient = cachingHttpClient(
context = context,
cacheDirectory = context.cacheDir.resolve("http")
).newBuilder()
.addNetworkInterceptor(UserAgentInterceptor)
.build()
}

fun cachingHttpClient(
context: Context,
cacheDirectory: File? = null,
cacheSize: Long = 10L * 1024L * 1024L,
trustAllCerts: Boolean = true,
connectTimeoutSecs: Long = 30L,
readTimeoutSecs: Long = 30L,
clientCertificateAlias: String? = null,
): OkHttpClient {
val builder: OkHttpClient.Builder = OkHttpClient.Builder()

Expand All @@ -78,31 +86,75 @@ fun cachingHttpClient(
.readTimeout(readTimeoutSecs, TimeUnit.SECONDS)
.followRedirects(true)

if (trustAllCerts) {
builder.trustAllCerts()
if (!clientCertificateAlias.isNullOrBlank() || trustAllCerts) {
builder.setupSsl(context, clientCertificateAlias, trustAllCerts)
}

return builder.build()
}

fun OkHttpClient.Builder.trustAllCerts() {
fun OkHttpClient.Builder.setupSsl(
context: Context,
clientCertificateAlias: String?,
trustAllCerts: Boolean
) {
try {
val trustManager = object : X509TrustManager {
override fun checkClientTrusted(chain: Array<out X509Certificate>?, authType: String?) {
val clientKeyManager = clientCertificateAlias?.let { clientAlias ->
object : X509KeyManager {
override fun getClientAliases(keyType: String?, issuers: Array<Principal>?) =
throw UnsupportedOperationException("getClientAliases")

override fun chooseClientAlias(
keyType: Array<String>?,
issuers: Array<Principal>?,
socket: Socket?
) = clientCertificateAlias

override fun getServerAliases(keyType: String?, issuers: Array<Principal>?) =
throw UnsupportedOperationException("getServerAliases")

override fun chooseServerAlias(
keyType: String?,
issuers: Array<Principal>?,
socket: Socket?
) = throw UnsupportedOperationException("chooseServerAlias")

override fun getCertificateChain(alias: String?): Array<X509Certificate>? {
return if (alias == clientAlias) KeyChain.getCertificateChain(context, clientAlias) else null
}

override fun getPrivateKey(alias: String?): PrivateKey? {
return if (alias == clientAlias) KeyChain.getPrivateKey(context, clientAlias) else null
}
}
}

override fun checkServerTrusted(chain: Array<out X509Certificate>?, authType: String?) {
}
val trustManager = if (trustAllCerts) {
hostnameVerifier { _, _ -> true }

@SuppressLint("CustomX509TrustManager")
object : X509TrustManager {
override fun checkClientTrusted(
chain: Array<out X509Certificate>?,
authType: String?
) = Unit

override fun getAcceptedIssuers(): Array<X509Certificate> = emptyArray()
override fun checkServerTrusted(
chain: Array<out X509Certificate>?,
authType: String?
) = Unit

override fun getAcceptedIssuers(): Array<X509Certificate> = emptyArray()
}
} else {
Platform.get().platformTrustManager()
}

val sslContext = SSLContext.getInstance("TLS")
sslContext.init(null, arrayOf<TrustManager>(trustManager), null)
sslContext.init(arrayOf(clientKeyManager), arrayOf(trustManager), null)
val sslSocketFactory = sslContext.socketFactory

sslSocketFactory(sslSocketFactory, trustManager)
.hostnameVerifier(HostnameVerifier { _, _ -> true })
} catch (e: NoSuchAlgorithmException) {
// ignore
} catch (e: KeyManagementException) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
package me.ash.reader.infrastructure.rss.provider

import android.content.Context
import com.google.gson.Gson
import com.google.gson.GsonBuilder
import me.ash.reader.infrastructure.di.UserAgentInterceptor
import me.ash.reader.infrastructure.di.cachingHttpClient
import okhttp3.OkHttpClient

abstract class ProviderAPI {
abstract class ProviderAPI(context: Context, clientCertificateAlias: String?) {

protected val client: OkHttpClient = cachingHttpClient()
protected val client: OkHttpClient = cachingHttpClient(
context = context,
clientCertificateAlias = clientCertificateAlias,
)
.newBuilder()
.addNetworkInterceptor(UserAgentInterceptor)
.build()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package me.ash.reader.infrastructure.rss.provider.fever

import android.content.Context
import me.ash.reader.infrastructure.exception.FeverAPIException
import me.ash.reader.infrastructure.rss.provider.ProviderAPI
import me.ash.reader.ui.ext.encodeBase64
Expand All @@ -10,11 +11,13 @@ import okhttp3.executeAsync
import java.util.concurrent.ConcurrentHashMap

class FeverAPI private constructor(
context: Context,
private val serverUrl: String,
private val apiKey: String,
private val httpUsername: String? = null,
private val httpPassword: String? = null,
) : ProviderAPI() {
clientCertificateAlias: String? = null,
) : ProviderAPI(context, clientCertificateAlias) {

private suspend inline fun <reified T> postRequest(query: String?): T {
val response = client.newCall(
Expand Down Expand Up @@ -104,14 +107,16 @@ class FeverAPI private constructor(
private val instances: ConcurrentHashMap<String, FeverAPI> = ConcurrentHashMap()

fun getInstance(
context: Context,
serverUrl: String,
username: String,
password: String,
httpUsername: String? = null,
httpPassword: String? = null,
clientCertificateAlias: String? = null,
): FeverAPI = "$username:$password".md5().run {
instances.getOrPut("$serverUrl$this$httpUsername$httpPassword") {
FeverAPI(serverUrl, this, httpUsername, httpPassword)
instances.getOrPut("$serverUrl$this$httpUsername$httpPassword$clientCertificateAlias") {
FeverAPI(context, serverUrl, this, httpUsername, httpPassword, clientCertificateAlias)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package me.ash.reader.infrastructure.rss.provider.greader

import android.content.Context
import me.ash.reader.infrastructure.di.USER_AGENT_STRING
import me.ash.reader.infrastructure.exception.GoogleReaderAPIException
import me.ash.reader.infrastructure.exception.RetryException
Expand All @@ -10,12 +11,14 @@ import okhttp3.executeAsync
import java.util.concurrent.ConcurrentHashMap

class GoogleReaderAPI private constructor(
context: Context,
private val serverUrl: String,
private val username: String,
private val password: String,
private val httpUsername: String? = null,
private val httpPassword: String? = null,
) : ProviderAPI() {
clientCertificateAlias: String? = null,
) : ProviderAPI(context, clientCertificateAlias) {

enum class Stream(val tag: String) {
ALL_ITEMS("user/-/state/com.google/reading-list"),
Expand Down Expand Up @@ -350,13 +353,15 @@ class GoogleReaderAPI private constructor(
private val instances: ConcurrentHashMap<String, GoogleReaderAPI> = ConcurrentHashMap()

fun getInstance(
context: Context,
serverUrl: String,
username: String,
password: String,
httpUsername: String? = null,
httpPassword: String? = null,
): GoogleReaderAPI = instances.getOrPut("$serverUrl$username$password$httpUsername$httpPassword") {
GoogleReaderAPI(serverUrl, username, password, httpUsername, httpPassword)
clientCertificateAlias: String? = null
): GoogleReaderAPI = instances.getOrPut("$serverUrl$username$password$httpUsername$httpPassword$clientCertificateAlias") {
GoogleReaderAPI(context, serverUrl, username, password, httpUsername, httpPassword, clientCertificateAlias)
}

fun clearInstance() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package me.ash.reader.ui.component.base

import androidx.compose.foundation.interaction.MutableInteractionSource
import androidx.compose.foundation.interaction.PressInteraction
import androidx.compose.foundation.text.KeyboardActions
import androidx.compose.foundation.text.KeyboardOptions
import androidx.compose.material.icons.Icons
Expand All @@ -22,6 +24,7 @@ import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Modifier
import androidx.compose.ui.focus.FocusRequester
import androidx.compose.ui.focus.focusProperties
import androidx.compose.ui.focus.focusRequester
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.platform.LocalClipboardManager
Expand All @@ -46,6 +49,7 @@ fun RYOutlineTextField(
errorMessage: String = "",
keyboardOptions: KeyboardOptions = KeyboardOptions.Default,
keyboardActions: KeyboardActions = KeyboardActions(),
onClick: (() -> Unit)? = null,
) {
val clipboardManager = LocalClipboardManager.current
val focusRequester = remember { FocusRequester() }
Expand All @@ -59,7 +63,11 @@ fun RYOutlineTextField(
}

OutlinedTextField(
modifier = Modifier.focusRequester(focusRequester),
modifier = if (onClick != null) {
Modifier.focusProperties { canFocus = false }
} else {
Modifier.focusRequester(focusRequester)
},
colors = TextFieldDefaults.colors(
unfocusedContainerColor = Color.Transparent,
focusedContainerColor = Color.Transparent
Expand Down Expand Up @@ -115,5 +123,18 @@ fun RYOutlineTextField(
},
keyboardOptions = keyboardOptions,
keyboardActions = keyboardActions,
readOnly = onClick != null,
interactionSource = onClick?.let {
remember { MutableInteractionSource() }
.also { interactionSource ->
LaunchedEffect(interactionSource) {
interactionSource.interactions.collect {
if (it is PressInteraction.Release) {
onClick.invoke()
}
}
}
}
}
)
}
Loading
Loading