Skip to content

Commit

Permalink
Made all templated code of track fitting public.
Browse files Browse the repository at this point in the history
This way clients can have access to the full details of the
code if they want to, while also getting algorithms that would
perform fitting in one very specific way.
  • Loading branch information
krasznaa committed Nov 12, 2024
1 parent 03f61e0 commit da27835
Show file tree
Hide file tree
Showing 19 changed files with 57 additions and 49 deletions.
2 changes: 1 addition & 1 deletion benchmarks/cpu/toy_detector_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ BENCHMARK_F(ToyDetectorBenchmark, CPU)(benchmark::State& state) {
traccc::track_params_estimation tp(host_mr);
traccc::host::combinatorial_kalman_filter_algorithm host_finding(
finding_cfg);
traccc::host::kalman_fitting_algorithm host_fitting(fitting_cfg);
traccc::host::kalman_fitting_algorithm host_fitting(fitting_cfg, host_mr);

for (auto _ : state) {

Expand Down
2 changes: 1 addition & 1 deletion core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ traccc_add_library( traccc_core core TYPE SHARED
"include/traccc/fitting/kalman_filter/kalman_fitter.hpp"
"include/traccc/fitting/kalman_filter/kalman_step_aborter.hpp"
"include/traccc/fitting/kalman_filter/statistics_updater.hpp"
"src/fitting/fit_tracks.hpp"
"include/traccc/fitting/details/fit_tracks.hpp"
"include/traccc/fitting/kalman_fitting_algorithm.hpp"
"src/fitting/kalman_fitting_algorithm.cpp"
"src/fitting/kalman_fitting_algorithm_constant_field_default_detector.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
#include "traccc/edm/track_candidate.hpp"
#include "traccc/edm/track_state.hpp"

// VecMem include(s).
#include <vecmem/memory/memory_resource.hpp>

namespace traccc::host::details {

/// Templated implementation of the track fitting algorithm.
Expand All @@ -19,27 +22,26 @@ namespace traccc::host::details {
/// specializations, to fit tracks on top of a specific detector type, magnetic
/// field type, and track fitting configuration.
///
/// @note The memory resource received by this function is not used thoroughly
/// for the setup of the output container. Inner vectors in the output's
/// jagged vector are created using the default memory resource.
///
/// @tparam fitter_t The fitter type used for the track fitting
///
/// @param det The detector object
/// @param field The magnetic field object
/// @param track_candidates All track candidates to fit
/// @param config The track fitting configuration
/// @param[in] fitter The fitter object to use on the track candidates
/// @param[in] track_candidates All track candidates to fit
/// @param[in] mr Memory resource to use for the output container
///
/// @return A container of the fitted track states
///
template <typename fitter_t>
track_state_container_types::host fit_tracks(
const typename fitter_t::detector_type& det,
const typename fitter_t::bfield_type& field,
fitter_t& fitter,
const track_candidate_container_types::const_view& track_candidates_view,
const typename fitter_t::config_type& config) {

// Create the fitter object.
fitter_t fitter(det, field, config);
vecmem::memory_resource& mr) {

// Output container.
track_state_container_types::host output_states;
// Create the output container.
track_state_container_types::host result{&mr};

// Iterate over the tracks,
const track_candidate_container_types::const_device track_candidates{
Expand All @@ -62,13 +64,13 @@ track_state_container_types::host fit_tracks(
fitter.fit(track_candidates.get_headers()[i], fitter_state);

// Save the results into the output container.
output_states.push_back(
result.push_back(
std::move(fitter_state.m_fit_res),
std::move(fitter_state.m_fit_actor_state.m_track_states));
}

// Return the fitted track states.
return output_states;
return result;
}

} // namespace traccc::host::details
11 changes: 10 additions & 1 deletion core/include/traccc/fitting/kalman_fitting_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
// Detray include(s).
#include <detray/detectors/bfield.hpp>

// VecMem include(s).
#include <vecmem/memory/memory_resource.hpp>

// System include(s).
#include <functional>

namespace traccc::host {

/// Kalman filter based track fitting algorithm
Expand All @@ -40,7 +46,8 @@ class kalman_fitting_algorithm
///
/// @param config The configuration object
///
kalman_fitting_algorithm(const config_type& config);
explicit kalman_fitting_algorithm(const config_type& config,
vecmem::memory_resource& mr);

/// Execute the algorithm
///
Expand Down Expand Up @@ -71,6 +78,8 @@ class kalman_fitting_algorithm
private:
/// Algorithm configuration
config_type m_config;
/// Memory resource to use in the algorithm
std::reference_wrapper<vecmem::memory_resource> m_mr;

}; // class kalman_fitting_algorithm

Expand Down
5 changes: 3 additions & 2 deletions core/src/fitting/kalman_fitting_algorithm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

namespace traccc::host {

kalman_fitting_algorithm::kalman_fitting_algorithm(const config_type& config)
: m_config(config) {}
kalman_fitting_algorithm::kalman_fitting_algorithm(const config_type& config,
vecmem::memory_resource& mr)
: m_config{config}, m_mr{mr} {}

} // namespace traccc::host
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
*/

// Project include(s).
#include "fit_tracks.hpp"
#include "traccc/fitting/details/fit_tracks.hpp"
#include "traccc/fitting/kalman_filter/kalman_fitter.hpp"
#include "traccc/fitting/kalman_fitting_algorithm.hpp"

Expand All @@ -21,18 +21,16 @@ kalman_fitting_algorithm::output_type kalman_fitting_algorithm::operator()(
const detray::bfield::const_field_t::view_t& field,
const track_candidate_container_types::const_view& track_candidates) const {

// Set up the fitter type(s).
using stepper_type =
// Create the fitter object.
kalman_fitter<
detray::rk_stepper<detray::bfield::const_field_t::view_t,
traccc::default_detector::host::algebra_type,
detray::constrained_step<>>;
using navigator_type =
detray::navigator<const traccc::default_detector::host>;
using fitter_type = kalman_fitter<stepper_type, navigator_type>;
detray::constrained_step<>>,
detray::navigator<const traccc::default_detector::host>>
fitter{det, field, m_config};

// Perform the track fitting using a common, templated function.
return details::fit_tracks<fitter_type>(det, field, track_candidates,
m_config);
return details::fit_tracks(fitter, track_candidates, m_mr.get());
}

} // namespace traccc::host
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
*/

// Project include(s).
#include "fit_tracks.hpp"
#include "traccc/fitting/details/fit_tracks.hpp"
#include "traccc/fitting/kalman_filter/kalman_fitter.hpp"
#include "traccc/fitting/kalman_fitting_algorithm.hpp"

Expand All @@ -21,18 +21,16 @@ kalman_fitting_algorithm::output_type kalman_fitting_algorithm::operator()(
const detray::bfield::const_field_t::view_t& field,
const track_candidate_container_types::const_view& track_candidates) const {

// Set up the fitter type(s).
using stepper_type =
// Create the fitter object.
kalman_fitter<
detray::rk_stepper<detray::bfield::const_field_t::view_t,
traccc::telescope_detector::host::algebra_type,
detray::constrained_step<>>;
using navigator_type =
detray::navigator<const traccc::telescope_detector::host>;
using fitter_type = kalman_fitter<stepper_type, navigator_type>;
detray::constrained_step<>>,
detray::navigator<const traccc::telescope_detector::host>>
fitter{det, field, m_config};

// Perform the track fitting using a common, templated function.
return details::fit_tracks<fitter_type>(det, field, track_candidates,
m_config);
return details::fit_tracks(fitter, track_candidates, m_mr.get());
}

} // namespace traccc::host
2 changes: 1 addition & 1 deletion examples/run/cpu/full_chain_algorithm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ full_chain_algorithm::full_chain_algorithm(
m_seeding(finder_config, grid_config, filter_config, mr),
m_track_parameter_estimation(mr),
m_finding(finding_config),
m_fitting(fitting_config),
m_fitting(fitting_config, mr),
m_finder_config(finder_config),
m_grid_config(grid_config),
m_filter_config(filter_config),
Expand Down
2 changes: 1 addition & 1 deletion examples/run/cpu/seeding_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
traccc::fitting_config fit_cfg;
fit_cfg.propagation = propagation_config;

traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg);
traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg, host_mr);

traccc::greedy_ambiguity_resolution_algorithm host_ambiguity_resolution{};

Expand Down
2 changes: 1 addition & 1 deletion examples/run/cpu/seq_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ int seq_run(const traccc::opts::input_data& input_opts,
seeding_opts.seedfilter, host_mr);
traccc::track_params_estimation tp(host_mr);
finding_algorithm finding_alg(finding_cfg);
fitting_algorithm fitting_alg(fitting_cfg);
fitting_algorithm fitting_alg(fitting_cfg, host_mr);
traccc::greedy_ambiguity_resolution_algorithm resolution_alg;

// performance writer
Expand Down
2 changes: 1 addition & 1 deletion examples/run/cpu/truth_finding_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ int seq_run(const traccc::opts::track_finding& finding_opts,
traccc::fitting_config fit_cfg;
fit_cfg.propagation = propagation_config;

traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg);
traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg, host_mr);

// Seed generator
traccc::seed_generator<traccc::default_detector::host> sg(detector,
Expand Down
2 changes: 1 addition & 1 deletion examples/run/cpu/truth_fitting_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ int main(int argc, char* argv[]) {
traccc::fitting_config fit_cfg;
fit_cfg.propagation = propagation_opts;

traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg);
traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg, host_mr);

// Seed generator
traccc::seed_generator<host_detector_type> sg(host_det, stddevs);
Expand Down
2 changes: 1 addition & 1 deletion examples/run/cuda/seeding_example_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
traccc::fitting_config fit_cfg;
fit_cfg.propagation = propagation_config;

traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg);
traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg, host_mr);
traccc::cuda::fitting_algorithm<device_fitter_type> device_fitting(
fit_cfg, mr, async_copy, stream);

