Skip to content

Commit f68d20d

Browse files
committed
Merge remote-tracking branch 'origin/bfws-v1.0-beta1_goal_directed' into bfws-v1.0-beta1
2 parents 97299e0 + 0de6174 commit f68d20d

File tree

3 files changed

+89
-36
lines changed

3 files changed

+89
-36
lines changed

src/search/drivers/sbfws/iw_run.hxx

+81-33
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11

22
#pragma once
33

4+
#include <stdio.h>
5+
#include <stdlib.h>
6+
47
#include <iomanip>
58
#include <unordered_set>
69

@@ -192,9 +195,10 @@ public:
192195
//!
193196
unsigned _max_width;
194197

195-
Config(int bound, bool complete, bool mark_negative) :
196-
_bound(bound), _complete(complete), _mark_negative(mark_negative), _max_width(1) {}
198+
bool _use_goal_directed_info;
197199

200+
Config(int bound, bool complete, bool mark_negative, unsigned max_width, bool goal_directed_info) :
201+
_bound(bound), _complete(complete), _mark_negative(mark_negative), _max_width(max_width), _use_goal_directed_info(goal_directed_info) {}
198202

199203
};
200204

@@ -203,7 +207,9 @@ protected:
203207
Config _config;
204208

205209
//! _all_paths[i] contains all paths in the simulation that reach a node that satisfies goal atom 'i'.
206-
// std::vector<std::vector<NodePT>> _all_paths;
210+
// std::vector<std::vector<NodePT>> _all_paths;
211+
212+
std::vector<NodePT> _optimal_paths;
207213

208214
//! '_unreached' contains the indexes of all those goal atoms that have yet not been reached.
209215
// TODO REMOVE
@@ -239,6 +245,7 @@ public:
239245
Base(model, OpenListT(), ClosedListT()),
240246
_config(config),
241247
// _all_paths(model.num_subgoals()),
248+
_optimal_paths(model.num_subgoals()),
242249
_unreached(),
243250
_in_seed(model.num_subgoals(), false),
244251
// _visited(),
@@ -267,14 +274,13 @@ public:
267274

268275

