From 4a620ce9b1b3dd8f5b47c53f454d6c9356f8ec9f Mon Sep 17 00:00:00 2001
From: Janne Mareike Koschinski <janne@kuschku.de>
Date: Tue, 1 Mar 2022 16:20:43 +0100
Subject: [PATCH] fix: correct issue with coroutine contexts

---
 .../client/session/ClientHandshakeHandler.kt  | 42 +++++--------
 .../client/session/ClientSession.kt           |  4 +-
 .../client/syncables/ClientBacklogManager.kt  | 40 ++++++-------
 .../client/util/CoroutineKeyedQueue.kt        | 12 +++-
 .../justjanne/libquassel/client/ClientTest.kt | 11 +++-
 .../protocol/io/CoroutineChannel.kt           |  5 +-
 .../protocol/session/MessageChannel.kt        | 10 +++-
 .../session/MessageChannelReadThread.kt       | 36 -----------
 .../protocol/session/MessageChannelReader.kt  | 59 +++++++++++++++++++
 9 files changed, 123 insertions(+), 96 deletions(-)
 delete mode 100644 libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/session/MessageChannelReadThread.kt
 create mode 100644 libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/session/MessageChannelReader.kt

diff --git a/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/session/ClientHandshakeHandler.kt b/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/session/ClientHandshakeHandler.kt
index b92936b..5c751fc 100644
--- a/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/session/ClientHandshakeHandler.kt
+++ b/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/session/ClientHandshakeHandler.kt
@@ -16,7 +16,6 @@ import de.justjanne.libquassel.protocol.models.HandshakeMessage
 import de.justjanne.libquassel.protocol.serializers.HandshakeMessageSerializer
 import de.justjanne.libquassel.protocol.session.CoreState
 import de.justjanne.libquassel.protocol.session.HandshakeHandler
-import de.justjanne.libquassel.protocol.session.MessageChannelReadThread
 import de.justjanne.libquassel.protocol.session.Session
 import de.justjanne.libquassel.protocol.util.log.trace
 import de.justjanne.libquassel.protocol.variant.QVariantMap
