diff --git a/.github/workflows/python-ci-tests.yml b/.github/workflows/python-ci-tests.yml index 47f860e1..2dc51d85 100644 --- a/.github/workflows/python-ci-tests.yml +++ b/.github/workflows/python-ci-tests.yml @@ -3,10 +3,31 @@ name: cti-python-stix2 test harness on: [push, pull_request] -jobs: - build: +env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: postgres +jobs: + test-job: runs-on: ubuntu-latest + + services: + postgres: + image: postgres:11 + # Provide the password for postgres + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: postgres + ports: [ '5432:5432' ] + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + strategy: matrix: python-version: [3.8, 3.9, '3.10', '3.11', '3.12'] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c39aaf6d..35dff666 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,22 +1,22 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.4.0 + rev: v4.5.0 hooks: - id: trailing-whitespace - id: check-merge-conflict - repo: https://github.com/asottile/add-trailing-comma - rev: v2.0.2 + rev: v3.1.0 hooks: - id: add-trailing-comma - repo: https://github.com/PyCQA/flake8 - rev: 3.8.4 + rev: 7.0.0 hooks: - id: flake8 name: Check project styling args: - --max-line-length=160 - repo: https://github.com/PyCQA/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort name: Sort python imports (shows diff) diff --git a/setup.py b/setup.py index 80edabb4..f8048f58 100644 --- a/setup.py +++ b/setup.py @@ -52,6 +52,7 @@ def get_long_description(): 'requests', 'simplejson', 'stix2-patterns>=1.2.0', + 'inflection', ], project_urls={ 'Documentation': 'https://stix2.readthedocs.io/', @@ -61,5 +62,10 @@ def get_long_description(): extras_require={ 'taxii': ['taxii2-client>=2.3.0'], 'semantic': ['haversine', 'rapidfuzz'], + 'relationaldb': [ + 'sqlalchemy', + 'sqlalchemy_utils', + 'psycopg2', + ], }, ) diff --git a/stix2/base.py b/stix2/base.py index 3ff01d98..7d97b8a4 100644 --- a/stix2/base.py +++ b/stix2/base.py @@ -470,6 +470,10 @@ def _check_object_constraints(self): self._check_at_least_one_property() +class _MetaObject(_STIXBase): + pass + + def _choose_one_hash(hash_dict): if "MD5" in hash_dict: return {"MD5": hash_dict["MD5"]} diff --git a/stix2/custom.py b/stix2/custom.py index 44a03493..cef590d3 100644 --- a/stix2/custom.py +++ b/stix2/custom.py @@ -89,7 +89,7 @@ def __init__(self, **kwargs): return _CustomObservable -def _custom_extension_builder(cls, type, properties, version, base_class): +def _custom_extension_builder(cls, applies_to, type, properties, version, base_class): properties = _get_properties_dict(properties) toplevel_properties = None @@ -98,6 +98,7 @@ def _custom_extension_builder(cls, type, properties, version, base_class): # it exists. How to treat the other properties which were given depends on # the extension type. extension_type = getattr(cls, "extension_type", None) + applies_to = applies_to if extension_type: # I suppose I could also go with a plain string property, since the # value is fixed... but an enum property seems more true to the @@ -128,6 +129,7 @@ class _CustomExtension(cls, base_class): _type = type _properties = nested_properties + _applies_to = applies_to if extension_type == "toplevel-property-extension": _toplevel_properties = toplevel_properties diff --git a/stix2/datastore/relational_db/add_method.py b/stix2/datastore/relational_db/add_method.py new file mode 100644 index 00000000..61d4b3ff --- /dev/null +++ b/stix2/datastore/relational_db/add_method.py @@ -0,0 +1,28 @@ +import re + +from stix2.datastore.relational_db.utils import get_all_subclasses +from stix2.properties import Property +from stix2.v21.base import _STIXBase21 + +_ALLOWABLE_CLASSES = get_all_subclasses(_STIXBase21) +_ALLOWABLE_CLASSES.extend(get_all_subclasses(Property)) +_ALLOWABLE_CLASSES.extend([Property]) + + +def create_real_method_name(name, klass_name): + classnames = map(lambda x: x.__name__, _ALLOWABLE_CLASSES) + if klass_name not in classnames: + raise NameError + + split_up_klass_name = re.findall('[A-Z][^A-Z]*', klass_name) + return name + "_" + "_".join([x.lower() for x in split_up_klass_name]) + + +def add_method(cls): + + def decorator(fn): + method_name = fn.__name__ + fn.__name__ = create_real_method_name(fn.__name__, cls.__name__) + setattr(cls, method_name, fn) + return fn + return decorator diff --git a/stix2/datastore/relational_db/database_backends/database_backend_base.py b/stix2/datastore/relational_db/database_backends/database_backend_base.py new file mode 100644 index 00000000..bb03f55f --- /dev/null +++ b/stix2/datastore/relational_db/database_backends/database_backend_base.py @@ -0,0 +1,114 @@ +from typing import Any + +from sqlalchemy import Boolean, Float, Integer, Text, create_engine +from sqlalchemy_utils import create_database, database_exists, drop_database + +from stix2.base import ( + _DomainObject, _MetaObject, _Observable, _RelationshipObject, +) + + +class DatabaseBackend: + def __init__(self, database_connection_url, force_recreate=False, **kwargs: Any): + self.database_connection_url = database_connection_url + self.database_exists = database_exists(database_connection_url) + + if force_recreate: + if self.database_exists: + drop_database(database_connection_url) + create_database(database_connection_url) + self.database_exists = database_exists(database_connection_url) + + self.database_connection = create_engine(database_connection_url) + + def _create_schemas(self): + pass + + @staticmethod + def determine_schema_name(stix_object): + return "" + + @staticmethod + def determine_stix_type(stix_object): + if isinstance(stix_object, _DomainObject): + return "sdo" + elif isinstance(stix_object, _Observable): + return "sco" + elif isinstance(stix_object, _RelationshipObject): + return "sro" + elif isinstance(stix_object, _MetaObject): + return "common" + + def _create_database(self): + if self.database_exists: + drop_database(self.database_connection.url) + create_database(self.database_connection.url) + self.database_exists = database_exists(self.database_connection.url) + + def schema_for(stix_class): + return "" + + @staticmethod + def schema_for_core(): + return "" + + # you must implement the next 4 methods in the subclass + + @staticmethod + def determine_sql_type_for_property(): # noqa: F811 + pass + + @staticmethod + def determine_sql_type_for_binary_property(): # noqa: F811 + pass + + @staticmethod + def determine_sql_type_for_hex_property(): # noqa: F811 + pass + + @staticmethod + def determine_sql_type_for_timestamp_property(): # noqa: F811 + pass + + @staticmethod + def determine_sql_type_for_kill_chain_phase(): # noqa: F811 + return None + + @staticmethod + def determine_sql_type_for_boolean_property(): # noqa: F811 + return Boolean + + @staticmethod + def determine_sql_type_for_float_property(): # noqa: F811 + return Float + + @staticmethod + def determine_sql_type_for_integer_property(): # noqa: F811 + return Integer + + @staticmethod + def determine_sql_type_for_reference_property(): # noqa: F811 + return Text + + @staticmethod + def determine_sql_type_for_string_property(): # noqa: F811 + return Text + + @staticmethod + def determine_sql_type_for_key_as_int(): # noqa: F811 + return Integer + + @staticmethod + def determine_sql_type_for_key_as_id(): # noqa: F811 + return Text + + @staticmethod + def array_allowed(): + return False + + def generate_value(self, stix_type, value): + sql_type = stix_type.determine_sql_type(self) + if sql_type == self.determine_sql_type_for_string_property(): + return value + elif sql_type == self.determine_sql_type_for_hex_property(): + return bytes.fromhex(value) diff --git a/stix2/datastore/relational_db/database_backends/postgres_backend.py b/stix2/datastore/relational_db/database_backends/postgres_backend.py new file mode 100644 index 00000000..ca501dfb --- /dev/null +++ b/stix2/datastore/relational_db/database_backends/postgres_backend.py @@ -0,0 +1,66 @@ +import os +from typing import Any + +from sqlalchemy import TIMESTAMP, LargeBinary, Text +from sqlalchemy.schema import CreateSchema + +from stix2.base import ( + _DomainObject, _MetaObject, _Observable, _RelationshipObject, +) +from stix2.datastore.relational_db.utils import schema_for + +from .database_backend_base import DatabaseBackend + + +class PostgresBackend(DatabaseBackend): + default_database_connection_url = \ + f"postgresql://{os.getenv('POSTGRES_USER', 'postgres')}:" + \ + f"{os.getenv('POSTGRES_PASSWORD', 'postgres')}@" + \ + f"{os.getenv('POSTGRES_IP_ADDRESS', '0.0.0.0')}:" + \ + f"{os.getenv('POSTGRES_PORT', '5432')}/postgres" + + def __init__(self, database_connection_url=default_database_connection_url, force_recreate=False, **kwargs: Any): + super().__init__(database_connection_url, force_recreate=force_recreate, **kwargs) + + def _create_schemas(self): + with self.database_connection.begin() as trans: + trans.execute(CreateSchema("common", if_not_exists=True)) + trans.execute(CreateSchema("sdo", if_not_exists=True)) + trans.execute(CreateSchema("sco", if_not_exists=True)) + trans.execute(CreateSchema("sro", if_not_exists=True)) + + @staticmethod + def determine_schema_name(stix_object): + if isinstance(stix_object, _DomainObject): + return "sdo" + elif isinstance(stix_object, _Observable): + return "sco" + elif isinstance(stix_object, _RelationshipObject): + return "sro" + elif isinstance(stix_object, _MetaObject): + return "common" + + @staticmethod + def schema_for(stix_class): + return schema_for(stix_class) + + @staticmethod + def schema_for_core(): + return "common" + + @staticmethod + def determine_sql_type_for_binary_property(): # noqa: F811 + return PostgresBackend.determine_sql_type_for_string_property() + + @staticmethod + def determine_sql_type_for_hex_property(): # noqa: F811 + # return LargeBinary + return PostgresBackend.determine_sql_type_for_string_property() + + @staticmethod + def determine_sql_type_for_timestamp_property(): # noqa: F811 + return TIMESTAMP(timezone=True) + + @staticmethod + def array_allowed(): + return True diff --git a/stix2/datastore/relational_db/input_creation.py b/stix2/datastore/relational_db/input_creation.py new file mode 100644 index 00000000..6bc04aa2 --- /dev/null +++ b/stix2/datastore/relational_db/input_creation.py @@ -0,0 +1,538 @@ + +from sqlalchemy import insert + +from stix2.datastore.relational_db.add_method import add_method +from stix2.datastore.relational_db.utils import ( + SCO_COMMON_PROPERTIES, SDO_COMMON_PROPERTIES, canonicalize_table_name, +) +from stix2.properties import ( + BinaryProperty, BooleanProperty, DictionaryProperty, + EmbeddedObjectProperty, EnumProperty, ExtensionsProperty, FloatProperty, + HashesProperty, HexProperty, IDProperty, IntegerProperty, ListProperty, + Property, ReferenceProperty, StringProperty, TimestampProperty, +) +from stix2.utils import STIXdatetime +from stix2.v21.common import KillChainPhase + + +@add_method(Property) +def generate_insert_information(self, name, stix_object, **kwargs): # noqa: F811 + pass + + +@add_method(BinaryProperty) +def generate_insert_information(self, name, stix_object, **kwargs): # noqa: F811 + return {name: stix_object[name]} + + +@add_method(BooleanProperty) +def generate_insert_information(self, name, stix_object, **kwargs): # noqa: F811 + return {name: stix_object[name]} + + +def instance_in_valid_types(cls, valid_types): + for v in valid_types: + if isinstance(v, cls): + return True + return False + + +def is_valid_type(cls, valid_types): + return cls in valid_types or instance_in_valid_types(cls, valid_types) + + +@add_method(DictionaryProperty) +def generate_insert_information(self, dictionary_name, stix_object, **kwargs): # noqa: F811 + bindings = dict() + data_sink = kwargs.get("data_sink") + table_name = kwargs.get("table_name") + schema_name = kwargs.get("schema_name") + foreign_key_value = kwargs.get("foreign_key_value") + insert_statements = list() + + table = data_sink.tables_dictionary[ + canonicalize_table_name( + table_name + "_" + dictionary_name, + schema_name, + ) + ] + + # binary, boolean, float, hex, + # integer, string, timestamp + valid_types = stix_object._properties[dictionary_name].valid_types + for name, value in stix_object[dictionary_name].items(): + bindings = dict() + if "id" in stix_object: + bindings["id"] = stix_object["id"] + elif foreign_key_value: + bindings["id"] = foreign_key_value + if not valid_types or len(self.valid_types) == 1: + value_binding = "value" + elif isinstance(value, int) and is_valid_type(IntegerProperty, valid_types): + value_binding = "integer_value" + elif isinstance(value, str) and is_valid_type(StringProperty, valid_types): + value_binding = "string_value" + elif isinstance(value, bool) and is_valid_type(BooleanProperty, valid_types): + value_binding = "boolean_value" + elif isinstance(value, float) and is_valid_type(FloatProperty, valid_types): + value_binding = "float_value" + elif isinstance(value, STIXdatetime) and is_valid_type(TimestampProperty, valid_types): + value_binding = "timestamp_value" + else: + value_binding = "string_value" + + bindings["name"] = name + bindings[value_binding] = value + + insert_statements.append(insert(table).values(bindings)) + + return insert_statements + + +@add_method(EmbeddedObjectProperty) +def generate_insert_information(self, name, stix_object, is_list=False, foreign_key_value=None, is_extension=False, **kwargs): # noqa: F811 + data_sink = kwargs.get("data_sink") + schema_name = kwargs.get("schema_name") + level = kwargs.get("level") + return generate_insert_for_sub_object( + data_sink, stix_object[name], self.type.__name__, schema_name, + level=level+1 if is_list else level, + is_embedded_object=True, + is_extension=is_extension, + parent_table_name=kwargs.get("parent_table_name"), + foreign_key_value=foreign_key_value, + ) + + +@add_method(EnumProperty) +def generate_insert_information(self, name, stix_object, **kwargs): # noqa: F811 + return {name: stix_object[name]} + + +@add_method(ExtensionsProperty) +def generate_insert_information(self, name, stix_object, data_sink=None, table_name=None, schema_name=None, parent_table_name=None, **kwargs): # noqa: F811 + input_statements = list() + for ex_name, ex in stix_object["extensions"].items(): + # ignore new extensions - they have no properties + if ex.extension_type is None or not ex.extension_type.startswith("new"): + if ex_name.startswith("extension-definition"): + ex_name = ex_name[0:30] + ex_name = ex_name.replace("extension-definition-", "ext_def") + bindings = { + "id": stix_object["id"], + "ext_table_name": canonicalize_table_name(ex_name, schema_name), + } + ex_table = data_sink.tables_dictionary[canonicalize_table_name(table_name + "_" + "extensions", schema_name)] + input_statements.append(insert(ex_table).values(bindings)) + input_statements.extend( + generate_insert_for_sub_object( + data_sink, ex, ex_name, schema_name, stix_object["id"], + parent_table_name=parent_table_name, + is_extension=True, + ), + ) + return input_statements + + +@add_method(FloatProperty) +def generate_insert_information(self, name, stix_object, **kwargs): # noqa: F811 + return {name: stix_object[name]} + + +@add_method(HexProperty) +def generate_insert_information(self, name, stix_object, data_sink, **kwargs): # noqa: F811 + return {name: data_sink.db_backend.generate_value(self, stix_object[name])} + + +def generate_insert_for_hashes( + data_sink, name, stix_object, table_name, schema_name, foreign_key_value=None, + is_embedded_object=False, **kwargs, +): + bindings = {"id": foreign_key_value} + table_name = canonicalize_table_name(table_name + "_" + name, schema_name) + table = data_sink.tables_dictionary[table_name] + insert_statements = list() + for hash_name, hash_value in stix_object["hashes"].items(): + + bindings["hash_name"] = hash_name + bindings["hash_value"] = hash_value + insert_statements.append(insert(table).values(bindings)) + return insert_statements + + +@add_method(HashesProperty) +def generate_insert_information( # noqa: F811 + self, name, stix_object, data_sink=None, table_name=None, schema_name=None, + is_embedded_object=False, foreign_key_value=None, is_list=False, **kwargs, +): + return generate_insert_for_hashes( + data_sink, name, stix_object, table_name, schema_name, + is_embedded_object=is_embedded_object, is_list=is_list, foreign_key_value=foreign_key_value, + ) + + +@add_method(IDProperty) +def generate_insert_information(self, name, stix_object, **kwargs): # noqa: F811 + return {name: stix_object[name]} + + +@add_method(IntegerProperty) +def generate_insert_information(self, name, stix_object, **kwargs): # noqa: F811 + return {name: stix_object[name]} + + +@add_method(ListProperty) +def generate_insert_information( # noqa: F811 + self, name, stix_object, data_sink=None, level=0, is_extension=False, + foreign_key_value=None, schema_name=None, **kwargs, +): + db_backend = data_sink.db_backend + table_name = kwargs.get("table_name") + if isinstance(self.contained, ReferenceProperty): + insert_statements = list() + + table = data_sink.tables_dictionary[canonicalize_table_name(table_name + "_" + name, schema_name)] + for idx, item in enumerate(stix_object[name]): + bindings = { + "id": stix_object["id"] if id in stix_object else foreign_key_value, + "ref_id": item, + } + insert_statements.append(insert(table).values(bindings)) + return insert_statements + elif self.contained == KillChainPhase: + insert_statements = list() + table = data_sink.tables_dictionary[canonicalize_table_name(table_name + "_" + name, schema_name)] + + for idx, item in enumerate(stix_object[name]): + bindings = { + "id": stix_object["id"] if id in stix_object else foreign_key_value, + "kill_chain_name": item["kill_chain_name"], + "phase_name": item["phase_name"], + } + insert_statements.append(insert(table).values(bindings)) + return insert_statements + elif isinstance(self.contained, EnumProperty): + insert_statements = list() + table = data_sink.tables_dictionary[canonicalize_table_name(table_name + "_" + name, schema_name)] + + for idx, item in enumerate(stix_object[name]): + bindings = { + "id": stix_object["id"] if id in stix_object else foreign_key_value, + name: item, + } + insert_statements.append(insert(table).values(bindings)) + return insert_statements + elif isinstance(self.contained, EmbeddedObjectProperty): + insert_statements = list() + for value in stix_object[name]: + next_id = data_sink.next_id() + table = data_sink.tables_dictionary[canonicalize_table_name(table_name + "_" + name, schema_name)] + bindings = { + "id": foreign_key_value, + "ref_id": next_id, + } + insert_statements.append(insert(table).values(bindings)) + insert_statements.extend( + generate_insert_for_sub_object( + data_sink, + value, + table_name + "_" + name + "_" + self.contained.type.__name__, + schema_name, + next_id, + level, + True, + is_extension=is_extension, + ), + ) + return insert_statements + else: + if db_backend.array_allowed(): + if isinstance(self.contained, HexProperty): + return {name: [data_sink.db_backend.generate_value(self.contained, x) for x in stix_object[name]]} + else: + return {name: stix_object[name]} + + else: + insert_statements = list() + table = data_sink.tables_dictionary[ + canonicalize_table_name( + table_name + "_" + name, + schema_name, + ) + ] + for elem in stix_object[name]: + bindings = { + "id": stix_object["id"], + name: bytes.fromhex(elem) if isinstance(self.contained, HexProperty) else elem, + } + insert_statements.append(insert(table).values(bindings)) + return insert_statements + + +@add_method(ReferenceProperty) +def generate_insert_information(self, name, stix_object, **kwargs): # noqa: F811 + return {name: stix_object[name]} + + +@add_method(StringProperty) +def generate_insert_information(self, name, stix_object, **kwargs): # noqa: F811 + return {name: stix_object[name]} + + +@add_method(TimestampProperty) +def generate_insert_information(self, name, stix_object, **kwargs): # noqa: F811 + return {name: stix_object[name]} + + +def derive_column_name(prop): + contained_property = prop.contained + if isinstance(contained_property, ReferenceProperty): + return "ref_id" + elif isinstance(contained_property, StringProperty): + return "value" + + +def generate_insert_for_array_in_table(table, values, foreign_key_value, column_name="ref_id"): + insert_statements = list() + for idx, item in enumerate(values): + bindings = { + "id": foreign_key_value, + column_name: item, + } + insert_statements.append(insert(table).values(bindings)) + return insert_statements + + +def generate_insert_for_external_references(data_sink, stix_object): + insert_statements = list() + next_id = None + object_table = data_sink.tables_dictionary["common.external_references"] + for er in stix_object["external_references"]: + bindings = {"id": stix_object["id"]} + for prop in ["source_name", "description", "url", "external_id"]: + if prop in er: + bindings[prop] = er[prop] + if "hashes" in er: + next_id = data_sink.next_id() + bindings["hash_ref_id"] = next_id + # else: + # # hash_ref_id is non-NULL, so -1 means there are no hashes + # bindings["hash_ref_id"] = -1 + er_insert_statement = insert(object_table).values(bindings) + insert_statements.append(er_insert_statement) + + if "hashes" in er: + insert_statements.extend( + generate_insert_for_hashes(data_sink, "hashes", er, "external_references", "common", foreign_key_value=next_id), + ) + + return insert_statements + + +def generate_insert_for_granular_markings(data_sink, granular_markings_table, stix_object): + db_backend = data_sink.db_backend + insert_statements = list() + granular_markings = stix_object["granular_markings"] + for idx, granular_marking in enumerate(granular_markings): + bindings = {"id": stix_object["id"]} + lang_property_value = granular_marking.get("lang") + if lang_property_value: + bindings["lang"] = lang_property_value + marking_ref_value = granular_marking.get("marking_ref") + if marking_ref_value: + bindings["marking_ref"] = marking_ref_value + if db_backend.array_allowed(): + bindings["selectors"] = granular_marking.get("selectors") + insert_statements.append(insert(granular_markings_table).values(bindings)) + else: + next_id = data_sink.next_id() + bindings["selectors"] = next_id + insert_statements.append(insert(granular_markings_table).values(bindings)) + table = data_sink.tables_dictionary[ + canonicalize_table_name( + granular_markings_table.name + "_selector", + db_backend.schema_for_core(), + ) + ] + for sel in granular_marking.get("selectors"): + selector_bindings = {"id": next_id, "selector": sel} + insert_statements.append(insert(table).values(selector_bindings)) + + return insert_statements + + +# def generate_insert_for_extensions(extensions, foreign_key_value, type_name, core_properties): +# sql_bindings_tuples = list() +# for name, ex in extensions.items(): +# sql_bindings_tuples.extend( +# generate_insert_for_subtype_extension( +# name, +# ex, +# foreign_key_value, +# type_name, +# core_properties, +# ), +# ) +# return sql_bindings_tuples + + +def generate_insert_for_core(data_sink, stix_object, core_properties, stix_type_name, schema_name): + db_backend = data_sink.db_backend + if stix_type_name in ["sdo", "sro", "common"]: + core_table = data_sink.tables_dictionary[db_backend.schema_for_core() + "." + "core_sdo"] + else: + core_table = data_sink.tables_dictionary[db_backend.schema_for_core() + "." + "core_sco"] + insert_statements = list() + core_bindings = {} + + child_table_properties = ["object_marking_refs", "granular_markings", "external_references", "type"] + if "labels" in core_properties and not db_backend.array_allowed(): + child_table_properties.append("labels") + + for prop_name, value in stix_object.items(): + if prop_name in core_properties: + # stored in separate tables below, skip here + if prop_name not in child_table_properties: + core_bindings[prop_name] = value + + core_insert_statement = insert(core_table).values(core_bindings) + insert_statements.append(core_insert_statement) + + if "labels" in stix_object and "labels" in child_table_properties: + label_table_name = canonicalize_table_name(core_table.name + "_labels", data_sink.db_backend.schema_for_core()) + labels_table = data_sink.tables_dictionary[label_table_name] + insert_statements.extend( + generate_insert_for_array_in_table( + labels_table, + stix_object["labels"], + stix_object["id"], + column_name="label", + ), + ) + + if "object_marking_refs" in stix_object: + object_marking_table_name = canonicalize_table_name( + "object_marking_refs", + data_sink.db_backend.schema_for_core(), + ) + if stix_type_name != "sco": + object_markings_ref_table = data_sink.tables_dictionary[object_marking_table_name + "_sdo"] + else: + object_markings_ref_table = data_sink.tables_dictionary[object_marking_table_name + "_sco"] + insert_statements.extend( + generate_insert_for_array_in_table( + object_markings_ref_table, + stix_object["object_marking_refs"], + stix_object["id"], + ), + ) + + # Granular markings + if "granular_markings" in stix_object: + granular_marking_table_name = canonicalize_table_name( + "granular_marking", + data_sink.db_backend.schema_for_core(), + ) + if stix_type_name != "sco": + granular_marking_table = data_sink.tables_dictionary[granular_marking_table_name + "_sdo"] + else: + granular_marking_table = data_sink.tables_dictionary[granular_marking_table_name + "_sco"] + granular_input_statements = generate_insert_for_granular_markings( + data_sink, + granular_marking_table, + stix_object, + ) + insert_statements.extend(granular_input_statements) + + return insert_statements + + +def generate_insert_for_sub_object( + data_sink, stix_object, type_name, schema_name, foreign_key_value=None, + is_embedded_object=False, is_list=False, parent_table_name=None, level=0, + is_extension=False, +): + insert_statements = list() + bindings = dict() + if "id" in stix_object: + bindings["id"] = stix_object["id"] + elif foreign_key_value: + bindings["id"] = foreign_key_value + if parent_table_name and (not is_extension or level > 0): + type_name = parent_table_name + "_" + type_name + if type_name.startswith("extension-definition"): + type_name = type_name[0:30] + type_name = type_name.replace("extension-definition-", "ext_def") + sub_insert_statements = list() + for name, prop in stix_object._properties.items(): + if name in stix_object: + result = prop.generate_insert_information( + name, + stix_object, + data_sink=data_sink, + table_name=type_name if isinstance(prop, (DictionaryProperty, ListProperty)) else parent_table_name, + schema_name=schema_name, + foreign_key_value=foreign_key_value, + is_embedded_object=is_embedded_object, + is_list=is_list, + level=level+1, + is_extension=is_extension, + parent_table_name=type_name, + ) + if isinstance(result, dict): + bindings.update(result) + elif isinstance(result, list): + sub_insert_statements.extend(result) + else: + raise ValueError("wrong type" + result) + if foreign_key_value: + bindings["id"] = foreign_key_value + object_table = data_sink.tables_dictionary[canonicalize_table_name(type_name, schema_name)] + insert_statements.append(insert(object_table).values(bindings)) + insert_statements.extend(sub_insert_statements) + return insert_statements + + +def generate_insert_for_object(data_sink, stix_object, stix_type_name, schema_name, level=0): + insert_statements = list() + bindings = dict() + if stix_type_name == "sco": + core_properties = SCO_COMMON_PROPERTIES + elif stix_type_name in ["sdo", "sro", "common"]: + core_properties = SDO_COMMON_PROPERTIES + else: + core_properties = list() + type_name = stix_object["type"] + if core_properties: + insert_statements.extend(generate_insert_for_core(data_sink, stix_object, core_properties, stix_type_name, schema_name)) + if "id" in stix_object: + foreign_key_value = stix_object["id"] + else: + foreign_key_value = None + sub_insert_statements = list() + for name, prop in stix_object._properties.items(): + if (name == 'id' or name not in core_properties) and name != "type" and name in stix_object: + result = prop.generate_insert_information( + name, stix_object, + data_sink=data_sink, + table_name=type_name, + schema_name=schema_name, + parent_table_name=type_name, + level=level, + foreign_key_value=foreign_key_value, + ) + if isinstance(result, dict): + bindings.update(result) + elif isinstance(result, list): + sub_insert_statements.extend(result) + else: + raise ValueError("wrong type" + result) + + object_table = data_sink.tables_dictionary[canonicalize_table_name(type_name, schema_name)] + insert_statements.append(insert(object_table).values(bindings)) + insert_statements.extend(sub_insert_statements) + + if "external_references" in stix_object: + insert_statements.extend(generate_insert_for_external_references(data_sink, stix_object)) + + return insert_statements diff --git a/stix2/datastore/relational_db/query.py b/stix2/datastore/relational_db/query.py new file mode 100644 index 00000000..637b66e9 --- /dev/null +++ b/stix2/datastore/relational_db/query.py @@ -0,0 +1,673 @@ +import inspect + +import sqlalchemy as sa + +import stix2 +from stix2.datastore import DataSourceError +from stix2.datastore.relational_db.utils import ( + canonicalize_table_name, schema_for, table_name_for, +) +import stix2.properties +import stix2.utils + + +def _check_support(stix_id): + """ + Misc support checks for the relational data source. May be better to error + out up front and say a type is not supported, than die with some cryptic + SQLAlchemy or other error later. This runs for side-effects (raises + an exception) and doesn't return anything. + + :param stix_id: A STIX ID. The basis for reading an object, used to + determine support + """ + # language-content has a complicated structure in its "contents" + # property, which is not currently supported for storage in a + # relational database. + stix_type = stix2.utils.get_type_from_id(stix_id) + if stix_type in ("language-content",): + raise DataSourceError(f"Reading {stix_type} objects is not supported.") + + +def _tables_for(stix_class, metadata): + """ + Get the core and type-specific tables for the given class + + :param stix_class: A class for a STIX object type + :param metadata: SQLAlchemy Metadata object containing all the table + information + :return: A (core_table, type_table) 2-tuple as SQLAlchemy Table objects + """ + # Info about the type-specific table + type_table_name = table_name_for(stix_class) + type_schema_name = schema_for(stix_class) + type_table = metadata.tables[f"{type_schema_name}.{type_table_name}"] + + # Some fixed info about core tables + if type_schema_name == "sco": + core_table_name = "common.core_sco" + else: + # for SROs and SMOs too? + core_table_name = "common.core_sdo" + + core_table = metadata.tables[core_table_name] + + return core_table, type_table + + +def _stix2_class_for(stix_id): + """ + Find the class for the STIX type indicated by the given STIX ID. + + :param stix_id: A STIX ID + """ + stix_type = stix2.utils.get_type_from_id(stix_id) + stix_class = stix2.registry.class_for_type( + # TODO: give user control over STIX version used? + stix_type, stix_version=stix2.DEFAULT_VERSION, + ) + + return stix_class + + +def _read_simple_properties(stix_id, core_table, type_table, conn): + """ + Read "simple" property values, i.e. those which don't need tables other + than the core/type-specific tables: they're stored directly in columns of + those tables. These two tables are joined and must have a defined foreign + key constraint between them. + + :param stix_id: A STIX ID + :param core_table: A core table + :param type_table: A type-specific table + :param conn: An SQLAlchemy DB connection + :return: A mapping containing the properties and values read + """ + # Both core and type-specific tables have "id"; let's not duplicate that + # in the result set columns. Is there a better way to do this? + type_cols_except_id = ( + col for col in type_table.c if col.key != "id" + ) + + core_type_select = sa.select(core_table, *type_cols_except_id) \ + .join(type_table) \ + .where(core_table.c.id == stix_id) + + # Should be at most one matching row + obj_dict = conn.execute(core_type_select).mappings().first() + + return obj_dict + + +def _read_simple_array(fk_id, elt_column_name, array_table, conn): + """ + Read array elements from a given table. + + :param fk_id: A foreign key value used to find the correct array elements + :param elt_column_name: The name of the table column which contains the + array elements + :param array_table: A SQLAlchemy Table object containing the array data + :param conn: An SQLAlchemy DB connection + :return: The array, as a list + """ + stmt = sa.select(array_table.c[elt_column_name]).where(array_table.c.id == fk_id) + refs = conn.scalars(stmt).all() + return refs + + +def _read_hashes(fk_id, hashes_table, conn): + """ + Read hashes from a table. + + :param fk_id: A foreign key value used to filter table rows + :param hashes_table: An SQLAlchemy Table object + :param conn: An SQLAlchemy DB connection + :return: The hashes as a dict, or None if no hashes were found + """ + stmt = sa.select(hashes_table.c.hash_name, hashes_table.c.hash_value).where( + hashes_table.c.id == fk_id, + ) + + results = conn.execute(stmt) + hashes = dict(results.all()) or None + return hashes + + +def _read_external_references(stix_id, metadata, conn): + """ + Read external references from some fixed tables in the common schema. + + :param stix_id: A STIX ID used to filter table rows + :param metadata: SQLAlchemy Metadata object containing all the table + information + :param conn: An SQLAlchemy DB connection + :return: The external references, as a list of dicts + """ + ext_refs_table = metadata.tables["common.external_references"] + ext_refs_hashes_table = metadata.tables["common.external_references_hashes"] + ext_refs = [] + + ext_refs_columns = (col for col in ext_refs_table.c if col.key != "id") + stmt = sa.select(*ext_refs_columns).where(ext_refs_table.c.id == stix_id) + ext_refs_results = conn.execute(stmt) + for ext_ref_mapping in ext_refs_results.mappings(): + # make a dict; we will need to modify this mapping + ext_ref_dict = dict(ext_ref_mapping) + hash_ref_id = ext_ref_dict.pop("hash_ref_id") + + hashes_dict = _read_hashes(hash_ref_id, ext_refs_hashes_table, conn) + if hashes_dict: + ext_ref_dict["hashes"] = hashes_dict + + ext_refs.append(ext_ref_dict) + + return ext_refs + + +def _read_object_marking_refs(stix_id, stix_type_class, metadata, conn): + """ + Read object marking refs from one of a couple special tables in the common + schema. + + :param stix_id: A STIX ID, used to filter table rows + :param stix_type_class: STIXTypeClass enum value, used to determine whether + to read the table for SDOs or SCOs + :param metadata: SQLAlchemy Metadata object containing all the table + information + :param conn: An SQLAlchemy DB connection + :return: The references as a list of strings + """ + + marking_table_name = "object_marking_refs_" + if stix_type_class is stix2.utils.STIXTypeClass.SCO: + marking_table_name += "sco" + else: + marking_table_name += "sdo" + + # The SCO/SDO object_marking_refs tables are mostly identical; they just + # have different foreign key constraints (to different core tables). + marking_table = metadata.tables["common." + marking_table_name] + + stmt = sa.select(marking_table.c.ref_id).where(marking_table.c.id == stix_id) + refs = conn.scalars(stmt).all() + + return refs + + +def _read_granular_markings(stix_id, stix_type_class, metadata, conn, db_backend): + """ + Read granular markings from one of a couple special tables in the common + schema. + + :param stix_id: A STIX ID, used to filter table rows + :param stix_type_class: STIXTypeClass enum value, used to determine whether + to read the table for SDOs or SCOs + :param metadata: SQLAlchemy Metadata object containing all the table + information + :param conn: An SQLAlchemy DB connection + :param db_backend: A backend object with information about how data is + stored in the database + :return: Granular markings as a list of dicts + """ + + marking_table_name = "granular_marking_" + if stix_type_class is stix2.utils.STIXTypeClass.SCO: + marking_table_name += "sco" + else: + marking_table_name += "sdo" + + marking_table = metadata.tables["common." + marking_table_name] + + if db_backend.array_allowed(): + # arrays allowed: everything combined in the same table + stmt = sa.select( + marking_table.c.lang, + marking_table.c.marking_ref, + marking_table.c.selectors, + ).where(marking_table.c.id == stix_id) + + marking_dicts = conn.execute(stmt).mappings().all() + + else: + # arrays not allowed: selectors are in their own table + stmt = sa.select( + marking_table.c.lang, + marking_table.c.marking_ref, + marking_table.c.selectors, + ).where(marking_table.c.id == stix_id) + + marking_dicts = list(conn.execute(stmt).mappings()) + + for idx, marking_dict in enumerate(marking_dicts): + # make a mutable shallow-copy of the row mapping + marking_dicts[idx] = marking_dict = dict(marking_dict) + selector_id = marking_dict.pop("selectors") + + selector_table_name = f"{marking_table.fullname}_selector" + selector_table = metadata.tables[selector_table_name] + + selectors = _read_simple_array( + selector_id, + "selector", + selector_table, + conn + ) + marking_dict["selectors"] = selectors + + return marking_dicts + + +def _read_kill_chain_phases(stix_id, type_table, metadata, conn): + """ + Read kill chain phases from a table. + + :param stix_id: A STIX ID used to filter table rows + :param type_table: A "parent" table whose name is used to compute the + kill chain phases table name + :param metadata: SQLAlchemy Metadata object containing all the table + information + :param conn: An SQLAlchemy DB connection + :return: Kill chain phases as a list of dicts + """ + + kill_chain_phases_table = metadata.tables[type_table.fullname + "_kill_chain_phase"] + stmt = sa.select( + kill_chain_phases_table.c.kill_chain_name, + kill_chain_phases_table.c.phase_name, + ).where(kill_chain_phases_table.c.id == stix_id) + + kill_chain_phases = conn.execute(stmt).mappings().all() + return kill_chain_phases + + +def _read_dictionary_property(stix_id, type_table, prop_name, prop_instance, metadata, conn): + """ + Read a dictionary from a table. + + :param stix_id: A STIX ID, used to filter table rows + :param type_table: A "parent" table whose name is used to compute the name + of the dictionary table + :param prop_name: The dictionary property name + :param prop_instance: The dictionary property instance + :param metadata: SQLAlchemy Metadata object containing all the table + information + :param conn: An SQLAlchemy DB connection + :return: The dictionary, or None if no dictionary entries were found + """ + dict_table_name = f"{type_table.fullname}_{prop_name}" + dict_table = metadata.tables[dict_table_name] + + if len(prop_instance.valid_types) == 1: + stmt = sa.select( + dict_table.c.name, dict_table.c.value, + ).where( + dict_table.c.id == stix_id, + ) + + results = conn.execute(stmt) + dict_value = dict(results.all()) + + else: + # In this case, we get one column per valid type + type_cols = (col for col in dict_table.c if col.key not in ("id", "name")) + stmt = sa.select(dict_table.c.name, *type_cols).where(dict_table.c.id == stix_id) + results = conn.execute(stmt) + + dict_value = {} + for row in results: + key, *type_values = row + # Exactly one of the type columns should be non-None; get that one + non_null_values = (v for v in type_values if v is not None) + first_non_null_value = next(non_null_values, None) + if first_non_null_value is None: + raise DataSourceError( + f'In dictionary table {dict_table.fullname}, key "{key}"' + " did not map to a non-null value", + ) + + dict_value[key] = first_non_null_value + + # DictionaryProperty doesn't like empty dicts. + dict_value = dict_value or None + + return dict_value + + +def _read_embedded_object(obj_id, parent_table, embedded_type, metadata, conn): + """ + Read an embedded object from the database. + + :param obj_id: An ID value used to identify a particular embedded object, + used to filter table rows + :param parent_table: A "parent" table whose name is used to compute the + name of the embedded object table + :param embedded_type: The Python class used to represent the embedded + type (a _STIXBase subclass) + :param metadata: SQLAlchemy Metadata object containing all the table + information + :param conn: An SQLAlchemy DB connection + :return: An instance of embedded_type + """ + + embedded_table_name = canonicalize_table_name( + f"{parent_table.name}_{embedded_type.__name__}", + parent_table.schema, + ) + embedded_table = metadata.tables[embedded_table_name] + + # The PK column in this case is a bookkeeping column and does not + # correspond to an actual embedded object property. So don't select + # that one. + non_id_cols = (col for col in embedded_table.c if col.key != "id") + + stmt = sa.select(*non_id_cols).where(embedded_table.c.id == obj_id) + mapping_row = conn.execute(stmt).mappings().first() + + if mapping_row is None: + obj = None + + else: + obj_dict = dict(mapping_row) + + for prop_name, prop_instance in embedded_type._properties.items(): + if prop_name not in obj_dict: + prop_value = _read_complex_property_value( + obj_id, + prop_name, + prop_instance, + embedded_table, + metadata, + conn, + ) + + if prop_value is not None: + obj_dict[prop_name] = prop_value + + obj = embedded_type(**obj_dict, allow_custom=True) + + return obj + + +def _read_embedded_object_list(fk_id, join_table, embedded_type, metadata, conn): + """ + Read a list of embedded objects from database tables. + + :param fk_id: A foreign key ID used to filter rows from the join table, + which acts to find relevant embedded objects + :param join_table: An SQLAlchemy Table object which is the required join + table + :param embedded_type: The Python class used to represent the list element + embedded type (a _STIXBase subclass) + :param metadata: SQLAlchemy Metadata object containing all the table + information + :param conn: An SQLAlchemy DB connection + :return: A list of instances of embedded_type + """ + + embedded_table_name = canonicalize_table_name( + f"{join_table.name}_{embedded_type.__name__}", + join_table.schema, + ) + embedded_table = metadata.tables[embedded_table_name] + + stmt = sa.select(embedded_table).join(join_table).where(join_table.c.id == fk_id) + results = conn.execute(stmt) + obj_list = [] + for result_mapping in results.mappings(): + obj_dict = dict(result_mapping) + obj_id = obj_dict.pop("id") + + for prop_name, prop_instance in embedded_type._properties.items(): + if prop_name not in obj_dict: + prop_value = _read_complex_property_value( + obj_id, + prop_name, + prop_instance, + embedded_table, + metadata, + conn, + ) + + if prop_value is not None: + obj_dict[prop_name] = prop_value + + obj = embedded_type(**obj_dict, allow_custom=True) + obj_list.append(obj) + + return obj_list + + +def _read_complex_property_value(obj_id, prop_name, prop_instance, obj_table, metadata, conn): + """ + Read property values which require auxiliary tables to store. These are + idiosyncratic and just require a lot of special cases. This function has + no special support for top-level common properties, so it is more + general-purpose, suitable for any sort of object (top level or embedded). + + :param obj_id: An ID of the owning object. Would be a STIX ID for a + top-level object, but could also be something else for sub-objects. + Used as a foreign key value in queries, so we only get values for this + object. + :param prop_name: The name of the property to read + :param prop_instance: A Property (subclass) instance with property + config information + :param obj_table: The table for the owning object. Mainly used for its + name; auxiliary table names are based on this + :param metadata: SQLAlchemy Metadata object containing all the table + information + :param conn: An SQLAlchemy DB connection + :return: The property value + """ + + prop_value = None + + if isinstance(prop_instance, stix2.properties.ListProperty): + + if isinstance(prop_instance.contained, stix2.properties.ReferenceProperty): + ref_table_name = f"{obj_table.fullname}_{prop_name}" + ref_table = metadata.tables[ref_table_name] + prop_value = _read_simple_array(obj_id, "ref_id", ref_table, conn) + + elif isinstance(prop_instance.contained, ( + # Most of these list-of-simple-type cases would occur when array + # columns are disabled. + stix2.properties.BinaryProperty, + stix2.properties.BooleanProperty, + stix2.properties.EnumProperty, + stix2.properties.HexProperty, + stix2.properties.IntegerProperty, + stix2.properties.FloatProperty, + stix2.properties.StringProperty, + stix2.properties.TimestampProperty, + )): + array_table_name = f"{obj_table.fullname}_{prop_name}" + array_table = metadata.tables[array_table_name] + prop_value = _read_simple_array( + obj_id, + prop_name, + array_table, + conn + ) + + elif isinstance(prop_instance.contained, stix2.properties.EmbeddedObjectProperty): + join_table_name = f"{obj_table.fullname}_{prop_name}" + join_table = metadata.tables[join_table_name] + prop_value = _read_embedded_object_list( + obj_id, + join_table, + prop_instance.contained.type, + metadata, + conn, + ) + + elif inspect.isclass(prop_instance.contained) and issubclass(prop_instance.contained, stix2.KillChainPhase): + prop_value = _read_kill_chain_phases(obj_id, obj_table, metadata, conn) + + else: + raise DataSourceError( + f'Not implemented: read "{prop_name}" property value' + f" of type list-of {prop_instance.contained}", + ) + + elif isinstance(prop_instance, stix2.properties.HashesProperty): + hashes_table_name = f"{obj_table.fullname}_{prop_name}" + hashes_table = metadata.tables[hashes_table_name] + prop_value = _read_hashes(obj_id, hashes_table, conn) + + elif isinstance(prop_instance, stix2.properties.ExtensionsProperty): + # TODO: add support for extensions + pass + + elif isinstance(prop_instance, stix2.properties.DictionaryProperty): + # ExtensionsProperty/HashesProperty subclasses DictionaryProperty, so + # this must come after those + prop_value = _read_dictionary_property(obj_id, obj_table, prop_name, prop_instance, metadata, conn) + + elif isinstance(prop_instance, stix2.properties.EmbeddedObjectProperty): + prop_value = _read_embedded_object( + obj_id, + obj_table, + prop_instance.type, + metadata, + conn, + ) + + else: + raise DataSourceError( + f'Not implemented: read "{prop_name}" property value' + f" of type {prop_instance.__class__}", + ) + + return prop_value + + +def _read_complex_top_level_property_value( + stix_id, + stix_type_class, + prop_name, + prop_instance, + type_table, + metadata, + conn, + db_backend +): + """ + Read property values which require auxiliary tables to store. These + require a lot of special cases. This function has additional support for + reading top-level common properties, which use special fixed tables. + + :param stix_id: STIX ID of an object to read + :param stix_type_class: The kind of object (SCO, SDO, etc). Which DB + tables to read can depend on this. + :param prop_name: The name of the property to read + :param prop_instance: A Property (subclass) instance with property + config information + :param type_table: The non-core base table used for this STIX type. Mainly + used for its name; auxiliary table names are based on this + :param metadata: SQLAlchemy Metadata object containing all the table + information + :param conn: An SQLAlchemy DB connection + :param db_backend: A backend object with information about how data is + stored in the database + :return: The property value + """ + + # Common properties: these use a fixed set of tables for all STIX objects + if prop_name == "external_references": + prop_value = _read_external_references(stix_id, metadata, conn) + + elif prop_name == "object_marking_refs": + prop_value = _read_object_marking_refs( + stix_id, + stix_type_class, + metadata, + conn + ) + + elif prop_name == "granular_markings": + prop_value = _read_granular_markings( + stix_id, + stix_type_class, + metadata, + conn, + db_backend + ) + + # Will apply when array columns are unsupported/disallowed by the backend + elif prop_name == "labels": + label_table = metadata.tables[ + f"common.core_{stix_type_class.name.lower()}_labels" + ] + prop_value = _read_simple_array(stix_id, "label", label_table, conn) + + else: + # Other properties use specific table patterns depending on property type + prop_value = _read_complex_property_value( + stix_id, + prop_name, + prop_instance, + type_table, + metadata, + conn + ) + + return prop_value + + +def read_object(stix_id, metadata, conn, db_backend): + """ + Read a STIX object from the database, identified by a STIX ID. + + :param stix_id: A STIX ID + :param metadata: SQLAlchemy Metadata object containing all the table + information + :param conn: An SQLAlchemy DB connection + :param db_backend: A backend object with information about how data is + stored in the database + :return: A STIX object + """ + _check_support(stix_id) + + stix_class = _stix2_class_for(stix_id) + + if not stix_class: + stix_type = stix2.utils.get_type_from_id(stix_id) + raise DataSourceError("Can't find registered class for type: " + stix_type) + + core_table, type_table = _tables_for(stix_class, metadata) + + if type_table.schema == "common": + # Applies to extension-definition SMO, whose data is stored in the + # common schema; it does not get its own. This type class is used to + # determine which common tables to use; its markings are + # in the *_sdo tables. + stix_type_class = stix2.utils.STIXTypeClass.SDO + else: + stix_type_class = stix2.utils.to_enum(type_table.schema, stix2.utils.STIXTypeClass) + + simple_props = _read_simple_properties(stix_id, core_table, type_table, conn) + if simple_props is None: + # could not find anything for the given ID! + return None + + obj_dict = dict(simple_props) + obj_dict["type"] = stix_class._type + + for prop_name, prop_instance in stix_class._properties.items(): + if prop_name not in obj_dict: + prop_value = _read_complex_top_level_property_value( + stix_id, + stix_type_class, + prop_name, + prop_instance, + type_table, + metadata, + conn, + db_backend + ) + + if prop_value is not None: + obj_dict[prop_name] = prop_value + + stix_obj = stix_class(**obj_dict, allow_custom=True) + return stix_obj diff --git a/stix2/datastore/relational_db/relational_db.py b/stix2/datastore/relational_db/relational_db.py new file mode 100644 index 00000000..ee0a312f --- /dev/null +++ b/stix2/datastore/relational_db/relational_db.py @@ -0,0 +1,243 @@ +from sqlalchemy import MetaData, delete +from sqlalchemy.schema import CreateTable, Sequence + +from stix2.base import _STIXBase +from stix2.datastore import DataSink, DataSource, DataStoreMixin +from stix2.datastore.relational_db.input_creation import ( + generate_insert_for_object, +) +from stix2.datastore.relational_db.query import read_object +from stix2.datastore.relational_db.table_creation import create_table_objects +from stix2.datastore.relational_db.utils import canonicalize_table_name +from stix2.parsing import parse + + +def _add(store, stix_data, allow_custom=True, version="2.1"): + """Add STIX objects to MemoryStore/Sink. + + Adds STIX objects to an in-memory dictionary for fast lookup. + Recursive function, breaks down STIX Bundles and lists. + + Args: + store: A MemoryStore, MemorySink or MemorySource object. + stix_data (list OR dict OR STIX object): STIX objects to be added + allow_custom (bool): Whether to allow custom properties as well unknown + custom objects. Note that unknown custom objects cannot be parsed + into STIX objects, and will be returned as is. Default: False. + version (str): Which STIX2 version to lock the parser to. (e.g. "2.0", + "2.1"). If None, the library makes the best effort to figure + out the spec representation of the object. + + """ + if isinstance(stix_data, list): + # STIX objects are in a list- recurse on each object + for stix_obj in stix_data: + _add(store, stix_obj, allow_custom, version) + + elif stix_data["type"] == "bundle": + # adding a json bundle - so just grab STIX objects + for stix_obj in stix_data.get("objects", []): + _add(store, stix_obj, allow_custom, version) + + else: + # Adding a single non-bundle object + if isinstance(stix_data, _STIXBase): + stix_obj = stix_data + else: + stix_obj = parse(stix_data, allow_custom, version) + + store.insert_object(stix_obj) + + +class RelationalDBStore(DataStoreMixin): + def __init__( + self, db_backend, allow_custom=True, version=None, + instantiate_database=True, print_sql=False, *stix_object_classes, + ): + """ + Initialize this store. + + Args: + database_connection_url: An SQLAlchemy URL referring to a database + allow_custom: Whether custom content is allowed when processing + dict content to be added to the store + version: TODO: unused so far + instantiate_database: Whether tables, etc should be created in the + database (only necessary the first time) + force_recreate: Drops old database and creates new one (useful if + the schema has changed and the tables need to be updated) + *stix_object_classes: STIX object classes to map into table schemas + (and ultimately database tables, if instantiation is desired). + This can be used to limit which table schemas are created, if + one is only working with a subset of STIX types. If not given, + auto-detect all classes and create table schemas for all of + them. + """ + + self.metadata = MetaData() + create_table_objects( + self.metadata, db_backend, stix_object_classes, + ) + + super().__init__( + source=RelationalDBSource( + db_backend, + metadata=self.metadata, + allow_custom=allow_custom, + ), + sink=RelationalDBSink( + db_backend, + print_sql=print_sql, + allow_custom=allow_custom, + version=version, + instantiate_database=instantiate_database, + metadata=self.metadata, + ), + ) + + +class RelationalDBSink(DataSink): + def __init__( + self, db_backend, allow_custom=True, version=None, + instantiate_database=True, print_sql=False, *stix_object_classes, metadata=None, + ): + """ + Initialize this sink. Only one of stix_object_classes and metadata + should be given: if the latter is given, assume table schemas are + already created. + + Args: + database_connection_or_url: An SQLAlchemy engine object, or URL + allow_custom: Whether custom content is allowed when processing + dict content to be added to the sink + version: TODO: unused so far + instantiate_database: Whether the database, tables, etc should be + created (only necessary the first time) + force_recreate: Drops old database and creates new one (useful if + the schema has changed and the tables need to be updated) + *stix_object_classes: STIX object classes to map into table schemas + (and ultimately database tables, if instantiation is desired). + This can be used to limit which table schemas are created, if + one is only working with a subset of STIX types. If not given, + auto-detect all classes and create table schemas for all of + them. If metadata is given, the table data therein is used and + this argument is ignored. + metadata: SQLAlchemy MetaData object containing table information. + Only applicable when this class is instantiated via a store, + so that table information can be constructed once and shared + between source and sink. + """ + super(RelationalDBSink, self).__init__() + + self.db_backend = db_backend + + if metadata: + self.metadata = metadata + else: + self.metadata = MetaData() + create_table_objects( + self.metadata, stix_object_classes, + ) + self.sequence = Sequence("my_general_seq", metadata=self.metadata, start=1, schema="common") + + self.allow_custom = allow_custom + + self.tables_dictionary = dict() + for t in self.metadata.tables.values(): + self.tables_dictionary[canonicalize_table_name(t.name, t.schema)] = t + + if instantiate_database: + if not self.db_backend.database_exists: + self.db_backend._create_database() + # else: + # self.clear_tables() + self.db_backend._create_schemas() + self._instantiate_database(print_sql) + + def _instantiate_database(self, print_sql=False): + self.metadata.create_all(self.db_backend.database_connection) + if print_sql: + for t in self.metadata.tables.values(): + print(CreateTable(t).compile(self.db_backend.database_connection)) + + def add(self, stix_data, version=None): + _add(self, stix_data, self.allow_custom) + add.__doc__ = _add.__doc__ + + def insert_object(self, stix_object): + schema_name = self.db_backend.determine_schema_name(stix_object) + stix_type_name = self.db_backend.determine_stix_type(stix_object) + with self.db_backend.database_connection.begin() as trans: + statements = generate_insert_for_object(self, stix_object, stix_type_name, schema_name) + for stmt in statements: + print("executing: ", stmt) + trans.execute(stmt) + trans.commit() + + def clear_tables(self): + tables = list(reversed(self.metadata.sorted_tables)) + with self.db_backend.database_connection.begin() as trans: + for table in tables: + delete_stmt = delete(table) + print(f'delete_stmt: {delete_stmt}') + trans.execute(delete_stmt) + + def next_id(self): + with self.db_backend.database_connection.begin() as trans: + return trans.execute(self.sequence) + + +class RelationalDBSource(DataSource): + def __init__( + self, db_backend, allow_custom, *stix_object_classes, metadata=None, + ): + """ + Initialize this source. Only one of stix_object_classes and metadata + should be given: if the latter is given, assume table schemas are + already created. Instances of this class do not create the actual + database tables; see the store/sink for that. + + Args: + db_backend: A database backend object + allow_custom: TODO: unused so far + *stix_object_classes: STIX object classes to map into table schemas. + This can be used to limit which schemas are created, if one is + only working with a subset of STIX types. If not given, + auto-detect all classes and create schemas for all of them. + If metadata is given, the table data therein is used and this + argument is ignored. + metadata: SQLAlchemy MetaData object containing table information. + Only applicable when this class is instantiated via a store, + so that table information can be constructed once and shared + between source and sink. + """ + super().__init__() + + self.db_backend = db_backend + + self.allow_custom = allow_custom + + if metadata: + self.metadata = metadata + else: + self.metadata = MetaData() + create_table_objects( + self.metadata, db_backend, stix_object_classes, + ) + + def get(self, stix_id, version=None, _composite_filters=None): + with self.db_backend.database_connection.connect() as conn: + stix_obj = read_object( + stix_id, + self.metadata, + conn, + self.db_backend, + ) + + return stix_obj + + def all_versions(self, stix_id, version=None, _composite_filters=None): + pass + + def query(self, query=None): + pass diff --git a/stix2/datastore/relational_db/relational_db_testing.py b/stix2/datastore/relational_db/relational_db_testing.py new file mode 100644 index 00000000..a6376fd2 --- /dev/null +++ b/stix2/datastore/relational_db/relational_db_testing.py @@ -0,0 +1,323 @@ +import datetime as dt + +from database_backends.postgres_backend import PostgresBackend +import pytz + +import stix2 +from stix2.datastore.relational_db.relational_db import RelationalDBStore +import stix2.properties + +directory_stix_object = stix2.Directory( + path="/foo/bar/a", + path_enc="latin1", + ctime="1980-02-23T05:43:28.2678Z", + atime="1991-06-09T18:06:33.915Z", + mtime="2000-06-28T13:06:09.5827Z", + contains_refs=[ + "file--8903b558-40e3-43e2-be90-b341c12ff7ae", + "directory--e0604d0c-bab3-4487-b350-87ac1a3a195c", + ], + object_marking_refs=[ + "marking-definition--1b3eec29-5376-4837-bd93-73203e65d73c", + ], +) + +s = stix2.v21.Software( + id="software--28897173-7314-4eec-b1cf-2c625b635bf6", + name="Word", + cpe="cpe:2.3:a:microsoft:word:2000:*:*:*:*:*:*:*", + swid="com.acme.rms-ce-v4-1-5-0", + version="2002", + languages=["c", "lisp"], + vendor="Microsoft", +) + + +def windows_registry_key_example(): + v1 = stix2.v21.WindowsRegistryValueType( + name="Foo", + data="qwerty", + data_type="REG_SZ", + ) + v2 = stix2.v21.WindowsRegistryValueType( + name="Bar", + data="Fred", + data_type="REG_SZ", + ) + w = stix2.v21.WindowsRegistryKey( + key="hkey_local_machine\\system\\bar\\foo", + values=[v1, v2], + ) + return w + + +def malware_with_all_required_properties(): + ref1 = stix2.v21.ExternalReference( + source_name="veris", + external_id="0001AA7F-C601-424A-B2B8-BE6C9F5164E7", + hashes={ + "SHA-256": "6db12788c37247f2316052e142f42f4b259d6561751e5f401a1ae2a6df9c674b", + "MD5": "3773a88f65a5e780c8dff9cdc3a056f3", + }, + url="https://github.com/vz-risk/VCDB/blob/master/data/json/0001AA7F-C601-424A-B2B8-BE6C9F5164E7.json", + ) + ref2 = stix2.v21.ExternalReference( + source_name="ACME Threat Intel", + description="Threat report", + url="http://www.example.com/threat-report.pdf", + ) + now = dt.datetime(2016, 5, 12, 8, 17, 27, tzinfo=pytz.utc) + + malware = stix2.v21.Malware( + external_references=[ref1, ref2], + type="malware", + id="malware--9c4638ec-f1de-4ddb-abf4-1b760417654e", + created=now, + modified=now, + name="Cryptolocker", + is_family=False, + labels=["foo", "bar"], + ) + return malware + + +def file_example_with_PDFExt_Object(): + f = stix2.v21.File( + name="qwerty.dll", + magic_number_hex="504B0304", + extensions={ + "pdf-ext": stix2.v21.PDFExt( + version="1.7", + document_info_dict={ + "Title": "Sample document", + "Author": "Adobe Systems Incorporated", + "Creator": "Adobe FrameMaker 5.5.3 for Power Macintosh", + "Producer": "Acrobat Distiller 3.01 for Power Macintosh", + "CreationDate": "20070412090123-02", + }, + pdfid0="DFCE52BD827ECF765649852119D", + pdfid1="57A1E0F9ED2AE523E313C", + ), + }, + ) + return f + + +def extension_definition_insert(): + return stix2.ExtensionDefinition( + created_by_ref="identity--8a5fb7e4-aabe-4635-8972-cbcde1fa4792", + name="test", + schema="a schema", + version="1.2.3", + extension_types=["property-extension", "new-sdo", "new-sro"], + object_marking_refs=[ + "marking-definition--caa0d913-5db8-4424-aae0-43e770287d30", + "marking-definition--122a27a0-b96f-46bc-8fcd-f7a159757e77", + ], + granular_markings=[ + { + "lang": "en_US", + "selectors": ["name", "schema"], + }, + { + "marking_ref": "marking-definition--50902d70-37ae-4f85-af68-3f4095493b42", + "selectors": ["name", "schema"], + }, + ], + ) + + +def dictionary_test(): + return stix2.File( + spec_version="2.1", + name="picture.jpg", + defanged=True, + ctime="1980-02-23T05:43:28.2678Z", + extensions={ + "raster-image-ext": { + "exif_tags": { + "Make": "Nikon", + "Model": "D7000", + "XResolution": 4928, + "YResolution": 3264, + }, + }, + }, + ) + + +def kill_chain_test(): + return stix2.AttackPattern( + spec_version="2.1", + id="attack-pattern--0c7b5b88-8ff7-4a4d-aa9d-feb398cd0061", + created="2016-05-12T08:17:27.000Z", + modified="2016-05-12T08:17:27.000Z", + name="Spear Phishing", + kill_chain_phases=[ + { + "kill_chain_name": "lockheed-martin-cyber-kill-chain", + "phase_name": "reconnaissance", + }, + ], + external_references=[ + { + "source_name": "capec", + "external_id": "CAPEC-163", + }, + ], + granular_markings=[ + { + "lang": "en_US", + "selectors": ["kill_chain_phases"], + }, + { + "marking_ref": "marking-definition--50902d70-37ae-4f85-af68-3f4095493b42", + "selectors": ["external_references"], + }, + ], ) + + +@stix2.CustomObject( + 'x-custom-type', + properties=[ + ("phases", stix2.properties.ListProperty(stix2.KillChainPhase)), + ("something_else", stix2.properties.IntegerProperty()), + ], +) +class CustomClass: + pass + + +def custom_obj(): + obj = CustomClass( + phases=[ + { + "kill_chain_name": "chain name", + "phase_name": "the phase name", + }, + ], + something_else=5, + ) + return obj + + +@stix2.CustomObject( + "test-object", [ + ("prop_name", stix2.properties.ListProperty(stix2.properties.BinaryProperty())), + ], + "extension-definition--15de9cdb-3515-4271-8479-8141154c5647", + is_sdo=True, +) +class TestClass: + pass + + +def test_binary_list(): + return TestClass(prop_name=["AREi", "7t3M"]) + + +@stix2.CustomObject( + "test2-object", [ + ( + "prop_name", stix2.properties.ListProperty( + stix2.properties.HexProperty(), + ), + ), + ], + "extension-definition--15de9cdb-4567-4271-8479-8141154c5647", + is_sdo=True, +) +class Test2Class: + pass + + +def test_hex_list(): + return Test2Class( + prop_name=["1122", "fedc"], + ) + + +@stix2.CustomObject( + "test3-object", [ + ( + "prop_name", + stix2.properties.DictionaryProperty( + valid_types=[ + stix2.properties.IntegerProperty, + stix2.properties.FloatProperty, + stix2.properties.StringProperty, + ], + ), + ), + ( + "list_of_timestamps", + stix2.properties.ListProperty(stix2.properties.TimestampProperty()), + ), + ], + "extension-definition--15de9cdb-1234-4271-8479-8141154c5647", + is_sdo=True, +) +class Test3Class: + pass + + +def test_dictionary(): + return Test3Class( + prop_name={"a": 1, "b": 2.3, "c": "foo"}, + list_of_timestamps={ "2016-05-12T08:17:27.000Z", "2024-05-12T08:17:27.000Z"} + ) + + +def main(): + store = RelationalDBStore( + PostgresBackend("postgresql://localhost/stix-data-sink", force_recreate=True), + True, + None, + True, + print_sql=True, + ) + + if store.sink.db_backend.database_exists: + + td = test_dictionary() + + store.add(td) + + th = test_hex_list() + + store.add(th) + + tb = test_binary_list() + + store.add(tb) + + co = custom_obj() + + store.add(co) + + pdf_file = file_example_with_PDFExt_Object() + store.add(pdf_file) + + ap = kill_chain_test() + store.add(ap) + + store.add(directory_stix_object) + + store.add(s) + + store.add(extension_definition_insert()) + + dict_example = dictionary_test() + store.add(dict_example) + + malware = malware_with_all_required_properties() + store.add(malware) + + # read_obj = store.get(directory_stix_object.id) + # print(read_obj) + else: + print("database does not exist") + + +if __name__ == '__main__': + main() diff --git a/stix2/datastore/relational_db/table_creation.py b/stix2/datastore/relational_db/table_creation.py new file mode 100644 index 00000000..32965f02 --- /dev/null +++ b/stix2/datastore/relational_db/table_creation.py @@ -0,0 +1,914 @@ +# from collections import OrderedDict + +from sqlalchemy import ( # create_engine,; insert, + ARRAY, CheckConstraint, Column, ForeignKey, Integer, Table, Text, + UniqueConstraint, +) + +from stix2.datastore.relational_db.add_method import add_method +from stix2.datastore.relational_db.utils import ( + SCO_COMMON_PROPERTIES, SDO_COMMON_PROPERTIES, canonicalize_table_name, + determine_column_name, determine_sql_type_from_stix, flat_classes, + get_stix_object_classes, +) +from stix2.properties import ( + BinaryProperty, BooleanProperty, DictionaryProperty, + EmbeddedObjectProperty, EnumProperty, ExtensionsProperty, FloatProperty, + HashesProperty, HexProperty, IDProperty, IntegerProperty, ListProperty, + ObjectReferenceProperty, Property, ReferenceProperty, StringProperty, + TimestampProperty, TypeProperty, +) +from stix2.v21.base import _Extension +from stix2.v21.common import KillChainPhase + + +def create_array_column(property_name, contained_sql_type, optional): + return Column( + property_name, + ARRAY(contained_sql_type), + CheckConstraint(f"{property_name} IS NULL or array_length({property_name}, 1) IS NOT NULL"), + nullable=optional, + ) + + +def create_array_child_table(metadata, db_backend, parent_table_name, table_name_suffix, property_name, contained_sql_type): + schema_name = db_backend.schema_for_core() + columns = [ + Column( + "id", + db_backend.determine_sql_type_for_key_as_id(), + ForeignKey( + canonicalize_table_name(parent_table_name, schema_name) + ".id", + ondelete="CASCADE", + ), + nullable=False, + ), + Column( + property_name, + contained_sql_type, + nullable=False, + ), + ] + return Table(parent_table_name + table_name_suffix, metadata, *columns, schema=schema_name) + + +def derive_column_name(prop): + contained_property = prop.contained + if isinstance(contained_property, ReferenceProperty): + return "ref_id" + elif isinstance(contained_property, StringProperty): + return "value" + + +def create_object_markings_refs_table(metadata, db_backend, sco_or_sdo): + return create_ref_table( + metadata, + db_backend, + {"marking-definition"}, + "object_marking_refs_" + sco_or_sdo, + "common.core_" + sco_or_sdo + ".id", + "common", + 0, + ) + + +def create_ref_table(metadata, db_backend, specifics, table_name, foreign_key_name, schema_name, auth_type=0): + columns = list() + columns.append( + Column( + "id", + db_backend.determine_sql_type_for_key_as_id(), + ForeignKey( + foreign_key_name, + ondelete="CASCADE", + ), + nullable=False, + ), + ) + columns.append(ref_column("ref_id", specifics, db_backend, auth_type)) + return Table(table_name, metadata, *columns, schema=schema_name) + + +def create_hashes_table(name, metadata, db_backend, schema_name, table_name, key_type=Text, level=1): + columns = list() + # special case, perhaps because its a single embedded object with hashes, and not a list of embedded object + # making the parent table's primary key does seem to worl + + columns.append( + Column( + "id", + key_type, + # ForeignKey( + # canonicalize_table_name(table_name, schema_name) + (".hash_ref_id" if table_name == "external_references" else ".id"), + # ondelete="CASCADE", + # ), + + nullable=False, + ), + ) + columns.append( + Column( + "hash_name", + db_backend.determine_sql_type_for_string_property(), + nullable=False, + ), + ) + columns.append( + Column( + "hash_value", + db_backend.determine_sql_type_for_string_property(), + nullable=False, + ), + ) + return Table( + canonicalize_table_name(table_name + "_" + name), + metadata, + *columns, + UniqueConstraint("id", "hash_name"), + schema=schema_name, + ) + + +def create_kill_chain_phases_table(name, metadata, db_backend, schema_name, table_name): + columns = list() + columns.append( + Column( + "id", + db_backend.determine_sql_type_for_key_as_id(), + ForeignKey( + canonicalize_table_name(table_name, schema_name) + ".id", + ondelete="CASCADE", + ), + nullable=False, + ), + ) + columns.append( + Column( + "kill_chain_name", + db_backend.determine_sql_type_for_string_property(), + nullable=False, + ), + ) + columns.append( + Column( + "phase_name", + db_backend.determine_sql_type_for_string_property(), + nullable=False, + ), + ) + return Table(canonicalize_table_name(table_name + "_" + name), metadata, *columns, schema=schema_name) + + +def create_granular_markings_table(metadata, db_backend, sco_or_sdo): + schema_name = db_backend.schema_for_core() + tables = list() + columns = [ + Column( + "id", + db_backend.determine_sql_type_for_key_as_id(), + ForeignKey(canonicalize_table_name("core_" + sco_or_sdo, schema_name) + ".id", ondelete="CASCADE"), + nullable=False, + ), + Column("lang", db_backend.determine_sql_type_for_string_property()), + Column( + "marking_ref", + db_backend.determine_sql_type_for_reference_property(), + CheckConstraint( + "marking_ref ~ '^marking-definition--[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$'", + # noqa: E131 + ), + ), + ] + if db_backend.array_allowed(): + columns.append(create_array_column("selectors", db_backend.determine_sql_type_for_string_property(), False)) + + else: + columns.append( + Column( + "selectors", + db_backend.determine_sql_type_for_key_as_int(), + unique=True, + ), + ) + + child_columns = [ + Column( + "id", + db_backend.determine_sql_type_for_key_as_int(), + ForeignKey( + canonicalize_table_name("granular_marking_" + sco_or_sdo, schema_name) + ".selectors", + ondelete="CASCADE", + ), + nullable=False, + ), + Column( + "selector", + db_backend.determine_sql_type_for_string_property(), + nullable=False, + ), + ] + tables.append( + Table( + canonicalize_table_name("granular_marking_" + sco_or_sdo + "_" + "selector"), + metadata, *child_columns, schema=schema_name, + ), + ) + tables.append( + Table( + "granular_marking_" + sco_or_sdo, + metadata, + *columns, + CheckConstraint( + """(lang IS NULL AND marking_ref IS NOT NULL) + OR + (lang IS NOT NULL AND marking_ref IS NULL)""", + ), + schema=schema_name, + ), + ) + return tables + + +def create_external_references_tables(metadata, db_backend): + columns = [ + Column( + "id", + db_backend.determine_sql_type_for_key_as_id(), + ForeignKey("common.core_sdo" + ".id", ondelete="CASCADE"), + CheckConstraint( + "id ~ '^[a-z][a-z0-9-]+[a-z0-9]--[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$'", # noqa: E131 + ), + ), + Column("source_name", db_backend.determine_sql_type_for_string_property()), + Column("description", db_backend.determine_sql_type_for_string_property()), + Column("url", db_backend.determine_sql_type_for_string_property()), + Column("external_id", db_backend.determine_sql_type_for_string_property()), + # all such keys are generated using the global sequence. + Column("hash_ref_id", db_backend.determine_sql_type_for_key_as_int(), autoincrement=False), + ] + return [ + Table("external_references", metadata, *columns, schema="common"), + create_hashes_table("hashes", metadata, db_backend, "common", "external_references", Integer), + ] + + +def create_core_table(metadata, db_backend, stix_type_name): + tables = list() + table_name = "core_" + stix_type_name + columns = [ + Column( + "id", + db_backend.determine_sql_type_for_key_as_id(), + CheckConstraint( + "id ~ '^[a-z][a-z0-9-]+[a-z0-9]--[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$'", # noqa: E131 + ), + primary_key=True, + ), + Column("spec_version", db_backend.determine_sql_type_for_string_property(), default="2.1"), + ] + if stix_type_name == "sdo": + sdo_columns = [ + Column( + "created_by_ref", + db_backend.determine_sql_type_for_reference_property(), + CheckConstraint( + "created_by_ref ~ '^identity--[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$'", # noqa: E131 + ), + ), + Column("created", db_backend.determine_sql_type_for_timestamp_property()), + Column("modified", db_backend.determine_sql_type_for_timestamp_property()), + Column("revoked", db_backend.determine_sql_type_for_boolean_property()), + Column("confidence", db_backend.determine_sql_type_for_integer_property()), + Column("lang", db_backend.determine_sql_type_for_string_property()), + ] + columns.extend(sdo_columns) + if db_backend.array_allowed(): + columns.append(create_array_column("labels", db_backend.determine_sql_type_for_string_property(), True)) + else: + tables.append( + create_array_child_table( + metadata, + db_backend, + table_name, + "_labels", + "label", + db_backend.determine_sql_type_for_string_property(), + ), + ) + else: + columns.append(Column("defanged", db_backend.determine_sql_type_for_boolean_property(), default=False)) + + tables.append( + Table( + table_name, + metadata, + *columns, + schema=db_backend.schema_for_core(), + ), + ) + return tables + + +@add_method(Property) +def determine_sql_type(self, db_backend): # noqa: F811 + pass + + +@add_method(KillChainPhase) +def determine_sql_type(self, db_backend): # noqa: F811 + return db_backend.determine_sql_type_for_kill_chain_phase() + + +@add_method(BinaryProperty) +def determine_sql_type(self, db_backend): # noqa: F811 + return db_backend.determine_sql_type_for_binary_property() + + +@add_method(BooleanProperty) +def determine_sql_type(self, db_backend): # noqa: F811 + return db_backend.determine_sql_type_for_boolean_property() + + +@add_method(FloatProperty) +def determine_sql_type(self, db_backend): # noqa: F811 + return db_backend.determine_sql_type_for_float_property() + + +@add_method(HexProperty) +def determine_sql_type(self, db_backend): # noqa: F811 + return db_backend.determine_sql_type_for_hex_property() + + +@add_method(IntegerProperty) +def determine_sql_type(self, db_backend): # noqa: F811 + return db_backend.determine_sql_type_for_integer_property() + + +@add_method(ReferenceProperty) +def determine_sql_type(self, db_backend): # noqa: F811 + return db_backend.determine_sql_type_for_reference_property() + + +@add_method(StringProperty) +def determine_sql_type(self, db_backend): # noqa: F811 + return db_backend.determine_sql_type_for_string_property() + + +@add_method(TimestampProperty) +def determine_sql_type(self, db_backend): # noqa: F811 + return db_backend.determine_sql_type_for_timestamp_property() + + +# ----------------------------- generate_table_information methods ---------------------------- + +@add_method(KillChainPhase) +def generate_table_information( # noqa: F811 + self, name, db_backend, metadata, schema_name, table_name, is_extension=False, is_list=False, + **kwargs, +): + level = kwargs.get("level") + return generate_object_table( + self.type, metadata, schema_name, table_name, is_extension, True, is_list, + parent_table_name=table_name, level=level + 1 if is_list else level, + ) + + +@add_method(Property) +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 + pass + + +@add_method(BinaryProperty) +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 + return Column( + name, + self.determine_sql_type(db_backend), + CheckConstraint( + # this regular expression might accept or reject some legal base64 strings + f"{name} ~ " + "'^[-A-Za-z0-9+/]*={0,3}$'", + ), + nullable=not self.required, + ) + + +@add_method(BooleanProperty) +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 + return Column( + name, + self.determine_sql_type(db_backend), + nullable=not self.required, + default=self._fixed_value if hasattr(self, "_fixed_value") else None, + ) + + +@add_method(DictionaryProperty) +def generate_table_information(self, name, db_backend, metadata, schema_name, table_name, is_extension=False, **kwargs): # noqa: F811 + columns = list() + + columns.append( + Column( + "id", + db_backend.determine_sql_type_for_key_as_id(), + ForeignKey(canonicalize_table_name(table_name, schema_name) + ".id", ondelete="CASCADE"), + ), + ) + columns.append( + Column( + "name", + db_backend.determine_sql_type_for_string_property(), + nullable=False, + ), + ) + if self.valid_types: + if len(self.valid_types) == 1: + if not isinstance(self.valid_types[0], ListProperty): + columns.append( + Column( + "value", + # its a class + determine_sql_type_from_stix(self.valid_types[0], db_backend), + nullable=False, + ), + ) + else: + contained_class = self.valid_types[0].contained + columns.append( + create_array_column( + "value", + contained_class.determine_sql_type(db_backend), + False, + ), + ) + else: + for column_type in self.valid_types: + sql_type = determine_sql_type_from_stix(column_type, db_backend) + columns.append( + Column( + determine_column_name(column_type), + sql_type, + ), + ) + else: + columns.append( + Column( + "value", + db_backend.determine_sql_type_for_string_property(), + nullable=False, + ), + ) + return [ + Table( + canonicalize_table_name(table_name + "_" + name), + metadata, + *columns, + UniqueConstraint("id", "name"), + schema=schema_name, + ), + ] + + +@add_method(EmbeddedObjectProperty) +def generate_table_information(self, name, db_backend, metadata, schema_name, table_name, is_extension=False, is_list=False, **kwargs): # noqa: F811 + level = kwargs.get("level") + return generate_object_table( + self.type, db_backend, metadata, schema_name, table_name, is_extension, True, is_list, + parent_table_name=table_name, level=level+1 if is_list else level, + ) + + +@add_method(EnumProperty) +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 + enum_re = "|".join(self.allowed) + return Column( + name, + self.determine_sql_type(db_backend), + CheckConstraint( + f"{name} ~ '^{enum_re}$'", + ), + nullable=not self.required, + ) + + +@add_method(ExtensionsProperty) +def generate_table_information(self, name, db_backend, metadata, schema_name, table_name, **kwargs): # noqa: F811 + columns = list() + columns.append( + Column( + "id", + db_backend.determine_sql_type_for_key_as_id(), + ForeignKey(canonicalize_table_name(table_name, schema_name) + ".id", ondelete="CASCADE"), + nullable=False, + ), + ) + columns.append( + Column( + "ext_table_name", + db_backend.determine_sql_type_for_string_property(), + nullable=False, + ), + ) + return [Table(canonicalize_table_name(table_name + "_" + name), metadata, *columns, schema=schema_name)] + + +@add_method(FloatProperty) +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 + return Column( + name, + self.determine_sql_type(db_backend), + nullable=not self.required, + default=self._fixed_value if hasattr(self, "_fixed_value") else None, + ) + + +@add_method(HashesProperty) +def generate_table_information(self, name, db_backend, metadata, schema_name, table_name, is_extension=False, **kwargs): # noqa: F811 + level = kwargs.get("level") + if kwargs.get("is_embedded_object"): + if not kwargs.get("is_list") or level == 0: + key_type = Text + else: + key_type = Integer + else: + key_type = Text + return [ + create_hashes_table( + name, + metadata, + db_backend, + schema_name, + table_name, + key_type=key_type, + level=level, + ), + ] + + +@add_method(HexProperty) +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 + return Column( + name, + db_backend.determine_sql_type_for_hex_property(), + nullable=not self.required, + ) + + +@add_method(IDProperty) +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 + schema_name = kwargs.get('schema_name') + table_name = kwargs.get("table_name") + core_table = kwargs.get("core_table") + # if schema_name == "common": + # return Column( + # name, + # Text, + # CheckConstraint( + # f"{name} ~ '^{table_name}" + "--[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$'", + # # noqa: E131 + # ), + # primary_key=True, + # nullable=not (self.required), + # ) + # else: + if schema_name: + foreign_key_column = f"common.core_{core_table}.id" + else: + foreign_key_column = f"core_{core_table}.id" + return Column( + name, + db_backend.determine_sql_type_for_key_as_id(), + ForeignKey(foreign_key_column, ondelete="CASCADE"), + CheckConstraint( + f"{name} ~ '^{table_name}" + "--[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$'", + # noqa: E131 + ), + primary_key=True, + nullable=not (self.required), + ) + + +@add_method(IntegerProperty) +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 + return Column( + name, + self.determine_sql_type(db_backend), + nullable=not self.required, + default=self._fixed_value if hasattr(self, "_fixed_value") else None, + ) + + +@add_method(ListProperty) +def generate_table_information(self, name, db_backend, metadata, schema_name, table_name, **kwargs): # noqa: F811 + is_extension = kwargs.get('is_extension') + is_embedded_object = kwargs.get('is_embedded_object') + tables = list() + # handle more complex embedded object before deciding if the ARRAY type is usable + if isinstance(self.contained, EmbeddedObjectProperty): + columns = list() + columns.append( + Column( + "id", + db_backend.determine_sql_type_for_key_as_id(), + ForeignKey( + canonicalize_table_name(table_name, schema_name) + ".id", + ondelete="CASCADE", + ), + ), + ) + columns.append( + Column( + "ref_id", + db_backend.determine_sql_type_for_key_as_int(), + primary_key=True, + nullable=False, + # all such keys are generated using the global sequence. + autoincrement=False, + ), + ) + tables.append(Table(canonicalize_table_name(table_name + "_" + name), metadata, *columns, schema=schema_name)) + tables.extend( + self.contained.generate_table_information( + name, + db_backend, + metadata, + schema_name, + canonicalize_table_name(table_name + "_" + name, None), + # if sub_table_needed else canonicalize_table_name(table_name, None), + is_extension, + parent_table_name=table_name, + is_list=True, + level=kwargs.get("level"), + ), + ) + return tables + elif isinstance(self.contained, ReferenceProperty): + return [ + create_ref_table( + metadata, + db_backend, + self.contained.specifics, + canonicalize_table_name(table_name + "_" + name), + canonicalize_table_name(table_name, schema_name) + ".id", + schema_name, + ), + ] + elif (( + isinstance( + self.contained, + (BinaryProperty, BooleanProperty, StringProperty, IntegerProperty, FloatProperty, HexProperty, TimestampProperty), + ) and + not db_backend.array_allowed() + ) or + isinstance(self.contained, EnumProperty)): + columns = list() + if is_embedded_object: + id_type = db_backend.determine_sql_type_for_key_as_int() + else: + id_type = db_backend.determine_sql_type_for_key_as_id() + columns.append( + Column( + "id", + id_type, + ForeignKey( + canonicalize_table_name(table_name, schema_name) + ".id", + ondelete="CASCADE", + ), + nullable=False, + ), + ) + columns.append(self.contained.generate_table_information(name, db_backend)) + tables.append(Table(canonicalize_table_name(table_name + "_" + name), metadata, *columns, schema=schema_name)) + + elif self.contained == KillChainPhase: + tables.append(create_kill_chain_phases_table(name, metadata, db_backend, schema_name, table_name)) + return tables + else: + # if ARRAY is not allowed, it is handled by a previous if clause + if isinstance(self.contained, Property): + return create_array_column(name, self.contained.determine_sql_type(db_backend), not self.required) + + +def ref_column(name, specifics, db_backend, auth_type=0): + if specifics: + types = "|".join(specifics) + if auth_type == 0: + constraint = \ + CheckConstraint( + f"{name} ~ '^({types})" + + "--[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$'", + ) + else: + constraint = \ + CheckConstraint( + f"(NOT({name} ~ '^({types})')) AND ({name} ~ " + + "'--[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$')", + ) + return Column(name, db_backend.determine_sql_type_for_reference_property(), constraint) + else: + return Column( + name, + db_backend.determine_sql_type_for_reference_property(), + nullable=False, + ) + + +@add_method(ObjectReferenceProperty) +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 + table_name = kwargs.get('table_name') + raise ValueError(f"Property {name} in {table_name} is of type ObjectReferenceProperty, which is for STIX 2.0 only") + + +@add_method(ReferenceProperty) +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 + return ref_column(name, self.specifics, db_backend, self.auth_type) + + +@add_method(StringProperty) +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 + return Column( + name, + db_backend.determine_sql_type_for_string_property(), + nullable=not self.required, + default=self._fixed_value if hasattr(self, "_fixed_value") else None, + ) + + +@add_method(TimestampProperty) +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 + return Column( + name, + self.determine_sql_type(db_backend), + # CheckConstraint( + # f"{name} ~ '^{enum_re}$'" + # ), + nullable=not (self.required), + ) + + +@add_method(TypeProperty) +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 + return Column( + name, + db_backend.determine_sql_type_for_string_property(), + nullable=not self.required, + default=self._fixed_value if hasattr(self, "_fixed_value") else None, + ) + + +def generate_object_table( + stix_object_class, db_backend, metadata, schema_name, foreign_key_name=None, + is_extension=False, is_embedded_object=False, is_list=False, parent_table_name=None, level=0, +): + properties = stix_object_class._properties + if hasattr(stix_object_class, "_type"): + table_name = stix_object_class._type + else: + table_name = stix_object_class.__name__ + # avoid long table names + if table_name.startswith("extension-definition"): + table_name = table_name[0:30] + table_name = table_name.replace("extension-definition-", "ext_def") + if parent_table_name: + table_name = parent_table_name + "_" + table_name + if is_embedded_object: + core_properties = list() + elif schema_name in ["sdo", "sro", "common"]: + core_properties = SDO_COMMON_PROPERTIES + elif schema_name == "sco": + core_properties = SCO_COMMON_PROPERTIES + else: + core_properties = list() + columns = list() + tables = list() + if schema_name == "sco": + core_table = "sco" + else: + # sro, smo common properties are the same as sdo's + core_table = "sdo" + for name, prop in properties.items(): + # type is never a column since it is implicit in the table + if (name == 'id' or name not in core_properties) and name != 'type': + col = prop.generate_table_information( + name, + db_backend, + metadata=metadata, + schema_name=schema_name, + table_name=table_name, + is_extension=is_extension, + is_embedded_object=is_embedded_object, + is_list=is_list, + level=level, + parent_table_name=parent_table_name, + core_table=core_table, + ) + if col is not None and isinstance(col, Column): + columns.append(col) + if col is not None and isinstance(col, list): + tables.extend(col) + if is_extension and not is_embedded_object: + columns.append( + Column( + "id", + db_backend.determine_sql_type_for_key_as_id(), + # no Foreign Key because it could be for different tables + primary_key=True, + ), + ) + if foreign_key_name: + if level == 0: + if is_extension and not is_embedded_object: + column = Column( + "id", + db_backend.determine_sql_type_for_key_as_id(), + ForeignKey( + canonicalize_table_name(foreign_key_name, schema_name) + ".id", + ondelete="CASCADE", + ), + ) + elif is_embedded_object: + column = Column( + "id", + db_backend.determine_sql_type_for_key_as_int() if is_list else db_backend.determine_sql_type_for_key_as_id(), + ForeignKey( + canonicalize_table_name(foreign_key_name, schema_name) + (".ref_id" if is_list else ".id"), + ondelete="CASCADE", + ), + # if it is a not list, then it is a single embedded object, and the primary key is unique + primary_key=not is_list, + ) + elif level > 0 and is_embedded_object: + column = Column( + "id", + db_backend.determine_sql_type_for_key_as_int() if (is_embedded_object and is_list) else db_backend.determine_sql_type_for_key_as_id(), + ForeignKey( + canonicalize_table_name(foreign_key_name, schema_name) + (".ref_id" if (is_embedded_object and is_list) else ".id"), + ondelete="CASCADE", + ), + primary_key=True, + nullable=False, + ) + else: + column = Column( + "id", + db_backend.determine_sql_type_for_key_as_id(), + ForeignKey( + canonicalize_table_name(foreign_key_name, schema_name) + ".id", + ondelete="CASCADE", + ), + ) + columns.append(column) + + # all_tables = [Table(canonicalize_table_name(table_name), metadata, *columns, schema=schema_name)] + # all_tables.extend(tables) + # return all_tables + + tables.append(Table(canonicalize_table_name(table_name), metadata, *columns, schema=schema_name)) + return tables + + +def add_tables(new_tables, tables): + if isinstance(new_tables, list): + tables.extend(new_tables) + else: + tables.append(new_tables) + + +def create_core_tables(metadata, db_backend): + tables = list() + add_tables(create_core_table(metadata, db_backend, "sdo"), tables) + add_tables(create_granular_markings_table(metadata, db_backend, "sdo"), tables) + add_tables(create_core_table(metadata, db_backend, "sco"), tables) + add_tables(create_granular_markings_table(metadata, db_backend, "sco"), tables) + add_tables(create_object_markings_refs_table(metadata, db_backend, "sdo"), tables) + add_tables(create_object_markings_refs_table(metadata, db_backend, "sco"), tables) + tables.extend(create_external_references_tables(metadata, db_backend)) + return tables + + +def create_table_objects(metadata, db_backend, stix_object_classes): + if stix_object_classes: + # If classes are given, allow some flexibility regarding lists of + # classes vs single classes + stix_object_classes = flat_classes(stix_object_classes) + + else: + # If no classes given explicitly, discover them automatically + stix_object_classes = get_stix_object_classes() + + tables = create_core_tables(metadata, db_backend) + + for stix_class in stix_object_classes: + + schema_name = db_backend.schema_for(stix_class) + is_extension = issubclass(stix_class, _Extension) + + tables.extend( + generate_object_table( + stix_class, + db_backend, + metadata, + schema_name, + is_extension=is_extension, + ), + ) + + return tables diff --git a/stix2/datastore/relational_db/utils.py b/stix2/datastore/relational_db/utils.py new file mode 100644 index 00000000..65257426 --- /dev/null +++ b/stix2/datastore/relational_db/utils.py @@ -0,0 +1,171 @@ +from collections.abc import Iterable, Mapping + +import inflection + +from stix2.properties import ( + BinaryProperty, BooleanProperty, FloatProperty, HexProperty, + IntegerProperty, Property, ReferenceProperty, StringProperty, + TimestampProperty, +) +from stix2.v21.base import ( + _DomainObject, _Extension, _MetaObject, _Observable, _RelationshipObject, +) + +# Helps us know which data goes in core, and which in a type-specific table. +SCO_COMMON_PROPERTIES = { + "id", + "type", + "spec_version", + "object_marking_refs", + "granular_markings", + "defanged", +} + +# Helps us know which data goes in core, and which in a type-specific table. +SDO_COMMON_PROPERTIES = { + "id", + "type", + "spec_version", + "object_marking_refs", + "granular_markings", + "defanged", + "created", + "modified", + "created_by_ref", + "revoked", + "labels", + "confidence", + "lang", + "external_references", +} + + +def canonicalize_table_name(table_name, schema_name=None): + if schema_name: + full_name = schema_name + "." + table_name + else: + full_name = table_name + full_name = full_name.replace("-", "_") + return inflection.underscore(full_name) + + +_IGNORE_OBJECTS = ["language-content"] + + +def get_all_subclasses(cls): + all_subclasses = [] + + for subclass in cls.__subclasses__(): + # This code might be useful if we decide that some objects just cannot have there tables + # automatically generated + + # if hasattr(subclass, "_type") and subclass._type in _IGNORE_OBJECTS: + # print(f'It is currently not possible to create a table for {subclass._type}') + # return [] + # else: + all_subclasses.append(subclass) + all_subclasses.extend(get_all_subclasses(subclass)) + return all_subclasses + + +def get_stix_object_classes(): + yield from get_all_subclasses(_DomainObject) + yield from get_all_subclasses(_RelationshipObject) + yield from get_all_subclasses(_Observable) + yield from get_all_subclasses(_MetaObject) + # Non-object extensions (property or toplevel-property only) + for ext_cls in get_all_subclasses(_Extension): + if ext_cls.extension_type not in ( + "new-sdo", "new-sco", "new-sro", + ): + yield ext_cls + + +def schema_for(stix_class): + + if issubclass(stix_class, _DomainObject): + schema_name = "sdo" + elif issubclass(stix_class, _RelationshipObject): + schema_name = "sro" + elif issubclass(stix_class, _Observable): + schema_name = "sco" + elif issubclass(stix_class, _MetaObject): + schema_name = "common" + elif issubclass(stix_class, _Extension): + schema_name = getattr(stix_class, "_applies_to", "sco") + else: + schema_name = None + + return schema_name + + +def table_name_for(stix_type_or_class): + if isinstance(stix_type_or_class, str): + table_name = stix_type_or_class + else: + # A _STIXBase subclass + table_name = getattr(stix_type_or_class, "_type", stix_type_or_class.__name__) + + # Applies to registered extension-definition style extensions only. + # Their "_type" attribute is actually set to the extension definition ID, + # rather than a STIX type. + if table_name.startswith("extension-definition"): + table_name = table_name[0:30] + table_name = table_name.replace("extension-definition-", "ext_def") + + table_name = canonicalize_table_name(table_name) + return table_name + + +def flat_classes(class_or_classes): + if isinstance(class_or_classes, Iterable) and not isinstance( + # Try to generically detect STIX objects, which are iterable, but we + # don't want to iterate through those. + class_or_classes, Mapping, + ): + for class_ in class_or_classes: + yield from flat_classes(class_) + else: + yield class_or_classes + + +def is_class_or_instance(cls_or_inst, cls): + return cls_or_inst == cls or isinstance(cls_or_inst, cls) + + +def determine_sql_type_from_stix(cls_or_inst, db_backend): # noqa: F811 + if is_class_or_instance(cls_or_inst, BinaryProperty): + return db_backend.determine_sql_type_for_binary_property() + elif is_class_or_instance(cls_or_inst, BooleanProperty): + return db_backend.determine_sql_type_for_boolean_property() + elif is_class_or_instance(cls_or_inst, FloatProperty): + return db_backend.determine_sql_type_for_float_property() + elif is_class_or_instance(cls_or_inst, HexProperty): + return db_backend.determine_sql_type_for_hex_property() + elif is_class_or_instance(cls_or_inst, IntegerProperty): + return db_backend.determine_sql_type_for_integer_property() + elif is_class_or_instance(cls_or_inst, StringProperty): + return db_backend.determine_sql_type_for_string_property() + elif is_class_or_instance(cls_or_inst, ReferenceProperty): + return db_backend.determine_sql_type_for_reference_property() + elif is_class_or_instance(cls_or_inst, TimestampProperty): + return db_backend.determine_sql_type_for_timestamp_property() + elif is_class_or_instance(cls_or_inst, Property): + return db_backend.determine_sql_type_for_integer_property() + + +def determine_column_name(cls_or_inst): # noqa: F811 + if is_class_or_instance(cls_or_inst, BinaryProperty): + return "binary_value" + elif is_class_or_instance(cls_or_inst, BooleanProperty): + return "boolean_value" + elif is_class_or_instance(cls_or_inst, FloatProperty): + return "float_value" + elif is_class_or_instance(cls_or_inst, HexProperty): + return "hex_value" + elif is_class_or_instance(cls_or_inst, IntegerProperty): + return "integer_value" + elif is_class_or_instance(cls_or_inst, StringProperty) or is_class_or_instance(cls_or_inst, ReferenceProperty): + return "string_value" + elif is_class_or_instance(cls_or_inst, TimestampProperty): + return "timestamp_value" diff --git a/stix2/environment.py b/stix2/environment.py index eab2ba9e..76d7be0d 100644 --- a/stix2/environment.py +++ b/stix2/environment.py @@ -191,7 +191,7 @@ def creator_of(self, obj): def object_similarity( obj1, obj2, prop_scores={}, ds1=None, ds2=None, ignore_spec_version=False, versioning_checks=False, - max_depth=1, **weight_dict + max_depth=1, **weight_dict, ): """This method returns a measure of how similar the two objects are. @@ -236,14 +236,14 @@ def object_similarity( """ return object_similarity( obj1, obj2, prop_scores, ds1, ds2, ignore_spec_version, - versioning_checks, max_depth, **weight_dict + versioning_checks, max_depth, **weight_dict, ) @staticmethod def object_equivalence( obj1, obj2, prop_scores={}, threshold=70, ds1=None, ds2=None, ignore_spec_version=False, versioning_checks=False, - max_depth=1, **weight_dict + max_depth=1, **weight_dict, ): """This method returns a true/false value if two objects are semantically equivalent. Internally, it calls the object_similarity function and compares it against the given @@ -294,13 +294,13 @@ def object_equivalence( """ return object_equivalence( obj1, obj2, prop_scores, threshold, ds1, ds2, - ignore_spec_version, versioning_checks, max_depth, **weight_dict + ignore_spec_version, versioning_checks, max_depth, **weight_dict, ) @staticmethod def graph_similarity( ds1, ds2, prop_scores={}, ignore_spec_version=False, - versioning_checks=False, max_depth=1, **weight_dict + versioning_checks=False, max_depth=1, **weight_dict, ): """This method returns a similarity score for two given graphs. Each DataStore can contain a connected or disconnected graph and the @@ -347,14 +347,14 @@ def graph_similarity( """ return graph_similarity( ds1, ds2, prop_scores, ignore_spec_version, - versioning_checks, max_depth, **weight_dict + versioning_checks, max_depth, **weight_dict, ) @staticmethod def graph_equivalence( ds1, ds2, prop_scores={}, threshold=70, ignore_spec_version=False, versioning_checks=False, - max_depth=1, **weight_dict + max_depth=1, **weight_dict, ): """This method returns a true/false value if two graphs are semantically equivalent. Internally, it calls the graph_similarity function and compares it against the given @@ -403,5 +403,5 @@ def graph_equivalence( """ return graph_equivalence( ds1, ds2, prop_scores, threshold, ignore_spec_version, - versioning_checks, max_depth, **weight_dict + versioning_checks, max_depth, **weight_dict, ) diff --git a/stix2/equivalence/graph/__init__.py b/stix2/equivalence/graph/__init__.py index 1f46fd3e..1beee4aa 100644 --- a/stix2/equivalence/graph/__init__.py +++ b/stix2/equivalence/graph/__init__.py @@ -11,7 +11,7 @@ def graph_equivalence( ds1, ds2, prop_scores={}, threshold=70, ignore_spec_version=False, versioning_checks=False, - max_depth=1, **weight_dict + max_depth=1, **weight_dict, ): """This method returns a true/false value if two graphs are semantically equivalent. Internally, it calls the graph_similarity function and compares it against the given @@ -60,7 +60,7 @@ def graph_equivalence( """ similarity_result = graph_similarity( ds1, ds2, prop_scores, ignore_spec_version, - versioning_checks, max_depth, **weight_dict + versioning_checks, max_depth, **weight_dict, ) if similarity_result >= threshold: return True @@ -69,7 +69,7 @@ def graph_equivalence( def graph_similarity( ds1, ds2, prop_scores={}, ignore_spec_version=False, - versioning_checks=False, max_depth=1, **weight_dict + versioning_checks=False, max_depth=1, **weight_dict, ): """This method returns a similarity score for two given graphs. Each DataStore can contain a connected or disconnected graph and the @@ -147,7 +147,7 @@ def graph_similarity( result = object_similarity( object1, object2, iprop_score, ds1, ds2, ignore_spec_version, versioning_checks, - max_depth, **weights + max_depth, **weights, ) if object1_id not in results: diff --git a/stix2/equivalence/object/__init__.py b/stix2/equivalence/object/__init__.py index 2b16a346..00534881 100644 --- a/stix2/equivalence/object/__init__.py +++ b/stix2/equivalence/object/__init__.py @@ -14,7 +14,7 @@ def object_equivalence( obj1, obj2, prop_scores={}, threshold=70, ds1=None, ds2=None, ignore_spec_version=False, - versioning_checks=False, max_depth=1, **weight_dict + versioning_checks=False, max_depth=1, **weight_dict, ): """This method returns a true/false value if two objects are semantically equivalent. Internally, it calls the object_similarity function and compares it against the given @@ -65,7 +65,7 @@ def object_equivalence( """ similarity_result = object_similarity( obj1, obj2, prop_scores, ds1, ds2, ignore_spec_version, - versioning_checks, max_depth, **weight_dict + versioning_checks, max_depth, **weight_dict, ) if similarity_result >= threshold: return True @@ -75,7 +75,7 @@ def object_equivalence( def object_similarity( obj1, obj2, prop_scores={}, ds1=None, ds2=None, ignore_spec_version=False, versioning_checks=False, - max_depth=1, **weight_dict + max_depth=1, **weight_dict, ): """This method returns a measure of similarity depending on how similar the two objects are. diff --git a/stix2/properties.py b/stix2/properties.py index be1c4b69..43612826 100644 --- a/stix2/properties.py +++ b/stix2/properties.py @@ -133,7 +133,7 @@ class Property(object): Subclasses can also define the following functions: - - ``def clean(self, value, allow_custom) -> (any, has_custom):`` + - ``def clean(self, value, allow_custom, strict) -> (any, has_custom):`` - Return a value that is valid for this property, and enforce and detect value customization. If ``value`` is not valid for this property, you may attempt to transform it first. If ``value`` is not @@ -148,7 +148,9 @@ class Property(object): mean there actually are any). The method must return an appropriate value for has_custom. Customization may not be applicable/possible for a property. In that case, allow_custom can be ignored, and - has_custom must be returned as False. + has_custom must be returned as False. strict is a True/False flag + that is used in the dictionary property. if strict is True, + properties like StringProperty will not be lenient in their clean method. - ``def default(self):`` - provide a default value for this property. @@ -170,7 +172,7 @@ class Property(object): """ - def _default_clean(self, value, allow_custom=False): + def _default_clean(self, value, allow_custom=False, strict=False): if value != self._fixed_value: raise ValueError("must equal '{}'.".format(self._fixed_value)) return value, False @@ -191,7 +193,7 @@ def __init__(self, required=False, fixed=None, default=None): if default: self.default = default - def clean(self, value, allow_custom=False): + def clean(self, value, allow_custom=False, strict=False): return value, False @@ -224,7 +226,7 @@ def __init__(self, contained, **kwargs): super(ListProperty, self).__init__(**kwargs) - def clean(self, value, allow_custom): + def clean(self, value, allow_custom=False, strict=False): try: iter(value) except TypeError: @@ -237,7 +239,10 @@ def clean(self, value, allow_custom): has_custom = False if isinstance(self.contained, Property): for item in value: - valid, temp_custom = self.contained.clean(item, allow_custom) + try: + valid, temp_custom = self.contained.clean(item, allow_custom, strict=strict) + except TypeError: + valid, temp_custom = self.contained.clean(item, allow_custom) result.append(valid) has_custom = has_custom or temp_custom @@ -275,9 +280,11 @@ class StringProperty(Property): def __init__(self, **kwargs): super(StringProperty, self).__init__(**kwargs) - def clean(self, value, allow_custom=False): - if not isinstance(value, str): - value = str(value) + def clean(self, value, allow_custom=False, strict=False): + if strict and not isinstance(value, str): + raise ValueError("Must be a string.") + + value = str(value) return value, False @@ -296,7 +303,7 @@ def __init__(self, type, spec_version=DEFAULT_VERSION): self.spec_version = spec_version super(IDProperty, self).__init__() - def clean(self, value, allow_custom=False): + def clean(self, value, allow_custom=False, strict=False): _validate_id(value, self.spec_version, self.required_prefix) return value, False @@ -311,7 +318,10 @@ def __init__(self, min=None, max=None, **kwargs): self.max = max super(IntegerProperty, self).__init__(**kwargs) - def clean(self, value, allow_custom=False): + def clean(self, value, allow_custom=False, strict=False): + if strict and not isinstance(value, int): + raise ValueError("must be an integer.") + try: value = int(value) except Exception: @@ -335,7 +345,10 @@ def __init__(self, min=None, max=None, **kwargs): self.max = max super(FloatProperty, self).__init__(**kwargs) - def clean(self, value, allow_custom=False): + def clean(self, value, allow_custom=False, strict=False): + if strict and not isinstance(value, float): + raise ValueError("must be a float.") + try: value = float(value) except Exception: @@ -356,7 +369,10 @@ class BooleanProperty(Property): _trues = ['true', 't', '1', 1, True] _falses = ['false', 'f', '0', 0, False] - def clean(self, value, allow_custom=False): + def clean(self, value, allow_custom=False, strict=False): + + if strict and not isinstance(value, bool): + raise ValueError("must be a boolean value.") if isinstance(value, str): value = value.lower() @@ -379,7 +395,7 @@ def __init__(self, precision="any", precision_constraint="exact", **kwargs): super(TimestampProperty, self).__init__(**kwargs) - def clean(self, value, allow_custom=False): + def clean(self, value, allow_custom=False, strict=False): return parse_into_datetime( value, self.precision, self.precision_constraint, ), False @@ -387,36 +403,136 @@ def clean(self, value, allow_custom=False): class DictionaryProperty(Property): - def __init__(self, spec_version=DEFAULT_VERSION, **kwargs): + def __init__(self, valid_types=None, spec_version=DEFAULT_VERSION, **kwargs): self.spec_version = spec_version + self.valid_types = self._normalize_valid_types(valid_types or []) + super(DictionaryProperty, self).__init__(**kwargs) - def clean(self, value, allow_custom=False): + def _normalize_valid_types(self, valid_types): + """ + Normalize valid_types to a list of property instances. Also ensure any + property types given are supported for type enforcement. + + :param valid_types: A single or iterable of Property instances or + subclasses + :return: A list of Property instances, or None if none were given + """ + simple_types = ( + BinaryProperty, BooleanProperty, FloatProperty, HexProperty, + IntegerProperty, StringProperty, TimestampProperty, + ReferenceProperty, EnumProperty, + ) + + # Normalize single prop instances/classes to lists + try: + iter(valid_types) + except TypeError: + valid_types = [valid_types] + + prop_instances = [] + for valid_type in valid_types: + if inspect.isclass(valid_type): + # Note: this will fail as of this writing with EnumProperty + # ReferenceProperty, ListProperty. Callers must instantiate + # those with suitable settings themselves. + prop_instance = valid_type() + + else: + prop_instance = valid_type + + # ListProperty's element type must be one of the supported + # simple types. + if isinstance(prop_instance, ListProperty): + if not isinstance(prop_instance.contained, simple_types): + raise ValueError( + "DictionaryProperty does not support lists of type: " + + type(prop_instance.contained).__name__, + ) + + elif not isinstance(prop_instance, simple_types): + raise ValueError( + "DictionaryProperty does not support value type: " + + type(prop_instance).__name__, + ) + + prop_instances.append(prop_instance) + + return prop_instances or None + + def _check_dict_key(self, k): + if self.spec_version == '2.0': + if len(k) < 3: + raise DictionaryKeyError(k, "shorter than 3 characters") + elif len(k) > 256: + raise DictionaryKeyError(k, "longer than 256 characters") + elif self.spec_version == '2.1': + if len(k) > 250: + raise DictionaryKeyError(k, "longer than 250 characters") + if not re.match(r"^[a-zA-Z0-9_-]+$", k): + msg = ( + "contains characters other than lowercase a-z, " + "uppercase A-Z, numerals 0-9, hyphen (-), or " + "underscore (_)" + ) + raise DictionaryKeyError(k, msg) + + def clean(self, value, allow_custom=False, strict=False): try: dictified = _get_dict(value) except ValueError: raise ValueError("The dictionary property must contain a dictionary") - for k in dictified.keys(): - if self.spec_version == '2.0': - if len(k) < 3: - raise DictionaryKeyError(k, "shorter than 3 characters") - elif len(k) > 256: - raise DictionaryKeyError(k, "longer than 256 characters") - elif self.spec_version == '2.1': - if len(k) > 250: - raise DictionaryKeyError(k, "longer than 250 characters") - if not re.match(r"^[a-zA-Z0-9_-]+$", k): - msg = ( - "contains characters other than lowercase a-z, " - "uppercase A-Z, numerals 0-9, hyphen (-), or " - "underscore (_)" - ) - raise DictionaryKeyError(k, msg) + + has_custom = False + for k, v in dictified.items(): + + self._check_dict_key(k) + + if self.valid_types: + for type_ in self.valid_types: + try: + # ReferenceProperty at least, does check for + # customizations, so we must propagate that + dictified[k], temp_custom = type_.clean( + value=v, + allow_custom=allow_custom, + # Ignore the passed-in value and fix this to True; + # we need strict cleaning to disambiguate value + # types here. + strict=True, + ) + except CustomContentError: + # Need to propagate these, not treat as a type error + raise + except Exception: + # clean failed; value must not conform to type_ + # Should be a narrower exception type here, but I don't + # know if it's safe to assume any particular exception + # types... + pass + else: + # clean succeeded; should check the has_custom flag + # just in case. But if allow_custom is False, I expect + # one of the valid_types property instances would have + # already raised an exception. + has_custom = has_custom or temp_custom + if has_custom and not allow_custom: + raise CustomContentError(f'Custom content detected in key "{k}"') + + break + + else: + # clean failed for all properties! + raise ValueError( + f"Invalid value: {v!r}", + ) + + # else: no valid types given, so we skip the validity check if len(dictified) < 1: raise ValueError("must not be empty.") - return dictified, False + return dictified, has_custom class HashesProperty(DictionaryProperty): @@ -434,7 +550,7 @@ def __init__(self, spec_hash_names, spec_version=DEFAULT_VERSION, **kwargs): if alg: self.__alg_to_spec_name[alg] = spec_hash_name - def clean(self, value, allow_custom): + def clean(self, value, allow_custom=False, strict=False): # ignore the has_custom return value here; there is no customization # of DictionaryProperties. clean_dict, _ = super().clean(value, allow_custom) @@ -482,7 +598,7 @@ def clean(self, value, allow_custom): class BinaryProperty(Property): - def clean(self, value, allow_custom=False): + def clean(self, value, allow_custom=False, strict=False): try: base64.b64decode(value) except (binascii.Error, TypeError): @@ -492,8 +608,10 @@ def clean(self, value, allow_custom=False): class HexProperty(Property): - def clean(self, value, allow_custom=False): - if not re.match(r"^([a-fA-F0-9]{2})+$", value): + def clean(self, value, allow_custom=False, strict=False): + if isinstance(value, (bytes, bytearray)): + value = value.hex() + elif not re.match(r"^([a-fA-F0-9]{2})+$", value): raise ValueError("must contain an even number of hexadecimal characters") return value, False @@ -541,7 +659,7 @@ def __init__(self, valid_types=None, invalid_types=None, spec_version=DEFAULT_VE super(ReferenceProperty, self).__init__(**kwargs) - def clean(self, value, allow_custom): + def clean(self, value, allow_custom=False, strict=False): if isinstance(value, _STIXBase): value = value.id value = str(value) @@ -576,13 +694,13 @@ def clean(self, value, allow_custom): if auth_type == self._WHITELIST: type_ok = is_stix_type( - obj_type, self.spec_version, *generics + obj_type, self.spec_version, *generics, ) or obj_type in specifics else: type_ok = ( not is_stix_type( - obj_type, self.spec_version, *generics + obj_type, self.spec_version, *generics, ) and obj_type not in specifics ) or obj_type in blacklist_exceptions @@ -619,7 +737,7 @@ def clean(self, value, allow_custom): class SelectorProperty(Property): - def clean(self, value, allow_custom=False): + def clean(self, value, allow_custom=False, strict=False): if not SELECTOR_REGEX.match(value): raise ValueError("must adhere to selector syntax.") return value, False @@ -640,7 +758,7 @@ def __init__(self, type, **kwargs): self.type = type super(EmbeddedObjectProperty, self).__init__(**kwargs) - def clean(self, value, allow_custom): + def clean(self, value, allow_custom=False, strict=False): if isinstance(value, dict): value = self.type(allow_custom=allow_custom, **value) elif not isinstance(value, self.type): @@ -668,8 +786,9 @@ def __init__(self, allowed, **kwargs): self.allowed = allowed super(EnumProperty, self).__init__(**kwargs) - def clean(self, value, allow_custom): - cleaned_value, _ = super(EnumProperty, self).clean(value, allow_custom) + def clean(self, value, allow_custom=False, strict=False): + + cleaned_value, _ = super(EnumProperty, self).clean(value, allow_custom, strict) if cleaned_value not in self.allowed: raise ValueError("value '{}' is not valid for this enumeration.".format(cleaned_value)) @@ -689,25 +808,20 @@ def __init__(self, allowed, **kwargs): allowed = [allowed] self.allowed = allowed - def clean(self, value, allow_custom): + def clean(self, value, allow_custom=False, strict=False): cleaned_value, _ = super(OpenVocabProperty, self).clean( - value, allow_custom, + value, allow_custom, strict, ) - # Disabled: it was decided that enforcing this is too strict (might - # break too much user code). Revisit when we have the capability for - # more granular config settings when creating objects. - # - # has_custom = cleaned_value not in self.allowed - # - # if not allow_custom and has_custom: - # raise CustomContentError( - # "custom value in open vocab: '{}'".format(cleaned_value), - # ) + # Customization enforcement is disabled: it was decided that enforcing + # it is too strict (might break too much user code). On the other + # hand, we need to lock it down in strict mode. If we are locking it + # down in strict mode, we always throw an exception if a value isn't + # in the vocab list, and never report anything as "custom". + if strict and cleaned_value not in self.allowed: + raise ValueError("not in vocab: " + cleaned_value) - has_custom = False - - return cleaned_value, has_custom + return cleaned_value, False class PatternProperty(StringProperty): @@ -722,7 +836,7 @@ def __init__(self, spec_version=DEFAULT_VERSION, *args, **kwargs): self.spec_version = spec_version super(ObservableProperty, self).__init__(*args, **kwargs) - def clean(self, value, allow_custom): + def clean(self, value, allow_custom=False, strict=False): try: dictified = _get_dict(value) # get deep copy since we are going modify the dict and might @@ -770,7 +884,7 @@ class ExtensionsProperty(DictionaryProperty): def __init__(self, spec_version=DEFAULT_VERSION, required=False): super(ExtensionsProperty, self).__init__(spec_version=spec_version, required=required) - def clean(self, value, allow_custom): + def clean(self, value, allow_custom=False, strict=False): try: dictified = _get_dict(value) # get deep copy since we are going modify the dict and might @@ -836,7 +950,7 @@ def __init__(self, spec_version=DEFAULT_VERSION, *args, **kwargs): self.spec_version = spec_version super(STIXObjectProperty, self).__init__(*args, **kwargs) - def clean(self, value, allow_custom): + def clean(self, value, allow_custom=False, strict=False): # Any STIX Object (SDO, SRO, or Marking Definition) can be added to # a bundle with no further checks. stix2_classes = {'_DomainObject', '_RelationshipObject', 'MarkingDefinition'} diff --git a/stix2/test/test_properties.py b/stix2/test/test_properties.py index 5116a685..4637e39f 100644 --- a/stix2/test/test_properties.py +++ b/stix2/test/test_properties.py @@ -181,6 +181,12 @@ def test_string_property(): assert prop.clean(1) assert prop.clean([1, 2, 3]) + with pytest.raises(ValueError): + prop.clean(1, strict=True) + + result = prop.clean("foo", strict=True) + assert result == ("foo", False) + def test_type_property(): prop = TypeProperty('my-type') @@ -244,6 +250,15 @@ def test_integer_property_invalid(value): int_prop.clean(value) +def test_integer_property_strict(): + int_prop = IntegerProperty() + with pytest.raises(ValueError): + int_prop.clean("123", strict=True) + + result = int_prop.clean(123, strict=True) + assert result == (123, False) + + @pytest.mark.parametrize( "value", [ 2, @@ -253,8 +268,8 @@ def test_integer_property_invalid(value): ], ) def test_float_property_valid(value): - int_prop = FloatProperty() - assert int_prop.clean(value) is not None + float_prop = FloatProperty() + assert float_prop.clean(value) is not None @pytest.mark.parametrize( @@ -264,9 +279,18 @@ def test_float_property_valid(value): ], ) def test_float_property_invalid(value): - int_prop = FloatProperty() + float_prop = FloatProperty() with pytest.raises(ValueError): - int_prop.clean(value) + float_prop.clean(value) + + +def test_float_property_strict(): + float_prop = FloatProperty() + with pytest.raises(ValueError): + float_prop.clean("1.323", strict=True) + + result = float_prop.clean(1.323, strict=True) + assert result == (1.323, False) @pytest.mark.parametrize( @@ -308,6 +332,15 @@ def test_boolean_property_invalid(value): bool_prop.clean(value) +def test_boolean_property_strict(): + bool_prop = BooleanProperty() + with pytest.raises(ValueError): + bool_prop.clean("true", strict=True) + + result = bool_prop.clean(True, strict=True) + assert result == (True, False) + + @pytest.mark.parametrize( "value", [ '2017-01-01T12:34:56Z', @@ -368,6 +401,16 @@ def test_enum_property_invalid(): enum_prop.clean('z', True) +def test_enum_property_strict(): + enum_prop = EnumProperty(['1', '2', '3']) + with pytest.raises(ValueError): + enum_prop.clean(1, strict=True) + + result = enum_prop.clean(1, strict=False) + assert result == ("1", False) + + + @pytest.mark.xfail( reason="Temporarily disabled custom open vocab enforcement", strict=True, @@ -391,6 +434,27 @@ def test_openvocab_property(vocab): assert ov_prop.clean("d", True) == ("d", True) +def test_openvocab_property_strict(): + ov_prop = OpenVocabProperty(["1", "2", "3"]) + with pytest.raises(ValueError): + ov_prop.clean(1, allow_custom=False, strict=True) + + with pytest.raises(ValueError): + ov_prop.clean(1, allow_custom=True, strict=True) + + result = ov_prop.clean("1", allow_custom=False, strict=True) + assert result == ("1", False) + + result = ov_prop.clean(1, allow_custom=True, strict=False) + assert result == ("1", False) + + result = ov_prop.clean(1, allow_custom=False, strict=False) + assert result == ("1", False) + + result = ov_prop.clean("foo", allow_custom=False, strict=False) + assert result == ("foo", False) + + @pytest.mark.parametrize( "value", [ {"sha256": "6db12788c37247f2316052e142f42f4b259d6561751e5f401a1ae2a6df9c674b"}, diff --git a/stix2/test/test_workbench.py b/stix2/test/test_workbench.py index 84f97a59..dc3f66f5 100644 --- a/stix2/test/test_workbench.py +++ b/stix2/test/test_workbench.py @@ -32,7 +32,7 @@ def test_workbench_environment(): # Create a STIX object ind = create( - Indicator, id=constants.INDICATOR_ID, **constants.INDICATOR_KWARGS + Indicator, id=constants.INDICATOR_ID, **constants.INDICATOR_KWARGS, ) save(ind) @@ -50,7 +50,7 @@ def test_workbench_environment(): def test_workbench_get_all_attack_patterns(): mal = AttackPattern( - id=constants.ATTACK_PATTERN_ID, **constants.ATTACK_PATTERN_KWARGS + id=constants.ATTACK_PATTERN_ID, **constants.ATTACK_PATTERN_KWARGS, ) save(mal) @@ -70,7 +70,7 @@ def test_workbench_get_all_campaigns(): def test_workbench_get_all_courses_of_action(): coa = CourseOfAction( - id=constants.COURSE_OF_ACTION_ID, **constants.COURSE_OF_ACTION_KWARGS + id=constants.COURSE_OF_ACTION_ID, **constants.COURSE_OF_ACTION_KWARGS, ) save(coa) @@ -114,7 +114,7 @@ def test_workbench_get_all_infrastructures(): def test_workbench_get_all_intrusion_sets(): ins = IntrusionSet( - id=constants.INTRUSION_SET_ID, **constants.INTRUSION_SET_KWARGS + id=constants.INTRUSION_SET_ID, **constants.INTRUSION_SET_KWARGS, ) save(ins) @@ -161,7 +161,7 @@ def test_workbench_get_all_notes(): def test_workbench_get_all_observed_data(): od = ObservedData( - id=constants.OBSERVED_DATA_ID, **constants.OBSERVED_DATA_KWARGS + id=constants.OBSERVED_DATA_ID, **constants.OBSERVED_DATA_KWARGS, ) save(od) @@ -190,7 +190,7 @@ def test_workbench_get_all_reports(): def test_workbench_get_all_threat_actors(): thr = ThreatActor( - id=constants.THREAT_ACTOR_ID, **constants.THREAT_ACTOR_KWARGS + id=constants.THREAT_ACTOR_ID, **constants.THREAT_ACTOR_KWARGS, ) save(thr) @@ -210,7 +210,7 @@ def test_workbench_get_all_tools(): def test_workbench_get_all_vulnerabilities(): vuln = Vulnerability( - id=constants.VULNERABILITY_ID, **constants.VULNERABILITY_KWARGS + id=constants.VULNERABILITY_ID, **constants.VULNERABILITY_KWARGS, ) save(vuln) diff --git a/stix2/test/v20/test_environment.py b/stix2/test/v20/test_environment.py index b6f3a0b3..8dd4eca7 100644 --- a/stix2/test/v20/test_environment.py +++ b/stix2/test/v20/test_environment.py @@ -459,7 +459,7 @@ def test_semantic_check_with_versioning(ds, ds2): }, ], object_marking_refs=[stix2.v20.TLP_WHITE], - ) + ), ) ds.add(ind) score = stix2.equivalence.object.reference_check(ind.id, INDICATOR_ID, ds, ds2, **weights) diff --git a/stix2/test/v20/test_granular_markings.py b/stix2/test/v20/test_granular_markings.py index ae2da3b0..e20d5685 100644 --- a/stix2/test/v20/test_granular_markings.py +++ b/stix2/test/v20/test_granular_markings.py @@ -15,7 +15,7 @@ def test_add_marking_mark_one_selector_multiple_refs(): before = Malware( - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( granular_markings=[ @@ -28,7 +28,7 @@ def test_add_marking_mark_one_selector_multiple_refs(): "marking_ref": MARKING_IDS[1], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.add_markings(before, [MARKING_IDS[0], MARKING_IDS[1]], ["description"]) @@ -47,7 +47,7 @@ def test_add_marking_mark_one_selector_multiple_refs(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), MARKING_IDS[0], ), @@ -60,7 +60,7 @@ def test_add_marking_mark_one_selector_multiple_refs(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), MARKING_IDS[0], ), @@ -73,7 +73,7 @@ def test_add_marking_mark_one_selector_multiple_refs(): "marking_ref": TLP_RED.id, }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), TLP_RED, ), @@ -91,7 +91,7 @@ def test_add_marking_mark_multiple_selector_one_refs(data): def test_add_marking_mark_multiple_selector_multiple_refs(): before = Malware( - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( granular_markings=[ @@ -104,7 +104,7 @@ def test_add_marking_mark_multiple_selector_multiple_refs(): "marking_ref": MARKING_IDS[1], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.add_markings(before, [MARKING_IDS[0], MARKING_IDS[1]], ["description", "name"]) @@ -120,7 +120,7 @@ def test_add_marking_mark_another_property_same_marking(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( granular_markings=[ @@ -129,7 +129,7 @@ def test_add_marking_mark_another_property_same_marking(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.add_markings(before, [MARKING_IDS[0]], ["name"]) @@ -145,7 +145,7 @@ def test_add_marking_mark_same_property_same_marking(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( granular_markings=[ @@ -154,7 +154,7 @@ def test_add_marking_mark_same_property_same_marking(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.add_markings(before, [MARKING_IDS[0]], ["description"]) @@ -391,7 +391,7 @@ def test_get_markings_positional_arguments_combinations(data): "marking_ref": MARKING_IDS[1], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), [MARKING_IDS[0], MARKING_IDS[1]], ), @@ -407,7 +407,7 @@ def test_get_markings_positional_arguments_combinations(data): "marking_ref": MARKING_IDS[1], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), [MARKING_IDS[0], MARKING_IDS[1]], ), @@ -426,7 +426,7 @@ def test_remove_marking_remove_multiple_selector_one_ref(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.remove_markings(before, MARKING_IDS[0], ["description", "modified"]) assert "granular_markings" not in before @@ -440,7 +440,7 @@ def test_remove_marking_mark_one_selector_from_multiple_ones(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = Malware( granular_markings=[ @@ -449,7 +449,7 @@ def test_remove_marking_mark_one_selector_from_multiple_ones(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.remove_markings(before, [MARKING_IDS[0]], ["modified"]) for m in before["granular_markings"]: @@ -468,7 +468,7 @@ def test_remove_marking_mark_one_selector_markings_from_multiple_ones(): "marking_ref": MARKING_IDS[1], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = Malware( granular_markings=[ @@ -481,7 +481,7 @@ def test_remove_marking_mark_one_selector_markings_from_multiple_ones(): "marking_ref": MARKING_IDS[1], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.remove_markings(before, [MARKING_IDS[0]], ["modified"]) for m in before["granular_markings"]: @@ -500,7 +500,7 @@ def test_remove_marking_mark_mutilple_selector_multiple_refs(): "marking_ref": MARKING_IDS[1], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.remove_markings(before, [MARKING_IDS[0], MARKING_IDS[1]], ["description", "modified"]) assert "granular_markings" not in before @@ -514,7 +514,7 @@ def test_remove_marking_mark_another_property_same_marking(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = Malware( granular_markings=[ @@ -527,7 +527,7 @@ def test_remove_marking_mark_another_property_same_marking(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.remove_markings(before, [MARKING_IDS[0]], ["modified"]) for m in before["granular_markings"]: @@ -542,7 +542,7 @@ def test_remove_marking_mark_same_property_same_marking(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.remove_markings(before, [MARKING_IDS[0]], ["description"]) assert "granular_markings" not in before @@ -572,7 +572,7 @@ def test_remove_marking_not_present(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) with pytest.raises(MarkingNotFoundError): markings.remove_markings(before, [MARKING_IDS[1]], ["description"]) @@ -594,7 +594,7 @@ def test_remove_marking_not_present(): "marking_ref": MARKING_IDS[3], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), dict( granular_markings=[ @@ -611,7 +611,7 @@ def test_remove_marking_not_present(): "marking_ref": MARKING_IDS[3], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), ] @@ -844,14 +844,14 @@ def test_create_sdo_with_invalid_marking(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) assert str(excinfo.value) == "Selector foo in Malware is not valid!" def test_set_marking_mark_one_selector_multiple_refs(): before = Malware( - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( granular_markings=[ @@ -864,7 +864,7 @@ def test_set_marking_mark_one_selector_multiple_refs(): "marking_ref": MARKING_IDS[1], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.set_markings(before, [MARKING_IDS[0], MARKING_IDS[1]], ["description"]) for m in before["granular_markings"]: @@ -879,7 +879,7 @@ def test_set_marking_mark_multiple_selector_one_refs(): "marking_ref": MARKING_IDS[1], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( granular_markings=[ @@ -888,7 +888,7 @@ def test_set_marking_mark_multiple_selector_one_refs(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.set_markings(before, [MARKING_IDS[0]], ["description", "modified"]) for m in before["granular_markings"]: @@ -897,7 +897,7 @@ def test_set_marking_mark_multiple_selector_one_refs(): def test_set_marking_mark_multiple_selector_multiple_refs_from_none(): before = Malware( - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( granular_markings=[ @@ -910,7 +910,7 @@ def test_set_marking_mark_multiple_selector_multiple_refs_from_none(): "marking_ref": MARKING_IDS[1], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.set_markings(before, [MARKING_IDS[0], MARKING_IDS[1]], ["description", "modified"]) for m in before["granular_markings"]: @@ -925,7 +925,7 @@ def test_set_marking_mark_another_property_same_marking(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( granular_markings=[ @@ -938,7 +938,7 @@ def test_set_marking_mark_another_property_same_marking(): "marking_ref": MARKING_IDS[2], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.set_markings(before, [MARKING_IDS[1], MARKING_IDS[2]], ["description"]) @@ -962,7 +962,7 @@ def test_set_marking_bad_selector(marking): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( granular_markings=[ @@ -971,7 +971,7 @@ def test_set_marking_bad_selector(marking): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) with pytest.raises(InvalidSelectorError): @@ -988,7 +988,7 @@ def test_set_marking_mark_same_property_same_marking(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( granular_markings=[ @@ -997,7 +997,7 @@ def test_set_marking_mark_same_property_same_marking(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.set_markings(before, [MARKING_IDS[0]], ["description"]) for m in before["granular_markings"]: @@ -1020,7 +1020,7 @@ def test_set_marking_mark_same_property_same_marking(): "marking_ref": MARKING_IDS[2], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), dict( granular_markings=[ @@ -1037,7 +1037,7 @@ def test_set_marking_mark_same_property_same_marking(): "marking_ref": MARKING_IDS[2], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), ] @@ -1099,7 +1099,7 @@ def test_set_marking_on_id_property(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) assert "id" in malware["granular_markings"][0]["selectors"] diff --git a/stix2/test/v20/test_object_markings.py b/stix2/test/v20/test_object_markings.py index 6bd2269d..88b3f14b 100644 --- a/stix2/test/v20/test_object_markings.py +++ b/stix2/test/v20/test_object_markings.py @@ -26,7 +26,7 @@ Malware(**MALWARE_KWARGS), Malware( object_marking_refs=[MARKING_IDS[0]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), MARKING_IDS[0], ), @@ -34,7 +34,7 @@ MALWARE_KWARGS, dict( object_marking_refs=[MARKING_IDS[0]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), MARKING_IDS[0], ), @@ -42,7 +42,7 @@ Malware(**MALWARE_KWARGS), Malware( object_marking_refs=[TLP_AMBER.id], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), TLP_AMBER, ), @@ -60,12 +60,12 @@ def test_add_markings_one_marking(data): def test_add_markings_multiple_marking(): before = Malware( - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.add_markings(before, [MARKING_IDS[0], MARKING_IDS[1]], None) @@ -76,7 +76,7 @@ def test_add_markings_multiple_marking(): def test_add_markings_combination(): before = Malware( - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1]], @@ -90,7 +90,7 @@ def test_add_markings_combination(): "marking_ref": MARKING_IDS[3], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.add_markings(before, MARKING_IDS[0], None) @@ -114,7 +114,7 @@ def test_add_markings_combination(): ) def test_add_markings_bad_markings(data): before = Malware( - **MALWARE_KWARGS + **MALWARE_KWARGS, ) with pytest.raises(exceptions.InvalidValueError): before = markings.add_markings(before, data, None) @@ -274,14 +274,14 @@ def test_get_markings_object_and_granular_combinations(data): ( Malware( object_marking_refs=[MARKING_IDS[0]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), Malware(**MALWARE_KWARGS), ), ( dict( object_marking_refs=[MARKING_IDS[0]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), MALWARE_KWARGS, ), @@ -306,33 +306,33 @@ def test_remove_markings_object_level(data): ( Malware( object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], MARKING_IDS[2]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), Malware( object_marking_refs=[MARKING_IDS[1]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), [MARKING_IDS[0], MARKING_IDS[2]], ), ( dict( object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], MARKING_IDS[2]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), dict( object_marking_refs=[MARKING_IDS[1]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), [MARKING_IDS[0], MARKING_IDS[2]], ), ( Malware( object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], TLP_AMBER.id], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), Malware( object_marking_refs=[MARKING_IDS[1]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), [MARKING_IDS[0], TLP_AMBER], ), @@ -350,7 +350,7 @@ def test_remove_markings_multiple(data): def test_remove_markings_bad_markings(): before = Malware( object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], MARKING_IDS[2]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) with pytest.raises(MarkingNotFoundError) as excinfo: markings.remove_markings(before, [MARKING_IDS[4]], None) @@ -362,14 +362,14 @@ def test_remove_markings_bad_markings(): ( Malware( object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], MARKING_IDS[2]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), Malware(**MALWARE_KWARGS), ), ( dict( object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], MARKING_IDS[2]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), MALWARE_KWARGS, ), @@ -533,14 +533,14 @@ def test_is_marked_object_and_granular_combinations(): ( Malware( object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], MARKING_IDS[2]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), Malware(**MALWARE_KWARGS), ), ( dict( object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], MARKING_IDS[2]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), MALWARE_KWARGS, ), @@ -557,11 +557,11 @@ def test_is_marked_no_markings(data): def test_set_marking(): before = Malware( object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], MARKING_IDS[2]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( object_marking_refs=[MARKING_IDS[4], MARKING_IDS[5]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.set_markings(before, [MARKING_IDS[4], MARKING_IDS[5]], None) @@ -585,11 +585,11 @@ def test_set_marking(): def test_set_marking_bad_input(data): before = Malware( object_marking_refs=[MARKING_IDS[0]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( object_marking_refs=[MARKING_IDS[0]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) with pytest.raises(exceptions.InvalidValueError): before = markings.set_markings(before, data, None) diff --git a/stix2/test/v20/test_utils.py b/stix2/test/v20/test_utils.py index f61369b9..e3b15279 100644 --- a/stix2/test/v20/test_utils.py +++ b/stix2/test/v20/test_utils.py @@ -201,7 +201,7 @@ def test_deduplicate(stix_objs1): def test_find_property_index(object, tuple_to_find, expected_index): assert stix2.serialization.find_property_index( object, - *tuple_to_find + *tuple_to_find, ) == expected_index diff --git a/stix2/test/v20/test_versioning.py b/stix2/test/v20/test_versioning.py index d3973f03..ace1cc82 100644 --- a/stix2/test/v20/test_versioning.py +++ b/stix2/test/v20/test_versioning.py @@ -43,7 +43,7 @@ def test_making_new_version_with_embedded_object(): "source_name": "capec", "external_id": "CAPEC-163", }], - **CAMPAIGN_MORE_KWARGS + **CAMPAIGN_MORE_KWARGS, ) campaign_v2 = campaign_v1.new_version( diff --git a/stix2/test/v21/test_campaign.py b/stix2/test/v21/test_campaign.py index edc7d777..5fc8e4bb 100644 --- a/stix2/test/v21/test_campaign.py +++ b/stix2/test/v21/test_campaign.py @@ -21,7 +21,7 @@ def test_campaign_example(): campaign = stix2.v21.Campaign( - **CAMPAIGN_MORE_KWARGS + **CAMPAIGN_MORE_KWARGS, ) assert campaign.serialize(pretty=True) == EXPECTED diff --git a/stix2/test/v21/test_datastore_relational_db.py b/stix2/test/v21/test_datastore_relational_db.py new file mode 100644 index 00000000..316303db --- /dev/null +++ b/stix2/test/v21/test_datastore_relational_db.py @@ -0,0 +1,963 @@ +import contextlib +import datetime +import json +import os + +import pytest + +import stix2 +from stix2.datastore import DataSourceError +from stix2.datastore.relational_db.database_backends.postgres_backend import ( + PostgresBackend, +) +from stix2.datastore.relational_db.relational_db import RelationalDBStore +import stix2.properties +import stix2.registry +import stix2.v21 + +_DB_CONNECT_URL = f"postgresql://{os.getenv('POSTGRES_USER', 'postgres')}:{os.getenv('POSTGRES_PASSWORD', 'postgres')}@0.0.0.0:5432/postgres" + +store = RelationalDBStore( + PostgresBackend(_DB_CONNECT_URL, True), + True, + None, + False, +) + +# Artifacts +basic_artifact_dict = { + "type": "artifact", + "spec_version": "2.1", + "id": "artifact--cb37bcf8-9846-5ab4-8662-75c1bf6e63ee", + "mime_type": "image/jpeg", + "payload_bin": "VGhpcyBpcyBhIHBsYWNlaG9sZGVyIGZvciBhIHNhZmUgbWFsd2FyZSBiaW5hcnkh", +} + +encrypted_artifact_dict = { + "type": "artifact", + "spec_version": "2.1", + "id": "artifact--3857f78d-7d16-5092-99fe-ecff58408b02", + "mime_type": "application/zip", + "payload_bin": "VGhpcyBpcyBhIHBsYWNlaG9sZGVyIGZvciBhbiB1bnNhZmUgbWFsd2FyZSBiaW5hcnkh", + "hashes": { + "MD5": "6b885a1e1d42c0ca66e5f8a17e5a5d29", + "SHA-256": "3eea3c4819e9d387ff6809f13dde5426b9466285b7d923016b2842a13eb2888b", + }, + "encryption_algorithm": "mime-type-indicated", + "decryption_key": "My voice is my passport", +} + + +def test_basic_artifact(): + artifact_stix_object = stix2.parse(basic_artifact_dict) + store.add(artifact_stix_object) + read_obj = json.loads(store.get(artifact_stix_object['id']).serialize()) + + for attrib in basic_artifact_dict.keys(): + assert basic_artifact_dict[attrib] == read_obj[attrib] + + +def test_encrypted_artifact(): + artifact_stix_object = stix2.parse(encrypted_artifact_dict) + store.add(artifact_stix_object) + read_obj = json.loads(store.get(artifact_stix_object['id']).serialize()) + + for attrib in encrypted_artifact_dict.keys(): + assert encrypted_artifact_dict[attrib] == read_obj[attrib] + + +# Autonomous System +as_dict = { + "type": "autonomous-system", + "spec_version": "2.1", + "id": "autonomous-system--f822c34b-98ae-597f-ade5-27dc241e8c74", + "number": 15139, + "name": "Slime Industries", + "rir": "ARIN", +} + + +def test_autonomous_system(): + as_obj = stix2.parse(as_dict) + store.add(as_obj) + read_obj = json.loads(store.get(as_obj['id']).serialize()) + + for attrib in as_dict.keys(): + assert as_dict[attrib] == read_obj[attrib] + + +# Directory +directory_dict = { + "type": "directory", + "spec_version": "2.1", + "id": "directory--17c909b1-521d-545d-9094-1a08ddf46b05", + "ctime": "2018-11-23T08:17:27.000Z", + "mtime": "2018-11-23T08:17:27.000Z", + "path": "C:\\Windows\\System32", + "path_enc": "cGF0aF9lbmM", + "contains_refs": [ + "directory--94c0a9b0-520d-545d-9094-1a08ddf46b05", + "file--95c0a9b0-520d-545d-9094-1a08ddf46b05", + ], +} + + +def test_directory(): + directory_obj = stix2.parse(directory_dict) + store.add(directory_obj) + read_obj = json.loads(store.get(directory_obj['id']).serialize()) + + for attrib in directory_dict.keys(): + if attrib == "ctime" or attrib == "mtime": # convert both into stix2 date format for consistency + assert stix2.utils.parse_into_datetime(directory_dict[attrib]) == stix2.utils.parse_into_datetime(read_obj[attrib]) + continue + assert directory_dict[attrib] == read_obj[attrib] + + +# Domain Name +domain_name_dict = { + "type": "domain-name", + "spec_version": "2.1", + "id": "domain-name--3c10e93f-798e-5a26-a0c1-08156efab7f5", + "value": "example.com", +} + + +def test_domain_name(): + domain_name_obj = stix2.parse(domain_name_dict) + store.add(domain_name_obj) + read_obj = json.loads(store.get(domain_name_obj['id']).serialize()) + + for attrib in domain_name_dict.keys(): + assert domain_name_dict[attrib] == read_obj[attrib] + + +# Email Address +email_addr_dict = { + "type": "email-addr", + "spec_version": "2.1", + "id": "email-addr--2d77a846-6264-5d51-b586-e43822ea1ea3", + "value": "john@example.com", + "display_name": "John Doe", + "belongs_to_ref": "user-account--0d5b424b-93b8-5cd8-ac36-306e1789d63c", +} + + +def test_email_addr(): + email_addr_stix_object = stix2.parse(email_addr_dict) + store.add(email_addr_stix_object) + read_obj = json.loads(store.get(email_addr_stix_object['id']).serialize()) + + for attrib in email_addr_dict.keys(): + assert email_addr_dict[attrib] == read_obj[attrib] + + +# Email Message +email_msg_dict = { + "type": "email-message", + "spec_version": "2.1", + "id": "email-message--8c57a381-2a17-5e61-8754-5ef96efb286c", + "from_ref": "email-addr--9b7e29b3-fd8d-562e-b3f0-8fc8134f5dda", + "sender_ref": "email-addr--9b7e29b3-fd8d-562e-b3f0-8fc8134f5eeb", + "to_refs": ["email-addr--d1b3bf0c-f02a-51a1-8102-11aba7959868"], + "cc_refs": [ + "email-addr--d2b3bf0c-f02a-51a1-8102-11aba7959868", + "email-addr--d3b3bf0c-f02a-51a1-8102-11aba7959868", + ], + "bcc_refs": [ + "email-addr--d4b3bf0c-f02a-51a1-8102-11aba7959868", + "email-addr--d5b3bf0c-f02a-51a1-8102-11aba7959868", + ], + "message_id": "message01", + "is_multipart": False, + "date": "2004-04-19T12:22:23.000Z", + "subject": "Did you see this?", + "received_lines": [ + "from mail.example.com ([198.51.100.3]) by smtp.gmail.com with ESMTPSA id \ + q23sm23309939wme.17.2016.07.19.07.20.32 (version=TLS1_2 cipher=ECDHE-RSA-AES128-GCM-SHA256 \ + bits=128/128); Tue, 19 Jul 2016 07:20:40 -0700 (PDT)", + ], + "additional_header_fields": { + "Reply-To": [ + "steve@example.com", + "jane@example.com", + ], + }, + "body": "message body", + "raw_email_ref": "artifact--cb37bcf8-9846-5ab4-8662-75c1bf6e63ee", +} + +multipart_email_msg_dict = { + "type": "email-message", + "spec_version": "2.1", + "id": "email-message--ef9b4b7f-14c8-5955-8065-020e0316b559", + "is_multipart": True, + "received_lines": [ + "from mail.example.com ([198.51.100.3]) by smtp.gmail.com with ESMTPSA id \ + q23sm23309939wme.17.2016.07.19.07.20.32 (version=TLS1_2 cipher=ECDHE-RSA-AES128-GCM-SHA256 \ + bits=128/128); Tue, 19 Jul 2016 07:20:40 -0700 (PDT)", + ], + "content_type": "multipart/mixed", + "date": "2016-06-19T14:20:40.000Z", + "from_ref": "email-addr--89f52ea8-d6ef-51e9-8fce-6a29236436ed", + "to_refs": ["email-addr--d1b3bf0c-f02a-51a1-8102-11aba7959868"], + "cc_refs": ["email-addr--e4ee5301-b52d-59cd-a8fa-8036738c7194"], + "subject": "Check out this picture of a cat!", + "additional_header_fields": { + "Content-Disposition": ["inline"], + "X-Mailer": ["Mutt/1.5.23"], + "X-Originating-IP": ["198.51.100.3"], + }, + "body_multipart": [ + { + "content_type": "text/plain; charset=utf-8", + "content_disposition": "inline", + "body": "Cats are funny!", + }, + { + "content_type": "image/png", + "content_disposition": "attachment; filename=\"tabby.png\"", + "body_raw_ref": "artifact--4cce66f8-6eaa-53cb-85d5-3a85fca3a6c5", + }, + { + "content_type": "application/zip", + "content_disposition": "attachment; filename=\"tabby_pics.zip\"", + "body_raw_ref": "file--6ce09d9c-0ad3-5ebf-900c-e3cb288955b5", + }, + ], +} + + +def test_email_msg(): + email_msg_stix_object = stix2.parse(email_msg_dict) + store.add(email_msg_stix_object) + read_obj = json.loads(store.get(email_msg_stix_object['id']).serialize()) + + for attrib in email_msg_dict.keys(): + if attrib == "date": + assert stix2.utils.parse_into_datetime(email_msg_dict[attrib]) == stix2.utils.parse_into_datetime( + read_obj[attrib], + ) + continue + assert email_msg_dict[attrib] == read_obj[attrib] + + +def test_multipart_email_msg(): + multipart_email_msg_stix_object = stix2.parse(multipart_email_msg_dict) + store.add(multipart_email_msg_stix_object) + read_obj = json.loads(store.get(multipart_email_msg_stix_object['id']).serialize()) + + for attrib in multipart_email_msg_dict.keys(): + if attrib == "date": + assert stix2.utils.parse_into_datetime(multipart_email_msg_dict[attrib]) == stix2.utils.parse_into_datetime( + read_obj[attrib], + ) + continue + assert multipart_email_msg_dict[attrib] == read_obj[attrib] + + +# File +# errors when adding magic_number_hex to store, so ignoring for now +file_dict = { + "type": "file", + "spec_version": "2.1", + "id": "file--66156fad-2a7d-5237-bbb4-ba1912887cfe", + "hashes": { + "SHA-256": "ceafbfd424be2ca4a5f0402cae090dda2fb0526cf521b60b60077c0f622b285a", + }, + "parent_directory_ref": "directory--93c0a9b0-520d-545d-9094-1a08ddf46b05", + "name": "qwerty.dll", + "size": 25536, + "name_enc": "windows-1252", + "magic_number_hex": "a1b2c3", + "mime_type": "application/msword", + "ctime": "2018-11-23T08:17:27.000Z", + "mtime": "2018-11-23T08:17:27.000Z", + "atime": "2018-11-23T08:17:27.000Z", + "contains_refs": [ + "file--77156fad-2a0d-5237-bba4-ba1912887cfe", + ], + "content_ref": "artifact--cb37bcf8-9846-5ab4-8662-75c1bf6e63ee", +} + + +def test_file(): + file_stix_object = stix2.parse(file_dict) + store.add(file_stix_object) + read_obj = json.loads(store.get(file_stix_object['id']).serialize()) + + for attrib in file_dict.keys(): + if attrib == "ctime" or attrib == "mtime" or attrib == "atime": + assert stix2.utils.parse_into_datetime(file_dict[attrib]) == stix2.utils.parse_into_datetime(read_obj[attrib]) + continue + assert file_dict[attrib] == read_obj[attrib] + + +# ipv4 ipv6 +ipv4_dict = { + "type": "ipv4-addr", + "spec_version": "2.1", + "id": "ipv4-addr--ff26c255-6336-5bc5-b98d-13d6226742dd", + "value": "198.51.100.3", +} + +ipv6_dict = { + "type": "ipv6-addr", + "spec_version": "2.1", + "id": "ipv6-addr--1e61d36c-a26c-53b7-a80f-2a00161c96b1", + "value": "2001:0db8:85a3:0000:0000:8a2e:0370:7334", +} + + +def test_ipv4(): + ipv4_stix_object = stix2.parse(ipv4_dict) + store.add(ipv4_stix_object) + read_obj = store.get(ipv4_stix_object['id']) + + for attrib in ipv4_dict.keys(): + assert ipv4_dict[attrib] == read_obj[attrib] + + +def test_ipv6(): + ipv6_stix_object = stix2.parse(ipv6_dict) + store.add(ipv6_stix_object) + read_obj = store.get(ipv6_stix_object['id']) + + for attrib in ipv6_dict.keys(): + assert ipv6_dict[attrib] == read_obj[attrib] + + +# Mutex +mutex_dict = { + "type": "mutex", + "spec_version": "2.1", + "id": "mutex--fba44954-d4e4-5d3b-814c-2b17dd8de300", + "name": "__CLEANSWEEP__", +} + + +def test_mutex(): + mutex_stix_object = stix2.parse(mutex_dict) + store.add(mutex_stix_object) + read_obj = store.get(mutex_stix_object['id']) + + for attrib in mutex_dict.keys(): + assert mutex_dict[attrib] == read_obj[attrib] + + +# Network Traffic +# ipfix property results in a unconsumed value error with the store add +network_traffic_dict = { + "type": "network-traffic", + "spec_version": "2.1", + "id": "network-traffic--631d7bb1-6bbc-53a6-a6d4-f3c2d35c2734", + "src_ref": "ipv4-addr--4d22aae0-2bf9-5427-8819-e4f6abf20a53", + "dst_ref": "ipv4-addr--03b708d9-7761-5523-ab75-5ea096294a68", + "start": "2018-11-23T08:17:27.000Z", + "end": "2018-11-23T08:18:27.000Z", + "is_active": False, + "src_port": 1000, + "dst_port": 1000, + "protocols": [ + "ipv4", + "tcp", + ], + "src_byte_count": 147600, + "dst_byte_count": 147600, + "src_packets": 100, + "dst_packets": 100, + "src_payload_ref": "artifact--3857f78d-7d16-5092-99fe-ecff58408b02", + "dst_payload_ref": "artifact--3857f78d-7d16-5092-99fe-ecff58408b03", + "encapsulates_refs": [ + "network-traffic--53e0bf48-2eee-5c03-8bde-ed7049d2c0a3", + "network-traffic--53e0bf48-2eee-5c03-8bde-ed7049d2c0a4", + ], + "encapsulated_by_ref": "network-traffic--53e0bf48-2eee-5c03-8bde-ed7049d2c0a5", +} + + +def test_network_traffic(): + network_traffic_stix_object = stix2.parse(network_traffic_dict) + store.add(network_traffic_stix_object) + read_obj = store.get(network_traffic_stix_object['id']) + + for attrib in network_traffic_dict.keys(): + if attrib == "start" or attrib == "end": + assert stix2.utils.parse_into_datetime(network_traffic_dict[attrib]) == stix2.utils.parse_into_datetime(read_obj[attrib]) + continue + assert network_traffic_dict[attrib] == read_obj[attrib] + + +# Process +process_dict = { + "type": "process", + "spec_version": "2.1", + "id": "process--f52a906a-0dfc-40bd-92f1-e7778ead38a9", + "is_hidden": False, + "pid": 1221, + "created_time": "2016-01-20T14:11:25.55Z", + "cwd": "/tmp/", + "environment_variables": { + "ENVTEST": "/path/to/bin", + }, + "command_line": "./gedit-bin --new-window", + "opened_connection_refs": [ + "network-traffic--53e0bf48-2eee-5c03-8bde-ed7049d2c0a3", + ], + "creator_user_ref": "user-account--cb37bcf8-9846-5ab4-8662-75c1bf6e63ee", + "image_ref": "file--e04f22d1-be2c-59de-add8-10f61d15fe20", + "parent_ref": "process--f52a906a-1dfc-40bd-92f1-e7778ead38a9", + "child_refs": [ + "process--ff2a906a-1dfc-40bd-92f1-e7778ead38a9", + "process--fe2a906a-1dfc-40bd-92f1-e7778ead38a9", + ], +} + + +def test_process(): + process_stix_object = stix2.parse(process_dict) + store.add(process_stix_object) + read_obj = json.loads(store.get(process_stix_object['id']).serialize()) + + for attrib in process_dict.keys(): + if attrib == "created_time": + assert stix2.utils.parse_into_datetime(process_dict[attrib]) == stix2.utils.parse_into_datetime(read_obj[attrib]) + continue + assert process_dict[attrib] == read_obj[attrib] + + +# Software +software_dict = { + "type": "software", + "spec_version": "2.1", + "id": "software--a1827f6d-ca53-5605-9e93-4316cd22a00a", + "name": "Word", + "cpe": "cpe:2.3:a:microsoft:word:2000:*:*:*:*:*:*:*", + "version": "2002", + "vendor": "Microsoft", +} + + +def test_software(): + software_stix_object = stix2.parse(software_dict) + store.add(software_stix_object) + read_obj = json.loads(store.get(software_stix_object['id']).serialize()) + + for attrib in software_dict.keys(): + assert software_dict[attrib] == read_obj[attrib] + + +# URL +url_dict = { + "type": "url", + "id": "url--a5477287-23ac-5971-a010-5c287877fa60", + "value": "https://example.com/research/index.html", +} + + +def test_url(): + url_stix_object = stix2.parse(url_dict) + store.add(url_stix_object) + read_obj = json.loads(store.get(url_stix_object['id']).serialize()) + + for attrib in url_dict.keys(): + assert url_dict[attrib] == read_obj[attrib] + + +# User Account +user_account_dict = { + "type": "user-account", + "spec_version": "2.1", + "id": "user-account--0d5b424b-93b8-5cd8-ac36-306e1789d63c", + "user_id": "1001", + "credential": "password", + "account_login": "jdoe", + "account_type": "unix", + "display_name": "John Doe", + "is_service_account": False, + "is_privileged": False, + "can_escalate_privs": True, + "is_disabled": False, + "account_created": "2016-01-20T12:31:12Z", + "account_expires": "2018-01-20T12:31:12Z", + "credential_last_changed": "2016-01-20T14:27:43Z", + "account_first_login": "2016-01-20T14:26:07Z", + "account_last_login": "2016-07-22T16:08:28Z", +} + + +def test_user_account(): + user_account_stix_object = stix2.parse(user_account_dict) + store.add(user_account_stix_object) + read_obj = json.loads(store.get(user_account_stix_object['id']).serialize()) + + for attrib in user_account_dict.keys(): + if attrib == "account_created" or attrib == "account_expires" \ + or attrib == "credential_last_changed" or attrib == "account_first_login" \ + or attrib == "account_last_login": + assert stix2.utils.parse_into_datetime(user_account_dict[attrib]) == stix2.utils.parse_into_datetime( + read_obj[attrib], + ) + continue + assert user_account_dict[attrib] == read_obj[attrib] + + +# Windows Registry +windows_registry_dict = { + "type": "windows-registry-key", + "spec_version": "2.1", + "id": "windows-registry-key--2ba37ae7-2745-5082-9dfd-9486dad41016", + "key": "hkey_local_machine\\system\\bar\\foo", + "values": [ + { + "name": "Foo", + "data": "qwerty", + "data_type": "REG_SZ", + }, + { + "name": "Bar", + "data": "42", + "data_type": "REG_DWORD", + }, + ], + "modified_time": "2018-01-20T12:31:12Z", + "creator_user_ref": "user-account--0d5b424b-93b8-5cd8-ac36-306e1789d63c", + "number_of_subkeys": 2, +} + + +def test_windows_registry(): + windows_registry_stix_object = stix2.parse(windows_registry_dict) + store.add(windows_registry_stix_object) + read_obj = json.loads(store.get(windows_registry_stix_object['id']).serialize()) + + for attrib in windows_registry_dict.keys(): + if attrib == "modified_time": + assert stix2.utils.parse_into_datetime(windows_registry_dict[attrib]) == stix2.utils.parse_into_datetime( + read_obj[attrib], + ) + continue + assert windows_registry_dict[attrib] == read_obj[attrib] + + +# x509 Certificate +basic_x509_certificate_dict = { + "type": "x509-certificate", + "spec_version": "2.1", + "id": "x509-certificate--463d7b2a-8516-5a50-a3d7-6f801465d5de", + "issuer": "C=ZA, ST=Western Cape, L=Cape Town, O=Thawte Consulting cc, OU=Certification \ + Services Division, CN=Thawte Server CA/emailAddress=server-certs@thawte.com", + "validity_not_before": "2016-03-12T12:00:00Z", + "validity_not_after": "2016-08-21T12:00:00Z", + "subject": "C=US, ST=Maryland, L=Pasadena, O=Brent Baccala, OU=FreeSoft, \ + CN=www.freesoft.org/emailAddress=baccala@freesoft.org", + "serial_number": "36:f7:d4:32:f4:ab:70:ea:d3:ce:98:6e:ea:99:93:49:32:0a:b7:06", +} + +extensions_x509_certificate_dict = { + "type": "x509-certificate", + "spec_version": "2.1", + "id": "x509-certificate--b595eaf0-0b28-5dad-9e8e-0fab9c1facc9", + "issuer": "C=ZA, ST=Western Cape, L=Cape Town, O=Thawte Consulting cc, OU=Certification \ + Services Division, CN=Thawte Server CA/emailAddress=server-certs@thawte.com", + "validity_not_before": "2016-03-12T12:00:00Z", + "validity_not_after": "2016-08-21T12:00:00Z", + "subject": "C=US, ST=Maryland, L=Pasadena, O=Brent Baccala, OU=FreeSoft, \ + CN=www.freesoft.org/emailAddress=baccala@freesoft.org", + "serial_number": "02:08:87:83:f2:13:58:1f:79:52:1e:66:90:0a:02:24:c9:6b:c7:dc", + "x509_v3_extensions": { + "basic_constraints": "critical,CA:TRUE, pathlen:0", + "name_constraints": "permitted;IP:192.168.0.0/255.255.0.0", + "policy_constraints": "requireExplicitPolicy:3", + "key_usage": "critical, keyCertSign", + "extended_key_usage": "critical,codeSigning,1.2.3.4", + "subject_key_identifier": "hash", + "authority_key_identifier": "keyid,issuer", + "subject_alternative_name": "email:my@other.address,RID:1.2.3.4", + "issuer_alternative_name": "issuer:copy", + "crl_distribution_points": "URI:http://myhost.com/myca.crl", + "inhibit_any_policy": "2", + "private_key_usage_period_not_before": "2016-03-12T12:00:00Z", + "private_key_usage_period_not_after": "2018-03-12T12:00:00Z", + "certificate_policies": "1.2.4.5, 1.1.3.4", + }, +} + + +def test_basic_x509_certificate(): + basic_x509_certificate_stix_object = stix2.parse(basic_x509_certificate_dict) + store.add(basic_x509_certificate_stix_object) + read_obj = json.loads(store.get(basic_x509_certificate_stix_object['id']).serialize()) + + for attrib in basic_x509_certificate_dict.keys(): + if attrib == "validity_not_before" or attrib == "validity_not_after": + assert stix2.utils.parse_into_datetime( + basic_x509_certificate_dict[attrib], + ) == stix2.utils.parse_into_datetime(read_obj[attrib]) + continue + assert basic_x509_certificate_dict[attrib] == read_obj[attrib] + + +def test_x509_certificate_with_extensions(): + extensions_x509_certificate_stix_object = stix2.parse(extensions_x509_certificate_dict) + store.add(extensions_x509_certificate_stix_object) + read_obj = json.loads(store.get(extensions_x509_certificate_stix_object['id']).serialize()) + + for attrib in extensions_x509_certificate_dict.keys(): + if attrib == "validity_not_before" or attrib == "validity_not_after": + assert stix2.utils.parse_into_datetime( + extensions_x509_certificate_dict[attrib], + ) == stix2.utils.parse_into_datetime(read_obj[attrib]) + continue + assert extensions_x509_certificate_dict[attrib] == read_obj[attrib] + + +def test_source_get_not_exists(): + obj = store.get("identity--00000000-0000-0000-0000-000000000000") + assert obj is None + + +def test_source_no_registration(): + with pytest.raises(DataSourceError): + # error, since no registered class can be found + store.get("doesnt-exist--a9e52398-3312-4377-90c2-86d49446c0d0") + + +def _unregister(reg_section, stix_type, ext_id=None): + """ + Unregister a class from the stix2 library's registry. + + :param reg_section: A registry section; depends on the kind of + class which was registered + :param stix_type: A STIX type + :param ext_id: An extension-definition ID, if applicable. A second + unregistration will occur in the extensions section of the registry if + given. + """ + # We ought to have a library function for this... + del stix2.registry.STIX2_OBJ_MAPS["2.1"][reg_section][stix_type] + if ext_id: + del stix2.registry.STIX2_OBJ_MAPS["2.1"]["extensions"][ext_id] + + +@contextlib.contextmanager +def _register_object(*args, **kwargs): + """ + A contextmanager which can register a class for an SDO/SRO and ensure it is + unregistered afterword. + + :param args: Positional args to a @CustomObject decorator + :param kwargs: Keyword args to a @CustomObject decorator + :return: The registered class + """ + @stix2.CustomObject(*args, **kwargs) + class TestClass: + pass + + try: + yield TestClass + except: + ext_id = kwargs.get("extension_name") + if not ext_id and len(args) >= 3: + ext_id = args[2] + + _unregister("objects", TestClass._type, ext_id) + + raise + + +@contextlib.contextmanager +def _register_observable(*args, **kwargs): + """ + A contextmanager which can register a class for an SCO and ensure it is + unregistered afterword. + + :param args: Positional args to a @CustomObservable decorator + :param kwargs: Keyword args to a @CustomObservable decorator + :return: The registered class + """ + @stix2.CustomObservable(*args, **kwargs) + class TestClass: + pass + + try: + yield TestClass + except: + ext_id = kwargs.get("extension_name") + if not ext_id and len(args) >= 4: + ext_id = args[3] + + _unregister("observables", TestClass._type, ext_id) + + raise + + +# "Base" properties used to derive property variations for testing (e.g. in a +# list, in a dictionary, in an embedded object, etc). Also includes sample +# values used to create test objects. The keys here are used to parameterize a +# fixture below. Parameterizing fixtures via simple strings makes for more +# understandable unit test output, although it can be kind of awkward in the +# implementation (can require long if-then chains checking the parameter +# strings). +_TEST_PROPERTIES = { + "binary": (stix2.properties.BinaryProperty(), "Af9J"), + "boolean": (stix2.properties.BooleanProperty(), True), + "float": (stix2.properties.FloatProperty(), 1.23), + "hex": (stix2.properties.HexProperty(), "a1b2c3"), + "integer": (stix2.properties.IntegerProperty(), 1), + "string": (stix2.properties.StringProperty(), "test"), + "timestamp": ( + stix2.properties.TimestampProperty(), + datetime.datetime.now(tz=datetime.timezone.utc), + ), + "ref": ( + stix2.properties.ReferenceProperty("SDO"), + "identity--ec83b570-0743-4179-a5e3-66fd2fae4711", + ), + "enum": ( + stix2.properties.EnumProperty(["value1", "value2"]), + "value1", + ), +} + + +@pytest.fixture(params=_TEST_PROPERTIES.keys()) +def base_property_value(request): + """Produce basic property instances and test values.""" + + base = _TEST_PROPERTIES.get(request.param) + if not base: + pytest.fail("Unrecognized base property: " + request.param) + + return base + + +@pytest.fixture( + params=[ + "base", + "list-of", + "dict-of", + # The following two test nesting lists inside dicts and vice versa + "dict-list-of", + "list-dict-of", + "subobject", + "list-of-subobject-prop", + "list-of-subobject-class", + ], +) +def property_variation_value(request, base_property_value): + """ + Produce property variations (and corresponding value variations) based on a + base property instance and value. E.g. in a list, in a sub-object, etc. + """ + base_property, prop_value = base_property_value + + class Embedded(stix2.v21._STIXBase21): + """ + Used for property variations where the property is embedded in a + sub-object. + """ + _properties = { + "embedded": base_property, + } + + if request.param == "base": + prop_variation = base_property + prop_variation_value = prop_value + + elif request.param == "list-of": + prop_variation = stix2.properties.ListProperty(base_property) + prop_variation_value = [prop_value] + + elif request.param == "dict-of": + prop_variation = stix2.properties.DictionaryProperty( + valid_types=base_property, + ) + # key name doesn't matter here + prop_variation_value = {"key": prop_value} + + elif request.param == "dict-list-of": + prop_variation = stix2.properties.DictionaryProperty( + valid_types=stix2.properties.ListProperty(base_property), + ) + # key name doesn't matter here + prop_variation_value = {"key": [prop_value]} + + elif request.param == "list-dict-of": + # These seem to all fail... perhaps there is no intent to support + # this? + pytest.xfail("ListProperty(DictionaryProperty) not supported?") + + # prop_variation = stix2.properties.ListProperty( + # stix2.properties.DictionaryProperty(valid_types=type(base_property)) + # ) + # key name doesn't matter here + # prop_variation_value = [{"key": prop_value}] + + elif request.param == "subobject": + prop_variation = stix2.properties.EmbeddedObjectProperty(Embedded) + prop_variation_value = {"embedded": prop_value} + + elif request.param == "list-of-subobject-prop": + # list-of-embedded values via EmbeddedObjectProperty + prop_variation = stix2.properties.ListProperty( + stix2.properties.EmbeddedObjectProperty(Embedded), + ) + prop_variation_value = [{"embedded": prop_value}] + + elif request.param == "list-of-subobject-class": + # Skip all of these since we know the data sink currently chokes on it + pytest.xfail("Data sink doesn't yet support ListProperty(<_STIXBase subclass>)") + + # list-of-embedded values using the embedded class directly + # prop_variation = stix2.properties.ListProperty(Embedded) + # prop_variation_value = [{"embedded": prop_value}] + + else: + pytest.fail("Unrecognized property variation: " + request.param) + + return prop_variation, prop_variation_value + + +@pytest.fixture(params=["sdo", "sco", "sro"]) +def object_variation(request, property_variation_value): + """ + Create and register a custom class variation (SDO, SCO, etc), then + instantiate it and produce the resulting object. + """ + + property_instance, property_value = property_variation_value + + # Fixed extension ID for everything + ext_id = "extension-definition--15de9cdb-3515-4271-8479-8141154c5647" + + if request.param == "sdo": + @stix2.CustomObject( + "test-object", [ + ("prop_name", property_instance), + ], + ext_id, + is_sdo=True, + ) + class TestClass: + pass + + elif request.param == "sro": + @stix2.CustomObject( + "test-object", [ + ("prop_name", property_instance), + ], + ext_id, + is_sdo=False, + ) + class TestClass: + pass + + elif request.param == "sco": + @stix2.CustomObservable( + "test-object", [ + ("prop_name", property_instance), + ], + ["prop_name"], + ext_id, + ) + class TestClass: + pass + + else: + pytest.fail("Unrecognized object variation: " + request.param) + + try: + instance = TestClass(prop_name=property_value) + yield instance + finally: + reg_section = "observables" if request.param == "sco" else "objects" + _unregister(reg_section, TestClass._type, ext_id) + + +def test_property(object_variation): + """ + Try to more exhaustively test many different property configurations: + ensure schemas can be created and values can be stored and retrieved. + """ + rdb_store = RelationalDBStore( + PostgresBackend(_DB_CONNECT_URL, True), + True, + None, + True, + True, + type(object_variation), + ) + + rdb_store.add(object_variation) + read_obj = rdb_store.get(object_variation["id"]) + + assert read_obj == object_variation + + +def test_dictionary_property_complex(): + """ + Test a dictionary property with multiple valid_types + """ + with _register_object( + "test-object", [ + ( + "prop_name", + stix2.properties.DictionaryProperty( + valid_types=[ + stix2.properties.IntegerProperty, + stix2.properties.FloatProperty, + stix2.properties.StringProperty, + ], + ), + ), + ], + "extension-definition--15de9cdb-3515-4271-8479-8141154c5647", + is_sdo=True, + ) as cls: + + obj = cls( + prop_name={"a": 1, "b": 2.3, "c": "foo"}, + ) + + rdb_store = RelationalDBStore( + PostgresBackend(_DB_CONNECT_URL, True), + True, + None, + True, + True, + cls, + ) + + rdb_store.add(obj) + read_obj = rdb_store.get(obj["id"]) + assert read_obj == obj + + +def test_extension_definition(): + obj = stix2.ExtensionDefinition( + created_by_ref="identity--8a5fb7e4-aabe-4635-8972-cbcde1fa4792", + labels=["label1", "label2"], + name="test", + schema="a schema", + version="1.2.3", + extension_types=["property-extension", "new-sdo", "new-sro"], + object_marking_refs=[ + "marking-definition--caa0d913-5db8-4424-aae0-43e770287d30", + "marking-definition--122a27a0-b96f-46bc-8fcd-f7a159757e77", + ], + granular_markings=[ + { + "lang": "en_US", + "selectors": ["name", "schema"], + }, + { + "marking_ref": "marking-definition--50902d70-37ae-4f85-af68-3f4095493b42", + "selectors": ["name", "schema"], + }, + ], + ) + + store.add(obj) + read_obj = store.get(obj["id"]) + assert read_obj == obj diff --git a/stix2/test/v21/test_environment.py b/stix2/test/v21/test_environment.py index 51ca15a5..502f2b77 100644 --- a/stix2/test/v21/test_environment.py +++ b/stix2/test/v21/test_environment.py @@ -973,7 +973,7 @@ def test_semantic_check_with_versioning(ds, ds2): }, ], object_marking_refs=[stix2.v21.TLP_WHITE], - ) + ), ) ds.add(ind) score = stix2.equivalence.object.reference_check(ind.id, INDICATOR_ID, ds, ds2, **weights) @@ -1146,7 +1146,7 @@ def test_depth_limiting(): } prop_scores1 = {} env1 = stix2.equivalence.graph.graph_similarity( - mem_store1, mem_store2, prop_scores1, **custom_weights + mem_store1, mem_store2, prop_scores1, **custom_weights, ) assert round(env1) == 38 @@ -1159,7 +1159,7 @@ def test_depth_limiting(): # Switching parameters prop_scores2 = {} env2 = stix2.equivalence.graph.graph_similarity( - mem_store2, mem_store1, prop_scores2, **custom_weights + mem_store2, mem_store1, prop_scores2, **custom_weights, ) assert round(env2) == 38 diff --git a/stix2/test/v21/test_granular_markings.py b/stix2/test/v21/test_granular_markings.py index ff8fe26d..2f5bae8c 100644 --- a/stix2/test/v21/test_granular_markings.py +++ b/stix2/test/v21/test_granular_markings.py @@ -14,7 +14,7 @@ def test_add_marking_mark_one_selector_multiple_refs(): before = Malware( - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( granular_markings=[ @@ -27,7 +27,7 @@ def test_add_marking_mark_one_selector_multiple_refs(): "marking_ref": MARKING_IDS[1], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.add_markings(before, [MARKING_IDS[0], MARKING_IDS[1]], ["description"]) @@ -46,7 +46,7 @@ def test_add_marking_mark_one_selector_multiple_refs(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), MARKING_IDS[0], ), @@ -59,7 +59,7 @@ def test_add_marking_mark_one_selector_multiple_refs(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), MARKING_IDS[0], ), @@ -72,7 +72,7 @@ def test_add_marking_mark_one_selector_multiple_refs(): "marking_ref": TLP_RED.id, }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), TLP_RED, ), @@ -90,7 +90,7 @@ def test_add_marking_mark_multiple_selector_one_refs(data): def test_add_marking_mark_multiple_selector_multiple_refs(): before = Malware( - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( granular_markings=[ @@ -103,7 +103,7 @@ def test_add_marking_mark_multiple_selector_multiple_refs(): "marking_ref": MARKING_IDS[1], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.add_markings(before, [MARKING_IDS[0], MARKING_IDS[1]], ["description", "name"]) @@ -113,7 +113,7 @@ def test_add_marking_mark_multiple_selector_multiple_refs(): def test_add_marking_mark_multiple_selector_multiple_refs_mixed(): before = Malware( - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( granular_markings=[ @@ -134,7 +134,7 @@ def test_add_marking_mark_multiple_selector_multiple_refs_mixed(): "lang": MARKING_LANGS[1], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.add_markings(before, [MARKING_IDS[0], MARKING_IDS[1], MARKING_LANGS[0], MARKING_LANGS[1]], ["description", "name"]) @@ -150,7 +150,7 @@ def test_add_marking_mark_another_property_same_marking(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( granular_markings=[ @@ -159,7 +159,7 @@ def test_add_marking_mark_another_property_same_marking(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.add_markings(before, [MARKING_IDS[0]], ["name"]) @@ -175,7 +175,7 @@ def test_add_marking_mark_same_property_same_marking(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( granular_markings=[ @@ -184,7 +184,7 @@ def test_add_marking_mark_same_property_same_marking(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.add_markings(before, [MARKING_IDS[0]], ["description"]) @@ -513,7 +513,7 @@ def test_get_markings_multiple_selectors_with_options(data): "marking_ref": MARKING_IDS[1], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), [MARKING_IDS[0], MARKING_IDS[1]], ), @@ -529,7 +529,7 @@ def test_get_markings_multiple_selectors_with_options(data): "marking_ref": MARKING_IDS[1], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), [MARKING_IDS[0], MARKING_IDS[1]], ), @@ -548,7 +548,7 @@ def test_remove_marking_remove_multiple_selector_one_ref(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.remove_markings(before, MARKING_IDS[0], ["description", "modified"]) assert "granular_markings" not in before @@ -562,7 +562,7 @@ def test_remove_marking_mark_one_selector_from_multiple_ones(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = Malware( granular_markings=[ @@ -571,7 +571,7 @@ def test_remove_marking_mark_one_selector_from_multiple_ones(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.remove_markings(before, [MARKING_IDS[0]], ["modified"]) for m in before["granular_markings"]: @@ -590,7 +590,7 @@ def test_remove_marking_mark_one_selector_from_multiple_ones_mixed(): "lang": MARKING_LANGS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = Malware( granular_markings=[ @@ -603,7 +603,7 @@ def test_remove_marking_mark_one_selector_from_multiple_ones_mixed(): "lang": MARKING_LANGS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.remove_markings(before, [MARKING_IDS[0], MARKING_LANGS[0]], ["modified"]) for m in before["granular_markings"]: @@ -622,7 +622,7 @@ def test_remove_marking_mark_one_selector_markings_from_multiple_ones(): "marking_ref": MARKING_IDS[1], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = Malware( granular_markings=[ @@ -635,7 +635,7 @@ def test_remove_marking_mark_one_selector_markings_from_multiple_ones(): "marking_ref": MARKING_IDS[1], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.remove_markings(before, [MARKING_IDS[0]], ["modified"]) for m in before["granular_markings"]: @@ -654,7 +654,7 @@ def test_remove_marking_mark_mutilple_selector_multiple_refs(): "marking_ref": MARKING_IDS[1], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.remove_markings(before, [MARKING_IDS[0], MARKING_IDS[1]], ["description", "modified"]) assert "granular_markings" not in before @@ -668,7 +668,7 @@ def test_remove_marking_mark_another_property_same_marking(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = Malware( granular_markings=[ @@ -681,7 +681,7 @@ def test_remove_marking_mark_another_property_same_marking(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.remove_markings(before, [MARKING_IDS[0]], ["modified"]) for m in before["granular_markings"]: @@ -696,7 +696,7 @@ def test_remove_marking_mark_same_property_same_marking(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.remove_markings(before, [MARKING_IDS[0]], ["description"]) assert "granular_markings" not in before @@ -726,7 +726,7 @@ def test_remove_marking_not_present(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) with pytest.raises(MarkingNotFoundError): markings.remove_markings(before, [MARKING_IDS[1]], ["description"]) @@ -752,7 +752,7 @@ def test_remove_marking_not_present(): "lang": MARKING_LANGS[1], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), dict( granular_markings=[ @@ -773,7 +773,7 @@ def test_remove_marking_not_present(): "lang": MARKING_LANGS[1], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), ] @@ -1008,14 +1008,14 @@ def test_create_sdo_with_invalid_marking(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) assert str(excinfo.value) == "Selector foo in Malware is not valid!" def test_set_marking_mark_one_selector_multiple_refs(): before = Malware( - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( granular_markings=[ @@ -1028,7 +1028,7 @@ def test_set_marking_mark_one_selector_multiple_refs(): "marking_ref": MARKING_IDS[1], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.set_markings(before, [MARKING_IDS[0], MARKING_IDS[1]], ["description"]) for m in before["granular_markings"]: @@ -1037,7 +1037,7 @@ def test_set_marking_mark_one_selector_multiple_refs(): def test_set_marking_mark_one_selector_multiple_lang_refs(): before = Malware( - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( granular_markings=[ @@ -1050,7 +1050,7 @@ def test_set_marking_mark_one_selector_multiple_lang_refs(): "lang": MARKING_LANGS[1], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.set_markings(before, [MARKING_LANGS[0], MARKING_LANGS[1]], ["description"]) for m in before["granular_markings"]: @@ -1065,7 +1065,7 @@ def test_set_marking_mark_multiple_selector_one_refs(): "marking_ref": MARKING_IDS[1], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( granular_markings=[ @@ -1074,7 +1074,7 @@ def test_set_marking_mark_multiple_selector_one_refs(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.set_markings(before, [MARKING_IDS[0]], ["description", "modified"]) for m in before["granular_markings"]: @@ -1093,7 +1093,7 @@ def test_set_marking_mark_multiple_mixed_markings(): "lang": MARKING_LANGS[2], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( granular_markings=[ @@ -1106,7 +1106,7 @@ def test_set_marking_mark_multiple_mixed_markings(): "lang": MARKING_LANGS[3], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.set_markings(before, [MARKING_IDS[2], MARKING_LANGS[3]], ["description", "modified"]) for m in before["granular_markings"]: @@ -1115,7 +1115,7 @@ def test_set_marking_mark_multiple_mixed_markings(): def test_set_marking_mark_multiple_selector_multiple_refs_from_none(): before = Malware( - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( granular_markings=[ @@ -1128,7 +1128,7 @@ def test_set_marking_mark_multiple_selector_multiple_refs_from_none(): "marking_ref": MARKING_IDS[1], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.set_markings(before, [MARKING_IDS[0], MARKING_IDS[1]], ["description", "modified"]) for m in before["granular_markings"]: @@ -1143,7 +1143,7 @@ def test_set_marking_mark_another_property_same_marking(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( granular_markings=[ @@ -1156,7 +1156,7 @@ def test_set_marking_mark_another_property_same_marking(): "marking_ref": MARKING_IDS[2], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.set_markings(before, [MARKING_IDS[1], MARKING_IDS[2]], ["description"]) @@ -1180,7 +1180,7 @@ def test_set_marking_bad_selector(marking): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( granular_markings=[ @@ -1189,7 +1189,7 @@ def test_set_marking_bad_selector(marking): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) with pytest.raises(InvalidSelectorError): @@ -1206,7 +1206,7 @@ def test_set_marking_mark_same_property_same_marking(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( granular_markings=[ @@ -1215,7 +1215,7 @@ def test_set_marking_mark_same_property_same_marking(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.set_markings(before, [MARKING_IDS[0]], ["description"]) for m in before["granular_markings"]: @@ -1238,7 +1238,7 @@ def test_set_marking_mark_same_property_same_marking(): "marking_ref": MARKING_IDS[2], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), dict( granular_markings=[ @@ -1255,7 +1255,7 @@ def test_set_marking_mark_same_property_same_marking(): "marking_ref": MARKING_IDS[2], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), ] @@ -1317,7 +1317,7 @@ def test_set_marking_on_id_property(): "marking_ref": MARKING_IDS[0], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) assert "id" in malware["granular_markings"][0]["selectors"] diff --git a/stix2/test/v21/test_object_markings.py b/stix2/test/v21/test_object_markings.py index bb1c4ab0..63273fb7 100644 --- a/stix2/test/v21/test_object_markings.py +++ b/stix2/test/v21/test_object_markings.py @@ -25,7 +25,7 @@ Malware(**MALWARE_KWARGS), Malware( object_marking_refs=[MARKING_IDS[0]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), MARKING_IDS[0], ), @@ -33,7 +33,7 @@ MALWARE_KWARGS, dict( object_marking_refs=[MARKING_IDS[0]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), MARKING_IDS[0], ), @@ -41,7 +41,7 @@ Malware(**MALWARE_KWARGS), Malware( object_marking_refs=[TLP_AMBER.id], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), TLP_AMBER, ), @@ -59,12 +59,12 @@ def test_add_markings_one_marking(data): def test_add_markings_multiple_marking(): before = Malware( - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.add_markings(before, [MARKING_IDS[0], MARKING_IDS[1]], None) @@ -75,7 +75,7 @@ def test_add_markings_multiple_marking(): def test_add_markings_combination(): before = Malware( - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1]], @@ -89,7 +89,7 @@ def test_add_markings_combination(): "marking_ref": MARKING_IDS[3], }, ], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.add_markings(before, MARKING_IDS[0], None) @@ -113,7 +113,7 @@ def test_add_markings_combination(): ) def test_add_markings_bad_markings(data): before = Malware( - **MALWARE_KWARGS + **MALWARE_KWARGS, ) with pytest.raises(exceptions.InvalidValueError): before = markings.add_markings(before, data, None) @@ -273,14 +273,14 @@ def test_get_markings_object_and_granular_combinations(data): ( Malware( object_marking_refs=[MARKING_IDS[0]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), Malware(**MALWARE_KWARGS), ), ( dict( object_marking_refs=[MARKING_IDS[0]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), MALWARE_KWARGS, ), @@ -305,33 +305,33 @@ def test_remove_markings_object_level(data): ( Malware( object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], MARKING_IDS[2]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), Malware( object_marking_refs=[MARKING_IDS[1]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), [MARKING_IDS[0], MARKING_IDS[2]], ), ( dict( object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], MARKING_IDS[2]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), dict( object_marking_refs=[MARKING_IDS[1]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), [MARKING_IDS[0], MARKING_IDS[2]], ), ( Malware( object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], TLP_AMBER.id], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), Malware( object_marking_refs=[MARKING_IDS[1]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), [MARKING_IDS[0], TLP_AMBER], ), @@ -349,7 +349,7 @@ def test_remove_markings_multiple(data): def test_remove_markings_bad_markings(): before = Malware( object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], MARKING_IDS[2]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) with pytest.raises(MarkingNotFoundError) as excinfo: markings.remove_markings(before, [MARKING_IDS[4]], None) @@ -361,14 +361,14 @@ def test_remove_markings_bad_markings(): ( Malware( object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], MARKING_IDS[2]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), Malware(**MALWARE_KWARGS), ), ( dict( object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], MARKING_IDS[2]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), MALWARE_KWARGS, ), @@ -532,14 +532,14 @@ def test_is_marked_object_and_granular_combinations(): ( Malware( object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], MARKING_IDS[2]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), Malware(**MALWARE_KWARGS), ), ( dict( object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], MARKING_IDS[2]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ), MALWARE_KWARGS, ), @@ -556,11 +556,11 @@ def test_is_marked_no_markings(data): def test_set_marking(): before = Malware( object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], MARKING_IDS[2]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( object_marking_refs=[MARKING_IDS[4], MARKING_IDS[5]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) before = markings.set_markings(before, [MARKING_IDS[4], MARKING_IDS[5]], None) @@ -584,11 +584,11 @@ def test_set_marking(): def test_set_marking_bad_input(data): before = Malware( object_marking_refs=[MARKING_IDS[0]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) after = Malware( object_marking_refs=[MARKING_IDS[0]], - **MALWARE_KWARGS + **MALWARE_KWARGS, ) with pytest.raises(exceptions.InvalidValueError): before = markings.set_markings(before, data, None) diff --git a/stix2/test/v21/test_observed_data.py b/stix2/test/v21/test_observed_data.py index d2ccec49..91baa69e 100644 --- a/stix2/test/v21/test_observed_data.py +++ b/stix2/test/v21/test_observed_data.py @@ -1176,14 +1176,14 @@ def test_incorrect_socket_options(): ) assert "Incorrect options key" == str(excinfo.value) - with pytest.raises(ValueError) as excinfo: + with pytest.raises(Exception) as excinfo: stix2.v21.SocketExt( is_listening=True, address_family="AF_INET", socket_type="SOCK_STREAM", options={"SO_RCVTIMEO": '100'}, ) - assert "Options value must be an integer" == str(excinfo.value) + assert "Dictionary Property does not support this value's type" in str(excinfo.value) def test_network_traffic_tcp_example(): diff --git a/stix2/test/v21/test_properties.py b/stix2/test/v21/test_properties.py index 4fb84ec0..7b9a41e1 100644 --- a/stix2/test/v21/test_properties.py +++ b/stix2/test/v21/test_properties.py @@ -6,23 +6,15 @@ ExtraPropertiesError, ParseError, ) from stix2.properties import ( - DictionaryProperty, EmbeddedObjectProperty, ExtensionsProperty, - HashesProperty, IDProperty, ListProperty, ObservableProperty, - ReferenceProperty, STIXObjectProperty, StringProperty, + DictionaryProperty, EmbeddedObjectProperty, EnumProperty, + ExtensionsProperty, HashesProperty, IDProperty, IntegerProperty, + ListProperty, ObservableProperty, ReferenceProperty, STIXObjectProperty, + StringProperty, ) from stix2.v21.common import MarkingProperty from . import constants - -def test_dictionary_property(): - p = DictionaryProperty(StringProperty) - - assert p.clean({'spec_version': '2.1'}) - with pytest.raises(ValueError): - p.clean({}) - - ID_PROP = IDProperty('my-type', spec_version="2.1") MY_ID = 'my-type--232c9d3f-49fc-4440-bb01-607f638778e7' @@ -388,6 +380,116 @@ class NewObj(): assert test_obj.property1[0]['foo'] == 'bar' +def test_dictionary_property(): + p = DictionaryProperty() + + result = p.clean({'spec_version': '2.1'}) + assert result == ({'spec_version': '2.1'}, False) + + with pytest.raises(ValueError): + p.clean({}, False) + + +def test_dictionary_property_values_str(): + p = DictionaryProperty(valid_types=[StringProperty], spec_version='2.1') + result = p.clean({'x': '123'}, False) + assert result == ({'x': '123'}, False) + + q = DictionaryProperty(valid_types=[StringProperty], spec_version='2.1') + with pytest.raises(ValueError): + assert q.clean({'x': [123]}, False) + + +def test_dictionary_property_values_str_single(): + # singles should be treated as length-one lists + p = DictionaryProperty(valid_types=StringProperty, spec_version='2.1') + result = p.clean({'x': '123'}, False) + assert result == ({'x': '123'}, False) + + with pytest.raises(ValueError): + assert p.clean({'x': [123]}, False) + + +def test_dictionary_property_values_int(): + p = DictionaryProperty(valid_types=[IntegerProperty], spec_version='2.1') + result = p.clean({'x': 123}, False) + assert result == ({'x': 123}, False) + + q = DictionaryProperty(valid_types=[IntegerProperty], spec_version='2.1') + with pytest.raises(ValueError): + assert q.clean({'x': [123]}, False) + + +def test_dictionary_property_values_stringlist(): + p = DictionaryProperty(valid_types=[ListProperty(StringProperty)], spec_version='2.1') + result = p.clean({'x': ['abc', 'def']}, False) + assert result == ({'x': ['abc', 'def']}, False) + + q = DictionaryProperty(valid_types=[ListProperty(StringProperty)], spec_version='2.1') + with pytest.raises(ValueError): + assert q.clean({'x': [123]}) + + r = DictionaryProperty(valid_types=[StringProperty, IntegerProperty], spec_version='2.1') + with pytest.raises(ValueError): + assert r.clean({'x': [123, 456]}) + + +def test_dictionary_property_values_list(): + p = DictionaryProperty(valid_types=[StringProperty, IntegerProperty], spec_version='2.1') + result = p.clean({'x': 123}, False) + assert result == ({'x': 123}, False) + + q = DictionaryProperty(valid_types=[StringProperty, IntegerProperty], spec_version='2.1') + result = q.clean({'x': '123'}, False) + assert result == ({'x': '123'}, False) + + r = DictionaryProperty(valid_types=[StringProperty, IntegerProperty], spec_version='2.1') + with pytest.raises(ValueError): + assert r.clean({'x': ['abc', 'def']}, False) + + +def test_dictionary_property_ref_custom(): + p = DictionaryProperty( + valid_types=ReferenceProperty(valid_types="SDO"), spec_version="2.1", + ) + + result = p.clean({"key": "identity--a2ac7670-f88f-424a-b3be-28f612f943f9"}, allow_custom=False) + assert result == ({"key": "identity--a2ac7670-f88f-424a-b3be-28f612f943f9"}, False) + + with pytest.raises(ValueError): + p.clean({"key": "software--a2ac7670-f88f-424a-b3be-28f612f943f9"}, allow_custom=False) + + with pytest.raises(ValueError): + p.clean({"key": "software--a2ac7670-f88f-424a-b3be-28f612f943f9"}, allow_custom=True) + + pfoo = DictionaryProperty( + valid_types=ReferenceProperty(valid_types=["SDO", "foo"]), spec_version="2.1", + ) + + with pytest.raises(CustomContentError): + pfoo.clean({"key": "foo--a2ac7670-f88f-424a-b3be-28f612f943f9"}, allow_custom=False) + + result = pfoo.clean({"key": "foo--a2ac7670-f88f-424a-b3be-28f612f943f9"}, allow_custom=True) + assert result == ({"key": "foo--a2ac7670-f88f-424a-b3be-28f612f943f9"}, True) + + +def test_dictionary_property_values_strict_clean(): + prop = DictionaryProperty( + valid_types=[EnumProperty(["value1", "value2"]), IntegerProperty], + ) + + result = prop.clean({"key": "value1"}, allow_custom=False) + assert result == ({"key": "value1"}, False) + + result = prop.clean({"key": 123}, allow_custom=False) + assert result == ({"key": 123}, False) + + with pytest.raises(ValueError): + # IntegerProperty normally cleans "123" to 123, but can't when used + # in a DictionaryProperty. + prop.clean({"key": "123"}, allow_custom=False) + + @pytest.mark.parametrize( "key", [ "a", diff --git a/stix2/test/v21/test_utils.py b/stix2/test/v21/test_utils.py index 33e7ea49..6dad23e8 100644 --- a/stix2/test/v21/test_utils.py +++ b/stix2/test/v21/test_utils.py @@ -205,7 +205,7 @@ def test_deduplicate(stix_objs1): def test_find_property_index(object, tuple_to_find, expected_index): assert stix2.serialization.find_property_index( object, - *tuple_to_find + *tuple_to_find, ) == expected_index diff --git a/stix2/test/v21/test_versioning.py b/stix2/test/v21/test_versioning.py index c7b6f119..98e01383 100644 --- a/stix2/test/v21/test_versioning.py +++ b/stix2/test/v21/test_versioning.py @@ -48,7 +48,7 @@ def test_making_new_version_with_embedded_object(): "source_name": "capec", "external_id": "CAPEC-163", }], - **CAMPAIGN_MORE_KWARGS + **CAMPAIGN_MORE_KWARGS, ) campaign_v2 = campaign_v1.new_version( diff --git a/stix2/v20/base.py b/stix2/v20/base.py index b5437ca6..45698e3e 100644 --- a/stix2/v20/base.py +++ b/stix2/v20/base.py @@ -1,7 +1,8 @@ """Base classes for STIX 2.0 type definitions.""" from ..base import ( - _DomainObject, _Extension, _Observable, _RelationshipObject, _STIXBase, + _DomainObject, _Extension, _MetaObject, _Observable, _RelationshipObject, + _STIXBase, ) @@ -23,3 +24,7 @@ class _DomainObject(_DomainObject, _STIXBase20): class _RelationshipObject(_RelationshipObject, _STIXBase20): pass + + +class _MetaObject(_MetaObject, _STIXBase20): + pass diff --git a/stix2/v20/common.py b/stix2/v20/common.py index feaa3efc..b5a9ca5a 100644 --- a/stix2/v20/common.py +++ b/stix2/v20/common.py @@ -11,7 +11,7 @@ SelectorProperty, StringProperty, TimestampProperty, TypeProperty, ) from ..utils import NOW, _get_dict -from .base import _STIXBase20 +from .base import _MetaObject, _STIXBase20 from .vocab import HASHING_ALGORITHM @@ -111,7 +111,7 @@ def clean(self, value, allow_custom=False): raise ValueError("must be a Statement, TLP Marking or a registered marking.") -class MarkingDefinition(_STIXBase20, _MarkingsMixin): +class MarkingDefinition(_MetaObject, _MarkingsMixin): """For more detailed information on this object's properties, see `the STIX 2.0 specification `__. """ diff --git a/stix2/v20/observables.py b/stix2/v20/observables.py index 2b6c81ca..d0bb51c5 100644 --- a/stix2/v20/observables.py +++ b/stix2/v20/observables.py @@ -818,5 +818,5 @@ def CustomExtension(type='x-custom-observable-ext', properties=None): """Decorator for custom extensions to STIX Cyber Observables. """ def wrapper(cls): - return _custom_extension_builder(cls, type, properties, '2.0', _Extension) + return _custom_extension_builder(cls, "sco", type, properties, '2.0', _Extension) return wrapper diff --git a/stix2/v20/sro.py b/stix2/v20/sro.py index 1372a5e5..13d5e136 100644 --- a/stix2/v20/sro.py +++ b/stix2/v20/sro.py @@ -39,7 +39,7 @@ class Relationship(_RelationshipObject): # Explicitly define the first three kwargs to make readable Relationship declarations. def __init__( self, source_ref=None, relationship_type=None, - target_ref=None, **kwargs + target_ref=None, **kwargs, ): # Allow (source_ref, relationship_type, target_ref) as positional args. if source_ref and not kwargs.get('source_ref'): diff --git a/stix2/v21/base.py b/stix2/v21/base.py index 3878b791..fb3a1966 100644 --- a/stix2/v21/base.py +++ b/stix2/v21/base.py @@ -1,7 +1,8 @@ """Base classes for STIX 2.1 type definitions.""" from ..base import ( - _DomainObject, _Extension, _Observable, _RelationshipObject, _STIXBase, + _DomainObject, _Extension, _MetaObject, _Observable, _RelationshipObject, + _STIXBase, ) @@ -29,6 +30,10 @@ def __init__(self, **kwargs): class _Extension(_Extension, _STIXBase21): extension_type = None + def __init__(self, applies_to="sco", **kwargs): + super(_Extension, self).__init__(**kwargs) + self._applies_to = applies_to + class _DomainObject(_DomainObject, _STIXBase21): pass @@ -36,3 +41,7 @@ class _DomainObject(_DomainObject, _STIXBase21): class _RelationshipObject(_RelationshipObject, _STIXBase21): pass + + +class _MetaObject(_MetaObject, _STIXBase21): + pass diff --git a/stix2/v21/common.py b/stix2/v21/common.py index 55c4a05d..218070ff 100644 --- a/stix2/v21/common.py +++ b/stix2/v21/common.py @@ -14,7 +14,7 @@ TypeProperty, ) from ..utils import NOW, _get_dict -from .base import _STIXBase21 +from .base import _MetaObject, _STIXBase21 from .vocab import EXTENSION_TYPE, HASHING_ALGORITHM @@ -79,7 +79,7 @@ def _check_object_constraints(self): self._check_at_least_one_property(['lang', 'marking_ref']) -class LanguageContent(_STIXBase21): +class LanguageContent(_MetaObject): """For more detailed information on this object's properties, see `the STIX 2.1 specification `__. """ @@ -107,7 +107,7 @@ class LanguageContent(_STIXBase21): ]) -class ExtensionDefinition(_STIXBase21): +class ExtensionDefinition(_MetaObject): """For more detailed information on this object's properties, see `the STIX 2.1 specification `__. """ @@ -140,11 +140,11 @@ class ExtensionDefinition(_STIXBase21): ]) -def CustomExtension(type='x-custom-ext', properties=None): +def CustomExtension(type='x-custom-ext', properties=None, applies_to="sco"): """Custom STIX Object Extension decorator. """ def wrapper(cls): - return _custom_extension_builder(cls, type, properties, '2.1', _Extension) + return _custom_extension_builder(cls, applies_to, type, properties, '2.1', _Extension) return wrapper @@ -190,7 +190,7 @@ def clean(self, value, allow_custom=False): raise ValueError("must be a Statement, TLP Marking or a registered marking.") -class MarkingDefinition(_STIXBase21, _MarkingsMixin): +class MarkingDefinition(_MetaObject, _MarkingsMixin): """For more detailed information on this object's properties, see `the STIX 2.1 specification `__. """ diff --git a/stix2/v21/observables.py b/stix2/v21/observables.py index f4a4be0f..6fa777ed 100644 --- a/stix2/v21/observables.py +++ b/stix2/v21/observables.py @@ -181,7 +181,7 @@ class EmailMessage(_Observable): ('message_id', StringProperty()), ('subject', StringProperty()), ('received_lines', ListProperty(StringProperty)), - ('additional_header_fields', DictionaryProperty(spec_version='2.1')), + ('additional_header_fields', DictionaryProperty(valid_types=[ListProperty(StringProperty)], spec_version='2.1')), ('body', StringProperty()), ('body_multipart', ListProperty(EmbeddedObjectProperty(type=EmailMIMEComponent))), ('raw_email_ref', ReferenceProperty(valid_types='artifact', spec_version='2.1')), @@ -245,7 +245,7 @@ class PDFExt(_Extension): _properties = OrderedDict([ ('version', StringProperty()), ('is_optimized', BooleanProperty()), - ('document_info_dict', DictionaryProperty(spec_version='2.1')), + ('document_info_dict', DictionaryProperty(valid_types=[StringProperty], spec_version='2.1')), ('pdfid0', StringProperty()), ('pdfid1', StringProperty()), ]) @@ -261,7 +261,7 @@ class RasterImageExt(_Extension): ('image_height', IntegerProperty()), ('image_width', IntegerProperty()), ('bits_per_pixel', IntegerProperty()), - ('exif_tags', DictionaryProperty(spec_version='2.1')), + ('exif_tags', DictionaryProperty(valid_types=[StringProperty, IntegerProperty], spec_version='2.1')), ]) @@ -468,7 +468,7 @@ class HTTPRequestExt(_Extension): ('request_method', StringProperty(required=True)), ('request_value', StringProperty(required=True)), ('request_version', StringProperty()), - ('request_header', DictionaryProperty(spec_version='2.1')), + ('request_header', DictionaryProperty(valid_types=[ListProperty(StringProperty)], spec_version='2.1')), ('message_body_length', IntegerProperty()), ('message_body_data_ref', ReferenceProperty(valid_types='artifact', spec_version='2.1')), ]) @@ -496,7 +496,7 @@ class SocketExt(_Extension): ('address_family', EnumProperty(NETWORK_SOCKET_ADDRESS_FAMILY, required=True)), ('is_blocking', BooleanProperty()), ('is_listening', BooleanProperty()), - ('options', DictionaryProperty(spec_version='2.1')), + ('options', DictionaryProperty(valid_types=[IntegerProperty], spec_version='2.1')), ('socket_type', EnumProperty(NETWORK_SOCKET_TYPE)), ('socket_descriptor', IntegerProperty(min=0)), ('socket_handle', IntegerProperty()), @@ -550,7 +550,7 @@ class NetworkTraffic(_Observable): ('dst_byte_count', IntegerProperty(min=0)), ('src_packets', IntegerProperty(min=0)), ('dst_packets', IntegerProperty(min=0)), - ('ipfix', DictionaryProperty(spec_version='2.1')), + ('ipfix', DictionaryProperty(valid_types=[StringProperty, IntegerProperty], spec_version='2.1')), ('src_payload_ref', ReferenceProperty(valid_types='artifact', spec_version='2.1')), ('dst_payload_ref', ReferenceProperty(valid_types='artifact', spec_version='2.1')), ('encapsulates_refs', ListProperty(ReferenceProperty(valid_types='network-traffic', spec_version='2.1'))), @@ -634,7 +634,7 @@ class Process(_Observable): ('created_time', TimestampProperty()), ('cwd', StringProperty()), ('command_line', StringProperty()), - ('environment_variables', DictionaryProperty(spec_version='2.1')), + ('environment_variables', DictionaryProperty(valid_types=[StringProperty], spec_version='2.1')), ('opened_connection_refs', ListProperty(ReferenceProperty(valid_types='network-traffic', spec_version='2.1'))), ('creator_user_ref', ReferenceProperty(valid_types='user-account', spec_version='2.1')), ('image_ref', ReferenceProperty(valid_types='file', spec_version='2.1')), diff --git a/stix2/v21/sro.py b/stix2/v21/sro.py index bf636c3d..8ef1582c 100644 --- a/stix2/v21/sro.py +++ b/stix2/v21/sro.py @@ -46,7 +46,7 @@ class Relationship(_RelationshipObject): # Explicitly define the first three kwargs to make readable Relationship declarations. def __init__( self, source_ref=None, relationship_type=None, - target_ref=None, **kwargs + target_ref=None, **kwargs, ): # Allow (source_ref, relationship_type, target_ref) as positional args. if source_ref and not kwargs.get('source_ref'): diff --git a/tox.ini b/tox.ini index bc7bd5b6..c74a1230 100644 --- a/tox.ini +++ b/tox.ini @@ -11,6 +11,10 @@ deps = rapidfuzz haversine medallion + sqlalchemy + sqlalchemy_utils + psycopg2 + commands = python -m pytest --cov=stix2 stix2/test/ --cov-report term-missing -W ignore::stix2.exceptions.STIXDeprecationWarning