Skip to content

Commit

Permalink
Merge pull request #1 from h0tk3y/migration-to-1.1-M04
Browse files Browse the repository at this point in the history
Updated to 1.1-M04.
  • Loading branch information
h0tk3y committed Jan 9, 2017
2 parents 73b94f5 + a613d46 commit e1e8f44
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 32 deletions.
3 changes: 2 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
buildscript {
ext.kotlin_version = '1.1-M03'
ext.kotlin_version = '1.1-M04'

repositories {
mavenCentral()
Expand All @@ -19,6 +19,7 @@ repositories {

dependencies {
compile "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version"
compile "org.jetbrains.kotlin:kotlin-reflect:$kotlin_version"
testCompile "junit:junit:4.12"
}

Expand Down
58 changes: 33 additions & 25 deletions src/main/kotlin/com/github/h0tk3y/kotlinMonads/DoNotation.kt
Original file line number Diff line number Diff line change
@@ -1,55 +1,61 @@
@file:Suppress("EXPERIMENTAL_FEATURE_WARNING")

package com.github.h0tk3y.kotlinMonads

import java.io.Serializable
import java.util.*
import kotlin.jvm.internal.CoroutineImpl
import kotlin.coroutines.Continuation
import kotlin.coroutines.CoroutineIntrinsics
import kotlin.coroutines.startCoroutine

fun <M : Monad<M, *>, T> doWith(m: Monad<M, T>,
coroutine c: DoController<M, T>.(T) -> Continuation<Unit>): Monad<M, T> {
return (m bind { x -> doWith(this, x, c) })
}
c: suspend DoController<M, T>.() -> Unit): Monad<M, T> =
m.bind { t -> doWith(this, t, c) }


fun <M : Monad<M, *>, T> doWith(aReturn: Return<M>,
defaultValue: T,
coroutine c: DoController<M, T>.(T) -> Continuation<Unit>): Monad<M, T> {
c: suspend DoController<M, T>.() -> Unit): Monad<M, T> {
val controller = DoController(aReturn, defaultValue)
c(controller, defaultValue).resume(Unit)
c.startCoroutine(controller, object : Continuation<Unit> {
override fun resume(value: Unit) {}
override fun resumeWithException(exception: Throwable) = throw exception
})
return controller.lastResult
}

private fun <T, R> backupLabel(c: Continuation<T>, block: Continuation<T>.() -> R): R {
val reflect = CoroutineImpl::class.java
val labelField = reflect.getDeclaredField("label")
val labelField by lazy {
val jClass = Class.forName("kotlin.jvm.internal.RestrictedCoroutineImpl")
return@lazy jClass.getDeclaredField("label").apply { isAccessible = true }
}

labelField.isAccessible = true
val l = labelField.get(c)
labelField.isAccessible = false
var <T> Continuation<T>.label
get() = labelField.get(this)
set(value) = labelField.set(this@label, value)

private fun <T, R> backupLabel(c: Continuation<T>, block: Continuation<T>.() -> R): R {
val backupLabel = c.label
val r = block(c)

labelField.isAccessible = true
labelField.set(c, l)
labelField.isAccessible = false

c.label = backupLabel
return r
}

class DoController<M : Monad<M, *>, T>(val returning: Return<M>,
initialValue: T) : Serializable {
var lastResult: Monad<M, T> = returning.returns(initialValue)
private set
val value: T) : Serializable, Return<M> by returning {
var lastResult: Monad<M, T> = returning.returns(value)
internal set

private val stackSignals = Stack<Boolean>().apply { push(false) }

fun <T> returns(t: T) = returning.returns(t)

suspend fun bind(m: Monad<M, T>, c: Continuation<T>) {
suspend fun bind(m: Monad<M, T>): T = CoroutineIntrinsics.suspendCoroutineOrReturn { c ->
stackSignals.pop()
stackSignals.push(true)
var anyCont = false
val o = m.bind { x ->
stackSignals.push(false)
backupLabel(c) { c.resume(x) }
backupLabel(c) {
c.resume(x)
}
val contHasMonad = stackSignals.pop()
if (contHasMonad) {
anyCont = true
Expand All @@ -59,9 +65,10 @@ class DoController<M : Monad<M, *>, T>(val returning: Return<M>,
}
}
lastResult = if (anyCont) o else m
CoroutineIntrinsics.SUSPENDED
}

suspend fun then(m: Monad<M, T>, c: Continuation<Unit>) {
suspend fun then(m: Monad<M, T>) = CoroutineIntrinsics.suspendCoroutineOrReturn<Unit> { c ->
stackSignals.pop()
stackSignals.push(true)
var anyCont = false
Expand All @@ -77,6 +84,7 @@ class DoController<M : Monad<M, *>, T>(val returning: Return<M>,
}
}
lastResult = if (anyCont) o else m
CoroutineIntrinsics.SUSPENDED
}
}

Expand Down
14 changes: 8 additions & 6 deletions src/test/kotlin/com/github/h0tk3y/kotlinMonads/DoNotationTest.kt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
@file:Suppress("EXPERIMENTAL_FEATURE_WARNING")

package com.github.h0tk3y.kotlinMonads

import org.junit.Assert.assertEquals
Expand All @@ -6,8 +8,8 @@ import org.junit.Test

class DoNotationTest {
@Test fun testLinearDo() {
val m = doWith(just(1)) { i ->
val j = bind(returns(i * 2))
val m = doWith(just(1)) {
val j = bind(returns(value * 2))
val k = bind(returns(j * 3))
then(returns(k + 1))
}
Expand All @@ -16,8 +18,8 @@ class DoNotationTest {

@Test fun testControlFlow() {
var called = false
val m = doWith(just(1)) { i ->
val j = bind(returns(i * 2))
val m = doWith(just(1)) {
val j = bind(returns(value * 2))
val k = bind(if (j % 2 == 0) none() else just(j))
called = true
then(returns(k))
Expand Down Expand Up @@ -56,8 +58,8 @@ class DoNotationTest {

@Test fun testBindLastStatement() {
val results = mutableListOf<Int>()
val m = doWith(monadListOf(2)) { i ->
val x = bind(monadListOf(i + 1, i * i))
val m = doWith(monadListOf(2)) {
val x = bind(monadListOf(value + 1, value * value))
val z = bind(returns(x))
results.add(z)
}
Expand Down

0 comments on commit e1e8f44

Please sign in to comment.