Skip to content

Commit

Permalink
graph_traversal: allow passing multiple sources to bfs()
Browse files Browse the repository at this point in the history
  • Loading branch information
yut23 committed Dec 17, 2024
1 parent 24338f8 commit ae0e6bc
Showing 1 changed file with 75 additions and 13 deletions.
88 changes: 75 additions & 13 deletions aoc_lib/src/graph_traversal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ concept Heuristic = requires(Func heuristic, const Key &key) {
template <class T>
concept FuncPassed = !std::same_as<T, optional_func>;

template <class T>
concept AnySourceCollection =
util::concepts::any_iterable_collection<T, typename T::value_type>;

} // namespace detail

/**
Expand All @@ -141,19 +145,21 @@ concept FuncPassed = !std::same_as<T, optional_func>;
* visited as a tree). `use_seen` should only be set to false if the graph has
* no cycles.
*
* Returns the distance from the source to the first target found, or -1 if not
* found.
* Returns the distance from the source(s) to the first target found, or -1 if
* not found.
*/
template <bool use_seen = true, class Key,
template <bool use_seen = true, detail::AnySourceCollection ASC,
class Key = ASC::value_type,
detail::ProcessNeighbors<Key> ProcessNeighbors,
detail::IsTarget<Key> IsTarget = detail::optional_func,
detail::Visit<Key> Visit = detail::optional_func>
int bfs(const Key &source, ProcessNeighbors &&process_neighbors,
int bfs(const ASC &sources, ProcessNeighbors &&process_neighbors,
IsTarget &&is_target, Visit &&visit) {
static_assert(detail::FuncPassed<IsTarget> || detail::FuncPassed<Visit>,
"is_target and visit must not both be defaulted");
using visit_ret_t = typename detail::visit_invoke_result<Key, Visit>::type;
detail::maybe_unordered_set<Key> queue = {source};
detail::maybe_unordered_set<Key> queue = {std::begin(sources),
std::end(sources)};
detail::maybe_unordered_set<Key> next_queue{};
detail::maybe_unordered_set<Key> seen{};

Expand Down Expand Up @@ -189,6 +195,32 @@ int bfs(const Key &source, ProcessNeighbors &&process_neighbors,
return -1;
}

// specialization for std::initializer_list, so deduction of `bfs({source,
// source}, ...)` works
template <bool use_seen = true, class Key,
detail::ProcessNeighbors<Key> ProcessNeighbors,
detail::IsTarget<Key> IsTarget = detail::optional_func,
detail::Visit<Key> Visit = detail::optional_func>
int bfs(const std::initializer_list<Key> &sources,
ProcessNeighbors &&process_neighbors, IsTarget &&is_target,
Visit &&visit) {
// explicitly specify the ASC template argument to avoid recursion
return bfs<use_seen, std::initializer_list<Key>>(
sources, std::forward<ProcessNeighbors>(process_neighbors),
std::forward<IsTarget>(is_target), std::forward<Visit>(visit));
}

template <bool use_seen = true, class Key,
detail::ProcessNeighbors<Key> ProcessNeighbors,
detail::IsTarget<Key> IsTarget = detail::optional_func,
detail::Visit<Key> Visit = detail::optional_func>
int bfs(const Key &source, ProcessNeighbors &&process_neighbors,
IsTarget &&is_target, Visit &&visit) {
return bfs<use_seen, std::initializer_list<Key>>(
{source}, std::forward<ProcessNeighbors>(process_neighbors),
std::forward<IsTarget>(is_target), std::forward<Visit>(visit));
}

/**
* Generic BFS on an arbitrary graph, with no duplicate checking.
*
Expand Down Expand Up @@ -422,12 +454,10 @@ struct tarjan_entry {
* Components are returned in topological order, along with a set of the
* directed edges between the components.
*/
template <class Key,
util::concepts::any_iterable_collection<Key> SourceCollection,
template <detail::AnySourceCollection ASC, class Key = ASC::value_type,
detail::ProcessNeighbors<Key> ProcessNeighbors>
std::pair<std::vector<std::vector<Key>>, std::set<std::pair<int, int>>>
tarjan_scc(const SourceCollection &sources,
ProcessNeighbors &&process_neighbors) {
tarjan_scc(const ASC &sources, ProcessNeighbors &&process_neighbors) {
int index = 0;
std::stack<Key> S{};
std::vector<std::vector<Key>> components{};
Expand Down Expand Up @@ -517,12 +547,20 @@ tarjan_scc(const SourceCollection &sources,
return {std::move(components), std::move(reversed_links)};
}

template <class Key, detail::ProcessNeighbors<Key> ProcessNeighbors>
std::pair<std::vector<std::vector<Key>>, std::set<std::pair<int, int>>>
tarjan_scc(const std::initializer_list<Key> &sources,
ProcessNeighbors &&process_neighbors) {
return tarjan_scc<std::initializer_list<Key>>(
sources, std::forward<ProcessNeighbors>(process_neighbors));
}

template <class Key, detail::ProcessNeighbors<Key> ProcessNeighbors>
std::pair<std::vector<std::vector<Key>>, std::set<std::pair<int, int>>>
tarjan_scc(const Key &source, ProcessNeighbors &&process_neighbors) {
const std::initializer_list<Key> sources = {source};
return tarjan_scc<Key>(sources,
std::forward<ProcessNeighbors>(process_neighbors));
return tarjan_scc(sources,
std::forward<ProcessNeighbors>(process_neighbors));
}

/**
Expand Down Expand Up @@ -804,6 +842,8 @@ void _lint_helper_template(
std::function<bool(const Key &, const Key &, int)> visit_with_parent_bool,
std::function<int(const Key &, const Key &)> get_distance,
std::function<int(const Key &)> heuristic) {
const std::vector<Key> sources_vec{source, source};

bfs<true>(source, process_neighbors, is_target, {});
bfs<true>(source, process_neighbors, {}, visit);
bfs<true>(source, process_neighbors, {}, visit_bool);
Expand All @@ -815,6 +855,28 @@ void _lint_helper_template(
bfs<false>(source, process_neighbors, is_target, visit);
bfs<false>(source, process_neighbors, is_target, visit_bool);

bfs<true>({source, source}, process_neighbors, is_target, {});
bfs<true>({source, source}, process_neighbors, {}, visit);
bfs<true>({source, source}, process_neighbors, {}, visit_bool);
bfs<true>({source, source}, process_neighbors, is_target, visit);
bfs<true>({source, source}, process_neighbors, is_target, visit_bool);
bfs<false>({source, source}, process_neighbors, is_target, {});
bfs<false>({source, source}, process_neighbors, {}, visit);
bfs<false>({source, source}, process_neighbors, {}, visit_bool);
bfs<false>({source, source}, process_neighbors, is_target, visit);
bfs<false>({source, source}, process_neighbors, is_target, visit_bool);

bfs<true>(sources_vec, process_neighbors, is_target, {});
bfs<true>(sources_vec, process_neighbors, {}, visit);
bfs<true>(sources_vec, process_neighbors, {}, visit_bool);
bfs<true>(sources_vec, process_neighbors, is_target, visit);
bfs<true>(sources_vec, process_neighbors, is_target, visit_bool);
bfs<false>(sources_vec, process_neighbors, is_target, {});
bfs<false>(sources_vec, process_neighbors, {}, visit);
bfs<false>(sources_vec, process_neighbors, {}, visit_bool);
bfs<false>(sources_vec, process_neighbors, is_target, visit);
bfs<false>(sources_vec, process_neighbors, is_target, visit_bool);

bfs_manual_dedupe(source, process_neighbors, is_target, {});
bfs_manual_dedupe(source, process_neighbors, {}, visit);
bfs_manual_dedupe(source, process_neighbors, {}, visit_bool);
Expand Down Expand Up @@ -847,8 +909,8 @@ void _lint_helper_template(
topo_sort(source, process_neighbors);

tarjan_scc(source, process_neighbors);
const std::vector<Key> sources{source, source};
tarjan_scc<Key>(sources, process_neighbors);
tarjan_scc({source, source}, process_neighbors);
tarjan_scc(sources_vec, process_neighbors);

longest_path_dag(source, process_neighbors, get_distance, is_target);

Expand Down

0 comments on commit ae0e6bc

Please sign in to comment.