From 18ca3aa80bd7508d4f64d06260940233467413b5 Mon Sep 17 00:00:00 2001 From: VaiTon Date: Fri, 26 Apr 2024 17:51:52 +0200 Subject: [PATCH] Add events function to helpers.jl --- src/Deserialization/helpers.jl | 20 ++++++++++++++++++++ test/test_TBLogger.jl | 10 +++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/Deserialization/helpers.jl b/src/Deserialization/helpers.jl index e2f60f43..4824c852 100644 --- a/src/Deserialization/helpers.jl +++ b/src/Deserialization/helpers.jl @@ -33,3 +33,23 @@ function steps(tbl; kwargs...) return steps end + +""" + events(logger) + +Returns a list of all the events serialized by `logger`. + +`logger` can be a `TBLogger` or the path of a valid TensorBoard logdir. + +You should call this function only if you are interested in the events in an array-like +structure. If you need to iterate over the events, use `map_events` instead. +""" +function events(tbl; kwargs...) + events = [] + + map_events(tbl; kwargs...) do ev + push!(events, ev) + end + + return events +end diff --git a/test/test_TBLogger.jl b/test/test_TBLogger.jl index 9992daef..1ee1fa5e 100644 --- a/test/test_TBLogger.jl +++ b/test/test_TBLogger.jl @@ -19,7 +19,7 @@ test_log_dir = "test_logs/" close.(values(tbl3.all_files)) tbl4 = TBLogger(test_log_dir*"run_2", tb_overwrite) @test !isfile(test_log_dir*"run_2/testfile") - + # check custom file prefix tbl5 = TBLogger(test_log_dir*"run_3"; time = 0, prefix = "test.") @test isfile(test_log_dir*"run_3/test.events.out.tfevents.0.$(gethostname())") @@ -124,3 +124,11 @@ end @test length(tbl.all_files) == 1 end + +@testset "events" begin + tbl = TBLogger(test_log_dir*"run", tb_overwrite) + @test length(TensorBoardLogger.events(tbl)) == 1 # creation event + + TensorBoardLogger.log_value(tbl, "test", 1.0) + @test length(TensorBoardLogger.events(tbl)) == 2 # creation event + log_value +end