Skip to content

Commit

Permalink
#2 forless: init forless functional monad cps with implementation of …
Browse files Browse the repository at this point in the history
…functional monad cps
  • Loading branch information
AraneusRota committed Jul 26, 2021
1 parent 6cc782a commit 32c73ef
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package differentiable.reversemode.monadcps.functional.forless

import differentiable.reversemode.monadcps.functional.forless.Monad.wrap

import java.lang.Math.{pow, toIntExact}
import scala.annotation.showAsInfix
import scala.compiletime.ops.int._
import scala.compiletime.{constValue, erasedValue}
import scala.language.implicitConversions
import scala.quoted._

case class Monad(val y: Num, derivationUpdater: Deriv => Deriv):
def flatMap(k: Num => Monad): Monad =
val that = k(y)
Monad(that.y, that.derivationUpdater andThen this.derivationUpdater )

def map(k: Num => Num): Monad =
flatMap(k andThen wrap)

object Monad:
def wrap(n: Num): Monad = Monad(n, identity)

class Num(val x: Double):
def +(that: Num): Monad =
val y = Num(this.x + that.x)
Monad(
y,
deriv =>
val derivWithThis = updateDeriv(deriv, this, y){ identity }
updateDeriv(derivWithThis, that, y){ identity }
)

def *(that: Num): Monad =
val y = Num(this.x * that.x)
Monad(
y,
deriv =>
val derivWithThis = updateDeriv(deriv, this, y){ that.x * _ }
updateDeriv(derivWithThis, that, y){ this.x * _ }
)

private def updateDeriv(deriv: Deriv, key: Num, y: Num)(op: (yd: Double) => Double): Deriv =
deriv + (key -> (
deriv(key) + op(deriv(y))
))

override def toString: String = x.toString
end Num

given Conversion[Double, Num] = Num(_)
given Conversion[Int, Num] = Num(_)

def grad(f: Monad => Monad)(x: Double): Double =
val xM = wrap(Num(x))
val topMonad = f(xM)
val initialDeriv = Map.empty.withDefaultValue(0.0).updated(topMonad.y, 1.0)
println(topMonad)
topMonad.derivationUpdater(initialDeriv)(xM.y)

type Deriv = Map[Num, Double]

@main def main() =
def f(xM: Monad): Monad =
for
x <- xM
y1 <- x * 2
y2 <- x * x
y3 <- y2 * x // y3' = y2' * x + y2 * x'
y4 <- y1 + y3
yield y4

def g(xM: Monad): Monad =
for
x <- xM
y <- x * x
yield y

println(grad(f andThen g)(3))
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package differentiable.reversemode.monadcps.functional.forless

import scala.quoted._

inline def !(inline n: Num): Monad = ${impl('n)}

private def impl(n: Expr[Num])(using Quotes): Expr[Monad] = ???

0 comments on commit 32c73ef

Please sign in to comment.