Skip to content

Commit

Permalink
finished first version of DNNF projected model enumeration
Browse files Browse the repository at this point in the history
  • Loading branch information
SHildebrandt committed Jan 22, 2024
1 parent 66bec63 commit ef1a693
Show file tree
Hide file tree
Showing 5 changed files with 339 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@
import org.logicng.formulas.Variable;
import org.logicng.knowledgecompilation.dnnf.functions.DnnfFunction;

import java.util.Objects;
import java.util.SortedSet;

/**
* A DNNF - Decomposable Negation Normal Form.
* @version 2.0.0
* @since 2.0.0
* @since 2.5.0
*/
public final class Dnnf {

Expand Down Expand Up @@ -79,4 +80,21 @@ public Formula formula() {
public SortedSet<Variable> getOriginalVariables() {
return this.originalVariables;
}

@Override
public boolean equals(final Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
final Dnnf dnnf = (Dnnf) o;
return Objects.equals(this.originalVariables, dnnf.originalVariables) && Objects.equals(this.formula, dnnf.formula);
}

@Override
public int hashCode() {
return Objects.hash(this.originalVariables, this.formula);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,103 +29,151 @@
package org.logicng.knowledgecompilation.dnnf.functions;

import static java.util.Collections.emptyList;
import static java.util.Collections.emptySet;
import static java.util.Collections.singleton;

import org.logicng.datastructures.Assignment;
import org.logicng.formulas.Formula;
import org.logicng.formulas.Literal;
import org.logicng.formulas.Variable;
import org.logicng.knowledgecompilation.dnnf.datastructures.Dnnf;
import org.logicng.util.CollectionHelper;
import org.logicng.util.Pair;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.stream.Collectors;

/**
* A DNNF function which counts models.
* A DNNF function for (projected) model enumeration.
* <p>
* This function does not require a deterministic DNNF.
* <p>
* The enumeration requires a set of variables which the formula is projected
* to. Variables which are not known by the DNNF will be added negatively to
* all models.
* <p>
* Note that an enumeration over {@code n} variables can have up to {@code 2^n}
* models, so the number of variables must be chosen carefully.
* @version 2.5.0
* @since 2.5.0
*/
public final class DnnfModelEnumerationFunction implements DnnfFunction<List<Assignment>> {

private final SortedSet<Variable> variables;

/**
* Creates a new DNNF model enumeration function over the given set of
* variables.
* @param variables the variables, must not be {@code null}
*/
public DnnfModelEnumerationFunction(final Collection<Variable> variables) {
this.variables = variables == null ? null : new TreeSet<>(variables);
this.variables = new TreeSet<>(variables);
}

@Override
public List<Assignment> apply(final SortedSet<Variable> originalVariables, final Formula formula) {
final Formula projectedFormula = this.variables == null ? formula : new DnnfProjectionFunction(this.variables).apply(originalVariables, formula).formula();
final Set<Set<Literal>> partialModels = applyRec(projectedFormula);
final SortedSet<Variable> globalDontCares = originalVariables.stream()
.filter(v -> !formula.variables().contains(v) && (this.variables == null || this.variables.contains(v)))
final Dnnf projectedDnnf = this.variables.equals(originalVariables)
? new Dnnf(this.variables, formula)
: new DnnfProjectionFunction(this.variables).apply(originalVariables, formula);
final Formula projectedFormula = projectedDnnf.formula();
final SortedSet<Variable> projectedVariables = projectedDnnf.getOriginalVariables();
final Pair<List<Variable>, Map<Variable, Integer>> mapping = createVarMapping(projectedVariables);
final List<Variable> intToVar = mapping.first();
final Map<Variable, Integer> varToInt = mapping.second();

final Set<BitSet> partialModels = computePartialModels(projectedFormula, varToInt);

final SortedSet<Variable> globalDontCares = projectedVariables.stream()
.filter(v -> !projectedFormula.variables().contains(v))
.collect(Collectors.toCollection(TreeSet::new));
final List<Variable> invalidVars = this.variables == null
? emptyList()
: this.variables.stream().filter(v -> !originalVariables.contains(v)).collect(Collectors.toList());
final Set<Set<Literal>> partialDontCareModels = getCartesianProduct(globalDontCares);
final Set<Set<Literal>> partialModelsWithGlobalDontCares = combineDisjointModels(Arrays.asList(partialModels, partialDontCareModels));
final Set<Set<Literal>> result = expandModelsWithMissingVars(partialModelsWithGlobalDontCares, this.variables);
addInvalidVars(result, invalidVars);
final List<Assignment> resultAssignments = new ArrayList<>();
for (final Set<Literal> model : result) {
resultAssignments.add(new Assignment(model));
}
return resultAssignments;
final Set<BitSet> partialDontCareModels = getCartesianProduct(globalDontCares.stream().map(varToInt::get).collect(Collectors.toList()), projectedVariables.size());
final Set<BitSet> partialModelsWithGlobalDontCares = combineDisjointModels(Arrays.asList(partialModels, partialDontCareModels));

final Set<BitSet> expandedModels = expandModelsWithMissingVars(partialModelsWithGlobalDontCares, projectedVariables.size());

final SortedSet<Variable> invalidVars = CollectionHelper.difference(this.variables, originalVariables, TreeSet::new);
return translateBitSetsToAssignments(expandedModels, intToVar, invalidVars);
}

private Set<Set<Literal>> applyRec(final Formula formula) {
private static Set<BitSet> computePartialModels(final Formula formula, final Map<Variable, Integer> varToInt) {
switch (formula.type()) {
case FALSE:
return new HashSet<>();
case LITERAL:
case TRUE:
return singleton(formula.literals());
return singleton(new BitSet(varToInt.size() * 2));
case LITERAL:
final BitSet newBitSet = new BitSet(varToInt.size() * 2);
for (final Literal lit : formula.literals()) {
newBitSet.set(2 * varToInt.get(lit.variable()) + (lit.phase() ? 0 : 1));
}
return singleton(newBitSet);
case AND:
final List<Set<Set<Literal>>> opResults = new ArrayList<>();
final List<Set<BitSet>> opResults = new ArrayList<>();
for (final Formula op : formula) {
opResults.add(applyRec(op));
opResults.add(computePartialModels(op, varToInt));
}
return combineDisjointModels(opResults);
case OR:
final Set<Set<Literal>> allModels = new HashSet<>();
final Set<BitSet> allModels = new HashSet<>();
for (final Formula op : formula) {
allModels.addAll(applyRec(op));
allModels.addAll(computePartialModels(op, varToInt));
}
return allModels;
default:
throw new IllegalArgumentException("Unexpected formula type: " + formula.type());
}
}

private Set<Set<Literal>> expandModelsWithMissingVars(final Set<Set<Literal>> partialModels, final SortedSet<Variable> allVariables) {
final Set<Set<Literal>> result = new HashSet<>();
for (final Set<Literal> partialModel : partialModels) {
final Set<Variable> missingVariables = new HashSet<>(allVariables);
for (final Literal lit : partialModel) {
missingVariables.remove(lit.variable());
}
private static Set<BitSet> expandModelsWithMissingVars(final Set<BitSet> partialModels, final int numVars) {
final Set<BitSet> result = new HashSet<>();
for (final BitSet partialModel : partialModels) {
final List<Integer> missingVariables = findMissingVars(partialModel, numVars);
if (missingVariables.isEmpty()) {
result.add(partialModel);
} else {
result.addAll(combineDisjointModels(Arrays.asList(singleton(partialModel), getCartesianProduct(missingVariables))));
result.addAll(combineDisjointModels(Arrays.asList(singleton(partialModel), getCartesianProduct(missingVariables, numVars))));
}
}
return result;
}

private static void addInvalidVars(final Set<Set<Literal>> models, final List<Variable> invalidVars) {
final List<Literal> negated = invalidVars.stream().map(Variable::negate).collect(Collectors.toList());
for (final Set<Literal> model : models) {
model.addAll(negated);
private static List<Integer> findMissingVars(final BitSet partialModel, final int numVars) {
final int cardinality = partialModel.cardinality();
if (cardinality == numVars) {
return emptyList();
}
final List<Integer> missing = new ArrayList<>(numVars - cardinality);
for (int i = 0; i < numVars; i++) {
if (!partialModel.get(2 * i) && !partialModel.get(2 * i + 1)) {
missing.add(i);
}
}
return missing;
}

private static List<Assignment> translateBitSetsToAssignments(final Set<BitSet> models, final List<Variable> intToVar, final SortedSet<Variable> invalidVars) {
final List<Assignment> result = new ArrayList<>(models.size());
for (final BitSet model : models) {
final Assignment assignment = new Assignment();
for (int i = model.nextSetBit(0); i >= 0; i = model.nextSetBit(i + 1)) {
final Variable variable = intToVar.get(i / 2);
assignment.addLiteral(i % 2 == 0 ? variable : variable.negate());
}
for (final Variable invalidVar : invalidVars) {
assignment.addLiteral(invalidVar.negate());
}
result.add(assignment);
}
return result;
}

/**
Expand All @@ -141,51 +189,59 @@ private static void addInvalidVars(final Set<Set<Literal>> models, final List<Va
* @param modelLists the list of sets of models
* @return the combined model-set list
*/
private static Set<Set<Literal>> combineDisjointModels(final List<Set<Set<Literal>>> modelLists) {
Set<Set<Literal>> currentModels = modelLists.get(0);
private static Set<BitSet> combineDisjointModels(final List<Set<BitSet>> modelLists) {
Set<BitSet> currentModels = modelLists.get(0);
for (int i = 1; i < modelLists.size(); i++) {
final Set<Set<Literal>> additionalModels = modelLists.get(i);
final Set<Set<Literal>> newModels = new HashSet<>();
for (final Set<Literal> currentModel : currentModels) {
for (final Set<Literal> additionalModel : additionalModels) {
newModels.add(setAdd(currentModel, additionalModel));
final Set<BitSet> additionalModels = modelLists.get(i);
final Set<BitSet> newModels = new HashSet<>();
for (final BitSet currentModel : currentModels) {
for (final BitSet additionalModel : additionalModels) {
newModels.add(combineBitSets(currentModel, additionalModel));
}
}
currentModels = newModels;
}
return currentModels;
}

private static BitSet combineBitSets(final BitSet currentModel, final BitSet additionalModel) {
final BitSet copiedBitSet = (BitSet) currentModel.clone();
copiedBitSet.or(additionalModel);
return copiedBitSet;
}

/**
* Returns the Cartesian product for the given variables, i.e. all
* combinations of literals are generated with each variable occurring
* positively and negatively.
* @param variables the variables, must not be {@code null}
* @return the Cartesian product
*/
private static Set<Set<Literal>> getCartesianProduct(final Collection<Variable> variables) {
Set<Set<Literal>> result = singleton(emptySet());
for (final Variable var : variables) {
final Set<Set<Literal>> extended = new HashSet<>(result.size() * 2);
for (final Set<Literal> literals : result) {
extended.add(extendedByLiteral(literals, var));
extended.add(extendedByLiteral(literals, var.negate()));
private static Set<BitSet> getCartesianProduct(final Collection<Integer> variables, final int numVars) {
Set<BitSet> result = singleton(new BitSet(numVars * 2));
for (final int var : variables) {
final Set<BitSet> extended = new HashSet<>(result.size() * 2);
for (final BitSet current : result) {
extended.add(extendedByLiteral(current, 2 * var));
extended.add(extendedByLiteral(current, 2 * var + 1));
}
result = extended;
}
return result;
}

private static Set<Literal> extendedByLiteral(final Set<Literal> literals, final Literal lit) {
final Set<Literal> extended = new HashSet<>(literals);
extended.add(lit);
private static BitSet extendedByLiteral(final BitSet current, final int lit) {
final BitSet extended = (BitSet) current.clone();
extended.set(lit);
return extended;
}

private static Set<Literal> setAdd(final Set<Literal> first, final Set<Literal> second) {
final Set<Literal> result = new HashSet<>();
result.addAll(first);
result.addAll(second);
return result;
private static Pair<List<Variable>, Map<Variable, Integer>> createVarMapping(final SortedSet<Variable> originalVariables) {
final List<Variable> intToVariable = new ArrayList<>(originalVariables);
final Map<Variable, Integer> variableToInt = new HashMap<>();
for (int i = 0; i < intToVariable.size(); i++) {
variableToInt.put(intToVariable.get(i), i);
}
return new Pair<>(intToVariable, variableToInt);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,52 +32,75 @@
import org.logicng.formulas.Literal;
import org.logicng.formulas.Variable;
import org.logicng.knowledgecompilation.dnnf.datastructures.Dnnf;
import org.logicng.util.CollectionHelper;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.SortedSet;
import java.util.TreeSet;

/**
* A DNNF function which counts models.
* A DNNF function which projects the DNNF to a given set of variables.
* <p>
* Not the resulting DNNF does <b>NOT</b> have deterministic property, and
* thus, it must not be used for model counting or other DNNF operations
* requiring a d-DNNF.
* @version 2.5.0
* @since 2.5.0
*/
public final class DnnfProjectionFunction implements DnnfFunction<Dnnf> {

private final SortedSet<Variable> variables;

/**
* Creates a new DNNF function to project a DNNF to the given set of
* variables.
* @param variables the variables to project to
*/
public DnnfProjectionFunction(final Collection<Variable> variables) {
this.variables = new TreeSet<>(variables);
}

@Override
public Dnnf apply(final SortedSet<Variable> originalVariables, final Formula formula) {
return new Dnnf(originalVariables, applyRec(formula));
return new Dnnf(CollectionHelper.intersection(this.variables, originalVariables, TreeSet::new), applyRec(formula, new HashMap<>()));
}

private Formula applyRec(final Formula formula) {
private Formula applyRec(final Formula formula, final Map<Formula, Formula> cache) {
final Formula cached = cache.get(formula);
if (cached != null) {
return cached;
}
final Formula result;
switch (formula.type()) {
case TRUE:
case FALSE:
return formula;
result = formula;
break;
case LITERAL:
return this.variables.contains(((Literal) formula).variable()) ? formula : formula.factory().verum();
result = this.variables.contains(((Literal) formula).variable()) ? formula : formula.factory().verum();
break;
case OR:
final List<Formula> newOrOps = new ArrayList<>();
for (final Formula op : formula) {
newOrOps.add(applyRec(op));
newOrOps.add(applyRec(op, cache));
}
return formula.factory().or(newOrOps);
result = formula.factory().or(newOrOps);
break;
case AND:
final List<Formula> newAndOps = new ArrayList<>();
for (final Formula op : formula) {
newAndOps.add(applyRec(op));
newAndOps.add(applyRec(op, cache));
}
return formula.factory().and(newAndOps);
result = formula.factory().and(newAndOps);
break;
default:
throw new IllegalArgumentException("Unexpected formula type: " + formula.type());
}
cache.put(formula, result);
return result;
}
}
Loading

0 comments on commit ef1a693

Please sign in to comment.