Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,19 @@ internal object PackStream {
@OptIn(ExperimentalContracts::class)
fun <T> ByteArray.unpack(block: Unpacker.() -> T): T {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return Unpacker(this).run(block)
return try {
Unpacker(this).run(block)
} catch (bufferException: java.nio.BufferUnderflowException) {
throw ServerError.ProtocolError.PackStreamParseError(
"ByteBuffer underflow",
bufferException
)
} catch (indexException: IndexOutOfBoundsException) {
throw ServerError.ProtocolError.PackStreamParseError(
"Array index out of bounds",
indexException
)
}
}

/** A [Structure](https://neo4j.com/docs/bolt/current/packstream/#data-type-structure). */
Expand Down Expand Up @@ -385,7 +397,10 @@ internal object PackStream {
buffer.put(STRUCT_16)
buffer.putShort(value.fields.size.toShort())
}
else -> error("Structure size '${value.fields.size}' is invalid")
else -> throw ServerError.ProtocolError.SerializationError(
"Structure",
IllegalArgumentException("Structure size '${value.fields.size}' exceeds maximum supported size")
)
}
buffer.put(value.id)
value.fields.forEach(::any)
Expand Down Expand Up @@ -429,7 +444,10 @@ internal object PackStream {
zonedDateTime.toInstant().epochSecond,
zonedDateTime.nano,
zone.totalSeconds)))
else -> error("ZonedDateTime '$zonedDateTime' is invalid")
else -> throw ServerError.ProtocolError.SerializationError(
"ZonedDateTime",
IllegalArgumentException("ZonedDateTime '$zonedDateTime' has unsupported zone type")
)
}

/**
Expand Down Expand Up @@ -493,7 +511,10 @@ internal object PackStream {
is LocalDateTime -> localDateTime(value)
is Duration -> duration(value)
is Structure -> structure(value)
else -> error("Value '$value' isn't packable")
else -> throw ServerError.ProtocolError.SerializationError(
value::class.simpleName ?: "Unknown",
IllegalArgumentException("Value of type '${value::class.simpleName}' is not supported for PackStream serialization")
)
}
}
}
Expand Down Expand Up @@ -720,8 +741,11 @@ internal object PackStream {
}
else -> this
}
} catch (_: Exception) {
error("Structure (${Char(id.toInt())}) '$this' is invalid")
} catch (conversionException: Exception) {
throw ServerError.ProtocolError.PackStreamParseError(
"Structure (${Char(id.toInt())}) conversion",
conversionException
)
}
}

Expand All @@ -740,7 +764,12 @@ internal object PackStream {
*/
fun ByteBuffer.getUInt32(): Int {
val uint32 = getInt().toUInt().toLong()
check(uint32 <= Int.MAX_VALUE) { "Size '$uint32' is too big" }
if (uint32 > Int.MAX_VALUE) {
throw ServerError.ProtocolError.PackStreamParseError(
"uint32 size",
IllegalArgumentException("Size '$uint32' exceeds maximum supported value")
)
}
return uint32.toInt()
}

Expand All @@ -752,8 +781,11 @@ internal object PackStream {
return bytes
}

/** Throw an [IllegalStateException] because the marker [Byte] is unexpected. */
fun Byte.unexpected(): Nothing = error("Unexpected marker '${toHex()}'")
/** Throw a [ServerError.ProtocolError] because the marker [Byte] is unexpected. */
fun Byte.unexpected(): Nothing = throw ServerError.ProtocolError.PackStreamParseError(
"marker byte",
IllegalArgumentException("Unexpected PackStream marker '${toHex()}'")
)
}
}

