Skip to content

Commit

Permalink
update to GF 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
PGimenez committed Feb 2, 2024
1 parent 9a2766f commit c0489e6
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 92 deletions.
40 changes: 16 additions & 24 deletions app.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
module App
using GenieFramework, PlotlyBase, JLD2, Statistics, DataFrames
using Interpolations
using Random
#= include("NodeUtils.jl") =#
using .NodeUtils
using .Delhi
#= include("delhi.jl") =#
#= include("utils.jl") =#
# these modules are automatically loaded when running Genie.loadapp()
using ..NodeUtils
using ..Delhi
using ..Utils
@genietools

prod_mode = (haskey(ENV, "GENIE_ENV") && ENV["GENIE_ENV"] == "prod") ? "true" : "false"
button_color = prod_mode == "true" ? "grey" : "primary"
button_tooltip = prod_mode == "true" ? "Run the app locally to enable this button" : ""
const prod_mode = (haskey(ENV, "GENIE_ENV") && ENV["GENIE_ENV"] == "prod") ? "true" : "false"
const button_color = prod_mode == "true" ? "grey" : "primary"
const button_tooltip = prod_mode == "true" ? "Run the app locally to enable this button" : ""

rng = MersenneTwister(123)
const rng = MersenneTwister(123)
if isfile("data.jld2")
@load "data.jld2" train_df test_df scaling
else
Expand All @@ -24,14 +24,14 @@ end
if isfile("params.jld")
@load "params.jld" θ
end
features = [:meantemp, :humidity, :wind_speed, :meanpressure]
units = ["Celsius", "g/m³ of water", "km/h", "hPa"]
feature_names = ["Mean temperature", "Humidity", "Wind speed", "Mean pressure"]
const features = [:meantemp, :humidity, :wind_speed, :meanpressure]
const units = ["Celsius", "g/m³ of water", "km/h", "hPa"]
const feature_names = ["Mean temperature", "Humidity", "Wind speed", "Mean pressure"]


data = vcat(train_df, test_df)
const data = vcat(train_df, test_df)
# Functions to interpolate when calculating the MSE
interpolators = [LinearInterpolation(data.t, data[!, col]) for col in names(data)]
const interpolators = [LinearInterpolation(data.t, data[!, col]) for col in names(data)]


# NODE parameters
Expand All @@ -48,7 +48,7 @@ t_grid = range(minimum(data.t), maximum(data.t), length=N_steps) |> collect
@in start=false
@in animate=false
@out prod_mode = prod_mode
@out θ=θ_new
@out θ=θ
@out losses = Float32[]
@out temp_pdata = [PlotlyBase.scatter(x=[1,2,3])]
@out hum_pdata = [PlotlyBase.scatter(x=[1,2,3])]
Expand Down Expand Up @@ -101,13 +101,5 @@ t_grid = range(minimum(data.t), maximum(data.t), length=N_steps) |> collect
end
end

ui() =[
h1("train and predict"),btn("Train", @click(:start), loading=:start),
range(1:100,:r),
cell(class="row"),
GenieFramework.plot(:temp_pdata, layout=:temp_layout),
GenieFramework.plot(:hum_pdata, layout=:hum_layout),
GenieFramework.plot(:wind_pdata, layout=:wind_layout),
GenieFramework.plot(:press_pdata, layout=:press_layout)
]
@page("/","app.jl.html")
end
117 changes: 60 additions & 57 deletions app.jl.html
Original file line number Diff line number Diff line change
@@ -1,66 +1,69 @@
<h1 style="text-align: center; color: #2c3e50; margin-bottom: 2rem; font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; font-size: 2.5rem; text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.1);">Delhi weather forecast</h1>
<header class="st-header q-pa-sm" style="text-align:center">
<h1 class="st-header__title text-h3">Weather data forecast</h1>
</header>

<div class="row st-module" style="padding:15px">
<div class="col-3 col-sm">
<q-badge color="secondary">
Prediction horizon (samples)
</q-badge>
<q-slider :min=30 v-model="r" :max=100 :step=1 label-always></q-slider>
<q-badge color="secondary">
Prediction step size (samples)
</q-badge>
<div class="q-gutter-sm">
<q-radio val=1 label="1" v-model="pstep"></q-radio>
<q-radio val=2 label="2" v-model="pstep"></q-radio>
<q-radio val=5 label="5" v-model="pstep"></q-radio>
</div>
</div>
<div class="col-3 col-sm" style="text-align:center">
<q-btn :loading="start" label="Train" v-on:click="start = true" color="$button_color" :disable="$disable_train">
<q-tooltip>
$button_tooltip
</q-tooltip>
</q-btn><br>
<q-btn style="margin-top:15px" :loading="animate" label="Animate" v-on:click="animate = true" color="primary"></q-btn><br>
<q-btn style="margin-top:15px" push color="secondary" label="app info">
<q-popup-proxy>
<q-banner>
This app uses a neural ordinary differential equation (NODE) to forecast weather data from Delhi. The forecast is implemented with the DiffeqFlux, DifferentialEquations, Optimization and Lux packages, and the code is based on the <a href="https://sebastiancallh.github.io/post/neural-ode-weather-forecast/">blog post</a> by Sebastian Callh.
</q-banner>
</q-popup-proxy>
</q-btn>
</div>
<div class=" col-5">
<div class=" col-6">
<h4 style="text-align:center;margin:0px">Mean-squared error</h4>
<div class="row">
<st-big-number :number="mse[0]" title="Temperature" ></st-big-number>
<st-big-number :number="mse[1]" title="Humidity" ></st-big-number>
<st-big-number :number="mse[2]" title="Wind" ></st-big-number>
<st-big-number :number="mse[3]" title="Pressure" ></st-big-number>
</div>
</div>
</div>
<div class="row">
<div class="st-col col-12 st-module">
<div class="row">
<div class="st-col col-12 col-sm ">
<h4 style="text-align:center;margin:0px">Temperature</h4>
<plotly :data="temp_pdata" :layout="temp_layout" :displaylogo="false"></plotly>
</div>
<div class="st-col col-12 col-sm ">
<h4 style="text-align:center;margin:0px">Humidity</h4>
<plotly :data="hum_pdata" :layout="hum_layout" :displaylogo="false"></plotly>
<st-big-number :number="mse[0]" title="Temperature" ></st-big-number>
<st-big-number :number="mse[1]" title="Humidity" ></st-big-number>
<st-big-number :number="mse[2]" title="Wind" ></st-big-number>
<st-big-number :number="mse[3]" title="Pressure" ></st-big-number>
</div>
</div>
<q-separator vertical />

<div class="col-4 col-sm" style="padding-left:50px;padding-top:15px">
<q-badge color="secondary">
Prediction horizon (samples)
</q-badge>
<q-slider :min=30 v-model="r" :max=100 :step=1 label-always></q-slider>
<q-badge color="secondary">
Prediction step size (samples)
</q-badge>
<div class="q-gutter-sm">
<q-radio val=1 label="1" v-model="pstep"></q-radio>
<q-radio val=2 label="2" v-model="pstep"></q-radio>
<q-radio val=5 label="5" v-model="pstep"></q-radio>
</div>
</div>
<div class="col-2 col-sm" style="text-align:center">
<q-btn :loading="start" label="Train" v-on:click="start = true" color="$button_color" :disable="$prod_mode">
<q-tooltip>
$button_tooltip
</q-tooltip>
</q-btn><br>
<q-btn style="margin-top:15px" :loading="animate" label="Animate" v-on:click="animate = true" color="primary"></q-btn><br>
<q-btn style="margin-top:15px" push color="secondary" label="app info">
<q-popup-proxy>
<q-banner>
This app uses a neural ordinary differential equation (NODE) to forecast weather data from Delhi. The forecast is implemented with the DiffeqFlux, DifferentialEquations, Optimization and Lux packages, and the code is based on the <a href="https://sebastiancallh.github.io/post/neural-ode-weather-forecast/">blog post</a> by Sebastian Callh.
</q-banner>
</q-popup-proxy>
</q-btn>
</div></div>
<div class="row">
<div class="st-col col-12 col-sm ">
<h4 style="text-align:center;margin:0px">Wind</h4>
<plotly :data="wind_pdata" :layout="wind_layout" :displaylogo="false"></plotly>
</div>
<div class="st-col col-12 col-sm ">
<h4 style="text-align:center;margin:0px">Pressure</h4>
<plotly :data="press_pdata" :layout="press_layout" :displaylogo="false"></plotly>
<div class="st-col col-12 st-module">
<div class="row">
<div class="st-col col-12 col-sm ">
<h4 style="text-align:center;margin-bottom:5px">Temperature</h4>
<plotly :data="temp_pdata" :layout="temp_layout" :displaylogo="false"></plotly>
</div>
<div class="st-col col-12 col-sm ">
<h4 style="text-align:center;margin-bottom:5px">Humidity</h4>
<plotly :data="hum_pdata" :layout="hum_layout" :displaylogo="false"></plotly>
</div>
</div>
<div class="row">
<div class="st-col col-12 col-sm ">
<h4 style="text-align:center;margin:0px">Wind</h4>
<plotly :data="wind_pdata" :layout="wind_layout" :displaylogo="false"></plotly>
</div>
<div class="st-col col-12 col-sm ">
<h4 style="text-align:center;margin:0px">Pressure</h4>
<plotly :data="press_pdata" :layout="press_layout" :displaylogo="false"></plotly>
</div>
</div>
</div>
</div>
</div>
</div>
Binary file modified data.jld2
Binary file not shown.
21 changes: 10 additions & 11 deletions lib/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
module Utils
using PlotlyBase, DataFrames, Statistics, Interpolations
export rescale_t, rescale_y, calc_mse, plot_pred, get_layout, get_traces

