From da2783591640cb19a2b8bb4a0a9d941c9eec2589 Mon Sep 17 00:00:00 2001 From: Attila Krasznahorkay Date: Mon, 4 Nov 2024 13:04:07 +0100 Subject: [PATCH] Made all templated code of track fitting public. This way clients can have access to the full details of the code if they want to, while also getting algorithms that would perform fitting in one very specific way. --- benchmarks/cpu/toy_detector_cpu.cpp | 2 +- core/CMakeLists.txt | 2 +- .../traccc/fitting/details}/fit_tracks.hpp | 30 ++++++++++--------- .../fitting/kalman_fitting_algorithm.hpp | 11 ++++++- core/src/fitting/kalman_fitting_algorithm.cpp | 5 ++-- ...orithm_constant_field_default_detector.cpp | 16 +++++----- ...ithm_constant_field_telescope_detector.cpp | 16 +++++----- examples/run/cpu/full_chain_algorithm.cpp | 2 +- examples/run/cpu/seeding_example.cpp | 2 +- examples/run/cpu/seq_example.cpp | 2 +- examples/run/cpu/truth_finding_example.cpp | 2 +- examples/run/cpu/truth_fitting_example.cpp | 2 +- examples/run/cuda/seeding_example_cuda.cpp | 2 +- examples/run/cuda/seq_example_cuda.cpp | 2 +- .../run/cuda/truth_finding_example_cuda.cpp | 2 +- .../run/cuda/truth_fitting_example_cuda.cpp | 2 +- .../cpu/test_ckf_sparse_tracks_telescope.cpp | 2 +- tests/cpu/test_kalman_fitter_telescope.cpp | 2 +- tests/cpu/test_kalman_fitter_wire_chamber.cpp | 2 +- 19 files changed, 57 insertions(+), 49 deletions(-) rename core/{src/fitting => include/traccc/fitting/details}/fit_tracks.hpp (73%) diff --git a/benchmarks/cpu/toy_detector_cpu.cpp b/benchmarks/cpu/toy_detector_cpu.cpp index b9abbf9b7..3fd92d468 100644 --- a/benchmarks/cpu/toy_detector_cpu.cpp +++ b/benchmarks/cpu/toy_detector_cpu.cpp @@ -57,7 +57,7 @@ BENCHMARK_F(ToyDetectorBenchmark, CPU)(benchmark::State& state) { traccc::track_params_estimation tp(host_mr); traccc::host::combinatorial_kalman_filter_algorithm host_finding( finding_cfg); - traccc::host::kalman_fitting_algorithm host_fitting(fitting_cfg); + traccc::host::kalman_fitting_algorithm host_fitting(fitting_cfg, host_mr); for (auto _ : state) { diff --git a/core/CMakeLists.txt b/core/CMakeLists.txt index e55bc281c..d2d2af375 100644 --- a/core/CMakeLists.txt +++ b/core/CMakeLists.txt @@ -74,7 +74,7 @@ traccc_add_library( traccc_core core TYPE SHARED "include/traccc/fitting/kalman_filter/kalman_fitter.hpp" "include/traccc/fitting/kalman_filter/kalman_step_aborter.hpp" "include/traccc/fitting/kalman_filter/statistics_updater.hpp" - "src/fitting/fit_tracks.hpp" + "include/traccc/fitting/details/fit_tracks.hpp" "include/traccc/fitting/kalman_fitting_algorithm.hpp" "src/fitting/kalman_fitting_algorithm.cpp" "src/fitting/kalman_fitting_algorithm_constant_field_default_detector.cpp" diff --git a/core/src/fitting/fit_tracks.hpp b/core/include/traccc/fitting/details/fit_tracks.hpp similarity index 73% rename from core/src/fitting/fit_tracks.hpp rename to core/include/traccc/fitting/details/fit_tracks.hpp index 00cc6bb3d..a5407f8ff 100644 --- a/core/src/fitting/fit_tracks.hpp +++ b/core/include/traccc/fitting/details/fit_tracks.hpp @@ -11,6 +11,9 @@ #include "traccc/edm/track_candidate.hpp" #include "traccc/edm/track_state.hpp" +// VecMem include(s). +#include + namespace traccc::host::details { /// Templated implementation of the track fitting algorithm. @@ -19,27 +22,26 @@ namespace traccc::host::details { /// specializations, to fit tracks on top of a specific detector type, magnetic /// field type, and track fitting configuration. /// +/// @note The memory resource received by this function is not used thoroughly +/// for the setup of the output container. Inner vectors in the output's +/// jagged vector are created using the default memory resource. +/// /// @tparam fitter_t The fitter type used for the track fitting /// -/// @param det The detector object -/// @param field The magnetic field object -/// @param track_candidates All track candidates to fit -/// @param config The track fitting configuration +/// @param[in] fitter The fitter object to use on the track candidates +/// @param[in] track_candidates All track candidates to fit +/// @param[in] mr Memory resource to use for the output container /// /// @return A container of the fitted track states /// template track_state_container_types::host fit_tracks( - const typename fitter_t::detector_type& det, - const typename fitter_t::bfield_type& field, + fitter_t& fitter, const track_candidate_container_types::const_view& track_candidates_view, - const typename fitter_t::config_type& config) { - - // Create the fitter object. - fitter_t fitter(det, field, config); + vecmem::memory_resource& mr) { - // Output container. - track_state_container_types::host output_states; + // Create the output container. + track_state_container_types::host result{&mr}; // Iterate over the tracks, const track_candidate_container_types::const_device track_candidates{ @@ -62,13 +64,13 @@ track_state_container_types::host fit_tracks( fitter.fit(track_candidates.get_headers()[i], fitter_state); // Save the results into the output container. - output_states.push_back( + result.push_back( std::move(fitter_state.m_fit_res), std::move(fitter_state.m_fit_actor_state.m_track_states)); } // Return the fitted track states. - return output_states; + return result; } } // namespace traccc::host::details diff --git a/core/include/traccc/fitting/kalman_fitting_algorithm.hpp b/core/include/traccc/fitting/kalman_fitting_algorithm.hpp index 2924ccac0..31fafbc91 100644 --- a/core/include/traccc/fitting/kalman_fitting_algorithm.hpp +++ b/core/include/traccc/fitting/kalman_fitting_algorithm.hpp @@ -17,6 +17,12 @@ // Detray include(s). #include +// VecMem include(s). +#include + +// System include(s). +#include + namespace traccc::host { /// Kalman filter based track fitting algorithm @@ -40,7 +46,8 @@ class kalman_fitting_algorithm /// /// @param config The configuration object /// - kalman_fitting_algorithm(const config_type& config); + explicit kalman_fitting_algorithm(const config_type& config, + vecmem::memory_resource& mr); /// Execute the algorithm /// @@ -71,6 +78,8 @@ class kalman_fitting_algorithm private: /// Algorithm configuration config_type m_config; + /// Memory resource to use in the algorithm + std::reference_wrapper m_mr; }; // class kalman_fitting_algorithm diff --git a/core/src/fitting/kalman_fitting_algorithm.cpp b/core/src/fitting/kalman_fitting_algorithm.cpp index ff85039b2..ef2002837 100644 --- a/core/src/fitting/kalman_fitting_algorithm.cpp +++ b/core/src/fitting/kalman_fitting_algorithm.cpp @@ -10,7 +10,8 @@ namespace traccc::host { -kalman_fitting_algorithm::kalman_fitting_algorithm(const config_type& config) - : m_config(config) {} +kalman_fitting_algorithm::kalman_fitting_algorithm(const config_type& config, + vecmem::memory_resource& mr) + : m_config{config}, m_mr{mr} {} } // namespace traccc::host diff --git a/core/src/fitting/kalman_fitting_algorithm_constant_field_default_detector.cpp b/core/src/fitting/kalman_fitting_algorithm_constant_field_default_detector.cpp index f944fb36c..f4a3933c6 100644 --- a/core/src/fitting/kalman_fitting_algorithm_constant_field_default_detector.cpp +++ b/core/src/fitting/kalman_fitting_algorithm_constant_field_default_detector.cpp @@ -6,7 +6,7 @@ */ // Project include(s). -#include "fit_tracks.hpp" +#include "traccc/fitting/details/fit_tracks.hpp" #include "traccc/fitting/kalman_filter/kalman_fitter.hpp" #include "traccc/fitting/kalman_fitting_algorithm.hpp" @@ -21,18 +21,16 @@ kalman_fitting_algorithm::output_type kalman_fitting_algorithm::operator()( const detray::bfield::const_field_t::view_t& field, const track_candidate_container_types::const_view& track_candidates) const { - // Set up the fitter type(s). - using stepper_type = + // Create the fitter object. + kalman_fitter< detray::rk_stepper>; - using navigator_type = - detray::navigator; - using fitter_type = kalman_fitter; + detray::constrained_step<>>, + detray::navigator> + fitter{det, field, m_config}; // Perform the track fitting using a common, templated function. - return details::fit_tracks(det, field, track_candidates, - m_config); + return details::fit_tracks(fitter, track_candidates, m_mr.get()); } } // namespace traccc::host diff --git a/core/src/fitting/kalman_fitting_algorithm_constant_field_telescope_detector.cpp b/core/src/fitting/kalman_fitting_algorithm_constant_field_telescope_detector.cpp index 53404596d..d28fe814c 100644 --- a/core/src/fitting/kalman_fitting_algorithm_constant_field_telescope_detector.cpp +++ b/core/src/fitting/kalman_fitting_algorithm_constant_field_telescope_detector.cpp @@ -6,7 +6,7 @@ */ // Project include(s). -#include "fit_tracks.hpp" +#include "traccc/fitting/details/fit_tracks.hpp" #include "traccc/fitting/kalman_filter/kalman_fitter.hpp" #include "traccc/fitting/kalman_fitting_algorithm.hpp" @@ -21,18 +21,16 @@ kalman_fitting_algorithm::output_type kalman_fitting_algorithm::operator()( const detray::bfield::const_field_t::view_t& field, const track_candidate_container_types::const_view& track_candidates) const { - // Set up the fitter type(s). - using stepper_type = + // Create the fitter object. + kalman_fitter< detray::rk_stepper>; - using navigator_type = - detray::navigator; - using fitter_type = kalman_fitter; + detray::constrained_step<>>, + detray::navigator> + fitter{det, field, m_config}; // Perform the track fitting using a common, templated function. - return details::fit_tracks(det, field, track_candidates, - m_config); + return details::fit_tracks(fitter, track_candidates, m_mr.get()); } } // namespace traccc::host diff --git a/examples/run/cpu/full_chain_algorithm.cpp b/examples/run/cpu/full_chain_algorithm.cpp index 349fa09d9..8756b945a 100644 --- a/examples/run/cpu/full_chain_algorithm.cpp +++ b/examples/run/cpu/full_chain_algorithm.cpp @@ -28,7 +28,7 @@ full_chain_algorithm::full_chain_algorithm( m_seeding(finder_config, grid_config, filter_config, mr), m_track_parameter_estimation(mr), m_finding(finding_config), - m_fitting(fitting_config), + m_fitting(fitting_config, mr), m_finder_config(finder_config), m_grid_config(grid_config), m_filter_config(filter_config), diff --git a/examples/run/cpu/seeding_example.cpp b/examples/run/cpu/seeding_example.cpp index 4216b8ba4..0b41e8af2 100644 --- a/examples/run/cpu/seeding_example.cpp +++ b/examples/run/cpu/seeding_example.cpp @@ -124,7 +124,7 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts, traccc::fitting_config fit_cfg; fit_cfg.propagation = propagation_config; - traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg, host_mr); traccc::greedy_ambiguity_resolution_algorithm host_ambiguity_resolution{}; diff --git a/examples/run/cpu/seq_example.cpp b/examples/run/cpu/seq_example.cpp index e83fbc6a4..df80a11cc 100644 --- a/examples/run/cpu/seq_example.cpp +++ b/examples/run/cpu/seq_example.cpp @@ -128,7 +128,7 @@ int seq_run(const traccc::opts::input_data& input_opts, seeding_opts.seedfilter, host_mr); traccc::track_params_estimation tp(host_mr); finding_algorithm finding_alg(finding_cfg); - fitting_algorithm fitting_alg(fitting_cfg); + fitting_algorithm fitting_alg(fitting_cfg, host_mr); traccc::greedy_ambiguity_resolution_algorithm resolution_alg; // performance writer diff --git a/examples/run/cpu/truth_finding_example.cpp b/examples/run/cpu/truth_finding_example.cpp index a02163b39..c6f7b9456 100644 --- a/examples/run/cpu/truth_finding_example.cpp +++ b/examples/run/cpu/truth_finding_example.cpp @@ -103,7 +103,7 @@ int seq_run(const traccc::opts::track_finding& finding_opts, traccc::fitting_config fit_cfg; fit_cfg.propagation = propagation_config; - traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg, host_mr); // Seed generator traccc::seed_generator sg(detector, diff --git a/examples/run/cpu/truth_fitting_example.cpp b/examples/run/cpu/truth_fitting_example.cpp index 6463d3d26..bebe633b9 100644 --- a/examples/run/cpu/truth_fitting_example.cpp +++ b/examples/run/cpu/truth_fitting_example.cpp @@ -107,7 +107,7 @@ int main(int argc, char* argv[]) { traccc::fitting_config fit_cfg; fit_cfg.propagation = propagation_opts; - traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg, host_mr); // Seed generator traccc::seed_generator sg(host_det, stddevs); diff --git a/examples/run/cuda/seeding_example_cuda.cpp b/examples/run/cuda/seeding_example_cuda.cpp index a45e2a4cf..d365eb3df 100644 --- a/examples/run/cuda/seeding_example_cuda.cpp +++ b/examples/run/cuda/seeding_example_cuda.cpp @@ -179,7 +179,7 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts, traccc::fitting_config fit_cfg; fit_cfg.propagation = propagation_config; - traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg, host_mr); traccc::cuda::fitting_algorithm device_fitting( fit_cfg, mr, async_copy, stream); diff --git a/examples/run/cuda/seq_example_cuda.cpp b/examples/run/cuda/seq_example_cuda.cpp index a97852f86..591a2f47f 100644 --- a/examples/run/cuda/seq_example_cuda.cpp +++ b/examples/run/cuda/seq_example_cuda.cpp @@ -166,7 +166,7 @@ int seq_run(const traccc::opts::detector& detector_opts, seeding_opts.seedfilter, host_mr); traccc::track_params_estimation tp(host_mr); host_finding_algorithm finding_alg(finding_cfg); - host_fitting_algorithm fitting_alg(fitting_cfg); + host_fitting_algorithm fitting_alg(fitting_cfg, host_mr); traccc::cuda::clusterization_algorithm ca_cuda(mr, copy, stream, clusterization_opts); diff --git a/examples/run/cuda/truth_finding_example_cuda.cpp b/examples/run/cuda/truth_finding_example_cuda.cpp index 0e6af2f84..90d968a74 100644 --- a/examples/run/cuda/truth_finding_example_cuda.cpp +++ b/examples/run/cuda/truth_finding_example_cuda.cpp @@ -155,7 +155,7 @@ int seq_run(const traccc::opts::track_finding& finding_opts, traccc::fitting_config fit_cfg; fit_cfg.propagation = propagation_config; - traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg, host_mr); traccc::cuda::fitting_algorithm device_fitting( fit_cfg, mr, async_copy, stream); diff --git a/examples/run/cuda/truth_fitting_example_cuda.cpp b/examples/run/cuda/truth_fitting_example_cuda.cpp index c93990bb3..5adba022d 100644 --- a/examples/run/cuda/truth_fitting_example_cuda.cpp +++ b/examples/run/cuda/truth_fitting_example_cuda.cpp @@ -155,7 +155,7 @@ int main(int argc, char* argv[]) { traccc::fitting_config fit_cfg; fit_cfg.propagation = propagation_opts; - traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg, host_mr); traccc::cuda::fitting_algorithm device_fitting( fit_cfg, mr, async_copy, stream); diff --git a/tests/cpu/test_ckf_sparse_tracks_telescope.cpp b/tests/cpu/test_ckf_sparse_tracks_telescope.cpp index cdb9d06d1..595f57122 100644 --- a/tests/cpu/test_ckf_sparse_tracks_telescope.cpp +++ b/tests/cpu/test_ckf_sparse_tracks_telescope.cpp @@ -136,7 +136,7 @@ TEST_P(CkfSparseTrackTelescopeTests, Run) { fit_cfg.propagation.navigation.overstep_tolerance = -100.f * unit::um; fit_cfg.propagation.navigation.max_mask_tolerance = 1.f * unit::mm; - traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg, host_mr); // Iterate over events for (std::size_t i_evt = 0; i_evt < n_events; i_evt++) { diff --git a/tests/cpu/test_kalman_fitter_telescope.cpp b/tests/cpu/test_kalman_fitter_telescope.cpp index 19e2b24fc..5cc3e4e80 100644 --- a/tests/cpu/test_kalman_fitter_telescope.cpp +++ b/tests/cpu/test_kalman_fitter_telescope.cpp @@ -123,7 +123,7 @@ TEST_P(KalmanFittingTelescopeTests, Run) { fit_cfg.propagation.navigation.overstep_tolerance = -100.f * unit::um; fit_cfg.propagation.navigation.max_mask_tolerance = 1.f * unit::mm; - traccc::host::kalman_fitting_algorithm fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm fitting(fit_cfg, host_mr); // Iterate over events for (std::size_t i_evt = 0; i_evt < n_events; i_evt++) { diff --git a/tests/cpu/test_kalman_fitter_wire_chamber.cpp b/tests/cpu/test_kalman_fitter_wire_chamber.cpp index 9a82e01d2..952a4e389 100644 --- a/tests/cpu/test_kalman_fitter_wire_chamber.cpp +++ b/tests/cpu/test_kalman_fitter_wire_chamber.cpp @@ -124,7 +124,7 @@ TEST_P(KalmanFittingWireChamberTests, Run) { static_cast(mask_tolerance); fit_cfg.propagation.navigation.search_window = search_window; fit_cfg.ptc_hypothesis = ptc; - traccc::host::kalman_fitting_algorithm fitting(fit_cfg); + traccc::host::kalman_fitting_algorithm fitting(fit_cfg, host_mr); // Iterate over events for (std::size_t i_evt = 0; i_evt < n_events; i_evt++) {