Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 42 additions & 20 deletions pytm/pytm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,7 +1278,6 @@ def get_table(db, klass):
db.close()



class Controls:
"""Controls implemented by/on and Element"""

Expand Down Expand Up @@ -2004,52 +2003,75 @@ def to_serializable(val):

@to_serializable.register(TM)
def ts_tm(obj):
return serialize(obj, nested=True)
result = serialize(obj, nested=True, ignore=(
"_sf", "_duplicate_ignored_attrs", "_threats", "_elements", "assumptions"))
result["elements"] = [e for e in obj._elements if isinstance(e, (Actor, Asset))]
result["assumptions"] = list(obj.assumptions)
return result


@to_serializable.register(Controls)
@to_serializable.register(Data)
@to_serializable.register(Finding)
def _(obj):
return serialize(obj, nested=False)


@to_serializable.register(Threat)
@to_serializable.register(Element)
def _(obj):
result = serialize(obj, nested=False, ignore=["target"])
result["target"] = [v.__name__ for v in obj.target]
return result


@to_serializable.register(Finding)
def _(obj):
return serialize(obj, nested=False, ignore=["element"])


@to_serializable.register(Element)
def ts_element(obj):
return serialize(obj, nested=False)
result = serialize(obj, nested=False, ignore=("_is_drawn", "uuid", "levels", "sourceFiles", "assumptions", "findings"))
result["levels"] = list(obj.levels)
result["sourceFiles"] = list(obj.sourceFiles)
result["assumptions"] = list(obj.assumptions)
result["findings"] = [v.id for v in obj.findings]
return result


def serialize(obj, nested=False):
@to_serializable.register(Actor)
@to_serializable.register(Asset)
def _(obj):
# Note that we use the ts_element function defined for the Element class
result = ts_element(obj)
result["__class__"] = obj.__class__.__name__
return result


def serialize(obj, nested=False, ignore=None):
"""Used if *obj* is an instance of TM, Element, Threat or Finding."""
klass = obj.__class__
result = {}
if isinstance(obj, (Actor, Asset)):
result["__class__"] = klass.__name__
if ignore is None:
ignore = []

for i in dir(obj):
if (
i.startswith("__")
or callable(getattr(klass, i, {}))
or (
isinstance(obj, TM)
and i in ("_sf", "_duplicate_ignored_attrs", "_threats")
)
or (isinstance(obj, Element) and i in ("_is_drawn", "uuid"))
or (isinstance(obj, Finding) and i == "element")
or i in ignore
):
continue
value = getattr(obj, i)
if isinstance(obj, TM) and i == "_elements":
value = [e for e in value if isinstance(e, (Actor, Asset))]
if value is not None:
if isinstance(value, (Element, Data)):
value = value.name
elif isinstance(obj, Threat) and i == "target":
value = [v.__name__ for v in value]
elif i in ("levels", "sourceFiles", "assumptions"):
value = list(value)
elif (
not nested
and not isinstance(value, str)
and isinstance(value, Iterable)
):
value = [v.id if isinstance(v, Finding) else v.name for v in value]
value = [v.name for v in value]
result[i.lstrip("_")] = value
return result

Expand Down