Skip to content

Commit

Permalink
Merge pull request #146 from VaiTon/events
Browse files Browse the repository at this point in the history
Add events function to helpers.jl
  • Loading branch information
oxinabox authored May 14, 2024
2 parents e0dcbb3 + 18ca3aa commit a2b97d4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
20 changes: 20 additions & 0 deletions src/Deserialization/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,23 @@ function steps(tbl; kwargs...)

return steps
end

"""
events(logger)
Returns a list of all the events serialized by `logger`.
`logger` can be a `TBLogger` or the path of a valid TensorBoard logdir.
You should call this function only if you are interested in the events in an array-like
structure. If you need to iterate over the events, use `map_events` instead.
"""
function events(tbl; kwargs...)
events = []

map_events(tbl; kwargs...) do ev
push!(events, ev)
end

return events
end
10 changes: 9 additions & 1 deletion test/test_TBLogger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ test_log_dir = "test_logs/"
close.(values(tbl3.all_files))
tbl4 = TBLogger(test_log_dir*"run_2", tb_overwrite)
@test !isfile(test_log_dir*"run_2/testfile")

# check custom file prefix
tbl5 = TBLogger(test_log_dir*"run_3"; time = 0, prefix = "test.")
@test isfile(test_log_dir*"run_3/test.events.out.tfevents.0.$(gethostname())")
Expand Down Expand Up @@ -124,3 +124,11 @@ end
@test length(tbl.all_files) == 1

end

@testset "events" begin
tbl = TBLogger(test_log_dir*"run", tb_overwrite)
@test length(TensorBoardLogger.events(tbl)) == 1 # creation event

TensorBoardLogger.log_value(tbl, "test", 1.0)
@test length(TensorBoardLogger.events(tbl)) == 2 # creation event + log_value
end

0 comments on commit a2b97d4

Please sign in to comment.