diff --git a/libp2p-muxer-yamux/src/main/kotlin/org/erwinkok/libp2p/muxer/yamux/Session.kt b/libp2p-muxer-yamux/src/main/kotlin/org/erwinkok/libp2p/muxer/yamux/Session.kt index 7e70812..37aa0d7 100644 --- a/libp2p-muxer-yamux/src/main/kotlin/org/erwinkok/libp2p/muxer/yamux/Session.kt +++ b/libp2p-muxer-yamux/src/main/kotlin/org/erwinkok/libp2p/muxer/yamux/Session.kt @@ -1,14 +1,43 @@ // Copyright (c) 2024 Erwin Kok. BSD-3-Clause license. See LICENSE file for more details. package org.erwinkok.libp2p.muxer.yamux +import io.ktor.utils.io.cancel +import io.ktor.utils.io.close +import io.ktor.utils.io.core.BytePacketBuilder +import io.ktor.utils.io.core.internal.ChunkBuffer +import io.ktor.utils.io.pool.ObjectPool +import kotlinx.atomicfu.locks.ReentrantLock +import kotlinx.atomicfu.locks.withLock +import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Job +import kotlinx.coroutines.channels.ClosedReceiveChannelException +import kotlinx.coroutines.launch +import kotlinx.coroutines.selects.select +import kotlinx.coroutines.withTimeoutOrNull import mu.KotlinLogging import org.erwinkok.libp2p.core.base.AwaitableClosable import org.erwinkok.libp2p.core.network.Connection +import org.erwinkok.libp2p.core.network.streammuxer.MuxedStream +import org.erwinkok.libp2p.core.util.SafeChannel +import org.erwinkok.libp2p.muxer.yamux.frame.CloseFrame +import org.erwinkok.libp2p.muxer.yamux.frame.Frame +import org.erwinkok.libp2p.muxer.yamux.frame.MessageFrame +import org.erwinkok.libp2p.muxer.yamux.frame.NewStreamFrame +import org.erwinkok.libp2p.muxer.yamux.frame.ResetFrame +import org.erwinkok.libp2p.muxer.yamux.frame.readMplexFrame +import org.erwinkok.libp2p.muxer.yamux.frame.writeMplexFrame +import org.erwinkok.result.Err +import org.erwinkok.result.Error import org.erwinkok.result.Ok import org.erwinkok.result.Result +import org.erwinkok.result.errorMessage +import org.erwinkok.result.map +import org.erwinkok.result.onFailure +import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicLong +import kotlin.coroutines.cancellation.CancellationException +import kotlin.time.Duration.Companion.seconds private val logger = KotlinLogging.logger {} @@ -95,15 +124,229 @@ class Session( } } + + private val streamChannel = SafeChannel(16) + private val outputChannel = SafeChannel(16) + private val mutex = ReentrantLock() + private val streams = mutableMapOf() + private val nextId = AtomicLong(0) + private val isClosing = AtomicBoolean(false) + private var closeCause: Error? = null + private val receiverJob: Job + private val pool: ObjectPool get() = connection.pool + + init { + receiverJob = processInbound() + processOutbound() + } + + suspend fun openStream(name: String?): Result { + return newNamedStream(name) + } + + suspend fun acceptStream(): Result { + if (streamChannel.isClosedForReceive) { + return Err(closeCause ?: ErrShutdown) + } + return try { + select { + streamChannel.onReceive { + Ok(it) + } + receiverJob.onJoin { + Err(closeCause ?: ErrShutdown) + } + } + } catch (e: Exception) { + Err(ErrShutdown) + } + } + override fun close() { + isClosing.set(true) + streamChannel.close() + receiverJob.cancel() + if (streams.isEmpty()) { + outputChannel.close() + } + _context.complete() + } + + internal fun removeStream(streamId: MplexStreamId) { + mutex.withLock { + streams.remove(streamId) + if (isClosing.get() && streams.isEmpty()) { + outputChannel.close() + } + } + } + + private fun processInbound() = scope.launch(_context + CoroutineName("mplex-stream-input-loop")) { + while (!connection.input.isClosedForRead && !streamChannel.isClosedForSend) { + try { + connection.input.readMplexFrame() + .map { mplexFrame -> processFrame(mplexFrame) } + .onFailure { + closeCause = it + return@launch + } + } catch (e: CancellationException) { + break + } catch (e: Exception) { + logger.warn { "Unexpected error occurred in mplex multiplexer input loop: ${errorMessage(e)}" } + throw e + } + } + }.apply { + invokeOnCompletion { + connection.input.cancel() + streamChannel.cancel() + // do not cancel the input of the streams here, there might still be some pending frames in the input queue. + // instead, close the input loop gracefully. + streams.forEach { it.value.remoteClosesWriting() } + } + } + + private fun processOutbound() = scope.launch(_context + CoroutineName("mplex-stream-output-loop")) { + while (!outputChannel.isClosedForReceive && !connection.output.isClosedForWrite) { + try { + val frame = outputChannel.receive() + connection.output.writeMplexFrame(frame) + connection.output.flush() + } catch (e: ClosedReceiveChannelException) { + break + } catch (e: CancellationException) { + break + } catch (e: Exception) { + logger.warn { "Unexpected error occurred in mplex mux input loop: ${errorMessage(e)}" } + throw e + } + } + }.apply { + invokeOnCompletion { + connection.output.close() + outputChannel.close() + // It is safe here to close the output of all streams, closing will still process pending requests. + streams.forEach { it.value.output.close() } + } + } + + private suspend fun processFrame(mplexFrame: Frame) { + val id = mplexFrame.id + val initiator = mplexFrame.initiator + val mplexStreamId = MplexStreamId(!initiator, id) + mutex.lock() + val stream: YamuxMuxedStream? = streams[mplexStreamId] + when (mplexFrame) { + is NewStreamFrame -> { + if (stream != null) { + mutex.unlock() + logger.warn { "$this: Remote creates existing new stream: $id. Ignoring." } + } else { + logger.debug { "$this: Remote creates new stream: $id" } + val name = streamName(mplexFrame.name, mplexStreamId) + val newStream = YamuxMuxedStream(scope, this, outputChannel, mplexStreamId, name, pool) + streams[mplexStreamId] = newStream + mutex.unlock() + streamChannel.send(newStream) + } + } + + is MessageFrame -> { + if (logger.isDebugEnabled) { + if (initiator) { + logger.debug("$this: Remote sends message on his stream: $id") + } else { + logger.debug("$this: Remote sends message on our stream: $id") + } + } + if (stream != null) { + mutex.unlock() + val builder = BytePacketBuilder(pool) + val data = mplexFrame.packet + builder.writePacket(data.copy()) + // There is (almost) no backpressure. If the reader is slow/blocking, then the entire muxer is blocking. + // Give the reader "ReceiveTimeout" time to process, reset stream if too slow. + val timeout = withTimeoutOrNull(ReceivePushTimeout) { + stream.remoteSendsNewMessage(builder.build()) + } + if (timeout == null) { + logger.warn { "$this: Reader timeout for stream: $mplexStreamId. Reader is too slow, resetting the stream." } + stream.reset() + } + } else { + mutex.unlock() + logger.warn { "$this: Remote sends message on non-existing stream: $mplexStreamId" } + } + } + is CloseFrame -> { + if (logger.isDebugEnabled) { + if (initiator) { + logger.debug("$this: Remote closes his stream: $mplexStreamId") + } else { + logger.debug("$this: Remote closes our stream: $mplexStreamId") + } + } + if (stream != null) { + mutex.unlock() + stream.remoteClosesWriting() + } else { + mutex.unlock() + logger.debug { "$this: Remote closes non-existing stream: $mplexStreamId" } + } + } + + is ResetFrame -> { + if (logger.isDebugEnabled) { + if (initiator) { + logger.debug("$this: Remote resets his stream: $id") + } else { + logger.debug("$this: Remote resets our stream: $id") + } + } + if (stream != null) { + mutex.unlock() + stream.remoteResetsStream() + } else { + mutex.unlock() + logger.debug { "$this: Remote resets non-existing stream: $id" } + } + } + } + mplexFrame.close() + } + + private suspend fun newNamedStream(newName: String?): Result { + if (outputChannel.isClosedForSend) { + return Err("$this: Mplex is closed") + } + mutex.lock() + val id = nextId.getAndIncrement() + val streamId = MplexStreamId(true, id) + logger.debug { "$this: We create stream: $id" } + val name = streamName(newName, streamId) + val muxedStream = YamuxMuxedStream(scope, this, outputChannel, streamId, name, pool) + streams[streamId] = muxedStream + mutex.unlock() + outputChannel.send(NewStreamFrame(id, name)) + return Ok(muxedStream) + } + + private fun streamName(name: String?, streamId: MplexStreamId): String { + if (name != null) { + return name + } + return String.format("stream%08x", streamId.id) } - fun openStream(name: String?): Result { - TODO("Not yet implemented") + override fun toString(): String { + val initiator = if (client) "client" else "server" + return "yamux-muxer<$initiator>" } - fun acceptStream(): Result { - TODO("Not yet implemented") + companion object { + private val ErrShutdown = Error("session shut down") + private val ReceivePushTimeout = 5.seconds } } diff --git a/libp2p-muxer-yamux/src/main/kotlin/org/erwinkok/libp2p/muxer/yamux/YamuxMuxedStream.kt b/libp2p-muxer-yamux/src/main/kotlin/org/erwinkok/libp2p/muxer/yamux/YamuxMuxedStream.kt index ca9e664..b1c5f04 100644 --- a/libp2p-muxer-yamux/src/main/kotlin/org/erwinkok/libp2p/muxer/yamux/YamuxMuxedStream.kt +++ b/libp2p-muxer-yamux/src/main/kotlin/org/erwinkok/libp2p/muxer/yamux/YamuxMuxedStream.kt @@ -1,12 +1,40 @@ // Copyright (c) 2024 Erwin Kok. BSD-3-Clause license. See LICENSE file for more details. package org.erwinkok.libp2p.muxer.yamux +import io.ktor.network.util.DefaultByteBufferPool +import io.ktor.utils.io.ByteChannel import io.ktor.utils.io.ByteReadChannel import io.ktor.utils.io.ByteWriteChannel +import io.ktor.utils.io.ReaderJob +import io.ktor.utils.io.WriterJob +import io.ktor.utils.io.cancel +import io.ktor.utils.io.close +import io.ktor.utils.io.core.ByteReadPacket import io.ktor.utils.io.core.internal.ChunkBuffer +import io.ktor.utils.io.core.writeFully import io.ktor.utils.io.pool.ObjectPool +import io.ktor.utils.io.pool.useInstance +import io.ktor.utils.io.reader +import io.ktor.utils.io.writer +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CoroutineName +import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Job +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.ClosedSendChannelException +import kotlinx.coroutines.channels.consumeEach +import mu.KotlinLogging +import org.erwinkok.libp2p.core.network.StreamResetException import org.erwinkok.libp2p.core.network.streammuxer.MuxedStream +import org.erwinkok.libp2p.core.util.SafeChannel +import org.erwinkok.libp2p.core.util.buildPacket +import org.erwinkok.libp2p.muxer.yamux.frame.CloseFrame +import org.erwinkok.libp2p.muxer.yamux.frame.Frame +import org.erwinkok.libp2p.muxer.yamux.frame.MessageFrame +import org.erwinkok.libp2p.muxer.yamux.frame.ResetFrame +import org.erwinkok.result.errorMessage +import java.io.IOException +import java.nio.ByteBuffer enum class StreamState { StreamInit, @@ -22,10 +50,16 @@ enum class HalfStreamState { HalfReset, } +private val logger = KotlinLogging.logger {} + class YamuxMuxedStream( + private val scope: CoroutineScope, private val session: Session, - -): MuxedStream { + private val outputChannel: Channel, + private val mplexStreamId: MplexStreamId, + override val name: String, + override val pool: ObjectPool +) : MuxedStream { // sendWindow uint32 // // memorySpan MemoryManager @@ -47,26 +81,127 @@ class YamuxMuxedStream( // // readDeadline, writeDeadline pipeDeadline - override val name: String - get() = TODO("Not yet implemented") - override val id: String - get() = TODO("Not yet implemented") + private val inputChannel = SafeChannel(16) + private val _context = Job(scope.coroutineContext[Job]) + private val writerJob: WriterJob + private val readerJob: ReaderJob - override fun reset() { - TODO("Not yet implemented") + override val id + get() = mplexStreamId.toString() + override val jobContext: Job + get() = _context + + override val input: ByteReadChannel = ByteChannel(false).also { writerJob = attachForReading(it) } + override val output: ByteWriteChannel = ByteChannel(false).also { readerJob = attachForWriting(it) } + + private fun attachForReading(channel: ByteChannel): WriterJob = + scope.writer(_context + CoroutineName("mplex-stream-input-loop"), channel) { + inputDataLoop(this.channel) + }.apply { + invokeOnCompletion { + if (readerJob.isCompleted) { + session.removeStream(mplexStreamId) + } + } + } + + private fun attachForWriting(channel: ByteChannel): ReaderJob = + scope.reader(_context + CoroutineName("mplex-stream-output-loop"), channel) { + outputDataLoop(this.channel) + }.apply { + invokeOnCompletion { + if (writerJob.isCompleted) { + session.removeStream(mplexStreamId) + } + } + } + + private suspend fun inputDataLoop(channel: ByteWriteChannel) { + while (!inputChannel.isClosedForReceive && !channel.isClosedForWrite) { + try { + inputChannel.consumeEach { + channel.writePacket(it) + channel.flush() + } + } catch (e: CancellationException) { + break + } catch (e: Exception) { + logger.warn { "Unexpected error occurred in mplex mux input loop: ${errorMessage(e)}" } + throw e + } + } + if (!inputChannel.isClosedForReceive) { + inputChannel.cancel() + } } - override val pool: ObjectPool - get() = TODO("Not yet implemented") - override val input: ByteReadChannel - get() = TODO("Not yet implemented") - override val output: ByteWriteChannel - get() = TODO("Not yet implemented") + private suspend fun outputDataLoop(channel: ByteReadChannel): Unit = DefaultByteBufferPool.useInstance { buffer: ByteBuffer -> + while (!channel.isClosedForRead && !outputChannel.isClosedForSend) { + buffer.clear() + try { + val size = channel.readAvailable(buffer) + if (size > 0) { + buffer.flip() + val packet = buildPacket(pool) { writeFully(buffer) } + val messageFrame = MessageFrame(mplexStreamId, packet) + outputChannel.send(messageFrame) + } + } catch (e: CancellationException) { + break + } catch (e: ClosedSendChannelException) { + break + } catch (e: Exception) { + logger.warn { "Unexpected error occurred in mplex mux output loop: ${errorMessage(e)}" } + throw e + } + } + if (!channel.isClosedForRead) { + channel.cancel(IOException("Failed writing to closed connection")) + } + if (!outputChannel.isClosedForSend) { + if (channel.closedCause is StreamResetException) { + outputChannel.send(ResetFrame(mplexStreamId)) + } else { + outputChannel.send(CloseFrame(mplexStreamId)) + } + } + } + + override fun reset() { + inputChannel.cancel() + input.cancel(StreamResetException()) + output.close(StreamResetException()) + _context.complete() + } override fun close() { - TODO("Not yet implemented") + inputChannel.cancel() + input.cancel() + output.close() + _context.complete() } - override val jobContext: Job - get() = TODO("Not yet implemented") + override fun toString(): String { + return "mplex-<$mplexStreamId>" + } + + internal suspend fun remoteSendsNewMessage(packet: ByteReadPacket): Boolean { + if (inputChannel.isClosedForSend) { + packet.close() + return false + } + inputChannel.send(packet) + return true + } + + internal fun remoteClosesWriting() { + inputChannel.close() + } + + internal fun remoteResetsStream() { + inputChannel.cancel() + input.cancel(StreamResetException()) + output.close(StreamResetException()) + _context.completeExceptionally(StreamResetException()) + } } diff --git a/libp2p-muxer-yamux/src/test/kotlin/org/erwinkok/libp2p/muxer/yamux/YamuxMultiplexerTest.kt b/libp2p-muxer-yamux/src/test/kotlin/org/erwinkok/libp2p/muxer/yamux/YamuxMultiplexerTest.kt new file mode 100644 index 0000000..7f7e0b0 --- /dev/null +++ b/libp2p-muxer-yamux/src/test/kotlin/org/erwinkok/libp2p/muxer/yamux/YamuxMultiplexerTest.kt @@ -0,0 +1,478 @@ +// Copyright (c) 2023 Erwin Kok. BSD-3-Clause license. See LICENSE file for more details. +package org.erwinkok.libp2p.muxer.yamux + +import io.ktor.utils.io.close +import io.ktor.utils.io.core.readBytes +import io.ktor.utils.io.core.toByteArray +import io.ktor.utils.io.core.writeFully +import io.ktor.utils.io.readFully +import io.ktor.utils.io.writeFully +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.Job +import kotlinx.coroutines.channels.ClosedReceiveChannelException +import kotlinx.coroutines.delay +import kotlinx.coroutines.joinAll +import kotlinx.coroutines.launch +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.yield +import org.erwinkok.libp2p.core.network.Connection +import org.erwinkok.libp2p.core.network.StreamResetException +import org.erwinkok.libp2p.core.network.streammuxer.MuxedStream +import org.erwinkok.libp2p.core.util.buildPacket +import org.erwinkok.libp2p.muxer.yamux.frame.CloseFrame +import org.erwinkok.libp2p.muxer.yamux.frame.MessageFrame +import org.erwinkok.libp2p.muxer.yamux.frame.NewStreamFrame +import org.erwinkok.libp2p.muxer.yamux.frame.readMplexFrame +import org.erwinkok.libp2p.muxer.yamux.frame.writeMplexFrame +import org.erwinkok.libp2p.testing.TestConnection +import org.erwinkok.libp2p.testing.TestWithLeakCheck +import org.erwinkok.libp2p.testing.VerifyingChunkBufferPool +import org.erwinkok.result.coAssertErrorResult +import org.erwinkok.result.expectNoErrors +import org.junit.jupiter.api.Assertions.assertArrayEquals +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertFalse +import org.junit.jupiter.api.Assertions.assertInstanceOf +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import kotlin.experimental.xor +import kotlin.random.Random +import kotlin.time.Duration.Companion.minutes + +internal class YamuxMultiplexerTest : TestWithLeakCheck { + override val pool = VerifyingChunkBufferPool() + + private val maxStreamId = 0x1000000000000000L + + @Test + fun remoteRequestsNewStream() = runTest { + val connectionPair = TestConnection(pool) + val mplexMultiplexer = YamuxStreamMuxerConnection(Session(this, YamuxConfig(), connectionPair.local, true)) + repeat(1000) { + val id = randomId() + connectionPair.remote.output.writeMplexFrame(NewStreamFrame(id, "aName$id")) + connectionPair.remote.output.flush() + val muxedStream = mplexMultiplexer.acceptStream().expectNoErrors() + assertEquals("aName$id", muxedStream.name) + assertStreamHasId(false, id, muxedStream) + muxedStream.close() + assertCloseFrameReceived(connectionPair.remote) + } + mplexMultiplexer.close() + mplexMultiplexer.awaitClosed() + } + + @Test + fun localRequestNewStream() = runTest { + val connectionPair = TestConnection(pool) + val mplexMultiplexer = YamuxStreamMuxerConnection(Session(this, YamuxConfig(), connectionPair.local, true)) + repeat(1000) { + val muxedStream = mplexMultiplexer.openStream("newStreamName$it").expectNoErrors() + assertEquals("newStreamName$it", muxedStream.name) + assertEquals(MplexStreamId(true, it.toLong()).toString(), muxedStream.id) + val actual = connectionPair.remote.input.readMplexFrame().expectNoErrors() + assertInstanceOf(NewStreamFrame::class.java, actual) + assertTrue(actual.initiator) + assertEquals(it.toLong(), actual.id) + muxedStream.close() + assertCloseFrameReceived(connectionPair.remote) + } + mplexMultiplexer.close() + mplexMultiplexer.awaitClosed() + } + + @Test + fun remoteOpensAndRemoteSends() = runTest { + val connectionPair = TestConnection(pool) + val mplexMultiplexer = YamuxStreamMuxerConnection(Session(this, YamuxConfig(), connectionPair.local, true)) + repeat(1000) { + val id = randomId() + connectionPair.remote.output.writeMplexFrame(NewStreamFrame(id, "aName$id")) + connectionPair.remote.output.flush() + val muxedStream = mplexMultiplexer.acceptStream().expectNoErrors() + assertEquals("aName$id", muxedStream.name) + assertStreamHasId(false, id, muxedStream) + val random1 = Random.nextBytes(1000) + connectionPair.remote.output.writeMplexFrame(MessageFrame(MplexStreamId(true, id), buildPacket(pool) { writeFully(random1) })) + connectionPair.remote.output.flush() + assertFalse(muxedStream.input.isClosedForRead) + val random2 = ByteArray(random1.size) + muxedStream.input.readFully(random2) + assertArrayEquals(random1, random2) + muxedStream.close() + assertCloseFrameReceived(connectionPair.remote) + } + mplexMultiplexer.close() + mplexMultiplexer.awaitClosed() + } + + @Test + fun remoteOpensAndLocalSends() = runTest { + val connectionPair = TestConnection(pool) + val mplexMultiplexer = YamuxStreamMuxerConnection(Session(this, YamuxConfig(), connectionPair.local, true)) + repeat(1000) { + val id = randomId() + connectionPair.remote.output.writeMplexFrame(NewStreamFrame(id, "aName$id")) + connectionPair.remote.output.flush() + val muxedStream = mplexMultiplexer.acceptStream().expectNoErrors() + assertEquals("aName$id", muxedStream.name) + assertStreamHasId(false, id, muxedStream) + val random1 = Random.nextBytes(1000) + assertFalse(muxedStream.output.isClosedForWrite) + muxedStream.output.writeFully(random1) + muxedStream.output.flush() + assertMessageFrameReceived(random1, connectionPair.remote) + muxedStream.close() + assertCloseFrameReceived(connectionPair.remote) + } + mplexMultiplexer.close() + mplexMultiplexer.awaitClosed() + } + + @Test + fun localOpenAndLocalSends() = runTest { + val connectionPair = TestConnection(pool) + val mplexMultiplexer = YamuxStreamMuxerConnection(Session(this, YamuxConfig(), connectionPair.local, true)) + repeat(1000) { + val muxedStream = mplexMultiplexer.openStream("newStreamName$it").expectNoErrors() + assertEquals("newStreamName$it", muxedStream.name) + assertEquals(MplexStreamId(true, it.toLong()).toString(), muxedStream.id) + assertNewStreamFrameReceived(it, "newStreamName$it", connectionPair.remote) + val random1 = Random.nextBytes(1000) + assertFalse(muxedStream.output.isClosedForWrite) + muxedStream.output.writeFully(random1) + muxedStream.output.flush() + assertMessageFrameReceived(random1, connectionPair.remote) + muxedStream.close() + assertCloseFrameReceived(connectionPair.remote) + } + mplexMultiplexer.close() + mplexMultiplexer.awaitClosed() + } + + @Test + fun localOpensAndRemoteSends() = runTest { + val connectionPair = TestConnection(pool) + val mplexMultiplexer = YamuxStreamMuxerConnection(Session(this, YamuxConfig(), connectionPair.local, true)) + repeat(1000) { + val muxedStream = mplexMultiplexer.openStream("newStreamName$it").expectNoErrors() + assertEquals("newStreamName$it", muxedStream.name) + assertEquals(MplexStreamId(true, it.toLong()).toString(), muxedStream.id) + assertNewStreamFrameReceived(it, "newStreamName$it", connectionPair.remote) + val random1 = Random.nextBytes(1000) + connectionPair.remote.output.writeMplexFrame(MessageFrame(MplexStreamId(false, it.toLong()), buildPacket(pool) { writeFully(random1) })) + connectionPair.remote.output.flush() + assertFalse(muxedStream.input.isClosedForRead) + val random2 = ByteArray(random1.size) + muxedStream.input.readFully(random2) + assertArrayEquals(random1, random2) + muxedStream.close() + assertCloseFrameReceived(connectionPair.remote) + } + mplexMultiplexer.close() + mplexMultiplexer.awaitClosed() + } + + @Test + fun remoteRequestsNewStreamAndCloses() = runTest { + val connectionPair = TestConnection(pool) + val mplexMultiplexer = YamuxStreamMuxerConnection(Session(this, YamuxConfig(), connectionPair.local, true)) + val id = randomId() + connectionPair.remote.output.writeMplexFrame(NewStreamFrame(id, "aName$id")) + connectionPair.remote.output.flush() + val muxedStream = mplexMultiplexer.acceptStream().expectNoErrors() + assertEquals("aName$id", muxedStream.name) + assertStreamHasId(false, id, muxedStream) + assertFalse(muxedStream.input.isClosedForRead) + assertFalse(muxedStream.output.isClosedForWrite) + connectionPair.remote.output.writeMplexFrame(CloseFrame(MplexStreamId(true, id))) + connectionPair.remote.output.flush() + val exception = assertThrows { + muxedStream.input.readPacket(10) + } + assertEquals("Unexpected EOF: expected 10 more bytes", exception.message) + assertTrue(muxedStream.input.isClosedForRead) + assertFalse(muxedStream.output.isClosedForWrite) + muxedStream.close() + assertCloseFrameReceived(connectionPair.remote) + mplexMultiplexer.close() + mplexMultiplexer.awaitClosed() + } + + @Test + fun remoteRequestsNewStreamAndLocalCloses() = runTest { + val connectionPair = TestConnection(pool) + val mplexMultiplexer = YamuxStreamMuxerConnection(Session(this, YamuxConfig(), connectionPair.local, true)) + repeat(1000) { + val id = randomId() + connectionPair.remote.output.writeMplexFrame(NewStreamFrame(id, "aName$id")) + connectionPair.remote.output.flush() + val muxedStream = mplexMultiplexer.acceptStream().expectNoErrors() + assertEquals("aName$id", muxedStream.name) + assertStreamHasId(false, id, muxedStream) + muxedStream.output.close() + yield() + assertCloseFrameReceived(connectionPair.remote) + assertFalse(muxedStream.input.isClosedForRead) + assertTrue(muxedStream.output.isClosedForWrite) + val exception = assertThrows { + muxedStream.output.writeFully(Random.nextBytes(1000)) + } + assertEquals("The channel was closed", exception.message) + muxedStream.close() + } + mplexMultiplexer.close() + mplexMultiplexer.awaitClosed() + } + + @Test + fun localRequestNewStreamAndCloses() = runTest { + val connectionPair = TestConnection(pool) + val mplexMultiplexer = YamuxStreamMuxerConnection(Session(this, YamuxConfig(), connectionPair.local, true)) + repeat(1000) { + val muxedStream = mplexMultiplexer.openStream("newStreamName$it").expectNoErrors() + assertEquals("newStreamName$it", muxedStream.name) + assertEquals(MplexStreamId(true, it.toLong()).toString(), muxedStream.id) + assertNewStreamFrameReceived(it, "newStreamName$it", connectionPair.remote) + muxedStream.output.close() + yield() + assertCloseFrameReceived(connectionPair.remote) + assertFalse(muxedStream.input.isClosedForRead) + assertTrue(muxedStream.output.isClosedForWrite) + val exception = assertThrows { + muxedStream.output.writeFully(Random.nextBytes(1000)) + } + assertEquals("The channel was closed", exception.message) + muxedStream.close() + } + mplexMultiplexer.close() + mplexMultiplexer.awaitClosed() + } + + @Test + fun localRequestNewStreamAndRemoteCloses() = runTest { + val connectionPair = TestConnection(pool) + val mplexMultiplexer = YamuxStreamMuxerConnection(Session(this, YamuxConfig(), connectionPair.local, true)) + repeat(1000) { + val muxedStream = mplexMultiplexer.openStream("newStreamName$it").expectNoErrors() + assertEquals("newStreamName$it", muxedStream.name) + assertEquals(MplexStreamId(true, it.toLong()).toString(), muxedStream.id) + assertNewStreamFrameReceived(it, "newStreamName$it", connectionPair.remote) + assertFalse(muxedStream.input.isClosedForRead) + assertFalse(muxedStream.output.isClosedForWrite) + connectionPair.remote.output.writeMplexFrame(CloseFrame(MplexStreamId(false, it.toLong()))) + connectionPair.remote.output.flush() + val exception = assertThrows { + muxedStream.input.readPacket(10) + } + assertEquals("Unexpected EOF: expected 10 more bytes", exception.message) + assertTrue(muxedStream.input.isClosedForRead) + assertFalse(muxedStream.output.isClosedForWrite) + muxedStream.close() + assertCloseFrameReceived(connectionPair.remote) + } + mplexMultiplexer.close() + mplexMultiplexer.awaitClosed() + } + + @Test + fun basicStreams() = runTest { + val connectionPair = TestConnection(pool) + val muxa = YamuxStreamMuxerConnection(Session(this, YamuxConfig(), connectionPair.local, true)) + val muxb = YamuxStreamMuxerConnection(Session(this, YamuxConfig(), connectionPair.remote, false)) + repeat(100) { + val random1 = Random.nextBytes(40960) + val job = launch { + val sb = muxb.acceptStream().expectNoErrors() + sb.output.writeFully(random1) + sb.output.flush() + sb.close() + sb.awaitClosed() + } + val sa = muxa.openStream().expectNoErrors() + val random2 = ByteArray(random1.size) + sa.input.readFully(random2) + assertArrayEquals(random1, random2) + job.join() + sa.close() + } + muxa.close() + muxb.close() + } + + @Test + fun echo() = runTest { + val pipe = TestConnection(pool) + val muxa = YamuxStreamMuxerConnection(Session(this, YamuxConfig(), pipe.local, true)) + val muxb = YamuxStreamMuxerConnection(Session(this, YamuxConfig(), pipe.remote, false)) + repeat(100) { + val message = Random.nextBytes(40960) + val job = launch { + val sb = muxb.acceptStream().expectNoErrors() + val buf = ByteArray(message.size) + sb.input.readFully(buf) + sb.output.writeFully(buf) + sb.output.flush() + sb.close() + } + val sa = muxa.openStream().expectNoErrors() + sa.output.writeFully(message) + sa.output.flush() + val buf = ByteArray(message.size) + sa.input.readFully(buf) + assertArrayEquals(message, buf) + job.join() + sa.close() + } + muxa.close() + muxb.close() + } + + @Test + fun stress() = runTest(timeout = 1.minutes) { + val pipe = TestConnection(pool) + val muxa = YamuxStreamMuxerConnection(Session(this, YamuxConfig(), pipe.local, true)) + val muxb = YamuxStreamMuxerConnection(Session(this, YamuxConfig(), pipe.remote, false)) + val messageSize = 40960 + repeat(1000) { + val jobs = mutableListOf() + repeat(10) { + jobs.add( + launch { + delay(Random.nextLong(1000)) + val sb = muxb.acceptStream().expectNoErrors() + val buf = ByteArray(messageSize) + sb.input.readFully(buf) + for (i in buf.indices) { + buf[i] = buf[i] xor 123 + } + sb.output.writeFully(buf) + sb.output.flush() + sb.close() + }, + ) + } + repeat(10) { + jobs.add( + launch { + val message = Random.nextBytes(messageSize) + val sa = muxa.openStream().expectNoErrors() + sa.output.writeFully(message) + sa.output.flush() + val buf = ByteArray(messageSize) + sa.input.readFully(buf) + for (i in buf.indices) { + buf[i] = buf[i] xor 123 + } + assertArrayEquals(message, buf) + sa.close() + }, + ) + } + jobs.joinAll() + } + muxa.close() + muxb.close() + } + + @Test + fun writeAfterClose() = runTest { + val pipe = TestConnection(pool) + val muxa = YamuxStreamMuxerConnection(Session(this, YamuxConfig(), pipe.local, true)) + val muxb = YamuxStreamMuxerConnection(Session(this, YamuxConfig(), pipe.remote, false)) + val message = "Hello world".toByteArray() + launch { + val sb = muxb.acceptStream().expectNoErrors() + sb.output.writeFully(message) + sb.output.flush() + sb.close() + sb.output.writeFully(message) + sb.output.flush() + } + val sa = muxa.openStream().expectNoErrors() + assertFalse(sa.input.isClosedForRead) + val buf = ByteArray(message.size) + sa.input.readFully(buf) + assertArrayEquals(message, buf) + assertTrue(sa.input.isClosedForRead) + val exception1 = assertThrows { + sa.input.readFully(buf) + } + assertEquals("Unexpected EOF: expected 11 more bytes", exception1.message) + sa.close() + muxa.close() + muxb.close() + } + + @Test + fun slowReader() = runTest { + val pipe = TestConnection(pool) + val muxa = YamuxStreamMuxerConnection(Session(this, YamuxConfig(), pipe.local, true)) + val muxb = YamuxStreamMuxerConnection(Session(this, YamuxConfig(), pipe.remote, false)) + val message = "Hello world".toByteArray() + val sa = muxa.openStream().expectNoErrors() + val exception = assertThrows { + for (i in 0..10000) { + sa.output.writeFully(message) + sa.output.flush() + yield() + } + } + assertEquals("Stream was reset", exception.message) + muxa.close() + muxb.close() + } + + @Test + fun acceptingStreamWhileClosing() = runTest { + val pipe = TestConnection(pool) + val mux = YamuxStreamMuxerConnection(Session(this, YamuxConfig(), pipe.local, true)) + val job = launch { + coAssertErrorResult("session shut down") { mux.acceptStream() } + } + mux.close() + job.join() + } + + @Test + fun acceptingStreamAfterClose() = runTest { + val pipe = TestConnection(pool) + val mux = YamuxStreamMuxerConnection(Session(this, YamuxConfig(), pipe.local, true)) + mux.close() + coAssertErrorResult("session shut down") { mux.acceptStream() } + } + + private fun randomId(): Long { + return Random.nextLong(maxStreamId) + } + + private fun assertStreamHasId(initiator: Boolean, id: Long, muxedStream: MuxedStream) { + assertEquals(MplexStreamId(initiator, id).toString(), muxedStream.id) + } + + private suspend fun assertMessageFrameReceived(expected: ByteArray, connection: Connection) { + val frame = connection.input.readMplexFrame().expectNoErrors() + if (frame is MessageFrame) { + assertArrayEquals(expected, frame.packet.readBytes()) + } else { + assertFalse(true, "MessageFrame expected") + } + } + + private suspend fun assertNewStreamFrameReceived(id: Int, name: String, connection: Connection) { + val frame = connection.input.readMplexFrame().expectNoErrors() + if (frame is NewStreamFrame) { + assertTrue(frame.initiator) + assertEquals(id.toLong(), frame.id) + assertEquals(name, frame.name) + } else { + assertFalse(true, "NewStreamFrame expected") + } + } + + private suspend fun assertCloseFrameReceived(connection: Connection) { + val frame = connection.input.readMplexFrame().expectNoErrors() + assertTrue(frame is CloseFrame) + } +} diff --git a/libp2p-muxer-yamux/src/test/kotlin/org/erwinkok/libp2p/muxer/yamux/YamuxMuxedStreamTest.kt b/libp2p-muxer-yamux/src/test/kotlin/org/erwinkok/libp2p/muxer/yamux/YamuxMuxedStreamTest.kt new file mode 100644 index 0000000..2892658 --- /dev/null +++ b/libp2p-muxer-yamux/src/test/kotlin/org/erwinkok/libp2p/muxer/yamux/YamuxMuxedStreamTest.kt @@ -0,0 +1,451 @@ +// Copyright (c) 2023 Erwin Kok. BSD-3-Clause license. See LICENSE file for more details. +package org.erwinkok.libp2p.muxer.yamux + +import io.ktor.utils.io.CancellationException +import io.ktor.utils.io.cancel +import io.ktor.utils.io.close +import io.ktor.utils.io.core.BytePacketBuilder +import io.ktor.utils.io.core.readBytes +import io.ktor.utils.io.core.writeFully +import io.ktor.utils.io.readFully +import io.ktor.utils.io.writeFully +import io.mockk.Runs +import io.mockk.coVerify +import io.mockk.every +import io.mockk.just +import io.mockk.mockk +import io.mockk.slot +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.channels.ClosedReceiveChannelException +import kotlinx.coroutines.channels.consumeEach +import kotlinx.coroutines.launch +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withTimeoutOrNull +import kotlinx.coroutines.yield +import org.erwinkok.libp2p.core.network.StreamResetException +import org.erwinkok.libp2p.core.util.SafeChannel +import org.erwinkok.libp2p.core.util.buildPacket +import org.erwinkok.libp2p.muxer.yamux.frame.CloseFrame +import org.erwinkok.libp2p.muxer.yamux.frame.Frame +import org.erwinkok.libp2p.muxer.yamux.frame.MessageFrame +import org.erwinkok.libp2p.muxer.yamux.frame.ResetFrame +import org.erwinkok.libp2p.testing.TestWithLeakCheck +import org.erwinkok.libp2p.testing.VerifyingChunkBufferPool +import org.junit.jupiter.api.Assertions.assertArrayEquals +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertFalse +import org.junit.jupiter.api.Assertions.assertNotNull +import org.junit.jupiter.api.Assertions.assertNull +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import kotlin.random.Random +import kotlin.time.Duration.Companion.seconds + +internal class YamuxMuxedStreamTest : TestWithLeakCheck { + override val pool = VerifyingChunkBufferPool() + + private val mplexStreamId = MplexStreamId(true, 1234) + private val mplexStreamName = "AName" + private val session = mockk() + private val streamIdSlot = slot() + + @BeforeEach + fun setup() { + every { session.removeStream(capture(streamIdSlot)) } just Runs + } + + @Test + fun testIdAndName() = runTest { + val reader = FrameReader(this, pool) + val muxedStream = YamuxMuxedStream(this, session, reader.frameChannel, mplexStreamId, mplexStreamName, pool) + assertEquals("stream000004d2/initiator", muxedStream.id) + assertEquals(mplexStreamName, muxedStream.name) + muxedStream.close() + muxedStream.awaitClosed() + reader.stop() + reader.assertNoBytesReceived() + } + + @Test + fun testInitiallyNothingAvailableForRead() = runTest { + val reader = FrameReader(this, pool) + val muxedStream = YamuxMuxedStream(this, session, reader.frameChannel, mplexStreamId, mplexStreamName, pool) + assertEquals(0, muxedStream.input.availableForRead) + muxedStream.close() + muxedStream.awaitClosed() + reader.stop() + reader.assertNoBytesReceived() + } + + @Test + fun testReadPacket() = runTest { + repeat(1000) { + val reader = FrameReader(this, pool) + val muxedStream = YamuxMuxedStream(this, session, reader.frameChannel, mplexStreamId, mplexStreamName, pool) + val random = Random.nextBytes(100000) + assertTrue(muxedStream.remoteSendsNewMessage(buildPacket(pool) { writeFully(random) })) + val bytes = ByteArray(random.size) + muxedStream.input.readFully(bytes) + assertArrayEquals(random, bytes) + muxedStream.close() + muxedStream.awaitClosed() + reader.stop() + reader.assertNoBytesReceived() + } + } + + @Test + fun testReadPacketSplit() = runTest { + repeat(1000) { + val reader = FrameReader(this, pool) + val muxedStream = YamuxMuxedStream(this, session, reader.frameChannel, mplexStreamId, mplexStreamName, pool) + val random = Random.nextBytes(50000) + assertTrue(muxedStream.remoteSendsNewMessage(buildPacket(pool) { writeFully(random) })) + for (j in 0 until 5) { + val bytes = ByteArray(10000) + muxedStream.input.readFully(bytes) + assertArrayEquals(random.copyOfRange(j * bytes.size, (j + 1) * bytes.size), bytes) + } + muxedStream.close() + muxedStream.awaitClosed() + reader.stop() + reader.assertNoBytesReceived() + } + } + + @Test + fun testReadPacketCombined() = runTest { + repeat(1000) { + val reader = FrameReader(this, pool) + val muxedStream = YamuxMuxedStream(this, session, reader.frameChannel, mplexStreamId, mplexStreamName, pool) + val random = Random.nextBytes(50000) + for (j in 0 until 5) { + val bytes = random.copyOfRange(j * 10000, (j + 1) * 10000) + assertTrue(muxedStream.remoteSendsNewMessage(buildPacket(pool) { writeFully(bytes) })) + } + val bytes = ByteArray(random.size) + muxedStream.input.readFully(bytes) + assertArrayEquals(random, bytes) + muxedStream.close() + muxedStream.awaitClosed() + reader.stop() + reader.assertNoBytesReceived() + } + } + + @Test + fun testReadPacketWait() = runTest { + val reader = FrameReader(this, pool) + val muxedStream = YamuxMuxedStream(this, session, reader.frameChannel, mplexStreamId, mplexStreamName, pool) + val result = withTimeoutOrNull(500) { + muxedStream.input.readPacket(10) + } + assertNull(result) + muxedStream.close() + muxedStream.awaitClosed() + reader.stop() + reader.assertNoBytesReceived() + } + + @Test + fun testReadPacketAfterCancel() = runTest { + val reader = FrameReader(this, pool) + val muxedStream = YamuxMuxedStream(this, session, reader.frameChannel, mplexStreamId, mplexStreamName, pool) + muxedStream.input.cancel() + yield() // Give the input coroutine a chance to cancel + val exception1 = assertThrows { + muxedStream.input.readPacket(10) + } + assertEquals("Channel has been cancelled", exception1.message) + // Remote can not send messages + assertFalse(muxedStream.remoteSendsNewMessage(buildPacket(pool) { writeFully(Random.nextBytes(100000)) })) + muxedStream.close() + muxedStream.awaitClosed() + reader.stop() + reader.assertNoBytesReceived() + } + + @Test + fun testReadPacketAfterClose() = runTest { + val reader = FrameReader(this, pool) + val muxedStream = YamuxMuxedStream(this, session, reader.frameChannel, mplexStreamId, mplexStreamName, pool) + assertFalse(muxedStream.input.isClosedForRead) + assertFalse(muxedStream.output.isClosedForWrite) + muxedStream.close() + muxedStream.awaitClosed() + assertTrue(muxedStream.input.isClosedForRead) + assertTrue(muxedStream.output.isClosedForWrite) + assertStreamRemoved() + val exception1 = assertThrows { + muxedStream.input.readPacket(123) + } + assertEquals("Channel has been cancelled", exception1.message) + reader.stop() + reader.assertNoBytesReceived() + reader.assertCloseFrameReceived(mplexStreamId) + } + + @Test + fun testReadPacketAfterRemoteCloses() = runTest { + val reader = FrameReader(this, pool) + val muxedStream = YamuxMuxedStream(this, session, reader.frameChannel, mplexStreamId, mplexStreamName, pool) + assertFalse(muxedStream.input.isClosedForRead) + assertFalse(muxedStream.output.isClosedForWrite) + muxedStream.remoteClosesWriting() + yield() + assertTrue(muxedStream.input.isClosedForRead) + assertFalse(muxedStream.output.isClosedForWrite) + reader.assertNoCloseFrameReceived() + assertStreamNotRemoved() + val exception1 = assertThrows { + muxedStream.input.readPacket(123) + } + assertEquals("Unexpected EOF: expected 123 more bytes", exception1.message) + muxedStream.close() + muxedStream.awaitClosed() + reader.stop() + reader.assertNoBytesReceived() + reader.assertCloseFrameReceived(mplexStreamId) + } + + @Test + fun testReadPacketAfterRemoteClosesDataInBuffer() = runTest { + val reader = FrameReader(this, pool) + val muxedStream = YamuxMuxedStream(this, session, reader.frameChannel, mplexStreamId, mplexStreamName, pool) + assertFalse(muxedStream.input.isClosedForRead) + assertFalse(muxedStream.output.isClosedForWrite) + val random = Random.nextBytes(50000) + assertTrue(muxedStream.remoteSendsNewMessage(buildPacket(pool) { writeFully(random) })) + muxedStream.remoteClosesWriting() + yield() + assertFalse(muxedStream.input.isClosedForRead) + assertFalse(muxedStream.output.isClosedForWrite) + reader.assertNoCloseFrameReceived() + assertStreamNotRemoved() + val bytes = ByteArray(random.size) + muxedStream.input.readFully(bytes) + assertArrayEquals(random, bytes) + assertTrue(muxedStream.input.isClosedForRead) + assertFalse(muxedStream.output.isClosedForWrite) + val exception1 = assertThrows { + muxedStream.input.readPacket(321) + } + assertEquals("Unexpected EOF: expected 321 more bytes", exception1.message) + muxedStream.close() + muxedStream.awaitClosed() + reader.stop() + reader.assertNoBytesReceived() + reader.assertCloseFrameReceived(mplexStreamId) + } + + @Test + fun testNotReading() = runTest { + val reader = FrameReader(this, pool) + val muxedStream = YamuxMuxedStream(this, session, reader.frameChannel, mplexStreamId, mplexStreamName, pool) + // It seems that the maximum of the input ByteReadChannel is 4088 bytes. So we have to provide enough data + // to fill the input channel (~5 * 1000 bytes) and we also have to fill up the inputChannel with 16 packets. + // So we have to provide 5 + 16 = 21 packets. + for (i in 0 until 21) { + muxedStream.remoteSendsNewMessage(buildPacket(pool) { writeFully(Random.nextBytes(1000)) }) + yield() // Give the input coroutine a chance to process the packets + } + assertTrue(muxedStream.input.availableForRead > 0) + val timeout = withTimeoutOrNull(2.seconds) { + muxedStream.remoteSendsNewMessage(buildPacket(pool) { writeFully(Random.nextBytes(1000)) }) + } + assertNull(timeout) + muxedStream.close() // Causes all packets in the input channel to be closed + muxedStream.awaitClosed() + reader.stop() + reader.assertNoBytesReceived() + } + + @Test + fun testReadPacketAfterReset() = runTest { + val reader = FrameReader(this, pool) + val muxedStream = YamuxMuxedStream(this, session, reader.frameChannel, mplexStreamId, mplexStreamName, pool) + val random = Random.nextBytes(50000) + muxedStream.remoteSendsNewMessage(buildPacket { writeFully(random) }) + assertFalse(muxedStream.input.isClosedForRead) + assertFalse(muxedStream.output.isClosedForWrite) + muxedStream.reset() + muxedStream.awaitClosed() + assertTrue(muxedStream.input.isClosedForRead) + assertTrue(muxedStream.output.isClosedForWrite) + reader.assertResetFrameReceived(mplexStreamId) + assertStreamRemoved() + val exception2 = assertThrows { + muxedStream.input.readPacket(random.size) + } + assertEquals("Stream was reset", exception2.message) + reader.stop() + reader.assertNoBytesReceived() + } + + // + // Write + // + + @Test + fun testWritePacket() = runTest { + repeat(1000) { + val reader = FrameReader(this, pool) + val random = Random.nextBytes(10000) + val muxedStream = YamuxMuxedStream(this, session, reader.frameChannel, mplexStreamId, mplexStreamName, pool) + muxedStream.output.writeFully(random) + muxedStream.output.flush() + muxedStream.close() + muxedStream.awaitClosed() + reader.stop() + reader.assertBytesReceived(random) + } + } + + @Test + fun testWritePacketSplit() = runTest { + repeat(1000) { + val reader = FrameReader(this, pool) + val random = Random.nextBytes(10000) + val muxedStream = YamuxMuxedStream(this, session, reader.frameChannel, mplexStreamId, mplexStreamName, pool) + muxedStream.output.writeFully(random, 0, 5000) + muxedStream.output.writeFully(random, 5000, 5000) + muxedStream.output.flush() + muxedStream.close() + muxedStream.awaitClosed() + reader.stop() + reader.assertBytesReceived(random) + } + } + + @Test + fun testWritePacketAfterChannelClose() = runTest { + val reader = FrameReader(this, pool) + val muxedStream = YamuxMuxedStream(this, session, reader.frameChannel, mplexStreamId, mplexStreamName, pool) + muxedStream.output.close() + yield() // Give the input coroutine a chance to cancel + val exception1 = assertThrows { + muxedStream.output.writeFully(Random.nextBytes(100000)) + muxedStream.output.flush() + } + assertEquals("The channel was closed", exception1.message) + // Remote can send messages + assertTrue(muxedStream.remoteSendsNewMessage(buildPacket(pool) { writeFully(Random.nextBytes(1000)) })) + muxedStream.close() + muxedStream.awaitClosed() + reader.stop() + reader.assertNoBytesReceived() + reader.assertCloseFrameReceived(mplexStreamId) + } + + @Test + fun testWritePacketAfterClose() = runTest { + val reader = FrameReader(this, pool) + val muxedStream = YamuxMuxedStream(this, session, reader.frameChannel, mplexStreamId, mplexStreamName, pool) + assertFalse(muxedStream.input.isClosedForRead) + assertFalse(muxedStream.output.isClosedForWrite) + muxedStream.close() + muxedStream.awaitClosed() + assertTrue(muxedStream.input.isClosedForRead) + assertTrue(muxedStream.output.isClosedForWrite) + assertStreamRemoved() + val exception1 = assertThrows { + muxedStream.output.writeFully(Random.nextBytes(100000)) + muxedStream.output.flush() + } + assertEquals("The channel was closed", exception1.message) + reader.stop() + reader.assertNoBytesReceived() + reader.assertCloseFrameReceived(mplexStreamId) + } + + @Test + fun testWritePacketAfterReset() = runTest { + val reader = FrameReader(this, pool) + val muxedStream = YamuxMuxedStream(this, session, reader.frameChannel, mplexStreamId, mplexStreamName, pool) + val random = Random.nextBytes(50000) + muxedStream.remoteSendsNewMessage(buildPacket { writeFully(random) }) + assertFalse(muxedStream.input.isClosedForRead) + assertFalse(muxedStream.output.isClosedForWrite) + muxedStream.reset() + muxedStream.awaitClosed() + assertTrue(muxedStream.input.isClosedForRead) + assertTrue(muxedStream.output.isClosedForWrite) + reader.assertResetFrameReceived(mplexStreamId) + assertStreamRemoved() + val exception2 = assertThrows { + muxedStream.output.writeFully(Random.nextBytes(100000)) + muxedStream.output.flush() + } + assertEquals("Stream was reset", exception2.message) + reader.stop() + reader.assertNoBytesReceived() + } + + private fun assertStreamRemoved() { + coVerify { session.removeStream(any()) } + assertEquals(mplexStreamId, streamIdSlot.captured) + } + + private fun assertStreamNotRemoved() { + coVerify(exactly = 0) { session.removeStream(any()) } + } + + private class FrameReader(scope: CoroutineScope, pool: VerifyingChunkBufferPool) { + val frameChannel = SafeChannel(16) + private var closeFrame: CloseFrame? = null + private var resetFrame: ResetFrame? = null + private val builder = BytePacketBuilder(pool) + private val job = scope.launch { + frameChannel.consumeEach { + when (it) { + is MessageFrame -> { + builder.writePacket(it.packet) + } + + is CloseFrame -> { + assertNull(closeFrame) + closeFrame = it + } + + is ResetFrame -> { + assertNull(resetFrame) + resetFrame = it + } + + else -> { + assertTrue(false, "Unexpected frame type in FrameReader: $it") + } + } + } + } + + fun assertResetFrameReceived(streamId: MplexStreamId) { + assertNotNull(resetFrame) + assertEquals(streamId, resetFrame?.streamId) + } + + fun assertCloseFrameReceived(streamId: MplexStreamId) { + assertNotNull(closeFrame) + assertEquals(streamId, closeFrame?.streamId) + } + + fun assertNoCloseFrameReceived() { + assertNull(closeFrame) + } + + fun assertBytesReceived(expected: ByteArray) { + assertArrayEquals(expected, builder.build().readBytes()) + } + + fun assertNoBytesReceived() { + assertTrue(builder.isEmpty) + } + + suspend fun stop() { + frameChannel.close() + job.join() + } + } +} diff --git a/libp2p-muxer-yamux/src/test/kotlin/org/erwinkok/libp2p/muxer/yamux/YamuxStreamMuxerTransportTest.kt b/libp2p-muxer-yamux/src/test/kotlin/org/erwinkok/libp2p/muxer/yamux/YamuxStreamMuxerTransportTest.kt index 91c07ff..05cdbe8 100644 --- a/libp2p-muxer-yamux/src/test/kotlin/org/erwinkok/libp2p/muxer/yamux/YamuxStreamMuxerTransportTest.kt +++ b/libp2p-muxer-yamux/src/test/kotlin/org/erwinkok/libp2p/muxer/yamux/YamuxStreamMuxerTransportTest.kt @@ -22,16 +22,16 @@ internal class YamuxStreamMuxerTransportTest : TestWithLeakCheck { assertEquals("/yamux/1.0.0", YamuxStreamMuxerTransport.protocolId.id) } -// @Test -// fun testCreate() = runTest { -// val transport = YamuxStreamMuxerTransport.create(this).expectNoErrors() -// assertNotNull(transport) -// val connection = mockk() -// val peerScope = mockk() -// every { connection.input } returns ByteChannel() -// every { connection.output } returns ByteChannel() -// val muxerConnection = transport.newConnection(connection, true, peerScope).expectNoErrors() -// assertNotNull(muxerConnection) -// muxerConnection.close() -// } + @Test + fun testCreate() = runTest { + val transport = YamuxStreamMuxerTransport.create(this).expectNoErrors() + assertNotNull(transport) + val connection = mockk() + val peerScope = mockk() + every { connection.input } returns ByteChannel() + every { connection.output } returns ByteChannel() + val muxerConnection = transport.newConnection(connection, true, peerScope).expectNoErrors() + assertNotNull(muxerConnection) + muxerConnection.close() + } }