@@ -56,18 +55,13 @@ class ClientHandshakeHandler(
     buildDate: String,
     featureSet: FeatureSet
   ): CoreState {
-    emit(
-      HandshakeMessage.ClientInit(
-        clientVersion,
-        buildDate,
-        featureSet
-      )
-    )
     when (
       val response = messageQueue.wait(
         HandshakeMessage.ClientInitAck::class.java,
         HandshakeMessage.ClientInitReject::class.java
-      )
+      ) {
+        emit(HandshakeMessage.ClientInit(clientVersion, buildDate, featureSet))
+      }
     ) {
       is HandshakeMessage.ClientInitReject ->
         throw HandshakeException.InitException(response.errorString ?: "Unknown Error")
@@ -89,17 +83,13 @@ class ClientHandshakeHandler(
   }
 
   override suspend fun login(username: String, password: String) {
-    emit(
-      HandshakeMessage.ClientLogin(
-        username,
-        password
-      )
-    )
     when (
       val response = messageQueue.wait(
         HandshakeMessage.ClientLoginAck::class.java,
         HandshakeMessage.ClientLoginReject::class.java
-      )
+      ) {
+        emit(HandshakeMessage.ClientLogin(username, password))
+      }
     ) {
       is HandshakeMessage.ClientLoginReject ->
         throw HandshakeException.LoginException(response.errorString ?: "Unknown Error")
@@ -117,21 +107,17 @@ class ClientHandshakeHandler(
     authenticator: String,
     authenticatorConfiguration: QVariantMap
   ) {
-    emit(
-      HandshakeMessage.CoreSetupData(
-        adminUsername,
-        adminPassword,
-        backend,
-        backendConfiguration,
-        authenticator,
-        authenticatorConfiguration
-      )
-    )
     when (
       val response = messageQueue.wait(
         HandshakeMessage.CoreSetupAck::class.java,
         HandshakeMessage.CoreSetupReject::class.java
-      )
+      ) {
+        emit(
+          HandshakeMessage.CoreSetupData(
+            adminUsername, adminPassword, backend, backendConfiguration, authenticator, authenticatorConfiguration
+          )
+        )
+      }
     ) {
       is HandshakeMessage.CoreSetupReject ->
         throw HandshakeException.SetupException(response.errorString ?: "Unknown Error")
@@ -142,6 +128,6 @@ class ClientHandshakeHandler(
   }
 
   companion object {
-    private val logger = LoggerFactory.getLogger(MessageChannelReadThread::class.java)
+    private val logger = LoggerFactory.getLogger(ClientHandshakeHandler::class.java)
   }
 }
diff --git a/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/session/ClientSession.kt b/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/session/ClientSession.kt
index 58d4c6b..2d038b8 100644
--- a/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/session/ClientSession.kt
+++ b/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/session/ClientSession.kt
@@ -20,7 +20,7 @@ import de.justjanne.libquassel.protocol.models.ids.NetworkId
 import de.justjanne.libquassel.protocol.serializers.qt.StringSerializerUtf8
 import de.justjanne.libquassel.protocol.session.CommonSyncProxy
 import de.justjanne.libquassel.protocol.session.MessageChannel
-import de.justjanne.libquassel.protocol.session.MessageChannelReadThread
+import de.justjanne.libquassel.protocol.session.MessageChannelReader
 import de.justjanne.libquassel.protocol.session.Session
 import de.justjanne.libquassel.protocol.syncables.HeartBeatHandler
 import de.justjanne.libquassel.protocol.syncables.ObjectRepository
@@ -78,7 +78,7 @@ class ClientSession(
     messageChannel.register(magicHandler)
     messageChannel.register(handshakeHandler)
     messageChannel.register(proxyMessageHandler)
-    MessageChannelReadThread(messageChannel).start()
+    MessageChannelReader(messageChannel).start()
   }
 
   override fun init(
diff --git a/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/syncables/ClientBacklogManager.kt b/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/syncables/ClientBacklogManager.kt
index 3eafa65..31f88e4 100644
--- a/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/syncables/ClientBacklogManager.kt
+++ b/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/syncables/ClientBacklogManager.kt
@@ -37,10 +37,10 @@ class ClientBacklogManager(
     last: MsgId = MsgId(-1),
     limit: Int = -1,
     additional: Int = 0
-  ): QVariantList {
-    requestBacklog(bufferId, first, last, limit, additional)
-    return bufferQueue.wait(BacklogData.Buffer(bufferId, first, last, limit, additional))
-  }
+  ): QVariantList =
+    bufferQueue.wait(BacklogData.Buffer(bufferId, first, last, limit, additional)) {
+      requestBacklog(bufferId, first, last, limit, additional)
+    }
 
   suspend fun backlogFiltered(
     bufferId: BufferId,
@@ -50,10 +50,10 @@ class ClientBacklogManager(
     additional: Int = 0,
     type: MessageTypes = MessageType.all,
     flags: MessageFlags = MessageFlag.all
-  ): QVariantList {
-    requestBacklogFiltered(bufferId, first, last, limit, additional, type.toBits().toInt(), flags.toBits().toInt())
-    return bufferFilteredQueue.wait(BacklogData.BufferFiltered(bufferId, first, last, limit, additional, type, flags))
-  }
+  ): QVariantList =
+    bufferFilteredQueue.wait(BacklogData.BufferFiltered(bufferId, first, last, limit, additional, type, flags)) {
+      requestBacklogFiltered(bufferId, first, last, limit, additional, type.toBits().toInt(), flags.toBits().toInt())
+    }
 
   suspend fun backlogForward(
     bufferId: BufferId,
@@ -62,20 +62,20 @@ class ClientBacklogManager(
     limit: Int = -1,
     type: MessageTypes = MessageType.all,
     flags: MessageFlags = MessageFlag.all
-  ): QVariantList {
-    requestBacklogForward(bufferId, first, last, limit, type.toBits().toInt(), flags.toBits().toInt())
-    return bufferForwardQueue.wait(BacklogData.BufferForward(bufferId, first, last, limit, type, flags))
-  }
+  ): QVariantList =
+    bufferForwardQueue.wait(BacklogData.BufferForward(bufferId, first, last, limit, type, flags)) {
+      requestBacklogForward(bufferId, first, last, limit, type.toBits().toInt(), flags.toBits().toInt())
+    }
 
   suspend fun backlogAll(
     first: MsgId = MsgId(-1),
     last: MsgId = MsgId(-1),
     limit: Int = -1,
     additional: Int = 0
-  ): QVariantList {
-    requestBacklogAll(first, last, limit, additional)
-    return allQueue.wait(BacklogData.All(first, last, limit, additional))
-  }
+  ): QVariantList =
+    allQueue.wait(BacklogData.All(first, last, limit, additional)) {
+      requestBacklogAll(first, last, limit, additional)
+    }
 
   suspend fun backlogAllFiltered(
     first: MsgId = MsgId(-1),
@@ -84,10 +84,10 @@ class ClientBacklogManager(
     additional: Int = 0,
     type: MessageTypes = MessageType.all,
     flags: MessageFlags = MessageFlag.all
-  ): QVariantList {
-    requestBacklogAllFiltered(first, last, limit, additional, type.toBits().toInt(), flags.toBits().toInt())
-    return allFilteredQueue.wait(BacklogData.AllFiltered(first, last, limit, additional, type, flags))
-  }
+  ): QVariantList =
+    allFilteredQueue.wait(BacklogData.AllFiltered(first, last, limit, additional, type, flags)) {
+      requestBacklogAllFiltered(first, last, limit, additional, type.toBits().toInt(), flags.toBits().toInt())
+    }
 
   override fun receiveBacklog(
     bufferId: BufferId,
diff --git a/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/util/CoroutineKeyedQueue.kt b/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/util/CoroutineKeyedQueue.kt
index 62863cf..244df83 100644
--- a/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/util/CoroutineKeyedQueue.kt
+++ b/libquassel-client/src/main/kotlin/de/justjanne/libquassel/client/util/CoroutineKeyedQueue.kt
@@ -9,6 +9,9 @@
 
 package de.justjanne.libquassel.client.util
 
+import kotlinx.coroutines.CoroutineScope
+import kotlinx.coroutines.coroutineScope
+import kotlinx.coroutines.launch
 import org.slf4j.LoggerFactory
 import kotlin.coroutines.Continuation
 import kotlin.coroutines.resume
@@ -16,9 +19,12 @@ import kotlin.coroutines.suspendCoroutine
 
 class CoroutineKeyedQueue<Key, Value> {
   private val waiting = mutableMapOf<Key, MutableList<Continuation<Value>>>()
-  suspend fun wait(vararg keys: Key): Value = suspendCoroutine {
-    for (key in keys) {
-      waiting.getOrPut(key, ::mutableListOf).add(it)
+  suspend fun wait(vararg keys: Key, beforeWait: (suspend CoroutineScope.() -> Unit)? = null): Value = coroutineScope {
+    suspendCoroutine { continuation ->
+      for (key in keys) {
+        waiting.getOrPut(key, ::mutableListOf).add(continuation)
+      }
+      beforeWait?.let { launch(block = it) }
     }
   }
 
diff --git a/libquassel-client/src/test/kotlin/de/justjanne/libquassel/client/ClientTest.kt b/libquassel-client/src/test/kotlin/de/justjanne/libquassel/client/ClientTest.kt
index 0d1f3c6..c698f1e 100644
--- a/libquassel-client/src/test/kotlin/de/justjanne/libquassel/client/ClientTest.kt
+++ b/libquassel-client/src/test/kotlin/de/justjanne/libquassel/client/ClientTest.kt
@@ -19,7 +19,6 @@ import de.justjanne.libquassel.protocol.exceptions.HandshakeException
 import de.justjanne.libquassel.protocol.features.FeatureSet
 import de.justjanne.libquassel.protocol.io.CoroutineChannel
 import de.justjanne.libquassel.protocol.models.ids.BufferId
-import de.justjanne.libquassel.protocol.models.ids.MsgId
 import de.justjanne.libquassel.protocol.session.CoreState
 import de.justjanne.testcontainersci.api.providedContainer
 import de.justjanne.testcontainersci.extension.CiContainers
@@ -29,8 +28,8 @@ import kotlinx.coroutines.withTimeout
 import org.junit.jupiter.api.Assertions.assertTrue
 import org.junit.jupiter.api.Test
 import org.junit.jupiter.api.assertThrows
+import org.slf4j.LoggerFactory
 import java.net.InetSocketAddress
-import java.time.Duration
 import javax.net.ssl.SSLContext
 import kotlin.test.assertEquals
 
@@ -41,6 +40,8 @@ class ClientTest {
     QuasselCoreContainer()
   }
 
+  private val logger = LoggerFactory.getLogger(ClientTest::class.java)
+
   private val username = "AzureDiamond"
   private val password = "hunter2"
 
@@ -96,19 +97,23 @@ class ClientTest {
     }
     session.handshakeHandler.login(username, password)
     session.baseInitHandler.waitForInitDone()
-    withTimeout(Duration.ofSeconds(5).toMillis()) {
+    logger.trace("Init Done")
+    withTimeout(5_000L) {
       assertEquals(
         emptyList(),
         session.backlogManager.backlog(bufferId = BufferId(1), limit = 5)
       )
+      logger.trace("Backlog Test #1 Done")
       assertEquals(
         emptyList(),
         session.backlogManager.backlogAll(limit = 5)
       )
+      logger.trace("Backlog Test #2 Done")
       assertEquals(
         emptyList(),
         session.backlogManager.backlogForward(bufferId = BufferId(1), limit = 5)
       )
+      logger.trace("Backlog Test #3 Done")
     }
     channel.close()
   }
diff --git a/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/io/CoroutineChannel.kt b/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/io/CoroutineChannel.kt
index 22e3de9..aa8f1fc 100644
--- a/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/io/CoroutineChannel.kt
+++ b/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/io/CoroutineChannel.kt
@@ -16,13 +16,14 @@ import kotlinx.coroutines.asCoroutineDispatcher
 import kotlinx.coroutines.flow.Flow
 import kotlinx.coroutines.flow.MutableStateFlow
 import kotlinx.coroutines.runInterruptible
+import java.io.Closeable
 import java.net.InetSocketAddress
 import java.net.Socket
 import java.nio.ByteBuffer
 import java.util.concurrent.Executors
 import javax.net.ssl.SSLContext
 
-class CoroutineChannel : StateHolder<CoroutineChannelState> {
+class CoroutineChannel : StateHolder<CoroutineChannelState>, Closeable {
   private lateinit var channel: StreamChannel
   private val writeContext = Executors.newSingleThreadExecutor().asCoroutineDispatcher()
   private val readContext = Executors.newSingleThreadExecutor().asCoroutineDispatcher()
@@ -75,7 +76,7 @@ class CoroutineChannel : StateHolder<CoroutineChannelState> {
     channel.flush()
   }
 
-  fun close() {
+  override fun close() {
     channel.close()
     state.update {
       copy(connected = false)
diff --git a/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/session/MessageChannel.kt b/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/session/MessageChannel.kt
index 8394110..935248d 100644
--- a/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/session/MessageChannel.kt
+++ b/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/session/MessageChannel.kt
@@ -17,12 +17,14 @@ import de.justjanne.libquassel.protocol.models.SignalProxyMessage
 import de.justjanne.libquassel.protocol.serializers.HandshakeMessageSerializer
 import de.justjanne.libquassel.protocol.serializers.SignalProxyMessageSerializer
 import de.justjanne.libquassel.protocol.util.log.trace
+import kotlinx.coroutines.coroutineScope
 import org.slf4j.LoggerFactory
+import java.io.Closeable
 import java.nio.ByteBuffer
 
 class MessageChannel(
   val channel: CoroutineChannel
-) {
+) : Closeable {
   var negotiatedFeatures = FeatureSet.none()
 
   private var handlers = mutableListOf<ConnectionHandler>()
@@ -98,7 +100,7 @@ class MessageChannel(
     SignalProxyMessageSerializer.serialize(it, message, negotiatedFeatures)
   }
 
-  suspend fun emit(sizePrefix: Boolean = true, f: (ChainedByteBuffer) -> Unit) {
+  suspend fun emit(sizePrefix: Boolean = true, f: (ChainedByteBuffer) -> Unit) = coroutineScope {
     val sendBuffer = sendBuffer.get()
     val sizeBuffer = sizeBuffer.get()
 
@@ -115,6 +117,10 @@ class MessageChannel(
     sendBuffer.clear()
   }
 
+  override fun close() {
+    channel.close()
+  }
+
   companion object {
     private val logger = LoggerFactory.getLogger(MessageChannel::class.java)
   }
diff --git a/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/session/MessageChannelReadThread.kt b/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/session/MessageChannelReadThread.kt
deleted file mode 100644
index b2efcf1..0000000
--- a/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/session/MessageChannelReadThread.kt
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * libquassel
- * Copyright (c) 2021 Janne Mareike Koschinski
- *
- * This Source Code Form is subject to the terms of the Mozilla Public License,
- * v. 2.0. If a copy of the MPL was not distributed with this file, You can
- * obtain one at https://mozilla.org/MPL/2.0/.
- */
-
-package de.justjanne.libquassel.protocol.session
-
-import de.justjanne.libquassel.protocol.util.log.info
-import kotlinx.coroutines.runBlocking
-import org.slf4j.LoggerFactory
-import java.nio.channels.ClosedChannelException
-
-class MessageChannelReadThread(
-  val channel: MessageChannel
-) : Thread("Message Channel Read Thread") {
-  override fun run() {
-    runBlocking {
-      try {
-        channel.init()
-        while (channel.channel.state().connected) {
-          channel.read()
-        }
-      } catch (e: ClosedChannelException) {
-        logger.info { "Channel closed" }
-      }
-    }
-  }
-
-  companion object {
-    private val logger = LoggerFactory.getLogger(MessageChannelReadThread::class.java)
-  }
-}
diff --git a/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/session/MessageChannelReader.kt b/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/session/MessageChannelReader.kt
new file mode 100644
index 0000000..9cb58c5
--- /dev/null
+++ b/libquassel-protocol/src/main/kotlin/de/justjanne/libquassel/protocol/session/MessageChannelReader.kt
@@ -0,0 +1,59 @@
+/*
+ * libquassel
+ * Copyright (c) 2021 Janne Mareike Koschinski
+ *
+ * This Source Code Form is subject to the terms of the Mozilla Public License,
+ * v. 2.0. If a copy of the MPL was not distributed with this file, You can
+ * obtain one at https://mozilla.org/MPL/2.0/.
+ */
+
+package de.justjanne.libquassel.protocol.session
+
+import de.justjanne.libquassel.protocol.util.log.info
+import kotlinx.coroutines.CoroutineScope
+import kotlinx.coroutines.Job
+import kotlinx.coroutines.asCoroutineDispatcher
+import kotlinx.coroutines.cancel
+import kotlinx.coroutines.cancelAndJoin
+import kotlinx.coroutines.isActive
+import kotlinx.coroutines.launch
+import kotlinx.coroutines.runBlocking
+import org.slf4j.LoggerFactory
+import java.io.Closeable
+import java.nio.channels.ClosedChannelException
+import java.util.concurrent.Executors
+
+class MessageChannelReader(
+  private val channel: MessageChannel
+) : Closeable {
+  private val executor = Executors.newSingleThreadExecutor()
+  private val dispatcher = executor.asCoroutineDispatcher()
+  private val scope = CoroutineScope(dispatcher)
+  private var job: Job? = null
+
+  fun start() {
+    job = scope.launch {
+      try {
+        channel.init()
+        while (isActive && channel.channel.state().connected) {
+          channel.read()
+        }
+      } catch (e: ClosedChannelException) {
+        logger.info { "Channel closed" }
+        close()
+      }
+    }
+  }
+
+  override fun close() {
+    channel.close()
+    runBlocking { job?.cancelAndJoin() }
+    scope.cancel()
+    dispatcher.cancel()
+    executor.shutdown()
+  }
+
+  companion object {
+    private val logger = LoggerFactory.getLogger(MessageChannelReader::class.java)
+  }
+}
-- 
GitLab