Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aggregate clause #1271

Merged
merged 16 commits into from
Nov 3, 2023
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package org.opencds.cqf.cql.engine.elm.executing;

import org.cqframework.cql.elm.visiting.ElmLibraryVisitor;
import org.hl7.elm.r1.AggregateClause;
import org.opencds.cqf.cql.engine.exception.CqlException;
import org.opencds.cqf.cql.engine.execution.State;
import org.opencds.cqf.cql.engine.execution.Variable;
import org.opencds.cqf.cql.engine.runtime.Tuple;

import java.util.List;
import java.util.Objects;

/*
CQL provides support for a limited class of recursive problems
using the aggregate clause of the query construct.
This clause is similar in function to the JavaScript .reduce() function,
in that it allows an expression to be repeatedly evaluated for each element of a list,
and that expression can access the current value of the aggregation.

https://cql.hl7.org/03-developersguide.html#aggregate-queries
*/

public class AggregateClauseEvaluator {

public static Object aggregate(AggregateClause elm, State state, ElmLibraryVisitor<Object, State> visitor, List<Object> elements) {
Objects.requireNonNull(elm, "elm can not be null");
Objects.requireNonNull(elements, "elements can not be null");
Objects.requireNonNull(state, "state can not be null");

if (elm.isDistinct()) {
elements = DistinctEvaluator.distinct(elements, state);
}

Object aggregatedValue = null;
if (elm.getStarting() != null) {
aggregatedValue = visitor.visitExpression(elm.getStarting(), state);
}

for(var e : elements) {
state.push(new Variable().withName(elm.getIdentifier()).withValue(aggregatedValue));
int pushes = 1;
if (!(e instanceof Tuple)) {
throw new CqlException("expected aggregation source to be a Tuple");
}
var tuple = (Tuple)e;
for (var p : tuple.getElements().entrySet()) {
state.push(new Variable().withName(p.getKey()).withValue(p.getValue()));
pushes++;
}

try {
aggregatedValue = visitor.visitExpression(elm.getExpression(), state);
}
finally {
while(pushes > 0) {
state.pop();
pushes--;
}
}
}

return aggregatedValue;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

public class FunctionRefEvaluator {

private static final Logger logger =LoggerFactory.getLogger(FunctionRefEvaluator.class);
private static final Logger logger = LoggerFactory.getLogger(FunctionRefEvaluator.class);

public static Object internalEvaluate(FunctionRef functionRef, State state, ElmLibraryVisitor<Object,State> visitor) {
ArrayList<Object> arguments = new ArrayList<>(functionRef.getOperand().size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import org.cqframework.cql.elm.visiting.ElmLibraryVisitor;
import org.hl7.elm.r1.*;
import org.opencds.cqf.cql.engine.exception.CqlException;
import org.opencds.cqf.cql.engine.execution.State;
import org.opencds.cqf.cql.engine.execution.Variable;
import org.opencds.cqf.cql.engine.runtime.CqlList;
import org.opencds.cqf.cql.engine.runtime.Tuple;
import org.opencds.cqf.cql.engine.runtime.iterators.QueryIterator;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
Expand All @@ -27,26 +29,30 @@ public static Iterable<Object> ensureIterable(Object source) {
}
}

private static void evaluateLets(Query elm, State state, List<Variable> letVariables, ElmLibraryVisitor<Object, State> visitor) {
private static void evaluateLets(Query elm, State state, List<Variable> letVariables,
ElmLibraryVisitor<Object, State> visitor) {
for (int i = 0; i < elm.getLet().size(); i++) {
letVariables.get(i).setValue(visitor.visitExpression(elm.getLet().get(i).getExpression(), state));
}
}

private static boolean evaluateRelationships(Query elm, State state, ElmLibraryVisitor<Object, State> visitor) {
// TODO: This is the most naive possible implementation here, but it should perform okay with 1) caching and 2) small data sets
// TODO: This is the most naive possible implementation here, but it should
// perform okay with 1) caching and 2) small data sets
boolean shouldInclude = true;
for (org.hl7.elm.r1.RelationshipClause relationship : elm.getRelationship()) {
boolean hasSatisfyingData = false;
Iterable<Object> relatedSourceData = ensureIterable(visitor.visitExpression(relationship.getExpression(), state));
Iterable<Object> relatedSourceData = ensureIterable(
visitor.visitExpression(relationship.getExpression(), state));
for (Object relatedElement : relatedSourceData) {
state.push(new Variable().withName(relationship.getAlias()).withValue(relatedElement));
try {
Object satisfiesRelatedCondition = visitor.visitExpression(relationship.getSuchThat(), state);
if ((relationship instanceof org.hl7.elm.r1.With
|| relationship instanceof org.hl7.elm.r1.Without) && Boolean.TRUE.equals(satisfiesRelatedCondition)) {
hasSatisfyingData = true;
break; // Once we have detected satisfying data, no need to continue testing
|| relationship instanceof org.hl7.elm.r1.Without)
&& Boolean.TRUE.equals(satisfiesRelatedCondition)) {
hasSatisfyingData = true;
break; // Once we have detected satisfying data, no need to continue testing
}
} finally {
state.pop();
Expand All @@ -56,7 +62,8 @@ private static boolean evaluateRelationships(Query elm, State state, ElmLibraryV
if ((relationship instanceof org.hl7.elm.r1.With && !hasSatisfyingData)
|| (relationship instanceof org.hl7.elm.r1.Without && hasSatisfyingData)) {
shouldInclude = false;
break; // Once we have determined the row should not be included, no need to continue testing other related information
break; // Once we have determined the row should not be included, no need to continue
// testing other related information
}
}

Expand All @@ -74,24 +81,35 @@ private static boolean evaluateWhere(Query elm, State state, ElmLibraryVisitor<O
return true;
}

private static Object evaluateReturn(Query elm, State state, List<Variable> variables, List<Object> elements, ElmLibraryVisitor<Object, State> visitor) {
return elm.getReturn() != null ? visitor.visitExpression(elm.getReturn().getExpression(), state) : constructResult(state, variables, elements);
private static Object evaluateReturn(Query elm, State state, List<Variable> variables, List<Object> elements,
ElmLibraryVisitor<Object, State> visitor) {
return elm.getReturn() != null ? visitor.visitExpression(elm.getReturn().getExpression(), state)
: constructResult(state, variables, elements);
}

private static List<Object> evaluateAggregate(AggregateClause elm, State state, ElmLibraryVisitor<Object, State> visitor, List<Object> elements) {
return Collections.singletonList(AggregateClauseEvaluator.aggregate(elm, state, visitor, elements));
}

private static Object constructTuple(State state, List<Variable> variables, List<Object> elements) {
LinkedHashMap<String, Object> elementMap = new LinkedHashMap<>();
for (int i = 0; i < variables.size(); i++) {
JPercival marked this conversation as resolved.
Show resolved Hide resolved
elementMap.put(variables.get(i).getName(), variables.get(i).getValue());
}

return new Tuple(state).withElements(elementMap);
}

private static Object constructResult(State state, List<Variable> variables, List<Object> elements) {
if (variables.size() > 1) {
LinkedHashMap<String, Object> elementMap = new LinkedHashMap<>();
for (int i = 0; i < variables.size(); i++) {
elementMap.put(variables.get(i).getName(), variables.get(i).getValue());
}

return new Tuple(state).withElements(elementMap);
return constructTuple(state, variables, elements);
}

return elements.get(0);
}

public static void sortResult(Query elm, List<Object> result, State state, String alias, ElmLibraryVisitor<Object, State> visitor) {
public static void sortResult(Query elm, List<Object> result, State state, String alias,
ElmLibraryVisitor<Object, State> visitor) {

SortClause sortClause = elm.getSort();

Expand All @@ -100,7 +118,8 @@ public static void sortResult(Query elm, List<Object> result, State state, Strin
for (SortByItem byItem : sortClause.getBy()) {

if (byItem instanceof ByExpression) {
result.sort(new CqlList(state, visitor, alias, ((ByExpression) byItem).getExpression()).expressionSort);
result.sort(
new CqlList(state, visitor, alias, ((ByExpression) byItem).getExpression()).expressionSort);
} else if (byItem instanceof ByColumn) {
result.sort(new CqlList(state, ((ByColumn) byItem).getPath()).columnSort);
} else {
Expand Down Expand Up @@ -141,6 +160,9 @@ public Iterable<Object> getData() {

@SuppressWarnings("unchecked")
public static Object internalEvaluate(Query elm, State state, ElmLibraryVisitor<Object, State> visitor) {
if (elm.getAggregate() != null && elm.getReturn() != null) {
throw new CqlException("aggregate and return are mutually exclusive");
}

var sources = new ArrayList<Iterator<Object>>();
var variables = new ArrayList<Variable>();
Expand Down Expand Up @@ -188,7 +210,13 @@ public static Object internalEvaluate(Query elm, State state, ElmLibraryVisitor<
continue;
}

result.add(evaluateReturn(elm, state, variables, elements, visitor));
if (elm.getAggregate() != null) {
result.add(constructTuple(state, variables, elements));
}
else {
result.add(evaluateReturn(elm, state, variables, elements, visitor));
}

}
} finally {
while (pushCount > 0) {
Expand All @@ -201,13 +229,17 @@ public static Object internalEvaluate(Query elm, State state, ElmLibraryVisitor<
result = DistinctEvaluator.distinct(result, state);
}

if (elm.getAggregate() != null) {
result = evaluateAggregate(elm.getAggregate(), state, visitor, result);
}

sortResult(elm, result, state, null, visitor);

if ((result == null || result.isEmpty()) && !sourceIsList) {
return null;
}

return sourceIsList ? result : result.get(0);
return elm.getAggregate() != null || !sourceIsList ? result.get(0) : result;
}

private static void assignVariables(List<Variable> variables, List<Object> elements) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ public class CqlAggregateFunctionsTest extends CqlTestBase {

@Test
public void test_all_aggregate_function_tests() {
EvaluationResult evaluationResult;

evaluationResult = engine.evaluate(toElmIdentifier("CqlAggregateFunctionsTest"));
var evaluationResult = engine.evaluate(toElmIdentifier("CqlAggregateFunctionsTest"));
Object result = evaluationResult.forExpression("AllTrueAllTrue").value();
assertThat(result, is(true));

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package org.opencds.cqf.cql.engine.execution;

import org.testng.annotations.Test;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;

public class CqlAggregateQueryTest extends CqlTestBase {
@Test
void test_all_aggregate_clause_tests() {
var evaluationResult = engine.evaluate(toElmIdentifier("CqlAggregateQueryTest"));
var result = evaluationResult.forExpression("AggregateSumWithStart").value();
assertThat(result, is(16));

result = evaluationResult.forExpression("AggregateSumWithNull").value();
assertThat(result, is(15));

result = evaluationResult.forExpression("AggregateSumAll").value();
assertThat(result, is(24));

result = evaluationResult.forExpression("AggregateSumDistinct").value();
assertThat(result, is(15));

result = evaluationResult.forExpression("Multi").value();
assertThat(result, is(6));

result = evaluationResult.forExpression("MegaMulti").value();
assertThat(result, is(36));

result = evaluationResult.forExpression("MegaMultiDistinct").value();
assertThat(result, is(37));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ define SumTestQuantity: Sum({1 'ml',2 'ml',3 'ml',4 'ml',5 'ml'})
define SumTestNull: Sum({ null, 1, null })

//Variance
define VarianceTest1: Variance({ 1.0, 2.0, 3.0, 4.0, 5.0 })
define VarianceTest1: Variance({ 1.0, 2.0, 3.0, 4.0, 5.0 })
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
library CqlAggregateQueryTest

//Aggregate clause
define AggregateSumWithStart:
({ 1, 2, 3, 4, 5 }) Num
aggregate Result starting 1: Result + Num // 15 + 1 (the initial value)

define AggregateSumWithNull:
({ 1, 2, 3, 4, 5 }) Num
aggregate Result: Coalesce(Result, 0) + Num // 15 + 0 (the initial value from null)

define AggregateSumAll:
({ 1, 1, 2, 2, 2, 3, 4, 4, 5 }) Num
aggregate all Result: Coalesce(Result, 0) + Num // 24 + 0

define AggregateSumDistinct:
({ 1, 1, 2, 2, 2, 3, 4, 4, 5 }) Num
aggregate distinct Result: Coalesce(Result, 0) + Num // 15 + 0 (the initial value)


define First: {1}
define Second: {2}
define Third: {3}

define Multi:
from First X, Second Y, Third Z
aggregate Agg: Coalesce(Agg, 0) + X + Y + Z // 6

define "A": {1, 2}
define "B": {1, 2}
define "C": {1, 2}

define MegaMulti:
from "A" X, "B" Y, "C" Z
aggregate Agg starting 0: Agg + X + Y + Z // 36 -- (1+1+1)+(1+1+2)+(1+2+1)+(1+2+2)+(2+1+1)+(2+1+2)+(2+2+1)+(2+2+2)


define "1": {1, 2, 2, 1}
define "2": {1, 2, 1, 2}
define "3": {2, 1, 2, 1}

define MegaMultiDistinct:
from "1" X, "2" Y, "3" Z
aggregate distinct Agg starting 1: Agg + X + Y + Z // 37 -- 1 + (1+1+1)+(1+1+2)+(1+2+1)+(1+2+2)+(2+1+1)+(2+1+2)+(2+2+1)+(2+2+2)