From d406eb156abaa736a94af4b3d6fd7214ee9d95a7 Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Fri, 13 Dec 2024 13:10:35 +0100 Subject: [PATCH] [CP-SAT] speed up no_overlap_2d (presolve, propagation); tweak shared tree workers; improve hint preservation during presolve; remove memory contention --- ortools/sat/BUILD.bazel | 8 +- ortools/sat/cp_model_checker.cc | 35 +- ortools/sat/cp_model_lns.cc | 14 +- ortools/sat/cp_model_lns.h | 12 +- ortools/sat/cp_model_presolve.cc | 373 ++++++++++++++----- ortools/sat/cp_model_presolve.h | 22 +- ortools/sat/cp_model_solver.cc | 42 ++- ortools/sat/diffn.cc | 23 +- ortools/sat/diffn_util.cc | 286 +++++++------- ortools/sat/diffn_util.h | 24 +- ortools/sat/diffn_util_test.cc | 66 +++- ortools/sat/disjunctive.cc | 8 + ortools/sat/feasibility_jump.cc | 8 +- ortools/sat/integer_expr.cc | 4 +- ortools/sat/integer_expr.h | 108 ++---- ortools/sat/integer_expr_test.cc | 53 ++- ortools/sat/linear_programming_constraint.cc | 30 +- ortools/sat/linear_programming_constraint.h | 5 +- ortools/sat/presolve_context.cc | 126 ++++--- ortools/sat/presolve_context.h | 37 +- ortools/sat/presolve_context_test.cc | 6 +- ortools/sat/rins.cc | 21 +- ortools/sat/sat_decision.cc | 11 +- ortools/sat/sat_parameters.proto | 8 +- ortools/sat/synchronization.cc | 33 +- ortools/sat/synchronization.h | 46 ++- 26 files changed, 889 insertions(+), 520 deletions(-) diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index fd455b66e1..3753bf305a 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -196,6 +196,7 @@ cc_library( "//ortools/port:proto_utils", "//ortools/util:saturated_arithmetic", "//ortools/util:sorted_interval_list", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -2926,9 +2927,7 @@ cc_library( "//ortools/base", "//ortools/base:stl_util", "//ortools/base:strong_vector", - "//ortools/graph", "//ortools/graph:connected_components", - "//ortools/graph:minimum_spanning_tree", "//ortools/graph:strongly_connected_components", "//ortools/util:fixed_shape_binary_tree", "//ortools/util:integer_pq", @@ -3110,9 +3109,11 @@ cc_test( ":integer_base", ":util", "//ortools/base", - "//ortools/base:gmock_main", + "//ortools/base:fuzztest", + "//ortools/base:gmock", "//ortools/graph:connected_components", "//ortools/graph:strongly_connected_components", + "//ortools/util:saturated_arithmetic", "//ortools/util:strong_integers", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -3123,6 +3124,7 @@ cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_benchmark//:benchmark", + "@com_google_fuzztest//fuzztest:fuzztest_gtest_main", ], ) diff --git a/ortools/sat/cp_model_checker.cc b/ortools/sat/cp_model_checker.cc index 40bca43a50..d9d8d0a141 100644 --- a/ortools/sat/cp_model_checker.cc +++ b/ortools/sat/cp_model_checker.cc @@ -17,11 +17,13 @@ #include #include #include +#include #include #include #include #include +#include "absl/algorithm/container.h" #include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" @@ -1420,6 +1422,7 @@ class ConstraintChecker { const auto& arg = ct.no_overlap_2d(); // Those intervals from arg.x_intervals and arg.y_intervals where both // the x and y intervals are enforced. + bool has_zero_sizes = false; std::vector enforced_rectangles; { const int num_intervals = arg.x_intervals_size(); @@ -1432,16 +1435,34 @@ class ConstraintChecker { .x_max = IntervalEnd(x.interval()), .y_min = IntervalStart(y.interval()), .y_max = IntervalEnd(y.interval())}); + const auto& rect = enforced_rectangles.back(); + if (rect.x_min == rect.x_max || rect.y_min == rect.y_max) { + has_zero_sizes = true; + } } } } - const std::vector> intersections = - FindPartialRectangleIntersectionsAlsoEmpty(enforced_rectangles); - if (!intersections.empty()) { - VLOG(1) << "Rectangles " << intersections[0].first << "(" - << enforced_rectangles[intersections[0].first] << ") and " - << intersections[0].second << "(" - << enforced_rectangles[intersections[0].second] + + std::optional> one_intersection; + if (!has_zero_sizes) { + absl::c_stable_sort(enforced_rectangles, + [](const Rectangle& a, const Rectangle& b) { + return a.x_min < b.x_min; + }); + one_intersection = FindOneIntersectionIfPresent(enforced_rectangles); + } else { + const std::vector> intersections = + FindPartialRectangleIntersections(enforced_rectangles); + if (!intersections.empty()) { + one_intersection = intersections[0]; + } + } + + if (one_intersection != std::nullopt) { + VLOG(1) << "Rectangles " << one_intersection->first << "(" + << enforced_rectangles[one_intersection->first] << ") and " + << one_intersection->second << "(" + << enforced_rectangles[one_intersection->second] << ") are not disjoint."; return false; } diff --git a/ortools/sat/cp_model_lns.cc b/ortools/sat/cp_model_lns.cc index 5a76d03b37..9c22c5950f 100644 --- a/ortools/sat/cp_model_lns.cc +++ b/ortools/sat/cp_model_lns.cc @@ -471,6 +471,7 @@ NeighborhoodGeneratorHelper::GetActiveRectangles( } std::vector results; + results.reserve(active_rectangles.size()); for (const auto& [rectangle, no_overlap_2d_constraints] : active_rectangles) { ActiveRectangle& result = results.emplace_back(); result.x_interval = rectangle.first; @@ -532,7 +533,9 @@ void RestrictAffineExpression(const LinearExpressionProto& expr, const Domain domain = ReadDomainFromProto(mutable_proto->variables(expr.vars(0))) .IntersectionWith(implied_domain); - FillDomainInProto(domain, mutable_proto->mutable_variables(expr.vars(0))); + if (!domain.IsEmpty()) { + FillDomainInProto(domain, mutable_proto->mutable_variables(expr.vars(0))); + } } struct StartEndIndex { @@ -1034,7 +1037,14 @@ std::vector> NeighborhoodGeneratorHelper::GetRoutingPaths( Neighborhood NeighborhoodGeneratorHelper::FixGivenVariables( const CpSolverResponse& base_solution, const absl::flat_hash_set& variables_to_fix) const { - Neighborhood neighborhood; + int initial_num_variables = 0; + { + absl::ReaderMutexLock domain_lock(&domain_mutex_); + + initial_num_variables = + model_proto_with_only_variables_->variables().size(); + } + Neighborhood neighborhood(initial_num_variables); // TODO(user): Maybe relax all variables in the objective when the number // is small or negligible compared to the number of variables. diff --git a/ortools/sat/cp_model_lns.h b/ortools/sat/cp_model_lns.h index d72e6a034e..f9051366f7 100644 --- a/ortools/sat/cp_model_lns.h +++ b/ortools/sat/cp_model_lns.h @@ -44,6 +44,14 @@ namespace sat { // Neighborhood returned by Neighborhood generators. struct Neighborhood { + static constexpr int kDefaultArenaSizePerVariable = 128; + + explicit Neighborhood(int num_variables_hint = 10) + : arena_buffer(kDefaultArenaSizePerVariable * num_variables_hint), + arena(std::make_unique(arena_buffer.data(), + arena_buffer.size())), + delta(*google::protobuf::Arena::Create(arena.get())) {} + // True if neighborhood generator was able to generate a neighborhood. bool is_generated = false; @@ -58,7 +66,9 @@ struct Neighborhood { // The delta will contains all variables from the initial model, potentially // with updated domains. // It can contains new variables and new constraints, and solution hinting. - CpModelProto delta; + std::vector arena_buffer; + std::unique_ptr arena; + CpModelProto& delta; // Neighborhood Id. Used to identify the neighborhood by a generator. // Currently only used by WeightedRandomRelaxationNeighborhoodGenerator. diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index 77d01a7483..d9b02452e7 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -7062,6 +7063,7 @@ void CpModelPresolver::RunPropagatorsForConstraint(const ConstraintProto& ct) { std::vector variable_mapping; CreateValidModelWithSingleConstraint(ct, context_, &variable_mapping, &tmp_model_); + DCHECK_EQ(ValidateCpModel(tmp_model_, false), ""); if (!LoadModelForPresolve(tmp_model_, std::move(local_params), context_, &model, "single constraint")) { return; @@ -8488,12 +8490,20 @@ bool CpModelPresolver::PresolveOneConstraint(int c) { DetectDuplicateIntervals(c, ct->mutable_no_overlap()->mutable_intervals()); return PresolveNoOverlap(ct); - case ConstraintProto::kNoOverlap2D: - DetectDuplicateIntervals( - c, ct->mutable_no_overlap_2d()->mutable_x_intervals()); - DetectDuplicateIntervals( - c, ct->mutable_no_overlap_2d()->mutable_y_intervals()); - return PresolveNoOverlap2D(c, ct); + case ConstraintProto::kNoOverlap2D: { + const bool changed = PresolveNoOverlap2D(c, ct); + if (ct->constraint_case() == ConstraintProto::kNoOverlap2D) { + // For 2D, we don't exploit index duplication between x/y so it is not + // important to do it beforehand. Moreover in some situation + // PresolveNoOverlap2D() remove a lot of interval, so better to do it + // afterwards. + DetectDuplicateIntervals( + c, ct->mutable_no_overlap_2d()->mutable_x_intervals()); + DetectDuplicateIntervals( + c, ct->mutable_no_overlap_2d()->mutable_y_intervals()); + } + return changed; + } case ConstraintProto::kCumulative: DetectDuplicateIntervals(c, ct->mutable_cumulative()->mutable_intervals()); @@ -12473,6 +12483,8 @@ bool ModelCopy::ImportAndSimplifyConstraints( // refer to interval before them. std::vector constraints_using_intervals; + interval_mapping_.assign(in_model.constraints().size(), -1); + starting_constraint_index_ = context_->working_model->constraints_size(); for (int c = 0; c < in_model.constraints_size(); ++c) { if (active_constraints != nullptr && !active_constraints(c)) continue; @@ -13005,9 +13017,9 @@ bool ModelCopy::CopyLinMax(const ConstraintProto& ct) { // Regroup all constant terms and copy the other. int64_t max_of_fixed_terms = std::numeric_limits::min(); for (const auto& expr : ct.lin_max().exprs()) { - if (context_->IsFixed(expr)) { - max_of_fixed_terms = - std::max(max_of_fixed_terms, context_->FixedValue(expr)); + const std::optional fixed = context_->FixedValueOrNullopt(expr); + if (fixed != std::nullopt) { + max_of_fixed_terms = std::max(max_of_fixed_terms, fixed.value()); } else { // copy. if (new_ct == nullptr) { @@ -13175,9 +13187,10 @@ void ModelCopy::CopyAndMapNoOverlap(const ConstraintProto& ct) { context_->working_model->add_constraints()->mutable_no_overlap(); new_ct->mutable_intervals()->Reserve(ct.no_overlap().intervals().size()); for (const int index : ct.no_overlap().intervals()) { - const auto it = interval_mapping_.find(index); - if (it == interval_mapping_.end()) continue; - new_ct->add_intervals(it->second); + const int new_index = interval_mapping_[index]; + if (new_index != -1) { + new_ct->add_intervals(new_index); + } } } @@ -13190,12 +13203,12 @@ void ModelCopy::CopyAndMapNoOverlap2D(const ConstraintProto& ct) { new_ct->mutable_x_intervals()->Reserve(num_intervals); new_ct->mutable_y_intervals()->Reserve(num_intervals); for (int i = 0; i < num_intervals; ++i) { - const auto x_it = interval_mapping_.find(ct.no_overlap_2d().x_intervals(i)); - if (x_it == interval_mapping_.end()) continue; - const auto y_it = interval_mapping_.find(ct.no_overlap_2d().y_intervals(i)); - if (y_it == interval_mapping_.end()) continue; - new_ct->add_x_intervals(x_it->second); - new_ct->add_y_intervals(y_it->second); + const int new_x = interval_mapping_[ct.no_overlap_2d().x_intervals(i)]; + if (new_x == -1) continue; + const int new_y = interval_mapping_[ct.no_overlap_2d().y_intervals(i)]; + if (new_y == -1) continue; + new_ct->add_x_intervals(new_x); + new_ct->add_y_intervals(new_y); } } @@ -13214,10 +13227,11 @@ bool ModelCopy::CopyAndMapCumulative(const ConstraintProto& ct) { new_ct->mutable_intervals()->Reserve(num_intervals); new_ct->mutable_demands()->Reserve(num_intervals); for (int i = 0; i < num_intervals; ++i) { - const auto it = interval_mapping_.find(ct.cumulative().intervals(i)); - if (it == interval_mapping_.end()) continue; - new_ct->add_intervals(it->second); - *new_ct->add_demands() = ct.cumulative().demands(i); + const int new_index = interval_mapping_[ct.cumulative().intervals(i)]; + if (new_index != -1) { + new_ct->add_intervals(new_index); + *new_ct->add_demands() = ct.cumulative().demands(i); + } } return true; @@ -13534,18 +13548,49 @@ CpSolverStatus CpModelPresolver::InfeasibleStatus() { return CpSolverStatus::INFEASIBLE; } -void CpModelPresolver::InitializeMappingModelVariables() { - // Sync the domains. - for (int i = 0; i < context_->working_model->variables_size(); ++i) { - FillDomainInProto(context_->DomainOf(i), - context_->working_model->mutable_variables(i)); - DCHECK_GT(context_->working_model->variables(i).domain_size(), 0); +// At the end of presolve, the mapping model is initialized to contains all +// the variable from the original model + the one created during presolve +// expand. It also contains the tightened domains. +namespace { +void InitializeMappingModelVariables(absl::Span domains, + std::vector* fixed_postsolve_mapping, + CpModelProto* mapping_proto) { + // Extend the fixed mapping to take into account all newly created variable + // since the time it was constructed. + int old_num_variables = mapping_proto->variables().size(); + while (fixed_postsolve_mapping->size() < domains.size()) { + mapping_proto->add_variables(); + fixed_postsolve_mapping->push_back(old_num_variables++); + DCHECK_EQ(old_num_variables, mapping_proto->variables().size()); + } + + // Overwrite the domains. + // + // Note that if the fixed_postsolve_mapping was not null, the mapping model + // should contains the original variable domains at the time the fixed mapping + // was computed. + for (int i = 0; i < domains.size(); ++i) { + FillDomainInProto(domains[i], mapping_proto->mutable_variables( + (*fixed_postsolve_mapping)[i])); } - // Set the variables of the mapping_model. - context_->mapping_model->mutable_variables()->CopyFrom( - context_->working_model->variables()); + // Remap the mapping proto. + // We only deal with constraint here, do not touch the rest. + // + // TODO(user): Maybe we should have a real "postsolve" proto so we can + // interleave postsolve "constraint" and remapping phase. This would allow to + // do that in the middle of the presolve. But maybe this is not as impactful. + auto mapping_function = [fixed_postsolve_mapping](int* ref) { + const int image = (*fixed_postsolve_mapping)[PositiveRef(*ref)]; + CHECK_GE(image, 0); + *ref = RefIsPositive(*ref) ? image : NegatedRef(image); + }; + for (ConstraintProto& ct_ref : *mapping_proto->mutable_constraints()) { + ApplyToAllVariableIndices(mapping_function, &ct_ref); + ApplyToAllLiteralIndices(mapping_function, &ct_ref); + } } +} // namespace void CpModelPresolver::ExpandCpModelAndCanonicalizeConstraints() { const int num_constraints_before_expansion = @@ -13575,6 +13620,49 @@ void CpModelPresolver::ExpandCpModelAndCanonicalizeConstraints() { } } +namespace { + +void UpdateHintInProto(PresolveContext* context) { + CpModelProto* proto = context->working_model; + if (!proto->has_solution_hint()) return; + + // Extract the new hint information from the context. + auto* mutable_hint = proto->mutable_solution_hint(); + mutable_hint->clear_vars(); + mutable_hint->clear_values(); + const int num_vars = context->working_model->variables().size(); + for (int hinted_var = 0; hinted_var < num_vars; ++hinted_var) { + if (!context->VarHasSolutionHint(hinted_var)) continue; + + // Note the use of ClampedSolutionHint() instead of SolutionHint() below. + // This also make sure a hint of INT_MIN or INT_MAX does not overflow. + // + // TODO(user): This should no longer be necessary, as we try to do that as + // soon as we update the domains, but we still do it to be safe. + int64_t hinted_value; + + // If the variable had a hint and has a representative with a hint, we also + // hint it using the representative value as a "ground truth". + const auto relation = context->GetAffineRelation(hinted_var); + if (relation.representative != hinted_var) { + // Lets first fetch the value of the representative. + const int rep = relation.representative; + if (!context->VarHasSolutionHint(rep)) continue; + const int64_t rep_value = context->ClampedSolutionHint(rep); + + // Apply the affine relation. + hinted_value = rep_value * relation.coeff + relation.offset; + } else { + hinted_value = context->ClampedSolutionHint(hinted_var); + } + + mutable_hint->add_vars(hinted_var); + mutable_hint->add_values(hinted_value); + } +} + +} // namespace + // The presolve works as follow: // // First stage: @@ -13591,15 +13679,7 @@ void CpModelPresolver::ExpandCpModelAndCanonicalizeConstraints() { // - Everything will be remapped so that only the variables appearing in some // constraints will be kept and their index will be in [0, num_new_variables). CpSolverStatus CpModelPresolver::Presolve() { - // We copy the search strategy to the mapping_model. - for (const auto& decision_strategy : - context_->working_model->search_strategy()) { - *(context_->mapping_model->add_search_strategy()) = decision_strategy; - } - - // Initialize the initial context.working_model domains. context_->InitializeNewDomains(); - context_->LoadSolutionHint(); // If the objective is a floating point one, we scale it. // @@ -13607,7 +13687,9 @@ CpSolverStatus CpModelPresolver::Presolve() { // just need to isolate more the "dual" reduction that usually need to look at // the objective. if (context_->working_model->has_floating_point_objective()) { - if (!context_->ScaleFloatingPointObjective()) { + context_->WriteVariableDomainsToProto(); + if (!ScaleFloatingPointObjective(context_->params(), logger_, + context_->working_model)) { SOLVER_LOG(logger_, "The floating point objective cannot be scaled with enough " "precision"); @@ -13622,13 +13704,29 @@ CpSolverStatus CpModelPresolver::Presolve() { context_->working_model->objective(); } + // If there is a large proprotion of fixed variable, lets remap the model + // before we start the actual presolve. This is useful for LNS in particular. + // + // fixed_postsolve_mapping[i] will contains the original index of the variable + // that will be at position i after MaybeRemoveFixedVariables(). If the + // mapping is left empty, it will be set to the identity mapping later by + // InitializeMappingModelVariables(). + std::vector fixed_postsolve_mapping; + if (!MaybeRemoveFixedVariables(&fixed_postsolve_mapping)) { + return InfeasibleStatus(); + } + + // Initialize the initial context.working_model domains. // Initialize the objective and the constraint <-> variable graph. // // Note that we did some basic presolving during the first copy of the model. // This is important has initializing the constraint <-> variable graph can // be costly, so better to remove trivially feasible constraint for instance. + context_->InitializeNewDomains(); + context_->LoadSolutionHint(); context_->ReadObjectiveFromProto(); if (!context_->CanonicalizeObjective()) return InfeasibleStatus(); + context_->UpdateNewConstraintsVariableUsage(); context_->RegisterVariablesUsedInAssumptions(); DCHECK(context_->ConstraintVariableUsageIsConsistent()); @@ -13658,7 +13756,10 @@ CpSolverStatus CpModelPresolver::Presolve() { // filling the tightened variables. Even without presolve, we do some // trivial presolving during the initial copy of the model, and expansion // might do more. - InitializeMappingModelVariables(); + context_->WriteVariableDomainsToProto(); + InitializeMappingModelVariables(context_->AllDomains(), + &fixed_postsolve_mapping, + context_->mapping_model); // We don't want to run postsolve when the presolve is disabled, but the // expansion might have added some constraints to the mapping model. To @@ -13942,7 +14043,10 @@ CpSolverStatus CpModelPresolver::Presolve() { } // Sync the domains and initialize the mapping model variables. - InitializeMappingModelVariables(); + context_->WriteVariableDomainsToProto(); + InitializeMappingModelVariables(context_->AllDomains(), + &fixed_postsolve_mapping, + context_->mapping_model); // Remove all the unused variables from the presolved model. postsolve_mapping_->clear(); @@ -13960,7 +14064,7 @@ CpSolverStatus CpModelPresolver::Presolve() { const int r = PositiveRef(context_->GetAffineRelation(i).representative); if (mapping[r] == -1 && !context_->VariableIsNotUsedAnymore(r)) { mapping[r] = postsolve_mapping_->size(); - postsolve_mapping_->push_back(r); + postsolve_mapping_->push_back(fixed_postsolve_mapping[r]); } continue; } @@ -13979,7 +14083,8 @@ CpSolverStatus CpModelPresolver::Presolve() { // We prefer to fix them to zero if possible. ++num_unused_variables; FillDomainInProto(Domain(context_->DomainOf(i).SmallestValue()), - context_->mapping_model->mutable_variables(i)); + context_->mapping_model->mutable_variables( + fixed_postsolve_mapping[i])); continue; } @@ -13995,7 +14100,7 @@ CpSolverStatus CpModelPresolver::Presolve() { } mapping[i] = postsolve_mapping_->size(); - postsolve_mapping_->push_back(i); + postsolve_mapping_->push_back(fixed_postsolve_mapping[i]); } context_->UpdateRuleStats(absl::StrCat("presolve: ", num_unused_variables, " unused variables removed.")); @@ -14017,7 +14122,11 @@ CpSolverStatus CpModelPresolver::Presolve() { } DCHECK(context_->ConstraintVariableUsageIsConsistent()); - ApplyVariableMapping(mapping, *context_); + UpdateHintInProto(context_); + const int old_size = postsolve_mapping_->size(); + ApplyVariableMapping(absl::MakeSpan(mapping), postsolve_mapping_, + context_->working_model); + CHECK_EQ(old_size, postsolve_mapping_->size()); // Compact all non-empty constraint at the beginning. RemoveEmptyConstraints(); @@ -14054,15 +14163,19 @@ CpSolverStatus CpModelPresolver::Presolve() { return CpSolverStatus::UNKNOWN; } -void ApplyVariableMapping(const std::vector& mapping, - const PresolveContext& context) { - CpModelProto* proto = context.working_model; - +void ApplyVariableMapping(absl::Span mapping, + std::vector* reverse_mapping, + CpModelProto* proto) { // Remap all the variable/literal references in the constraints and the // enforcement literals in the variables. - auto mapping_function = [&mapping](int* ref) { - const int image = mapping[PositiveRef(*ref)]; - CHECK_GE(image, 0); + auto mapping_function = [mapping, reverse_mapping](int* ref) mutable { + const int var = PositiveRef(*ref); + int image = mapping[var]; + if (image < 0) { + // We extend the mapping if this variable is still used. + image = mapping[var] = reverse_mapping->size(); + reverse_mapping->push_back(var); + } *ref = RefIsPositive(*ref) ? image : NegatedRef(image); }; for (ConstraintProto& ct_ref : *proto->mutable_constraints()) { @@ -14082,6 +14195,23 @@ void ApplyVariableMapping(const std::vector& mapping, mapping_function(&mutable_ref); } + // Remap the symmetries. Note that we should have properly dealt with fixed + // orbit and such in FilterOrbitOnUnusedOrFixedVariables(). + if (proto->has_symmetry()) { + for (SparsePermutationProto& generator : + *proto->mutable_symmetry()->mutable_permutations()) { + for (int& var : *generator.mutable_support()) { + mapping_function(&var); + } + } + + // We clear the orbitope info (we don't really use it after presolve). + proto->mutable_symmetry()->clear_orbitopes(); + } + + // Note: For the rest of the mapping, if mapping[i] is -1, we can just ignore + // the variable instead of trying to map it. + // Remap the search decision heuristic. // Note that we delete any heuristic related to a removed variable. for (DecisionStrategyProto& strategy : *proto->mutable_search_strategy()) { @@ -14108,41 +14238,30 @@ void ApplyVariableMapping(const std::vector& mapping, new_size); } - // Remap the solution hint. Note that after remapping, we may have duplicate - // variables. For instance, identical constant variables are mapped to a - // single one. Another case is variables with the same representative. In the - // later case we only keep the representative, since the hint of the others - // might no longer be valid (the hint of non-representative variables is not - // updated). In the former case we keep only the hint of the first occurrence. + // Remap the solution hint. if (proto->has_solution_hint()) { - absl::flat_hash_set used_vars; auto* mutable_hint = proto->mutable_solution_hint(); - mutable_hint->clear_vars(); - mutable_hint->clear_values(); - const int num_vars = context.working_model->variables().size(); - for (int hinted_var = 0; hinted_var < num_vars; ++hinted_var) { - if (context.GetAffineRelation(hinted_var).representative != hinted_var) { - continue; - } - if (!context.VarHasSolutionHint(hinted_var)) continue; - int64_t hinted_value = context.SolutionHint(hinted_var); - // We always move a hint within bounds. - // This also make sure a hint of INT_MIN or INT_MAX does not overflow. - if (hinted_value < context.MinOf(hinted_var)) { - hinted_value = context.MinOf(hinted_var); - } - if (hinted_value > context.MaxOf(hinted_var)) { - hinted_value = context.MaxOf(hinted_var); - } + // Note that after remapping, we may have duplicate variables. For instance, + // identical constant variables are mapped to a single one. So we make sure + // we don't output duplicates here and just keep the first occurrence. + absl::flat_hash_set hinted_images; + int new_size = 0; + const int old_size = mutable_hint->vars().size(); + for (int i = 0; i < old_size; ++i) { + const int hinted_var = mutable_hint->vars(i); + const int64_t hinted_value = mutable_hint->values(i); const int image = mapping[hinted_var]; if (image >= 0) { - if (!used_vars.insert(image).second) continue; - mutable_hint->add_vars(image); - mutable_hint->add_values(hinted_value); + if (!hinted_images.insert(image).second) continue; + mutable_hint->set_vars(new_size, image); + mutable_hint->set_values(new_size, hinted_value); + ++new_size; } } + mutable_hint->mutable_vars()->Truncate(new_size); + mutable_hint->mutable_values()->Truncate(new_size); } // Move the variable definitions. @@ -14160,26 +14279,96 @@ void ApplyVariableMapping(const std::vector& mapping, proto->add_variables()->Swap(&proto_ref); } - // Check that all variables are used. + // Check that all variables have a non-empty domain. for (const IntegerVariableProto& v : proto->variables()) { CHECK_GT(v.domain_size(), 0); } +} - // Remap the symmetries. Note that we should have properly dealt with fixed - // orbit and such in FilterOrbitOnUnusedOrFixedVariables(). - if (proto->has_symmetry()) { - for (SparsePermutationProto& generator : - *proto->mutable_symmetry()->mutable_permutations()) { - for (int& var : *generator.mutable_support()) { - CHECK(RefIsPositive(var)); - var = mapping[var]; - CHECK_NE(var, -1); +bool CpModelPresolver::MaybeRemoveFixedVariables( + std::vector* postsolve_mapping) { + postsolve_mapping->clear(); + if (!context_->params().remove_fixed_variables_early()) return true; + if (!context_->params().cp_model_presolve()) return true; + + // This is supposed to be already called, but it is a no-opt if this was the + // case, and it comment nicely that we do require domains to be up to date + // in the context. + context_->InitializeNewDomains(); + + // Initialize the mapping to remove all fixed variables. + const int num_vars = context_->working_model->variables().size(); + std::vector mapping(num_vars, -1); + for (int i = 0; i < num_vars; ++i) { + if (context_->IsFixed(i)) continue; + mapping[i] = postsolve_mapping->size(); + postsolve_mapping->push_back(i); + } + + // Lets only do this if the proportion of fixed variables is large enough. + const int num_fixed = num_vars - postsolve_mapping->size(); + if (num_fixed < 1000 || num_fixed * 2 <= num_vars) { + postsolve_mapping->clear(); + return true; + } + + // TODO(user): Right now the copy do not remove fixed variable from the + // objective, so we do that here so that these variable should not appear + // anymore. Fix that. + if (context_->working_model->has_objective()) { + auto* objective = context_->working_model->mutable_objective(); + auto* mutable_vars = objective->mutable_vars(); + auto* mutable_coeffs = objective->mutable_coeffs(); + const int old_size = objective->vars().size(); + int64_t offset = 0; + int new_size = 0; + for (int i = 0; i < old_size; ++i) { + const int var = objective->vars(i); + const int64_t coeff = objective->coeffs(i); + if (context_->IsFixed(var)) { + offset += context_->FixedValue(var) * coeff; + continue; } + mutable_vars->Set(new_size, var); + mutable_coeffs->Set(new_size, coeff); + ++new_size; } + mutable_vars->Truncate(new_size); + mutable_coeffs->Truncate(new_size); + objective->set_offset(objective->offset() + offset); - // We clear the orbitope info (we don't really use it after presolve). - proto->mutable_symmetry()->clear_orbitopes(); + context_->ReadObjectiveFromProto(); + if (!context_->CanonicalizeObjective()) return false; + if (!PropagateObjective()) return false; + if (context_->ModelIsUnsat()) return false; + context_->WriteObjectiveToProto(); } + + // Copy the current domains into the mapping model. + // Note that we are not sure the domain where properly written. + context_->WriteVariableDomainsToProto(); + *context_->mapping_model->mutable_variables() = + context_->working_model->variables(); + + // Reset some part of the context, it will re-read the new domains below. + context_->ResetAfterCopy(); + + SOLVER_LOG(logger_, "Large number of fixed variables ", + FormatCounter(num_fixed), " / ", FormatCounter(num_vars), + ", doing a first remapping phase to go down to ", + FormatCounter(postsolve_mapping->size()), " variables."); + + // Perform the actual mapping. + // Note that this might re-add fixed variable that are still used. + const int old_size = postsolve_mapping->size(); + ApplyVariableMapping(absl::MakeSpan(mapping), postsolve_mapping, + context_->working_model); + if (postsolve_mapping->size() > old_size) { + const int new_extra = postsolve_mapping->size() - old_size; + SOLVER_LOG(logger_, "TODO: ", new_extra, + " fixed variables still required in the model!"); + } + return true; } namespace { diff --git a/ortools/sat/cp_model_presolve.h b/ortools/sat/cp_model_presolve.h index ec75d6f7cf..e295c47839 100644 --- a/ortools/sat/cp_model_presolve.h +++ b/ortools/sat/cp_model_presolve.h @@ -43,13 +43,14 @@ namespace sat { // Replaces all the instance of a variable i (and the literals referring to it) // by mapping[i]. The definition of variables i is also moved to its new index. -// Variables with a negative mapping value are ignored and it is an error if -// such variable is referenced anywhere (this is CHECKed). +// If mapping[i] < 0 the variable can be ignored if possible. If it is not +// possible, then we will use a new index for it (at the end) and the mapping +// will be updated to reflect that. // -// The image of the mapping should be dense in [0, new_num_variables), this is -// also CHECKed. -void ApplyVariableMapping(const std::vector& mapping, - const PresolveContext& context); +// The image of the mapping should be dense in [0, reverse_mapping->size()). +void ApplyVariableMapping(absl::Span mapping, + std::vector* reverse_mapping, + CpModelProto* proto); // Presolves the initial content of presolved_model. // @@ -95,10 +96,9 @@ class CpModelPresolver { // A simple helper that logs the rules applied so far and return INFEASIBLE. CpSolverStatus InfeasibleStatus(); - // At the end of presolve, the mapping model is initialized to contains all - // the variable from the original model + the one created during presolve - // expand. It also contains the tightened domains. - void InitializeMappingModelVariables(); + // If there is a large proportion of fixed variables, remap the whole proto + // before we start the presolve. + bool MaybeRemoveFixedVariables(std::vector* postsolve_mapping); // Runs the inner loop of the presolver. bool ProcessChangedVariables(std::vector* in_queue, @@ -471,7 +471,7 @@ class ModelCopy { // Temp vectors. std::vector non_fixed_variables_; std::vector non_fixed_coefficients_; - absl::flat_hash_map interval_mapping_; + std::vector interval_mapping_; int starting_constraint_index_ = 0; std::vector temp_enforcement_literals_; diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index df993391c1..66b0a1fd9b 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -1214,16 +1214,16 @@ class LnsSolver : public SubSolver { shared_->response->SolutionsRepository(); if (repo.NumSolutions() > 0) { base_response.set_status(CpSolverStatus::FEASIBLE); - const SharedSolutionRepository::Solution solution = - repo.GetRandomBiasedSolution(random); - for (const int64_t value : solution.variable_values) { - base_response.add_solution(value); - } + std::shared_ptr::Solution> + solution = repo.GetRandomBiasedSolution(random); + base_response.mutable_solution()->Assign( + solution->variable_values.begin(), + solution->variable_values.end()); // Note: We assume that the solution rank is the solution internal // objective. - data.initial_best_objective = repo.GetSolution(0).rank; - data.base_objective = solution.rank; + data.initial_best_objective = repo.GetSolution(0)->rank; + data.base_objective = solution->rank; } else { base_response.set_status(CpSolverStatus::UNKNOWN); @@ -1282,8 +1282,17 @@ class LnsSolver : public SubSolver { shared_->time_limit->UpdateLocalLimit(local_time_limit); // Presolve and solve the LNS fragment. - CpModelProto lns_fragment; - CpModelProto mapping_proto; + int64_t buffer_size; + { + absl::MutexLock l(&next_arena_size_mutex_); + buffer_size = next_arena_size_; + } + std::vector arena_buffer(buffer_size); + google::protobuf::Arena arena(arena_buffer.data(), arena_buffer.size()); + CpModelProto& lns_fragment = + *google::protobuf::Arena::Create(&arena); + CpModelProto& mapping_proto = + *google::protobuf::Arena::Create(&arena); auto context = std::make_unique( &local_model, &lns_fragment, &mapping_proto); @@ -1517,9 +1526,8 @@ class LnsSolver : public SubSolver { // if we just recovered the base solution. if (data.status == CpSolverStatus::OPTIMAL || data.status == CpSolverStatus::FEASIBLE) { - const std::vector base_solution( - base_response.solution().begin(), base_response.solution().end()); - if (solution_values != base_solution) { + if (absl::MakeSpan(solution_values) != + absl::MakeSpan(base_response.solution())) { new_solution = true; shared_->response->NewSolution(solution_values, solution_info, /*model=*/nullptr); @@ -1568,6 +1576,10 @@ class LnsSolver : public SubSolver { ", #calls:", generator_->num_calls(), ", p:", fully_solved_proportion, "]"); } + { + absl::MutexLock l(&next_arena_size_mutex_); + next_arena_size_ = arena.SpaceUsed(); + } }; } @@ -1582,6 +1594,12 @@ class LnsSolver : public SubSolver { NeighborhoodGeneratorHelper* helper_; const SatParameters lns_parameters_; SharedClasses* shared_; + // This is a optimization to allocate the arena for the LNS fragment already + // at roughly the right size. We will update it with the last size of the + // latest LNS fragment. + absl::Mutex next_arena_size_mutex_; + int64_t next_arena_size_ ABSL_GUARDED_BY(next_arena_size_mutex_) = + helper_->ModelProto().SpaceUsedLong(); }; void SolveCpModelParallel(SharedClasses* shared, Model* global_model) { diff --git a/ortools/sat/diffn.cc b/ortools/sat/diffn.cc index bf5bd92300..7f3bdc01a8 100644 --- a/ortools/sat/diffn.cc +++ b/ortools/sat/diffn.cc @@ -71,6 +71,7 @@ IntegerVariable CreateVariableWithTightDomain( IntegerVariable CreateVariableEqualToMinOf( absl::Span exprs, Model* model) { std::vector converted; + converted.reserve(exprs.size()); for (const AffineExpression& affine : exprs) { LinearExpression e; e.offset = affine.constant; @@ -81,32 +82,34 @@ IntegerVariable CreateVariableEqualToMinOf( converted.push_back(e); } - LinearExpression target; const IntegerVariable var = CreateVariableWithTightDomain(exprs, model); + LinearExpression target; target.vars.push_back(var); target.coeffs.push_back(IntegerValue(1)); - model->Add(IsEqualToMinOf(target, converted)); + AddIsEqualToMinOf(target, std::move(converted), model); return var; } IntegerVariable CreateVariableEqualToMaxOf( absl::Span exprs, Model* model) { std::vector converted; + converted.reserve(exprs.size()); for (const AffineExpression& affine : exprs) { + // We take the negation of affine. LinearExpression e; - e.offset = affine.constant; + e.offset = -affine.constant; if (affine.var != kNoIntegerVariable) { - e.vars.push_back(affine.var); - e.coeffs.push_back(affine.coeff); + e.vars = {NegationOf(affine.var)}; + e.coeffs = {affine.coeff}; } - converted.push_back(NegationOf(e)); + converted.push_back(std::move(e)); } - LinearExpression target; const IntegerVariable var = CreateVariableWithTightDomain(exprs, model); - target.vars.push_back(NegationOf(var)); - target.coeffs.push_back(IntegerValue(1)); - model->Add(IsEqualToMinOf(target, converted)); + LinearExpression target; + target.vars = {NegationOf(var)}; + target.coeffs = {IntegerValue(1)}; + AddIsEqualToMinOf(target, std::move(converted), model); return var; } diff --git a/ortools/sat/diffn_util.cc b/ortools/sat/diffn_util.cc index 07ac566bb0..02ba333c50 100644 --- a/ortools/sat/diffn_util.cc +++ b/ortools/sat/diffn_util.cc @@ -21,7 +21,6 @@ #include #include #include -#include #include #include #include @@ -44,8 +43,6 @@ #include "ortools/base/stl_util.h" #include "ortools/base/strong_vector.h" #include "ortools/graph/connected_components.h" -#include "ortools/graph/graph.h" -#include "ortools/graph/minimum_spanning_tree.h" #include "ortools/graph/strongly_connected_components.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/intervals.h" @@ -124,7 +121,7 @@ CompactVectorVector GetOverlappingRectangleComponents( } std::vector> intersections = - FindPartialRectangleIntersectionsAlsoEmpty(rectangles_to_process); + FindPartialRectangleIntersections(rectangles_to_process); const int num_intersections = intersections.size(); intersections.reserve(num_intersections * 2 + 1); for (int i = 0; i < num_intersections; ++i) { @@ -1821,16 +1818,19 @@ struct Rectangle32 { int index; }; +// Requires that rectangles are sorted by x_min and that sizes on both +// dimensions are > 0. std::vector> FindPartialRectangleIntersectionsImpl( absl::Span rectangles, int y_max) { // We are going to use a sweep line algorithm to find the intersections. // First, we sort the rectangles by their x coordinates, then consider a sweep - // line that goes from the left to the right. - std::sort(rectangles.begin(), rectangles.end(), - [](const Rectangle32& a, const Rectangle32& b) { - return std::tuple(a.x_min, -a.x_max, a.index) < - std::tuple(b.x_min, -b.x_max, b.index); - }); + // line that goes from the left to the right. See the comment on the + // SweepLineIntervalTree class for more details about what we store for each + // line. + DCHECK(std::is_sorted(rectangles.begin(), rectangles.end(), + [](const Rectangle32& a, const Rectangle32& b) { + return a.x_min < b.x_min; + })); SweepLineIntervalTree interval_tree(y_max, rectangles.size()); @@ -1839,6 +1839,10 @@ std::vector> FindPartialRectangleIntersectionsImpl( std::vector> arcs; for (int rectangle_index = 0; rectangle_index < rectangles.size(); ++rectangle_index) { + DCHECK_LT(rectangles[rectangle_index].x_min, + rectangles[rectangle_index].x_max); + DCHECK_LT(rectangles[rectangle_index].y_min, + rectangles[rectangle_index].y_max); const int sweep_line_x_pos = rectangles[rectangle_index].x_min; const Rectangle32& r = rectangles[rectangle_index]; interval_pieces.clear(); @@ -1858,154 +1862,154 @@ std::vector> FindPartialRectangleIntersectionsImpl( std::vector> FindPartialRectangleIntersections( absl::Span rectangles) { + // This function preprocess the data and calls + // FindPartialRectangleIntersectionsImpl() to actually solve the problem + // using a sweep line algorithm. The preprocessing consists of the following: + // - It converts the arbitrary int64_t coordinates into a small integer by + // sorting the possible values and assigning them consecutive integers. + // - It grows zero size intervals to make them size one. This simplifies + // things considerably, since it is hard to reason about degenerated + // rectangles in the general algorithm. + // + // Note that the last point need to be done with care. Imagine the following + // example: + // +----------+ + // | | + // | +--------------+ + // | | | + // | | p,q r | + // | +----*-----*-+-+ + // | | | + // | | | + // | | | + // | +------------+ + // | | + // | | + // +----------+ + // Where p,q and r are points (ie, boxes of size 0x0) and p and q have the + // same coordinates. We replace them by the following: + // +----------+ + // | | + // | +----------------------+ + // | | | + // | | | + // | +----+-+---------------+ + // | | |p| + // | | +-+-+ + // | | |q| + // | | +-+ +-+ + // | | |r| + // | +--------------+-+---+ + // | | | + // | | | + // | | | + // | +--------------------+ + // | | + // | | + // +----------+ + // + // That is a pretty radical deformation of the original shape, but it retains + // the property of whether a pair of rectangles intersect or not. + if (rectangles.empty()) return {}; - std::vector to_sort_x; - std::vector to_sort_y; - for (const Rectangle& r : rectangles) { - DCHECK_GT(r.SizeX(), 0); - DCHECK_GT(r.SizeY(), 0); - to_sort_x.push_back(r.x_min); - to_sort_x.push_back(r.x_max); - to_sort_y.push_back(r.y_min); - to_sort_y.push_back(r.y_max); - } - gtl::STLSortAndRemoveDuplicates(&to_sort_x); - gtl::STLSortAndRemoveDuplicates(&to_sort_y); - - absl::flat_hash_map x_map; - absl::flat_hash_map y_map; - x_map.reserve(to_sort_x.size()); - y_map.reserve(to_sort_y.size()); - for (int i = 0; i < to_sort_x.size(); ++i) { - x_map[to_sort_x[i]] = i; - } - for (int i = 0; i < to_sort_y.size(); ++i) { - y_map[to_sort_y[i]] = i; - } - std::vector rectangles32; - rectangles32.reserve(rectangles.size()); + enum class Event { + kEnd = 0, + kPoint = 1, + kBegin = 2, + }; + std::vector> x_events; + std::vector> y_events; + x_events.reserve(rectangles.size() * 2); + y_events.reserve(rectangles.size() * 2); for (int i = 0; i < rectangles.size(); ++i) { const Rectangle& r = rectangles[i]; - rectangles32.push_back({.x_min = x_map[r.x_min], - .x_max = x_map[r.x_max], - .y_min = y_map[r.y_min], - .y_max = y_map[r.y_max], - .index = i}); - } - return FindPartialRectangleIntersectionsImpl(absl::MakeSpan(rectangles32), - to_sort_y.size()); -} - -std::vector> FindPartialRectangleIntersectionsAlsoEmpty( - absl::Span rectangles) { - auto first_index_no_area_it = std::find_if( - rectangles.begin(), rectangles.end(), [](const Rectangle& r) { - DCHECK_GE(r.SizeX(), 0); - DCHECK_GE(r.SizeY(), 0); - return r.SizeX() == 0 || r.SizeY() == 0; - }); - if (first_index_no_area_it == rectangles.end()) { - // Avoid copying, all rectangles have non-zero area. - return FindPartialRectangleIntersections(rectangles); - } - - // Now we need to do the boring code of special-casing all the different cases - // of rectangles with zero area. We still want to use the quasilinear - // algorithm for the subset of the input with non-zero area. - std::vector rectangles_with_area, horizontal_lines, vertical_lines, - points; - std::vector rectangles_with_area_indexes, horizontal_lines_indexes, - vertical_lines_indexes, points_indexes; - rectangles_with_area.reserve(rectangles.size()); - rectangles_with_area_indexes.reserve(rectangles.size()); - rectangles_with_area.insert(rectangles_with_area.end(), rectangles.begin(), - first_index_no_area_it); - rectangles_with_area_indexes.resize(rectangles_with_area.size()); - std::iota(rectangles_with_area_indexes.begin(), - rectangles_with_area_indexes.end(), 0); - - for (int i = first_index_no_area_it - rectangles.begin(); - i < rectangles.size(); ++i) { - if (rectangles[i].SizeX() > 0 && rectangles[i].SizeY() > 0) { - rectangles_with_area.push_back(rectangles[i]); - rectangles_with_area_indexes.push_back(i); - } else if (rectangles[i].SizeX() > 0) { - horizontal_lines.push_back(rectangles[i]); - horizontal_lines_indexes.push_back(i); - } else if (rectangles[i].SizeY() > 0) { - vertical_lines.push_back(rectangles[i]); - vertical_lines_indexes.push_back(i); + DCHECK_GE(r.SizeX(), 0); + DCHECK_GE(r.SizeY(), 0); + if (r.SizeX() == 0) { + x_events.push_back({r.x_min, Event::kPoint, i}); + } else { + x_events.push_back({r.x_min, Event::kBegin, i}); + x_events.push_back({r.x_max, Event::kEnd, i}); + } + if (r.SizeY() == 0) { + y_events.push_back({r.y_min, Event::kPoint, i}); } else { - points.push_back(rectangles[i]); - points_indexes.push_back(i); + y_events.push_back({r.y_min, Event::kBegin, i}); + y_events.push_back({r.y_max, Event::kEnd, i}); } } + std::sort(y_events.begin(), y_events.end()); - // Handle rectangles intersecting rectangles using the sweep line algorithm. - std::vector> arcs = - FindPartialRectangleIntersections(rectangles_with_area); - for (std::pair& arc : arcs) { - arc.first = rectangles_with_area_indexes[arc.first]; - arc.second = rectangles_with_area_indexes[arc.second]; + std::vector rectangles32; + rectangles32.resize(rectangles.size()); + IntegerValue prev_y = 0; + Event prev_event = Event::kEnd; + int cur_index = -1; + for (int i = 0; i < y_events.size(); ++i) { + const auto [y, event, index] = y_events[i]; + if ((prev_event != event && prev_event != Event::kEnd) || prev_y != y || + event == Event::kPoint || cur_index == -1) { + ++cur_index; + } + + switch (event) { + case Event::kBegin: + rectangles32[index].y_min = cur_index; + rectangles32[index].index = index; + break; + case Event::kEnd: + rectangles32[index].y_max = cur_index; + break; + case Event::kPoint: + rectangles32[index].y_min = cur_index; + rectangles32[index].y_max = cur_index + 1; + rectangles32[index].index = index; + break; + } + prev_event = event; + prev_y = y; } + const int max_y_index = cur_index + 1; - // Handle rectangles intersecting non-rectangles. - for (int i = 0; i < rectangles_with_area.size(); ++i) { - const int index = rectangles_with_area_indexes[i]; - const Rectangle& r = rectangles_with_area[i]; - for (int j = 0; j < vertical_lines.size(); ++j) { - const int vertical_line_index = vertical_lines_indexes[j]; - const Rectangle& vertical_line = vertical_lines[j]; - if (!r.IsDisjoint(vertical_line)) { - arcs.push_back({index, vertical_line_index}); - } + std::sort(x_events.begin(), x_events.end()); + IntegerValue prev_x = 0; + prev_event = Event::kEnd; + cur_index = -1; + for (int i = 0; i < x_events.size(); ++i) { + const auto [x, event, index] = x_events[i]; + if ((prev_event != event && prev_event != Event::kEnd) || prev_x != x || + event == Event::kPoint || cur_index == -1) { + ++cur_index; } - for (int j = 0; j < horizontal_lines.size(); ++j) { - const int horizontal_line_index = horizontal_lines_indexes[j]; - const Rectangle& horizontal_line = horizontal_lines[j]; - if (!r.IsDisjoint(horizontal_line)) { - arcs.push_back({index, horizontal_line_index}); - } - } - for (int j = 0; j < points.size(); ++j) { - const int point_index = points_indexes[j]; - const Rectangle& point = points[j]; - if (!r.IsDisjoint(point)) { - arcs.push_back({index, point_index}); - } + + switch (event) { + case Event::kBegin: + rectangles32[index].x_min = cur_index; + break; + case Event::kEnd: + rectangles32[index].x_max = cur_index; + break; + case Event::kPoint: + rectangles32[index].x_min = cur_index; + rectangles32[index].x_max = cur_index + 1; + break; } + prev_event = event; + prev_x = x; } - // Finally handle vertical lines intersecting horizontal lines. - for (int i = 0; i < horizontal_lines.size(); ++i) { - const int index = horizontal_lines_indexes[i]; - const Rectangle& r = horizontal_lines[i]; - for (int j = 0; j < vertical_lines.size(); ++j) { - const int vertical_line_index = vertical_lines_indexes[j]; - const Rectangle& vertical_line = vertical_lines[j]; - if (!r.IsDisjoint(vertical_line)) { - arcs.push_back({index, vertical_line_index}); - } + std::vector sorted_rectangles32; + sorted_rectangles32.reserve(rectangles.size()); + for (int i = 0; i < x_events.size(); ++i) { + const auto [x, event, index] = x_events[i]; + if (event == Event::kBegin || event == Event::kPoint) { + sorted_rectangles32.push_back(rectangles32[index]); } } - // Now make our graph a minimal spanning tree again. - ::util::ReverseArcListGraph<> graph; - std::vector arc_indexes; - absl::flat_hash_map> pair_by_arc_index; - for (const auto& [a, b] : arcs) { - pair_by_arc_index[arc_indexes.size()] = {a, b}; - arc_indexes.push_back(graph.AddArc(a, b)); - } - const std::vector mst_arc_indices = - BuildKruskalMinimumSpanningTreeFromSortedArcs(graph, arc_indexes); - std::vector> result; - for (const int arc_index : mst_arc_indices) { - const auto& [a, b] = pair_by_arc_index[arc_index]; - result.push_back({a, b}); - } - return result; + return FindPartialRectangleIntersectionsImpl( + absl::MakeSpan(sorted_rectangles32), max_y_index); } std::optional> FindOneIntersectionIfPresent( diff --git a/ortools/sat/diffn_util.h b/ortools/sat/diffn_util.h index 9dafc70082..8e3d979dd1 100644 --- a/ortools/sat/diffn_util.h +++ b/ortools/sat/diffn_util.h @@ -683,29 +683,27 @@ inline bool RegionIncludesOther(absl::Span region, return PavedRegionDifference({other.begin(), other.end()}, region).empty(); } -// For a given a set of N rectangles with non-zero area in `rectangles`, there -// might be up to N*(N-1)/2 pairs of rectangles that intersect one another. If -// each of these pairs describe an arc and each rectangle describe a node, the -// rectangles and their intersections describe a graph. This function returns -// the full spanning forest for this graph (ie., a spanning tree for each -// connected component). This function allows to know if a set of rectangles has -// any intersection, find an example intersection for each rectangle that has -// one, or split the rectangles into connected components according to their -// intersections. +// For a given a set of N rectangles in `rectangles`, there might be up to +// N*(N-1)/2 pairs of rectangles that intersect one another. If each of these +// pairs describe an arc and each rectangle describe a node, the rectangles and +// their intersections describe a graph. This function returns the full spanning +// forest for this graph (ie., a spanning tree for each connected component). +// This function allows to know if a set of rectangles has any intersection, +// find an example intersection for each rectangle that has one, or split the +// rectangles into connected components according to their intersections. // // The returned tuples are the arcs of the spanning forest represented by their // indices in the input vector. // +// This function works with degenerate rectangles (ie., points or lines) and +// have the same semantics for overlap as Rectangle::IsDisjoint(). +// // Note: This function runs in O(N (log N)^2) time on the input size, which // would be impossible to do if we were to return all the intersections, which // can be quadratic in number. std::vector> FindPartialRectangleIntersections( absl::Span rectangles); -// Same as above, but also correctly handles rectangles with zero area. -std::vector> FindPartialRectangleIntersectionsAlsoEmpty( - absl::Span rectangles); - // This function is faster that the FindPartialRectangleIntersections() if one // only want to know if there is at least one intersection. It is in O(N log N). // diff --git a/ortools/sat/diffn_util_test.cc b/ortools/sat/diffn_util_test.cc index f4cb6692ca..9788e1ebf1 100644 --- a/ortools/sat/diffn_util_test.cc +++ b/ortools/sat/diffn_util_test.cc @@ -34,6 +34,7 @@ #include "absl/types/span.h" #include "benchmark/benchmark.h" #include "gtest/gtest.h" +#include "ortools/base/fuzztest.h" #include "ortools/base/gmock.h" #include "ortools/base/logging.h" #include "ortools/graph/connected_components.h" @@ -41,6 +42,7 @@ #include "ortools/sat/2d_orthogonal_packing_testing.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/util.h" +#include "ortools/util/saturated_arithmetic.h" #include "ortools/util/strong_integers.h" namespace operations_research { @@ -1054,6 +1056,27 @@ TEST(FindPartialIntersections, Random) { .y_min = rec.y_min - IntegerValue(absl::Uniform(random, 0, 4)), .y_max = rec.y_max + IntegerValue(absl::Uniform(random, 0, 4))}; } + const int num_of_rectangle_with_area = rectangles.size(); + const int num_points = absl::Uniform(random, 0, 5); + for (int i = 0; i < num_points; ++i) { + const IntegerValue x = absl::Uniform(random, 0, 100); + const IntegerValue y = absl::Uniform(random, 0, 100); + rectangles.push_back({.x_min = x, .x_max = x, .y_min = y, .y_max = y}); + } + const int num_lines = absl::Uniform(random, 0, 5); + for (int i = 0; i < num_lines; ++i) { + const IntegerValue v = absl::Uniform(random, 0, 100); + const IntegerValue i1 = absl::Uniform(random, 0, 99); + const IntegerValue i2 = absl::Uniform(random, i1.value() + 1, 100); + if (absl::Bernoulli(random, 0.5)) { + rectangles.push_back( + {.x_min = i1, .x_max = i2, .y_min = v, .y_max = v}); + } else { + rectangles.push_back( + {.x_min = v, .x_max = v, .y_min = i1, .y_max = i2}); + } + } + const std::vector> naive_result = GetAllIntersections(rectangles); const std::vector> result = @@ -1071,13 +1094,18 @@ TEST(FindPartialIntersections, Random) { } // We also test FindOneIntersectionIfPresent(). - absl::c_sort(rectangles, [](const Rectangle& a, const Rectangle& b) { - return a.x_min < b.x_min; - }); - if (naive_result.empty()) { - EXPECT_EQ(FindOneIntersectionIfPresent(rectangles), std::nullopt); + std::sort(rectangles.begin(), + rectangles.begin() + num_of_rectangle_with_area, + [](const Rectangle& a, const Rectangle& b) { + return a.x_min < b.x_min; + }); + absl::Span rectangles_with_area = + absl::MakeSpan(rectangles).subspan(0, num_of_rectangle_with_area); + if (FindPartialRectangleIntersections(rectangles_with_area).empty()) { + EXPECT_EQ(FindOneIntersectionIfPresent(rectangles_with_area), + std::nullopt); } else { - auto opt_pair = FindOneIntersectionIfPresent(rectangles); + auto opt_pair = FindOneIntersectionIfPresent(rectangles_with_area); EXPECT_NE(opt_pair, std::nullopt); EXPECT_FALSE( rectangles[opt_pair->first].IsDisjoint(rectangles[opt_pair->second])); @@ -1085,6 +1113,32 @@ TEST(FindPartialIntersections, Random) { } } +void CheckFuzzedRectangles( + const std::vector>& tuples) { + std::vector rectangles; + rectangles.reserve(tuples.size()); + for (const auto& [x_min, x_size, y_min, y_size] : tuples) { + rectangles.push_back({.x_min = x_min, + .x_max = CapAdd(x_min, x_size), + .y_min = y_min, + .y_max = CapAdd(y_min, y_size)}); + } + const std::vector> result = + FindPartialRectangleIntersections(rectangles); + for (const auto& [i, j] : result) { + CHECK(!rectangles[i].IsDisjoint(rectangles[j])) << i << " " << j; + } + const std::vector> naive_result = + GetAllIntersections(rectangles); + CHECK(GraphsDefineSameConnectedComponents(naive_result, result)) + << RenderRectGraph(std::nullopt, rectangles, result); +} + +FUZZ_TEST(FindPartialIntersections, CheckFuzzedRectangles) + .WithDomains(fuzztest::VectorOf(fuzztest::TupleOf( + fuzztest::Arbitrary(), fuzztest::NonNegative(), + fuzztest::Arbitrary(), fuzztest::NonNegative()))); + void BM_FindRectangles(benchmark::State& state) { absl::BitGen random; std::vector> problems; diff --git a/ortools/sat/disjunctive.cc b/ortools/sat/disjunctive.cc index 5b6c9958cf..6ceaa51942 100644 --- a/ortools/sat/disjunctive.cc +++ b/ortools/sat/disjunctive.cc @@ -1558,6 +1558,14 @@ bool DisjunctiveEdgeFinding::Propagate() { return false; } + // Corner case: The propagation of the previous window might have made the + // current task absent even if it wasn't at the loop beginning. + if (helper_->IsAbsent(presence_lit)) { + window_.clear(); + window_end = kMinIntegerValue; + continue; + } + // Start of the next window. window_.clear(); window_.push_back({task, shifted_smin}); diff --git a/ortools/sat/feasibility_jump.cc b/ortools/sat/feasibility_jump.cc index 12514708ba..e1f2ea70fc 100644 --- a/ortools/sat/feasibility_jump.cc +++ b/ortools/sat/feasibility_jump.cc @@ -383,10 +383,10 @@ std::function FeasibilityJumpSolver::GenerateTask(int64_t /*task_id*/) { new_best_solution_was_found) { if (type() == SubSolver::INCOMPLETE) { // Choose a base solution for this neighborhood. - const SharedSolutionRepository::Solution solution = - shared_response_->SolutionsRepository().GetRandomBiasedSolution( - random_); - state_->solution = solution.variable_values; + std::shared_ptr::Solution> + solution = shared_response_->SolutionsRepository() + .GetRandomBiasedSolution(random_); + state_->solution = solution->variable_values; ++state_->num_solutions_imported; } else { if (!first_time) { diff --git a/ortools/sat/integer_expr.cc b/ortools/sat/integer_expr.cc index 787276ec1f..a6cbc8ab85 100644 --- a/ortools/sat/integer_expr.cc +++ b/ortools/sat/integer_expr.cc @@ -641,9 +641,9 @@ void MinPropagator::RegisterWith(GenericLiteralWatcher* watcher) { watcher->WatchUpperBound(min_var_, id); } -LinMinPropagator::LinMinPropagator(const std::vector& exprs, +LinMinPropagator::LinMinPropagator(std::vector exprs, IntegerVariable min_var, Model* model) - : exprs_(exprs), + : exprs_(std::move(exprs)), min_var_(min_var), model_(model), integer_trail_(model_->GetOrCreate()) {} diff --git a/ortools/sat/integer_expr.h b/ortools/sat/integer_expr.h index efbf8422a6..4dcf7c3a55 100644 --- a/ortools/sat/integer_expr.h +++ b/ortools/sat/integer_expr.h @@ -243,8 +243,8 @@ class MinPropagator : public PropagatorInterface { // Assumes Canonical expressions (all positive coefficients). class LinMinPropagator : public PropagatorInterface, LazyReasonInterface { public: - LinMinPropagator(const std::vector& exprs, - IntegerVariable min_var, Model* model); + LinMinPropagator(std::vector exprs, IntegerVariable min_var, + Model* model); LinMinPropagator(const LinMinPropagator&) = delete; LinMinPropagator& operator=(const LinMinPropagator&) = delete; @@ -695,80 +695,52 @@ inline std::function NewWeightedSum( }; } -// Expresses the fact that an existing integer variable is equal to the minimum -// of other integer variables. -inline std::function IsEqualToMinOf( - IntegerVariable min_var, const std::vector& vars) { - return [=](Model* model) { - for (const IntegerVariable& var : vars) { - model->Add(LowerOrEqual(min_var, var)); - } - - MinPropagator* constraint = - new MinPropagator(vars, min_var, model->GetOrCreate()); - constraint->RegisterWith(model->GetOrCreate()); - model->TakeOwnership(constraint); - }; -} - // Expresses the fact that an existing integer variable is equal to the minimum // of linear expressions. Assumes Canonical expressions (all positive // coefficients). -inline std::function IsEqualToMinOf( - const LinearExpression& min_expr, - const std::vector& exprs) { - return [=](Model* model) { - IntegerTrail* integer_trail = model->GetOrCreate(); - - IntegerVariable min_var; - if (min_expr.vars.size() == 1 && - std::abs(min_expr.coeffs[0].value()) == 1 && min_expr.offset == 0) { - if (min_expr.coeffs[0].value() == 1) { - min_var = min_expr.vars[0]; - } else { - min_var = NegationOf(min_expr.vars[0]); - } +inline void AddIsEqualToMinOf(const LinearExpression& min_expr, + std::vector exprs, + Model* model) { + IntegerTrail* integer_trail = model->GetOrCreate(); + + IntegerVariable min_var; + if (min_expr.vars.size() == 1 && std::abs(min_expr.coeffs[0].value()) == 1 && + min_expr.offset == 0) { + if (min_expr.coeffs[0].value() == 1) { + min_var = min_expr.vars[0]; } else { - // Create a new variable if the expression is not just a single variable. - IntegerValue min_lb = min_expr.Min(*integer_trail); - IntegerValue min_ub = min_expr.Max(*integer_trail); - min_var = integer_trail->AddIntegerVariable(min_lb, min_ub); - - // min_var = min_expr - LinearConstraintBuilder builder(0, 0); - builder.AddLinearExpression(min_expr, 1); - builder.AddTerm(min_var, -1); - LoadLinearConstraint(builder.Build(), model); - } - for (const LinearExpression& expr : exprs) { - LinearConstraintBuilder builder(0, kMaxIntegerValue); - builder.AddLinearExpression(expr, 1); - builder.AddTerm(min_var, -1); - LoadLinearConstraint(builder.Build(), model); + min_var = NegationOf(min_expr.vars[0]); } + } else { + // Create a new variable if the expression is not just a single variable. + IntegerValue min_lb = min_expr.Min(*integer_trail); + IntegerValue min_ub = min_expr.Max(*integer_trail); + min_var = integer_trail->AddIntegerVariable(min_lb, min_ub); + + // min_var = min_expr + LinearConstraintBuilder builder(0, 0); + builder.AddLinearExpression(min_expr, 1); + builder.AddTerm(min_var, -1); + LoadLinearConstraint(builder.Build(), model); + } + for (const LinearExpression& expr : exprs) { + LinearConstraintBuilder builder(0, kMaxIntegerValue); + builder.AddLinearExpression(expr, 1); + builder.AddTerm(min_var, -1); + LoadLinearConstraint(builder.Build(), model); + } - LinMinPropagator* constraint = new LinMinPropagator(exprs, min_var, model); - constraint->RegisterWith(model->GetOrCreate()); - model->TakeOwnership(constraint); - }; + LinMinPropagator* constraint = + new LinMinPropagator(std::move(exprs), min_var, model); + constraint->RegisterWith(model->GetOrCreate()); + model->TakeOwnership(constraint); } -// Expresses the fact that an existing integer variable is equal to the maximum -// of other integer variables. -inline std::function IsEqualToMaxOf( - IntegerVariable max_var, const std::vector& vars) { - return [=](Model* model) { - std::vector negated_vars; - for (const IntegerVariable& var : vars) { - negated_vars.push_back(NegationOf(var)); - model->Add(GreaterOrEqual(max_var, var)); - } - - MinPropagator* constraint = new MinPropagator( - negated_vars, NegationOf(max_var), model->GetOrCreate()); - constraint->RegisterWith(model->GetOrCreate()); - model->TakeOwnership(constraint); - }; +ABSL_DEPRECATED("Use AddIsEqualToMinOf() instead.") +inline std::function IsEqualToMinOf( + const LinearExpression& min_expr, + const std::vector& exprs) { + return [&](Model* model) { AddIsEqualToMinOf(min_expr, exprs, model); }; } template diff --git a/ortools/sat/integer_expr_test.cc b/ortools/sat/integer_expr_test.cc index fc770f7082..faf12759d2 100644 --- a/ortools/sat/integer_expr_test.cc +++ b/ortools/sat/integer_expr_test.cc @@ -349,10 +349,31 @@ TEST(MinMaxTest, LevelZeroPropagation) { std::vector vars{model.Add(NewIntegerVariable(4, 9)), model.Add(NewIntegerVariable(2, 7)), model.Add(NewIntegerVariable(3, 8))}; + std::vector exprs; + for (const IntegerVariable var : vars) { + LinearExpression expr; + expr.vars.push_back(var); + expr.coeffs.push_back(1); + exprs.push_back(expr); + } const IntegerVariable min = model.Add(NewIntegerVariable(0, 10)); + { + LinearExpression min_expr; + min_expr.vars.push_back(min); + min_expr.coeffs.push_back(1); + model.Add(IsEqualToMinOf(min_expr, exprs)); + } const IntegerVariable max = model.Add(NewIntegerVariable(0, 10)); - model.Add(IsEqualToMinOf(min, vars)); - model.Add(IsEqualToMaxOf(max, vars)); + { + // We negate everything to get a max. + LinearExpression max_expr; + max_expr.vars.push_back(max); + max_expr.coeffs.push_back(-1); + for (LinearExpression& ref : exprs) { + ref.coeffs[0] = -ref.coeffs[0]; + } + model.Add(IsEqualToMinOf(max_expr, exprs)); + } EXPECT_EQ(SatSolver::FEASIBLE, model.GetOrCreate()->Solve()); EXPECT_BOUNDS_EQ(min, 2, 7); @@ -413,30 +434,6 @@ TEST(LinMinMaxTest, LevelZeroPropagation) { EXPECT_BOUNDS_EQ(vars[2], 5, 8); } -TEST(MinTest, OnlyOnePossibleCandidate) { - Model model; - std::vector vars{model.Add(NewIntegerVariable(4, 7)), - model.Add(NewIntegerVariable(2, 9)), - model.Add(NewIntegerVariable(5, 8))}; - const IntegerVariable min = model.Add(NewIntegerVariable(0, 10)); - model.Add(IsEqualToMinOf(min, vars)); - - // So far everything is normal. - EXPECT_EQ(SatSolver::FEASIBLE, model.GetOrCreate()->Solve()); - EXPECT_BOUNDS_EQ(min, 2, 7); - - // But now, if the min is known to be <= 3, the minimum variable is known! it - // has to be variable #1, so we can propagate its upper bound. - model.Add(LowerOrEqual(min, 3)); - EXPECT_EQ(SatSolver::FEASIBLE, model.GetOrCreate()->Solve()); - EXPECT_BOUNDS_EQ(min, 2, 3); - EXPECT_BOUNDS_EQ(vars[1], 2, 3); - - // Test infeasibility. - model.Add(LowerOrEqual(min, 1)); - EXPECT_EQ(SatSolver::INFEASIBLE, model.GetOrCreate()->Solve()); -} - TEST(LinMinTest, OnlyOnePossibleCandidate) { Model model; std::vector vars{model.Add(NewIntegerVariable(4, 7)), @@ -453,7 +450,7 @@ TEST(LinMinTest, OnlyOnePossibleCandidate) { LinearExpression min_expr; min_expr.vars.push_back(min); min_expr.coeffs.push_back(1); - model.Add(IsEqualToMinOf(min_expr, exprs)); + AddIsEqualToMinOf(min_expr, exprs, &model); // So far everything is normal. EXPECT_EQ(SatSolver::FEASIBLE, model.GetOrCreate()->Solve()); @@ -508,7 +505,7 @@ TEST(LinMinTest, OnlyOnePossibleExpr) { LinearExpression min_expr; min_expr.vars.push_back(min); min_expr.coeffs.push_back(1); - model.Add(IsEqualToMinOf(min_expr, exprs)); + AddIsEqualToMinOf(min_expr, exprs, &model); // So far everything is normal. EXPECT_EQ(SatSolver::FEASIBLE, model.GetOrCreate()->Solve()); diff --git a/ortools/sat/linear_programming_constraint.cc b/ortools/sat/linear_programming_constraint.cc index b3bd50d63f..29066e5ce9 100644 --- a/ortools/sat/linear_programming_constraint.cc +++ b/ortools/sat/linear_programming_constraint.cc @@ -1659,8 +1659,9 @@ void LinearProgrammingConstraint::AddCGCuts() { IgnoreTrivialConstraintMultipliers(&tmp_cg_multipliers_); if (tmp_cg_multipliers_.size() <= 1) continue; } - tmp_integer_multipliers_ = ScaleMultipliers( - tmp_cg_multipliers_, /*take_objective_into_account=*/false, &scaling); + ScaleMultipliers(tmp_cg_multipliers_, + /*take_objective_into_account=*/false, &scaling, + &tmp_integer_multipliers_); if (scaling != 0) { if (AddCutFromConstraints("CG", tmp_integer_multipliers_)) { ++num_added; @@ -2273,16 +2274,16 @@ void LinearProgrammingConstraint::IgnoreTrivialConstraintMultipliers( lp_multipliers->resize(new_size); } -std::vector> -LinearProgrammingConstraint::ScaleMultipliers( +void LinearProgrammingConstraint::ScaleMultipliers( absl::Span> lp_multipliers, - bool take_objective_into_account, IntegerValue* scaling) const { + bool take_objective_into_account, IntegerValue* scaling, + std::vector>* output) const { *scaling = 0; - std::vector> integer_multipliers; + output->clear(); if (lp_multipliers.empty()) { // Empty linear combinaison. - return integer_multipliers; + return; } // TODO(user): we currently do not support scaling down, so we just abort @@ -2291,7 +2292,7 @@ LinearProgrammingConstraint::ScaleMultipliers( if (ScalingCanOverflow(/*power=*/0, take_objective_into_account, lp_multipliers, overflow_cap)) { ++num_scaling_issues_; - return integer_multipliers; + return; } // Note that we don't try to scale by more than 63 since in practice the @@ -2319,16 +2320,15 @@ LinearProgrammingConstraint::ScaleMultipliers( const IntegerValue coeff(std::round(double_coeff * scaling_as_double)); if (coeff != 0) { gcd = std::gcd(gcd, std::abs(coeff.value())); - integer_multipliers.push_back({row, coeff}); + output->push_back({row, coeff}); } } if (gcd > 1) { *scaling /= gcd; - for (auto& entry : integer_multipliers) { + for (auto& entry : *output) { entry.second /= gcd; } } - return integer_multipliers; } template @@ -2611,8 +2611,8 @@ bool LinearProgrammingConstraint::PropagateExactLpReason() { IntegerValue scaling = 0; IgnoreTrivialConstraintMultipliers(&tmp_lp_multipliers_); - tmp_integer_multipliers_ = ScaleMultipliers( - tmp_lp_multipliers_, take_objective_into_account, &scaling); + ScaleMultipliers(tmp_lp_multipliers_, take_objective_into_account, &scaling, + &tmp_integer_multipliers_); if (scaling == 0) { VLOG(1) << simplex_.GetProblemStatus(); VLOG(1) << "Issue while computing the exact LP reason. Aborting."; @@ -2681,8 +2681,8 @@ bool LinearProgrammingConstraint::PropagateExactDualRay() { tmp_lp_multipliers_.push_back({row, row_factors_[row.value()] * value}); } IgnoreTrivialConstraintMultipliers(&tmp_lp_multipliers_); - tmp_integer_multipliers_ = ScaleMultipliers( - tmp_lp_multipliers_, /*take_objective_into_account=*/false, &scaling); + ScaleMultipliers(tmp_lp_multipliers_, /*take_objective_into_account=*/false, + &scaling, &tmp_integer_multipliers_); if (scaling == 0) { VLOG(1) << "Isse while computing the exact dual ray reason. Aborting."; return true; diff --git a/ortools/sat/linear_programming_constraint.h b/ortools/sat/linear_programming_constraint.h index 9f50ef8a69..226a3fa801 100644 --- a/ortools/sat/linear_programming_constraint.h +++ b/ortools/sat/linear_programming_constraint.h @@ -332,9 +332,10 @@ class LinearProgrammingConstraint : public PropagatorInterface, // will still be exact as it will work for any set of multiplier. void IgnoreTrivialConstraintMultipliers( std::vector>* lp_multipliers); - std::vector> ScaleMultipliers( + void ScaleMultipliers( absl::Span> lp_multipliers, - bool take_objective_into_account, IntegerValue* scaling) const; + bool take_objective_into_account, IntegerValue* scaling, + std::vector>* output) const; // Can we have an overflow if we scale each coefficients with // std::round(std::ldexp(coeff, power)) ? diff --git a/ortools/sat/presolve_context.cc b/ortools/sat/presolve_context.cc index 85e1d86a79..2aede237b7 100644 --- a/ortools/sat/presolve_context.cc +++ b/ortools/sat/presolve_context.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -218,55 +219,55 @@ void PresolveContext::AddImplyInDomain(int b, int x, const Domain& domain) { } bool PresolveContext::DomainIsEmpty(int ref) const { - return domains[PositiveRef(ref)].IsEmpty(); + return domains_[PositiveRef(ref)].IsEmpty(); } bool PresolveContext::IsFixed(int ref) const { - DCHECK_LT(PositiveRef(ref), domains.size()); + DCHECK_LT(PositiveRef(ref), domains_.size()); DCHECK(!DomainIsEmpty(ref)); - return domains[PositiveRef(ref)].IsFixed(); + return domains_[PositiveRef(ref)].IsFixed(); } bool PresolveContext::CanBeUsedAsLiteral(int ref) const { const int var = PositiveRef(ref); - return domains[var].Min() >= 0 && domains[var].Max() <= 1; + return domains_[var].Min() >= 0 && domains_[var].Max() <= 1; } bool PresolveContext::LiteralIsTrue(int lit) const { DCHECK(CanBeUsedAsLiteral(lit)); if (RefIsPositive(lit)) { - return domains[lit].Min() == 1; + return domains_[lit].Min() == 1; } else { - return domains[PositiveRef(lit)].Max() == 0; + return domains_[PositiveRef(lit)].Max() == 0; } } bool PresolveContext::LiteralIsFalse(int lit) const { DCHECK(CanBeUsedAsLiteral(lit)); if (RefIsPositive(lit)) { - return domains[lit].Max() == 0; + return domains_[lit].Max() == 0; } else { - return domains[PositiveRef(lit)].Min() == 1; + return domains_[PositiveRef(lit)].Min() == 1; } } int64_t PresolveContext::MinOf(int ref) const { DCHECK(!DomainIsEmpty(ref)); - return RefIsPositive(ref) ? domains[PositiveRef(ref)].Min() - : -domains[PositiveRef(ref)].Max(); + return RefIsPositive(ref) ? domains_[PositiveRef(ref)].Min() + : -domains_[PositiveRef(ref)].Max(); } int64_t PresolveContext::MaxOf(int ref) const { DCHECK(!DomainIsEmpty(ref)); - return RefIsPositive(ref) ? domains[PositiveRef(ref)].Max() - : -domains[PositiveRef(ref)].Min(); + return RefIsPositive(ref) ? domains_[PositiveRef(ref)].Max() + : -domains_[PositiveRef(ref)].Min(); } int64_t PresolveContext::FixedValue(int ref) const { DCHECK(!DomainIsEmpty(ref)); CHECK(IsFixed(ref)); - return RefIsPositive(ref) ? domains[PositiveRef(ref)].Min() - : -domains[PositiveRef(ref)].Min(); + return RefIsPositive(ref) ? domains_[PositiveRef(ref)].Min() + : -domains_[PositiveRef(ref)].Min(); } int64_t PresolveContext::MinOf(const LinearExpressionProto& expr) const { @@ -311,6 +312,18 @@ int64_t PresolveContext::FixedValue(const LinearExpressionProto& expr) const { return result; } +std::optional PresolveContext::FixedValueOrNullopt( + const LinearExpressionProto& expr) const { + int64_t result = expr.offset(); + for (int i = 0; i < expr.vars_size(); ++i) { + if (expr.coeffs(i) == 0) continue; + const Domain& domain = domains_[expr.vars(i)]; + if (!domain.IsFixed()) return std::nullopt; + result += expr.coeffs(i) * domain.Min(); + } + return result; +} + Domain PresolveContext::DomainSuperSetOf( const LinearExpressionProto& expr) const { Domain result(expr.offset()); @@ -519,18 +532,18 @@ bool PresolveContext::VariableIsOnlyUsedInLinear1AndOneExtraConstraint( Domain PresolveContext::DomainOf(int ref) const { Domain result; if (RefIsPositive(ref)) { - result = domains[ref]; + result = domains_[ref]; } else { - result = domains[PositiveRef(ref)].Negation(); + result = domains_[PositiveRef(ref)].Negation(); } return result; } bool PresolveContext::DomainContains(int ref, int64_t value) const { if (!RefIsPositive(ref)) { - return domains[PositiveRef(ref)].Contains(-value); + return domains_[PositiveRef(ref)].Contains(-value); } - return domains[ref].Contains(value); + return domains_[ref].Contains(value); } bool PresolveContext::DomainContains(const LinearExpressionProto& expr, @@ -554,38 +567,38 @@ ABSL_MUST_USE_RESULT bool PresolveContext::IntersectDomainWithInternal( const int var = PositiveRef(ref); if (RefIsPositive(ref)) { - if (domains[var].IsIncludedIn(domain)) { + if (domains_[var].IsIncludedIn(domain)) { return true; } - domains[var] = domains[var].IntersectionWith(domain); + domains_[var] = domains_[var].IntersectionWith(domain); } else { const Domain temp = domain.Negation(); - if (domains[var].IsIncludedIn(temp)) { + if (domains_[var].IsIncludedIn(temp)) { return true; } - domains[var] = domains[var].IntersectionWith(temp); + domains_[var] = domains_[var].IntersectionWith(temp); } if (domain_modified != nullptr) { *domain_modified = true; } modified_domains.Set(var); - if (domains[var].IsEmpty()) { + if (domains_[var].IsEmpty()) { return NotifyThatModelIsUnsat( absl::StrCat("var #", ref, " as empty domain after intersecting with ", domain.ToString())); } if (update_hint && VarHasSolutionHint(var)) { - UpdateVarSolutionHint(var, domains[var].ClosestValue(SolutionHint(var))); + UpdateVarSolutionHint(var, domains_[var].ClosestValue(SolutionHint(var))); } #ifdef CHECK_HINT if (working_model->has_solution_hint() && HintIsLoaded() && - !domains[var].Contains(hint_[var])) { + !domains_[var].Contains(hint_[var])) { LOG(FATAL) << "Hint with value " << hint_[var] << " infeasible when changing domain of " << var << " to " - << domains[var]; + << domains_[var]; } #endif @@ -1042,10 +1055,11 @@ void PresolveContext::CanonicalizeVariable(int ref) { UpdateNewConstraintsVariableUsage(); } -bool PresolveContext::ScaleFloatingPointObjective() { - DCHECK(working_model->has_floating_point_objective()); - DCHECK(!working_model->has_objective()); - const auto& objective = working_model->floating_point_objective(); +bool ScaleFloatingPointObjective(const SatParameters& params, + SolverLogger* logger, CpModelProto* proto) { + DCHECK(proto->has_floating_point_objective()); + DCHECK(!proto->has_objective()); + const auto& objective = proto->floating_point_objective(); std::vector> terms; for (int i = 0; i < objective.vars_size(); ++i) { DCHECK(RefIsPositive(objective.vars(i))); @@ -1053,12 +1067,9 @@ bool PresolveContext::ScaleFloatingPointObjective() { } const double offset = objective.offset(); const bool maximize = objective.maximize(); - working_model->clear_floating_point_objective(); + proto->clear_floating_point_objective(); - // We need the domains up to date before scaling. - WriteVariableDomainsToProto(); - return ScaleAndSetObjective(params_, terms, offset, maximize, working_model, - logger_); + return ScaleAndSetObjective(params, terms, offset, maximize, proto, logger); } bool PresolveContext::CanonicalizeAffineVariable(int ref, int64_t coeff, @@ -1401,11 +1412,21 @@ std::string PresolveContext::AffineRelationDebugString(int ref) const { RefDebugString(r.representative), " + ", r.offset); } +void PresolveContext::ResetAfterCopy() { + domains_.clear(); + modified_domains.ClearAll(); + var_with_reduced_small_degree.ClearAll(); + var_to_constraints_.clear(); + var_to_num_linear1_.clear(); + objective_map_.clear(); + hint_.clear(); +} + // Create the internal structure for any new variables in working_model. void PresolveContext::InitializeNewDomains() { const int new_size = working_model->variables().size(); - DCHECK_GE(new_size, domains.size()); - if (domains.size() == new_size) return; + DCHECK_GE(new_size, domains_.size()); + if (domains_.size() == new_size) return; modified_domains.Resize(new_size); var_with_reduced_small_degree.Resize(new_size); @@ -1414,10 +1435,12 @@ void PresolveContext::InitializeNewDomains() { // We mark the domain as modified so we will look at these new variable during // our presolve loop. - for (int i = domains.size(); i < new_size; ++i) { + const int old_size = domains_.size(); + domains_.resize(new_size); + for (int i = old_size; i < new_size; ++i) { modified_domains.Set(i); - domains.emplace_back(ReadDomainFromProto(working_model->variables(i))); - if (domains.back().IsEmpty()) { + domains_[i] = ReadDomainFromProto(working_model->variables(i)); + if (domains_[i].IsEmpty()) { is_unsat_ = true; return; } @@ -1560,7 +1583,9 @@ bool PresolveContext::InsertVarValueEncodingInternal(int literal, int var, absl::flat_hash_map& var_map = encoding_[var]; // The code below is not 100% correct if this is not the case. - CHECK(DomainOf(var).Contains(value)); + if (!DomainOf(var).Contains(value)) { + return SetLiteralToFalse(literal); + } if (DomainOf(var).IsFixed()) { return SetLiteralToTrue(literal); } @@ -1781,7 +1806,7 @@ bool PresolveContext::HasVarValueEncoding(int ref, int64_t value, bool PresolveContext::IsFullyEncoded(int ref) const { const int var = PositiveRef(ref); - const int64_t size = domains[var].Size(); + const int64_t size = domains_[var].Size(); if (size <= 2) return true; const auto& it = encoding_.find(var); return it == encoding_.end() ? false : size <= it->second.size(); @@ -1801,7 +1826,7 @@ int PresolveContext::GetOrCreateVarValueEncoding(int ref, int64_t value) { const int var = ref; // Returns the false literal if the value is not in the domain. - if (!domains[var].Contains(value)) { + if (!domains_[var].Contains(value)) { return GetFalseLiteral(); } @@ -1825,7 +1850,7 @@ int PresolveContext::GetOrCreateVarValueEncoding(int ref, int64_t value) { } // Special case for fixed domains. - if (domains[var].Size() == 1) { + if (domains_[var].Size() == 1) { const int true_literal = GetTrueLiteral(); var_map[value] = SavedLiteral(true_literal); return true_literal; @@ -1834,7 +1859,7 @@ int PresolveContext::GetOrCreateVarValueEncoding(int ref, int64_t value) { // Special case for domains of size 2. const int64_t var_min = MinOf(var); const int64_t var_max = MaxOf(var); - if (domains[var].Size() == 2) { + if (domains_[var].Size() == 2) { // Checks if the other value is already encoded. const int64_t other_value = value == var_min ? var_max : var_min; auto other_it = var_map.find(other_value); @@ -1942,6 +1967,7 @@ void PresolveContext::ReadObjectiveFromProto() { const int var = PositiveRef(ref); objective_map_[var] += RefIsPositive(ref) ? coeff : -coeff; + if (objective_map_[var] == 0) { RemoveVariableFromObjective(var); } else { @@ -2759,22 +2785,22 @@ void CreateValidModelWithSingleConstraint(const ConstraintProto& ct, auto [it, inserted] = inverse_interval_map.insert({i, mini_model->constraints_size()}); if (inserted) { - *mini_model->add_constraints() = context->working_model->constraints(i); + const ConstraintProto& itv_ct = context->working_model->constraints(i); + *mini_model->add_constraints() = itv_ct; // Now add end = start + size for the interval. This is not strictly // necessary but it makes the presolve more powerful. ConstraintProto* linear = mini_model->add_constraints(); - *linear->mutable_enforcement_literal() = ct.enforcement_literal(); + *linear->mutable_enforcement_literal() = itv_ct.enforcement_literal(); LinearConstraintProto* mutable_linear = linear->mutable_linear(); - const IntervalConstraintProto& itv = - context->working_model->constraints(i).interval(); + const IntervalConstraintProto& itv = itv_ct.interval(); mutable_linear->add_domain(0); mutable_linear->add_domain(0); AddLinearExpressionToLinearConstraint(itv.start(), 1, mutable_linear); AddLinearExpressionToLinearConstraint(itv.size(), 1, mutable_linear); AddLinearExpressionToLinearConstraint(itv.end(), -1, mutable_linear); - CanonicalizeLinearExpressionNoContext(ct.enforcement_literal(), + CanonicalizeLinearExpressionNoContext(itv_ct.enforcement_literal(), mutable_linear); } } diff --git a/ortools/sat/presolve_context.h b/ortools/sat/presolve_context.h index 4d3444d0d5..270372f78b 100644 --- a/ortools/sat/presolve_context.h +++ b/ortools/sat/presolve_context.h @@ -82,6 +82,11 @@ class SavedVariable { int ref_ = 0; }; +// If a floating point objective is present, scale it using the current domains +// and transform it to an integer_objective. +ABSL_MUST_USE_RESULT bool ScaleFloatingPointObjective( + const SatParameters& params, SolverLogger* logger, CpModelProto* proto); + // Wrap the CpModelProto we are presolving with extra data structure like the // in-memory domain of each variables and the constraint variable graph. class PresolveContext { @@ -141,6 +146,7 @@ class PresolveContext { int64_t FixedValue(int ref) const; bool DomainContains(int ref, int64_t value) const; Domain DomainOf(int ref) const; + absl::Span AllDomains() const { return domains_; } // Helper to query the state of an interval. bool IntervalIsConstant(int ct_ref) const; @@ -160,6 +166,10 @@ class PresolveContext { bool IsFixed(const LinearExpressionProto& expr) const; int64_t FixedValue(const LinearExpressionProto& expr) const; + // This is faster than testing IsFixed() + FixedValue(). + std::optional FixedValueOrNullopt( + const LinearExpressionProto& expr) const; + // Accepts any proto with two parallel vector .vars() and .coeffs(), like // LinearConstraintProto or ObjectiveProto or LinearExpressionProto but beware // that this ignore any offset. @@ -224,7 +234,7 @@ class PresolveContext { // This function takes a positive variable reference. bool DomainOfVarIsIncludedIn(int var, const Domain& domain) { - return domains[var].IsIncludedIn(domain); + return domains_[var].IsIncludedIn(domain); } // Returns true if this ref only appear in one constraint. @@ -388,6 +398,16 @@ class PresolveContext { // Creates the internal structure for any new variables in working_model. void InitializeNewDomains(); + // This is a bit hacky. Clear some fields. See call site. + // + // TODO(user): The ModelCopier should probably not depend on the full context + // it only need to read/write domains and call UpdateRuleStats(), so we might + // want to split that part out so that we can just initialize the full context + // later. Alternatively, we could just move more complex part of the context + // out, like the graph, the encoding, the affine representative, and so on to + // individual and easier to manage classes. + void ResetAfterCopy(); + // Clears the "rules" statistics. void ClearStats(); @@ -472,7 +492,6 @@ class PresolveContext { ABSL_MUST_USE_RESULT bool CanonicalizeOneObjectiveVariable(int var); ABSL_MUST_USE_RESULT bool CanonicalizeObjective(bool simplify_domain = true); void WriteObjectiveToProto() const; - ABSL_MUST_USE_RESULT bool ScaleFloatingPointObjective(); // When the objective is singleton, we can always restrict the domain of var // so that the current objective domain is non-constraining. Returns false @@ -616,6 +635,18 @@ class PresolveContext { bool HintIsLoaded() const { return hint_is_loaded_; } absl::Span SolutionHint() const { return hint_; } + // Similar to SolutionHint() but make sure the value is within the current + // bounds of the variable. + int64_t ClampedSolutionHint(int var) { + int64_t value = hint_[var]; + if (value > MaxOf(var)) { + value = MaxOf(var); + } else if (value < MinOf(var)) { + value = MinOf(var); + } + return value; + } + bool LiteralSolutionHint(int lit) const { const int var = PositiveRef(lit); return RefIsPositive(lit) ? hint_[var] : !hint_[var]; @@ -754,7 +785,7 @@ class PresolveContext { bool is_unsat_ = false; // The current domain of each variables. - std::vector domains; + std::vector domains_; // Parallel to domains. // diff --git a/ortools/sat/presolve_context_test.cc b/ortools/sat/presolve_context_test.cc index c8530b1d7d..7f67c4e46c 100644 --- a/ortools/sat/presolve_context_test.cc +++ b/ortools/sat/presolve_context_test.cc @@ -905,7 +905,8 @@ TEST(PresolveContextTest, ObjectiveScalingMinimize) { )pb"); PresolveContext context(&model, &working_model, nullptr); context.InitializeNewDomains(); - ASSERT_TRUE(context.ScaleFloatingPointObjective()); + ASSERT_TRUE(ScaleFloatingPointObjective(context.params(), context.logger(), + &working_model)); ASSERT_TRUE(working_model.has_objective()); ASSERT_FALSE(working_model.has_floating_point_objective()); const CpObjectiveProto& obj = working_model.objective(); @@ -929,7 +930,8 @@ TEST(PresolveContextTest, ObjectiveScalingMaximize) { )pb"); PresolveContext context(&model, &working_model, nullptr); context.InitializeNewDomains(); - ASSERT_TRUE(context.ScaleFloatingPointObjective()); + ASSERT_TRUE(ScaleFloatingPointObjective(context.params(), context.logger(), + &working_model)); ASSERT_TRUE(working_model.has_objective()); ASSERT_FALSE(working_model.has_floating_point_objective()); const CpObjectiveProto& obj = working_model.objective(); diff --git a/ortools/sat/rins.cc b/ortools/sat/rins.cc index a85f00b1dd..b33e2f5b16 100644 --- a/ortools/sat/rins.cc +++ b/ortools/sat/rins.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -73,12 +74,12 @@ std::vector GetLPRelaxationValues( return relaxation_values; } - const SharedSolutionRepository::Solution lp_solution = - lp_solutions->GetRandomBiasedSolution(random); + std::shared_ptr::Solution> + lp_solution = lp_solutions->GetRandomBiasedSolution(random); - for (int model_var = 0; model_var < lp_solution.variable_values.size(); + for (int model_var = 0; model_var < lp_solution->variable_values.size(); ++model_var) { - relaxation_values.push_back(lp_solution.variable_values[model_var]); + relaxation_values.push_back(lp_solution->variable_values[model_var]); } return relaxation_values; } @@ -207,12 +208,12 @@ ReducedDomainNeighborhood GetRinsRensNeighborhood( if (response_manager != nullptr && response_manager->SolutionsRepository().NumSolutions() > 0 && three_out_of_four(random)) { // Rins. - const std::vector solution = - response_manager->SolutionsRepository() - .GetRandomBiasedSolution(random) - .variable_values; - FillRinsNeighborhood(solution, relaxation_values, difficulty, random, - reduced_domains); + std::shared_ptr::Solution> + solution = + response_manager->SolutionsRepository().GetRandomBiasedSolution( + random); + FillRinsNeighborhood(solution->variable_values, relaxation_values, + difficulty, random, reduced_domains); reduced_domains.source_info = "rins_"; } else { // Rens. FillRensNeighborhood(relaxation_values, difficulty, random, diff --git a/ortools/sat/sat_decision.cc b/ortools/sat/sat_decision.cc index ce0d940a0a..2ab0760c60 100644 --- a/ortools/sat/sat_decision.cc +++ b/ortools/sat/sat_decision.cc @@ -15,6 +15,7 @@ #include #include +#include #include #include #include @@ -199,12 +200,12 @@ bool SatDecisionPolicy::UseLsSolutionAsInitialPolarity() { // This is in term of proto variable. // TODO(user): use cp_model_mapping. But this is not needed to experiment // on pure sat problems. - std::vector solution = - ls_hints_->GetRandomBiasedSolution(*random_).variable_values; - if (solution.size() != var_polarity_.size()) return false; + std::shared_ptr solution = + ls_hints_->GetRandomBiasedSolution(*random_); + if (solution->variable_values.size() != var_polarity_.size()) return false; - for (int i = 0; i < solution.size(); ++i) { - var_polarity_[BooleanVariable(i)] = solution[i] == 1; + for (int i = 0; i < solution->variable_values.size(); ++i) { + var_polarity_[BooleanVariable(i)] = solution->variable_values[i] == 1; } return false; diff --git a/ortools/sat/sat_parameters.proto b/ortools/sat/sat_parameters.proto index d8185ac4aa..2db8698a75 100644 --- a/ortools/sat/sat_parameters.proto +++ b/ortools/sat/sat_parameters.proto @@ -23,7 +23,7 @@ option java_multiple_files = true; // Contains the definitions for all the sat algorithm parameters and their // default values. // -// NEXT TAG: 310 +// NEXT TAG: 311 message SatParameters { // In some context, like in a portfolio of search, it makes sense to name a // given parameters set for logging purpose. @@ -450,6 +450,12 @@ message SatParameters { // Whether we also use the sat presolve when cp_model_presolve is true. optional bool cp_model_use_sat_presolve = 93 [default = true]; + // If cp_model_presolve is true and there is a large proportion of fixed + // variable after the first model copy, remap all the model to a dense set of + // variable before the full presolve even starts. This should help for LNS on + // large models. + optional bool remove_fixed_variables_early = 310 [default = true]; + // If true, we detect variable that are unique to a table constraint and only // there to encode a cost on each tuple. This is usually the case when a WCSP // (weighted constraint program) is encoded into CP-SAT format. diff --git a/ortools/sat/synchronization.cc b/ortools/sat/synchronization.cc index 6676aff803..1bc56a1587 100644 --- a/ortools/sat/synchronization.cc +++ b/ortools/sat/synchronization.cc @@ -78,14 +78,17 @@ void SharedLPSolutionRepository::NewLPSolution( if (lp_solution.empty()) return; // Add this solution to the pool. - SharedSolutionRepository::Solution solution; - solution.variable_values = std::move(lp_solution); + auto solution = + std::make_shared::Solution>(); + solution->variable_values = std::move(lp_solution); // We always prefer to keep the solution from the last synchronize batch. - absl::MutexLock mutex_lock(&mutex_); - solution.rank = -num_synchronization_; - ++num_added_; - new_solutions_.push_back(solution); + { + absl::MutexLock mutex_lock(&mutex_); + solution->rank = -num_synchronization_; + ++num_added_; + new_solutions_.push_back(solution); + } } void SharedIncompleteSolutionManager::AddSolution( @@ -546,17 +549,21 @@ CpSolverResponse SharedResponseManager::GetResponseInternal( CpSolverResponse SharedResponseManager::GetResponse() { absl::MutexLock mutex_lock(&mutex_); - CpSolverResponse result = - solutions_.NumSolutions() == 0 - ? GetResponseInternal({}, "") - : GetResponseInternal(solutions_.GetSolution(0).variable_values, - solutions_.GetSolution(0).info); - + CpSolverResponse result; + if (solutions_.NumSolutions() == 0) { + result = GetResponseInternal({}, ""); + } else { + std::shared_ptr::Solution> + solution = solutions_.GetSolution(0); + result = GetResponseInternal(solution->variable_values, solution->info); + } // If this is true, we postsolve and copy all of our solutions. if (parameters_.fill_additional_solutions_in_response()) { std::vector temp; for (int i = 0; i < solutions_.NumSolutions(); ++i) { - temp = solutions_.GetSolution(i).variable_values; + std::shared_ptr::Solution> + solution = solutions_.GetSolution(i); + temp = solution->variable_values; for (int i = solution_postprocessors_.size(); --i >= 0;) { solution_postprocessors_[i](&temp); } diff --git a/ortools/sat/synchronization.h b/ortools/sat/synchronization.h index ce731572cc..c0d34d5d59 100644 --- a/ortools/sat/synchronization.h +++ b/ortools/sat/synchronization.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -97,7 +98,7 @@ class SharedSolutionRepository { int NumSolutions() const; // Returns the solution #i where i must be smaller than NumSolutions(). - Solution GetSolution(int index) const; + std::shared_ptr GetSolution(int index) const; // Returns the rank of the best known solution. // You shouldn't call this if NumSolutions() is zero. @@ -109,7 +110,8 @@ class SharedSolutionRepository { ValueType GetVariableValueInSolution(int var_index, int solution_index) const; // Returns a random solution biased towards good solutions. - Solution GetRandomBiasedSolution(absl::BitGenRef random) const; + std::shared_ptr GetRandomBiasedSolution( + absl::BitGenRef random) const; // Add a new solution. Note that it will not be added to the pool of solution // right away. One must call Synchronize for this to happen. In order to be @@ -143,8 +145,8 @@ class SharedSolutionRepository { // Our two solutions pools, the current one and the new one that will be // merged into the current one on each Synchronize() calls. mutable std::vector tmp_indices_ ABSL_GUARDED_BY(mutex_); - std::vector solutions_ ABSL_GUARDED_BY(mutex_); - std::vector new_solutions_ ABSL_GUARDED_BY(mutex_); + std::vector> solutions_ ABSL_GUARDED_BY(mutex_); + std::vector> new_solutions_ ABSL_GUARDED_BY(mutex_); }; // Solutions coming from the LP. @@ -813,7 +815,7 @@ int SharedSolutionRepository::NumSolutions() const { } template -typename SharedSolutionRepository::Solution +std::shared_ptr::Solution> SharedSolutionRepository::GetSolution(int i) const { absl::MutexLock mutex_lock(&mutex_); ++num_queried_; @@ -824,24 +826,24 @@ template int64_t SharedSolutionRepository::GetBestRank() const { absl::MutexLock mutex_lock(&mutex_); CHECK_GT(solutions_.size(), 0); - return solutions_[0].rank; + return solutions_[0]->rank; } template ValueType SharedSolutionRepository::GetVariableValueInSolution( int var_index, int solution_index) const { absl::MutexLock mutex_lock(&mutex_); - return solutions_[solution_index].variable_values[var_index]; + return solutions_[solution_index]->variable_values[var_index]; } // TODO(user): Experiments on the best distribution. template -typename SharedSolutionRepository::Solution +std::shared_ptr::Solution> SharedSolutionRepository::GetRandomBiasedSolution( absl::BitGenRef random) const { absl::MutexLock mutex_lock(&mutex_); ++num_queried_; - const int64_t best_rank = solutions_[0].rank; + const int64_t best_rank = solutions_[0]->rank; // As long as we have solution with the best objective that haven't been // explored too much, we select one uniformly. Otherwise, we select a solution @@ -855,9 +857,9 @@ SharedSolutionRepository::GetRandomBiasedSolution( // Select all the best solution with a low enough selection count. tmp_indices_.clear(); for (int i = 0; i < solutions_.size(); ++i) { - const auto& solution = solutions_[i]; - if (solution.rank == best_rank && - solution.num_selected <= kExplorationThreshold) { + std::shared_ptr solution = solutions_[i]; + if (solution->rank == best_rank && + solution->num_selected <= kExplorationThreshold) { tmp_indices_.push_back(i); } } @@ -868,16 +870,20 @@ SharedSolutionRepository::GetRandomBiasedSolution( } else { index = tmp_indices_[absl::Uniform(random, 0, tmp_indices_.size())]; } - solutions_[index].num_selected++; + solutions_[index]->num_selected++; return solutions_[index]; } template void SharedSolutionRepository::Add(Solution solution) { if (num_solutions_to_keep_ <= 0) return; - absl::MutexLock mutex_lock(&mutex_); - ++num_added_; - new_solutions_.push_back(std::move(solution)); + std::shared_ptr solution_ptr = + std::make_shared(std::move(solution)); + { + absl::MutexLock mutex_lock(&mutex_); + ++num_added_; + new_solutions_.push_back(std::move(solution_ptr)); + } } template @@ -893,15 +899,17 @@ void SharedSolutionRepository::Synchronize() { // existing solutions. // // TODO(user): Introduce a notion of orthogonality to diversify the pool? - gtl::STLStableSortAndRemoveDuplicates(&solutions_); + gtl::STLStableSortAndRemoveDuplicates( + &solutions_, [](const std::shared_ptr& a, + const std::shared_ptr& b) { return *a < *b; }); if (solutions_.size() > num_solutions_to_keep_) { solutions_.resize(num_solutions_to_keep_); } if (!solutions_.empty()) { VLOG(2) << "Solution pool update:" << " num_solutions=" << solutions_.size() - << " min_rank=" << solutions_[0].rank - << " max_rank=" << solutions_.back().rank; + << " min_rank=" << solutions_[0]->rank + << " max_rank=" << solutions_.back()->rank; } num_synchronization_++;