Skip to content

Commit

Permalink
added docstrings and typehints to models (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
taylorwalton authored Jul 11, 2023
1 parent 9db52fe commit f6d3a87
Show file tree
Hide file tree
Showing 9 changed files with 444 additions and 160 deletions.
73 changes: 55 additions & 18 deletions backend/app/models/agents.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,62 @@
# from datetime import datetime
from datetime import datetime

# from loguru import logger
# from sqlalchemy.dialects.postgresql import JSONB # Add this line
from sqlalchemy import Boolean
from sqlalchemy import Column
from sqlalchemy import DateTime
from sqlalchemy import Integer
from sqlalchemy import String

from app import db
from app import ma


# Class for agent metadata which stores the agent ID, IP address, hostname, OS, last seen timestamp,
# and boolean for critical assest.
# Path: backend\app\models.py
class AgentMetadata(db.Model):
id = db.Column(db.Integer, primary_key=True)
agent_id = db.Column(db.String(100))
ip_address = db.Column(db.String(100))
os = db.Column(db.String(100))
hostname = db.Column(db.String(100))
critical_asset = db.Column(db.Boolean, default=False)
last_seen = db.Column(db.DateTime)

def __init__(self, agent_id, ip_address, os, hostname, critical_asset, last_seen):
"""
Class for agent metadata which stores the agent ID, IP address, hostname, OS, last seen timestamp,
and boolean for critical asset. This class inherits from SQLAlchemy's Model class.
"""

id: Column[Integer] = db.Column(db.Integer, primary_key=True)
agent_id: Column[String] = db.Column(db.String(100))
ip_address: Column[String] = db.Column(db.String(100))
os: Column[String] = db.Column(db.String(100))
hostname: Column[String] = db.Column(db.String(100))
critical_asset: Column[Boolean] = db.Column(db.Boolean, default=False)
last_seen: Column[DateTime] = db.Column(db.DateTime)

def __init__(
self,
agent_id: str,
ip_address: str,
os: str,
hostname: str,
critical_asset: bool,
last_seen: datetime,
):
"""
Initialize a new instance of the AgentMetadata class.
:param agent_id: Unique ID for the agent.
:param ip_address: IP address of the agent.
:param os: Operating system of the agent.
:param hostname: Hostname of the agent.
:param critical_asset: Boolean value indicating if the agent is a critical asset.
:param last_seen: Timestamp of when the agent was last seen.
"""
self.agent_id = agent_id
self.ip_address = ip_address
self.os = os
self.hostname = hostname
self.critical_asset = critical_asset
self.last_seen = last_seen

def __repr__(self):
def __repr__(self) -> str:
"""
Returns a string representation of the AgentMetadata instance.
:return: A string representation of the agent ID.
"""
return f"<AgentMetadata {self.agent_id}>"

def mark_as_critical(self):
Expand All @@ -53,8 +82,16 @@ def commit_wazuh_agent_to_db(self):


class AgentMetadataSchema(ma.Schema):
"""
Schema for serializing and deserializing instances of the AgentMetadata class.
"""

class Meta:
fields = (
"""
Meta class defines the fields to be serialized/deserialized.
"""

fields: tuple = (
"id",
"agent_id",
"ip_address",
Expand All @@ -65,5 +102,5 @@ class Meta:
)


agent_metadata_schema = AgentMetadataSchema()
agent_metadatas_schema = AgentMetadataSchema(many=True)
agent_metadata_schema: AgentMetadataSchema = AgentMetadataSchema()
agent_metadatas_schema: AgentMetadataSchema = AgentMetadataSchema(many=True)
53 changes: 40 additions & 13 deletions backend/app/models/artifacts.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,62 @@
from datetime import datetime
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy.dialects.postgresql import TEXT # Add this line

from app import db
from app import ma
from sqlalchemy.dialects.postgresql import TEXT # Add this line

# Class for artifacts collected which stores the artifact name, artificat results (json), hostname

# Path: backend\app\models.py
class Artifact(db.Model):
id = db.Column(db.Integer, primary_key=True)
artifact_name = db.Column(db.String(100))
artifact_results = db.Column(TEXT)
hostname = db.Column(db.String(100))

def __init__(self, artifact_name, artifact_results, hostname):
"""
Class for artifacts collected which stores the artifact name, artifact results (json), and hostname.
This class inherits from SQLAlchemy's Model class.
"""

id: Column[Integer] = db.Column(db.Integer, primary_key=True)
artifact_name: Column[String] = db.Column(db.String(100))
artifact_results: Column[TEXT] = db.Column(TEXT)
hostname: Column[String] = db.Column(db.String(100))

def __init__(self, artifact_name: str, artifact_results: str, hostname: str):
"""
Initialize a new instance of the Artifact class.
:param artifact_name: The name of the artifact.
:param artifact_results: The results of the artifact, stored as a JSON string.
:param hostname: The hostname where the artifact was collected.
"""
self.artifact_name = artifact_name
self.artifact_results = artifact_results
self.hostname = hostname

def __repr__(self):
def __repr__(self) -> str:
"""
Returns a string representation of the Artifact instance.
:return: A string representation of the artifact name.
"""
return f"<Artifact {self.artifact_name}>"


class ArtifactSchema(ma.Schema):
"""
Schema for serializing and deserializing instances of the Artifact class.
"""

class Meta:
fields = (
"""
Meta class defines the fields to be serialized/deserialized.
"""

fields: tuple = (
"id",
"artifact_name",
"artifact_results",
"hostname",
)


artifact_schema = ArtifactSchema()
artifacts_schema = ArtifactSchema(many=True)
artifact_schema: ArtifactSchema = ArtifactSchema()
artifacts_schema: ArtifactSchema = ArtifactSchema(many=True)
51 changes: 39 additions & 12 deletions backend/app/models/cases.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,61 @@
from datetime import datetime
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import String

from app import db
from app import ma

# Class for cases which stores the case ID, case name, list of agents

# Path: backend\app\models.py
class Case(db.Model):
id = db.Column(db.Integer, primary_key=True)
case_id = db.Column(db.Integer)
case_name = db.Column(db.String(100))
agents = db.Column(db.String(1000))

def __init__(self, case_id, case_name, agents):
"""
Class for cases which stores the case ID, case name, and a list of agents.
This class inherits from SQLAlchemy's Model class.
"""

id: Column[Integer] = db.Column(db.Integer, primary_key=True)
case_id: Column[Integer] = db.Column(db.Integer)
case_name: Column[String] = db.Column(db.String(100))
agents: Column[String] = db.Column(db.String(1000))

def __init__(self, case_id: int, case_name: str, agents: str):
"""
Initialize a new instance of the Case class.
:param case_id: The ID of the case.
:param case_name: The name of the case.
:param agents: A comma-separated string of agents associated with the case.
"""
self.case_id = case_id
self.case_name = case_name
self.agents = agents

def __repr__(self):
def __repr__(self) -> str:
"""
Returns a string representation of the Case instance.
:return: A string representation of the case ID.
"""
return f"<Case {self.case_id}>"


class CaseSchema(ma.Schema):
"""
Schema for serializing and deserializing instances of the Case class.
"""

class Meta:
fields = (
"""
Meta class defines the fields to be serialized/deserialized.
"""

fields: tuple = (
"id",
"case_id",
"case_name",
"agents",
)


case_schema = CaseSchema()
cases_schema = CaseSchema(many=True)
case_schema: CaseSchema = CaseSchema()
cases_schema: CaseSchema = CaseSchema(many=True)
42 changes: 21 additions & 21 deletions backend/app/models/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any
from typing import Dict

import grpc
import pika
Expand All @@ -17,10 +19,8 @@

from app.models.models import Connectors

# from werkzeug.utils import secure_filename


def dynamic_import(module_name, class_name):
def dynamic_import(module_name: str, class_name: str) -> Any:
"""
This function dynamically imports a module and returns a specific class from it.
Expand All @@ -43,10 +43,10 @@ class Connector(ABC):
:param attributes: A dictionary of attributes necessary for the connector to connect to the service or system.
"""

attributes: dict
attributes: Dict[str, Any]

@abstractmethod
def verify_connection(self):
def verify_connection(self) -> Dict[str, Any]:
"""
This abstract method should be implemented by all subclasses of Connector. It is meant to verify the
connection to the service or system the connector is designed to connect to.
Expand All @@ -56,7 +56,7 @@ def verify_connection(self):
pass

@staticmethod
def get_connector_info_from_db(connector_name):
def get_connector_info_from_db(connector_name: str) -> Dict[str, Any]:
"""
This method retrieves connector information from the database.
Expand Down Expand Up @@ -89,10 +89,10 @@ class WazuhIndexerConnector(Connector):
:param connector_name: A string that specifies the name of the connector.
"""

def __init__(self, connector_name):
def __init__(self, connector_name: str):
super().__init__(attributes=self.get_connector_info_from_db(connector_name))

def verify_connection(self):
def verify_connection(self) -> Dict[str, Any]:
"""
This method verifies the connection to the Wazuh indexer service.
Expand Down Expand Up @@ -131,10 +131,10 @@ class GraylogConnector(Connector):
:param connector_name: A string that specifies the name of the connector.
"""

def __init__(self, connector_name):
def __init__(self, connector_name: str):
super().__init__(attributes=self.get_connector_info_from_db(connector_name))

def verify_connection(self):
def verify_connection(self) -> Dict[str, Any]:
"""
Verifies the connection to Graylog service.
Expand Down Expand Up @@ -177,10 +177,10 @@ class WazuhManagerConnector(Connector):
:param connector_name: A string that specifies the name of the connector.
"""

def __init__(self, connector_name):
def __init__(self, connector_name: str):
super().__init__(attributes=self.get_connector_info_from_db(connector_name))

def verify_connection(self):
def verify_connection(self) -> Dict[str, Any]:
"""
Verifies the connection to Wazuh manager service.
Expand Down Expand Up @@ -215,7 +215,7 @@ def verify_connection(self):
)
return {"connectionSuccessful": False, "authToken": None}

def get_auth_token(self):
def get_auth_token(self) -> str:
"""
Returns the authentication token for the Wazuh manager service.
Expand All @@ -232,10 +232,10 @@ class ShuffleConnector(Connector):
:param connector_name: A string that specifies the name of the connector.
"""

def __init__(self, connector_name):
def __init__(self, connector_name: str):
super().__init__(attributes=self.get_connector_info_from_db(connector_name))

def verify_connection(self):
def verify_connection(self) -> Dict[str, Any]:
"""
Verifies the connection to Shuffle service.
Expand Down Expand Up @@ -278,10 +278,10 @@ class DfirIrisConnector(Connector):
:param connector_name: A string that specifies the name of the connector.
"""

def __init__(self, connector_name):
def __init__(self, connector_name: str):
super().__init__(attributes=self.get_connector_info_from_db(connector_name))

def verify_connection(self):
def verify_connection(self) -> Dict[str, Any]:
"""
Verifies the connection to DFIR IRIS service.
Expand Down Expand Up @@ -326,10 +326,10 @@ class VelociraptorConnector(Connector):
connector_name (str): The name of the connector.
"""

def __init__(self, connector_name):
def __init__(self, connector_name: str):
super().__init__(attributes=self.get_connector_info_from_db(connector_name))

def verify_connection(self):
def verify_connection(self) -> Dict[str, Any]:
"""
Verifies the connection to Velociraptor service.
Expand Down Expand Up @@ -391,10 +391,10 @@ class RabbitMQConnector(Connector):
connector_name (str): The name of the connector.
"""

def __init__(self, connector_name):
def __init__(self, connector_name: str):
super().__init__(attributes=self.get_connector_info_from_db(connector_name))

def verify_connection(self):
def verify_connection(self) -> Dict[str, Any]:
"""
Verifies the connection to RabbitMQ service.
"""
Expand Down
Loading

0 comments on commit f6d3a87

Please sign in to comment.