Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add macro trixi_include_changeprecision to make a double precision elixir run with single precision #35

Merged
merged 16 commits into from
Jan 28, 2025
Merged
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "TrixiBase"
uuid = "9a0f1c46-06d5-4909-a5a3-ce25d3fa3284"
authors = ["Michael Schlottke-Lakemper <michael@sloede.com>"]
version = "0.1.5-DEV"
version = "0.1.5"

[deps]
ChangePrecision = "3cb15238-376d-56a3-8042-d33272777c9a"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

[weakdeps]
Expand All @@ -13,6 +14,7 @@ MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
TrixiBaseMPIExt = "MPI"

[compat]
ChangePrecision = "1.1.0"
MPI = "0.20"
TimerOutputs = "0.5.25"
julia = "1.8"
Expand Down
3 changes: 2 additions & 1 deletion src/TrixiBase.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
module TrixiBase

using ChangePrecision: ChangePrecision
using TimerOutputs: TimerOutput, TimerOutputs

include("trixi_include.jl")
include("trixi_timeit.jl")

export trixi_include
export trixi_include, trixi_include_changeprecision
export @trixi_timeit, timer, timeit_debug_enabled,
disable_debug_timings, enable_debug_timings

Expand Down
60 changes: 57 additions & 3 deletions src/trixi_include.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# of `TrixiBase`. However, users will want to evaluate in the global scope of `Main` or something
# similar to manage dependencies on their own.
"""
trixi_include([mod::Module=Main,] elixir::AbstractString; kwargs...)
trixi_include([mapexpr::Function,] [mod::Module=Main,] elixir::AbstractString; kwargs...)
ranocha marked this conversation as resolved.
Show resolved Hide resolved

`include` the file `elixir` and evaluate its content in the global scope of module `mod`.
You can override specific assignments in `elixir` by supplying keyword arguments.
Expand All @@ -16,6 +16,10 @@ into calls to `solve` with it's default value used in the SciML ecosystem
for ODEs, see the "Miscellaneous" section of the
[documentation](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/).

The optional first argument `mapexpr` can be used to transform the included code before
it is evaluated: for each parsed expression `expr` in `elixir`, the `include` function
actually evaluates `mapexpr(expr)`. If it is omitted, `mapexpr` defaults to `identity`.

# Examples

```@example
Expand All @@ -30,7 +34,7 @@ julia> redirect_stdout(devnull) do
0.1
```
"""
function trixi_include(mod::Module, elixir::AbstractString; kwargs...)
function trixi_include(mapexpr::Function, mod::Module, elixir::AbstractString; kwargs...)
efaulhaber marked this conversation as resolved.
Show resolved Hide resolved
# Check that all kwargs exist as assignments
code = read(elixir, String)
expr = Meta.parse("begin \n$code \nend")
Expand All @@ -45,13 +49,63 @@ function trixi_include(mod::Module, elixir::AbstractString; kwargs...)
if !mpi_isparallel(Val{:MPIExt}())
@info "You just called `trixi_include`. Julia may now compile the code, please be patient."
end
Base.include(ex -> replace_assignments(insert_maxiters(ex); kwargs...), mod, elixir)
Base.include(ex -> replace_assignments(insert_maxiters(mapexpr(ex)); kwargs...),
ranocha marked this conversation as resolved.
Show resolved Hide resolved
mod, elixir)
ranocha marked this conversation as resolved.
Show resolved Hide resolved
end

function trixi_include(mod::Module, elixir::AbstractString; kwargs...)
trixi_include(identity, mod, elixir; kwargs...)
end

function trixi_include(elixir::AbstractString; kwargs...)
trixi_include(Main, elixir; kwargs...)
end

"""
trixi_include_changeprecision(T, [mod::Module=Main,] elixir::AbstractString; kwargs...)

`include` the elixir `elixir` and evaluate its content in the global scope of module `mod`.
You can override specific assignments in `elixir` by supplying keyword arguments,
similar to [`trixi_include`](@ref).

The only difference to [`trixi_include`](@ref) is that the precision of floating-point
numbers in the included elixir is changed to `T`.
More precisely, the package [ChangePrecision.jl](https://github.com/JuliaMath/ChangePrecision.jl)
is used to convert all `Float64` literals, operations like `/` that produce `Float64` results,
and functions like `ones` that return `Float64` arrays by default, to the desired type `T`.
See the documentation of ChangePrecision.jl for more details.

The purpose of this function is to conveniently run a full simulation with `Float32`,
which is orders of magnitude faster on most GPUs than `Float64`, by just including
the elixir with `trixi_include_changeprecision(Float32, elixir)`.
Many constructors in the Trixi.jl framework are written in a way that changing all floating-point
arguments to `Float32` will change the element type to `Float32` as well.
In TrixiParticles.jl, including an elixir with this macro should be sufficient
to run the full simulation with single precision.
"""
function trixi_include_changeprecision(T, mod::Module, filename::AbstractString; kwargs...)
trixi_include(expr -> ChangePrecision.changeprecision(T, replace_trixi_include(T, expr)),
mod, filename; kwargs...)
end

function trixi_include_changeprecision(T, filename::AbstractString; kwargs...)
trixi_include_changeprecision(T, Main, filename; kwargs...)
end
sloede marked this conversation as resolved.
Show resolved Hide resolved

function replace_trixi_include(T, expr)
expr = TrixiBase.walkexpr(expr) do x
if x isa Expr
if x.head === :call && x.args[1] === :trixi_include
x.args[1] = :trixi_include_changeprecision
insert!(x.args, 2, :($T))
end
end
return x
end

return expr
end
Comment on lines +95 to +107
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more problem here. This doesn't work when, for some reason, we would include an elixir that looks like this:

using TrixiBase

TrixiBase.trixi_include(...)

It only works when we directly write the elixir as

using TrixiBase

trixi_include(...)

@ranocha any ideas how to fix this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you special case this with something like Symbol("TrixiBase.trixi_include")?

Alternatively, @vchuravy might have an idea how this can be done?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried, didn't work. Also, we would need to add a special case for each package that re-exports this name.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be a GlobalRef(mod, name).

So x.args[1] isa GlobalRef and then check the name and replace it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't work. x.args[1] isa GlobalRef is false.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

julia> TrixiBase.trixi_include_changeprecision(Float32, "../TrixiBase.jl/test1.jl")
[ Info: You just called `trixi_include`. Julia may now compile the code, please be patient.
ERROR: LoadError: UndefVarError: `TrixiBase.trixi_include_changeprecision` not defined in `Main`
Suggestion: check for spelling errors or missing imports.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it looks like TrixiBase.trixi_include_changeprecision is defined... Does your test file create a module? How does it import trixi_include?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idk. The test file just contains this single line:

TrixiBase.trixi_include("test2.jl")

And the error message says we're in Main.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or should we merge it as is and discuss this in an issue? I don't feel like wasting more time on this edge case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. IMHO if it is possible I'd just hard-code two versions for trixi_include and TrixiBase.trixi_include. If that's not feasible, just put it in the docstring that you need to invoke this in a certain way.


# Insert the keyword argument `maxiters` into calls to `solve` and `Trixi.solve`
# with default value `10^5` if it is not already present.
function insert_maxiters(expr)
Expand Down
Loading