Skip to content
Snippets Groups Projects
Verified Commit 4a620ce9 authored by Janne Mareike Koschinski's avatar Janne Mareike Koschinski
Browse files

fix: correct issue with coroutine contexts

parent 8a638126
No related branches found
No related tags found
No related merge requests found
Showing
with 123 additions and 60 deletions
......@@ -16,7 +16,6 @@ import de.justjanne.libquassel.protocol.models.HandshakeMessage
import de.justjanne.libquassel.protocol.serializers.HandshakeMessageSerializer
import de.justjanne.libquassel.protocol.session.CoreState
import de.justjanne.libquassel.protocol.session.HandshakeHandler
import de.justjanne.libquassel.protocol.session.MessageChannelReadThread
import de.justjanne.libquassel.protocol.session.Session
import de.justjanne.libquassel.protocol.util.log.trace
import de.justjanne.libquassel.protocol.variant.QVariantMap
......@@ -56,18 +55,13 @@ class ClientHandshakeHandler(
buildDate: String,
featureSet: FeatureSet
): CoreState {
emit(
HandshakeMessage.ClientInit(
clientVersion,
buildDate,
featureSet
)
)
when (
val response = messageQueue.wait(
HandshakeMessage.ClientInitAck::class.java,
HandshakeMessage.ClientInitReject::class.java
)
) {
emit(HandshakeMessage.ClientInit(clientVersion, buildDate, featureSet))
}
) {
is HandshakeMessage.ClientInitReject ->
throw HandshakeException.InitException(response.errorString ?: "Unknown Error")
......@@ -89,17 +83,13 @@ class ClientHandshakeHandler(
}
override suspend fun login(username: String, password: String) {
emit(
HandshakeMessage.ClientLogin(
username,
password
)
)
when (
val response = messageQueue.wait(
HandshakeMessage.ClientLoginAck::class.java,
HandshakeMessage.ClientLoginReject::class.java
)
) {
emit(HandshakeMessage.ClientLogin(username, password))
}
) {
is HandshakeMessage.ClientLoginReject ->
throw HandshakeException.LoginException(response.errorString ?: "Unknown Error")
......@@ -117,21 +107,17 @@ class ClientHandshakeHandler(
authenticator: String,
authenticatorConfiguration: QVariantMap
) {
emit(
HandshakeMessage.CoreSetupData(
adminUsername,
adminPassword,
backend,
backendConfiguration,
authenticator,
authenticatorConfiguration
)
)
when (
val response = messageQueue.wait(
HandshakeMessage.CoreSetupAck::class.java,
HandshakeMessage.CoreSetupReject::class.java
) {
emit(
HandshakeMessage.CoreSetupData(
adminUsername, adminPassword, backend, backendConfiguration, authenticator, authenticatorConfiguration
)
)
}
) {
is HandshakeMessage.CoreSetupReject ->
throw HandshakeException.SetupException(response.errorString ?: "Unknown Error")
......@@ -142,6 +128,6 @@ class ClientHandshakeHandler(
}
companion object {
private val logger = LoggerFactory.getLogger(MessageChannelReadThread::class.java)
private val logger = LoggerFactory.getLogger(ClientHandshakeHandler::class.java)
}
}
......@@ -20,7 +20,7 @@ import de.justjanne.libquassel.protocol.models.ids.NetworkId
import de.justjanne.libquassel.protocol.serializers.qt.StringSerializerUtf8
import de.justjanne.libquassel.protocol.session.CommonSyncProxy
import de.justjanne.libquassel.protocol.session.MessageChannel
import de.justjanne.libquassel.protocol.session.MessageChannelReadThread
import de.justjanne.libquassel.protocol.session.MessageChannelReader
import de.justjanne.libquassel.protocol.session.Session
import de.justjanne.libquassel.protocol.syncables.HeartBeatHandler
import de.justjanne.libquassel.protocol.syncables.ObjectRepository
......@@ -78,7 +78,7 @@ class ClientSession(
messageChannel.register(magicHandler)
messageChannel.register(handshakeHandler)
messageChannel.register(proxyMessageHandler)
MessageChannelReadThread(messageChannel).start()
MessageChannelReader(messageChannel).start()
}
override fun init(
......
......@@ -37,9 +37,9 @@ class ClientBacklogManager(
last: MsgId = MsgId(-1),
limit: Int = -1,
additional: Int = 0
): QVariantList {
): QVariantList =
bufferQueue.wait(BacklogData.Buffer(bufferId, first, last, limit, additional)) {
requestBacklog(bufferId, first, last, limit, additional)
return bufferQueue.wait(BacklogData.Buffer(bufferId, first, last, limit, additional))
}
suspend fun backlogFiltered(
......@@ -50,9 +50,9 @@ class ClientBacklogManager(
additional: Int = 0,
type: MessageTypes = MessageType.all,
flags: MessageFlags = MessageFlag.all
): QVariantList {
): QVariantList =
bufferFilteredQueue.wait(BacklogData.BufferFiltered(bufferId, first, last, limit, additional, type, flags)) {
requestBacklogFiltered(bufferId, first, last, limit, additional, type.toBits().toInt(), flags.toBits().toInt())
return bufferFilteredQueue.wait(BacklogData.BufferFiltered(bufferId, first, last, limit, additional, type, flags))
}
suspend fun backlogForward(
......@@ -62,9 +62,9 @@ class ClientBacklogManager(
limit: Int = -1,
type: MessageTypes = MessageType.all,
flags: MessageFlags = MessageFlag.all
): QVariantList {
): QVariantList =
bufferForwardQueue.wait(BacklogData.BufferForward(bufferId, first, last, limit, type, flags)) {
requestBacklogForward(bufferId, first, last, limit, type.toBits().toInt(), flags.toBits().toInt())
return bufferForwardQueue.wait(BacklogData.BufferForward(bufferId, first, last, limit, type, flags))
}
suspend fun backlogAll(
......@@ -72,9 +72,9 @@ class ClientBacklogManager(
last: MsgId = MsgId(-1),
limit: Int = -1,
additional: Int = 0
): QVariantList {
): QVariantList =
allQueue.wait(BacklogData.All(first, last, limit, additional)) {
requestBacklogAll(first, last, limit, additional)
return allQueue.wait(BacklogData.All(first, last, limit, additional))
}
suspend fun backlogAllFiltered(
......@@ -84,9 +84,9 @@ class ClientBacklogManager(
additional: Int = 0,
type: MessageTypes = MessageType.all,
flags: MessageFlags = MessageFlag.all
): QVariantList {
): QVariantList =
allFilteredQueue.wait(BacklogData.AllFiltered(first, last, limit, additional, type, flags)) {
requestBacklogAllFiltered(first, last, limit, additional, type.toBits().toInt(), flags.toBits().toInt())
return allFilteredQueue.wait(BacklogData.AllFiltered(first, last, limit, additional, type, flags))
}
override fun receiveBacklog(
......
......@@ -9,6 +9,9 @@
package de.justjanne.libquassel.client.util
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.launch
import org.slf4j.LoggerFactory
import kotlin.coroutines.Continuation
import kotlin.coroutines.resume
......@@ -16,9 +19,12 @@ import kotlin.coroutines.suspendCoroutine
class CoroutineKeyedQueue<Key, Value> {
private val waiting = mutableMapOf<Key, MutableList<Continuation<Value>>>()
suspend fun wait(vararg keys: Key): Value = suspendCoroutine {
suspend fun wait(vararg keys: Key, beforeWait: (suspend CoroutineScope.() -> Unit)? = null): Value = coroutineScope {
suspendCoroutine { continuation ->
for (key in keys) {
waiting.getOrPut(key, ::mutableListOf).add(it)
waiting.getOrPut(key, ::mutableListOf).add(continuation)
}
beforeWait?.let { launch(block = it) }
}
}
......
......@@ -19,7 +19,6 @@ import de.justjanne.libquassel.protocol.exceptions.HandshakeException
import de.justjanne.libquassel.protocol.features.FeatureSet
import de.justjanne.libquassel.protocol.io.CoroutineChannel
import de.justjanne.libquassel.protocol.models.ids.BufferId
import de.justjanne.libquassel.protocol.models.ids.MsgId
import de.justjanne.libquassel.protocol.session.CoreState
import de.justjanne.testcontainersci.api.providedContainer
import de.justjanne.testcontainersci.extension.CiContainers
......@@ -29,8 +28,8 @@ import kotlinx.coroutines.withTimeout
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import org.slf4j.LoggerFactory
import java.net.InetSocketAddress
import java.time.Duration
import javax.net.ssl.SSLContext
import kotlin.test.assertEquals
......@@ -41,6 +40,8 @@ class ClientTest {
QuasselCoreContainer()
}
private val logger = LoggerFactory.getLogger(ClientTest::class.java)
private val username = "AzureDiamond"
private val password = "hunter2"
......@@ -96,19 +97,23 @@ class ClientTest {
}
session.handshakeHandler.login(username, password)
session.baseInitHandler.waitForInitDone()
withTimeout(Duration.ofSeconds(5).toMillis()) {
logger.trace("Init Done")
withTimeout(5_000L) {
assertEquals(
emptyList(),
session.backlogManager.backlog(bufferId = BufferId(1), limit = 5)
)
logger.trace("Backlog Test #1 Done")
assertEquals(
emptyList(),
session.backlogManager.backlogAll(limit = 5)
)
logger.trace("Backlog Test #2 Done")
assertEquals(
emptyList(),
session.backlogManager.backlogForward(bufferId = BufferId(1), limit = 5)
)
logger.trace("Backlog Test #3 Done")
}
channel.close()
}
......
......@@ -16,13 +16,14 @@ import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.runInterruptible
import java.io.Closeable
import java.net.InetSocketAddress
import java.net.Socket
import java.nio.ByteBuffer
import java.util.concurrent.Executors
import javax.net.ssl.SSLContext
class CoroutineChannel : StateHolder<CoroutineChannelState> {
class CoroutineChannel : StateHolder<CoroutineChannelState>, Closeable {
private lateinit var channel: StreamChannel
private val writeContext = Executors.newSingleThreadExecutor().asCoroutineDispatcher()
private val readContext = Executors.newSingleThreadExecutor().asCoroutineDispatcher()
......@@ -75,7 +76,7 @@ class CoroutineChannel : StateHolder<CoroutineChannelState> {
channel.flush()
}
fun close() {
override fun close() {
channel.close()
state.update {
copy(connected = false)
......
......@@ -17,12 +17,14 @@ import de.justjanne.libquassel.protocol.models.SignalProxyMessage
import de.justjanne.libquassel.protocol.serializers.HandshakeMessageSerializer
import de.justjanne.libquassel.protocol.serializers.SignalProxyMessageSerializer
import de.justjanne.libquassel.protocol.util.log.trace
import kotlinx.coroutines.coroutineScope
import org.slf4j.LoggerFactory
import java.io.Closeable
import java.nio.ByteBuffer
class MessageChannel(
val channel: CoroutineChannel
) {
) : Closeable {
var negotiatedFeatures = FeatureSet.none()
private var handlers = mutableListOf<ConnectionHandler>()
......@@ -98,7 +100,7 @@ class MessageChannel(
SignalProxyMessageSerializer.serialize(it, message, negotiatedFeatures)
}
suspend fun emit(sizePrefix: Boolean = true, f: (ChainedByteBuffer) -> Unit) {
suspend fun emit(sizePrefix: Boolean = true, f: (ChainedByteBuffer) -> Unit) = coroutineScope {
val sendBuffer = sendBuffer.get()
val sizeBuffer = sizeBuffer.get()
......@@ -115,6 +117,10 @@ class MessageChannel(
sendBuffer.clear()
}
override fun close() {
channel.close()
}
companion object {
private val logger = LoggerFactory.getLogger(MessageChannel::class.java)
}
......
......@@ -10,27 +10,50 @@
package de.justjanne.libquassel.protocol.session
import de.justjanne.libquassel.protocol.util.log.info
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.cancel
import kotlinx.coroutines.cancelAndJoin
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.slf4j.LoggerFactory
import java.io.Closeable
import java.nio.channels.ClosedChannelException
import java.util.concurrent.Executors
class MessageChannelReadThread(
val channel: MessageChannel
) : Thread("Message Channel Read Thread") {
override fun run() {
runBlocking {
class MessageChannelReader(
private val channel: MessageChannel
) : Closeable {
private val executor = Executors.newSingleThreadExecutor()
private val dispatcher = executor.asCoroutineDispatcher()
private val scope = CoroutineScope(dispatcher)
private var job: Job? = null
fun start() {
job = scope.launch {
try {
channel.init()
while (channel.channel.state().connected) {
while (isActive && channel.channel.state().connected) {
channel.read()
}
} catch (e: ClosedChannelException) {
logger.info { "Channel closed" }
close()
}
}
}
override fun close() {
channel.close()
runBlocking { job?.cancelAndJoin() }
scope.cancel()
dispatcher.cancel()
executor.shutdown()
}
companion object {
private val logger = LoggerFactory.getLogger(MessageChannelReadThread::class.java)
private val logger = LoggerFactory.getLogger(MessageChannelReader::class.java)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment