diff --git a/ext/StatsPlotsExt.jl b/ext/StatsPlotsExt.jl index 9c60d34c1..fe0c62da5 100644 --- a/ext/StatsPlotsExt.jl +++ b/ext/StatsPlotsExt.jl @@ -112,6 +112,98 @@ Renaming and reexport of StatsPlots function `plotlyjs()` to define PlotlyJS.jl plotlyjs_backend(args...; kwargs...) = StatsPlots.plotlyjs(args...; kwargs...) +# ============================================================================= +# Helper functions for container handling and plot setup +# ============================================================================= + +""" + clear_container!(container::Vector{Dict}) + +Clear all elements from a plot container. +""" +function clear_container!(container::Vector{Dict}) + empty!(container) +end + +""" + setup_plot_attributes(plot_attributes::Dict) + +Setup plot attributes based on the current backend. +Returns a tuple of (gr_back, attributes, attributes_redux) where: +- gr_back: Boolean indicating if GR backend is active +- attributes: Merged plot attributes with framestyle +- attributes_redux: Copy of attributes without framestyle +""" +function setup_plot_attributes(plot_attributes::Dict) + gr_back = StatsPlots.backend() == StatsPlots.Plots.GRBackend() + + if !gr_back + attrbts = merge(DEFAULT_PLOT_ATTRIBUTES, Dict(:framestyle => :box)) + else + attrbts = merge(DEFAULT_PLOT_ATTRIBUTES, Dict()) + end + + attributes = merge(attrbts, plot_attributes) + attributes_redux = copy(attributes) + delete!(attributes_redux, :framestyle) + + return gr_back, attributes, attributes_redux +end + +""" + create_extended_palette(attributes_redux::Dict; total_pal_len::Int=100, alpha_reduction_factor::Float64=0.7) + +Create an extended palette with alpha reduction from the base palette in attributes. +""" +function create_extended_palette(attributes_redux::Dict; total_pal_len::Int=100, alpha_reduction_factor::Float64=0.7) + orig_pal = StatsPlots.palette(attributes_redux[:palette]) + pal = mapreduce(x -> StatsPlots.coloralpha.(orig_pal, alpha_reduction_factor ^ x), vcat, 0:(total_pal_len Γ· length(orig_pal)) - 1) |> StatsPlots.palette + return pal +end + +""" + group_container_by_model_and_merge_diffs(container::Vector{Dict}, args_and_kwargs::Dict, diffdict::Dict) + +Group container entries by model name and merge difference dictionaries. +Returns the updated diffdict with grouped differences merged. +""" +function group_container_by_model_and_merge_diffs(container::Vector{Dict}, args_and_kwargs::Dict, diffdict::Dict) + grouped_by_model = Dict{Any, Vector{Dict}}() + + for d in container + model = d[:model_name] + d_sub = Dict(k => d[k] for k in setdiff(keys(args_and_kwargs), keys(DEFAULT_ARGS_AND_KWARGS_NAMES)) if haskey(d, k)) + push!(get!(grouped_by_model, model, Vector{Dict}()), d_sub) + end + + model_names = unique([d[:model_name] for d in container]) + + for model in model_names + if length(grouped_by_model[model]) > 1 + diffdict_grouped = compare_args_and_kwargs(grouped_by_model[model]) + diffdict = merge_by_runid(diffdict, diffdict_grouped) + end + end + + return diffdict +end + +""" + create_reduced_vector_and_diffdict(container::Vector{Dict}) + +Create a reduced vector of dictionaries keeping only run_id, label, and DEFAULT_ARGS_AND_KWARGS_NAMES keys, +and compute the difference dictionary. +""" +function create_reduced_vector_and_diffdict(container::Vector{Dict}) + reduced_vector = [ + Dict(k => d[k] for k in vcat(:run_id, :label, keys(DEFAULT_ARGS_AND_KWARGS_NAMES)...) if haskey(d, k)) + for d in container + ] + + diffdict = compare_args_and_kwargs(reduced_vector) + + return reduced_vector, diffdict +end """ $(SIGNATURES) @@ -229,20 +321,7 @@ function plot_model_estimates(𝓂::β„³, sylvester_algorithmΒ³ = (isa(sylvester_algorithm, Symbol) || length(sylvester_algorithm) < 2) ? sum(k * (k + 1) Γ· 2 for k in 1:𝓂.timings.nPast_not_future_and_mixed + 1 + 𝓂.timings.nExo) > DEFAULT_SYLVESTER_THRESHOLD ? DEFAULT_LARGE_SYLVESTER_ALGORITHM : DEFAULT_SYLVESTER_ALGORITHM : sylvester_algorithm[2], lyapunov_algorithm = lyapunov_algorithm) - gr_back = StatsPlots.backend() == StatsPlots.Plots.GRBackend() - - if !gr_back - attrbts = merge(DEFAULT_PLOT_ATTRIBUTES, Dict(:framestyle => :box)) - else - attrbts = merge(DEFAULT_PLOT_ATTRIBUTES, Dict()) - end - - attributes = merge(attrbts, plot_attributes) - - attributes_redux = copy(attributes) - - delete!(attributes_redux, :framestyle) - + gr_back, attributes, attributes_redux = setup_plot_attributes(plot_attributes) # write_parameters_input!(𝓂, parameters, verbose = verbose) @@ -367,21 +446,13 @@ function plot_model_estimates(𝓂::β„³, extended_x_axis = vcat(x_axis, [last_x + i * period for i in 1:forecast_periods]) end - orig_pal = StatsPlots.palette(attributes_redux[:palette]) - - total_pal_len = 100 - - alpha_reduction_factor = 0.7 - - pal = mapreduce(x -> StatsPlots.coloralpha.(orig_pal, alpha_reduction_factor ^ x), vcat, 0:(total_pal_len Γ· length(orig_pal)) - 1) |> StatsPlots.palette + pal = create_extended_palette(attributes_redux) estimate_color = :navy data_color = :orangered - while length(model_estimates_active_plot_container) > 0 - pop!(model_estimates_active_plot_container) - end + clear_container!(model_estimates_active_plot_container) args_and_kwargs = Dict(:run_id => length(model_estimates_active_plot_container) + 1, :model_name => 𝓂.model_name, @@ -875,20 +946,7 @@ function plot_model_estimates!(𝓂::β„³, sylvester_algorithmΒ³ = (isa(sylvester_algorithm, Symbol) || length(sylvester_algorithm) < 2) ? sum(k * (k + 1) Γ· 2 for k in 1:𝓂.timings.nPast_not_future_and_mixed + 1 + 𝓂.timings.nExo) > DEFAULT_SYLVESTER_THRESHOLD ? DEFAULT_LARGE_SYLVESTER_ALGORITHM : DEFAULT_SYLVESTER_ALGORITHM : sylvester_algorithm[2], lyapunov_algorithm = lyapunov_algorithm) - gr_back = StatsPlots.backend() == StatsPlots.Plots.GRBackend() - - if !gr_back - attrbts = merge(DEFAULT_PLOT_ATTRIBUTES, Dict(:framestyle => :box)) - else - attrbts = merge(DEFAULT_PLOT_ATTRIBUTES, Dict()) - end - - attributes = merge(attrbts, plot_attributes) - - attributes_redux = copy(attributes) - - delete!(attributes_redux, :framestyle) - + gr_back, attributes, attributes_redux = setup_plot_attributes(plot_attributes) # write_parameters_input!(𝓂, parameters, verbose = verbose) @@ -1012,13 +1070,7 @@ function plot_model_estimates!(𝓂::β„³, extended_x_axis = vcat(x_axis, [last_x + i * period for i in 1:forecast_periods]) end - orig_pal = StatsPlots.palette(attributes_redux[:palette]) - - total_pal_len = 100 - - alpha_reduction_factor = 0.7 - - pal = mapreduce(x -> StatsPlots.coloralpha.(orig_pal, alpha_reduction_factor ^ x), vcat, 0:(total_pal_len Γ· length(orig_pal)) - 1) |> StatsPlots.palette + pal = create_extended_palette(attributes_redux) estimate_color = :navy @@ -1088,37 +1140,10 @@ function plot_model_estimates!(𝓂::β„³, @info "Plot with same parameters already exists. Using previous plot data to create plot." end - # 1. Keep only certain keys from each dictionary - reduced_vector = [ - Dict(k => d[k] for k in vcat(:run_id, keys(DEFAULT_ARGS_AND_KWARGS_NAMES)...) if haskey(d, k)) - for d in model_estimates_active_plot_container - ] - - diffdict = compare_args_and_kwargs(reduced_vector) - - # 2. Group the original vector by :model_name. Check difference for keys where they matter between models. Two different models might have different shocks so that difference is less important, but the same model with different shocks is a difference to highlight. - grouped_by_model = Dict{Any, Vector{Dict}}() - - for d in model_estimates_active_plot_container - model = d[:model_name] - d_sub = Dict(k => d[k] for k in setdiff(keys(args_and_kwargs), keys(DEFAULT_ARGS_AND_KWARGS_NAMES)) if haskey(d, k)) - push!(get!(grouped_by_model, model, Vector{Dict}()), d_sub) - end + _, diffdict = create_reduced_vector_and_diffdict(model_estimates_active_plot_container) - model_names = [] - - for d in model_estimates_active_plot_container - push!(model_names, d[:model_name]) - end - - model_names = unique(model_names) - - for model in model_names - if length(grouped_by_model[model]) > 1 - diffdict_grouped = compare_args_and_kwargs(grouped_by_model[model]) - diffdict = merge_by_runid(diffdict, diffdict_grouped) - end - end + # Group by model name and merge diffs for multi-model comparisons + diffdict = group_container_by_model_and_merge_diffs(model_estimates_active_plot_container, args_and_kwargs, diffdict) annotate_ss = Vector{Pair{String, Any}}[] @@ -1823,19 +1848,7 @@ function plot_irf(𝓂::β„³; sylvester_algorithmΒ² = isa(sylvester_algorithm, Symbol) ? sylvester_algorithm : sylvester_algorithm[1], sylvester_algorithmΒ³ = (isa(sylvester_algorithm, Symbol) || length(sylvester_algorithm) < 2) ? sum(k * (k + 1) Γ· 2 for k in 1:𝓂.timings.nPast_not_future_and_mixed + 1 + 𝓂.timings.nExo) > DEFAULT_SYLVESTER_THRESHOLD ? DEFAULT_LARGE_SYLVESTER_ALGORITHM : DEFAULT_SYLVESTER_ALGORITHM : sylvester_algorithm[2]) - gr_back = StatsPlots.backend() == StatsPlots.Plots.GRBackend() - - if !gr_back - attrbts = merge(DEFAULT_PLOT_ATTRIBUTES, Dict(:framestyle => :box)) - else - attrbts = merge(DEFAULT_PLOT_ATTRIBUTES, Dict()) - end - - attributes = merge(attrbts, plot_attributes) - - attributes_redux = copy(attributes) - - delete!(attributes_redux, :framestyle) + gr_back, attributes, attributes_redux = setup_plot_attributes(plot_attributes) shocks, negative_shock, shock_size, periods_extended, shock_idx, shock_history = process_shocks_input(shocks, negative_shock, shock_size, periods, 𝓂) @@ -1956,9 +1969,7 @@ function plot_irf(𝓂::β„³; push!(processed_rename_dictionary, k => rename_dictionary[k]) end - while length(irf_active_plot_container) > 0 - pop!(irf_active_plot_container) - end + clear_container!(irf_active_plot_container) args_and_kwargs = Dict(:run_id => length(irf_active_plot_container) + 1, :model_name => 𝓂.model_name, @@ -2000,13 +2011,7 @@ function plot_irf(𝓂::β„³; push!(irf_active_plot_container, args_and_kwargs) - orig_pal = StatsPlots.palette(attributes_redux[:palette]) - - total_pal_len = 100 - - alpha_reduction_factor = 0.7 - - pal = mapreduce(x -> StatsPlots.coloralpha.(orig_pal, alpha_reduction_factor ^ x), vcat, 0:(total_pal_len Γ· length(orig_pal)) - 1) |> StatsPlots.palette + pal = create_extended_palette(attributes_redux) return_plots = [] @@ -2509,27 +2514,9 @@ function plot_irf!(𝓂::β„³; sylvester_algorithmΒ² = isa(sylvester_algorithm, Symbol) ? sylvester_algorithm : sylvester_algorithm[1], sylvester_algorithmΒ³ = (isa(sylvester_algorithm, Symbol) || length(sylvester_algorithm) < 2) ? sum(k * (k + 1) Γ· 2 for k in 1:𝓂.timings.nPast_not_future_and_mixed + 1 + 𝓂.timings.nExo) > DEFAULT_SYLVESTER_THRESHOLD ? DEFAULT_LARGE_SYLVESTER_ALGORITHM : DEFAULT_SYLVESTER_ALGORITHM : sylvester_algorithm[2]) - gr_back = StatsPlots.backend() == StatsPlots.Plots.GRBackend() - - if !gr_back - attrbts = merge(DEFAULT_PLOT_ATTRIBUTES, Dict(:framestyle => :box)) - else - attrbts = merge(DEFAULT_PLOT_ATTRIBUTES, Dict()) - end - - attributes = merge(attrbts, plot_attributes) - - attributes_redux = copy(attributes) - - delete!(attributes_redux, :framestyle) - - orig_pal = StatsPlots.palette(attributes_redux[:palette]) + gr_back, attributes, attributes_redux = setup_plot_attributes(plot_attributes) - total_pal_len = 100 - - alpha_reduction_factor = 0.7 - - pal = mapreduce(x -> StatsPlots.coloralpha.(orig_pal, alpha_reduction_factor ^ x), vcat, 0:(total_pal_len Γ· length(orig_pal)) - 1) |> StatsPlots.palette + pal = create_extended_palette(attributes_redux) shocks, negative_shock, shock_size, periods_extended, shock_idx, shock_history = process_shocks_input(shocks, negative_shock, shock_size, periods, 𝓂) @@ -2692,37 +2679,10 @@ function plot_irf!(𝓂::β„³; @info "Plot with same parameters already exists. Using previous plot data to create plot." end - # 1. Keep only certain keys from each dictionary - reduced_vector = [ - Dict(k => d[k] for k in vcat(:run_id, :label, keys(DEFAULT_ARGS_AND_KWARGS_NAMES)...) if haskey(d, k)) - for d in irf_active_plot_container - ] + _, diffdict = create_reduced_vector_and_diffdict(irf_active_plot_container) - diffdict = compare_args_and_kwargs(reduced_vector) - - # 2. Group the original vector by :model_name - grouped_by_model = Dict{Any, Vector{Dict}}() - - for d in irf_active_plot_container - model = d[:model_name] - d_sub = Dict(k => d[k] for k in setdiff(keys(args_and_kwargs), keys(DEFAULT_ARGS_AND_KWARGS_NAMES)) if haskey(d, k)) - push!(get!(grouped_by_model, model, Vector{Dict}()), d_sub) - end - - model_names = [] - - for d in irf_active_plot_container - push!(model_names, d[:model_name]) - end - - model_names = unique(model_names) - - for model in model_names - if length(grouped_by_model[model]) > 1 - diffdict_grouped = compare_args_and_kwargs(grouped_by_model[model]) - diffdict = merge_by_runid(diffdict, diffdict_grouped) - end - end + # Group by model name and merge diffs for multi-model comparisons + diffdict = group_container_by_model_and_merge_diffs(irf_active_plot_container, args_and_kwargs, diffdict) # @assert haskey(diffdict, :parameters) || haskey(diffdict, :shock_names) || haskey(diffdict, :initial_state) || any(haskey.(Ref(diffdict), keys(DEFAULT_ARGS_AND_KWARGS_NAMES))) "New plot must be different from previous plot. Use the version without ! to plot." @@ -3541,19 +3501,7 @@ function plot_conditional_variance_decomposition(𝓂::β„³; opts = merge_calculation_options(tol = tol, verbose = verbose, quadratic_matrix_equation_algorithm = quadratic_matrix_equation_algorithm) - gr_back = StatsPlots.backend() == StatsPlots.Plots.GRBackend() - - if !gr_back - attrbts = merge(DEFAULT_PLOT_ATTRIBUTES, Dict(:framestyle => :box)) - else - attrbts = merge(DEFAULT_PLOT_ATTRIBUTES, Dict()) - end - - attributes = merge(attrbts, plot_attributes) - - attributes_redux = copy(attributes) - - delete!(attributes_redux, :framestyle) + gr_back, attributes, attributes_redux = setup_plot_attributes(plot_attributes) fevds = get_conditional_variance_decomposition(𝓂, periods = 1:periods, @@ -3600,13 +3548,7 @@ function plot_conditional_variance_decomposition(𝓂::β„³; end end - orig_pal = StatsPlots.palette(attributes_redux[:palette]) - - total_pal_len = 100 - - alpha_reduction_factor = 0.7 - - pal = mapreduce(x -> StatsPlots.coloralpha.(orig_pal, alpha_reduction_factor ^ x), vcat, 0:(total_pal_len Γ· length(orig_pal)) - 1) |> StatsPlots.palette + pal = create_extended_palette(attributes_redux) n_subplots = length(var_idx) pp = [] @@ -3812,19 +3754,7 @@ function plot_solution(𝓂::β„³, sylvester_algorithmΒ³ = (isa(sylvester_algorithm, Symbol) || length(sylvester_algorithm) < 2) ? sum(k * (k + 1) Γ· 2 for k in 1:𝓂.timings.nPast_not_future_and_mixed + 1 + 𝓂.timings.nExo) > DEFAULT_SYLVESTER_THRESHOLD ? DEFAULT_LARGE_SYLVESTER_ALGORITHM : DEFAULT_SYLVESTER_ALGORITHM : sylvester_algorithm[2], lyapunov_algorithm = lyapunov_algorithm) - gr_back = StatsPlots.backend() == StatsPlots.Plots.GRBackend() - - if !gr_back - attrbts = merge(DEFAULT_PLOT_ATTRIBUTES, Dict(:framestyle => :box)) - else - attrbts = merge(DEFAULT_PLOT_ATTRIBUTES, Dict()) - end - - attributes = merge(attrbts, plot_attributes) - - attributes_redux = copy(attributes) - - delete!(attributes_redux, :framestyle) + gr_back, attributes, attributes_redux = setup_plot_attributes(plot_attributes) state = state isa Symbol ? state : state |> Meta.parse |> replace_indices @@ -3880,9 +3810,7 @@ function plot_solution(𝓂::β„³, state_selector = state .== 𝓂.var # Clear container for new plot - while length(solution_active_plot_container) > 0 - pop!(solution_active_plot_container) - end + clear_container!(solution_active_plot_container) if any(x -> contains(string(x), "β—–"), full_NSSS) full_NSSS_decomposed = decompose_name.(full_NSSS) @@ -4001,22 +3929,9 @@ function _plot_solution_from_container(; push!(joint_states, string(apply_custom_name.(container[:state], Ref(Dict(container[:rename_dictionary]))))) end - gr_back = StatsPlots.backend() == StatsPlots.Plots.GRBackend() - - if !gr_back - attrbts = merge(DEFAULT_PLOT_ATTRIBUTES, Dict(:framestyle => :box)) - else - attrbts = merge(DEFAULT_PLOT_ATTRIBUTES, Dict()) - end - - attributes = merge(attrbts, plot_attributes) - attributes_redux = copy(attributes) - delete!(attributes_redux, :framestyle) + gr_back, attributes, attributes_redux = setup_plot_attributes(plot_attributes) - orig_pal = StatsPlots.palette(attributes_redux[:palette]) - total_pal_len = 100 - alpha_reduction_factor = 0.7 - pal = mapreduce(x -> StatsPlots.coloralpha.(orig_pal, alpha_reduction_factor ^ x), vcat, 0:(total_pal_len Γ· length(orig_pal)) - 1) |> StatsPlots.palette + pal = create_extended_palette(attributes_redux) # Create comparison of containers to detect differences # Keep relevant keys for comparison: model_name, state, parameters, algorithm, ignore_obc, label @@ -4044,37 +3959,10 @@ function _plot_solution_from_container(; if length(solution_active_plot_container) == 0 diffdict[:label] = [solution_active_plot_container[1][:label]] else - # 1. Keep only certain keys from each dictionary - reduced_vector = [ - Dict(k => d[k] for k in vcat(:run_id, :label, keys(DEFAULT_ARGS_AND_KWARGS_NAMES)...) if haskey(d, k)) - for d in solution_active_plot_container - ] - - diffdict = compare_args_and_kwargs(reduced_vector) + _, diffdict = create_reduced_vector_and_diffdict(solution_active_plot_container) - # 2. Group the original vector by :model_name - grouped_by_model = Dict{Any, Vector{Dict}}() - - for d in solution_active_plot_container#[1:end-1] - model = d[:model_name] - d_sub = Dict(k => d[k] for k in setdiff(keys(solution_active_plot_container[end]), keys(DEFAULT_ARGS_AND_KWARGS_NAMES)) if haskey(d, k)) - push!(get!(grouped_by_model, model, Vector{Dict}()), d_sub) - end - - model_names = [] - - for d in solution_active_plot_container - push!(model_names, d[:model_name]) - end - - model_names = unique(model_names) - - for model in model_names - if length(grouped_by_model[model]) > 1 - diffdict_grouped = compare_args_and_kwargs(grouped_by_model[model]) - diffdict = merge_by_runid(diffdict, diffdict_grouped) - end - end + # Group by model name and merge diffs for multi-model comparisons + diffdict = group_container_by_model_and_merge_diffs(solution_active_plot_container, solution_active_plot_container[end], diffdict) end else # For single container, create a diffdict with just the label @@ -4812,19 +4700,7 @@ function plot_conditional_forecast(𝓂::β„³, sylvester_algorithm::Union{Symbol,Vector{Symbol},Tuple{Symbol,Vararg{Symbol}}} = DEFAULT_SYLVESTER_SELECTOR(𝓂)) # @nospecialize # reduce compile time - gr_back = StatsPlots.backend() == StatsPlots.Plots.GRBackend() - - if !gr_back - attrbts = merge(DEFAULT_PLOT_ATTRIBUTES, Dict(:framestyle => :box)) - else - attrbts = merge(DEFAULT_PLOT_ATTRIBUTES, Dict()) - end - - attributes = merge(attrbts, plot_attributes) - - attributes_redux = copy(attributes) - - delete!(attributes_redux, :framestyle) + gr_back, attributes, attributes_redux = setup_plot_attributes(plot_attributes) initial_state_input = copy(initial_state) @@ -4943,9 +4819,7 @@ function plot_conditional_forecast(𝓂::β„³, shocks = Matrix{Union{Nothing,Float64}}(undef,length(𝓂.exo),periods) end - while length(conditional_forecast_active_plot_container) > 0 - pop!(conditional_forecast_active_plot_container) - end + clear_container!(conditional_forecast_active_plot_container) # Create display names for variables and shocks full_variable_names_display = [(apply_custom_name(replace_indices_in_symbol(v), rename_dictionary)) for v in full_var_SS if v βˆ‰ map(x->Symbol(string(x) * "β‚β‚“β‚Ž"),𝓂.timings.exo)] @@ -5028,13 +4902,7 @@ function plot_conditional_forecast(𝓂::β„³, push!(conditional_forecast_active_plot_container, args_and_kwargs) - orig_pal = StatsPlots.palette(attributes_redux[:palette]) - - total_pal_len = 100 - - alpha_reduction_factor = 0.7 - - pal = mapreduce(x -> StatsPlots.coloralpha.(orig_pal, alpha_reduction_factor ^ x), vcat, 0:(total_pal_len Γ· length(orig_pal)) - 1) |> StatsPlots.palette + pal = create_extended_palette(attributes_redux) n_subplots = length(var_idx) pp = [] @@ -5272,19 +5140,7 @@ function plot_conditional_forecast!(𝓂::β„³, @assert plot_type ∈ [:compare, :stack] "plot_type must be either :compare or :stack" - gr_back = StatsPlots.backend() == StatsPlots.Plots.GRBackend() - - if !gr_back - attrbts = merge(DEFAULT_PLOT_ATTRIBUTES, Dict(:framestyle => :box)) - else - attrbts = merge(DEFAULT_PLOT_ATTRIBUTES, Dict()) - end - - attributes = merge(attrbts, plot_attributes) - - attributes_redux = copy(attributes) - - delete!(attributes_redux, :framestyle) + gr_back, attributes, attributes_redux = setup_plot_attributes(plot_attributes) initial_state_input = copy(initial_state) @@ -5447,13 +5303,7 @@ function plot_conditional_forecast!(𝓂::β„³, # sorted_variable_names_display = sort(variable_names_display) sorted_shock_names_display = sort(shock_names_display) - orig_pal = StatsPlots.palette(attributes_redux[:palette]) - - total_pal_len = 100 - - alpha_reduction_factor = 0.7 - - pal = mapreduce(x -> StatsPlots.coloralpha.(orig_pal, alpha_reduction_factor ^ x), vcat, 0:(total_pal_len Γ· length(orig_pal)) - 1) |> StatsPlots.palette + pal = create_extended_palette(attributes_redux) args_and_kwargs = Dict(:run_id => length(conditional_forecast_active_plot_container) + 1, :model_name => 𝓂.model_name, @@ -5508,37 +5358,10 @@ function plot_conditional_forecast!(𝓂::β„³, @info "Plot with same parameters already exists. Using previous plot data to create plot." end - # 1. Keep only certain keys from each dictionary - reduced_vector = [ - Dict(k => d[k] for k in vcat(:run_id, :label, keys(DEFAULT_ARGS_AND_KWARGS_NAMES)...) if haskey(d, k)) - for d in conditional_forecast_active_plot_container - ] - - diffdict = compare_args_and_kwargs(reduced_vector) - - # 2. Group the original vector by :model_name - grouped_by_model = Dict{Any, Vector{Dict}}() - - for d in conditional_forecast_active_plot_container - model = d[:model_name] - d_sub = Dict(k => d[k] for k in setdiff(keys(args_and_kwargs), keys(DEFAULT_ARGS_AND_KWARGS_NAMES)) if haskey(d, k)) - push!(get!(grouped_by_model, model, Vector{Dict}()), d_sub) - end - - model_names = [] + _, diffdict = create_reduced_vector_and_diffdict(conditional_forecast_active_plot_container) - for d in conditional_forecast_active_plot_container - push!(model_names, d[:model_name]) - end - - model_names = unique(model_names) - - for model in model_names - if length(grouped_by_model[model]) > 1 - diffdict_grouped = compare_args_and_kwargs(grouped_by_model[model]) - diffdict = merge_by_runid(diffdict, diffdict_grouped) - end - end + # Group by model name and merge diffs for multi-model comparisons + diffdict = group_container_by_model_and_merge_diffs(conditional_forecast_active_plot_container, args_and_kwargs, diffdict) annotate_ss = Vector{Pair{String, Any}}[]