Skip to content

Commit

Permalink
Update Host Track Finding, main branch (2024.11.11.) (#768)
Browse files Browse the repository at this point in the history
* Renamed the host CKF algorithm.

Giving the class and file names longer, hopefully more
expressive names.

* Made the templated track finding function public.

Allowing clients to use it with specializations not included
in traccc::core.
  • Loading branch information
krasznaa authored Nov 11, 2024
1 parent 8827300 commit df652bf
Show file tree
Hide file tree
Showing 19 changed files with 78 additions and 57 deletions.
5 changes: 3 additions & 2 deletions benchmarks/cpu/toy_detector_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "traccc/geometry/detector.hpp"

// Traccc algorithm include(s).
#include "traccc/finding/ckf_algorithm.hpp"
#include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp"
#include "traccc/fitting/fitting_algorithm.hpp"
#include "traccc/seeding/seeding_algorithm.hpp"
#include "traccc/seeding/track_params_estimation.hpp"
Expand Down Expand Up @@ -62,7 +62,8 @@ BENCHMARK_F(ToyDetectorBenchmark, CPU)(benchmark::State& state) {
// Algorithms
traccc::seeding_algorithm sa(seeding_cfg, grid_cfg, filter_cfg, host_mr);
traccc::track_params_estimation tp(host_mr);
traccc::host::ckf_algorithm host_finding(finding_cfg);
traccc::host::combinatorial_kalman_filter_algorithm host_finding(
finding_cfg);
traccc::fitting_algorithm<host_fitter_type> host_fitting(fitting_cfg);

for (auto _ : state) {
Expand Down
10 changes: 5 additions & 5 deletions core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ traccc_add_library( traccc_core core TYPE SHARED
"include/traccc/finding/finding_config.hpp"
"include/traccc/finding/actors/ckf_aborter.hpp"
"include/traccc/finding/actors/interaction_register.hpp"
"src/finding/find_tracks.hpp"
"include/traccc/finding/ckf_algorithm.hpp"
"src/finding/ckf_algorithm.cpp"
"src/finding/ckf_algorithm_defdet_cfield.cpp"
"src/finding/ckf_algorithm_teldet_cfield.cpp"
"include/traccc/finding/details/find_tracks.hpp"
"include/traccc/finding/combinatorial_kalman_filter_algorithm.hpp"
"src/finding/combinatorial_kalman_filter_algorithm.cpp"
"src/finding/combinatorial_kalman_filter_algorithm_constant_field_default_detector.cpp"
"src/finding/combinatorial_kalman_filter_algorithm_constant_field_telescope_detector.cpp"
# Fitting algorithmic code
"include/traccc/fitting/kalman_filter/gain_matrix_smoother.hpp"
"include/traccc/fitting/kalman_filter/gain_matrix_updater.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace traccc::host {
/// This is the main host-based track finding algorithm of the project. More
/// documentation to be written later...
///
class ckf_algorithm
class combinatorial_kalman_filter_algorithm
: public algorithm<track_candidate_container_types::host(
const default_detector::host&,
const detray::bfield::const_field_t::view_t&,
Expand All @@ -44,7 +44,7 @@ class ckf_algorithm
using output_type = track_candidate_container_types::host;

/// Constructor with the algorithm's configuration
ckf_algorithm(const config_type& config);
explicit combinatorial_kalman_filter_algorithm(const config_type& config);

/// Execute the algorithm
///
Expand Down Expand Up @@ -84,6 +84,6 @@ class ckf_algorithm
/// Algorithm configuration
config_type m_config;

}; // class ckf_algorithm
}; // class combinatorial_kalman_filter_algorithm

} // namespace traccc::host
File renamed without changes.
15 changes: 0 additions & 15 deletions core/src/finding/ckf_algorithm.cpp

This file was deleted.

28 changes: 28 additions & 0 deletions core/src/finding/combinatorial_kalman_filter_algorithm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2024 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

// Local include(s).
#include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp"

// System include(s).
#include <stdexcept>

namespace traccc::host {

combinatorial_kalman_filter_algorithm::combinatorial_kalman_filter_algorithm(
const config_type& config)
: m_config{config} {

// Check the configuration.
if (m_config.min_track_candidates_per_track == 0) {
throw std::invalid_argument(
"The minimum number of track candidates per track must be at least "
"1.");
}
}

} // namespace traccc::host
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
*/

// Local include(s).
#include "find_tracks.hpp"
#include "traccc/finding/ckf_algorithm.hpp"
#include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp"
#include "traccc/finding/details/find_tracks.hpp"

// Detray include(s).
#include <detray/core/detector.hpp>
#include <detray/detectors/bfield.hpp>
#include <detray/navigation/navigator.hpp>
#include <detray/propagator/propagator.hpp>
#include <detray/propagator/rk_stepper.hpp>

namespace traccc::host {

ckf_algorithm::output_type ckf_algorithm::operator()(
combinatorial_kalman_filter_algorithm::output_type
combinatorial_kalman_filter_algorithm::operator()(
const default_detector::host& det,
const detray::bfield::const_field_t::view_t& field,
const measurement_collection_types::const_view& measurements,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
*/

// Local include(s).
#include "find_tracks.hpp"
#include "traccc/finding/ckf_algorithm.hpp"
#include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp"
#include "traccc/finding/details/find_tracks.hpp"

// Detray include(s).
#include <detray/core/detector.hpp>
#include <detray/detectors/bfield.hpp>
#include <detray/navigation/navigator.hpp>
#include <detray/propagator/propagator.hpp>
#include <detray/propagator/rk_stepper.hpp>

namespace traccc::host {

ckf_algorithm::output_type ckf_algorithm::operator()(
combinatorial_kalman_filter_algorithm::output_type
combinatorial_kalman_filter_algorithm::operator()(
const telescope_detector::host& det,
const detray::bfield::const_field_t::view_t& field,
const measurement_collection_types::const_view& measurements,
Expand Down
5 changes: 3 additions & 2 deletions examples/run/cpu/full_chain_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "traccc/clusterization/clusterization_algorithm.hpp"
#include "traccc/edm/silicon_cell_collection.hpp"
#include "traccc/edm/track_state.hpp"
#include "traccc/finding/ckf_algorithm.hpp"
#include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp"
#include "traccc/fitting/fitting_algorithm.hpp"
#include "traccc/fitting/kalman_filter/kalman_fitter.hpp"
#include "traccc/geometry/detector.hpp"
Expand Down Expand Up @@ -63,7 +63,8 @@ class full_chain_algorithm : public algorithm<track_state_container_types::host(
using spacepoint_formation_algorithm =
traccc::host::silicon_pixel_spacepoint_formation_algorithm;
/// Track finding algorithm type
using finding_algorithm = traccc::host::ckf_algorithm;
using finding_algorithm =
traccc::host::combinatorial_kalman_filter_algorithm;
/// Track fitting algorithm type
using fitting_algorithm = traccc::fitting_algorithm<
traccc::kalman_fitter<stepper_type, navigator_type>>;
Expand Down
4 changes: 2 additions & 2 deletions examples/run/cpu/seeding_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

// algorithms
#include "traccc/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.hpp"
#include "traccc/finding/ckf_algorithm.hpp"
#include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp"
#include "traccc/fitting/fitting_algorithm.hpp"
#include "traccc/seeding/seeding_algorithm.hpp"
#include "traccc/seeding/track_params_estimation.hpp"
Expand Down Expand Up @@ -129,7 +129,7 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
traccc::finding_config cfg(finding_opts);
cfg.propagation = propagation_config;

traccc::host::ckf_algorithm host_finding(cfg);
traccc::host::combinatorial_kalman_filter_algorithm host_finding(cfg);

// Fitting algorithm object
typename traccc::fitting_algorithm<host_fitter_type>::config_type fit_cfg;
Expand Down
5 changes: 3 additions & 2 deletions examples/run/cpu/seq_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// algorithms
#include "traccc/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.hpp"
#include "traccc/clusterization/clusterization_algorithm.hpp"
#include "traccc/finding/ckf_algorithm.hpp"
#include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp"
#include "traccc/fitting/fitting_algorithm.hpp"
#include "traccc/seeding/seeding_algorithm.hpp"
#include "traccc/seeding/silicon_pixel_spacepoint_formation_algorithm.hpp"
Expand Down Expand Up @@ -106,7 +106,8 @@ int seq_run(const traccc::opts::input_data& input_opts,
detray::constrained_step<>>;
using navigator_type =
detray::navigator<const traccc::default_detector::host>;
using finding_algorithm = traccc::host::ckf_algorithm;
using finding_algorithm =
traccc::host::combinatorial_kalman_filter_algorithm;
using fitting_algorithm = traccc::fitting_algorithm<
traccc::kalman_fitter<stepper_type, navigator_type>>;

Expand Down
4 changes: 2 additions & 2 deletions examples/run/cpu/truth_finding_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "traccc/definitions/common.hpp"
#include "traccc/definitions/primitives.hpp"
#include "traccc/efficiency/finding_performance_writer.hpp"
#include "traccc/finding/ckf_algorithm.hpp"
#include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp"
#include "traccc/fitting/fitting_algorithm.hpp"
#include "traccc/fitting/kalman_filter/kalman_fitter.hpp"
#include "traccc/io/read_detector.hpp"
Expand Down Expand Up @@ -109,7 +109,7 @@ int seq_run(const traccc::opts::track_finding& finding_opts,
cfg.propagation = propagation_config;

// Finding algorithm object
traccc::host::ckf_algorithm host_finding(cfg);
traccc::host::combinatorial_kalman_filter_algorithm host_finding(cfg);

// Fitting algorithm object
typename traccc::fitting_algorithm<host_fitter_type>::config_type fit_cfg;
Expand Down
4 changes: 2 additions & 2 deletions examples/run/cuda/seeding_example_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "traccc/efficiency/nseed_performance_writer.hpp"
#include "traccc/efficiency/seeding_performance_writer.hpp"
#include "traccc/efficiency/track_filter.hpp"
#include "traccc/finding/ckf_algorithm.hpp"
#include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp"
#include "traccc/fitting/fitting_algorithm.hpp"
#include "traccc/io/read_detector.hpp"
#include "traccc/io/read_detector_description.hpp"
Expand Down Expand Up @@ -174,7 +174,7 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
cfg.propagation = propagation_config;

// Finding algorithm object
traccc::host::ckf_algorithm host_finding(cfg);
traccc::host::combinatorial_kalman_filter_algorithm host_finding(cfg);
traccc::cuda::finding_algorithm<rk_stepper_type, device_navigator_type>
device_finding(cfg, mr, async_copy, stream);

Expand Down
5 changes: 3 additions & 2 deletions examples/run/cuda/seq_example_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "traccc/cuda/utils/stream.hpp"
#include "traccc/device/container_d2h_copy_alg.hpp"
#include "traccc/efficiency/seeding_performance_writer.hpp"
#include "traccc/finding/ckf_algorithm.hpp"
#include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp"
#include "traccc/fitting/fitting_algorithm.hpp"
#include "traccc/io/read_cells.hpp"
#include "traccc/io/read_detector.hpp"
Expand Down Expand Up @@ -136,7 +136,8 @@ int seq_run(const traccc::opts::detector& detector_opts,
using device_navigator_type =
detray::navigator<const traccc::default_detector::device>;

using host_finding_algorithm = traccc::host::ckf_algorithm;
using host_finding_algorithm =
traccc::host::combinatorial_kalman_filter_algorithm;
using device_finding_algorithm =
traccc::cuda::finding_algorithm<stepper_type, device_navigator_type>;

Expand Down
7 changes: 4 additions & 3 deletions examples/run/cuda/truth_finding_example_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "traccc/device/container_d2h_copy_alg.hpp"
#include "traccc/device/container_h2d_copy_alg.hpp"
#include "traccc/efficiency/finding_performance_writer.hpp"
#include "traccc/finding/ckf_algorithm.hpp"
#include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp"
#include "traccc/fitting/fitting_algorithm.hpp"
#include "traccc/fitting/kalman_filter/kalman_fitter.hpp"
#include "traccc/io/read_detector.hpp"
Expand Down Expand Up @@ -151,7 +151,7 @@ int seq_run(const traccc::opts::track_finding& finding_opts,
cfg.propagation = propagation_config;

// Finding algorithm object
traccc::host::ckf_algorithm host_finding(cfg);
traccc::host::combinatorial_kalman_filter_algorithm host_finding(cfg);
traccc::cuda::finding_algorithm<rk_stepper_type, device_navigator_type>
device_finding(cfg, mr, async_copy, stream);

Expand Down Expand Up @@ -243,7 +243,8 @@ int seq_run(const traccc::opts::track_finding& finding_opts,
track_state_d2h(track_states_cuda_buffer);

// CPU containers
traccc::host::ckf_algorithm::output_type track_candidates;
traccc::host::combinatorial_kalman_filter_algorithm::output_type
track_candidates;
traccc::fitting_algorithm<host_fitter_type>::output_type track_states;

if (accelerator_opts.compare_with_cpu) {
Expand Down
5 changes: 3 additions & 2 deletions examples/run/sycl/full_chain_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

// Project include(s).
#include "traccc/edm/silicon_cell_collection.hpp"
#include "traccc/finding/ckf_algorithm.hpp"
#include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp"
#include "traccc/fitting/fitting_algorithm.hpp"
#include "traccc/fitting/kalman_filter/kalman_fitter.hpp"
#include "traccc/geometry/detector.hpp"
Expand Down Expand Up @@ -73,7 +73,8 @@ class full_chain_algorithm
/// Clustering algorithm type
using clustering_algorithm = clusterization_algorithm;
/// Track finding algorithm type
using finding_algorithm = traccc::host::ckf_algorithm;
using finding_algorithm =
traccc::host::combinatorial_kalman_filter_algorithm;
/// Track fitting algorithm type
using fitting_algorithm = traccc::fitting_algorithm<
traccc::kalman_fitter<stepper_type, navigator_type>>;
Expand Down
8 changes: 5 additions & 3 deletions tests/cpu/test_ckf_combinatorics_telescope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
*/

// Project include(s).
#include "traccc/finding/ckf_algorithm.hpp"
#include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp"
#include "traccc/fitting/fitting_algorithm.hpp"
#include "traccc/io/read_measurements.hpp"
#include "traccc/io/utils.hpp"
Expand Down Expand Up @@ -132,8 +132,10 @@ TEST_P(CpuCkfCombinatoricsTelescopeTests, Run) {
cfg_limit.propagation.navigation.max_mask_tolerance = 1.f * unit<float>::mm;

// Finding algorithm object
traccc::host::ckf_algorithm host_finding(cfg_no_limit);
traccc::host::ckf_algorithm host_finding_limit(cfg_limit);
traccc::host::combinatorial_kalman_filter_algorithm host_finding(
cfg_no_limit);
traccc::host::combinatorial_kalman_filter_algorithm host_finding_limit(
cfg_limit);

// Iterate over events
for (std::size_t i_evt = 0; i_evt < n_events; i_evt++) {
Expand Down
4 changes: 2 additions & 2 deletions tests/cpu/test_ckf_sparse_tracks_telescope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
*/

// Project include(s).
#include "traccc/finding/ckf_algorithm.hpp"
#include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp"
#include "traccc/fitting/fitting_algorithm.hpp"
#include "traccc/io/read_measurements.hpp"
#include "traccc/io/utils.hpp"
Expand Down Expand Up @@ -128,7 +128,7 @@ TEST_P(CkfSparseTrackTelescopeTests, Run) {
cfg.propagation.navigation.max_mask_tolerance = 1.f * unit<float>::mm;

// Finding algorithm object
traccc::host::ckf_algorithm host_finding(cfg);
traccc::host::combinatorial_kalman_filter_algorithm host_finding(cfg);

// Fitting algorithm object
typename traccc::fitting_algorithm<host_fitter_type>::config_type fit_cfg;
Expand Down
4 changes: 2 additions & 2 deletions tests/cuda/test_ckf_toy_detector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "traccc/cuda/finding/finding_algorithm.hpp"
#include "traccc/device/container_d2h_copy_alg.hpp"
#include "traccc/device/container_h2d_copy_alg.hpp"
#include "traccc/finding/ckf_algorithm.hpp"
#include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp"
#include "traccc/io/read_measurements.hpp"
#include "traccc/io/utils.hpp"
#include "traccc/performance/container_comparator.hpp"
Expand Down Expand Up @@ -143,7 +143,7 @@ TEST_P(CkfToyDetectorTests, Run) {
cfg.propagation.navigation.search_window = search_window;

// Finding algorithm object
traccc::host::ckf_algorithm host_finding(cfg);
traccc::host::combinatorial_kalman_filter_algorithm host_finding(cfg);

// Finding algorithm object
traccc::cuda::finding_algorithm<rk_stepper_type, device_navigator_type>
Expand Down

0 comments on commit df652bf

Please sign in to comment.