Expand Down
2 changes: 1 addition & 1 deletion examples/run/cuda/seq_example_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ int seq_run(const traccc::opts::detector& detector_opts,
seeding_opts.seedfilter, host_mr);
traccc::track_params_estimation tp(host_mr);
host_finding_algorithm finding_alg(finding_cfg);
host_fitting_algorithm fitting_alg(fitting_cfg);
host_fitting_algorithm fitting_alg(fitting_cfg, host_mr);

traccc::cuda::clusterization_algorithm ca_cuda(mr, copy, stream,
clusterization_opts);
Expand Down
2 changes: 1 addition & 1 deletion examples/run/cuda/truth_finding_example_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ int seq_run(const traccc::opts::track_finding& finding_opts,
traccc::fitting_config fit_cfg;
fit_cfg.propagation = propagation_config;

traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg);
traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg, host_mr);
traccc::cuda::fitting_algorithm<device_fitter_type> device_fitting(
fit_cfg, mr, async_copy, stream);

Expand Down
2 changes: 1 addition & 1 deletion examples/run/cuda/truth_fitting_example_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ int main(int argc, char* argv[]) {
traccc::fitting_config fit_cfg;
fit_cfg.propagation = propagation_opts;

traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg);
traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg, host_mr);
traccc::cuda::fitting_algorithm<device_fitter_type> device_fitting(
fit_cfg, mr, async_copy, stream);