rescale_t(x) = t_scale .* x .+ t_mean
rescale_y(x,i) = y_scale[i] .* x .+ y_mean[i]
Expand All @@ -16,24 +19,19 @@ function plot_pred(t, y, t̂, ŷ; kwargs...)
return traces
end

predict(y0, t, θ, state) = begin
node, _, _ = neural_ode(t, length(y0))
= Array(node(y0, θ, state)[1])
end

get_layout(title, xlabel, ylabel) = PlotlyBase.Layout(
#= title=title, =#
xaxis=attr( title=xlabel, showgrid=false),
yaxis=attr( title=ylabel, showgrid=true),
margin=attr(l=5, r=5, t=5, b=5),
#= margin=attr(l=5, r=5, t=15, b=5), =#
legend=attr( x=1, y=1.02, yanchor="bottom", xanchor="right", orientation="h"),
)

function get_traces(t_train, t_predict, y_train, ŷ, y_test, quantity_idx)
[
PlotlyBase.scatter(x=rescale_t(t_predict), y=rescale_y(ŷ,quantity_idx), mode="line", name=""),
PlotlyBase.scatter(x=rescale_t(t_predict), y=rescale_y(ŷ,quantity_idx), mode="line", name="",line=attr(color="black")),
PlotlyBase.scatter(x=rescale_t(t_train), y=rescale_y(y_train,quantity_idx), mode="markers", marker=attr(size=10, line=attr(width=2, color="DarkSlateGrey")), name = "y_train"),
PlotlyBase.scatter(x=rescale_t(t_test), y=rescale_y(y_test,quantity_idx), mode="markers", name = "y_test")
PlotlyBase.scatter(x=rescale_t(t_test), y=rescale_y(y_test,quantity_idx), mode="markers", name = "y_test", marker=attr(size=6, color="orange"))
]
end

Expand All @@ -51,9 +49,9 @@ function get_traces(train_df, test_df, predict_df, norm)
for (i, col_name) in zip(1:4, names(train_df)[2:end])
# Generate traces for each feature
feature_traces = [
PlotlyBase.scatter(x=rescale_t(predict_df.t), y=rescale_y(predict_df[!, col_name], i), mode="line", name="Predict"),
PlotlyBase.scatter(x=rescale_t(train_df.t), y=rescale_y(train_df[!, col_name], i), mode="markers", marker=attr(size=10, line=attr(width=2, color="DarkSlateGrey")), name = "Train"),
PlotlyBase.scatter(x=rescale_t(test_df.t), y=rescale_y(test_df[!, col_name], i), mode="markers", name = "Test")
PlotlyBase.scatter(x=rescale_t(predict_df.t), y=rescale_y(predict_df[!, col_name], i), mode="line", line=attr(color="black"), name="Predict"),
PlotlyBase.scatter(x=rescale_t(train_df.t), y=rescale_y(train_df[!, col_name], i), mode="markers", marker=attr(size=10, color="darkcyan", line=attr(width=0, color="DarkSlateGrey")), name = "Train"),
PlotlyBase.scatter(x=rescale_t(test_df.t), y=rescale_y(test_df[!, col_name], i), mode="markers", marker=attr(size=6, color="orange"), name = "Test")
]

# Add the set of traces to the list
Expand All @@ -62,3 +60,4 @@ function get_traces(train_df, test_df, predict_df, norm)

return all_traces
end
end

0 comments on commit c0489e6

Please sign in to comment.