Skip to content

Commit

Permalink
new: test_multiple_graph_sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
aMahanna committed Jan 6, 2025
1 parent ed34107 commit fe554ae
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 7 deletions.
20 changes: 14 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import logging
import os
import sys
from io import StringIO
from typing import Any
from typing import Any, Dict

import networkx as nx
import pytest
Expand All @@ -15,6 +13,8 @@

logger.setLevel(logging.INFO)

con: Dict[str, Any]
client: ArangoClient
db: StandardDatabase
run_gpu_tests: bool

Expand All @@ -30,6 +30,7 @@ def pytest_addoption(parser: Any) -> None:


def pytest_configure(config: Any) -> None:
global con
con = {
"url": config.getoption("url"),
"username": config.getoption("username"),
Expand All @@ -43,10 +44,11 @@ def pytest_configure(config: Any) -> None:
print("Password: " + con["password"])
print("Database: " + con["dbName"])

global client
client = ArangoClient(hosts=con["url"])

global db
db = ArangoClient(hosts=con["url"]).db(
con["dbName"], con["username"], con["password"], verify=True
)
db = client.db(con["dbName"], con["username"], con["password"], verify=True)

print("Version: " + db.version())
print("----------------------------------------")
Expand Down Expand Up @@ -99,6 +101,12 @@ def load_two_relation_graph() -> None:
)


def get_db(db_name: str) -> StandardDatabase:
global con
global client
return client.db(db_name, con["username"], con["password"], verify=True)


def create_line_graph(load_attributes: set[str]) -> nxadb.Graph:
G = nx.Graph()
G.add_edge(1, 2, my_custom_weight=1)
Expand Down
35 changes: 34 additions & 1 deletion tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from nx_arangodb.classes.dict.graph import GRAPH_FIELD
from nx_arangodb.classes.dict.node import NodeAttrDict, NodeDict

from .conftest import create_grid_graph, create_line_graph, db, run_gpu_tests
from .conftest import create_grid_graph, create_line_graph, db, get_db, run_gpu_tests

G_NX: nx.Graph = nx.karate_club_graph()
G_NX_digraph = nx.DiGraph(G_NX)
Expand Down Expand Up @@ -88,6 +88,39 @@ def test_adb_graph_init(graph_cls: type[nxadb.Graph]) -> None:
G.name = "RenamedTestGraph"


def test_multiple_graph_sessions():
db_1_name = "test_db_1"
db_2_name = "test_db_2"

db.delete_database(db_1_name, ignore_missing=True)
db.delete_database(db_2_name, ignore_missing=True)

db.create_database(db_1_name)
db.create_database(db_2_name)

db_1 = get_db(db_1_name)
db_2 = get_db(db_2_name)

G_1 = nxadb.Graph(name="TestGraph", db=db_1)
G_2 = nxadb.Graph(name="TestGraph", db=db_2)

G_1.add_node(1, foo="bar")
G_1.add_node(2)
G_1.add_edge(1, 2)

G_2.add_node(1)
G_2.add_node(2)
G_2.add_node(3)
G_2.add_edge(1, 2)
G_2.add_edge(2, 3)

res_1 = nx.pagerank(G_1)
res_2 = nx.pagerank(G_2)

assert len(res_1) == 2
assert len(res_2) == 3


def test_load_graph_from_nxadb():
graph_name = "KarateGraph"

Expand Down

0 comments on commit fe554ae

Please sign in to comment.