Expand Down
100 changes: 80 additions & 20 deletions graph-guard/src/main/kotlin/io/github/cfraser/graphguard/Server.kt
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,27 @@ constructor(
intercept { session, message ->
plugin
.runCatching { intercept(session, message) }
.onFailure { LOGGER.error("Failed to intercept '{}'", message, it) }
.onFailure { throwable ->
val pluginError = ServerError.PluginError.InterceptorFailure(
plugin::class.simpleName ?: "Unknown",
message::class.simpleName ?: "Unknown",
throwable
)
LOGGER.error("Plugin interceptor failed: {}", pluginError.message, pluginError)
}
.getOrDefault(message)
}
observe { event ->
plugin
.runCatching { observe(event) }
.onFailure { LOGGER.error("Failed to observe '{}'", event, it) }
.onFailure { throwable ->
val pluginError = ServerError.PluginError.ObserverFailure(
plugin::class.simpleName ?: "Unknown",
event::class.simpleName ?: "Unknown",
throwable
)
LOGGER.error("Plugin observer failed: {}", pluginError.message, pluginError)
}
}
}

Expand All @@ -120,7 +134,6 @@ constructor(
* The [Server] is ready to accept client connections after [Server.start] returns.
*/
@Synchronized
@Suppress("TooGenericExceptionCaught")
fun start() {
val latch = CountDownLatch(1)
check(running?.isActive?.not() ?: true) { "The proxy server is already running" }
Expand All @@ -135,7 +148,14 @@ constructor(
try {
run(latch)
} catch (thrown: Exception) {
LOGGER.error("Proxy server failed to run", thrown)
val serverError = when (thrown) {
is ServerError -> thrown
is java.net.BindException -> ServerError.ConnectionError.BindFailure(address.toString(), thrown)
is java.net.ConnectException -> ServerError.ConnectionError.GraphConnectionFailure(graph.toString(), thrown)
is java.net.SocketTimeoutException -> ServerError.ConnectionError.ConnectionTimeout(30000L)
else -> ServerError.ConnectionError.AcceptFailure(thrown)
}
LOGGER.error("Proxy server failed to start: {}", serverError.message, serverError)
}
}
latch.await()
Expand All @@ -154,7 +174,6 @@ constructor(
* [CancellationException] is **not** thrown after the server is stopped.
* > [CountDownLatch.countDown] when the [Server] is ready to accept client connections.
*/
@Suppress("TooGenericExceptionCaught")
private suspend fun run(latch: CountDownLatch) {
try {
bind { selector, serverSocket ->
Expand All @@ -175,7 +194,14 @@ constructor(
} catch (cancellation: CancellationException) {
LOGGER.debug("Proxy connection closed", cancellation)
} catch (exception: Exception) {
LOGGER.error("Proxy connection failure", exception)
val connectionError = when (exception) {
is ServerError -> exception
is java.net.ConnectException -> ServerError.ConnectionError.GraphConnectionFailure(graph.toString(), exception)
is java.net.SocketTimeoutException -> ServerError.ConnectionError.ConnectionTimeout(30000L)
is java.io.IOException -> ServerError.ConnectionError.ConnectionClosed(exception.message)
else -> ServerError.ConnectionError.AcceptFailure(exception)
}
LOGGER.error("Proxy connection failure: {}", connectionError.message, connectionError)
}
}
}
Expand All @@ -189,8 +215,11 @@ constructor(
withLoggingContext("graph-guard.server" to "$address", "graph-guard.graph" to "$graph") {
try {
SelectorManager(coroutineContext).use { selector ->
val socket =
aSocket(selector).tcp().bind(KInetSocketAddress(address.hostname, address.port))
val socket = try {
aSocket(selector).tcp().bind(KInetSocketAddress(address.hostname, address.port))
} catch (bindException: Exception) {
throw ServerError.ConnectionError.BindFailure(address.toString(), bindException)
}
LOGGER.info("Started proxy server on '{}'", socket.localAddress)
plugin.observe(Started)
socket.use { server -> block(selector, server) }
Expand Down Expand Up @@ -233,12 +262,24 @@ constructor(
selector: SelectorManager,
block: suspend CoroutineScope.(Connection, ByteReadChannel, ByteWriteChannel) -> Unit
) {
var socket = aSocket(selector).tcp().connect(KInetSocketAddress(graph.host, graph.port))
if ("+s" in graph.scheme)
socket =
socket.tls(coroutineContext = coroutineContext) {
trustManager = this@Server.trustManager
}
var socket = try {
aSocket(selector).tcp().connect(KInetSocketAddress(graph.host, graph.port))
} catch (connectException: Exception) {
throw ServerError.ConnectionError.GraphConnectionFailure(graph.toString(), connectException)
}

if ("+s" in graph.scheme) {
socket = try {
socket.tls(coroutineContext = coroutineContext) {
trustManager = this@Server.trustManager
}
} catch (tlsException: Exception) {
throw ServerError.ConfigurationError.TlsConfigurationError(
"Failed to establish TLS connection to graph", tlsException
)
}
}

val graphConnection = Connection.Graph(socket.remoteAddress.toInetSocketAddress())
try {
socket.withChannels { reader, writer ->
Expand All @@ -253,7 +294,7 @@ constructor(
}

/** Proxy a [Bolt.Session] between the *client* and *graph*. */
@Suppress("LongParameterList", "TooGenericExceptionCaught")
@Suppress("LongParameterList")
private suspend fun CoroutineScope.proxy(
clientConnection: Connection,
clientReader: ByteReadChannel,
Expand Down Expand Up @@ -313,7 +354,13 @@ constructor(
} catch (cancellation: CancellationException) {
LOGGER.debug("Proxy session closed", cancellation)
} catch (thrown: Exception) {
LOGGER.error("Proxy session failure", thrown)
val sessionError = when (thrown) {
is ServerError -> thrown
is java.io.IOException -> ServerError.ConnectionError.ConnectionClosed(thrown.message)
is java.nio.channels.ClosedChannelException -> ServerError.ConnectionError.ConnectionClosed("Channel closed")
else -> ServerError.ProtocolError.MalformedMessage("Session", thrown.message ?: "Unknown error")
}
LOGGER.error("Proxy session failure: {}", sessionError.message, sessionError)
}
}

Expand All @@ -323,7 +370,6 @@ constructor(
* [source] to the *resolved destination*.
* > Intercept [Bolt.Goodbye] and [cancel] the [CoroutineScope] to end the session.
*/
@Suppress("TooGenericExceptionCaught")
private fun CoroutineScope.proxy(
session: Bolt.Session,
source: Connection,
Expand All @@ -334,16 +380,30 @@ constructor(
val message =
try {
reader.readMessage()
} catch (_: Exception) {
} catch (readException: Exception) {
val protocolError = when (readException) {
is ServerError.ProtocolError -> readException
is java.io.EOFException -> ServerError.ConnectionError.ConnectionClosed("End of stream")
is java.nio.channels.ClosedChannelException -> ServerError.ConnectionError.ConnectionClosed("Channel closed")
is kotlinx.coroutines.TimeoutCancellationException -> ServerError.ConnectionError.ConnectionTimeout(30000L)
else -> ServerError.ProtocolError.PackStreamParseError("message read", readException)
}
LOGGER.debug("Failed to read message from {}: {}", source, protocolError.message)
break
}
LOGGER.debug("Read '{}' from {}", message, source)
val intercepted = plugin.intercept(session, message)
val (destination, writer) = resolver(intercepted)
try {
writer.writeMessage(intercepted)
} catch (thrown: Exception) {
LOGGER.error("Failed to write '{}' to {}", intercepted, destination, thrown)
} catch (writeException: Exception) {
val writeError = when (writeException) {
is ServerError -> writeException
is java.io.IOException -> ServerError.ConnectionError.ConnectionClosed(writeException.message)
is java.nio.channels.ClosedChannelException -> ServerError.ConnectionError.ConnectionClosed("Channel closed")
else -> ServerError.ProtocolError.SerializationError(intercepted::class.simpleName ?: "Unknown", writeException)
}
LOGGER.error("Failed to write '{}' to {}: {}", intercepted, destination, writeError.message, writeError)
break
}
LOGGER.debug("Wrote '{}' to {}", intercepted, destination)
Expand Down
Loading