Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal committed Jan 6, 2025
1 parent e9822a7 commit 2f8c8dd
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 7 deletions.
4 changes: 2 additions & 2 deletions gtsam/discrete/tests/testDiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ TEST(DiscreteConditional, constructors) {
DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2);
DecisionTreeFactor expected2 = f2 / f2.sum(1);
DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor();
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));

std::vector<double> probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75};
Expand All @@ -70,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) {
DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2);
DecisionTreeFactor expected2 = f2 / f2.sum(1);
DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor();
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
}

Expand Down
7 changes: 4 additions & 3 deletions gtsam/discrete/tests/testDiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9);

// Check if graph product works
DecisionTreeFactor product = graph.product();
DecisionTreeFactor product = graph.product()->toDecisionTreeFactor();
EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9);
}

Expand All @@ -117,7 +117,7 @@ TEST(DiscreteFactorGraph, test) {
*std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);

// Normalize newFactor by max for comparison with expected
auto denominator = newFactor.max(newFactor.size());
auto denominator = newFactor.max(newFactor.size())->toDecisionTreeFactor();

newFactor = newFactor / denominator;

Expand All @@ -131,7 +131,8 @@ TEST(DiscreteFactorGraph, test) {
CHECK(&newFactor);
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
// Normalize by max.
denominator = expectedFactor.max(expectedFactor.size());
denominator =
expectedFactor.max(expectedFactor.size())->toDecisionTreeFactor();
// Ensure denominator is correct.
expectedFactor = expectedFactor / denominator;
EXPECT(assert_equal(expectedFactor, newFactor));
Expand Down
2 changes: 1 addition & 1 deletion gtsam_unstable/discrete/tests/testCSP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ TEST(CSP, allInOne) {
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);

// Just for fun, create the product and check it
DecisionTreeFactor product = csp.product();
DecisionTreeFactor product = csp.product()->toDecisionTreeFactor();
// product.dot("product");
DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0");
EXPECT(assert_equal(expectedProduct, product));
Expand Down
2 changes: 1 addition & 1 deletion gtsam_unstable/discrete/tests/testScheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ TEST(schedulingExample, test) {
EXPECT(assert_equal(expected, (DiscreteFactorGraph)s));

// Do brute force product and output that to file
DecisionTreeFactor product = s.product();
DecisionTreeFactor product = s.product()->toDecisionTreeFactor();
// product.dot("scheduling", false);

// Do exact inference
Expand Down

0 comments on commit 2f8c8dd

Please sign in to comment.