diff --git a/build.gradle.kts b/build.gradle.kts index 6b04793..64dfdfd 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -14,7 +14,7 @@ plugins { apply() group = "io.github.doip-sim-ecu" -version = "0.13.0" +version = "0.14.0" repositories { gradlePluginPortal() diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index e1bef7e..0d18421 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-8.0.2-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-8.8-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/src/main/kotlin/NetworkHandler.kt b/src/main/kotlin/NetworkHandler.kt new file mode 100644 index 0000000..6fe6869 --- /dev/null +++ b/src/main/kotlin/NetworkHandler.kt @@ -0,0 +1,351 @@ +import io.ktor.network.selector.* +import io.ktor.network.sockets.* +import io.ktor.utils.io.* +import io.ktor.utils.io.core.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.ClosedReceiveChannelException +import kotlinx.coroutines.slf4j.MDCContext +import library.* +import library.DelegatedKtorSocket +import library.SSLDoipTcpSocket +import nl.altindag.ssl.SSLFactory +import nl.altindag.ssl.pem.util.PemUtils +import org.slf4j.LoggerFactory +import org.slf4j.MDC +import java.net.InetAddress +import java.net.SocketException +import java.nio.file.Paths +import javax.net.ssl.SSLServerSocket +import javax.net.ssl.SSLSocket +import kotlin.collections.component1 +import kotlin.collections.component2 +import kotlin.concurrent.fixedRateTimer +import kotlin.concurrent.thread +import kotlin.system.exitProcess + +public open class UdpNetworkBinding( + private val localAddress: String, + private val port: Int = 13400, + private val broadcastEnabled: Boolean = true, + private val broadcastAddress: String = "255.255.255.255", + private val doipEntities: List>, +) { + private val logger = LoggerFactory.getLogger(UdpNetworkBinding::class.java) + + private lateinit var udpServerSocket: BoundDatagramSocket + + private val udpMessageHandlers = doipEntities.associateWith { it.createDoipUdpMessageHandler() } + + protected open suspend fun startVamTimer(socket: BoundDatagramSocket) { + if (broadcastEnabled) { + sendVams(socket) + } + } + + protected open suspend fun sendVams(socket: BoundDatagramSocket) { + var vamSentCounter = 0 + + val entries = doipEntities.associateWith { it.generateVehicleAnnouncementMessages() } + + fixedRateTimer("VAM", daemon = true, initialDelay = 500, period = 500) { + if (vamSentCounter >= 3) { + this.cancel() + return@fixedRateTimer + } + entries.forEach { (doipEntity, vams) -> + MDC.put("ecu", doipEntity.name) + vams.forEach { vam -> + logger.info("Sending VAM for ${vam.logicalAddress.toByteArray().toHexString()} as broadcast") + runBlocking(Dispatchers.IO) { + launch(MDCContext()) { + socket.send( + Datagram( + packet = ByteReadPacket(vam.asByteArray), + address = InetSocketAddress(broadcastAddress, port) + ) + ) + } + } + } + } + + vamSentCounter++ + } + } + + public fun start() { + thread(name = "UDP") { + runBlocking { + udpServerSocket = aSocket(ActorSelectorManager(Dispatchers.IO)) + .udp() + .bind(localAddress = InetSocketAddress(hostname = localAddress, port = port)) { + broadcast = true + reuseAddress = true +// reusePort = true // not supported on windows + typeOfService = TypeOfService.IPTOS_RELIABILITY +// socket.joinGroup(multicastAddress) + } + logger.info("Listening on udp: ${udpServerSocket.localAddress}") + startVamTimer(udpServerSocket) + + while (!udpServerSocket.isClosed) { + val datagram = udpServerSocket.receive() + withContext(Dispatchers.IO) { + handleUdpMessage(udpMessageHandlers, datagram, udpServerSocket) + } + } + } + } + } + + protected open fun CoroutineScope.handleUdpMessage( + udpMessageHandlers: Map, DoipUdpMessageHandler>, + datagram: Datagram, + socket: BoundDatagramSocket + ) { + val message = DoipUdpMessageParser.parseUDP(datagram.packet) + udpMessageHandlers.forEach { (doipEntity, datagramHandler) -> + runBlocking { + MDC.put("ecu", doipEntity.name) + try { + logger.traceIf { "Incoming UDP message for ${doipEntity.name}" } + datagramHandler.handleUdpMessage(socket.outgoing, datagram.address, message) + } catch (e: HeaderNegAckException) { + val code = when (e) { + is IncorrectPatternFormat -> DoipUdpHeaderNegAck.NACK_INCORRECT_PATTERN_FORMAT + is HeaderTooShort -> DoipUdpHeaderNegAck.NACK_INCORRECT_PATTERN_FORMAT + is InvalidPayloadLength -> DoipUdpHeaderNegAck.NACK_INVALID_PAYLOAD_LENGTH + is UnknownPayloadType -> DoipUdpHeaderNegAck.NACK_UNKNOWN_PAYLOAD_TYPE + else -> { + DoipUdpHeaderNegAck.NACK_UNKNOWN_PAYLOAD_TYPE + } + } + logger.debug("Error in Message-Header, sending negative acknowledgement", e) + datagramHandler.respondHeaderNegAck( + socket.outgoing, + datagram.address, + code + ) + return@runBlocking + } catch (e: Exception) { + logger.error("Unknown error while processing message", e) + } + } + } + } +} + +public open class TcpNetworkBinding( + private val networkManager: NetworkManager, + private val localAddress: String, + private val localPort: Int, + private val tlsOptions: TlsOptions?, + private val doipEntities: List> +) { + private val logger = LoggerFactory.getLogger(TcpNetworkBinding::class.java) + + private val serverSockets: MutableList = mutableListOf() + private val activeConnections: MutableMap> = mutableMapOf() + +// public fun pauseTcpServerSockets(duration: kotlin.time.Duration) { +// logger.warn("Closing serversockets") +// serverSockets.forEach { +// try { +// it.close() +// } catch (ignored: Exception) { +// } +// } +// serverSockets.clear() +// logger.warn("Pausing server sockets for ${duration.inWholeMilliseconds} ms") +// Thread.sleep(duration.inWholeMilliseconds) +// logger.warn("Restarting server sockets after ${duration.inWholeMilliseconds} ms") +// runBlocking { +// launch { +// startVamTimer(udpServerSocket) +// } +// launch { +// start() +// } +// } +// } + + public fun start() { + thread(name = "TCP") { + runBlocking { + withContext(Dispatchers.IO) { + val serverSocket = + aSocket(ActorSelectorManager(Dispatchers.IO)) + .tcp() + .bind(InetSocketAddress(localAddress, localPort)) + serverSockets.add(serverSocket) + logger.info("Listening on tcp: ${serverSocket.localAddress}") + while (!serverSocket.isClosed) { + val socket = serverSocket.accept() + val activeConnection = ActiveConnection(this@TcpNetworkBinding, doipEntities) + activeConnection.handleTcpSocket(this@withContext, DelegatedKtorSocket(socket), null) + } + } + } + } + +// TLS with ktor-network doesn't work yet https://youtrack.jetbrains.com/issue/KTOR-694 + if (tlsOptions != null && tlsOptions.tlsMode != TlsMode.DISABLED) { + if (tlsOptions.tlsCert == null) { + System.err.println("tlsCert is null") + exitProcess(-1) + } else if (tlsOptions.tlsKey == null) { + System.err.println("tlsKey is null") + exitProcess(-1) + } else if (!tlsOptions.tlsCert.isFile) { + System.err.println("${tlsOptions.tlsCert.absolutePath} doesn't exist or isn't a file") + exitProcess(-1) + } else if (!tlsOptions.tlsKey.isFile) { + System.err.println("${tlsOptions.tlsKey.absolutePath} doesn't exist or isn't a file") + exitProcess(-1) + } + + thread(name = "TLS") { + runBlocking { + val key = PemUtils.loadIdentityMaterial( + Paths.get(tlsOptions.tlsCert.toURI()), + Paths.get(tlsOptions.tlsKey.toURI()), + tlsOptions.tlsKeyPassword?.toCharArray() + ) + val trustMaterial = PemUtils.loadTrustMaterial(Paths.get(tlsOptions.tlsCert.toURI())) + + val sslFactory = SSLFactory.builder() + .withIdentityMaterial(key) + .withTrustMaterial(trustMaterial) + .build() + + val serverSocket = withContext(Dispatchers.IO) { + (sslFactory.sslServerSocketFactory.createServerSocket( + tlsOptions.tlsPort, + 50, + InetAddress.getByName(localAddress) + )) + } + serverSockets.add(serverSocket as ServerSocket) + val tlsServerSocket = serverSocket as SSLServerSocket + logger.info("Listening on tls: ${tlsServerSocket.localSocketAddress}") + + if (tlsOptions.tlsProtocols != null) { + val supportedProtocols = tlsServerSocket.supportedProtocols.toSet() + // Use filter to retain order of protocols/ciphers + tlsServerSocket.enabledProtocols = + tlsOptions.tlsProtocols.filter { supportedProtocols.contains(it) }.toTypedArray() + } + + if (tlsOptions.tlsCiphers != null) { + val supportedCipherSuites = tlsServerSocket.supportedCipherSuites.toSet() + // Use filter to retain order of protocols/ciphers + tlsServerSocket.enabledCipherSuites = + tlsOptions.tlsCiphers.filter { supportedCipherSuites.contains(it) }.toTypedArray() + } + + logger.info("Enabled TLS protocols: ${tlsServerSocket.enabledProtocols.joinToString(", ")}") + logger.info("Enabled TLS cipher suites: ${tlsServerSocket.enabledCipherSuites.joinToString(", ")}") + + while (!tlsServerSocket.isClosed) { + val socket = tlsServerSocket.accept() as SSLSocket + val activeConnection = ActiveConnection(this@TcpNetworkBinding, doipEntities) + activeConnection.handleTcpSocket(this, SSLDoipTcpSocket(socket), null) + } + } + } + } + } + + public open class ActiveConnection( + private val networkBinding: TcpNetworkBinding, + private val doipEntities: List> + ) { + private val logger = LoggerFactory.getLogger(ActiveConnection::class.java) + + public open suspend fun handleTcpSocket( + scope: CoroutineScope, + socket: DoipTcpSocket, + disableServerSocketCallback: ((kotlin.time.Duration) -> Unit)? + ) { + scope.launch(Dispatchers.IO) { + val handler = GroupDoipTcpConnectionMessageHandler(doipEntities, socket, networkBinding.tlsOptions) + + val entity = doipEntities.first() + + logger.debugIf { "New incoming data connection from ${socket.remoteAddress}" } + val input = socket.openReadChannel() + val output = socket.openWriteChannel() + try { + val parser = DoipTcpMessageParser(doipEntities.first().config.maxDataSize - 8) + while (!socket.isClosed) { + val message = parser.parseDoipTcpMessage(input) + + runBlocking { + try { + MDC.put("ecu", entity.name) + handler.handleTcpMessage(message, output) + } catch (e: ClosedReceiveChannelException) { + // ignore - socket was closed + logger.debugIf { "Socket was closed by remote ${socket.remoteAddress}" } + withContext(Dispatchers.IO) { + handler.connectionClosed(e) + socket.runCatching { this.close() } + } + } catch (e: SocketException) { + logger.error("Socket error: ${e.message} -> closing socket") + withContext(Dispatchers.IO) { + handler.connectionClosed(e) + socket.runCatching { this.close() } + } + } catch (e: HeaderNegAckException) { + if (!socket.isClosed) { + logger.debug( + "Error in Header while parsing message, sending negative acknowledgment", + e + ) + val response = + DoipTcpHeaderNegAck(DoipTcpDiagMessageNegAck.NACK_CODE_TRANSPORT_PROTOCOL_ERROR).asByteArray + output.writeFully(response) + withContext(Dispatchers.IO) { + handler.connectionClosed(e) + socket.runCatching { this.close() } + } + } + } catch (e: DoipEntityHardResetException) { + logger.warn("Simulating Hard Reset on ${entity.name} for ${e.duration.inWholeMilliseconds} ms") + output.flush() + socket.close() + + if (disableServerSocketCallback != null) { + disableServerSocketCallback(e.duration) + } + } catch (e: Exception) { + if (!socket.isClosed) { + logger.error( + "Unknown error parsing/handling message, sending negative acknowledgment", + e + ) + val response = + DoipTcpHeaderNegAck(DoipTcpDiagMessageNegAck.NACK_CODE_TRANSPORT_PROTOCOL_ERROR).asByteArray + output.writeFully(response) + withContext(Dispatchers.IO) { + handler.connectionClosed(e) + socket.runCatching { this.close() } + } + } + } + } + } + } catch (e: Throwable) { + logger.error("Unknown error inside socket processing loop, closing socket", e) + } finally { + try { + socket.close() + } finally { + networkBinding.activeConnections.remove(this@ActiveConnection) + } + } + } + } + } +} diff --git a/src/main/kotlin/NetworkManager.kt b/src/main/kotlin/NetworkManager.kt new file mode 100644 index 0000000..492072a --- /dev/null +++ b/src/main/kotlin/NetworkManager.kt @@ -0,0 +1,127 @@ +import library.DoipEntity +import org.slf4j.LoggerFactory +import java.net.Inet4Address +import java.net.InetAddress +import java.net.NetworkInterface + +public open class NetworkManager( + public val config: NetworkingData, + public val doipEntities: List>, +) { + private val log = LoggerFactory.getLogger(NetworkManager::class.java) + + protected fun findInterfaceByName(): NetworkInterface? { + var foundInterface: NetworkInterface? = null + NetworkInterface.getNetworkInterfaces()?.let { netIntf -> + while (netIntf.hasMoreElements()) { + val entry = netIntf.nextElement() + if (entry.displayName != null && entry.displayName.equals(config.networkInterface, true)) { + foundInterface = entry + break + } + entry.subInterfaces?.let { subInterfaces -> + while (subInterfaces.hasMoreElements()) { + val subInterface = subInterfaces.nextElement() + if (subInterface.displayName != null && subInterface.displayName.equals( + config.networkInterface, + true + ) + ) { + foundInterface = entry; + break + } + } + } + if (foundInterface != null) { + break + } + } + } + + return foundInterface + } + + protected fun getAvailableIPAddresses(): List { + if (config.networkInterface.isNullOrBlank() || config.networkInterface == "0.0.0.0") { + return listOf(InetAddress.getByName(config.networkInterface)) + } + val list = mutableListOf() + findInterfaceByName()?.let { intf -> + intf.inetAddresses?.let { inetAddresses -> + while (inetAddresses.hasMoreElements()) { + val address = inetAddresses.nextElement() + if (address is Inet4Address) { + list.add(address) + } + } + } + } + if (list.isEmpty()) { + InetAddress.getByName(config.networkInterface)?.let { addr -> + list.add(addr) + } + } + return list + } + + protected open fun buildStartupMap(): Map>> { + val ipAddresses = getAvailableIPAddresses().toMutableList() + if (ipAddresses.isEmpty()) { + throw IllegalArgumentException("No network interface with the identifier ${config.networkInterface} could be found") + } + log.info("There are ${ipAddresses.size} ip address available, and we have ${doipEntities.size} doip entities") + val entitiesByIP = mutableMapOf>>() + doipEntities.forEach { entity -> + val ip = if (ipAddresses.size == 1) { + ipAddresses.first() + } else { + ipAddresses.removeFirst() + } + log.info("Assigning entity ${entity.name} to $ip") + var entityList = entitiesByIP[ip.hostAddress] + if (entityList == null) { + entityList = mutableListOf() + entitiesByIP[ip.hostAddress] = entityList + } + entityList.add(entity) + } + return entitiesByIP + } + + public fun start() { + val map = buildStartupMap() + + // UDP + map.forEach { (address, entities) -> + val unb = createUdpNetworkBinding(address, entities) + unb.start() + } + + if (config.bindOnAnyForUdpAdditional && !map.containsKey("0.0.0.0")) { + val unb = createUdpNetworkBindingAny() + unb.start() + } + + // TCP + map.forEach { (address, entities) -> + val tnb = createTcpNetworkBinding(address, entities) + tnb.start() + } + } + + protected open fun createTcpNetworkBinding( + address: String, + entities: List> + ): TcpNetworkBinding = + TcpNetworkBinding(this, address, config.localPort, config.tlsOptions, entities) + + protected open fun createUdpNetworkBinding( + address: String, + entities: List> + ): UdpNetworkBinding = + UdpNetworkBinding(address, config.localPort, config.broadcastEnable, config.broadcastAddress, entities) + + protected open fun createUdpNetworkBindingAny(): UdpNetworkBinding = + UdpNetworkBinding("0.0.0.0", config.localPort, config.broadcastEnable, config.broadcastAddress, doipEntities) +} + diff --git a/src/main/kotlin/SimGateway.kt b/src/main/kotlin/SimDoipEntity.kt similarity index 54% rename from src/main/kotlin/SimGateway.kt rename to src/main/kotlin/SimDoipEntity.kt index 47003f7..3649531 100644 --- a/src/main/kotlin/SimGateway.kt +++ b/src/main/kotlin/SimDoipEntity.kt @@ -3,57 +3,13 @@ import kotlinx.coroutines.runBlocking import kotlinx.coroutines.slf4j.MDCContext import library.* import org.slf4j.MDC -import kotlin.properties.Delegates -import kotlin.time.Duration -import kotlin.time.Duration.Companion.seconds @Suppress("unused") -public open class GatewayData(name: String) : RequestsData(name) { - /** - * Network address this gateway should bind on (default: 0.0.0.0) - */ - public var localAddress: String = "0.0.0.0" - - /** - * Should udp be bound additionally on any? - * There's an issue when binding it to an network interface of not receiving 255.255.255.255 broadcasts - */ - public var bindOnAnyForUdpAdditional: Boolean = true - - /** - * Network port this gateway should bind on (default: 13400) - */ - public var localPort: Int = 13400 - - /** - * Multicast address - */ - public var multicastAddress: String? = null - - /** - * Whether VAM broadcasts shall be sent on startup (default: true) - */ - public var broadcastEnable: Boolean = true - - /** - * Default broadcast address for VAM messages (default: 255.255.255.255) - */ - public var broadcastAddress: String = "255.255.255.255" - - /** - * The logical address under which the gateway shall be reachable - */ - public var logicalAddress: Short by Delegates.notNull() - - /** - * The functional address under which the gateway (and other ecus) shall be reachable - */ - public var functionalAddress: Short by Delegates.notNull() - +public open class DoipEntityData(name: String, public val nodeType: DoipNodeType = DoipNodeType.GATEWAY) : EcuData(name) { /** * Vehicle identifier, 17 chars, will be filled with '0`, or if left null, set to 0xFF */ - public var vin: String? = null // 17 byte VIN + public var vin: String? = null // 17 byte VIN /** * Group ID of the gateway @@ -65,22 +21,12 @@ public open class GatewayData(name: String) : RequestsData(name) { */ public var eid: ByteArray = byteArrayOf(0, 0, 0, 0, 0, 0) // 6 byte entity identification (usually MAC) - /** - * Interval between sending pending NRC messages (0x78) - */ - public var pendingNrcSendInterval: Duration = 2.seconds - /** * Maximum payload data size allowed for a DoIP message */ public var maxDataSize: Int = Int.MAX_VALUE - public var tlsMode: TlsMode = TlsMode.DISABLED - public var tlsPort: Int = 3496 - public var tlsOptions: TlsOptions = TlsOptions() - private val _ecus: MutableList = mutableListOf() - private val _additionalVams: MutableList = mutableListOf() public val ecus: List get() = this._ecus.toList() @@ -93,33 +39,19 @@ public open class GatewayData(name: String) : RequestsData(name) { receiver.invoke(ecuData) _ecus.add(ecuData) } - - public fun doipEntity(name: String, vam: DoipUdpVehicleAnnouncementMessage, receiver: EcuData.() -> Unit) { - val ecuData = EcuData(name) - receiver.invoke(ecuData) - _ecus.add(ecuData) - _additionalVams.add(vam) - } } -private fun GatewayData.toGatewayConfig(): DoipEntityConfig { +private fun DoipEntityData.toDoipEntityConfig(): DoipEntityConfig { val config = DoipEntityConfig( name = this.name, gid = this.gid, eid = this.eid, - localAddress = this.localAddress, - bindOnAnyForUdpAdditional = this.bindOnAnyForUdpAdditional, - localPort = this.localPort, logicalAddress = this.logicalAddress, - broadcastEnabled = this.broadcastEnable, - broadcastAddress = this.broadcastAddress, pendingNrcSendInterval = this.pendingNrcSendInterval, - tlsMode = this.tlsMode, - tlsPort = this.tlsPort, - tlsOptions = this.tlsOptions, // Fill up too short vin's with 'Z' - if no vin is given, use 0xFF, as defined in ISO 13400 for when no vin is set (yet) - vin = this.vin?.padEnd(17, 'Z')?.toByteArray() ?: ByteArray(17).let { it.fill(0xFF.toByte()); it }, + vin = this.vin?.padEnd(17, '0')?.toByteArray() ?: ByteArray(17).let { it.fill(0xFF.toByte()); it }, maxDataSize = this.maxDataSize, + nodeType = nodeType, ) // Add the gateway itself as an ecu, so it too can receive requests @@ -137,13 +69,13 @@ private fun GatewayData.toGatewayConfig(): DoipEntityConfig { } @Suppress("MemberVisibilityCanBePrivate") -public class SimGateway(private val data: GatewayData) : DoipEntity(data.toGatewayConfig()) { +public open class SimDoipEntity(private val data: DoipEntityData) : DoipEntity(data.toDoipEntityConfig()) { public val requests: RequestList get() = data.requests override fun createEcu(config: EcuConfig): SimEcu { - // To be able to handle requests for the gateway itself, insert a dummy ecu with the gateways logicalAddress if (config.name == data.name) { + // To be able to handle requests for the gateway itself, insert a dummy ecu with the gateways logicalAddress val ecu = EcuData( name = data.name, logicalAddress = data.logicalAddress, @@ -163,12 +95,12 @@ public class SimGateway(private val data: GatewayData) : DoipEntity(data return SimEcu(ecuData) } - public fun reset(recursiveEcus: Boolean = true) { + public override fun reset(recursiveEcus: Boolean) { runBlocking { MDC.put("ecu", name) launch(MDCContext()) { - logger.infoIf { "Resetting gateway" } + logger.infoIf { "Resetting doip entity" } requests.forEach { it.reset() } if (recursiveEcus) { ecus.forEach { it.reset() } diff --git a/src/main/kotlin/SimDsl.kt b/src/main/kotlin/SimDsl.kt index c4dbd91..39a0122 100644 --- a/src/main/kotlin/SimDsl.kt +++ b/src/main/kotlin/SimDsl.kt @@ -11,9 +11,11 @@ public typealias InterceptorRequestData = ResponseData public typealias InterceptorRequestHandler = InterceptorRequestData.(request: RequestMessage) -> Boolean public typealias InterceptorResponseHandler = InterceptorResponseData.(response: ByteArray) -> Boolean public typealias EcuDataHandler = EcuData.() -> Unit -public typealias GatewayDataHandler = GatewayData.() -> Unit +public typealias DoipEntityDataHandler = DoipEntityData.() -> Unit +public typealias NetworkingDataHandler = NetworkingData.() -> Unit public typealias CreateEcuFunc = (name: String, receiver: EcuDataHandler) -> Unit -public typealias CreateGatewayFunc = (name: String, receiver: GatewayDataHandler) -> Unit +public typealias CreateDoipEntityFunc = (name: String, receiver: DoipEntityDataHandler) -> Unit +public typealias CreateNetworkFunc = (receiver: NetworkingDataHandler) -> Unit @Suppress("unused") public class InterceptorResponseData( @@ -520,10 +522,18 @@ public open class RequestsData( */ public open class EcuData( name: String, + /** + * The logical address under which the gateway shall be reachable + */ public var logicalAddress: Short = 0, + /** + * The functional address under which the gateway (and other ecus) shall be reachable + */ public var functionalAddress: Short = 0, + /** + * Interval between sending pending NRC messages (0x78) + */ public var pendingNrcSendInterval: Duration = 2.seconds, - public var additionalVam: EcuAdditionalVamData? = null, nrcOnNoMatch: Boolean = true, requests: List = emptyList(), resetHandler: List = emptyList(), @@ -536,33 +546,32 @@ public open class EcuData( ackBytesLengthMap = ackBytesLengthMap, ) -internal val gateways: MutableList = mutableListOf() -internal val gatewayInstances: MutableList = mutableListOf() +internal val networks: MutableList = mutableListOf() +internal val networkInstances: MutableList = mutableListOf() -public fun gatewayInstances(): List = - gatewayInstances.toList() +public fun networks(): List = + networks.toList() -public fun gateways(): List = - gateways.toList() +public fun networkInstances(): List = + networkInstances.toList() -/** - * Defines a DoIP-Gateway and the ECUs behind it - */ -public fun gateway(name: String, receiver: GatewayDataHandler) { - val gatewayData = GatewayData(name) - receiver.invoke(gatewayData) - gateways.add(gatewayData) +public fun network(receiver: NetworkingDataHandler) { + val networkingData = NetworkingData() + receiver.invoke(networkingData) + networks.add(networkingData) } public fun reset() { - gatewayInstances.forEach { it.reset() } + networkInstances.forEach { it.reset() } } @Suppress("unused") public fun start() { - gatewayInstances.addAll(gateways.map { SimGateway(it) }) + networkInstances.addAll(networks.map { SimDoipNetworking(it) }) + + val networkManager = networkInstances.map { NetworkManager(it.data, it.doipEntities) } - gatewayInstances.forEach { + networkManager.forEach { it.start() } } diff --git a/src/main/kotlin/SimEcu.kt b/src/main/kotlin/SimEcu.kt index 94e7ca3..50d54fc 100644 --- a/src/main/kotlin/SimEcu.kt +++ b/src/main/kotlin/SimEcu.kt @@ -33,7 +33,6 @@ internal fun EcuData.toEcuConfig(): EcuConfig = logicalAddress = logicalAddress, functionalAddress = functionalAddress, pendingNrcSendInterval = pendingNrcSendInterval, - additionalVam = additionalVam, ) @@ -389,13 +388,15 @@ public class SimEcu(private val data: EcuData) : SimulatedEcu(data.toEcuConfig() /** * Resets all the ECUs stored properties, timers, interceptors and requests */ - public fun reset() { + public override fun reset() { runBlocking(Dispatchers.Default) { MDC.put("ecu", name) launch(MDCContext()) { logger.debug("Resetting interceptors, timers and stored data") + super.reset() + inboundInterceptors.clear() synchronized(mainTimer) { diff --git a/src/main/kotlin/SimNetworking.kt b/src/main/kotlin/SimNetworking.kt new file mode 100644 index 0000000..c145f10 --- /dev/null +++ b/src/main/kotlin/SimNetworking.kt @@ -0,0 +1,110 @@ +import library.* + +public enum class NetworkMode { + AUTO, + SINGLE_IP, +} + +@Suppress("unused") +public open class NetworkingData { + /** + * The network interface that should be used to bind on, can be an IP, or name + */ + public var networkInterface: String? = "0.0.0.0" + + /** + * Mode for assigning ip addresses to doip entities + */ + public var networkMode: NetworkMode = NetworkMode.AUTO + + /** + * Should udp be bound additionally on any? + * There's an issue when binding it to a network interface with not receiving 255.255.255.255 broadcasts + */ + public var bindOnAnyForUdpAdditional: Boolean = true + + /** + * Network port this gateway should bind on (default: 13400) + */ + public var localPort: Int = 13400 + + /** + * Whether VAM broadcasts shall be sent on startup (default: true) + */ + public var broadcastEnable: Boolean = true + + /** + * Default broadcast address for VAM messages (default: 255.255.255.255) + */ + public var broadcastAddress: String = "255.255.255.255" + + public val tlsOptions: TlsOptions = TlsOptions() + + internal val _doipEntities: MutableList = mutableListOf() + + public val doipEntities: List + get() = _doipEntities + + /** + * Defines a DoIP-Gateway and the ECUs behind it + */ + public fun gateway(name: String, receiver: DoipEntityDataHandler) { + val gatewayData = DoipEntityData(name, DoipNodeType.GATEWAY) + receiver.invoke(gatewayData) + _doipEntities.add(gatewayData) + } + + /** + * Defines a DoIP-Gateway and the ECUs behind it + */ + public fun doipEntity(name: String, receiver: DoipEntityDataHandler) { + val gatewayData = DoipEntityData(name, DoipNodeType.NODE) + receiver.invoke(gatewayData) + _doipEntities.add(gatewayData) + } + +} + +public open class SimDoipNetworking(data: NetworkingData) : SimNetworking(data) { + override fun createDoipEntity(data: DoipEntityData): SimDoipEntity = + SimDoipEntity(data) +} + +public abstract class SimNetworking>(public val data: NetworkingData) { + public val doipEntities: List + get() = _doipEntities + + private val _doipEntities: MutableList = mutableListOf() + protected val _vams: MutableList = mutableListOf() + + protected fun addEntity(doipEntity: @UnsafeVariance T) { + _doipEntities.add(doipEntity) + } + + protected abstract fun createDoipEntity(data: DoipEntityData): T + + init { + start() + } + + private fun start() { + _doipEntities.clear() + + data._doipEntities.map { createDoipEntity(it) }.forEach { + it.start() + _doipEntities.add(it) + } + + _doipEntities.forEach { + it.start() + } + } + + public open fun reset() { + _doipEntities.forEach { it.reset() } + } + + public open fun findEcuByName(ecuName: String, ignoreCase: Boolean = true): E? = + _doipEntities.flatMap { it.ecus }.firstOrNull { ecuName.equals(it.name, ignoreCase) } +} + diff --git a/src/main/kotlin/client/DoipClient.kt b/src/main/kotlin/client/DoipClient.kt index 7a13f4b..552845e 100644 --- a/src/main/kotlin/client/DoipClient.kt +++ b/src/main/kotlin/client/DoipClient.kt @@ -93,11 +93,10 @@ public class DoipClient( thread(name = "UDP_RECV", isDaemon = true) { runBlocking { - val handler = DoipClientUdpMessageHandler() while (!udpServerSocket.isClosed) { try { val datagram = udpServerSocket.receive() - val message = handler.parseMessage(datagram) + val message = DoipUdpMessageParser.parseUDP(datagram.packet) if (message is DoipUdpVehicleAnnouncementMessage) { val sourceAddress = datagram.address _doipEntities[message.logicalAddress] = DoipEntityAnnouncement(sourceAddress, message) @@ -129,7 +128,7 @@ public class DoipEntityTcpConnection(private val socket: DoipTcpSocket, private writeChannel.writeFully(ByteBuffer.wrap(DoipTcpRoutingActivationRequest(testerAddress).asByteArray)) writeChannel.flush() - val msg = DoipTcpConnectionMessageHandler(socket).receiveTcpData(readChannel) as DoipTcpRoutingActivationResponse + val msg = DoipTcpMessageParser(Int.MAX_VALUE).parseDoipTcpMessage(readChannel) as DoipTcpRoutingActivationResponse if (msg.responseCode != DoipTcpRoutingActivationResponse.RC_OK) { throw ConnectException("Routing activation failed (${msg.responseCode})") } @@ -167,15 +166,17 @@ public class DoipEntityTcpConnection(private val socket: DoipTcpSocket, private ) writeChannel.writeFully(ByteBuffer.wrap(request.asByteArray)) writeChannel.flush() - val handler = DoipTcpConnectionMessageHandler(socket) - val diagResponse = handler.receiveTcpData(readChannel) + + val parser = DoipTcpMessageParser(Int.MAX_VALUE) + + val diagResponse = parser.parseDoipTcpMessage(readChannel) if (diagResponse is DoipTcpDiagMessagePosAck) { if (waitForResponse) { val response = withTimeoutOrNull(timeout = waitTimeout) { var msg: DoipTcpMessage? do { - msg = handler.receiveTcpData(readChannel) + msg = parser.parseDoipTcpMessage(readChannel) } while (msg !is DoipTcpDiagMessage) msg } ?: throw RuntimeException("No response within $waitTimeout") diff --git a/src/main/kotlin/library/DefaultDoipEntityTcpConnectionMessageHandler.kt b/src/main/kotlin/library/DefaultDoipEntityTcpConnectionMessageHandler.kt index 70f021c..3bb7f34 100644 --- a/src/main/kotlin/library/DefaultDoipEntityTcpConnectionMessageHandler.kt +++ b/src/main/kotlin/library/DefaultDoipEntityTcpConnectionMessageHandler.kt @@ -4,6 +4,7 @@ import io.ktor.utils.io.* import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import kotlinx.coroutines.slf4j.MDCContext +import networkInstances import org.slf4j.Logger import org.slf4j.LoggerFactory import org.slf4j.MDC @@ -11,10 +12,10 @@ import org.slf4j.MDC public open class DefaultDoipEntityTcpConnectionMessageHandler( public val doipEntity: DoipEntity<*>, socket: DoipTcpSocket, - maxPayloadLength: Int, public val logicalAddress: Short, - public val diagMessageHandler: DiagnosticMessageHandler -) : DoipTcpConnectionMessageHandler(socket, maxPayloadLength) { + public var diagMessageHandler: DiagnosticMessageHandler, + private val tlsOptions: TlsOptions?, +) : DoipTcpConnectionMessageHandler(socket) { private val logger: Logger = LoggerFactory.getLogger(DefaultDoipEntityTcpConnectionMessageHandler::class.java) override suspend fun handleTcpMessage(message: DoipTcpMessage, output: ByteWriteChannel) { @@ -42,7 +43,7 @@ public open class DefaultDoipEntityTcpConnectionMessageHandler( ) return } else { - if (doipEntity.config.tlsMode == TlsMode.MANDATORY && socket.socketType != SocketType.TLS_DATA) { + if (tlsOptions != null && tlsOptions.tlsMode == TlsMode.MANDATORY && socket.socketType != SocketType.TLS_DATA) { logger.info("Routing activation for ${message.sourceAddress} denied (TLS required)") output.writeFully( DoipTcpRoutingActivationResponse( @@ -68,8 +69,8 @@ public open class DefaultDoipEntityTcpConnectionMessageHandler( ).asByteArray ) } else if ( - doipEntity.ecus.all { it.config.additionalVam == null } && - doipEntity.hasAlreadyActiveConnection(message.sourceAddress, this) + networkInstances().groupBy { it.data.networkInterface }.all { it.value.all { it.doipEntities.size == 1}} + && doipEntity.hasAlreadyActiveConnection(message.sourceAddress, this) ) { logger.error("Routing activation for ${message.sourceAddress} denied (Has already an active connection)") output.writeFully( diff --git a/src/main/kotlin/library/DefaultDoipEntityUdpMessageHandler.kt b/src/main/kotlin/library/DefaultDoipEntityUdpMessageHandler.kt index 9ac73dc..a4c8285 100644 --- a/src/main/kotlin/library/DefaultDoipEntityUdpMessageHandler.kt +++ b/src/main/kotlin/library/DefaultDoipEntityUdpMessageHandler.kt @@ -13,20 +13,13 @@ public open class DefaultDoipEntityUdpMessageHandler( ) : DoipUdpMessageHandler { private val logger: Logger = LoggerFactory.getLogger(DefaultDoipEntityUdpMessageHandler::class.java) - internal companion object { - fun generateVamByEntityConfig(doipEntity: DoipEntity<*>): List = - with(doipEntity.config) { - listOf(DoipUdpVehicleAnnouncementMessage(vin, logicalAddress, gid, eid)) + - doipEntity.ecus.filter { it.config.additionalVam != null }.map { it.config.additionalVam!!.toVam(it.config, doipEntity.config) } - } - } @Suppress("MemberVisibilityCanBePrivate") protected suspend fun sendVamResponse( sendChannel: SendChannel, sourceAddress: SocketAddress, ) { - val vams = generateVamByEntityConfig(doipEntity) + val vams = doipEntity.generateVehicleAnnouncementMessages() vams.forEach { vam -> logger.info("Sending VIR-Response (VAM) for ${vam.logicalAddress.toString(16)} to $sourceAddress") sendChannel.send( diff --git a/src/main/kotlin/library/DoipEntity.kt b/src/main/kotlin/library/DoipEntity.kt index 4a0da80..44eb2bf 100644 --- a/src/main/kotlin/library/DoipEntity.kt +++ b/src/main/kotlin/library/DoipEntity.kt @@ -1,25 +1,12 @@ package library -import io.ktor.network.selector.* -import io.ktor.network.sockets.* import io.ktor.utils.io.* -import io.ktor.utils.io.core.* import kotlinx.coroutines.* -import kotlinx.coroutines.channels.ClosedReceiveChannelException import kotlinx.coroutines.slf4j.MDCContext -import nl.altindag.ssl.SSLFactory -import nl.altindag.ssl.pem.util.PemUtils import org.slf4j.Logger import org.slf4j.LoggerFactory import org.slf4j.MDC import java.io.File -import java.net.InetAddress -import java.net.SocketException -import java.nio.file.Paths -import javax.net.ssl.* -import kotlin.concurrent.fixedRateTimer -import kotlin.concurrent.thread -import kotlin.system.exitProcess import kotlin.time.Duration.Companion.seconds public typealias GID = ByteArray @@ -40,11 +27,13 @@ public enum class TlsMode { } public data class TlsOptions( - val tlsCert: File? = null, - val tlsKey: File? = null, - val tlsKeyPassword: String? = null, - val tlsCiphers: List? = DefaultTlsCiphers, - val tlsProtocols: List? = DefaultTlsProtocols, + public val tlsMode: TlsMode = TlsMode.DISABLED, + public val tlsPort: Int = 3496, + public val tlsCert: File? = null, + public val tlsKey: File? = null, + public val tlsKeyPassword: String? = null, + public val tlsCiphers: List? = DefaultTlsCiphers, + public val tlsProtocols: List? = DefaultTlsProtocols, ) @Suppress("unused") @@ -55,15 +44,7 @@ public open class DoipEntityConfig( public val eid: EID, public val vin: VIN, public val maxDataSize: Int = Int.MAX_VALUE, - public val localAddress: String = "0.0.0.0", - public val bindOnAnyForUdpAdditional: Boolean = true, - public val localPort: Int = 13400, - public val broadcastEnabled: Boolean = true, - public val broadcastAddress: String = "255.255.255.255", public val pendingNrcSendInterval: kotlin.time.Duration = 2.seconds, - public val tlsMode: TlsMode = TlsMode.DISABLED, - public val tlsPort: Int = 3496, - public val tlsOptions: TlsOptions = TlsOptions(), public val ecuConfigList: MutableList = mutableListOf(), public val nodeType: DoipNodeType = DoipNodeType.GATEWAY, ) { @@ -105,59 +86,32 @@ public abstract class DoipEntity( public val ecus: List get() = _ecus - private lateinit var udpServerSocket: BoundDatagramSocket - protected abstract fun createEcu(config: EcuConfig): T + public abstract fun reset(recursiveEcus: Boolean = true) + + override fun existsTargetAddress(targetAddress: Short): Boolean = + targetEcusByLogical.containsKey(targetAddress) || targetEcusByFunctional.containsKey(targetAddress) + + public fun generateVehicleAnnouncementMessages(): List = + config.let { + listOf(DoipUdpVehicleAnnouncementMessage(it.vin, it.logicalAddress, it.gid, it.eid)) + } - protected open fun createDoipUdpMessageHandler(): DoipUdpMessageHandler = + public open fun createDoipUdpMessageHandler(): DoipUdpMessageHandler = DefaultDoipEntityUdpMessageHandler( doipEntity = this, config = config ) - protected open fun createDoipTcpMessageHandler(socket: DoipTcpSocket): DoipTcpConnectionMessageHandler = + public open fun createDoipTcpMessageHandler(socket: DoipTcpSocket, tlsOptions: TlsOptions?): DoipTcpConnectionMessageHandler = DefaultDoipEntityTcpConnectionMessageHandler( doipEntity = this, socket = socket, logicalAddress = config.logicalAddress, - maxPayloadLength = config.maxDataSize - 8, - diagMessageHandler = this + diagMessageHandler = this, + tlsOptions = tlsOptions, ) - protected open suspend fun sendVams(vams: List, socket: BoundDatagramSocket) { - var vamSentCounter = 0 - - fixedRateTimer("VAM", daemon = true, initialDelay = 500, period = 500) { - if (vamSentCounter >= 3) { - this.cancel() - return@fixedRateTimer - } - vams.forEach { vam -> - logger.info("Sending VAM for ${vam.logicalAddress.toByteArray().toHexString()} as broadcast") - runBlocking(Dispatchers.IO) { - MDC.put("ecu", name) - launch(MDCContext()) { - socket.send( - Datagram( - packet = ByteReadPacket(vam.asByteArray), - address = InetSocketAddress(config.broadcastAddress, 13400) - ) - ) - } - } - } - - vamSentCounter++ - } - } - - protected open suspend fun startVamTimer(socket: BoundDatagramSocket) { - if (config.broadcastEnabled) { - val vams = DefaultDoipEntityUdpMessageHandler.generateVamByEntityConfig(this) - sendVams(vams, socket) - } - } - protected open suspend fun sendResponse(request: DoipTcpDiagMessage, output: ByteWriteChannel, data: ByteArray) { if (data.isEmpty()) { return @@ -170,9 +124,6 @@ public abstract class DoipEntity( output.writeFully(response.asByteArray) } - override fun existsTargetAddress(targetAddress: Short): Boolean = - targetEcusByLogical.containsKey(targetAddress) || targetEcusByFunctional.containsKey(targetAddress) - override suspend fun onIncomingDiagMessage(diagMessage: DoipTcpDiagMessage, output: ByteWriteChannel) { val ecu = targetEcusByLogical[diagMessage.targetAddress] ecu?.run { @@ -200,216 +151,6 @@ public abstract class DoipEntity( public open fun findEcuByName(name: String, ignoreCase: Boolean = true): T? = this.ecus.firstOrNull { name.equals(it.name, ignoreCase = ignoreCase) } - protected open fun CoroutineScope.handleTcpSocket(socket: DoipTcpSocket, disableServerSocketCallback: (kotlin.time.Duration) -> Unit) { - launch { - logger.debugIf { "New incoming data connection from ${socket.remoteAddress}" } - val tcpMessageHandler = createDoipTcpMessageHandler(socket) - val input = socket.openReadChannel() - val output = socket.openWriteChannel() - try { - connectionHandlers.add(tcpMessageHandler) - while (!socket.isClosed) { - try { - val message = tcpMessageHandler.receiveTcpData(input) - tcpMessageHandler.handleTcpMessage(message, output) - } catch (e: ClosedReceiveChannelException) { - // ignore - socket was closed - logger.debugIf { "Socket was closed by remote ${socket.remoteAddress}" } - withContext(Dispatchers.IO) { - tcpMessageHandler.connectionClosed(e) - socket.runCatching { this.close() } - } - } catch (e: SocketException) { - logger.error("Socket error: ${e.message} -> closing socket") - withContext(Dispatchers.IO) { - tcpMessageHandler.connectionClosed(e) - socket.runCatching { this.close() } - } - } catch (e: HeaderNegAckException) { - if (!socket.isClosed) { - logger.debug("Error in Header while parsing message, sending negative acknowledgment", e) - val response = - DoipTcpHeaderNegAck(DoipTcpDiagMessageNegAck.NACK_CODE_TRANSPORT_PROTOCOL_ERROR).asByteArray - output.writeFully(response) - withContext(Dispatchers.IO) { - tcpMessageHandler.connectionClosed(e) - socket.runCatching { this.close() } - } - } - } catch (e: DoipEntityHardResetException) { - logger.warn("Simulating Hard Reset on ${this@DoipEntity.name} for ${e.duration.inWholeMilliseconds} ms") - output.flush() - socket.close() - - disableServerSocketCallback(e.duration) - } catch (e: Exception) { - if (!socket.isClosed) { - logger.error("Unknown error parsing/handling message, sending negative acknowledgment", e) - val response = - DoipTcpHeaderNegAck(DoipTcpDiagMessageNegAck.NACK_CODE_TRANSPORT_PROTOCOL_ERROR).asByteArray - output.writeFully(response) - withContext(Dispatchers.IO) { - tcpMessageHandler.connectionClosed(e) - socket.runCatching { this.close() } - } - } - } - } - } catch (e: Throwable) { - logger.error("Unknown error inside socket processing loop, closing socket", e) - } finally { - try { - withContext(Dispatchers.IO) { - tcpMessageHandler.closeSocket() - } - } finally { - connectionHandlers.remove(tcpMessageHandler) - } - } - } - } - - protected open fun CoroutineScope.handleUdpMessage( - udpMessageHandler: DoipUdpMessageHandler, - datagram: Datagram, - socket: BoundDatagramSocket - ) { - runBlocking { - MDC.put("ecu", name) - launch(MDCContext()) { - try { - logger.traceIf { "Incoming UDP message for $name" } - val message = udpMessageHandler.parseMessage(datagram) - logger.traceIf { "Message for $name is of type $message" } - udpMessageHandler.handleUdpMessage(socket.outgoing, datagram.address, message) - } catch (e: HeaderNegAckException) { - val code = when (e) { - is IncorrectPatternFormat -> DoipUdpHeaderNegAck.NACK_INCORRECT_PATTERN_FORMAT - is HeaderTooShort -> DoipUdpHeaderNegAck.NACK_INCORRECT_PATTERN_FORMAT - is InvalidPayloadLength -> DoipUdpHeaderNegAck.NACK_INVALID_PAYLOAD_LENGTH - is UnknownPayloadType -> DoipUdpHeaderNegAck.NACK_UNKNOWN_PAYLOAD_TYPE - else -> { - DoipUdpHeaderNegAck.NACK_UNKNOWN_PAYLOAD_TYPE - } - } - logger.debug("Error in Message-Header, sending negative acknowledgement", e) - udpMessageHandler.respondHeaderNegAck( - socket.outgoing, - datagram.address, - code - ) - } catch (e: Exception) { - logger.error("Unknown error while processing message", e) - } - } - } - } - - private val serverSockets: MutableList = mutableListOf() - - public fun pauseTcpServerSockets(duration: kotlin.time.Duration) { - logger.warn("Closing serversockets") - serverSockets.forEach { try { it.close() } catch (ignored: Exception) {} } - serverSockets.clear() - logger.warn("Pausing server sockets for ${duration.inWholeMilliseconds} ms") - Thread.sleep(duration.inWholeMilliseconds) - logger.warn("Restarting server sockets after ${duration.inWholeMilliseconds} ms") - runBlocking { - launch { - startVamTimer(udpServerSocket) - } - launch { - startTcpServerSockets() - } - } - } - - public fun startTcpServerSockets() { - thread(name = "TCP") { - runBlocking { - val serverSocket = - aSocket(ActorSelectorManager(Dispatchers.IO)) - .tcp() - .bind(InetSocketAddress(config.localAddress, config.localPort)) - serverSockets.add(serverSocket) - logger.info("Listening on tcp: ${serverSocket.localAddress}") - while (!serverSocket.isClosed) { - val socket = serverSocket.accept() - handleTcpSocket(DelegatedKtorSocket(socket), ::pauseTcpServerSockets) - } - } - } - -// TLS with ktor-network doesn't work yet https://youtrack.jetbrains.com/issue/KTOR-694 - if (config.tlsMode != TlsMode.DISABLED) { - val tlsOptions = config.tlsOptions - if (tlsOptions.tlsCert == null) { - System.err.println("tlsCert is null") - exitProcess(-1) - } else if (tlsOptions.tlsKey == null) { - System.err.println("tlsKey is null") - exitProcess(-1) - } else if (!tlsOptions.tlsCert.isFile) { - System.err.println("${tlsOptions.tlsCert.absolutePath} doesn't exist or isn't a file") - exitProcess(-1) - } else if (!tlsOptions.tlsKey.isFile) { - System.err.println("${tlsOptions.tlsKey.absolutePath} doesn't exist or isn't a file") - exitProcess(-1) - } - - thread(name = "TLS") { - runBlocking { - val key = PemUtils.loadIdentityMaterial( - Paths.get(tlsOptions.tlsCert.toURI()), - Paths.get(tlsOptions.tlsKey.toURI()), - tlsOptions.tlsKeyPassword?.toCharArray() - ) - val trustMaterial = PemUtils.loadTrustMaterial(Paths.get(tlsOptions.tlsCert.toURI())) - - val sslFactory = SSLFactory.builder() - .withIdentityMaterial(key) - .withTrustMaterial(trustMaterial) - .build() - - val serverSocket = withContext(Dispatchers.IO) { - (sslFactory.sslServerSocketFactory.createServerSocket( - config.tlsPort, - 50, - InetAddress.getByName(config.localAddress) - )) - } - serverSockets.add(serverSocket as ServerSocket) - val tlsServerSocket = serverSocket as SSLServerSocket - logger.info("Listening on tls: ${tlsServerSocket.localSocketAddress}") - - if (tlsOptions.tlsProtocols != null) { - val supportedProtocols = tlsServerSocket.supportedProtocols.toSet() - // Use filter to retain order of protocols/ciphers - tlsServerSocket.enabledProtocols = - tlsOptions.tlsProtocols.filter { supportedProtocols.contains(it) }.toTypedArray() - } - - if (tlsOptions.tlsCiphers != null) { - val supportedCipherSuites = tlsServerSocket.supportedCipherSuites.toSet() - // Use filter to retain order of protocols/ciphers - tlsServerSocket.enabledCipherSuites = - tlsOptions.tlsCiphers.filter { supportedCipherSuites.contains(it) }.toTypedArray() - } - - logger.info("Enabled TLS protocols: ${tlsServerSocket.enabledProtocols.joinToString(", ")}") - logger.info("Enabled TLS cipher suites: ${tlsServerSocket.enabledCipherSuites.joinToString(", ")}") - - while (!tlsServerSocket.isClosed) { - withContext(Dispatchers.IO) { - val socket = tlsServerSocket.accept() as SSLSocket - handleTcpSocket(SSLDoipTcpSocket(socket), ::pauseTcpServerSockets) - } - } - } - } - } - } - public fun start() { this._ecus.addAll(this.config.ecuConfigList.map { createEcu(it) }) @@ -419,57 +160,5 @@ public abstract class DoipEntity( _ecus.forEach { it.simStarted() } - - thread(name = "UDP") { - runBlocking { - udpServerSocket = - aSocket(ActorSelectorManager(Dispatchers.IO)) - .udp() - .bind(localAddress = InetSocketAddress(config.localAddress, 13400)) { - broadcast = true - reuseAddress = true -// reusePort = true // not supported on windows - typeOfService = TypeOfService.IPTOS_RELIABILITY -// socket.joinGroup(multicastAddress) - } - logger.info("Listening on udp: ${udpServerSocket.localAddress}") - startVamTimer(udpServerSocket) - val udpMessageHandler = createDoipUdpMessageHandler() - - if (config.localAddress != "0.0.0.0" && config.bindOnAnyForUdpAdditional) { - logger.info("Also listening on udp 0.0.0.0 for broadcasts") - val localAddress = InetSocketAddress("0.0.0.0", 13400) - val anyServerSocket = - aSocket(ActorSelectorManager(Dispatchers.IO)) - .udp() - .bind(localAddress = localAddress) { - broadcast = true - reuseAddress = true -// reusePort = true // not supported on windows - typeOfService = TypeOfService.IPTOS_RELIABILITY - } - thread(start = true, isDaemon = true) { - runBlocking { - while (!anyServerSocket.isClosed) { - val datagram = anyServerSocket.receive() - if (datagram.address is InetSocketAddress) { - if (datagram.address == localAddress) { - continue - } - } - handleUdpMessage(udpMessageHandler, datagram, anyServerSocket) - } - } - } - } - - while (!udpServerSocket.isClosed) { - val datagram = udpServerSocket.receive() - handleUdpMessage(udpMessageHandler, datagram, udpServerSocket) - } - } - } - - startTcpServerSockets() } } diff --git a/src/main/kotlin/library/DoipTcpMessages.kt b/src/main/kotlin/library/DoipTcpMessages.kt index 76f9f01..8ed32b1 100644 --- a/src/main/kotlin/library/DoipTcpMessages.kt +++ b/src/main/kotlin/library/DoipTcpMessages.kt @@ -8,21 +8,12 @@ import kotlin.experimental.inv public abstract class DoipTcpMessage : DoipMessage -@Suppress("MemberVisibilityCanBePrivate") -public open class DoipTcpConnectionMessageHandler( - public val socket: DoipTcpSocket, - public val maxPayloadLength: Int = Int.MAX_VALUE -) { - private var _registeredSourceAddress: Short? = null - - public var registeredSourceAddress: Short? - get() = _registeredSourceAddress - protected set(value) { - _registeredSourceAddress = value - } - +public class DoipTcpMessageParser(private val maxPayloadLength: Int) { + private companion object { + private val logger: Logger = LoggerFactory.getLogger(DoipTcpMessageParser::class.java) + } - public open suspend fun receiveTcpData(brc: ByteReadChannel): DoipTcpMessage { + public suspend fun parseDoipTcpMessage(brc: ByteReadChannel): DoipTcpMessage { logger.traceIf { "# receiveTcpData" } val protocolVersion = brc.readByte() val inverseProtocolVersion = brc.readByte() @@ -42,6 +33,7 @@ public open class DoipTcpConnectionMessageHandler( val code = brc.readByte() return DoipTcpHeaderNegAck(code) } + TYPE_TCP_ROUTING_REQ -> { val sourceAddress = brc.readShort() val activationType = brc.readByte() @@ -49,6 +41,7 @@ public open class DoipTcpConnectionMessageHandler( val oemData = if (payloadLength > 7) brc.readInt() else null return DoipTcpRoutingActivationRequest(sourceAddress, activationType, reserved, oemData) } + TYPE_TCP_ROUTING_RES -> { val testerAddress = brc.readShort() val entityAddress = brc.readShort() @@ -62,13 +55,16 @@ public open class DoipTcpConnectionMessageHandler( oemData = oemData ) } + TYPE_TCP_ALIVE_REQ -> { return DoipTcpAliveCheckRequest() } + TYPE_TCP_ALIVE_RES -> { val sourceAddress = brc.readShort() return DoipTcpAliveCheckResponse(sourceAddress) } + TYPE_TCP_DIAG_MESSAGE -> { val sourceAddress = brc.readShort() val targetAddress = brc.readShort() @@ -78,6 +74,7 @@ public open class DoipTcpConnectionMessageHandler( sourceAddress, targetAddress, payload ) } + TYPE_TCP_DIAG_MESSAGE_POS_ACK -> { val sourceAddress = brc.readShort() val targetAddress = brc.readShort() @@ -91,6 +88,7 @@ public open class DoipTcpConnectionMessageHandler( payload = payload ) } + TYPE_TCP_DIAG_MESSAGE_NEG_ACK -> { val sourceAddress = brc.readShort() val targetAddress = brc.readShort() @@ -104,9 +102,24 @@ public open class DoipTcpConnectionMessageHandler( payload = payload ) } + else -> throw UnknownPayloadType("Unknown payload type $payloadType") } } +} + +@Suppress("MemberVisibilityCanBePrivate") +public open class DoipTcpConnectionMessageHandler( + public val socket: DoipTcpSocket, +) { + private var _registeredSourceAddress: Short? = null + + public var registeredSourceAddress: Short? + get() = _registeredSourceAddress + protected set(value) { + _registeredSourceAddress = value + } + public open suspend fun handleTcpMessage(message: DoipTcpMessage, output: ByteWriteChannel) { logger.traceIf { "# handleTcpMessage $message" } diff --git a/src/main/kotlin/library/DoipUdpMessageHandler.kt b/src/main/kotlin/library/DoipUdpMessageHandler.kt index 34ad8a7..cd56cb3 100644 --- a/src/main/kotlin/library/DoipUdpMessageHandler.kt +++ b/src/main/kotlin/library/DoipUdpMessageHandler.kt @@ -144,7 +144,4 @@ public interface DoipUdpMessageHandler { ) ) } - - public suspend fun parseMessage(datagram: Datagram): DoipUdpMessage = - DoipUdpMessageParser.parseUDP(datagram.packet) } diff --git a/src/main/kotlin/library/EcuConfig.kt b/src/main/kotlin/library/EcuConfig.kt index 8a7b521..fed8f6b 100644 --- a/src/main/kotlin/library/EcuConfig.kt +++ b/src/main/kotlin/library/EcuConfig.kt @@ -7,5 +7,4 @@ public open class EcuConfig( public val logicalAddress: Short, public val functionalAddress: Short, public val pendingNrcSendInterval: kotlin.time.Duration = 2.seconds, - public val additionalVam: EcuAdditionalVamData? = null, ) diff --git a/src/main/kotlin/library/GroupDoipTcpConnectionMessageHandler.kt b/src/main/kotlin/library/GroupDoipTcpConnectionMessageHandler.kt new file mode 100644 index 0000000..a0211f0 --- /dev/null +++ b/src/main/kotlin/library/GroupDoipTcpConnectionMessageHandler.kt @@ -0,0 +1,29 @@ +package library + +import io.ktor.utils.io.ByteWriteChannel + +public class GroupDoipTcpConnectionMessageHandler( + entities: List>, + socket: DoipTcpSocket, + tlsOptions: TlsOptions?, +) : DefaultDoipEntityTcpConnectionMessageHandler(entities.first(), socket, entities.first().config.logicalAddress.toShort(), entities.first(), tlsOptions) { + + private val diagnosticMessageHandler: List = entities.map { it } + + init { + super.diagMessageHandler = GroupHandler(diagnosticMessageHandler) + } + + public class GroupHandler(private val list: List) : DiagnosticMessageHandler { + override fun existsTargetAddress(targetAddress: Short): Boolean = + list.any { it.existsTargetAddress(targetAddress) } + + override suspend fun onIncomingDiagMessage( + diagMessage: DoipTcpDiagMessage, + output: ByteWriteChannel + ) { + val handler = list.firstOrNull { it.existsTargetAddress(diagMessage.targetAddress) } ?: list.first() + handler.onIncomingDiagMessage(diagMessage, output) + } + } +} \ No newline at end of file diff --git a/src/main/kotlin/library/SimulatedEcu.kt b/src/main/kotlin/library/SimulatedEcu.kt index 8c5d8d6..1e73af2 100644 --- a/src/main/kotlin/library/SimulatedEcu.kt +++ b/src/main/kotlin/library/SimulatedEcu.kt @@ -17,7 +17,12 @@ public open class SimulatedEcu(public val config: EcuConfig) { private val isBusy: AtomicBoolean = AtomicBoolean(false) internal open fun simStarted() { + } + + public open fun reset() { + } + public open fun start() { } /** diff --git a/src/test/kotlin/SimGatewayTest.kt b/src/test/kotlin/SimDoipEntityTest.kt similarity index 92% rename from src/test/kotlin/SimGatewayTest.kt rename to src/test/kotlin/SimDoipEntityTest.kt index 18e0e56..01609bd 100644 --- a/src/test/kotlin/SimGatewayTest.kt +++ b/src/test/kotlin/SimDoipEntityTest.kt @@ -1,6 +1,5 @@ import assertk.assertThat import assertk.assertions.isEqualTo -import client.ConnectException import client.DoipClient import io.ktor.network.sockets.* import kotlinx.coroutines.channels.ClosedReceiveChannelException @@ -13,18 +12,17 @@ import kotlin.time.Duration.Companion.milliseconds import kotlin.time.Duration.Companion.seconds @Disabled("only for manual testing") -class SimGatewayTest { +class SimDoipEntityTest { @Test fun `test byte read channel`() { - val doipEntity = SimGateway( - GatewayData( + val doipEntity = SimDoipEntity( + DoipEntityData( name = "TEST" ).also { it.logicalAddress = 0x1010 it.gid = GID(6) it.eid = EID(6) it.vin = "01234567890123456" - it.localAddress = "127.0.0.1" it.functionalAddress = 0x3030 it.requests.add( RequestMatcher( @@ -63,15 +61,14 @@ class SimGatewayTest { @OptIn(ExperimentalDoipDslApi::class) @Test fun `test hard reset`() { - val doipEntity = SimGateway( - GatewayData( + val doipEntity = SimDoipEntity( + DoipEntityData( name = "TEST" ).also { it.logicalAddress = 0x1010 it.gid = GID(6) it.eid = EID(6) it.vin = "01234567890123456" - it.localAddress = "127.0.0.1" it.functionalAddress = 0x3030 it.requests.add( RequestMatcher( diff --git a/src/test/kotlin/SimDslTest.kt b/src/test/kotlin/SimDslTest.kt index a6246c9..e56ffb4 100644 --- a/src/test/kotlin/SimDslTest.kt +++ b/src/test/kotlin/SimDslTest.kt @@ -16,58 +16,69 @@ import kotlin.time.Duration.Companion.seconds class SimDslTest { @AfterEach fun tearDown() { - gateways.clear() - gatewayInstances.clear() + networks.clear() } @Test fun `test dsl`() { - gateway("GW") { - request(byteArrayOf(0x10), "REQ1") { respond(byteArrayOf(0x50)) } - request("10", "REQ2", duplicateStrategy = DuplicateStrategy.APPEND) { respond("50") } - request("10 []", "REQ3") { ack() } - request("10.*", "REQ4", duplicateStrategy = DuplicateStrategy.APPEND) { - nrc() - addOrReplaceEcuTimer(name = "TEST", delay = 100.milliseconds) { - // do nothing + network { + gateway("GW") { + request(byteArrayOf(0x10), "REQ1") { respond(byteArrayOf(0x50)) } + request("10", "REQ2", duplicateStrategy = DuplicateStrategy.APPEND) { respond("50") } + request("10 []", "REQ3") { ack() } + request("10.*", "REQ4", duplicateStrategy = DuplicateStrategy.APPEND) { + nrc() + addOrReplaceEcuTimer(name = "TEST", delay = 100.milliseconds) { + // do nothing + } } - } - onReset("RESETIT") { - } + onReset("RESETIT") { + } - ecu("ECU1") { - request(byteArrayOf(0x10),"REQ1") { ack() } - request("10", "REQ2", duplicateStrategy = DuplicateStrategy.APPEND) { ack() } - request("10 []", "REQ3") { ack() } - request("10.*", "REQ4", duplicateStrategy = DuplicateStrategy.APPEND) { nrc(); addOrReplaceEcuInterceptor(duration = 1.seconds) { false } } - additionalVam = EcuAdditionalVamData(eid = "1234".decodeHex()) + ecu("ECU1") { + request(byteArrayOf(0x10), "REQ1") { ack() } + request("10", "REQ2", duplicateStrategy = DuplicateStrategy.APPEND) { ack() } + request("10 []", "REQ3") { ack() } + request( + "10.*", + "REQ4", + duplicateStrategy = DuplicateStrategy.APPEND + ) { nrc(); addOrReplaceEcuInterceptor(duration = 1.seconds) { false } } + } } } - assertThat(gateways.size).isEqualTo(1) - assertThat(gateways[0].name).isEqualTo("GW") - assertThat(gateways[0].requests.size).isEqualTo(4) - assertThat(gateways[0].resetHandler.size).isEqualTo(1) + assertThat(networks.size).isEqualTo(1) + + val doipEntities = networks().first()._doipEntities + assertThat(doipEntities.size).isEqualTo(1) + assertThat(doipEntities[0].name).isEqualTo("GW") + assertThat(doipEntities[0].requests.size).isEqualTo(4) + assertThat(doipEntities[0].resetHandler.size).isEqualTo(1) - assertThat(gateways[0].ecus.size).isEqualTo(1) - assertThat(gateways[0].ecus[0].name).isEqualTo("ECU1") - assertThat(gateways[0].ecus[0].requests.size).isEqualTo(4) - assertThat(gateways[0].ecus[0].additionalVam!!.eid).isEqualTo("1234".decodeHex()) + assertThat(doipEntities[0].ecus.size).isEqualTo(1) + assertThat(doipEntities[0].ecus[0].name).isEqualTo("ECU1") + assertThat(doipEntities[0].ecus[0].requests.size).isEqualTo(4) - assertThat(gatewayInstances.size).isEqualTo(0) + assertThat(networkInstances.size).isEqualTo(0) } @Test fun `test multibyte ack`() { - gateway("GW") { - ecu("ECU") { - ackBytesLengthMap = mapOf(0x22.toByte() to 3) - request(byteArrayOf(0x22, 0x10, 0x20), "REQ2") { ack() } + network { + gateway("GW") { + ecu("ECU") { + ackBytesLengthMap = mapOf(0x22.toByte() to 3) + request(byteArrayOf(0x22, 0x10, 0x20), "REQ2") { ack() } + } } } - assertThat(gateways.size).isEqualTo(1) - val ecuData = gateways[0].ecus[0] + assertThat(networks.size).isEqualTo(1) + + val doipEntities = networks().first()._doipEntities + assertThat(doipEntities.size).isEqualTo(1) + val ecuData = doipEntities[0].ecus[0] val msg = UdsMessage( 0x1, 0x2, @@ -95,14 +106,22 @@ class SimDslTest { assertThat(createEcuFunc).isNotNull() } - fun createGwFunc(createGwFunc: CreateGatewayFunc) { + fun createGwFunc(createGwFunc: CreateDoipEntityFunc) { assertThat(createGwFunc).isNotNull() createGwFunc("TEST") { createEcuFunc(::ecu) } } - createGwFunc(::gateway) + fun createNetwork(createNetwork: CreateNetworkFunc) { + assertThat(createNetwork).isNotNull() + createNetwork { + createGwFunc(::gateway) + } + } + + + createNetwork(::network) } @Test diff --git a/src/test/kotlin/library/DoipTcpConnectionMessageHandlerTest.kt b/src/test/kotlin/library/DoipTcpConnectionMessageHandlerTest.kt index 30e7684..f42085a 100644 --- a/src/test/kotlin/library/DoipTcpConnectionMessageHandlerTest.kt +++ b/src/test/kotlin/library/DoipTcpConnectionMessageHandlerTest.kt @@ -28,19 +28,19 @@ class DoipTcpConnectionMessageHandlerTest { @Test fun `test receive`() { - val tcpMessageHandler = DoipTcpConnectionMessageHandler(mock()) + val parser = DoipTcpMessageParser(65535) val data = Random.nextBytes(10) runBlocking { - tcpMessageHandler.receiveTcpData(ByteReadChannel(DoipTcpHeaderNegAck(0x11).asByteArray)) - tcpMessageHandler.receiveTcpData(ByteReadChannel(DoipTcpAliveCheckRequest().asByteArray)) - tcpMessageHandler.receiveTcpData(ByteReadChannel(DoipTcpAliveCheckResponse(0x1234.toShort()).asByteArray)) - tcpMessageHandler.receiveTcpData(ByteReadChannel(DoipTcpDiagMessage(0x1234.toShort(), 0x4321.toShort(), data).asByteArray)) - tcpMessageHandler.receiveTcpData(ByteReadChannel(DoipTcpDiagMessageNegAck(0x11.toShort(), 0x22.toShort(), 0x11).asByteArray)) - tcpMessageHandler.receiveTcpData(ByteReadChannel(DoipTcpDiagMessagePosAck(0x11.toShort(), 0x22.toShort(), 0x11).asByteArray)) - tcpMessageHandler.receiveTcpData(ByteReadChannel(DoipTcpRoutingActivationResponse(0x11.toShort(), 0x22.toShort(), 0x11).asByteArray)) - tcpMessageHandler.receiveTcpData(ByteReadChannel(DoipTcpRoutingActivationRequest(0x11.toShort()).asByteArray)) - assertThrows { tcpMessageHandler.receiveTcpData(ByteReadChannel(byteArrayOf(0x0))) } - assertThrows { tcpMessageHandler.receiveTcpData(ByteReadChannel(doipMessage(0xffff.toShort()))) } + parser.parseDoipTcpMessage(ByteReadChannel(DoipTcpHeaderNegAck(0x11).asByteArray)) + parser.parseDoipTcpMessage(ByteReadChannel(DoipTcpAliveCheckRequest().asByteArray)) + parser.parseDoipTcpMessage(ByteReadChannel(DoipTcpAliveCheckResponse(0x1234.toShort()).asByteArray)) + parser.parseDoipTcpMessage(ByteReadChannel(DoipTcpDiagMessage(0x1234.toShort(), 0x4321.toShort(), data).asByteArray)) + parser.parseDoipTcpMessage(ByteReadChannel(DoipTcpDiagMessageNegAck(0x11.toShort(), 0x22.toShort(), 0x11).asByteArray)) + parser.parseDoipTcpMessage(ByteReadChannel(DoipTcpDiagMessagePosAck(0x11.toShort(), 0x22.toShort(), 0x11).asByteArray)) + parser.parseDoipTcpMessage(ByteReadChannel(DoipTcpRoutingActivationResponse(0x11.toShort(), 0x22.toShort(), 0x11).asByteArray)) + parser.parseDoipTcpMessage(ByteReadChannel(DoipTcpRoutingActivationRequest(0x11.toShort()).asByteArray)) + assertThrows { parser.parseDoipTcpMessage(ByteReadChannel(byteArrayOf(0x0))) } + assertThrows { parser.parseDoipTcpMessage(ByteReadChannel(doipMessage(0xffff.toShort()))) } } } }