Skip to content

Commit

Permalink
Fixed TRR g_norm stopping condition
Browse files Browse the repository at this point in the history
  • Loading branch information
lucaspellegrinelli committed Jun 25, 2024
1 parent 2b6dd58 commit 0224be6
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ fn do_trust_region_reflective(
let b = nx.add(nx.dot(jt, j), lambda_eye)
let g = nx.dot(jt, r)

let g_norm = nx.norm(g) |> nx.to_number
let g_norm = nx.reduce_max(nx.abs(g)) |> nx.to_number
use <- bool.guard(g_norm <. tolerance, Ok(params))

use non_bounded_p <- result.try(dogleg(j, g, b, delta))
Expand Down
6 changes: 6 additions & 0 deletions src/gleastsq/internal/nx.gleam
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,15 @@ pub fn eye(n: Int) -> NxTensor
@external(erlang, "Elixir.Nx", "min")
pub fn min(a: NxTensor, b: NxTensor) -> NxTensor

@external(erlang, "Elixir.Nx", "abs")
pub fn abs(a: NxTensor) -> NxTensor

@external(erlang, "Elixir.Nx", "max")
pub fn max(a: NxTensor, b: NxTensor) -> NxTensor

@external(erlang, "Elixir.Nx", "reduce_max")
pub fn reduce_max(a: NxTensor) -> NxTensor

@external(erlang, "Elixir.Nx", "dot")
pub fn dot(a: NxTensor, b: NxTensor) -> NxTensor

Expand Down

0 comments on commit 0224be6

Please sign in to comment.