269276
//! Mark all atoms in the path to some goal. 'seed_nodes' contains all nodes satisfying some subgoal.
270-
std::vector<bool> compute_relevant_w1_atoms(const std::vector<NodePT>& seed_nodes) const {
277+
void mark_atoms_in_path_to_subgoal(const std::vector<NodePT>& seed_nodes, std::vector<bool>& atoms) const {
271278
const AtomIndex& index = Problem::getInstance().get_tuple_index();
272279
std::unordered_set<NodePT> all_visited;
273-
std::vector<bool> seen_atoms(index.size());
280+
assert(atoms.size() == index.size());
274281

275282
for (NodePT node:seed_nodes) {
276283

277-
278284
NodePT root = node;
279285
// We ignore s0
280286
while (node->has_parent()) {
@@ -287,44 +293,87 @@ public:
287293
for (unsigned var = 0; var < state.numAtoms(); ++var) {
288294
if (state.getValue(var) == 0) continue; // TODO THIS WON'T GENERALIZE WELL TO FSTRIPS DOMAINS
289295
AtomIdx atom = index.to_index(var, state.getValue(var));
290-
seen_atoms[atom] = true;
296+
atoms[atom] = true;
291297
}
292298

293299
node = node->parent;
294300
}
295301
}
296-
return seen_atoms;
302+
}
303+
304+
305+
std::vector<bool> compute_R(const StateT& seed) {
306+
if (_config._use_goal_directed_info) {
307+
return compute_goal_directed_R(seed);
308+
} else {
309+
return compute_R_IW1(seed);
310+
}
297311
}
298312

299313
std::vector<bool> compute_R_IW1(const StateT& seed) {
314+
LPT_INFO("cout", "IW Simulation - Computing blind R");
300315
_config._max_width = 1;
301316
_config._bound = -1; // No bound
302-
std::vector<NodePT> w1_seed_nodes;
303-
compute_R(seed, w1_seed_nodes);
304-
305-
// auto rset = compute_relevant_w1_atoms(w1_seed_nodes);
306-
auto rset = _evaluator.reached_atoms();
307-
LPT_INFO("cout", "IW Simulation - |R_{IW(1)}| = " << std::count(rset.begin(), rset.end(), true)); // TODO REMOVE THIS, IT'S EXPENSIVE
308-
return rset;
317+
std::vector<NodePT> seed_nodes;
318+
_compute_R(seed, seed_nodes);
319+
320+
LPT_INFO("cout", "IW Simulation - Number of seed nodes: " << seed_nodes.size());
321+
std::vector<bool> rel_blind = _evaluator.reached_atoms();
322+
LPT_INFO("cout", "IW Simulation - Blind |R| = " << std::count(rel_blind.begin(), rel_blind.end(), true));
323+
return rel_blind;
309324
}
325+
326+
std::vector<bool> compute_goal_directed_R(const StateT& seed) {
327+
LPT_INFO("cout", "IW Simulation - Computing goal-directed R");
328+
const AtomIndex& index = Problem::getInstance().get_tuple_index();
329+
_config._max_width = 2;
330+
_config._bound = -1; // No bound
331+
std::vector<NodePT> seed_nodes;
332+
_compute_R(seed, seed_nodes);
333+
310334

335+
LPT_INFO("cout", "IW Simulation - Number of seed nodes: " << seed_nodes.size());
336+
337+
std::vector<bool> rel_goal_directed(index.size(), false);
338+
mark_atoms_in_path_to_subgoal(seed_nodes, rel_goal_directed);
339+
LPT_INFO("cout", "IW Simulation - Goal-directed |R| = " << std::count(rel_goal_directed.begin(), rel_goal_directed.end(), true));
340+
return rel_goal_directed;
341+
}
311342

312-
std::vector<AtomIdx> compute_R(const StateT& seed, std::vector<NodePT>& w1_seed_nodes) {
343+
344+
std::vector<AtomIdx> _compute_R(const StateT& seed, std::vector<NodePT>& seed_nodes) {
313345

314346
_config._complete = false;
315347

316-
bool all_reached_before_bound = _run(seed);
348+
_run(seed);
317349

318-
if (all_reached_before_bound) {
319-
for (const auto& n:_w1_nodes) {
320-
if (n->satisfies_subgoal) w1_seed_nodes.push_back(n);
321-
}
322-
} else {
323-
w1_seed_nodes = _w1_nodes;
350+
LPT_INFO("cout", "IW Simulation - Num unreached subgoals: " << _unreached.size() << " / " << this->_model.num_subgoals());
351+
if (!_unreached.empty()) {
352+
LPT_INFO("cout", "Some subgoals not reached during the simulation. ABORTING");
353+
exit(1);
324354
}
325355

356+
// std::vector<NodePT> w1_goal_reaching_nodes;
357+
// std::vector<NodePT> w2_goal_reaching_nodes;
358+
// std::vector<NodePT> wgt2_goal_reaching_nodes;
359+
360+
361+
/*
362+
for (unsigned subgoal_idx = 0; subgoal_idx < _all_paths.size(); ++subgoal_idx) {
363+
const std::vector<NodePT>& paths = _all_paths[subgoal_idx];
364+
assert(_in_seed[subgoal_idx] || !paths.empty());
365+
seed_nodes.insert(seed_nodes.end(), paths.begin(), paths.end());
366+
}
367+
*/
368+
369+
for (unsigned subgoal_idx = 0; subgoal_idx < _optimal_paths.size(); ++subgoal_idx) {
370+
if (!_in_seed[subgoal_idx]) {
371+
assert(_optimal_paths[subgoal_idx] != nullptr);
372+
seed_nodes.push_back(_optimal_paths[subgoal_idx]);
373+
}
374+
}
375+
326376

327-
LPT_INFO("cout", "IW Simulation - Num unreached subgoals: " << _unreached.size() << " / " << this->_model.num_subgoals());
328377
/*
329378
LPT_INFO("cout", "IW Simulation - Number of novelty-1 nodes: " << _w1_nodes.size());
330379
LPT_INFO("cout", "IW Simulation - Number of novelty=1 nodes expanded in the simulation: " << _w1_nodes_expanded);
@@ -356,18 +405,14 @@ public:
356405
}
357406

358407
bool _run(const StateT& seed) {
359-
mark_seed_subgoals(seed);
360-
361408
NodePT n = std::make_shared<NodeT>(seed, _generated++);
409+
mark_seed_subgoals(n);
362410

363411
auto nov =_evaluator.evaluate(*n);
364412
_unused(nov);
365413
assert(nov==1);
366414
// LPT_INFO("cout", "IW Simulation - Seed node: " << *n);
367415

368-
369-
// if (process_node(n)) return;
370-
371416
this->_open.insert(n);
372417

373418
unsigned accepted = 1; // (the root node counts as accepted)
@@ -440,6 +485,7 @@ protected:
440485
if (this->_model.goal(state, subgoal_idx)) {
441486
node->satisfies_subgoal = true;
442487
// _all_paths[subgoal_idx].push_back(node);
488+
if (!_optimal_paths[subgoal_idx]) _optimal_paths[subgoal_idx] = node;
443489
it = _unreached.erase(it);
444490
} else {
445491
++it;
@@ -456,15 +502,17 @@ protected:
456502
for (unsigned i = 0; i < this->_model.num_subgoals(); ++i) {
457503
if (!_in_seed[i] && this->_model.goal(state, i)) {
458504
node->satisfies_subgoal = true;
459-
// _all_paths[i].push_back(node);
505+
if (!_optimal_paths[i]) _optimal_paths[i] = node;
506+
_unreached.erase(i);
460507
}
461508
}
462-
return false;
509+
return _unreached.empty();
510+
//return false; // return false so we don't interrupt the processing
463511
}
464512

465-
void mark_seed_subgoals(const StateT& seed) {
513+
void mark_seed_subgoals(const NodePT& node) {
466514
for (unsigned i = 0; i < this->_model.num_subgoals(); ++i) {
467-
if (this->_model.goal(seed, i)) {
515+
if (this->_model.goal(node->state, i)) {
468516
_in_seed[i] = true;
469517
} else {
470518
_unreached.insert(i);

src/search/drivers/sbfws/sgbfs.hxx

+6-1
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,12 @@ public:
191191
_wgr_novelty_evaluators(),
192192
_unsat_goal_atoms_heuristic(_problem),
193193
_mark_negative_propositions(config.mark_negative_propositions),
194-
_simconfig(c.getOption<int>("sim.bound", 10000), config.complete_simulation, config.mark_negative_propositions),
194+
_simconfig(c.getOption<int>("sim.bound", 10000),
195+
config.complete_simulation,
196+
config.mark_negative_propositions,
197+
config.simulation_width,
198+
c.getOption<bool>("goal_directed", false)
199+
),
195200
_stats(stats),
196201
_aptk_rpg(nullptr),
197202
_rstype(config.relevant_set_type)

src/utils/static.cxx

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ StaticExtension::load_static_extension(const std::string& name, const ProblemInf
3232
else extension = new Arity3Function(Serializer::deserializeArity3Map(filename));
3333

3434
} else if (arity == 4) {
35-
if (type == SymbolData::Type::PREDICATE) extension = new Arity4Function(Serializer::deserializeArity4Map(filename));
36-
else extension = new Arity4Predicate(Serializer::deserializeArity4Set(filename));
35+
if (type == SymbolData::Type::PREDICATE) extension = new Arity4Predicate(Serializer::deserializeArity4Set(filename));
36+
else extension = new Arity4Function(Serializer::deserializeArity4Map(filename));
3737

3838

3939
} else WORK_IN_PROGRESS("Such high symbol arities have not yet been implemented");

0 commit comments

Comments
 (0)