Skip to content

Commit

Permalink
Generate handlers explicitly to cut down on runtime costs
Browse files Browse the repository at this point in the history
  • Loading branch information
b-studios committed Dec 9, 2023
1 parent 6097095 commit 8c98dcb
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ object TransformerDirect extends Transformer {
* Translation of expressions is trivial
*/
def toJS(expr: core.Expr)(using TransformerContext): js.Expr = expr match {
case Literal((), _) => js.Member($effekt, JSName("unit"))
case Literal((), _) => $effekt.field("unit")
case Literal(s: String, _) => JsString(s)
case literal: Literal => js.RawExpr(literal.value.toString)
case ValueVar(id, tpe) => nameRef(id)
Expand Down Expand Up @@ -149,15 +149,15 @@ object TransformerDirect extends Transformer {
def entrypoint(result: JSName, k: JSName, vars: List[JSName], s: List[js.Stmt]): List[js.Stmt] =
val suspension = freshName("suspension")
val frame = js.Lambda(List(result), js.Call(js.Variable(k), js.Variable(result) :: vars.map(js.Variable.apply)))
List(js.Try(s, suspension, List(js.Return(js.builtin("push",js.Variable(suspension), frame)))))
List(js.Try(s, suspension, List(js.Return($effekt.call("push",js.Variable(suspension), frame)))))

def toJS(s: core.Stmt)(using C: TransformerContext): Bind = s match {

case Scope(definitions, body) =>
Bind { k => definitions.flatMap { toJS } ++ toJS(body)(k) }

case Alloc(id, init, region, body) =>
val jsRegion = if region == symbols.builtins.globalRegion then js.Member($effekt, JSName("global")) else nameRef(region)
val jsRegion = if region == symbols.builtins.globalRegion then $effekt.field("global") else nameRef(region)
Bind { k =>
js.Const(nameDef(id), js.MethodCall(jsRegion, `fresh`, toJS(init))) :: toJS(body)(k)
}
Expand Down Expand Up @@ -210,7 +210,7 @@ object TransformerDirect extends Transformer {

case Var(id, init, cap, body) =>
Bind { k =>
js.Const(nameDef(id), js.builtin("fresh", toJS(init))) :: toJS(body)(k)
js.Const(nameDef(id), $effekt.call("fresh", toJS(init))) :: toJS(body)(k)
}

// obviously recursive calls
Expand Down Expand Up @@ -266,16 +266,16 @@ object TransformerDirect extends Transformer {
val suspension = freshName("suspension")
val prompt = freshName("prompt")

val promptDef = js.Const(prompt, js.builtin("freshPrompt"))
val freshRegion = js.ExprStmt(js.builtin("freshRegion"))
val regionCleanup = js.ExprStmt(js.builtin("leaveRegion"))
val promptDef = js.Const(prompt, $effekt.call("freshPrompt"))
val freshRegion = js.ExprStmt($effekt.call("freshRegion"))
val regionCleanup = js.ExprStmt($effekt.call("leaveRegion"))

val (handlerNames, handlerDefs) = (bps zip hs).map {
case (param, handler) => (toJS(param), js.Const(toJS(param), toJS(handler, prompt)))
}.unzip

Bind { k => promptDef :: handlerDefs ::: (js.Try(freshRegion :: toJS(body)(k), suspension,
List(k(js.builtin("handle", js.Variable(prompt), js.Variable(suspension)))), List(regionCleanup)) :: Nil)
List(k($effekt.call("handle", js.Variable(prompt), js.Variable(suspension)))), List(regionCleanup)) :: Nil)
}

case Try(_, _) =>
Expand All @@ -285,19 +285,19 @@ object TransformerDirect extends Transformer {
val suspension = freshName("suspension")
val region = nameDef(r.id)

val freshRegion = js.Const(region, js.builtin("freshRegion"))
val regionCleanup = js.ExprStmt(js.builtin("leaveRegion"))
val freshRegion = js.Const(region, $effekt.call("freshRegion"))
val regionCleanup = js.ExprStmt($effekt.call("leaveRegion"))

Bind { k =>
js.Try(freshRegion :: toJS(body)(k), suspension,
List(k(js.builtin("handle", js.Undefined, js.Variable(suspension)))), List(regionCleanup)) :: Nil
List(k($effekt.call("handle", js.Undefined, js.Variable(suspension)))), List(regionCleanup)) :: Nil
}

case Region(_) =>
Context.panic("Body of the region is expected to be a block literal in core.")

case Hole() =>
Return(js.builtin("hole"))
Return($effekt.call("hole"))

case Get(id, capt, tpe) => Context.panic("Should have been translated to direct style")
case Put(id, capt, value) => Context.panic("Should have been translated to direct style")
Expand All @@ -315,15 +315,15 @@ object TransformerDirect extends Transformer {
val biArgs = biParams.map { p => js.Variable(p) }

val lambda = js.Lambda((vps ++ bps).map(toJS) ++ biParams,
js.Return(js.builtin("suspend_bidirectional", js.Variable(prompt), js.ArrayLiteral(biArgs), js.Lambda(List(nameDef(resume)),
js.Return($effekt.call("suspend_bidirectional", js.Variable(prompt), js.ArrayLiteral(biArgs), js.Lambda(List(nameDef(resume)),
C.clearingScope { js.Block(toJS(body)(Continuation.Return)) }))))

nameDef(id) -> lambda

// (args...) => $effekt.suspend(prompt, (resume) => { ... BODY ... resume(v) ... })
case Operation(id, tps, cps, vps, bps, Some(resume), body) =>
val lambda = js.Lambda((vps ++ bps) map toJS,
js.Return(js.builtin("suspend", js.Variable(prompt),
js.Return($effekt.call("suspend", js.Variable(prompt),
js.Lambda(List(toJS(resume)), C.clearingScope { js.Block(toJS(body)(Continuation.Return)) }))))

nameDef(id) -> lambda
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ object TransformerMonadic extends Transformer {
}

def toJS(expr: core.Expr)(using DeclarationContext, Context): js.Expr = expr match {
case Literal((), _) => js.Member($effekt, JSName("unit"))
case Literal((), _) => $effekt.field("unit")
case Literal(s: String, _) => JsString(s)
case literal: Literal => js.RawExpr(literal.value.toString)
case ValueVar(id, tpe) => nameRef(id)
Expand All @@ -76,6 +76,36 @@ object TransformerMonadic extends Transformer {
nameDef(id) -> monadic.Lambda((vps ++ bps ++ resume.toList) map toJS, stmts, ret)
})

def toJS(handler: core.Implementation, prompt: js.Expr)(using DeclarationContext, Context): js.Expr =
js.Object(
handler.operations.map {
// // (args...cap...) => $effekt.suspend(prompt, (resume) => { ... body ... resume((cap...) => { ... }) ... })
// case Operation(id, tps, cps, vps, bps,
// Some(BlockParam(resume, core.BlockType.Function(_, _, _, List(core.BlockType.Function(_, _, _, bidirectionalTpes, _)), _), _)),
// body) =>
// // add parameters for bidirectional arguments
// val biParams = bidirectionalTpes.map { _ => freshName("cap") }
// val biArgs = biParams.map { p => js.Variable(p) }
//
// val lambda = js.Lambda((vps ++ bps).map(toJS) ++ biParams,
// js.Return($effekt.call("suspend_bidirectional", js.Variable(prompt), js.ArrayLiteral(biArgs), js.Lambda(List(nameDef(resume)),
// C.clearingScope { js.Block(toJS(body)(Continuation.Return)) }))))
//
// nameDef(id) -> lambda
// })

// (args...) => $effekt.shift(prompt, (resume) => { ... BODY ... resume(v) ... })
case Operation(id, tps, cps, vps, bps, Some(resume), body) =>
val (stmts, ret) = toJSStmt(body)
val lambda = js.Lambda((vps ++ bps) map toJS,
js.Return($effekt.call("shift", prompt,
monadic.Lambda(List(toJS(resume)), stmts, ret))))

nameDef(id) -> lambda

case Operation(id, tps, cps, vps, bps, None, body) => Context.panic("Effect handler should take continuation")
})

def toJS(module: core.ModuleDecl, imports: List[js.Import], exports: List[js.Export])(using DeclarationContext, Context): js.Module = {
val name = JSName(jsModuleName(module.path))
val externs = module.externs.map(toJS)
Expand Down Expand Up @@ -114,8 +144,22 @@ object TransformerMonadic extends Transformer {
case Return(e) =>
monadic.Pure(toJS(e))

case Try(body, hs) =>
monadic.Handle(hs map toJS, toJS(body))
// $effekt.handle(p => {
// const amb = { flip: ... };
//
// })
case Try(core.BlockLit(_, _, _, bps, body), hs) =>
val prompt = freshName("p")

val handlerDefs = (bps zip hs).map {
case (param, handler) => js.Const(toJS(param), toJS(handler, js.Variable(prompt)))
}
val (stmts, ret) = toJSStmt(body)

monadic.Handle(monadic.Lambda(List(prompt), handlerDefs ++ stmts, ret))

case Try(_, _) =>
Context.panic("Body of the try is expected to be a block literal in core.")

case Region(body) =>
monadic.Builtin("withRegion", toJS(body))
Expand Down Expand Up @@ -165,7 +209,7 @@ object TransformerMonadic extends Transformer {

case Alloc(id, init, region, body) if region == symbols.builtins.globalRegion =>
val (stmts, ret) = toJSStmt(body)
(js.Const(nameDef(id), js.MethodCall($effekt, `ref`, toJS(init))) :: stmts, ret)
(js.Const(nameDef(id), $effekt.call("ref", toJS(init))) :: stmts, ret)

case Alloc(id, init, region, body) =>
val (stmts, ret) = toJSStmt(body)
Expand Down
15 changes: 11 additions & 4 deletions effekt/shared/src/main/scala/effekt/generator/js/Tree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,18 @@ package effekt
package generator
package js

import scala.collection.immutable.{ AbstractSeq, LinearSeq }

// TODO choose appropriate representation and apply conversions
case class JSName(name: String)

val $effekt = Variable(JSName("$effekt"))
def builtin(name: String, args: Expr*): js.Expr = js.MethodCall($effekt, JSName(name), args: _*)
object $effekt {
val namespace = Variable(JSName("$effekt"))
def field(name: String): js.Expr =
js.Member(namespace, JSName(name))
def call(name: String, args: js.Expr*): js.Expr =
js.MethodCall(namespace, JSName(name), args: _*)
}

enum Import {
// import * as <name> from "<file>";
Expand Down Expand Up @@ -219,9 +226,9 @@ object monadic {

def Call(callee: Expr, args: List[Expr]): Control = js.Call(callee, args)
def If(cond: Expr, thn: Control, els: Control): Control = js.IfExpr(cond, thn, els)
def Handle(handlers: List[Expr], body: Expr): Control = js.Call(Builtin("handleMonadic", js.ArrayLiteral(handlers)), List(body))
def Handle(body: Expr): Control = Builtin("handleMonadic", body)

def Builtin(name: String, args: Expr*): Control = js.MethodCall($effekt, JSName(name), args: _*)
def Builtin(name: String, args: Expr*): Control = $effekt.call(name, args: _*)

def Lambda(params: List[JSName], stmts: List[Stmt], ret: Control): Expr =
js.Lambda(params, js.Block(stmts :+ js.Return(ret)))
Expand Down
34 changes: 5 additions & 29 deletions libraries/js/effekt_runtime.js
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ const $runtime = (function() {

const delayed = a => Control(k => apply(k, a()))

const shift = p => f => Control(k => {
const shift = (p, f) => Control(k => {
const split = splitAt(k, p)
const localCont = a => Control(k =>
Step(pure(a), pushSubcont(split.head, k)))
Expand All @@ -222,36 +222,11 @@ const $runtime = (function() {
})).then(a => a.shouldRun ? a.cont() : $effekt.pure(a.cont))
}

const reset = p => c => Control(k => Step(c, Stack(Nil, Arena(), p, k)))
//const reset = (p, c => Control(k => Step(c, Stack(Nil, Arena(), p, k)))

function handleMonadic(handlers) {
function handleMonadic(body) {
const p = _prompt++;

// modify all implementations in the handlers to capture the continuation at prompt p
const caps = handlers.map(h => {
var cap = Object.create({})
for (var op in h) {
const impl = h[op];
cap[op] = function() {
// split two kinds of arguments, parameters of the operation and capabilities
const args = Array.from(arguments);
const arity = impl.length - 1
const oargs = args.slice(0, arity)
const caps = args.slice(arity)
var r = shift(p)(k => impl.apply(null, oargs.concat([k])))
// resume { caps => e}
if (caps.length > 0) {
return r.then(f => f.apply(null, caps))
}
// resume(v)
else {
return r
}
}
}
return cap;
});
return body => reset(p)(body.apply(null, caps))
return Control(k => Step(body(p), Stack(Nil, Arena(), p, k)))
}

// Direct Style Runtime
Expand Down Expand Up @@ -364,6 +339,7 @@ const $runtime = (function() {
handleMonadic: handleMonadic,
ref: Cell,
state: withState,
shift: shift,
_if: (c, thn, els) => c ? thn() : els(),
withRegion: withRegion,

Expand Down

0 comments on commit 8c98dcb

Please sign in to comment.