diff --git a/Project.toml b/Project.toml index ac7a29d8..7f1394de 100644 --- a/Project.toml +++ b/Project.toml @@ -43,6 +43,7 @@ Artifacts = "1.10" ChainRulesCore = "1.24" ComponentArrays = "0.15.13" ConcreteStructs = "0.2.3" +DataInterpolations = "5.2.0" ExplicitImports = "1.5" ForwardDiff = "0.10.36" GPUArraysCore = "0.1.6" diff --git a/test/layer_tests.jl b/test/layer_tests.jl index 080bc240..4f92ae81 100644 --- a/test/layer_tests.jl +++ b/test/layer_tests.jl @@ -150,8 +150,8 @@ end ongpu && continue @testset "$(spl): train_grid $(train_grid), dims $(dims)" for spl in ( - ConstantInterpolation, LinearInterpolation, QuadraticInterpolation, - QuadraticSpline, CubicSpline), + ConstantInterpolation, LinearInterpolation, + QuadraticInterpolation, QuadraticSpline, CubicSpline), train_grid in (true, false), dims in ((), (8,)) @@ -164,12 +164,12 @@ end y, st = spline(x, ps, st) @test size(y) == (dims..., 4) - @jet spline(x, ps, st) + @jet spline(x, ps, st) opt_broken=!ongpu # See SciML/DataInterpolations.jl/issues/267 y, st = spline(x, ps_ca, st) @test size(y) == (dims..., 4) - @jet spline(x, ps_ca, st) + @jet spline(x, ps_ca, st) opt_broken=!ongpu # See SciML/DataInterpolations.jl/issues/267 ∂x, ∂ps = Zygote.gradient((x, ps) -> sum(abs2, first(spline(x, ps, st))), x, ps) spl !== ConstantInterpolation && @test ∂x !== nothing