Skip to content

Commit

Permalink
Merge pull request #598 from robthew/master
Browse files Browse the repository at this point in the history
Update add_method to check class names
  • Loading branch information
rpiazza authored Apr 16, 2024
2 parents 4abcf3d + 68497bb commit 0efa195
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 18 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/python-ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ jobs:
run: |
tox
- name: Upload coverage information to Codecov
uses: codecov/codecov-action@v1
uses: codecov/codecov-action@v4.2.0
with:
fail_ci_if_error: true # optional (default = false)
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: false # optional (default = false)
verbose: true # optional (default = false)
17 changes: 11 additions & 6 deletions stix2/datastore/relational_db/add_method.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
import re
from stix2.v21.base import _STIXBase21
from stix2.datastore.relational_db.utils import get_all_subclasses
from stix2.properties import Property

# _ALLOWABLE_CLASSES = get_all_subclasses(_STIXBase21)
#
#
# _ALLOWABLE_CLASSES.extend(get_all_subclasses(Property))
_ALLOWABLE_CLASSES = get_all_subclasses(_STIXBase21)
_ALLOWABLE_CLASSES.extend(get_all_subclasses(Property))
_ALLOWABLE_CLASSES.extend([Property])


def create_real_method_name(name, klass_name):
# if klass_name not in _ALLOWABLE_CLASSES:
# raise NameError
classnames = map(lambda x: x.__name__, _ALLOWABLE_CLASSES)
if klass_name not in classnames:
raise NameError

split_up_klass_name = re.findall('[A-Z][^A-Z]*', klass_name)
return name + "_" + "_".join([x.lower() for x in split_up_klass_name])


def add_method(cls):

def decorator(fn):
method_name = fn.__name__
fn.__name__ = create_real_method_name(fn.__name__, cls.__name__)
Expand Down
14 changes: 7 additions & 7 deletions stix2/datastore/relational_db/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,22 @@ def canonicalize_table_name(table_name, schema_name=None):
return inflection.underscore(full_name)


def _get_all_subclasses(cls):
def get_all_subclasses(cls):
all_subclasses = []

for subclass in cls.__subclasses__():
all_subclasses.append(subclass)
all_subclasses.extend(_get_all_subclasses(subclass))
all_subclasses.extend(get_all_subclasses(subclass))
return all_subclasses


def get_stix_object_classes():
yield from _get_all_subclasses(_DomainObject)
yield from _get_all_subclasses(_RelationshipObject)
yield from _get_all_subclasses(_Observable)
yield from _get_all_subclasses(_MetaObject)
yield from get_all_subclasses(_DomainObject)
yield from get_all_subclasses(_RelationshipObject)
yield from get_all_subclasses(_Observable)
yield from get_all_subclasses(_MetaObject)
# Non-object extensions (property or toplevel-property only)
for ext_cls in _get_all_subclasses(_Extension):
for ext_cls in get_all_subclasses(_Extension):
if ext_cls.extension_type not in (
"new-sdo", "new-sco", "new-sro",
):
Expand Down
4 changes: 1 addition & 3 deletions stix2/test/v21/test_datastore_relational_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
True,
None,
False,
False
)

# Artifacts
Expand Down Expand Up @@ -280,7 +281,6 @@ def test_multipart_email_msg():
def test_file():
file_stix_object = stix2.parse(file_dict)
store.add(file_stix_object)
read_obj = store.get(file_stix_object['id'])
read_obj = json.loads(store.get(file_stix_object['id']).serialize())

for attrib in file_dict.keys():
Expand Down Expand Up @@ -418,7 +418,6 @@ def test_network_traffic():
def test_process():
process_stix_object = stix2.parse(process_dict)
store.add(process_stix_object)
read_obj = store.get(process_stix_object['id'])
read_obj = json.loads(store.get(process_stix_object['id']).serialize())

for attrib in process_dict.keys():
Expand Down Expand Up @@ -446,7 +445,6 @@ def test_process():
def test_software():
software_stix_object = stix2.parse(software_dict)
store.add(software_stix_object)
read_obj = store.get(software_stix_object['id'])
read_obj = json.loads(store.get(software_stix_object['id']).serialize())

for attrib in software_dict.keys():
Expand Down

0 comments on commit 0efa195

Please sign in to comment.