Skip to content

Commit

Permalink
Merge pull request #379 from opencybersecurityalliance/refactoring-co…
Browse files Browse the repository at this point in the history
…mmand-20230719

move an exception handling out of commands.py
  • Loading branch information
subbyte authored Jul 26, 2023
2 parents cfcd2b4 + 4456e1f commit 2435dc4
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 18 deletions.
15 changes: 1 addition & 14 deletions src/kestrel/codegen/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,8 @@ def assign(stmt, session):
qry = _transform_query(session.store, entity_table, transform)
else:
qry = Query(entity_table)

qry = _build_query(session.store, entity_table, qry, stmt, [])

try:
session.store.assign_query(stmt["output"], qry)
output = new_var(session.store, stmt["output"], [], stmt, session.symtable)
except InvalidAttr as e:
var_attr = str(e).split()[-1]
var_name, _, attr = var_attr.rpartition(".")
raise MissingEntityAttribute(var_name, attr) from e

return output, None
session.store.assign_query(stmt["output"], qry)


@_debug_logger
Expand All @@ -148,9 +138,6 @@ def merge(stmt, session):
entity_tables = [t for t in entity_tables if t is not None]
session.store.merge(stmt["output"], entity_tables)

output = new_var(session.store, stmt["output"], [], stmt, session.symtable)
return output, None


@_debug_logger
@_default_output
Expand Down
14 changes: 10 additions & 4 deletions src/kestrel/codegen/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from collections import OrderedDict
from kestrel.codegen.relations import get_entity_id_attribute
from kestrel.exceptions import KestrelInternalError
from kestrel.exceptions import KestrelInternalError, MissingEntityAttribute


def gen_variable_summary(var_name, var_struct):
Expand Down Expand Up @@ -92,7 +92,13 @@ def get_variable_entity_count(variable):
entity_count = 0
if variable.entity_table:
entity_id_attr = get_entity_id_attribute(variable)
if entity_id_attr not in variable.store.columns(variable.entity_table):
return 0
entity_count = variable.store.count(variable.entity_table)
try:
columns = variable.store.columns(variable.entity_table)
except InvalidAttr as e:
# TODO: a better solution needed for tests/test_timestamped.py::test_timestamped_grouped_assign
table_attr = str(e).split()[-1]
table_name, _, attr = table_attr.rpartition(".")
raise MissingEntityAttribute(table_name, attr) from e
if entity_id_attr in columns:
entity_count = variable.store.count(variable.entity_table)
return entity_count

0 comments on commit 2435dc4

Please sign in to comment.