diff --git a/pyproject.toml b/pyproject.toml index 3ad22b6..44a49a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ click = "*" psycopg2-binary = "*" PyYAML = "*" transitions = "*" +pydantic = "^2.9.2" [tool.poetry.dev-dependencies] ipykernel = "*" diff --git a/spiderexpress/cli.py b/spiderexpress/cli.py index 37a5e98..15c03b5 100644 --- a/spiderexpress/cli.py +++ b/spiderexpress/cli.py @@ -11,13 +11,14 @@ - Refine verbs/commands for the CLI - find a mechanism for stopping/starting collections """ + +import sys from importlib.metadata import entry_points from pathlib import Path -from loguru import logger as log import click import yaml -import sys +from loguru import logger as log from .spider import CONNECTOR_GROUP, STRATEGY_GROUP, Spider from .types import Configuration @@ -34,22 +35,21 @@ def cli(ctx): @cli.command() @click.argument("config", type=click.Path(path_type=Path, exists=True)) @click.option("-v", "--verbose", count=True) -@click.option("-l", "--logfile", type=click.Path(dir_okay=False, writable=True, path_type=str)) +@click.option( + "-l", "--logfile", type=click.Path(dir_okay=False, writable=True, path_type=str) +) @click.pass_context def start(ctx: click.Context, config: Path, verbose: int, logfile: str): """start a job""" - logging_level = max(50 - (10 * verbose), 0) # Allows logging level to be between 0 and 50. + logging_level = max( + 50 - (10 * verbose), 0 + ) # Allows logging level to be between 0 and 50. logging_configuration = { - "handlers": [ - { - "sink": logfile or sys.stdout, - "level": logging_level - } - ], - "extra": {} + "handlers": [{"sink": logfile or sys.stdout, "level": logging_level}], + "extra": {}, } log.configure(**logging_configuration) - log.debug(f"Starting logging with verbosity {logging_level}.") + log.debug(f"Starting logging with verbosity {logging_level}.") ctx.obj.start(config) @@ -62,7 +62,13 @@ def create(config: str, interactive: bool): if interactive: for key, description in [ - ("seeds", "add seeds?"), + ("project_name", "Name of your project?"), + ("db_url", "URL of your database?"), + ("max_iteration", "How many iterations should be done?"), + ( + "empty_seeds", + "What should happen if seeds are empty? Can be 'stop' or 'retry'", + ), ("seed_file", "do you wish to read a file for seeds?"), ]: args[key] = click.prompt(description) @@ -76,7 +82,7 @@ def create(config: str, interactive: bool): @cli.command() def list(): # pylint: disable=W0622 """list all plugins""" - click.echo("--- connectors ---", color="blue") + click.echo("--- connectors ---") for connector in entry_points(group=CONNECTOR_GROUP): click.echo(connector.name) click.echo("--- strategies ---") diff --git a/spiderexpress/connectors/csv.py b/spiderexpress/connectors/csv.py index 9a5e508..a874941 100644 --- a/spiderexpress/connectors/csv.py +++ b/spiderexpress/connectors/csv.py @@ -1,13 +1,15 @@ """A CSV-reading, network-rippin' connector for your testing purposes.""" + import dataclasses from typing import Dict, List, Optional, Tuple, Union import pandas as pd -from spiderexpress.types import PlugIn, fromdict +from spiderexpress.types import PlugIn, from_dict _cache = {} + @dataclasses.dataclass class CSVConnectorConfiguration: """Configuration items for the csv_connector.""" @@ -23,7 +25,7 @@ def csv_connector( ) -> Tuple[pd.DataFrame, pd.DataFrame]: """The CSV connector!""" if isinstance(configuration, dict): - configuration = fromdict(CSVConnectorConfiguration, configuration) + configuration = from_dict(CSVConnectorConfiguration, configuration) if configuration.cache: if configuration.edge_list_location not in _cache: @@ -61,9 +63,11 @@ def csv_connector( return ( edge_return, - nodes.loc[nodes.name.isin(node_ids), :] - if nodes is not None - else pd.DataFrame(), + ( + nodes.loc[nodes.name.isin(node_ids), :] + if nodes is not None + else pd.DataFrame() + ), ) diff --git a/spiderexpress/model.py b/spiderexpress/model.py index cdea87a..f383bec 100644 --- a/spiderexpress/model.py +++ b/spiderexpress/model.py @@ -4,13 +4,13 @@ """ import datetime -from typing import Any, Callable, Dict, List, Tuple, Type +from typing import Dict import sqlalchemy as sql from loguru import logger as log -from sqlalchemy import orm +from sqlalchemy import JSON, orm -# pylint: disable=R0903 +# pylint: disable=R0903, W0622 mapper_registry = orm.registry() @@ -24,6 +24,8 @@ class Base(orm.DeclarativeBase): """Base class for all models.""" + type_annotation_map = {Dict: JSON} + def __repr__(self): props = [ f"{key}={value}" @@ -68,177 +70,174 @@ class TaskList(Base): finished_at: orm.Mapped[datetime.datetime] = orm.mapped_column(nullable=True) -def create_factory( - cls: Type[Any], spec_fixed: List[sql.Column], spec_variadic: Dict[str, Any] -) -> Callable: - """Create a factory function for a given class.""" +class RawDataStore(Base): + """Table for raw data storage. - log.info( - f"Creating factory for {cls.__name__} with {spec_fixed} and {spec_variadic}" - ) + Attributes: + id: Primary key for the table. + connector_id: Identifier for the connector. + output_type: Type of the output data. + created_at: Timestamp when the data was created. + data: The raw data stored in JSON format. + """ - def _(data: Dict[str, Any]) -> Type[Any]: - return cls( - **{ - key: data.get(key) - for key in [column.name for column in spec_fixed] - + list(spec_variadic.keys()) - } - ) + __tablename__ = "raw_data_store" - return _ + id: orm.Mapped[str] = orm.mapped_column(primary_key=True) + connector_id: orm.Mapped[str] = orm.mapped_column(index=True) + output_type: orm.Mapped[str] = orm.mapped_column(index=True) + created_at: orm.Mapped[datetime.datetime] = orm.mapped_column( + index=True, insert_default=lambda: datetime.datetime.now(datetime.timezone.utc) + ) + data: orm.Mapped[Dict] = orm.mapped_column(insert_default={}) -def create_raw_edge_table( - name: str, spec_variadic: Dict[str, str] -) -> Tuple[sql.Table, Type["RawEdge"], Callable]: - """Create an edge table dynamically. +class LayerDenseEdges(Base): + """Table for dense data storage.""" - parameters: - name: name of the table - spec_variadic: dict of variadic columns + __tablename__ = "layer_dense_edges" - returns: - table: the table - """ - spec_fixed = [ - sql.Column("id", sql.Integer, primary_key=True, index=True, autoincrement=True), - sql.Column("source", sql.Text, index=True, unique=False), - sql.Column("target", sql.Text, index=True, unique=False), - sql.Column("iteration", sql.Integer, index=True, unique=False), - ] - - table = sql.Table( - name, - Base.metadata, - *spec_fixed, - *[ - sql.Column(key, type_lookup.get(value)) - for key, value in spec_variadic.items() - ], + id: orm.Mapped[str] = orm.mapped_column(primary_key=True, index=True) + source: orm.Mapped[str] = orm.mapped_column(index=True) + target: orm.Mapped[str] = orm.mapped_column(index=True) + edge_type: orm.Mapped[str] = orm.mapped_column(index=True) + layer_id: orm.Mapped[str] = orm.mapped_column(index=True) + created_at: orm.Mapped[datetime.datetime] = orm.mapped_column( + index=True, insert_default=lambda: datetime.datetime.now(datetime.timezone.utc) ) + data: orm.Mapped[Dict] = orm.mapped_column(insert_default={}) + + +def insert_layer_dense_edge(session, layer_id, edge_type, data): + """Insert a dense edge into the database.""" + source = data.get("source") + target = data.get("target") + id = f"{layer_id}:{source}-{target}" + + dense_edge = LayerDenseEdges( + id=id, + source=source, + target=target, + edge_type=edge_type, + layer_id=layer_id, + data=data, + ) + session.merge(dense_edge) + session.commit() + log.debug(f"Inserted dense edge from {source} to {target} in layer {layer_id}") - class RawEdge: - """Unaggregated, raw edge.""" - - def __repr__(self): - return f"""""" - mapper_registry.map_imperatively(RawEdge, table) +class LayerDenseNodes(Base): + """Table for dense data storage.""" - return table, RawEdge, create_factory(RawEdge, spec_fixed, spec_variadic) + __tablename__ = "layer_dense_nodes" + id: orm.Mapped[str] = orm.mapped_column(primary_key=True, index=True) + name: orm.Mapped[str] = orm.mapped_column() + layer_id: orm.Mapped[str] = orm.mapped_column(index=True) + node_type: orm.Mapped[str] = orm.mapped_column(index=True) + created_at: orm.Mapped[datetime.datetime] = orm.mapped_column( + index=True, insert_default=lambda: datetime.datetime.now(datetime.timezone.utc) + ) + data: orm.Mapped[Dict] = orm.mapped_column() -def create_aggregated_edge_table( - name: str, spec_variadic: Dict[str, str] -) -> Tuple[sql.Table, Type["AggEdge"], Callable]: - """Create an aggregated edge table dynamically. - parameters: - name: name of the table - spec_variadic: dict of variadic columns +def insert_layer_dense_node(session, layer_id, node_type, data): + """Insert a dense node into the database.""" + name = data.get("name") + id = f"{layer_id}:{name}" - returns: - table: the table - """ - spec_fixed = [ - sql.Column("source", sql.Text, primary_key=True, index=True), - sql.Column("target", sql.Text, primary_key=True, index=True), - sql.Column( - "iteration", sql.Integer, primary_key=True, index=True, unique=False - ), - sql.Column("weight", sql.Integer), - ] - - table = sql.Table( - name, - Base.metadata, - *spec_fixed, - *[sql.Column(key, sql.Integer) for key, value in spec_variadic.items()], + dense_node = LayerDenseNodes( + id=id, name=name, layer_id=layer_id, node_type=node_type, data=data ) + session.merge(dense_node) + session.commit() + log.debug(f"Inserted dense node {name} in layer {layer_id}") - class AggEdge: - """Aggregated edge.""" - def __repr__(self): - return f"""""" +class LayerSparseEdges(Base): + """Table for sparse data storage.""" - mapper_registry.map_imperatively(AggEdge, table) + __tablename__ = "layer_sparse_store" - return table, AggEdge, create_factory(AggEdge, spec_fixed, spec_variadic) + id: orm.Mapped[str] = orm.mapped_column(primary_key=True, index=True) + layer_id: orm.Mapped[str] = orm.mapped_column(index=True) + source: orm.Mapped[str] = orm.mapped_column(index=True) + target: orm.Mapped[str] = orm.mapped_column(index=True) + edge_type: orm.Mapped[str] = orm.mapped_column(index=True) + weight: orm.Mapped[float] = orm.mapped_column() + created_at: orm.Mapped[datetime.datetime] = orm.mapped_column( + index=True, insert_default=lambda: datetime.datetime.now(datetime.timezone.utc) + ) + data: orm.Mapped[Dict] = orm.mapped_column() + + +def insert_layer_sparse_edge(session, layer_id, edge_type, data): + """Insert a sparse edge into the database.""" + source = data.get("source") + target = data.get("target") + weight = data.get("weight") + id = f"{layer_id}:{source}-{target}" + + sparse_edge = LayerSparseEdges( + id=id, + source=source, + target=target, + weight=weight, + edge_type=edge_type, + layer_id=layer_id, + data=data, + ) + session.merge(sparse_edge) + session.commit() + log.debug(f"Inserted sparse edge from {source} to {target} in layer {layer_id}") -def create_node_table( - name: str, spec_variadic: Dict[str, str] -) -> Tuple[sql.Table, Type["Node"], Callable]: - """Create a node table dynamically. +class LayerSparseNodes(Base): + """Table for sparse data storage.""" - parameters: - name: name of the table - spec_variadic: dict of variadic columns + __tablename__ = "layer_sparse_nodes" - returns: - table: the table - """ - spec_fixed = [ - sql.Column("name", sql.Text, primary_key=True, index=True), - sql.Column("iteration", sql.Integer, index=True, unique=False), - ] - - table = sql.Table( - name, - Base.metadata, - *spec_fixed, - *[ - sql.Column(key, type_lookup.get(value)) - for key, value in spec_variadic.items() - ], + id: orm.Mapped[str] = orm.mapped_column(primary_key=True, index=True) + layer_id: orm.Mapped[str] = orm.mapped_column(index=True) + name: orm.Mapped[str] = orm.mapped_column() + node_type: orm.Mapped[str] = orm.mapped_column(index=True) + created_at: orm.Mapped[datetime.datetime] = orm.mapped_column( + index=True, insert_default=lambda: datetime.datetime.now(datetime.timezone.utc) ) + data: orm.Mapped[Dict] = orm.mapped_column() - class Node: - """Node.""" - def __repr__(self): - return f"""""" +def insert_layer_sparse_node(session, layer_id, node_type, data): + """Insert a sparse node into the database.""" + name = data.get("name") + id = f"{layer_id}:{name}" - mapper_registry.map_imperatively(Node, table) + sparse_node = LayerSparseNodes( + id=id, name=name, layer_id=layer_id, node_type=node_type, data=data + ) + session.merge(sparse_node) + session.commit() + log.debug(f"Inserted sparse node {name} in layer {layer_id}") - return table, Node, create_factory(Node, spec_fixed, spec_variadic) +class SamplerStateStore(Base): + """Table for storing the state of the sampler.""" -def create_sampler_state_table( - name: str, spec_variadic: Dict[str, str] -) -> Tuple[sql.Table, Type["SamplerState"], Callable]: - """Create a sampler state table dynamically.""" + __tablename__ = "sampler_state_store" - table = sql.Table( - name, - Base.metadata, - sql.Column("id", sql.Integer, primary_key=True, index=True, autoincrement=True), - *[ - sql.Column(key, type_lookup.get(value)) - for key, value in spec_variadic.items() - ], + id: orm.Mapped[int] = orm.mapped_column(primary_key=True, autoincrement=True) + iteration: orm.Mapped[int] = orm.mapped_column(index=True) + layer_id: orm.Mapped[str] = orm.mapped_column(index=True) + data: orm.Mapped[Dict] = orm.mapped_column() + created_at: orm.Mapped[datetime.datetime] = orm.mapped_column( + insert_default=lambda: datetime.datetime.now(datetime.timezone.utc) ) - class SamplerState: - """Sampler state.""" - - mapper_registry.map_imperatively(SamplerState, table) - return table, SamplerState, create_factory(SamplerState, [], spec_variadic) +def insert_sampler_state(session, layer_id, iteration, data): + """Insert the state of the sampler into the database.""" + sampler_state = SamplerStateStore(iteration=iteration, layer_id=layer_id, data=data) + session.add(sampler_state) + session.commit() + log.debug("Inserted sampler state for layer {layer_id} at iteration {iteration}") diff --git a/spiderexpress/plugin_manager.py b/spiderexpress/plugin_manager.py index e626ff9..4af73e0 100644 --- a/spiderexpress/plugin_manager.py +++ b/spiderexpress/plugin_manager.py @@ -46,7 +46,7 @@ def get_plugin(spec: PlugInSpec, group: str) -> Callable: def _(spec: str, group: str) -> Callable: plugin = _access_entry_point(spec, group) if not plugin: - raise ValueError(f"{ spec } could not be found in { group }") + raise ValueError(f"{spec} could not be found in {group}") return functools.partial( plugin.callable, configuration=plugin.default_configuration ) @@ -56,7 +56,7 @@ def _(spec: str, group: str) -> Callable: def _(spec: dict, group: str) -> Callable: if len(spec.keys()) > 1: log.warning( - "Requesting specification has more than one type." + f"Requested specification {spec} has more than one type. " "Using the first instance found" ) for name, configuration in spec.items(): diff --git a/spiderexpress/router.py b/spiderexpress/router.py index 5a7c4ee..ee0348b 100644 --- a/spiderexpress/router.py +++ b/spiderexpress/router.py @@ -2,8 +2,13 @@ """ + import re -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional + +from loguru import logger as log + +from spiderexpress.types import RouterSpec class RouterValidationError(ValueError): @@ -11,39 +16,76 @@ class RouterValidationError(ValueError): class Router: - """Creates an edge router from the given specification. + r"""Creates an edge router from the given specification. Arguments: name (str): the layer name to bind to spec (Dict[str, Union[str, List[TargetSpec]]]): how data should be routed Raises: - ValueError: if the specification is malformed + RouterValidationError: if the specification is malformed + + Example: + spec = { + "source": "handle", + "target": [ + { + "field": "text", + "pattern": r"https://www\.twitter\.com/(\w+)", + "dispatch_with": "test", + "type": "twitter-url", + } + ], + "view_count": "view_count", + } + input_data = { + "handle": "Tony", + "text": "Check this out: https://www.twitter.com/ernie", + "view_count": 123, + } + router = Router("test", spec) + result = router.parse(input_data) + print(result) + # Output: [{ + 'source': 'Tony', 'target': 'ernie', 'view_count': 123, + 'dispatch_with': 'test', 'type': 'twitter-url' + }] """ - def __init__(self, name: str, spec: Dict[str, Any], context: Optional[Dict] = None): + TARGET = "target" + SOURCE = "source" + + def __init__(self, name: str, spec: RouterSpec, context: Optional[Dict] = None): # Store the layer name self.name = name # Validate the spec against the rule set Router.validate_spec(name, spec, context) - self.spec = spec.copy() + self.spec: RouterSpec = spec @classmethod def validate_spec(cls, name, spec, context): """Validates a spec in a context.""" - if "to" not in spec: - raise RouterValidationError(f"{name}: Key 'to' is missing from {spec}.") - if not isinstance(spec.get("to"), list): + # pylint: disable=R0912 + if Router.TARGET not in spec: + raise RouterValidationError( + f"{name}: Key {Router.TARGET} is missing from {spec}." + ) + if Router.SOURCE not in spec: raise RouterValidationError( - f"{name}: 'to' is not a list but {spec.get('to')}." + f"{name}: Key {Router.SOURCE} is missing from {spec}." ) - for target_spec in spec.get("to"): - mandatory_fields = ["field", "dispatch_with"] - for field in mandatory_fields: - if target_spec.get(field) is None: - raise RouterValidationError( - f"{name}: '{ field }' cannot be None in {target_spec}" - ) + if not isinstance(spec.get(Router.TARGET), list): + raise RouterValidationError( + f"{name}: 'to' is not a list but '{spec.get('to')}'." + ) + if isinstance(spec.get(Router.TARGET), list): + for target_spec in spec.get(Router.TARGET): + mandatory_fields = ["field", "dispatch_with"] + for field in mandatory_fields: + if target_spec.get(field) is None: + raise RouterValidationError( + f"{name}: '{field}' cannot be None in {target_spec}" + ) if context is None: return @@ -54,42 +96,45 @@ def validate_spec(cls, name, spec, context): this_connector = connectors.get(name) for _, data_column_name in spec.items(): if data_column_name not in this_connector: - raise RouterValidationError( - f"{ name }: { data_column_name } not found." - ) - for target_spec in spec.get("to"): + raise RouterValidationError(f"{name}: {data_column_name} not found.") + for target_spec in spec.get(Router.TARGET): field = target_spec.get("field") if field not in connectors.get(name).get("columns"): raise RouterValidationError( f"{name}: reference to {field} not found in " f"context." ) - def parse(self, input_data): + def parse(self, input_data) -> List[Dict[str, Any]]: """Parses data with the given spec and emits edges.""" ret = [] constant = {} + + log.debug(f"Router '{self.name}' parsing {input_data}") + # First we calculate all constants for edge_key, spec in self.spec.items(): if isinstance(spec, str): constant[edge_key] = input_data.get(spec) - for directive in self.spec.get("to", []): - value = input_data.get(directive.get("field")) - # Add further constants if there are some defined in the spec - local_constant = { - **{ - key: value - for key, value in directive.items() - if key not in ["field", "pattern"] - }, - **constant, - } - if "pattern" not in directive: - # Simply get the value and return a - ret.append({"to": value, **local_constant}) - continue - # Get all matches from the string and return an edge for each - matches = re.findall(directive.get("pattern"), value) - for match in matches: - ret.append({"to": match, **local_constant}) - - return ret + if isinstance(self.spec.get(Router.TARGET), list): + for directive in self.spec.get(Router.TARGET, []): + value = input_data.get(directive.get("field")) + # Add further constants if there are some defined in the spec + local_constant = { + **{ + key: value + for key, value in directive.items() + if key not in ["field", "pattern"] + }, + **constant, + } + if "pattern" not in directive: + # Simply get the value and return a + ret.append({Router.TARGET: value, **local_constant}) + continue + # Get all matches from the string and return an edge for each + matches = re.findall(directive.get("pattern"), value) + for match in matches: + ret.append({Router.TARGET: match, **local_constant}) + return ret + + return [constant] diff --git a/spiderexpress/spider.py b/spiderexpress/spider.py index deb29fd..e0f8a47 100644 --- a/spiderexpress/spider.py +++ b/spiderexpress/spider.py @@ -14,7 +14,7 @@ from datetime import datetime from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union import pandas as pd import sqlalchemy as sql @@ -26,15 +26,20 @@ from spiderexpress.model import ( AppMetaData, Base, + LayerDenseEdges, + LayerDenseNodes, + SamplerStateStore, SeedList, TaskList, - create_aggregated_edge_table, - create_node_table, - create_raw_edge_table, - create_sampler_state_table, + insert_layer_dense_edge, + insert_layer_dense_node, + insert_layer_sparse_edge, + insert_layer_sparse_node, + insert_sampler_state, ) -from spiderexpress.plugin_manager import get_plugin, get_table_configuration -from spiderexpress.types import Configuration, Connector, Strategy +from spiderexpress.plugin_manager import get_plugin +from spiderexpress.router import Router +from spiderexpress.types import Configuration, Connector, Strategy, from_dict # pylint: disable=W0613,E1101,C0103 @@ -44,9 +49,6 @@ MAX_RETRIES = 3 -factory_registry = {} -class_registry = {} - class Spider: """This is spiderexpress' Spider. @@ -70,7 +72,7 @@ class Spider: "idle", { "name": "starting", - "on_enter": ["open_database", "load_plugins"], + "on_enter": ["open_database"], }, { "name": "gathering", @@ -88,6 +90,9 @@ class Spider: "name": "stopping", "on_enter": "close_database", }, + { + "name": "stopped", + }, ] """List of states the spider can be in. @@ -147,6 +152,11 @@ class Spider: "dest": "stopping", # "conditions": "iteration_limit_reached", }, + { + "trigger": "end", + "source": "stopping", + "dest": "stopped", + }, ] """List of transitions the spider can make.""" @@ -167,14 +177,14 @@ def __init__( self.retry_count = 0 # set the loaded configuration to None, as it is not loaded yet - self.configuration: Optional[Configuration] = None + self.configuration: Optional[Configuration] = configuration self.connector: Optional[Connector] = None self.strategy: Optional[Strategy] = None self._cache_: Optional[orm.Session] = None self.appstate: Optional[AppMetaData] = None @property - def task_buffer(self): + def task_buffer(self) -> List[TaskList]: """Returns the task buffer, which is a list of tasks that are currently being processed.""" return self._cache_.query(TaskList).filter(TaskList.status == "new").all() @@ -191,8 +201,12 @@ def is_gathering_not_done(self): def _conditional_advance(self, *args) -> None: """Advances the state machine when the current state is done.""" + # pylint: disable=R0911 if self.state == "idle": return + if self.state == "stopping": + self.trigger("end") + return if self.state == "gathering": if self.may_sample(): log.debug("Advancing from gathering to sampling") @@ -215,7 +229,9 @@ def _conditional_advance(self, *args) -> None: targets = self.machine.get_triggers(self.state) - log.debug(f"Advancing from {self.state} and I can trigger {', '.join(targets)}") + log.debug( + f"Advancing from {self.state} and I can trigger {', '.join(targets) or 'nothing'}." + ) for target in targets: if self.trigger(target) is True: @@ -235,13 +251,7 @@ def load_config(self, config_file: Path) -> None: ) with config_file.open("r", encoding="utf8") as file: - self.configuration = yaml.full_load(file) - else: - self.configuration = ( - self.configuration - if isinstance(self.configuration, Configuration) - else Configuration(**self.configuration) - ) + self.configuration = from_dict(Configuration, yaml.safe_load(file)) def is_config_valid(self): """Asserts that the configuration is valid.""" @@ -273,37 +283,6 @@ def open_database(self, *args) -> None: self._cache_ = orm.Session(engine) - _, Node, node_factory = create_node_table( - self.configuration.node_table["name"], - self.configuration.node_table["columns"], - ) - _, RawEdge, raw_edge_factory = create_raw_edge_table( - self.configuration.edge_raw_table["name"], - self.configuration.edge_raw_table["columns"], - ) - _, AggEdge, agg_edge_factory = create_aggregated_edge_table( - self.configuration.edge_agg_table["name"], - self.configuration.edge_agg_table["columns"], - ) - - class_registry["node"] = Node - class_registry["raw_edge"] = RawEdge - class_registry["agg_edge"] = AggEdge - - factory_registry["node"] = node_factory - factory_registry["raw_edge"] = raw_edge_factory - factory_registry["agg_edge"] = agg_edge_factory - - strategy_name = ( - self.configuration.strategy - if isinstance(self.configuration.strategy, str) - else list(self.configuration.strategy.keys())[0] - ) - - _, SamplerState, _ = create_sampler_state_table( # pylint: disable=W0612 - strategy_name, get_table_configuration(strategy_name, STRATEGY_GROUP) - ) - class_registry["sampler_state"] = SamplerState Base.metadata.create_all(engine) appstate = self._cache_.get(AppMetaData, "1") @@ -332,24 +311,27 @@ def initialize_seeds(self): log.debug(f"Copying seeds to database: {', '.join(self.configuration.seeds)}.") # with self._cache_.begin(): - for seed in self.configuration.seeds: - if self._cache_.get(SeedList, seed) is None: - _seed = SeedList(id=seed, iteration=0, status="new") - self._cache_.add(_seed) - self._add_task(_seed) - - def _add_task(self, task: Union[SeedList, str], parent: Optional[TaskList] = None): + for layer, seeds in self.configuration.seeds.items(): + for seed in seeds: + if self._cache_.get(SeedList, seed) is None: + _seed = SeedList(id=seed, iteration=0, status="new") + self._cache_.add(_seed) + self._add_task(_seed, layer=layer) + + def _add_task( + self, task: Union[SeedList, str], layer: str, parent: Optional[TaskList] = None + ): """Adds a task to the task buffer.""" if not self._cache_: raise ValueError("Cache is not present.") - if not isinstance(task, (SeedList, class_registry["node"], str)): + if not isinstance(task, (SeedList, LayerDenseNodes, str)): raise ValueError( - "Task must be a seed, a node or a node-identifier, but is {type(task).__name__}" + f"Task must be a seed, a node or a node-identifier, but is {type(task).__name__}" ) node_id = ( task.name - if isinstance(task, (class_registry["node"])) + if isinstance(task, LayerDenseNodes) else task if isinstance(task, str) else task.id ) @@ -362,7 +344,7 @@ def _add_task(self, task: Union[SeedList, str], parent: Optional[TaskList] = Non node_id=node_id, status="new", initiated_at=datetime.now(), - connector="stub_value", + connector=layer, parent_task_id=parent.id if parent else None, ) self._cache_.add(new_task) @@ -409,8 +391,8 @@ def retry_with_unused_seeds(self): candidates = ( self._cache_.execute( - sql.select(class_registry["node"].name).where( - class_registry["node"].name.not_in(candidate_nodes_names) + sql.select(LayerDenseNodes.name).where( + LayerDenseNodes.name.not_in(candidate_nodes_names) ) ) .scalars() @@ -446,15 +428,15 @@ def gather_node_data(self): if len(self.task_buffer) == 0: return - task: TaskList = self.task_buffer.pop(0) + task = self.task_buffer.pop(0) log.debug(f"Attempting to gather data for {task.node_id}.") # Begin transaction with the cache - node_info = self._cache_.get(class_registry["node"], task.node_id) + node_info = self._cache_.get(LayerDenseNodes, task.node_id) if node_info is None: - self._dispatch_connector_for_node_(task) + self._dispatch_connector_for_node_(task, task.connector) # Mark the node as done seed = self._cache_.get(SeedList, task.node_id) @@ -479,137 +461,135 @@ def iteration_limit_reached(self): def sample_network(self): """Samples the network.""" + # pylint: disable=R0914 if not self._cache_: raise ValueError("Cache is not present.") - - log.debug("Attempting to sample the network.") - - aggregation_spec = self.configuration.edge_agg_table["columns"] - aggregation_funcs = { - "count": sql.func.count, - "max": sql.func.max, - "min": sql.func.min, - "sum": sql.func.sum, - "avg": sql.func.avg, - } - - aggregations = [ - sql.func.count().label( # pylint: disable=E1102 # not-callable, but it is :shrug: - "weight" - ), - *[ - aggregation_funcs[aggregation]( - getattr(class_registry["raw_edge"], column) - ).label(column) - for column, aggregation in aggregation_spec.items() - ], - ] - sql_statement = ( - self._cache_.query( - class_registry["raw_edge"].source, - class_registry["raw_edge"].target, - *aggregations, + for layer_id, layer_config in self.configuration.layers.items(): + # Get data for the layer from the dense data stores + edges = pd.read_sql( + self._cache_.query( + LayerDenseEdges.source, + LayerDenseEdges.target, + sql.func.count("*").label("weight"), # pylint: disable=E1102 + ) + .where(LayerDenseEdges.layer_id == layer_id) + .group_by(LayerDenseEdges.source, LayerDenseEdges.target) + .statement, + self._cache_.connection(), ) - .group_by( - class_registry["raw_edge"].source, class_registry["raw_edge"].target + nodes = pd.json_normalize( + pd.read_sql( + self._cache_.query(LayerDenseNodes.data) + .where(LayerDenseNodes.layer_id == layer_id) + .statement, + self._cache_.connection(), + ).data + ) + sampler_state = pd.json_normalize( + pd.read_sql( + sql.select(SamplerStateStore.data).where( + SamplerStateStore.layer_id == layer_id + ), + self._cache_.connection(), + ).data ) - .statement - ) - - log.debug(f"Aggregation query: {sql_statement}") - - edges = pd.read_sql( - sql_statement, - self._cache_.connection(), - ) - nodes = pd.read_sql( - self._cache_.query(class_registry["node"]).statement, - self._cache_.connection(), - ) - sampler_state = pd.read_sql( - sql.select(class_registry["sampler_state"]), self._cache_.connection() - ) - - log.info( - f"""Sampling from { - len(edges) - } edges with { - len(nodes) - } nodes while the sampler's state is { len(sampler_state) } long.""" - ) - - log.info(f"That's the current state of affairs: { sampler_state }") - new_seeds, new_edges, _, new_sampler_state = self.strategy( - edges, nodes, sampler_state - ) + log.debug( + f""" + Sampling layer {layer_id} with {len(edges)} edges and {len(nodes)} nodes. + Edges to sample: +{edges} - new_edges["iteration"] = self.appstate.iteration + Sampler state: +{sampler_state} +""" + ) - if len(new_seeds) == 0: - log.warning("Found no new seeds.") - elif self.retry_count > 0: - self.retry_count = 0 + sampler: Strategy = get_plugin(layer_config.sampler, STRATEGY_GROUP) + new_seeds, sparse_edges, sparse_nodes, new_sampler_state = sampler( + edges, nodes, sampler_state + ) - for seed in new_seeds: - if seed is None: - continue - if self._cache_.get(SeedList, seed) is None: - _seed = SeedList( - id=seed, iteration=self.appstate.iteration + 1, status="new" + log.info(f"That's the current state of affairs:\n\n{new_sampler_state}") + + sparse_edges["iteration"] = self.appstate.iteration + if len(new_seeds) == 0: + log.warning("Found no new seeds.") + elif self.retry_count > 0: + self.retry_count = 0 + + for seed in new_seeds: + if seed is None: + continue + if self._cache_.get(SeedList, seed) is None: + _seed = SeedList( + id=seed, iteration=self.appstate.iteration + 1, status="new" + ) + self._cache_.add(_seed) + self._add_task(_seed, layer=layer_id) + + for edge in sparse_edges.to_dict(orient="records"): + if edge["source"] is not None and edge["target"] is not None: + insert_layer_sparse_edge(self._cache_, layer_id, "test", data=edge) + for node in sparse_nodes.to_dict(orient="records"): + insert_layer_sparse_node(self._cache_, layer_id, "test", node) + for state in new_sampler_state.to_dict(orient="records"): + insert_sampler_state( + self._cache_, layer_id, self.appstate.iteration, state ) - self._cache_.add(_seed) - self._add_task(_seed) - self._cache_.add_all( - [ - factory_registry["agg_edge"](edge) - for edge in new_edges.to_dict(orient="records") - if edge["source"] is not None and edge["target"] is not None - ] - ) - for state in new_sampler_state.to_dict(orient="records"): - self._cache_.merge(class_registry["sampler_state"](**state)) - self._cache_.commit() - - def load_plugins(self, *args): - """Loads the plug-ins.""" - self.strategy = get_plugin(self.configuration.strategy, STRATEGY_GROUP) - self.connector = get_plugin(self.configuration.connector, CONNECTOR_GROUP) + self._cache_.commit() # section: private methods - def _dispatch_connector_for_node_(self, node: TaskList): - if not self.configuration or not self.connector: + def _dispatch_connector_for_node_(self, node: TaskList, layer: str): + # pylint: disable=R0914 + if not self.configuration: raise ValueError("Configuration or Connector are not present") + # Get the connector for the layer + layer_configuration = self.configuration.layers[layer] + connector_spec = layer_configuration.connector + connector = get_plugin(connector_spec, CONNECTOR_GROUP) - edges, nodes = self.connector([node.node_id]) + log.debug(f"Requesting data for {node.node_id} from {connector_spec}.") - log.debug(f"edges:\n{edges}\n\nnodes:{nodes}\n") + raw_edges, nodes = connector([node.node_id]) - if len(edges) > 0: - log.info( - f"""Persisting { - len(edges) - } edges for node { - node.node_id - } in iteration - #{ - self.appstate.iteration - }.""" - ) - edges["iteration"] = self.appstate.iteration - for edge in edges.to_dict(orient="records"): - self._cache_.merge(factory_registry["raw_edge"](edge)) - - if self.configuration.eager is True and node.parent_task_id is None: - # We add only new task if the parent_task_id is None to avoid snowballing - # the entire population before we even begin sampling. - for target in edges["target"].unique().tolist(): - self._add_task(target, parent=node) - self._cache_.commit() + routers = layer_configuration.routers + for router_definition in routers: + for router_name, router_spec in router_definition.items(): + + log.debug( + f"Routing data with {router_name} and this spec: {router_spec}." + ) + + router = Router(router_name, router_spec) + for raw_edge in raw_edges.to_dict(orient="records"): + edges = router.parse(raw_edge) + for edge in edges: + edge["iteration"] = self.appstate.iteration + insert_layer_dense_edge( + self._cache_, edge.get("dispatch_with"), router_name, edge + ) + + log.debug(f"Inserted edge: {edge}") + + if ( + layer_configuration.eager is True + and node.parent_task_id is None + ): + # We add only new task if the parent_task_id is None to avoid snowballing + # the entire population before we even begin sampling. + targets = {edge.get("target") for edge in edges} + for target in targets: + self._add_task( + target, + parent=node, + layer=router_spec.get("dispatch_with"), + ) + self._cache_.commit() if len(nodes) > 0: nodes["iteration"] = self.appstate.iteration for _node in nodes.to_dict(orient="records"): - self._cache_.merge(factory_registry["node"](_node)) + insert_layer_dense_node(self._cache_, "test", "rest", _node) self._cache_.commit() diff --git a/spiderexpress/strategies/random.py b/spiderexpress/strategies/random.py index 6268dcd..4376a44 100644 --- a/spiderexpress/strategies/random.py +++ b/spiderexpress/strategies/random.py @@ -15,28 +15,32 @@ def random_strategy( ): """Random sampling strategy.""" # split the edges table into edges _inside_ and _outside_ of the known network + is_first_round = state.empty + if is_first_round: + state = pd.DataFrame({"node_id": edges.source.unique()}) mask = edges.target.isin(state.node_id) - edges_inward = edges.loc[mask, :] edges_outward = edges.loc[~mask, :] # select 10 edges to follow if len(edges_outward) < configuration["n"]: - edges_sampled = edges_outward + sparse_edges = edges_outward else: - edges_sampled = edges_outward.sample(n=configuration["n"], replace=False) + sparse_edges = edges_outward.sample(n=configuration["n"], replace=False) - new_seeds = edges_sampled.target # select target node names as seeds for the + new_seeds = ( + sparse_edges.target.unique() + ) # select target node names as seeds for the # next layer - edges_to_add = pd.concat([edges_inward, edges_sampled]) # add edges inside the - # known network as well as the sampled edges to the known network - new_nodes = nodes.loc[nodes.name.isin(new_seeds), :] - - return new_seeds, edges_to_add, new_nodes + sparse_nodes = nodes.loc[nodes.name.isin(new_seeds), :] + new_state = pd.DataFrame({"node_id": new_seeds}) + if is_first_round: + new_state = pd.concat([state, new_state]) + return new_seeds, sparse_edges, sparse_nodes, new_state random = PlugIn( callable=random_strategy, - tables={"node_": "Text"}, + tables={"state": {"node_id": "Text"}}, metadata={}, default_configuration={"n": 10}, ) diff --git a/spiderexpress/strategies/spikyball.py b/spiderexpress/strategies/spikyball.py index 063f4cd..ad6cfa7 100644 --- a/spiderexpress/strategies/spikyball.py +++ b/spiderexpress/strategies/spikyball.py @@ -13,7 +13,7 @@ import pandas as pd from loguru import logger as log -from ..types import PlugIn, fromdict +from ..types import PlugIn, from_dict @dataclass @@ -319,23 +319,30 @@ def spikyball_strategy( """ if isinstance(configuration, dict): - configuration = fromdict(SpikyBallConfiguration, configuration) + configuration = from_dict(SpikyBallConfiguration, configuration) + first_round = state.empty + if first_round: + state = pd.DataFrame({"node_id": edges.source.unique()}) e_in, e_out = filter_edges(edges, state.node_id.tolist()) - seeds, e_sampled = sample_edges( + seeds, sparse_edges = sample_edges( e_out, nodes, configuration.sampler, configuration.layer_max_size, ) - state = pd.concat([state, pd.DataFrame({"node_id": seeds})]) + if first_round: + state = pd.concat([state, pd.DataFrame({"node_id": seeds})]) + else: + state = pd.DataFrame({"node_id": seeds}) + sparse_nodes = nodes.loc[nodes.name.isin(seeds), :] - return seeds, pd.concat([e_in, e_sampled]), e_out, state + return seeds, sparse_edges, sparse_nodes, state spikyball = PlugIn( callable=spikyball_strategy, default_configuration={}, - tables={"node_id": "Text"}, + tables={"state": {"node_id": "Text"}}, metadata={}, ) diff --git a/spiderexpress/types.py b/spiderexpress/types.py index c208f47..ee088e8 100644 --- a/spiderexpress/types.py +++ b/spiderexpress/types.py @@ -1,19 +1,17 @@ # pylint: disable=R -"""Type definitions for spiderexpress +"""Type definitions for spiderexpress'DSL. Philipp Kessling Leibniz-Institute for Media Research, 2022 - """ - - -from dataclasses import dataclass, fields, is_dataclass +import json +from dataclasses import field, fields, is_dataclass from pathlib import Path -from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union import pandas as pd -import yaml +from pydantic.dataclasses import dataclass Connector = Callable[[List[str]], Tuple[pd.DataFrame, pd.DataFrame]] """Connector Interface @@ -29,7 +27,19 @@ [pd.DataFrame, pd.DataFrame, pd.DataFrame], Tuple[List[str], pd.DataFrame, pd.DataFrame, pd.DataFrame], ] -"""Strategy Interface""" +"""Strategy Interface. + +Args: + edges (pd.DataFrame): the edge table + nodes (pd.DataFrame): the node table + state (pd.DataFrame): the last state of the sampler + +Returns: + (pd.DataFrame): A list of new seed nodes + (pd.DataFrame): Update for the sparse edge table + (pd.DataFrame): Update for the sparse edge table + (pd.DataFrame): Update for the state table +""" PlugInSpec = Union[str, Dict[str, Union[str, Dict[str, Union[str, int]]]]] """Plug-In Definition Notation. @@ -38,48 +48,78 @@ ColumnSpec = Dict[str, Union[str, Dict[str, str]]] """Column Specification.""" +StopCondition = Literal["stop", "retry"] + + +@dataclass() +class Configuration: + """Configuration-File Wrapper. + + Attributes: + seeds (Dict[str, List[str]], optional): A dictionary of seeds. Defaults to None. + seed_file (str, optional): A JSON-file containing seeds per layer. + Takes precedence over `seeds`. Defaults to None. + project_name (str, optional): The name of the project. Defaults to "spider". + db_url (str, optional): The database url. Defaults to None. + db_schema (str, optional): The database schema. Defaults to None. + empty_seeds (StopCondition, optional): What to do if the seeds are empty. + Defaults to "stop". + layers (List[Dict], optional): A list of layer configurations. Defaults to None. + max_iteration (int, optional): The maximum number of iterations. Defaults to 10000. + """ + + seeds: Optional[Dict[str, List[str]]] = None + seed_file: Optional[str] = None + project_name: str = "spider" + db_url: Optional[str] = None + db_schema: Optional[str] = None + empty_seeds: StopCondition = "stop" + layers: Dict[str, "Layer"] = field(default_factory=dict) + max_iteration: int = 10000 + + def __post_init__(self) -> None: + """Configuration-File Wrapper for SpiderExpress""" + if self.seeds is None: + if self.seed_file is None: + raise ValueError("Either seeds or seed_file must be provided.") + _seed_file = Path(self.seed_file) + if not _seed_file.exists(): + raise FileNotFoundError(f"Seed file {_seed_file.resolve()} not found.") + with _seed_file.open("r", encoding="utf8") as file: + self.seeds = json.load(file) + self.db_url = self.db_url or f"sqlite:///{self.project_name}.db" + if self.db_url.startswith("sqlite") and self.db_schema is not None: + raise ValueError("SQLite does not support schemas.") + self.empty_seeds = ( + self.empty_seeds if self.empty_seeds in ["stop", "retry"] else "stop" + ) + + +@dataclass +class FieldSpec: + """Field Specification""" + + field: str + dispatch_with: str + regex: Optional[str] = None + + +@dataclass +class RouterSpec: + """Router Configuration""" + + source: str + target: List[FieldSpec] + + +@dataclass +class Layer: + """Layer Configuration""" -class Configuration(yaml.YAMLObject): - """Configuration-File Wrapper""" - - yaml_tag = "!spiderexpress:Configuration" - - def __init__( - self, - seeds: Optional[List[str]] = None, - seed_file: Optional[str] = None, - project_name: str = "spider", - db_url: Optional[str] = None, - db_schema: Optional[str] = None, - empty_seeds: str = "stop", - eager: bool = True, - edge_raw_table: Optional[ColumnSpec] = None, - edge_agg_table: Optional[ColumnSpec] = None, - node_table: Optional[ColumnSpec] = None, - strategy: PlugInSpec = "spikyball", - connector: PlugInSpec = "telegram", - max_iteration: int = 10000, - batch_size: int = 150, - ) -> None: - if seed_file is not None: - _seed_file = Path(seed_file) - if _seed_file.exists(): - with _seed_file.open("r", encoding="utf8") as file: - self.seeds = list(file.readlines()) - else: - self.seeds = seeds - self.strategy = strategy - self.connector = connector - self.project_name = project_name - self.db_url = db_url or f"sqlite:///{project_name}.db" - self.db_schema = db_schema - self.edge_raw_table = edge_raw_table or {"name": "edge_raw", "columns": {}} - self.edge_agg_table = edge_agg_table or {"name": "edge_agg", "columns": {}} - self.node_table = node_table or {"name": "node", "columns": {}} - self.max_iteration = max_iteration - self.batch_size = batch_size - self.empty_seeds = empty_seeds if empty_seeds in ["stop", "retry"] else "stop" - self.eager = eager + connector: Dict + routers: List[Dict[str, Dict]] + eager = False + sampler: Dict @dataclass @@ -93,7 +133,7 @@ class ConfigurationItem: T = TypeVar("T") -def fromdict(cls: Type[T], dictionary: dict) -> T: +def from_dict(cls: Type[T], dictionary: dict) -> T: """convert a dictionary to a dataclass warning: @@ -107,12 +147,12 @@ def fromdict(cls: Type[T], dictionary: dict) -> T: returns: the dataclass with values from the dictionary """ - fieldtypes: Dict[str, Type] = {f.name: f.type for f in fields(cls)} + field_types = {f.name: f.type for f in fields(cls)} return cls( **{ key: ( - fromdict(fieldtypes[key], value) - if isinstance(value, dict) and is_dataclass(fieldtypes[key]) + from_dict(field_types[key], value) + if isinstance(value, dict) and is_dataclass(field_types[key]) else value ) for key, value in dictionary.items() diff --git a/tests/stubs/seeds.json b/tests/stubs/seeds.json new file mode 100644 index 0000000..947e29a --- /dev/null +++ b/tests/stubs/seeds.json @@ -0,0 +1 @@ +{"test": ["1", "2", "3"]} diff --git a/tests/stubs/sevens_grader_random_test.pe.yml b/tests/stubs/sevens_grader_random_test.pe.yml index 9b4b697..1cc726d 100644 --- a/tests/stubs/sevens_grader_random_test.pe.yml +++ b/tests/stubs/sevens_grader_random_test.pe.yml @@ -1,31 +1,26 @@ -!spiderexpress:Configuration -batch_size: 150 -connector: - csv: - node_list_location: tests/stubs/7th_graders/nodes.csv - edge_list_location: tests/stubs/7th_graders/edges.csv - mode: out -db_url: sqlite:/// +db_url: sqlite:/// # sevens_grader_random_test.db db_schema: -eager: false empty_seeds: stop -edge_table_name: edge_list max_iteration: 10000 -node_table: - name: sevens_grader_nodes - columns: {} -edge_raw_table: - name: sevens_grader_edge_raw - columns: - layer: Integer -edge_agg_table: - name: sevens_grader_edge_agg - columns: - layer: Integer +layers: + test: + eager: false + connector: + csv: + node_list_location: tests/stubs/7th_graders/nodes.csv + edge_list_location: tests/stubs/7th_graders/edges.csv + mode: out + routers: + - all: # This is the name of the router and should be the type of edge. + source: source # This is the field that is mapped to the source columns. + target: + - field: target # This is the field that is mapped to the target columns. + dispatch_with: test # This is the name of the layer to dispatch to. + sampler: + random: + n: 5 project_name: spider seeds: - - "1" - - "13" -strategy: - random: - n: 5 + test: + - "1" + - "13" diff --git a/tests/stubs/sevens_grader_spikyball_test.pe.yml b/tests/stubs/sevens_grader_spikyball_test.pe.yml index c1628be..1d281d2 100644 --- a/tests/stubs/sevens_grader_spikyball_test.pe.yml +++ b/tests/stubs/sevens_grader_spikyball_test.pe.yml @@ -1,41 +1,37 @@ -!spiderexpress:Configuration -batch_size: 150 -connector: - csv: - node_list_location: tests/stubs/7th_graders/nodes.csv - edge_list_location: tests/stubs/7th_graders/edges.csv - mode: out -db_url: sqlite:/// +db_url: sqlite:/// # sevens_grader_random_test.db db_schema: -eager: false empty_seeds: stop -edge_table_name: edge_list max_iteration: 10000 -node_table: - name: sevens_grader_spkyball_nodes - columns: {} -edge_raw_table: - name: sevens_grader_ed_spkyballge_raw - columns: - layer: Integer -edge_agg_table: - name: sevens_grader_ed_spkyballge_agg - columns: - layer: max +layers: + test: + eager: false + connector: + csv: + node_list_location: tests/stubs/7th_graders/nodes.csv + edge_list_location: tests/stubs/7th_graders/edges.csv + mode: out + routers: + - all: # This is the name of the router and should be the type of edge. + source: source # This is the field that is mapped to the source columns. + target: + - field: target # This is the field that is mapped to the target columns. + dispatch_with: test # This is the name of the layer to dispatch to. + sampler: + spikyball: + layer_max_size: 5 + sampler: + source_node_probability: + coefficient: 1 + weights: { } + target_node_probability: + coefficient: 1 + weights: { } + edge_probability: + coefficient: 1 + weights: { } + project_name: spider seeds: - - "1" - - "13" -strategy: - spikyball: - layer_max_size: 5 - sampler: - source_node_probability: - coefficient: 1 - weights: {} - target_node_probability: - coefficient: 1 - weights: {} - edge_probability: - coefficient: 1 - weights: {} + test: + - "1" + - "13" diff --git a/tests/test_config.py b/tests/test_config.py index 0184da0..3d307b4 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,57 +1,82 @@ """test suite for spiderexpress.Configuration""" -from pathlib import Path + +from typing import Dict import pytest import yaml -from pytest import skip from spiderexpress import Configuration +from spiderexpress.types import from_dict # pylint: disable=W0621 -@pytest.fixture -def configuration(): - """Creates a configuration object.""" - return Configuration( - ["a", "b"], - None, - "test", - "memory", - ) - - -def test_initializer(configuration): - """Should instantiate a configuration object.""" - - assert configuration.project_name == "test" - assert configuration.db_url == "memory" - assert configuration.seeds == ["a", "b"] - - -def test_fields(): - """should do something""" - skip() - - -def test_serialization(configuration, tmpdir: Path): - """Should write out and re-read the configuration.""" - temp_conf = tmpdir / "test.pe.yml" - - with temp_conf.open("w") as file: - yaml.dump(configuration, file) - - assert temp_conf.exists() - - with temp_conf.open("r") as file: - configuration_2 = yaml.full_load(file) - - for key, value in configuration.__dict__.items(): - assert value == configuration_2.__dict__[key] - - -def test_seed_seedfile(): - """ - It should have either a seed file or a seed list, should throw otherwise. - """ - skip() +def test_parse_configuration_from_file(): + """Should parse a configuration from a YAML file.""" + with open( + "tests/stubs/sevens_grader_random_test.pe.yml", "r", encoding="utf8" + ) as file: + config = yaml.safe_load(file) + assert from_dict(Configuration, config) is not None + + +def test_parse_configuration_from_dict(): + """Should parse a configuration from a file.""" + config = from_dict(Configuration, {"project_name": "test", "seeds": {}}) + assert config is not None + assert config.project_name == "test" + assert config.db_url == "sqlite:///test.db" + assert config.max_iteration == 10000 + assert config.empty_seeds == "stop" + assert config.layers == {} + assert config.seeds == {} + + +def test_fail_to_open_seeds_from_file(): + """Should fail to parse a configuration from a file.""" + with pytest.raises(FileNotFoundError): + from_dict(Configuration, {"seed_file": "non_existent_file"}) + + +def test_open_seeds_from_file(): + """Should parse a configuration from a file with a seed file.""" + config = from_dict(Configuration, {"seed_file": "tests/stubs/seeds.json"}) + assert config is not None + assert config.seeds == {"test": ["1", "2", "3"]} + + +@pytest.mark.parametrize( + ["configuration"], + [ + pytest.param( + { + "layers": { + "test": { + "connector": {"csv": {"n": 1}}, + "routers": [], + "sampler": {"random": {"n": 1}}, + } + }, + "seeds": {"test": ["1", "13"]}, + }, + id="empty_router", + ), + ], +) +def test_parse_layer_configuration(configuration: Dict): + """Should parse a layer configuration.""" + config = from_dict(Configuration, configuration) + assert config is not None + assert config.project_name == "spider" + assert config.db_url == "sqlite:///spider.db" + assert config.max_iteration == 10000 + assert config.empty_seeds == "stop" + assert config.seeds == {"test": ["1", "13"]} + + +def test_layers_must_have_a_router_configuration(): + """Should fail to parse a configuration without a router configuration.""" + with pytest.raises(ValueError): + from_dict( + Configuration, {"layers": {"test": {}}, "seeds": {"test": ["1", "13"]}} + ) diff --git a/tests/test_model.py b/tests/test_model.py index 14a5132..02a9f30 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,18 +1,12 @@ """Test suite for the model module.""" + from datetime import datetime import pytest import sqlalchemy as sql from sqlalchemy.orm import Session -from spiderexpress.model import ( - AppMetaData, - Base, - SeedList, - create_aggregated_edge_table, - create_node_table, - create_raw_edge_table, -) +from spiderexpress.model import AppMetaData, Base, SeedList # pylint: disable=W0621 @@ -44,87 +38,6 @@ def create_tables(connection): Base.metadata.drop_all(connection) -def test_create_raw_edge_table_with_session(session, create_tables): - """Test the creation of a raw edge table.""" - _, RawEdge, edge_factory = create_raw_edge_table("raw_edges", {"weight": "Integer"}) - - create_tables() - - edge = edge_factory({"source": "a", "target": "b", "weight": 1, "view_count": 1}) - session.add(edge) - session.commit() - - assert session.query(RawEdge).count() == 1 - - -def test_create_raw_edge_table(): - """Test the creation of a raw edge table.""" - table, RawEdge, _ = create_raw_edge_table("raw_edges_2", {"weight": "Integer"}) - - assert table.name == "raw_edges_2" - assert len(table.columns) == 5 - - assert hasattr(RawEdge, "source") - assert hasattr(RawEdge, "target") - assert hasattr(RawEdge, "weight") - - -def test_create_aggregated_edge_table_with_session(session, create_tables): - """Test the creation of an aggregated edge table.""" - _, Edge, edge_factory = create_aggregated_edge_table( - "agg_edges", {"view_count": "Integer"} - ) - - create_tables() - - edge = edge_factory( - {"source": "a", "target": "b", "weight": 1, "view_count": 1, "iteration": 0} - ) - session.add(edge) - session.commit() - - assert session.query(Edge).count() == 1 - - -def test_create_aggregated_edge_table(): - """Test the creation of an aggregated edge table.""" - table, Edge, _ = create_aggregated_edge_table( - "agg_edges_2", {"view_count": "Integer"} - ) - - assert table.name == "agg_edges_2" - assert len(table.columns) == 5 - - assert hasattr(Edge, "source") - assert hasattr(Edge, "target") - assert hasattr(Edge, "view_count") - - -def test_create_node_table_session(session, create_tables): - """Test the creation of a node table.""" - - _, Node, node_factory = create_node_table("nodes", {"subscriber_count": "Integer"}) - - create_tables() - - node = node_factory({"name": "a", "subscriber_count": 1}) - session.add(node) - session.commit() - - assert session.query(Node).count() == 1 - - -def test_create_node_table_with(): - """Test the creation of a node table.""" - table, Node, _ = create_node_table("nodes2", {"subscriber_count": "Integer"}) - - assert table.name == "nodes2" - assert len(table.columns) == 3 - - assert hasattr(Node, "name") - assert hasattr(Node, "subscriber_count") - - def test_app_state_table(session, create_tables): """Test the creation of a node table.""" diff --git a/tests/test_router.py b/tests/test_router.py index 36bb51f..54a8d2d 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -1,4 +1,5 @@ """Test suite for spiderexpress' multi-layer router.""" + from typing import Any, Dict, List, Optional, Union import pytest @@ -9,20 +10,25 @@ @pytest.mark.parametrize( ["specification", "context"], [ - pytest.param({"from": "column", "to": "wrong_value"}, None, id="to_is_string"), - pytest.param({"from": "here", "ot": "bad"}, None, id="to_is_missing"), pytest.param( - {"from": "here", "to": [{"field": "yadad"}]}, + {"source": "column", "target": "wrong_value"}, None, id="target_is_string" + ), + pytest.param({"source": "here", "ot": "bad"}, None, id="target_is_missing"), + pytest.param( + {"source": "here", "target": [{"field": "yadad"}]}, {"connectors": {"test": {"columns": {"there": "Text"}}}}, id="from_column_is_missing_in_context", ), pytest.param( - {"from": "column", "to": [{"field": "column", "dispatch_with": "layer_1"}]}, + { + "source": "column", + "target": [{"field": "column", "dispatch_with": "layer_1"}], + }, {"connectors": {"layer_1": {"type": "something", "columns": {"column1"}}}}, id="column_is_missing_in_context", ), pytest.param( - {"from": "column", "to": [{"dispatch_with": "layer_1"}]}, + {"source": "column", "target": [{"dispatch_with": "layer_1"}]}, {"connectors": {"layer_1": {"type": "something", "columns": {"column1"}}}}, id="field is None", ), @@ -59,14 +65,14 @@ def test_router_spec_validation( pytest.param( input_data_1, { - "from": "handle", - "to": [{"field": "forwarded_handle", "dispatch_with": "test"}], + "source": "handle", + "target": [{"field": "forwarded_handle", "dispatch_with": "test"}], "view_count": "view_count", }, [ { - "from": "Tony", - "to": "Bert", + "source": "Tony", + "target": "Bert", "view_count": 123, "dispatch_with": "test", } @@ -76,8 +82,8 @@ def test_router_spec_validation( pytest.param( input_data_1, { - "from": "handle", - "to": [ + "source": "handle", + "target": [ { "field": "url", "pattern": r"https://www\.twitter\.com/(\w+)", @@ -88,8 +94,8 @@ def test_router_spec_validation( }, [ { - "from": "Tony", - "to": "ernie", + "source": "Tony", + "target": "ernie", "view_count": 123, "dispatch_with": "test", } @@ -99,8 +105,8 @@ def test_router_spec_validation( pytest.param( input_data_1, { - "from": "handle", - "to": [ + "source": "handle", + "target": [ { "field": "url", "pattern": r"https://www\.twitter\.com/(\w+)", @@ -112,8 +118,8 @@ def test_router_spec_validation( }, [ { - "from": "Tony", - "to": "ernie", + "source": "Tony", + "target": "ernie", "view_count": 123, "type": "twitter-url", "dispatch_with": "test", @@ -124,8 +130,8 @@ def test_router_spec_validation( pytest.param( input_data_1, { - "from": "handle", - "to": [ + "source": "handle", + "target": [ { "field": "text", "pattern": r"https://www\.twitter\.com/(\w+)", @@ -137,15 +143,15 @@ def test_router_spec_validation( }, [ { - "from": "Tony", - "to": "ernie", + "source": "Tony", + "target": "ernie", "view_count": 123, "dispatch_with": "test", "type": "twitter-url", }, { - "from": "Tony", - "to": "bobafett", + "source": "Tony", + "target": "bobafett", "view_count": 123, "dispatch_with": "test", "type": "twitter-url", @@ -156,8 +162,8 @@ def test_router_spec_validation( pytest.param( input_data_1, { - "from": "handle", - "to": [ + "source": "handle", + "target": [ { "field": "text", "pattern": r"https://www\.twitter\.com/(\w+)", @@ -169,21 +175,21 @@ def test_router_spec_validation( }, [ { - "from": "Tony", - "to": "ernie", + "source": "Tony", + "target": "ernie", "view_count": 123, "dispatch_with": "test", "type": "twitter-url", }, { - "from": "Tony", - "to": "bobafett", + "source": "Tony", + "target": "bobafett", "view_count": 123, "dispatch_with": "test", "type": "twitter-url", }, ], - id="multiple values from regex with directive constant", + id="multiple values source regex with directive constant", ), ], ) diff --git a/tests/test_spider.py b/tests/test_spider.py index 98389b6..8fce165 100644 --- a/tests/test_spider.py +++ b/tests/test_spider.py @@ -5,6 +5,7 @@ It is loaded by click as soon as the application starts and the configuration is loaded automatically by the initializer. """ + # pylint: disable=E1101 from pathlib import Path @@ -28,18 +29,15 @@ def test_load_config(): def test_spider(): """Should instantiate a spider.""" - spider = Spider(auto_transitions=False) + spider = Spider(auto_transitions=True) assert spider is not None assert spider.is_idle() spider.start(Path("tests/stubs/sevens_grader_random_test.pe.yml")) - assert spider.is_starting() assert spider.configuration is not None - - spider.gather() - assert spider._cache_ is not None # pylint: disable=W0212 + assert spider.is_stopped() def test_spider_with_spikyball(): @@ -51,7 +49,7 @@ def test_spider_with_spikyball(): spider.start(Path("tests/stubs/sevens_grader_spikyball_test.pe.yml")) - assert spider.is_stopping() + assert spider.is_stopped() assert spider.configuration is not None diff --git a/tests/test_utils.py b/tests/test_utils.py index d2f8eb9..e996df3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -10,7 +10,7 @@ import pytest -from spiderexpress.types import fromdict +from spiderexpress.types import from_dict from tests.conftest import MyFunkyTestClass, MyOtherFunkyTestClass @@ -31,6 +31,6 @@ ) def test_fromdict(value: object): """test fromdict()""" - ans = fromdict(type(value), asdict(value)) + ans = from_dict(type(value), asdict(value)) print(ans) assert ans == value