Skip to content

Commit

Permalink
Fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelmckinsey1 committed Jul 11, 2024
1 parent d7d9dae commit db2fd67
Showing 1 changed file with 36 additions and 18 deletions.
54 changes: 36 additions & 18 deletions thicket/tests/test_concat_thickets.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,38 +111,56 @@ def test_filter_concat_thickets_columns(thicket_axis_columns):
filter_multiple_and(combined_th, columns_values)


def test_filter_stats_concat_thickets_columns(thicket_axis_columns):
def test_filter_stats_concat_thickets_columns(thicket_axis_columns, intersection):
thickets, thickets_cp, combined_th = thicket_axis_columns
# columns and corresponding values to filter by

columns_values = {
("test", "test_string_column"): ["less than 20"],
("test", "test_numeric_column"): [4, 15],
}
# set string column values
less_than_20 = ["less than 20"] * 21
less_than_45 = ["less than 45"] * 25
less_than_178 = ["less than 75"] * 28
new_col = less_than_20 + less_than_45 + less_than_178
combined_th.statsframe.dataframe[("test", "test_string_column")] = new_col
# set numeric column values
combined_th.statsframe.dataframe[("test", "test_numeric_column")] = range(0, 74)

if intersection:
less_than_65 = ["less than 65"] * 18
new_col = less_than_20 + less_than_45 + less_than_65
combined_th.statsframe.dataframe[("test", "test_string_column")] = new_col
# set numeric column values
combined_th.statsframe.dataframe[("test", "test_numeric_column")] = range(0, 64)
else:
less_than_75 = ["less than 75"] * 28
new_col = less_than_20 + less_than_45 + less_than_75
combined_th.statsframe.dataframe[("test", "test_string_column")] = new_col
# set numeric column values
combined_th.statsframe.dataframe[("test", "test_numeric_column")] = range(0, 74)

check_filter_stats(combined_th, columns_values)


def test_query_concat_thickets_columns(thicket_axis_columns):
def test_query_concat_thickets_columns(thicket_axis_columns, intersection):
thickets, thickets_cp, combined_th = thicket_axis_columns
# test arguments
hnids = [
0,
1,
2,
3,
4,
5,
6,
7,
] # "0" because top-level node "RAJAPerf" will be included in query result.
if intersection:
# Shorter graph for intersection
hnids = [
0,
1,
2,
3,
4,
]
else:
hnids = [
0,
1,
2,
3,
4,
5,
6,
7,
] # "0" because top-level node "RAJAPerf" will be included in query result.
query = (
ht.QueryMatcher()
.match("*")
Expand Down

0 comments on commit db2fd67

Please sign in to comment.