diff --git a/pyproject.toml b/pyproject.toml
index 3ad22b6..44a49a7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -15,6 +15,7 @@ click = "*"
psycopg2-binary = "*"
PyYAML = "*"
transitions = "*"
+pydantic = "^2.9.2"
[tool.poetry.dev-dependencies]
ipykernel = "*"
diff --git a/spiderexpress/cli.py b/spiderexpress/cli.py
index 37a5e98..15c03b5 100644
--- a/spiderexpress/cli.py
+++ b/spiderexpress/cli.py
@@ -11,13 +11,14 @@
- Refine verbs/commands for the CLI
- find a mechanism for stopping/starting collections
"""
+
+import sys
from importlib.metadata import entry_points
from pathlib import Path
-from loguru import logger as log
import click
import yaml
-import sys
+from loguru import logger as log
from .spider import CONNECTOR_GROUP, STRATEGY_GROUP, Spider
from .types import Configuration
@@ -34,22 +35,21 @@ def cli(ctx):
@cli.command()
@click.argument("config", type=click.Path(path_type=Path, exists=True))
@click.option("-v", "--verbose", count=True)
-@click.option("-l", "--logfile", type=click.Path(dir_okay=False, writable=True, path_type=str))
+@click.option(
+ "-l", "--logfile", type=click.Path(dir_okay=False, writable=True, path_type=str)
+)
@click.pass_context
def start(ctx: click.Context, config: Path, verbose: int, logfile: str):
"""start a job"""
- logging_level = max(50 - (10 * verbose), 0) # Allows logging level to be between 0 and 50.
+ logging_level = max(
+ 50 - (10 * verbose), 0
+ ) # Allows logging level to be between 0 and 50.
logging_configuration = {
- "handlers": [
- {
- "sink": logfile or sys.stdout,
- "level": logging_level
- }
- ],
- "extra": {}
+ "handlers": [{"sink": logfile or sys.stdout, "level": logging_level}],
+ "extra": {},
}
log.configure(**logging_configuration)
- log.debug(f"Starting logging with verbosity {logging_level}.")
+ log.debug(f"Starting logging with verbosity {logging_level}.")
ctx.obj.start(config)
@@ -62,7 +62,13 @@ def create(config: str, interactive: bool):
if interactive:
for key, description in [
- ("seeds", "add seeds?"),
+ ("project_name", "Name of your project?"),
+ ("db_url", "URL of your database?"),
+ ("max_iteration", "How many iterations should be done?"),
+ (
+ "empty_seeds",
+ "What should happen if seeds are empty? Can be 'stop' or 'retry'",
+ ),
("seed_file", "do you wish to read a file for seeds?"),
]:
args[key] = click.prompt(description)
@@ -76,7 +82,7 @@ def create(config: str, interactive: bool):
@cli.command()
def list(): # pylint: disable=W0622
"""list all plugins"""
- click.echo("--- connectors ---", color="blue")
+ click.echo("--- connectors ---")
for connector in entry_points(group=CONNECTOR_GROUP):
click.echo(connector.name)
click.echo("--- strategies ---")
diff --git a/spiderexpress/connectors/csv.py b/spiderexpress/connectors/csv.py
index 9a5e508..a874941 100644
--- a/spiderexpress/connectors/csv.py
+++ b/spiderexpress/connectors/csv.py
@@ -1,13 +1,15 @@
"""A CSV-reading, network-rippin' connector for your testing purposes."""
+
import dataclasses
from typing import Dict, List, Optional, Tuple, Union
import pandas as pd
-from spiderexpress.types import PlugIn, fromdict
+from spiderexpress.types import PlugIn, from_dict
_cache = {}
+
@dataclasses.dataclass
class CSVConnectorConfiguration:
"""Configuration items for the csv_connector."""
@@ -23,7 +25,7 @@ def csv_connector(
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""The CSV connector!"""
if isinstance(configuration, dict):
- configuration = fromdict(CSVConnectorConfiguration, configuration)
+ configuration = from_dict(CSVConnectorConfiguration, configuration)
if configuration.cache:
if configuration.edge_list_location not in _cache:
@@ -61,9 +63,11 @@ def csv_connector(
return (
edge_return,
- nodes.loc[nodes.name.isin(node_ids), :]
- if nodes is not None
- else pd.DataFrame(),
+ (
+ nodes.loc[nodes.name.isin(node_ids), :]
+ if nodes is not None
+ else pd.DataFrame()
+ ),
)
diff --git a/spiderexpress/model.py b/spiderexpress/model.py
index cdea87a..f383bec 100644
--- a/spiderexpress/model.py
+++ b/spiderexpress/model.py
@@ -4,13 +4,13 @@
"""
import datetime
-from typing import Any, Callable, Dict, List, Tuple, Type
+from typing import Dict
import sqlalchemy as sql
from loguru import logger as log
-from sqlalchemy import orm
+from sqlalchemy import JSON, orm
-# pylint: disable=R0903
+# pylint: disable=R0903, W0622
mapper_registry = orm.registry()
@@ -24,6 +24,8 @@
class Base(orm.DeclarativeBase):
"""Base class for all models."""
+ type_annotation_map = {Dict: JSON}
+
def __repr__(self):
props = [
f"{key}={value}"
@@ -68,177 +70,174 @@ class TaskList(Base):
finished_at: orm.Mapped[datetime.datetime] = orm.mapped_column(nullable=True)
-def create_factory(
- cls: Type[Any], spec_fixed: List[sql.Column], spec_variadic: Dict[str, Any]
-) -> Callable:
- """Create a factory function for a given class."""
+class RawDataStore(Base):
+ """Table for raw data storage.
- log.info(
- f"Creating factory for {cls.__name__} with {spec_fixed} and {spec_variadic}"
- )
+ Attributes:
+ id: Primary key for the table.
+ connector_id: Identifier for the connector.
+ output_type: Type of the output data.
+ created_at: Timestamp when the data was created.
+ data: The raw data stored in JSON format.
+ """
- def _(data: Dict[str, Any]) -> Type[Any]:
- return cls(
- **{
- key: data.get(key)
- for key in [column.name for column in spec_fixed]
- + list(spec_variadic.keys())
- }
- )
+ __tablename__ = "raw_data_store"
- return _
+ id: orm.Mapped[str] = orm.mapped_column(primary_key=True)
+ connector_id: orm.Mapped[str] = orm.mapped_column(index=True)
+ output_type: orm.Mapped[str] = orm.mapped_column(index=True)
+ created_at: orm.Mapped[datetime.datetime] = orm.mapped_column(
+ index=True, insert_default=lambda: datetime.datetime.now(datetime.timezone.utc)
+ )
+ data: orm.Mapped[Dict] = orm.mapped_column(insert_default={})
-def create_raw_edge_table(
- name: str, spec_variadic: Dict[str, str]
-) -> Tuple[sql.Table, Type["RawEdge"], Callable]:
- """Create an edge table dynamically.
+class LayerDenseEdges(Base):
+ """Table for dense data storage."""
- parameters:
- name: name of the table
- spec_variadic: dict of variadic columns
+ __tablename__ = "layer_dense_edges"
- returns:
- table: the table
- """
- spec_fixed = [
- sql.Column("id", sql.Integer, primary_key=True, index=True, autoincrement=True),
- sql.Column("source", sql.Text, index=True, unique=False),
- sql.Column("target", sql.Text, index=True, unique=False),
- sql.Column("iteration", sql.Integer, index=True, unique=False),
- ]
-
- table = sql.Table(
- name,
- Base.metadata,
- *spec_fixed,
- *[
- sql.Column(key, type_lookup.get(value))
- for key, value in spec_variadic.items()
- ],
+ id: orm.Mapped[str] = orm.mapped_column(primary_key=True, index=True)
+ source: orm.Mapped[str] = orm.mapped_column(index=True)
+ target: orm.Mapped[str] = orm.mapped_column(index=True)
+ edge_type: orm.Mapped[str] = orm.mapped_column(index=True)
+ layer_id: orm.Mapped[str] = orm.mapped_column(index=True)
+ created_at: orm.Mapped[datetime.datetime] = orm.mapped_column(
+ index=True, insert_default=lambda: datetime.datetime.now(datetime.timezone.utc)
)
+ data: orm.Mapped[Dict] = orm.mapped_column(insert_default={})
+
+
+def insert_layer_dense_edge(session, layer_id, edge_type, data):
+ """Insert a dense edge into the database."""
+ source = data.get("source")
+ target = data.get("target")
+ id = f"{layer_id}:{source}-{target}"
+
+ dense_edge = LayerDenseEdges(
+ id=id,
+ source=source,
+ target=target,
+ edge_type=edge_type,
+ layer_id=layer_id,
+ data=data,
+ )
+ session.merge(dense_edge)
+ session.commit()
+ log.debug(f"Inserted dense edge from {source} to {target} in layer {layer_id}")
- class RawEdge:
- """Unaggregated, raw edge."""
-
- def __repr__(self):
- return f""""""
- mapper_registry.map_imperatively(RawEdge, table)
+class LayerDenseNodes(Base):
+ """Table for dense data storage."""
- return table, RawEdge, create_factory(RawEdge, spec_fixed, spec_variadic)
+ __tablename__ = "layer_dense_nodes"
+ id: orm.Mapped[str] = orm.mapped_column(primary_key=True, index=True)
+ name: orm.Mapped[str] = orm.mapped_column()
+ layer_id: orm.Mapped[str] = orm.mapped_column(index=True)
+ node_type: orm.Mapped[str] = orm.mapped_column(index=True)
+ created_at: orm.Mapped[datetime.datetime] = orm.mapped_column(
+ index=True, insert_default=lambda: datetime.datetime.now(datetime.timezone.utc)
+ )
+ data: orm.Mapped[Dict] = orm.mapped_column()
-def create_aggregated_edge_table(
- name: str, spec_variadic: Dict[str, str]
-) -> Tuple[sql.Table, Type["AggEdge"], Callable]:
- """Create an aggregated edge table dynamically.
- parameters:
- name: name of the table
- spec_variadic: dict of variadic columns
+def insert_layer_dense_node(session, layer_id, node_type, data):
+ """Insert a dense node into the database."""
+ name = data.get("name")
+ id = f"{layer_id}:{name}"
- returns:
- table: the table
- """
- spec_fixed = [
- sql.Column("source", sql.Text, primary_key=True, index=True),
- sql.Column("target", sql.Text, primary_key=True, index=True),
- sql.Column(
- "iteration", sql.Integer, primary_key=True, index=True, unique=False
- ),
- sql.Column("weight", sql.Integer),
- ]
-
- table = sql.Table(
- name,
- Base.metadata,
- *spec_fixed,
- *[sql.Column(key, sql.Integer) for key, value in spec_variadic.items()],
+ dense_node = LayerDenseNodes(
+ id=id, name=name, layer_id=layer_id, node_type=node_type, data=data
)
+ session.merge(dense_node)
+ session.commit()
+ log.debug(f"Inserted dense node {name} in layer {layer_id}")
- class AggEdge:
- """Aggregated edge."""
- def __repr__(self):
- return f""""""
+class LayerSparseEdges(Base):
+ """Table for sparse data storage."""
- mapper_registry.map_imperatively(AggEdge, table)
+ __tablename__ = "layer_sparse_store"
- return table, AggEdge, create_factory(AggEdge, spec_fixed, spec_variadic)
+ id: orm.Mapped[str] = orm.mapped_column(primary_key=True, index=True)
+ layer_id: orm.Mapped[str] = orm.mapped_column(index=True)
+ source: orm.Mapped[str] = orm.mapped_column(index=True)
+ target: orm.Mapped[str] = orm.mapped_column(index=True)
+ edge_type: orm.Mapped[str] = orm.mapped_column(index=True)
+ weight: orm.Mapped[float] = orm.mapped_column()
+ created_at: orm.Mapped[datetime.datetime] = orm.mapped_column(
+ index=True, insert_default=lambda: datetime.datetime.now(datetime.timezone.utc)
+ )
+ data: orm.Mapped[Dict] = orm.mapped_column()
+
+
+def insert_layer_sparse_edge(session, layer_id, edge_type, data):
+ """Insert a sparse edge into the database."""
+ source = data.get("source")
+ target = data.get("target")
+ weight = data.get("weight")
+ id = f"{layer_id}:{source}-{target}"
+
+ sparse_edge = LayerSparseEdges(
+ id=id,
+ source=source,
+ target=target,
+ weight=weight,
+ edge_type=edge_type,
+ layer_id=layer_id,
+ data=data,
+ )
+ session.merge(sparse_edge)
+ session.commit()
+ log.debug(f"Inserted sparse edge from {source} to {target} in layer {layer_id}")
-def create_node_table(
- name: str, spec_variadic: Dict[str, str]
-) -> Tuple[sql.Table, Type["Node"], Callable]:
- """Create a node table dynamically.
+class LayerSparseNodes(Base):
+ """Table for sparse data storage."""
- parameters:
- name: name of the table
- spec_variadic: dict of variadic columns
+ __tablename__ = "layer_sparse_nodes"
- returns:
- table: the table
- """
- spec_fixed = [
- sql.Column("name", sql.Text, primary_key=True, index=True),
- sql.Column("iteration", sql.Integer, index=True, unique=False),
- ]
-
- table = sql.Table(
- name,
- Base.metadata,
- *spec_fixed,
- *[
- sql.Column(key, type_lookup.get(value))
- for key, value in spec_variadic.items()
- ],
+ id: orm.Mapped[str] = orm.mapped_column(primary_key=True, index=True)
+ layer_id: orm.Mapped[str] = orm.mapped_column(index=True)
+ name: orm.Mapped[str] = orm.mapped_column()
+ node_type: orm.Mapped[str] = orm.mapped_column(index=True)
+ created_at: orm.Mapped[datetime.datetime] = orm.mapped_column(
+ index=True, insert_default=lambda: datetime.datetime.now(datetime.timezone.utc)
)
+ data: orm.Mapped[Dict] = orm.mapped_column()
- class Node:
- """Node."""
- def __repr__(self):
- return f""""""
+def insert_layer_sparse_node(session, layer_id, node_type, data):
+ """Insert a sparse node into the database."""
+ name = data.get("name")
+ id = f"{layer_id}:{name}"
- mapper_registry.map_imperatively(Node, table)
+ sparse_node = LayerSparseNodes(
+ id=id, name=name, layer_id=layer_id, node_type=node_type, data=data
+ )
+ session.merge(sparse_node)
+ session.commit()
+ log.debug(f"Inserted sparse node {name} in layer {layer_id}")
- return table, Node, create_factory(Node, spec_fixed, spec_variadic)
+class SamplerStateStore(Base):
+ """Table for storing the state of the sampler."""
-def create_sampler_state_table(
- name: str, spec_variadic: Dict[str, str]
-) -> Tuple[sql.Table, Type["SamplerState"], Callable]:
- """Create a sampler state table dynamically."""
+ __tablename__ = "sampler_state_store"
- table = sql.Table(
- name,
- Base.metadata,
- sql.Column("id", sql.Integer, primary_key=True, index=True, autoincrement=True),
- *[
- sql.Column(key, type_lookup.get(value))
- for key, value in spec_variadic.items()
- ],
+ id: orm.Mapped[int] = orm.mapped_column(primary_key=True, autoincrement=True)
+ iteration: orm.Mapped[int] = orm.mapped_column(index=True)
+ layer_id: orm.Mapped[str] = orm.mapped_column(index=True)
+ data: orm.Mapped[Dict] = orm.mapped_column()
+ created_at: orm.Mapped[datetime.datetime] = orm.mapped_column(
+ insert_default=lambda: datetime.datetime.now(datetime.timezone.utc)
)
- class SamplerState:
- """Sampler state."""
-
- mapper_registry.map_imperatively(SamplerState, table)
- return table, SamplerState, create_factory(SamplerState, [], spec_variadic)
+def insert_sampler_state(session, layer_id, iteration, data):
+ """Insert the state of the sampler into the database."""
+ sampler_state = SamplerStateStore(iteration=iteration, layer_id=layer_id, data=data)
+ session.add(sampler_state)
+ session.commit()
+ log.debug("Inserted sampler state for layer {layer_id} at iteration {iteration}")
diff --git a/spiderexpress/plugin_manager.py b/spiderexpress/plugin_manager.py
index e626ff9..4af73e0 100644
--- a/spiderexpress/plugin_manager.py
+++ b/spiderexpress/plugin_manager.py
@@ -46,7 +46,7 @@ def get_plugin(spec: PlugInSpec, group: str) -> Callable:
def _(spec: str, group: str) -> Callable:
plugin = _access_entry_point(spec, group)
if not plugin:
- raise ValueError(f"{ spec } could not be found in { group }")
+ raise ValueError(f"{spec} could not be found in {group}")
return functools.partial(
plugin.callable, configuration=plugin.default_configuration
)
@@ -56,7 +56,7 @@ def _(spec: str, group: str) -> Callable:
def _(spec: dict, group: str) -> Callable:
if len(spec.keys()) > 1:
log.warning(
- "Requesting specification has more than one type."
+ f"Requested specification {spec} has more than one type. "
"Using the first instance found"
)
for name, configuration in spec.items():
diff --git a/spiderexpress/router.py b/spiderexpress/router.py
index 5a7c4ee..ee0348b 100644
--- a/spiderexpress/router.py
+++ b/spiderexpress/router.py
@@ -2,8 +2,13 @@
"""
+
import re
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
+
+from loguru import logger as log
+
+from spiderexpress.types import RouterSpec
class RouterValidationError(ValueError):
@@ -11,39 +16,76 @@ class RouterValidationError(ValueError):
class Router:
- """Creates an edge router from the given specification.
+ r"""Creates an edge router from the given specification.
Arguments:
name (str): the layer name to bind to
spec (Dict[str, Union[str, List[TargetSpec]]]): how data should be routed
Raises:
- ValueError: if the specification is malformed
+ RouterValidationError: if the specification is malformed
+
+ Example:
+ spec = {
+ "source": "handle",
+ "target": [
+ {
+ "field": "text",
+ "pattern": r"https://www\.twitter\.com/(\w+)",
+ "dispatch_with": "test",
+ "type": "twitter-url",
+ }
+ ],
+ "view_count": "view_count",
+ }
+ input_data = {
+ "handle": "Tony",
+ "text": "Check this out: https://www.twitter.com/ernie",
+ "view_count": 123,
+ }
+ router = Router("test", spec)
+ result = router.parse(input_data)
+ print(result)
+ # Output: [{
+ 'source': 'Tony', 'target': 'ernie', 'view_count': 123,
+ 'dispatch_with': 'test', 'type': 'twitter-url'
+ }]
"""
- def __init__(self, name: str, spec: Dict[str, Any], context: Optional[Dict] = None):
+ TARGET = "target"
+ SOURCE = "source"
+
+ def __init__(self, name: str, spec: RouterSpec, context: Optional[Dict] = None):
# Store the layer name
self.name = name
# Validate the spec against the rule set
Router.validate_spec(name, spec, context)
- self.spec = spec.copy()
+ self.spec: RouterSpec = spec
@classmethod
def validate_spec(cls, name, spec, context):
"""Validates a spec in a context."""
- if "to" not in spec:
- raise RouterValidationError(f"{name}: Key 'to' is missing from {spec}.")
- if not isinstance(spec.get("to"), list):
+ # pylint: disable=R0912
+ if Router.TARGET not in spec:
+ raise RouterValidationError(
+ f"{name}: Key {Router.TARGET} is missing from {spec}."
+ )
+ if Router.SOURCE not in spec:
raise RouterValidationError(
- f"{name}: 'to' is not a list but {spec.get('to')}."
+ f"{name}: Key {Router.SOURCE} is missing from {spec}."
)
- for target_spec in spec.get("to"):
- mandatory_fields = ["field", "dispatch_with"]
- for field in mandatory_fields:
- if target_spec.get(field) is None:
- raise RouterValidationError(
- f"{name}: '{ field }' cannot be None in {target_spec}"
- )
+ if not isinstance(spec.get(Router.TARGET), list):
+ raise RouterValidationError(
+ f"{name}: 'to' is not a list but '{spec.get('to')}'."
+ )
+ if isinstance(spec.get(Router.TARGET), list):
+ for target_spec in spec.get(Router.TARGET):
+ mandatory_fields = ["field", "dispatch_with"]
+ for field in mandatory_fields:
+ if target_spec.get(field) is None:
+ raise RouterValidationError(
+ f"{name}: '{field}' cannot be None in {target_spec}"
+ )
if context is None:
return
@@ -54,42 +96,45 @@ def validate_spec(cls, name, spec, context):
this_connector = connectors.get(name)
for _, data_column_name in spec.items():
if data_column_name not in this_connector:
- raise RouterValidationError(
- f"{ name }: { data_column_name } not found."
- )
- for target_spec in spec.get("to"):
+ raise RouterValidationError(f"{name}: {data_column_name} not found.")
+ for target_spec in spec.get(Router.TARGET):
field = target_spec.get("field")
if field not in connectors.get(name).get("columns"):
raise RouterValidationError(
f"{name}: reference to {field} not found in " f"context."
)
- def parse(self, input_data):
+ def parse(self, input_data) -> List[Dict[str, Any]]:
"""Parses data with the given spec and emits edges."""
ret = []
constant = {}
+
+ log.debug(f"Router '{self.name}' parsing {input_data}")
+
# First we calculate all constants
for edge_key, spec in self.spec.items():
if isinstance(spec, str):
constant[edge_key] = input_data.get(spec)
- for directive in self.spec.get("to", []):
- value = input_data.get(directive.get("field"))
- # Add further constants if there are some defined in the spec
- local_constant = {
- **{
- key: value
- for key, value in directive.items()
- if key not in ["field", "pattern"]
- },
- **constant,
- }
- if "pattern" not in directive:
- # Simply get the value and return a
- ret.append({"to": value, **local_constant})
- continue
- # Get all matches from the string and return an edge for each
- matches = re.findall(directive.get("pattern"), value)
- for match in matches:
- ret.append({"to": match, **local_constant})
-
- return ret
+ if isinstance(self.spec.get(Router.TARGET), list):
+ for directive in self.spec.get(Router.TARGET, []):
+ value = input_data.get(directive.get("field"))
+ # Add further constants if there are some defined in the spec
+ local_constant = {
+ **{
+ key: value
+ for key, value in directive.items()
+ if key not in ["field", "pattern"]
+ },
+ **constant,
+ }
+ if "pattern" not in directive:
+ # Simply get the value and return a
+ ret.append({Router.TARGET: value, **local_constant})
+ continue
+ # Get all matches from the string and return an edge for each
+ matches = re.findall(directive.get("pattern"), value)
+ for match in matches:
+ ret.append({Router.TARGET: match, **local_constant})
+ return ret
+
+ return [constant]
diff --git a/spiderexpress/spider.py b/spiderexpress/spider.py
index deb29fd..e0f8a47 100644
--- a/spiderexpress/spider.py
+++ b/spiderexpress/spider.py
@@ -14,7 +14,7 @@
from datetime import datetime
from pathlib import Path
-from typing import Any, Dict, Optional, Union
+from typing import Any, Dict, List, Optional, Union
import pandas as pd
import sqlalchemy as sql
@@ -26,15 +26,20 @@
from spiderexpress.model import (
AppMetaData,
Base,
+ LayerDenseEdges,
+ LayerDenseNodes,
+ SamplerStateStore,
SeedList,
TaskList,
- create_aggregated_edge_table,
- create_node_table,
- create_raw_edge_table,
- create_sampler_state_table,
+ insert_layer_dense_edge,
+ insert_layer_dense_node,
+ insert_layer_sparse_edge,
+ insert_layer_sparse_node,
+ insert_sampler_state,
)
-from spiderexpress.plugin_manager import get_plugin, get_table_configuration
-from spiderexpress.types import Configuration, Connector, Strategy
+from spiderexpress.plugin_manager import get_plugin
+from spiderexpress.router import Router
+from spiderexpress.types import Configuration, Connector, Strategy, from_dict
# pylint: disable=W0613,E1101,C0103
@@ -44,9 +49,6 @@
MAX_RETRIES = 3
-factory_registry = {}
-class_registry = {}
-
class Spider:
"""This is spiderexpress' Spider.
@@ -70,7 +72,7 @@ class Spider:
"idle",
{
"name": "starting",
- "on_enter": ["open_database", "load_plugins"],
+ "on_enter": ["open_database"],
},
{
"name": "gathering",
@@ -88,6 +90,9 @@ class Spider:
"name": "stopping",
"on_enter": "close_database",
},
+ {
+ "name": "stopped",
+ },
]
"""List of states the spider can be in.
@@ -147,6 +152,11 @@ class Spider:
"dest": "stopping",
# "conditions": "iteration_limit_reached",
},
+ {
+ "trigger": "end",
+ "source": "stopping",
+ "dest": "stopped",
+ },
]
"""List of transitions the spider can make."""
@@ -167,14 +177,14 @@ def __init__(
self.retry_count = 0
# set the loaded configuration to None, as it is not loaded yet
- self.configuration: Optional[Configuration] = None
+ self.configuration: Optional[Configuration] = configuration
self.connector: Optional[Connector] = None
self.strategy: Optional[Strategy] = None
self._cache_: Optional[orm.Session] = None
self.appstate: Optional[AppMetaData] = None
@property
- def task_buffer(self):
+ def task_buffer(self) -> List[TaskList]:
"""Returns the task buffer, which is a list of tasks that are currently
being processed."""
return self._cache_.query(TaskList).filter(TaskList.status == "new").all()
@@ -191,8 +201,12 @@ def is_gathering_not_done(self):
def _conditional_advance(self, *args) -> None:
"""Advances the state machine when the current state is done."""
+ # pylint: disable=R0911
if self.state == "idle":
return
+ if self.state == "stopping":
+ self.trigger("end")
+ return
if self.state == "gathering":
if self.may_sample():
log.debug("Advancing from gathering to sampling")
@@ -215,7 +229,9 @@ def _conditional_advance(self, *args) -> None:
targets = self.machine.get_triggers(self.state)
- log.debug(f"Advancing from {self.state} and I can trigger {', '.join(targets)}")
+ log.debug(
+ f"Advancing from {self.state} and I can trigger {', '.join(targets) or 'nothing'}."
+ )
for target in targets:
if self.trigger(target) is True:
@@ -235,13 +251,7 @@ def load_config(self, config_file: Path) -> None:
)
with config_file.open("r", encoding="utf8") as file:
- self.configuration = yaml.full_load(file)
- else:
- self.configuration = (
- self.configuration
- if isinstance(self.configuration, Configuration)
- else Configuration(**self.configuration)
- )
+ self.configuration = from_dict(Configuration, yaml.safe_load(file))
def is_config_valid(self):
"""Asserts that the configuration is valid."""
@@ -273,37 +283,6 @@ def open_database(self, *args) -> None:
self._cache_ = orm.Session(engine)
- _, Node, node_factory = create_node_table(
- self.configuration.node_table["name"],
- self.configuration.node_table["columns"],
- )
- _, RawEdge, raw_edge_factory = create_raw_edge_table(
- self.configuration.edge_raw_table["name"],
- self.configuration.edge_raw_table["columns"],
- )
- _, AggEdge, agg_edge_factory = create_aggregated_edge_table(
- self.configuration.edge_agg_table["name"],
- self.configuration.edge_agg_table["columns"],
- )
-
- class_registry["node"] = Node
- class_registry["raw_edge"] = RawEdge
- class_registry["agg_edge"] = AggEdge
-
- factory_registry["node"] = node_factory
- factory_registry["raw_edge"] = raw_edge_factory
- factory_registry["agg_edge"] = agg_edge_factory
-
- strategy_name = (
- self.configuration.strategy
- if isinstance(self.configuration.strategy, str)
- else list(self.configuration.strategy.keys())[0]
- )
-
- _, SamplerState, _ = create_sampler_state_table( # pylint: disable=W0612
- strategy_name, get_table_configuration(strategy_name, STRATEGY_GROUP)
- )
- class_registry["sampler_state"] = SamplerState
Base.metadata.create_all(engine)
appstate = self._cache_.get(AppMetaData, "1")
@@ -332,24 +311,27 @@ def initialize_seeds(self):
log.debug(f"Copying seeds to database: {', '.join(self.configuration.seeds)}.")
# with self._cache_.begin():
- for seed in self.configuration.seeds:
- if self._cache_.get(SeedList, seed) is None:
- _seed = SeedList(id=seed, iteration=0, status="new")
- self._cache_.add(_seed)
- self._add_task(_seed)
-
- def _add_task(self, task: Union[SeedList, str], parent: Optional[TaskList] = None):
+ for layer, seeds in self.configuration.seeds.items():
+ for seed in seeds:
+ if self._cache_.get(SeedList, seed) is None:
+ _seed = SeedList(id=seed, iteration=0, status="new")
+ self._cache_.add(_seed)
+ self._add_task(_seed, layer=layer)
+
+ def _add_task(
+ self, task: Union[SeedList, str], layer: str, parent: Optional[TaskList] = None
+ ):
"""Adds a task to the task buffer."""
if not self._cache_:
raise ValueError("Cache is not present.")
- if not isinstance(task, (SeedList, class_registry["node"], str)):
+ if not isinstance(task, (SeedList, LayerDenseNodes, str)):
raise ValueError(
- "Task must be a seed, a node or a node-identifier, but is {type(task).__name__}"
+ f"Task must be a seed, a node or a node-identifier, but is {type(task).__name__}"
)
node_id = (
task.name
- if isinstance(task, (class_registry["node"]))
+ if isinstance(task, LayerDenseNodes)
else task if isinstance(task, str) else task.id
)
@@ -362,7 +344,7 @@ def _add_task(self, task: Union[SeedList, str], parent: Optional[TaskList] = Non
node_id=node_id,
status="new",
initiated_at=datetime.now(),
- connector="stub_value",
+ connector=layer,
parent_task_id=parent.id if parent else None,
)
self._cache_.add(new_task)
@@ -409,8 +391,8 @@ def retry_with_unused_seeds(self):
candidates = (
self._cache_.execute(
- sql.select(class_registry["node"].name).where(
- class_registry["node"].name.not_in(candidate_nodes_names)
+ sql.select(LayerDenseNodes.name).where(
+ LayerDenseNodes.name.not_in(candidate_nodes_names)
)
)
.scalars()
@@ -446,15 +428,15 @@ def gather_node_data(self):
if len(self.task_buffer) == 0:
return
- task: TaskList = self.task_buffer.pop(0)
+ task = self.task_buffer.pop(0)
log.debug(f"Attempting to gather data for {task.node_id}.")
# Begin transaction with the cache
- node_info = self._cache_.get(class_registry["node"], task.node_id)
+ node_info = self._cache_.get(LayerDenseNodes, task.node_id)
if node_info is None:
- self._dispatch_connector_for_node_(task)
+ self._dispatch_connector_for_node_(task, task.connector)
# Mark the node as done
seed = self._cache_.get(SeedList, task.node_id)
@@ -479,137 +461,135 @@ def iteration_limit_reached(self):
def sample_network(self):
"""Samples the network."""
+ # pylint: disable=R0914
if not self._cache_:
raise ValueError("Cache is not present.")
-
- log.debug("Attempting to sample the network.")
-
- aggregation_spec = self.configuration.edge_agg_table["columns"]
- aggregation_funcs = {
- "count": sql.func.count,
- "max": sql.func.max,
- "min": sql.func.min,
- "sum": sql.func.sum,
- "avg": sql.func.avg,
- }
-
- aggregations = [
- sql.func.count().label( # pylint: disable=E1102 # not-callable, but it is :shrug:
- "weight"
- ),
- *[
- aggregation_funcs[aggregation](
- getattr(class_registry["raw_edge"], column)
- ).label(column)
- for column, aggregation in aggregation_spec.items()
- ],
- ]
- sql_statement = (
- self._cache_.query(
- class_registry["raw_edge"].source,
- class_registry["raw_edge"].target,
- *aggregations,
+ for layer_id, layer_config in self.configuration.layers.items():
+ # Get data for the layer from the dense data stores
+ edges = pd.read_sql(
+ self._cache_.query(
+ LayerDenseEdges.source,
+ LayerDenseEdges.target,
+ sql.func.count("*").label("weight"), # pylint: disable=E1102
+ )
+ .where(LayerDenseEdges.layer_id == layer_id)
+ .group_by(LayerDenseEdges.source, LayerDenseEdges.target)
+ .statement,
+ self._cache_.connection(),
)
- .group_by(
- class_registry["raw_edge"].source, class_registry["raw_edge"].target
+ nodes = pd.json_normalize(
+ pd.read_sql(
+ self._cache_.query(LayerDenseNodes.data)
+ .where(LayerDenseNodes.layer_id == layer_id)
+ .statement,
+ self._cache_.connection(),
+ ).data
+ )
+ sampler_state = pd.json_normalize(
+ pd.read_sql(
+ sql.select(SamplerStateStore.data).where(
+ SamplerStateStore.layer_id == layer_id
+ ),
+ self._cache_.connection(),
+ ).data
)
- .statement
- )
-
- log.debug(f"Aggregation query: {sql_statement}")
-
- edges = pd.read_sql(
- sql_statement,
- self._cache_.connection(),
- )
- nodes = pd.read_sql(
- self._cache_.query(class_registry["node"]).statement,
- self._cache_.connection(),
- )
- sampler_state = pd.read_sql(
- sql.select(class_registry["sampler_state"]), self._cache_.connection()
- )
-
- log.info(
- f"""Sampling from {
- len(edges)
- } edges with {
- len(nodes)
- } nodes while the sampler's state is { len(sampler_state) } long."""
- )
-
- log.info(f"That's the current state of affairs: { sampler_state }")
- new_seeds, new_edges, _, new_sampler_state = self.strategy(
- edges, nodes, sampler_state
- )
+ log.debug(
+ f"""
+ Sampling layer {layer_id} with {len(edges)} edges and {len(nodes)} nodes.
+ Edges to sample:
+{edges}
- new_edges["iteration"] = self.appstate.iteration
+ Sampler state:
+{sampler_state}
+"""
+ )
- if len(new_seeds) == 0:
- log.warning("Found no new seeds.")
- elif self.retry_count > 0:
- self.retry_count = 0
+ sampler: Strategy = get_plugin(layer_config.sampler, STRATEGY_GROUP)
+ new_seeds, sparse_edges, sparse_nodes, new_sampler_state = sampler(
+ edges, nodes, sampler_state
+ )
- for seed in new_seeds:
- if seed is None:
- continue
- if self._cache_.get(SeedList, seed) is None:
- _seed = SeedList(
- id=seed, iteration=self.appstate.iteration + 1, status="new"
+ log.info(f"That's the current state of affairs:\n\n{new_sampler_state}")
+
+ sparse_edges["iteration"] = self.appstate.iteration
+ if len(new_seeds) == 0:
+ log.warning("Found no new seeds.")
+ elif self.retry_count > 0:
+ self.retry_count = 0
+
+ for seed in new_seeds:
+ if seed is None:
+ continue
+ if self._cache_.get(SeedList, seed) is None:
+ _seed = SeedList(
+ id=seed, iteration=self.appstate.iteration + 1, status="new"
+ )
+ self._cache_.add(_seed)
+ self._add_task(_seed, layer=layer_id)
+
+ for edge in sparse_edges.to_dict(orient="records"):
+ if edge["source"] is not None and edge["target"] is not None:
+ insert_layer_sparse_edge(self._cache_, layer_id, "test", data=edge)
+ for node in sparse_nodes.to_dict(orient="records"):
+ insert_layer_sparse_node(self._cache_, layer_id, "test", node)
+ for state in new_sampler_state.to_dict(orient="records"):
+ insert_sampler_state(
+ self._cache_, layer_id, self.appstate.iteration, state
)
- self._cache_.add(_seed)
- self._add_task(_seed)
- self._cache_.add_all(
- [
- factory_registry["agg_edge"](edge)
- for edge in new_edges.to_dict(orient="records")
- if edge["source"] is not None and edge["target"] is not None
- ]
- )
- for state in new_sampler_state.to_dict(orient="records"):
- self._cache_.merge(class_registry["sampler_state"](**state))
- self._cache_.commit()
-
- def load_plugins(self, *args):
- """Loads the plug-ins."""
- self.strategy = get_plugin(self.configuration.strategy, STRATEGY_GROUP)
- self.connector = get_plugin(self.configuration.connector, CONNECTOR_GROUP)
+ self._cache_.commit()
# section: private methods
- def _dispatch_connector_for_node_(self, node: TaskList):
- if not self.configuration or not self.connector:
+ def _dispatch_connector_for_node_(self, node: TaskList, layer: str):
+ # pylint: disable=R0914
+ if not self.configuration:
raise ValueError("Configuration or Connector are not present")
+ # Get the connector for the layer
+ layer_configuration = self.configuration.layers[layer]
+ connector_spec = layer_configuration.connector
+ connector = get_plugin(connector_spec, CONNECTOR_GROUP)
- edges, nodes = self.connector([node.node_id])
+ log.debug(f"Requesting data for {node.node_id} from {connector_spec}.")
- log.debug(f"edges:\n{edges}\n\nnodes:{nodes}\n")
+ raw_edges, nodes = connector([node.node_id])
- if len(edges) > 0:
- log.info(
- f"""Persisting {
- len(edges)
- } edges for node {
- node.node_id
- } in iteration
- #{
- self.appstate.iteration
- }."""
- )
- edges["iteration"] = self.appstate.iteration
- for edge in edges.to_dict(orient="records"):
- self._cache_.merge(factory_registry["raw_edge"](edge))
-
- if self.configuration.eager is True and node.parent_task_id is None:
- # We add only new task if the parent_task_id is None to avoid snowballing
- # the entire population before we even begin sampling.
- for target in edges["target"].unique().tolist():
- self._add_task(target, parent=node)
- self._cache_.commit()
+ routers = layer_configuration.routers
+ for router_definition in routers:
+ for router_name, router_spec in router_definition.items():
+
+ log.debug(
+ f"Routing data with {router_name} and this spec: {router_spec}."
+ )
+
+ router = Router(router_name, router_spec)
+ for raw_edge in raw_edges.to_dict(orient="records"):
+ edges = router.parse(raw_edge)
+ for edge in edges:
+ edge["iteration"] = self.appstate.iteration
+ insert_layer_dense_edge(
+ self._cache_, edge.get("dispatch_with"), router_name, edge
+ )
+
+ log.debug(f"Inserted edge: {edge}")
+
+ if (
+ layer_configuration.eager is True
+ and node.parent_task_id is None
+ ):
+ # We add only new task if the parent_task_id is None to avoid snowballing
+ # the entire population before we even begin sampling.
+ targets = {edge.get("target") for edge in edges}
+ for target in targets:
+ self._add_task(
+ target,
+ parent=node,
+ layer=router_spec.get("dispatch_with"),
+ )
+ self._cache_.commit()
if len(nodes) > 0:
nodes["iteration"] = self.appstate.iteration
for _node in nodes.to_dict(orient="records"):
- self._cache_.merge(factory_registry["node"](_node))
+ insert_layer_dense_node(self._cache_, "test", "rest", _node)
self._cache_.commit()
diff --git a/spiderexpress/strategies/random.py b/spiderexpress/strategies/random.py
index 6268dcd..4376a44 100644
--- a/spiderexpress/strategies/random.py
+++ b/spiderexpress/strategies/random.py
@@ -15,28 +15,32 @@ def random_strategy(
):
"""Random sampling strategy."""
# split the edges table into edges _inside_ and _outside_ of the known network
+ is_first_round = state.empty
+ if is_first_round:
+ state = pd.DataFrame({"node_id": edges.source.unique()})
mask = edges.target.isin(state.node_id)
- edges_inward = edges.loc[mask, :]
edges_outward = edges.loc[~mask, :]
# select 10 edges to follow
if len(edges_outward) < configuration["n"]:
- edges_sampled = edges_outward
+ sparse_edges = edges_outward
else:
- edges_sampled = edges_outward.sample(n=configuration["n"], replace=False)
+ sparse_edges = edges_outward.sample(n=configuration["n"], replace=False)
- new_seeds = edges_sampled.target # select target node names as seeds for the
+ new_seeds = (
+ sparse_edges.target.unique()
+ ) # select target node names as seeds for the
# next layer
- edges_to_add = pd.concat([edges_inward, edges_sampled]) # add edges inside the
- # known network as well as the sampled edges to the known network
- new_nodes = nodes.loc[nodes.name.isin(new_seeds), :]
-
- return new_seeds, edges_to_add, new_nodes
+ sparse_nodes = nodes.loc[nodes.name.isin(new_seeds), :]
+ new_state = pd.DataFrame({"node_id": new_seeds})
+ if is_first_round:
+ new_state = pd.concat([state, new_state])
+ return new_seeds, sparse_edges, sparse_nodes, new_state
random = PlugIn(
callable=random_strategy,
- tables={"node_": "Text"},
+ tables={"state": {"node_id": "Text"}},
metadata={},
default_configuration={"n": 10},
)
diff --git a/spiderexpress/strategies/spikyball.py b/spiderexpress/strategies/spikyball.py
index 063f4cd..ad6cfa7 100644
--- a/spiderexpress/strategies/spikyball.py
+++ b/spiderexpress/strategies/spikyball.py
@@ -13,7 +13,7 @@
import pandas as pd
from loguru import logger as log
-from ..types import PlugIn, fromdict
+from ..types import PlugIn, from_dict
@dataclass
@@ -319,23 +319,30 @@ def spikyball_strategy(
"""
if isinstance(configuration, dict):
- configuration = fromdict(SpikyBallConfiguration, configuration)
+ configuration = from_dict(SpikyBallConfiguration, configuration)
+ first_round = state.empty
+ if first_round:
+ state = pd.DataFrame({"node_id": edges.source.unique()})
e_in, e_out = filter_edges(edges, state.node_id.tolist())
- seeds, e_sampled = sample_edges(
+ seeds, sparse_edges = sample_edges(
e_out,
nodes,
configuration.sampler,
configuration.layer_max_size,
)
- state = pd.concat([state, pd.DataFrame({"node_id": seeds})])
+ if first_round:
+ state = pd.concat([state, pd.DataFrame({"node_id": seeds})])
+ else:
+ state = pd.DataFrame({"node_id": seeds})
+ sparse_nodes = nodes.loc[nodes.name.isin(seeds), :]
- return seeds, pd.concat([e_in, e_sampled]), e_out, state
+ return seeds, sparse_edges, sparse_nodes, state
spikyball = PlugIn(
callable=spikyball_strategy,
default_configuration={},
- tables={"node_id": "Text"},
+ tables={"state": {"node_id": "Text"}},
metadata={},
)
diff --git a/spiderexpress/types.py b/spiderexpress/types.py
index c208f47..ee088e8 100644
--- a/spiderexpress/types.py
+++ b/spiderexpress/types.py
@@ -1,19 +1,17 @@
# pylint: disable=R
-"""Type definitions for spiderexpress
+"""Type definitions for spiderexpress'DSL.
Philipp Kessling
Leibniz-Institute for Media Research, 2022
-
"""
-
-
-from dataclasses import dataclass, fields, is_dataclass
+import json
+from dataclasses import field, fields, is_dataclass
from pathlib import Path
-from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
+from typing import Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
import pandas as pd
-import yaml
+from pydantic.dataclasses import dataclass
Connector = Callable[[List[str]], Tuple[pd.DataFrame, pd.DataFrame]]
"""Connector Interface
@@ -29,7 +27,19 @@
[pd.DataFrame, pd.DataFrame, pd.DataFrame],
Tuple[List[str], pd.DataFrame, pd.DataFrame, pd.DataFrame],
]
-"""Strategy Interface"""
+"""Strategy Interface.
+
+Args:
+ edges (pd.DataFrame): the edge table
+ nodes (pd.DataFrame): the node table
+ state (pd.DataFrame): the last state of the sampler
+
+Returns:
+ (pd.DataFrame): A list of new seed nodes
+ (pd.DataFrame): Update for the sparse edge table
+ (pd.DataFrame): Update for the sparse edge table
+ (pd.DataFrame): Update for the state table
+"""
PlugInSpec = Union[str, Dict[str, Union[str, Dict[str, Union[str, int]]]]]
"""Plug-In Definition Notation.
@@ -38,48 +48,78 @@
ColumnSpec = Dict[str, Union[str, Dict[str, str]]]
"""Column Specification."""
+StopCondition = Literal["stop", "retry"]
+
+
+@dataclass()
+class Configuration:
+ """Configuration-File Wrapper.
+
+ Attributes:
+ seeds (Dict[str, List[str]], optional): A dictionary of seeds. Defaults to None.
+ seed_file (str, optional): A JSON-file containing seeds per layer.
+ Takes precedence over `seeds`. Defaults to None.
+ project_name (str, optional): The name of the project. Defaults to "spider".
+ db_url (str, optional): The database url. Defaults to None.
+ db_schema (str, optional): The database schema. Defaults to None.
+ empty_seeds (StopCondition, optional): What to do if the seeds are empty.
+ Defaults to "stop".
+ layers (List[Dict], optional): A list of layer configurations. Defaults to None.
+ max_iteration (int, optional): The maximum number of iterations. Defaults to 10000.
+ """
+
+ seeds: Optional[Dict[str, List[str]]] = None
+ seed_file: Optional[str] = None
+ project_name: str = "spider"
+ db_url: Optional[str] = None
+ db_schema: Optional[str] = None
+ empty_seeds: StopCondition = "stop"
+ layers: Dict[str, "Layer"] = field(default_factory=dict)
+ max_iteration: int = 10000
+
+ def __post_init__(self) -> None:
+ """Configuration-File Wrapper for SpiderExpress"""
+ if self.seeds is None:
+ if self.seed_file is None:
+ raise ValueError("Either seeds or seed_file must be provided.")
+ _seed_file = Path(self.seed_file)
+ if not _seed_file.exists():
+ raise FileNotFoundError(f"Seed file {_seed_file.resolve()} not found.")
+ with _seed_file.open("r", encoding="utf8") as file:
+ self.seeds = json.load(file)
+ self.db_url = self.db_url or f"sqlite:///{self.project_name}.db"
+ if self.db_url.startswith("sqlite") and self.db_schema is not None:
+ raise ValueError("SQLite does not support schemas.")
+ self.empty_seeds = (
+ self.empty_seeds if self.empty_seeds in ["stop", "retry"] else "stop"
+ )
+
+
+@dataclass
+class FieldSpec:
+ """Field Specification"""
+
+ field: str
+ dispatch_with: str
+ regex: Optional[str] = None
+
+
+@dataclass
+class RouterSpec:
+ """Router Configuration"""
+
+ source: str
+ target: List[FieldSpec]
+
+
+@dataclass
+class Layer:
+ """Layer Configuration"""
-class Configuration(yaml.YAMLObject):
- """Configuration-File Wrapper"""
-
- yaml_tag = "!spiderexpress:Configuration"
-
- def __init__(
- self,
- seeds: Optional[List[str]] = None,
- seed_file: Optional[str] = None,
- project_name: str = "spider",
- db_url: Optional[str] = None,
- db_schema: Optional[str] = None,
- empty_seeds: str = "stop",
- eager: bool = True,
- edge_raw_table: Optional[ColumnSpec] = None,
- edge_agg_table: Optional[ColumnSpec] = None,
- node_table: Optional[ColumnSpec] = None,
- strategy: PlugInSpec = "spikyball",
- connector: PlugInSpec = "telegram",
- max_iteration: int = 10000,
- batch_size: int = 150,
- ) -> None:
- if seed_file is not None:
- _seed_file = Path(seed_file)
- if _seed_file.exists():
- with _seed_file.open("r", encoding="utf8") as file:
- self.seeds = list(file.readlines())
- else:
- self.seeds = seeds
- self.strategy = strategy
- self.connector = connector
- self.project_name = project_name
- self.db_url = db_url or f"sqlite:///{project_name}.db"
- self.db_schema = db_schema
- self.edge_raw_table = edge_raw_table or {"name": "edge_raw", "columns": {}}
- self.edge_agg_table = edge_agg_table or {"name": "edge_agg", "columns": {}}
- self.node_table = node_table or {"name": "node", "columns": {}}
- self.max_iteration = max_iteration
- self.batch_size = batch_size
- self.empty_seeds = empty_seeds if empty_seeds in ["stop", "retry"] else "stop"
- self.eager = eager
+ connector: Dict
+ routers: List[Dict[str, Dict]]
+ eager = False
+ sampler: Dict
@dataclass
@@ -93,7 +133,7 @@ class ConfigurationItem:
T = TypeVar("T")
-def fromdict(cls: Type[T], dictionary: dict) -> T:
+def from_dict(cls: Type[T], dictionary: dict) -> T:
"""convert a dictionary to a dataclass
warning:
@@ -107,12 +147,12 @@ def fromdict(cls: Type[T], dictionary: dict) -> T:
returns:
the dataclass with values from the dictionary
"""
- fieldtypes: Dict[str, Type] = {f.name: f.type for f in fields(cls)}
+ field_types = {f.name: f.type for f in fields(cls)}
return cls(
**{
key: (
- fromdict(fieldtypes[key], value)
- if isinstance(value, dict) and is_dataclass(fieldtypes[key])
+ from_dict(field_types[key], value)
+ if isinstance(value, dict) and is_dataclass(field_types[key])
else value
)
for key, value in dictionary.items()
diff --git a/tests/stubs/seeds.json b/tests/stubs/seeds.json
new file mode 100644
index 0000000..947e29a
--- /dev/null
+++ b/tests/stubs/seeds.json
@@ -0,0 +1 @@
+{"test": ["1", "2", "3"]}
diff --git a/tests/stubs/sevens_grader_random_test.pe.yml b/tests/stubs/sevens_grader_random_test.pe.yml
index 9b4b697..1cc726d 100644
--- a/tests/stubs/sevens_grader_random_test.pe.yml
+++ b/tests/stubs/sevens_grader_random_test.pe.yml
@@ -1,31 +1,26 @@
-!spiderexpress:Configuration
-batch_size: 150
-connector:
- csv:
- node_list_location: tests/stubs/7th_graders/nodes.csv
- edge_list_location: tests/stubs/7th_graders/edges.csv
- mode: out
-db_url: sqlite:///
+db_url: sqlite:/// # sevens_grader_random_test.db
db_schema:
-eager: false
empty_seeds: stop
-edge_table_name: edge_list
max_iteration: 10000
-node_table:
- name: sevens_grader_nodes
- columns: {}
-edge_raw_table:
- name: sevens_grader_edge_raw
- columns:
- layer: Integer
-edge_agg_table:
- name: sevens_grader_edge_agg
- columns:
- layer: Integer
+layers:
+ test:
+ eager: false
+ connector:
+ csv:
+ node_list_location: tests/stubs/7th_graders/nodes.csv
+ edge_list_location: tests/stubs/7th_graders/edges.csv
+ mode: out
+ routers:
+ - all: # This is the name of the router and should be the type of edge.
+ source: source # This is the field that is mapped to the source columns.
+ target:
+ - field: target # This is the field that is mapped to the target columns.
+ dispatch_with: test # This is the name of the layer to dispatch to.
+ sampler:
+ random:
+ n: 5
project_name: spider
seeds:
- - "1"
- - "13"
-strategy:
- random:
- n: 5
+ test:
+ - "1"
+ - "13"
diff --git a/tests/stubs/sevens_grader_spikyball_test.pe.yml b/tests/stubs/sevens_grader_spikyball_test.pe.yml
index c1628be..1d281d2 100644
--- a/tests/stubs/sevens_grader_spikyball_test.pe.yml
+++ b/tests/stubs/sevens_grader_spikyball_test.pe.yml
@@ -1,41 +1,37 @@
-!spiderexpress:Configuration
-batch_size: 150
-connector:
- csv:
- node_list_location: tests/stubs/7th_graders/nodes.csv
- edge_list_location: tests/stubs/7th_graders/edges.csv
- mode: out
-db_url: sqlite:///
+db_url: sqlite:/// # sevens_grader_random_test.db
db_schema:
-eager: false
empty_seeds: stop
-edge_table_name: edge_list
max_iteration: 10000
-node_table:
- name: sevens_grader_spkyball_nodes
- columns: {}
-edge_raw_table:
- name: sevens_grader_ed_spkyballge_raw
- columns:
- layer: Integer
-edge_agg_table:
- name: sevens_grader_ed_spkyballge_agg
- columns:
- layer: max
+layers:
+ test:
+ eager: false
+ connector:
+ csv:
+ node_list_location: tests/stubs/7th_graders/nodes.csv
+ edge_list_location: tests/stubs/7th_graders/edges.csv
+ mode: out
+ routers:
+ - all: # This is the name of the router and should be the type of edge.
+ source: source # This is the field that is mapped to the source columns.
+ target:
+ - field: target # This is the field that is mapped to the target columns.
+ dispatch_with: test # This is the name of the layer to dispatch to.
+ sampler:
+ spikyball:
+ layer_max_size: 5
+ sampler:
+ source_node_probability:
+ coefficient: 1
+ weights: { }
+ target_node_probability:
+ coefficient: 1
+ weights: { }
+ edge_probability:
+ coefficient: 1
+ weights: { }
+
project_name: spider
seeds:
- - "1"
- - "13"
-strategy:
- spikyball:
- layer_max_size: 5
- sampler:
- source_node_probability:
- coefficient: 1
- weights: {}
- target_node_probability:
- coefficient: 1
- weights: {}
- edge_probability:
- coefficient: 1
- weights: {}
+ test:
+ - "1"
+ - "13"
diff --git a/tests/test_config.py b/tests/test_config.py
index 0184da0..3d307b4 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -1,57 +1,82 @@
"""test suite for spiderexpress.Configuration"""
-from pathlib import Path
+
+from typing import Dict
import pytest
import yaml
-from pytest import skip
from spiderexpress import Configuration
+from spiderexpress.types import from_dict
# pylint: disable=W0621
-@pytest.fixture
-def configuration():
- """Creates a configuration object."""
- return Configuration(
- ["a", "b"],
- None,
- "test",
- "memory",
- )
-
-
-def test_initializer(configuration):
- """Should instantiate a configuration object."""
-
- assert configuration.project_name == "test"
- assert configuration.db_url == "memory"
- assert configuration.seeds == ["a", "b"]
-
-
-def test_fields():
- """should do something"""
- skip()
-
-
-def test_serialization(configuration, tmpdir: Path):
- """Should write out and re-read the configuration."""
- temp_conf = tmpdir / "test.pe.yml"
-
- with temp_conf.open("w") as file:
- yaml.dump(configuration, file)
-
- assert temp_conf.exists()
-
- with temp_conf.open("r") as file:
- configuration_2 = yaml.full_load(file)
-
- for key, value in configuration.__dict__.items():
- assert value == configuration_2.__dict__[key]
-
-
-def test_seed_seedfile():
- """
- It should have either a seed file or a seed list, should throw otherwise.
- """
- skip()
+def test_parse_configuration_from_file():
+ """Should parse a configuration from a YAML file."""
+ with open(
+ "tests/stubs/sevens_grader_random_test.pe.yml", "r", encoding="utf8"
+ ) as file:
+ config = yaml.safe_load(file)
+ assert from_dict(Configuration, config) is not None
+
+
+def test_parse_configuration_from_dict():
+ """Should parse a configuration from a file."""
+ config = from_dict(Configuration, {"project_name": "test", "seeds": {}})
+ assert config is not None
+ assert config.project_name == "test"
+ assert config.db_url == "sqlite:///test.db"
+ assert config.max_iteration == 10000
+ assert config.empty_seeds == "stop"
+ assert config.layers == {}
+ assert config.seeds == {}
+
+
+def test_fail_to_open_seeds_from_file():
+ """Should fail to parse a configuration from a file."""
+ with pytest.raises(FileNotFoundError):
+ from_dict(Configuration, {"seed_file": "non_existent_file"})
+
+
+def test_open_seeds_from_file():
+ """Should parse a configuration from a file with a seed file."""
+ config = from_dict(Configuration, {"seed_file": "tests/stubs/seeds.json"})
+ assert config is not None
+ assert config.seeds == {"test": ["1", "2", "3"]}
+
+
+@pytest.mark.parametrize(
+ ["configuration"],
+ [
+ pytest.param(
+ {
+ "layers": {
+ "test": {
+ "connector": {"csv": {"n": 1}},
+ "routers": [],
+ "sampler": {"random": {"n": 1}},
+ }
+ },
+ "seeds": {"test": ["1", "13"]},
+ },
+ id="empty_router",
+ ),
+ ],
+)
+def test_parse_layer_configuration(configuration: Dict):
+ """Should parse a layer configuration."""
+ config = from_dict(Configuration, configuration)
+ assert config is not None
+ assert config.project_name == "spider"
+ assert config.db_url == "sqlite:///spider.db"
+ assert config.max_iteration == 10000
+ assert config.empty_seeds == "stop"
+ assert config.seeds == {"test": ["1", "13"]}
+
+
+def test_layers_must_have_a_router_configuration():
+ """Should fail to parse a configuration without a router configuration."""
+ with pytest.raises(ValueError):
+ from_dict(
+ Configuration, {"layers": {"test": {}}, "seeds": {"test": ["1", "13"]}}
+ )
diff --git a/tests/test_model.py b/tests/test_model.py
index 14a5132..02a9f30 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -1,18 +1,12 @@
"""Test suite for the model module."""
+
from datetime import datetime
import pytest
import sqlalchemy as sql
from sqlalchemy.orm import Session
-from spiderexpress.model import (
- AppMetaData,
- Base,
- SeedList,
- create_aggregated_edge_table,
- create_node_table,
- create_raw_edge_table,
-)
+from spiderexpress.model import AppMetaData, Base, SeedList
# pylint: disable=W0621
@@ -44,87 +38,6 @@ def create_tables(connection):
Base.metadata.drop_all(connection)
-def test_create_raw_edge_table_with_session(session, create_tables):
- """Test the creation of a raw edge table."""
- _, RawEdge, edge_factory = create_raw_edge_table("raw_edges", {"weight": "Integer"})
-
- create_tables()
-
- edge = edge_factory({"source": "a", "target": "b", "weight": 1, "view_count": 1})
- session.add(edge)
- session.commit()
-
- assert session.query(RawEdge).count() == 1
-
-
-def test_create_raw_edge_table():
- """Test the creation of a raw edge table."""
- table, RawEdge, _ = create_raw_edge_table("raw_edges_2", {"weight": "Integer"})
-
- assert table.name == "raw_edges_2"
- assert len(table.columns) == 5
-
- assert hasattr(RawEdge, "source")
- assert hasattr(RawEdge, "target")
- assert hasattr(RawEdge, "weight")
-
-
-def test_create_aggregated_edge_table_with_session(session, create_tables):
- """Test the creation of an aggregated edge table."""
- _, Edge, edge_factory = create_aggregated_edge_table(
- "agg_edges", {"view_count": "Integer"}
- )
-
- create_tables()
-
- edge = edge_factory(
- {"source": "a", "target": "b", "weight": 1, "view_count": 1, "iteration": 0}
- )
- session.add(edge)
- session.commit()
-
- assert session.query(Edge).count() == 1
-
-
-def test_create_aggregated_edge_table():
- """Test the creation of an aggregated edge table."""
- table, Edge, _ = create_aggregated_edge_table(
- "agg_edges_2", {"view_count": "Integer"}
- )
-
- assert table.name == "agg_edges_2"
- assert len(table.columns) == 5
-
- assert hasattr(Edge, "source")
- assert hasattr(Edge, "target")
- assert hasattr(Edge, "view_count")
-
-
-def test_create_node_table_session(session, create_tables):
- """Test the creation of a node table."""
-
- _, Node, node_factory = create_node_table("nodes", {"subscriber_count": "Integer"})
-
- create_tables()
-
- node = node_factory({"name": "a", "subscriber_count": 1})
- session.add(node)
- session.commit()
-
- assert session.query(Node).count() == 1
-
-
-def test_create_node_table_with():
- """Test the creation of a node table."""
- table, Node, _ = create_node_table("nodes2", {"subscriber_count": "Integer"})
-
- assert table.name == "nodes2"
- assert len(table.columns) == 3
-
- assert hasattr(Node, "name")
- assert hasattr(Node, "subscriber_count")
-
-
def test_app_state_table(session, create_tables):
"""Test the creation of a node table."""
diff --git a/tests/test_router.py b/tests/test_router.py
index 36bb51f..54a8d2d 100644
--- a/tests/test_router.py
+++ b/tests/test_router.py
@@ -1,4 +1,5 @@
"""Test suite for spiderexpress' multi-layer router."""
+
from typing import Any, Dict, List, Optional, Union
import pytest
@@ -9,20 +10,25 @@
@pytest.mark.parametrize(
["specification", "context"],
[
- pytest.param({"from": "column", "to": "wrong_value"}, None, id="to_is_string"),
- pytest.param({"from": "here", "ot": "bad"}, None, id="to_is_missing"),
pytest.param(
- {"from": "here", "to": [{"field": "yadad"}]},
+ {"source": "column", "target": "wrong_value"}, None, id="target_is_string"
+ ),
+ pytest.param({"source": "here", "ot": "bad"}, None, id="target_is_missing"),
+ pytest.param(
+ {"source": "here", "target": [{"field": "yadad"}]},
{"connectors": {"test": {"columns": {"there": "Text"}}}},
id="from_column_is_missing_in_context",
),
pytest.param(
- {"from": "column", "to": [{"field": "column", "dispatch_with": "layer_1"}]},
+ {
+ "source": "column",
+ "target": [{"field": "column", "dispatch_with": "layer_1"}],
+ },
{"connectors": {"layer_1": {"type": "something", "columns": {"column1"}}}},
id="column_is_missing_in_context",
),
pytest.param(
- {"from": "column", "to": [{"dispatch_with": "layer_1"}]},
+ {"source": "column", "target": [{"dispatch_with": "layer_1"}]},
{"connectors": {"layer_1": {"type": "something", "columns": {"column1"}}}},
id="field is None",
),
@@ -59,14 +65,14 @@ def test_router_spec_validation(
pytest.param(
input_data_1,
{
- "from": "handle",
- "to": [{"field": "forwarded_handle", "dispatch_with": "test"}],
+ "source": "handle",
+ "target": [{"field": "forwarded_handle", "dispatch_with": "test"}],
"view_count": "view_count",
},
[
{
- "from": "Tony",
- "to": "Bert",
+ "source": "Tony",
+ "target": "Bert",
"view_count": 123,
"dispatch_with": "test",
}
@@ -76,8 +82,8 @@ def test_router_spec_validation(
pytest.param(
input_data_1,
{
- "from": "handle",
- "to": [
+ "source": "handle",
+ "target": [
{
"field": "url",
"pattern": r"https://www\.twitter\.com/(\w+)",
@@ -88,8 +94,8 @@ def test_router_spec_validation(
},
[
{
- "from": "Tony",
- "to": "ernie",
+ "source": "Tony",
+ "target": "ernie",
"view_count": 123,
"dispatch_with": "test",
}
@@ -99,8 +105,8 @@ def test_router_spec_validation(
pytest.param(
input_data_1,
{
- "from": "handle",
- "to": [
+ "source": "handle",
+ "target": [
{
"field": "url",
"pattern": r"https://www\.twitter\.com/(\w+)",
@@ -112,8 +118,8 @@ def test_router_spec_validation(
},
[
{
- "from": "Tony",
- "to": "ernie",
+ "source": "Tony",
+ "target": "ernie",
"view_count": 123,
"type": "twitter-url",
"dispatch_with": "test",
@@ -124,8 +130,8 @@ def test_router_spec_validation(
pytest.param(
input_data_1,
{
- "from": "handle",
- "to": [
+ "source": "handle",
+ "target": [
{
"field": "text",
"pattern": r"https://www\.twitter\.com/(\w+)",
@@ -137,15 +143,15 @@ def test_router_spec_validation(
},
[
{
- "from": "Tony",
- "to": "ernie",
+ "source": "Tony",
+ "target": "ernie",
"view_count": 123,
"dispatch_with": "test",
"type": "twitter-url",
},
{
- "from": "Tony",
- "to": "bobafett",
+ "source": "Tony",
+ "target": "bobafett",
"view_count": 123,
"dispatch_with": "test",
"type": "twitter-url",
@@ -156,8 +162,8 @@ def test_router_spec_validation(
pytest.param(
input_data_1,
{
- "from": "handle",
- "to": [
+ "source": "handle",
+ "target": [
{
"field": "text",
"pattern": r"https://www\.twitter\.com/(\w+)",
@@ -169,21 +175,21 @@ def test_router_spec_validation(
},
[
{
- "from": "Tony",
- "to": "ernie",
+ "source": "Tony",
+ "target": "ernie",
"view_count": 123,
"dispatch_with": "test",
"type": "twitter-url",
},
{
- "from": "Tony",
- "to": "bobafett",
+ "source": "Tony",
+ "target": "bobafett",
"view_count": 123,
"dispatch_with": "test",
"type": "twitter-url",
},
],
- id="multiple values from regex with directive constant",
+ id="multiple values source regex with directive constant",
),
],
)
diff --git a/tests/test_spider.py b/tests/test_spider.py
index 98389b6..8fce165 100644
--- a/tests/test_spider.py
+++ b/tests/test_spider.py
@@ -5,6 +5,7 @@
It is loaded by click as soon as the application starts and the configuration
is loaded automatically by the initializer.
"""
+
# pylint: disable=E1101
from pathlib import Path
@@ -28,18 +29,15 @@ def test_load_config():
def test_spider():
"""Should instantiate a spider."""
- spider = Spider(auto_transitions=False)
+ spider = Spider(auto_transitions=True)
assert spider is not None
assert spider.is_idle()
spider.start(Path("tests/stubs/sevens_grader_random_test.pe.yml"))
- assert spider.is_starting()
assert spider.configuration is not None
-
- spider.gather()
- assert spider._cache_ is not None # pylint: disable=W0212
+ assert spider.is_stopped()
def test_spider_with_spikyball():
@@ -51,7 +49,7 @@ def test_spider_with_spikyball():
spider.start(Path("tests/stubs/sevens_grader_spikyball_test.pe.yml"))
- assert spider.is_stopping()
+ assert spider.is_stopped()
assert spider.configuration is not None
diff --git a/tests/test_utils.py b/tests/test_utils.py
index d2f8eb9..e996df3 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -10,7 +10,7 @@
import pytest
-from spiderexpress.types import fromdict
+from spiderexpress.types import from_dict
from tests.conftest import MyFunkyTestClass, MyOtherFunkyTestClass
@@ -31,6 +31,6 @@
)
def test_fromdict(value: object):
"""test fromdict()"""
- ans = fromdict(type(value), asdict(value))
+ ans = from_dict(type(value), asdict(value))
print(ans)
assert ans == value