Skip to content

Commit

Permalink
Removed FFI in place of exception package
Browse files Browse the repository at this point in the history
  • Loading branch information
lucaspellegrinelli committed Jun 25, 2024
1 parent d45db22 commit 145e213
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 19 deletions.
1 change: 1 addition & 0 deletions gleam.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ gleam_stdlib = ">= 0.34.0 and < 2.0.0"
gleam_otp = ">= 0.10.0 and < 1.0.0"
nx = ">= 0.7.2 and < 1.0.0"
gleam_community_maths = ">= 1.1.1 and < 2.0.0"
exception = ">= 2.0.0 and < 3.0.0"

[dev-dependencies]
gleeunit = ">= 1.0.0 and < 2.0.0"
Expand Down
4 changes: 3 additions & 1 deletion manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

packages = [
{ name = "complex", version = "0.5.0", build_tools = ["mix"], requirements = [], otp_app = "complex", source = "hex", outer_checksum = "2683BD3C184466CFB94FAD74CBFDDFAA94B860E27AD4CA1BFFE3BFF169D91EF1" },
{ name = "exception", version = "2.0.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "exception", source = "hex", outer_checksum = "F5580D584F16A20B7FCDCABF9E9BE9A2C1F6AC4F9176FA6DD0B63E3B20D450AA" },
{ name = "gleam_bitwise", version = "1.3.1", build_tools = ["gleam"], requirements = [], otp_app = "gleam_bitwise", source = "hex", outer_checksum = "B36E1D3188D7F594C7FD4F43D0D2CE17561DE896202017548578B16FE1FE9EFC" },
{ name = "gleam_community_maths", version = "1.1.1", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleam_community_maths", source = "hex", outer_checksum = "6C4ED7BC7E7DF6977719B5F2CFE717EE8280D1CF6EA81D55FD9953758C7FD14E" },
{ name = "gleam_erlang", version = "0.25.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleam_erlang", source = "hex", outer_checksum = "054D571A7092D2A9727B3E5D183B7507DAB0DA41556EC9133606F09C15497373" },
Expand All @@ -15,7 +16,8 @@ packages = [
]

[requirements]
gleam_community_maths = { version = ">= 1.1.1 and < 2.0.0"}
exception = { version = ">= 2.0.0 and < 3.0.0"}
gleam_community_maths = { version = ">= 1.1.1 and < 2.0.0" }
gleam_otp = { version = ">= 0.10.0 and < 1.0.0" }
gleam_stdlib = { version = ">= 0.34.0 and < 2.0.0" }
gleeunit = { version = ">= 1.0.0 and < 2.0.0" }
Expand Down
12 changes: 0 additions & 12 deletions src/ffi.ex

This file was deleted.

2 changes: 1 addition & 1 deletion src/gleastsq/errors.gleam
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ pub type FitErrors {
NonConverged
WrongParameters(String)
JacobianTaskError
SolveError
SolveError(String)
}
2 changes: 1 addition & 1 deletion src/gleastsq/internal/methods/gauss_newton.gleam
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ fn do_gauss_newton(
let jtj = nx.add(nx.dot(jt, j), eye)
let jt_r = nx.dot(jt, r)

use delta_solve <- result.try(result.replace_error(
use delta_solve <- result.try(result.map_error(
nx.solve(jtj, jt_r),
SolveError,
))
Expand Down
2 changes: 1 addition & 1 deletion src/gleastsq/internal/methods/levenberg_marquardt.gleam
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ fn do_levenberg_marquardt(
let h_damped = nx.add(nx.dot(jt, j), lambda_eye)
let g = nx.dot(jt, r)

use delta_solve <- result.try(result.replace_error(
use delta_solve <- result.try(result.map_error(
nx.solve(h_damped, g),
SolveError,
))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ fn dogleg(
let p_u_denominator = nx.dot(g, nx.dot(jt, jg))
let p_u = nx.multiply_mat(nx.divide_mat(p_u_numerator, p_u_denominator), g)

use bg_solve <- result.try(result.replace_error(nx.solve(b, g), SolveError))
use bg_solve <- result.try(result.map_error(nx.solve(b, g), SolveError))
let p_b = nx.negate(bg_solve)

let p_b_norm = nx.norm(p_b) |> nx.to_number
Expand Down
20 changes: 18 additions & 2 deletions src/gleastsq/internal/nx.gleam
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import exception
import gleam/dynamic
import gleam/result

pub type NxTensor =
Nil

Expand Down Expand Up @@ -53,8 +57,8 @@ pub fn to_list_1d(a: NxTensor) -> List(Float)
@external(erlang, "Elixir.Nx", "to_number")
pub fn to_number(a: NxTensor) -> Float

@external(erlang, "Elixir.NxBindings", "safe_solve")
pub fn solve(a: NxTensor, b: NxTensor) -> Result(NxTensor, String)
@external(erlang, "Elixir.Nx.LinAlg", "solve")
pub fn unsafe_solve(a: NxTensor, b: NxTensor) -> NxTensor

@external(erlang, "Elixir.Nx.LinAlg", "norm")
pub fn norm(a: NxTensor) -> NxTensor
Expand All @@ -70,3 +74,15 @@ pub fn divide_mat(a: NxTensor, b: NxTensor) -> NxTensor

@external(erlang, "Elixir.Nx", "concatenate")
pub fn concatenate(a: List(NxTensor), opts opts: List(NxOpts)) -> NxTensor

pub fn solve(a: NxTensor, b: NxTensor) -> Result(NxTensor, String) {
case exception.rescue(fn() { unsafe_solve(a, b) }) {
Ok(r) -> Ok(r)
Error(e) ->
case e {
exception.Errored(e) ->
Error(result.unwrap(dynamic.string(e), "Error solving matrix"))
_ -> panic as "Unexpected error while solving matrix"
}
}
}

0 comments on commit 145e213

Please sign in to comment.