From 4a620ce9b1b3dd8f5b47c53f454d6c9356f8ec9f Mon Sep 17 00:00:00 2001 From: Janne Mareike Koschinski <janne@kuschku.de> Date: Tue, 1 Mar 2022 16:20:43 +0100 Subject: [PATCH] fix: correct issue with coroutine contexts --- .../client/session/ClientHandshakeHandler.kt | 42 +++++-------- .../client/session/ClientSession.kt | 4 +- .../client/syncables/ClientBacklogManager.kt | 40 ++++++------- .../client/util/CoroutineKeyedQueue.kt | 12 +++- .../justjanne/libquassel/client/ClientTest.kt | 11 +++- .../protocol/io/CoroutineChannel.kt | 5 +- .../protocol/session/MessageChannel.kt | 10 +++- .../session/MessageChannelReadThread.kt | 36 ----------- .../protocol/session/MessageChannelReader.kt | 59 +++++++++++++++++++ 9 files changed, 123 insertions(+), 96 deletions(-) delete mode 100644 libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/session/MessageChannelReadThread.kt create mode 100644 libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/session/MessageChannelReader.kt diff --git a/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/session/ClientHandshakeHandler.kt b/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/session/ClientHandshakeHandler.kt index b92936b..5c751fc 100644 --- a/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/session/ClientHandshakeHandler.kt +++ b/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/session/ClientHandshakeHandler.kt @@ -16,7 +16,6 @@ import de.justjanne.libquassel.protocol.models.HandshakeMessage import de.justjanne.libquassel.protocol.serializers.HandshakeMessageSerializer import de.justjanne.libquassel.protocol.session.CoreState import de.justjanne.libquassel.protocol.session.HandshakeHandler -import de.justjanne.libquassel.protocol.session.MessageChannelReadThread import de.justjanne.libquassel.protocol.session.Session import de.justjanne.libquassel.protocol.util.log.trace import de.justjanne.libquassel.protocol.variant.QVariantMap @@ -56,18 +55,13 @@ class ClientHandshakeHandler( buildDate: String, featureSet: FeatureSet ): CoreState { - emit( - HandshakeMessage.ClientInit( - clientVersion, - buildDate, - featureSet - ) - ) when ( val response = messageQueue.wait( HandshakeMessage.ClientInitAck::class.java, HandshakeMessage.ClientInitReject::class.java - ) + ) { + emit(HandshakeMessage.ClientInit(clientVersion, buildDate, featureSet)) + } ) { is HandshakeMessage.ClientInitReject -> throw HandshakeException.InitException(response.errorString ?: "Unknown Error") @@ -89,17 +83,13 @@ class ClientHandshakeHandler( } override suspend fun login(username: String, password: String) { - emit( - HandshakeMessage.ClientLogin( - username, - password - ) - ) when ( val response = messageQueue.wait( HandshakeMessage.ClientLoginAck::class.java, HandshakeMessage.ClientLoginReject::class.java - ) + ) { + emit(HandshakeMessage.ClientLogin(username, password)) + } ) { is HandshakeMessage.ClientLoginReject -> throw HandshakeException.LoginException(response.errorString ?: "Unknown Error") @@ -117,21 +107,17 @@ class ClientHandshakeHandler( authenticator: String, authenticatorConfiguration: QVariantMap ) { - emit( - HandshakeMessage.CoreSetupData( - adminUsername, - adminPassword, - backend, - backendConfiguration, - authenticator, - authenticatorConfiguration - ) - ) when ( val response = messageQueue.wait( HandshakeMessage.CoreSetupAck::class.java, HandshakeMessage.CoreSetupReject::class.java - ) + ) { + emit( + HandshakeMessage.CoreSetupData( + adminUsername, adminPassword, backend, backendConfiguration, authenticator, authenticatorConfiguration + ) + ) + } ) { is HandshakeMessage.CoreSetupReject -> throw HandshakeException.SetupException(response.errorString ?: "Unknown Error") @@ -142,6 +128,6 @@ class ClientHandshakeHandler( } companion object { - private val logger = LoggerFactory.getLogger(MessageChannelReadThread::class.java) + private val logger = LoggerFactory.getLogger(ClientHandshakeHandler::class.java) } } diff --git a/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/session/ClientSession.kt b/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/session/ClientSession.kt index 58d4c6b..2d038b8 100644 --- a/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/session/ClientSession.kt +++ b/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/session/ClientSession.kt @@ -20,7 +20,7 @@ import de.justjanne.libquassel.protocol.models.ids.NetworkId import de.justjanne.libquassel.protocol.serializers.qt.StringSerializerUtf8 import de.justjanne.libquassel.protocol.session.CommonSyncProxy import de.justjanne.libquassel.protocol.session.MessageChannel -import de.justjanne.libquassel.protocol.session.MessageChannelReadThread +import de.justjanne.libquassel.protocol.session.MessageChannelReader import de.justjanne.libquassel.protocol.session.Session import de.justjanne.libquassel.protocol.syncables.HeartBeatHandler import de.justjanne.libquassel.protocol.syncables.ObjectRepository @@ -78,7 +78,7 @@ class ClientSession( messageChannel.register(magicHandler) messageChannel.register(handshakeHandler) messageChannel.register(proxyMessageHandler) - MessageChannelReadThread(messageChannel).start() + MessageChannelReader(messageChannel).start() } override fun init( diff --git a/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/syncables/ClientBacklogManager.kt b/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/syncables/ClientBacklogManager.kt index 3eafa65..31f88e4 100644 --- a/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/syncables/ClientBacklogManager.kt +++ b/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/syncables/ClientBacklogManager.kt @@ -37,10 +37,10 @@ class ClientBacklogManager( last: MsgId = MsgId(-1), limit: Int = -1, additional: Int = 0 - ): QVariantList { - requestBacklog(bufferId, first, last, limit, additional) - return bufferQueue.wait(BacklogData.Buffer(bufferId, first, last, limit, additional)) - } + ): QVariantList = + bufferQueue.wait(BacklogData.Buffer(bufferId, first, last, limit, additional)) { + requestBacklog(bufferId, first, last, limit, additional) + } suspend fun backlogFiltered( bufferId: BufferId, @@ -50,10 +50,10 @@ class ClientBacklogManager( additional: Int = 0, type: MessageTypes = MessageType.all, flags: MessageFlags = MessageFlag.all - ): QVariantList { - requestBacklogFiltered(bufferId, first, last, limit, additional, type.toBits().toInt(), flags.toBits().toInt()) - return bufferFilteredQueue.wait(BacklogData.BufferFiltered(bufferId, first, last, limit, additional, type, flags)) - } + ): QVariantList = + bufferFilteredQueue.wait(BacklogData.BufferFiltered(bufferId, first, last, limit, additional, type, flags)) { + requestBacklogFiltered(bufferId, first, last, limit, additional, type.toBits().toInt(), flags.toBits().toInt()) + } suspend fun backlogForward( bufferId: BufferId, @@ -62,20 +62,20 @@ class ClientBacklogManager( limit: Int = -1, type: MessageTypes = MessageType.all, flags: MessageFlags = MessageFlag.all - ): QVariantList { - requestBacklogForward(bufferId, first, last, limit, type.toBits().toInt(), flags.toBits().toInt()) - return bufferForwardQueue.wait(BacklogData.BufferForward(bufferId, first, last, limit, type, flags)) - } + ): QVariantList = + bufferForwardQueue.wait(BacklogData.BufferForward(bufferId, first, last, limit, type, flags)) { + requestBacklogForward(bufferId, first, last, limit, type.toBits().toInt(), flags.toBits().toInt()) + } suspend fun backlogAll( first: MsgId = MsgId(-1), last: MsgId = MsgId(-1), limit: Int = -1, additional: Int = 0 - ): QVariantList { - requestBacklogAll(first, last, limit, additional) - return allQueue.wait(BacklogData.All(first, last, limit, additional)) - } + ): QVariantList = + allQueue.wait(BacklogData.All(first, last, limit, additional)) { + requestBacklogAll(first, last, limit, additional) + } suspend fun backlogAllFiltered( first: MsgId = MsgId(-1), @@ -84,10 +84,10 @@ class ClientBacklogManager( additional: Int = 0, type: MessageTypes = MessageType.all, flags: MessageFlags = MessageFlag.all - ): QVariantList { - requestBacklogAllFiltered(first, last, limit, additional, type.toBits().toInt(), flags.toBits().toInt()) - return allFilteredQueue.wait(BacklogData.AllFiltered(first, last, limit, additional, type, flags)) - } + ): QVariantList = + allFilteredQueue.wait(BacklogData.AllFiltered(first, last, limit, additional, type, flags)) { + requestBacklogAllFiltered(first, last, limit, additional, type.toBits().toInt(), flags.toBits().toInt()) + } override fun receiveBacklog( bufferId: BufferId, diff --git a/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/util/CoroutineKeyedQueue.kt b/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/util/CoroutineKeyedQueue.kt index 62863cf..244df83 100644 --- a/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/util/CoroutineKeyedQueue.kt +++ b/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/util/CoroutineKeyedQueue.kt @@ -9,6 +9,9 @@ package de.justjanne.libquassel.client.util +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.launch import org.slf4j.LoggerFactory import kotlin.coroutines.Continuation import kotlin.coroutines.resume @@ -16,9 +19,12 @@ import kotlin.coroutines.suspendCoroutine class CoroutineKeyedQueue<Key, Value> { private val waiting = mutableMapOf<Key, MutableList<Continuation<Value>>>() - suspend fun wait(vararg keys: Key): Value = suspendCoroutine { - for (key in keys) { - waiting.getOrPut(key, ::mutableListOf).add(it) + suspend fun wait(vararg keys: Key, beforeWait: (suspend CoroutineScope.() -> Unit)? = null): Value = coroutineScope { + suspendCoroutine { continuation -> + for (key in keys) { + waiting.getOrPut(key, ::mutableListOf).add(continuation) + } + beforeWait?.let { launch(block = it) } } } diff --git a/libquassel-client/src/test/kotlin/de/justjanne/libquassel/client/ClientTest.kt b/libquassel-client/src/test/kotlin/de/justjanne/libquassel/client/ClientTest.kt index 0d1f3c6..c698f1e 100644 --- a/libquassel-client/src/test/kotlin/de/justjanne/libquassel/client/ClientTest.kt +++ b/libquassel-client/src/test/kotlin/de/justjanne/libquassel/client/ClientTest.kt @@ -19,7 +19,6 @@ import de.justjanne.libquassel.protocol.exceptions.HandshakeException import de.justjanne.libquassel.protocol.features.FeatureSet import de.justjanne.libquassel.protocol.io.CoroutineChannel import de.justjanne.libquassel.protocol.models.ids.BufferId -import de.justjanne.libquassel.protocol.models.ids.MsgId import de.justjanne.libquassel.protocol.session.CoreState import de.justjanne.testcontainersci.api.providedContainer import de.justjanne.testcontainersci.extension.CiContainers @@ -29,8 +28,8 @@ import kotlinx.coroutines.withTimeout import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows +import org.slf4j.LoggerFactory import java.net.InetSocketAddress -import java.time.Duration import javax.net.ssl.SSLContext import kotlin.test.assertEquals @@ -41,6 +40,8 @@ class ClientTest { QuasselCoreContainer() } + private val logger = LoggerFactory.getLogger(ClientTest::class.java) + private val username = "AzureDiamond" private val password = "hunter2" @@ -96,19 +97,23 @@ class ClientTest { } session.handshakeHandler.login(username, password) session.baseInitHandler.waitForInitDone() - withTimeout(Duration.ofSeconds(5).toMillis()) { + logger.trace("Init Done") + withTimeout(5_000L) { assertEquals( emptyList(), session.backlogManager.backlog(bufferId = BufferId(1), limit = 5) ) + logger.trace("Backlog Test #1 Done") assertEquals( emptyList(), session.backlogManager.backlogAll(limit = 5) ) + logger.trace("Backlog Test #2 Done") assertEquals( emptyList(), session.backlogManager.backlogForward(bufferId = BufferId(1), limit = 5) ) + logger.trace("Backlog Test #3 Done") } channel.close() } diff --git a/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/io/CoroutineChannel.kt b/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/io/CoroutineChannel.kt index 22e3de9..aa8f1fc 100644 --- a/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/io/CoroutineChannel.kt +++ b/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/io/CoroutineChannel.kt @@ -16,13 +16,14 @@ import kotlinx.coroutines.asCoroutineDispatcher import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.runInterruptible +import java.io.Closeable import java.net.InetSocketAddress import java.net.Socket import java.nio.ByteBuffer import java.util.concurrent.Executors import javax.net.ssl.SSLContext -class CoroutineChannel : StateHolder<CoroutineChannelState> { +class CoroutineChannel : StateHolder<CoroutineChannelState>, Closeable { private lateinit var channel: StreamChannel private val writeContext = Executors.newSingleThreadExecutor().asCoroutineDispatcher() private val readContext = Executors.newSingleThreadExecutor().asCoroutineDispatcher() @@ -75,7 +76,7 @@ class CoroutineChannel : StateHolder<CoroutineChannelState> { channel.flush() } - fun close() { + override fun close() { channel.close() state.update { copy(connected = false) diff --git a/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/session/MessageChannel.kt b/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/session/MessageChannel.kt index 8394110..935248d 100644 --- a/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/session/MessageChannel.kt +++ b/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/session/MessageChannel.kt @@ -17,12 +17,14 @@ import de.justjanne.libquassel.protocol.models.SignalProxyMessage import de.justjanne.libquassel.protocol.serializers.HandshakeMessageSerializer import de.justjanne.libquassel.protocol.serializers.SignalProxyMessageSerializer import de.justjanne.libquassel.protocol.util.log.trace +import kotlinx.coroutines.coroutineScope import org.slf4j.LoggerFactory +import java.io.Closeable import java.nio.ByteBuffer class MessageChannel( val channel: CoroutineChannel -) { +) : Closeable { var negotiatedFeatures = FeatureSet.none() private var handlers = mutableListOf<ConnectionHandler>() @@ -98,7 +100,7 @@ class MessageChannel( SignalProxyMessageSerializer.serialize(it, message, negotiatedFeatures) } - suspend fun emit(sizePrefix: Boolean = true, f: (ChainedByteBuffer) -> Unit) { + suspend fun emit(sizePrefix: Boolean = true, f: (ChainedByteBuffer) -> Unit) = coroutineScope { val sendBuffer = sendBuffer.get() val sizeBuffer = sizeBuffer.get() @@ -115,6 +117,10 @@ class MessageChannel( sendBuffer.clear() } + override fun close() { + channel.close() + } + companion object { private val logger = LoggerFactory.getLogger(MessageChannel::class.java) } diff --git a/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/session/MessageChannelReadThread.kt b/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/session/MessageChannelReadThread.kt deleted file mode 100644 index b2efcf1..0000000 --- a/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/session/MessageChannelReadThread.kt +++ /dev/null @@ -1,36 +0,0 @@ -/* - * libquassel - * Copyright (c) 2021 Janne Mareike Koschinski - * - * This Source Code Form is subject to the terms of the Mozilla Public License, - * v. 2.0. If a copy of the MPL was not distributed with this file, You can - * obtain one at https://mozilla.org/MPL/2.0/. - */ - -package de.justjanne.libquassel.protocol.session - -import de.justjanne.libquassel.protocol.util.log.info -import kotlinx.coroutines.runBlocking -import org.slf4j.LoggerFactory -import java.nio.channels.ClosedChannelException - -class MessageChannelReadThread( - val channel: MessageChannel -) : Thread("Message Channel Read Thread") { - override fun run() { - runBlocking { - try { - channel.init() - while (channel.channel.state().connected) { - channel.read() - } - } catch (e: ClosedChannelException) { - logger.info { "Channel closed" } - } - } - } - - companion object { - private val logger = LoggerFactory.getLogger(MessageChannelReadThread::class.java) - } -} diff --git a/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/session/MessageChannelReader.kt b/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/session/MessageChannelReader.kt new file mode 100644 index 0000000..9cb58c5 --- /dev/null +++ b/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/session/MessageChannelReader.kt @@ -0,0 +1,59 @@ +/* + * libquassel + * Copyright (c) 2021 Janne Mareike Koschinski + * + * This Source Code Form is subject to the terms of the Mozilla Public License, + * v. 2.0. If a copy of the MPL was not distributed with this file, You can + * obtain one at https://mozilla.org/MPL/2.0/. + */ + +package de.justjanne.libquassel.protocol.session + +import de.justjanne.libquassel.protocol.util.log.info +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Job +import kotlinx.coroutines.asCoroutineDispatcher +import kotlinx.coroutines.cancel +import kotlinx.coroutines.cancelAndJoin +import kotlinx.coroutines.isActive +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import org.slf4j.LoggerFactory +import java.io.Closeable +import java.nio.channels.ClosedChannelException +import java.util.concurrent.Executors + +class MessageChannelReader( + private val channel: MessageChannel +) : Closeable { + private val executor = Executors.newSingleThreadExecutor() + private val dispatcher = executor.asCoroutineDispatcher() + private val scope = CoroutineScope(dispatcher) + private var job: Job? = null + + fun start() { + job = scope.launch { + try { + channel.init() + while (isActive && channel.channel.state().connected) { + channel.read() + } + } catch (e: ClosedChannelException) { + logger.info { "Channel closed" } + close() + } + } + } + + override fun close() { + channel.close() + runBlocking { job?.cancelAndJoin() } + scope.cancel() + dispatcher.cancel() + executor.shutdown() + } + + companion object { + private val logger = LoggerFactory.getLogger(MessageChannelReader::class.java) + } +} -- GitLab