Expand Down
2 changes: 1 addition & 1 deletion tests/cpu/test_ckf_sparse_tracks_telescope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ TEST_P(CkfSparseTrackTelescopeTests, Run) {
fit_cfg.propagation.navigation.overstep_tolerance =
-100.f * unit<float>::um;
fit_cfg.propagation.navigation.max_mask_tolerance = 1.f * unit<float>::mm;
traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg);
traccc::host::kalman_fitting_algorithm host_fitting(fit_cfg, host_mr);

// Iterate over events
for (std::size_t i_evt = 0; i_evt < n_events; i_evt++) {
Expand Down
2 changes: 1 addition & 1 deletion tests/cpu/test_kalman_fitter_telescope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ TEST_P(KalmanFittingTelescopeTests, Run) {
fit_cfg.propagation.navigation.overstep_tolerance =
-100.f * unit<float>::um;
fit_cfg.propagation.navigation.max_mask_tolerance = 1.f * unit<float>::mm;
traccc::host::kalman_fitting_algorithm fitting(fit_cfg);
traccc::host::kalman_fitting_algorithm fitting(fit_cfg, host_mr);

// Iterate over events
for (std::size_t i_evt = 0; i_evt < n_events; i_evt++) {
Expand Down
2 changes: 1 addition & 1 deletion tests/cpu/test_kalman_fitter_wire_chamber.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ TEST_P(KalmanFittingWireChamberTests, Run) {
static_cast<float>(mask_tolerance);
fit_cfg.propagation.navigation.search_window = search_window;
fit_cfg.ptc_hypothesis = ptc;
traccc::host::kalman_fitting_algorithm fitting(fit_cfg);
traccc::host::kalman_fitting_algorithm fitting(fit_cfg, host_mr);

// Iterate over events
for (std::size_t i_evt = 0; i_evt < n_events; i_evt++) {
Expand Down

0 comments on commit da27835

Please sign in to comment.