Skip to content

Commit

Permalink
new: fully support parameterized db object (#70)
Browse files Browse the repository at this point in the history
* new: fully support parameterized `db` object

* fix: `hosts`

* fix: docstring

* new: support `use_gpu` algorithm parameter

* new: `test_multiple_graph_sessions`
  • Loading branch information
aMahanna authored Jan 6, 2025
1 parent 9f59085 commit bd47753
Show file tree
Hide file tree
Showing 13 changed files with 158 additions and 75 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,16 @@ import nx_arangodb as nxadb

G = nxadb.Graph(name="MyGraph")

# Option 1: Use Global Config
nx.config.backends.arangodb.use_gpu = False

nx.pagerank(G)
nx.betweenness_centrality(G)
# ...

nx.config.backends.arangodb.use_gpu = True

# Option 2: Use Local Config
nx.pagerank(G, use_gpu=False)
nx.betweenness_centrality(G, use_gpu=False)
```

<p align="center">
Expand Down
10 changes: 1 addition & 9 deletions _nx_arangodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,7 @@ def get_info():
for key in info_keys:
del d[key]

d["default_config"] = {
"host": None,
"username": None,
"password": None,
"db_name": None,
"read_parallelism": None,
"read_batch_size": None,
"use_gpu": True,
}
d["default_config"] = {"use_gpu": True}

return d

Expand Down
7 changes: 5 additions & 2 deletions doc/algorithms/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,17 @@ You can also force-run algorithms on CPU even if ``nx-cugraph`` is installed:
G = nxadb.Graph(name="MyGraph")
# Option 1: Use Global Config
nx.config.backends.arangodb.use_gpu = False
nx.pagerank(G)
nx.betweenness_centrality(G)
# ...
nx.config.backends.arangodb.use_gpu = True
# Option 2: Use Local Config
nx.pagerank(G, use_gpu=False)
nx.betweenness_centrality(G, use_gpu=False)
.. image:: ../_static/dispatch.png
:align: center
Expand Down
6 changes: 1 addition & 5 deletions doc/nx_arangodb.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,7 @@
"outputs": [],
"source": [
"# 5. Run an algorithm (CPU)\n",
"nx.config.backends.arangodb.use_gpu = False # Optional\n",
"\n",
"res = nx.pagerank(G)"
"res = nx.pagerank(G, use_gpu=False)"
]
},
{
Expand Down Expand Up @@ -357,8 +355,6 @@
"source": [
"# 4. Run an algorithm (GPU)\n",
"# See *Package Installation* to install nx-cugraph ^\n",
"nx.config.backends.arangodb.use_gpu = True\n",
"\n",
"res = nx.pagerank(G)"
]
},
Expand Down
16 changes: 16 additions & 0 deletions nx_arangodb/classes/dict/adj.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def adjlist_outer_dict_factory(
db: StandardDatabase,
graph: Graph,
default_node_type: str,
read_parallelism: int,
read_batch_size: int,
edge_type_key: str,
edge_type_func: Callable[[str, str], str],
graph_type: str,
Expand All @@ -115,6 +117,8 @@ def adjlist_outer_dict_factory(
db,
graph,
default_node_type,
read_parallelism,
read_batch_size,
edge_type_key,
edge_type_func,
graph_type,
Expand Down Expand Up @@ -1454,6 +1458,12 @@ class AdjListOuterDict(UserDict[str, AdjListInnerDict]):
symmetrize_edges_if_directed : bool
Whether to add the reverse edge if the graph is directed.
read_parallelism : int
The number of parallel threads to use for reading data in _fetch_all.
read_batch_size : int
The number of documents to read in each batch in _fetch_all.
Example
-------
>>> g = nxadb.Graph(name="MyGraph")
Expand All @@ -1467,6 +1477,8 @@ def __init__(
db: StandardDatabase,
graph: Graph,
default_node_type: str,
read_parallelism: int,
read_batch_size: int,
edge_type_key: str,
edge_type_func: Callable[[str, str], str],
graph_type: str,
Expand All @@ -1489,6 +1501,8 @@ def __init__(
self.edge_type_key = edge_type_key
self.edge_type_func = edge_type_func
self.default_node_type = default_node_type
self.read_parallelism = read_parallelism
self.read_batch_size = read_batch_size
self.adjlist_inner_dict_factory = adjlist_inner_dict_factory(
db,
graph,
Expand Down Expand Up @@ -1853,6 +1867,8 @@ def _fetch_all(self) -> None:
is_directed=True,
is_multigraph=self.is_multigraph,
symmetrize_edges_if_directed=self.symmetrize_edges_if_directed,
read_parallelism=self.read_parallelism,
read_batch_size=self.read_batch_size,
)

# Even if the Graph is undirected,
Expand Down
27 changes: 25 additions & 2 deletions nx_arangodb/classes/dict/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,20 @@


def node_dict_factory(
db: StandardDatabase, graph: Graph, default_node_type: str
db: StandardDatabase,
graph: Graph,
default_node_type: str,
read_parallelism: int,
read_batch_size: int,
) -> Callable[..., NodeDict]:
"""Factory function for creating a NodeDict."""
return lambda: NodeDict(db, graph, default_node_type)
return lambda: NodeDict(
db,
graph,
default_node_type,
read_parallelism,
read_batch_size,
)


def node_attr_dict_factory(
Expand Down Expand Up @@ -250,6 +260,12 @@ class NodeDict(UserDict[str, NodeAttrDict]):
default_node_type : str
The default node type for the graph.
read_parallelism : int
The number of parallel threads to use for reading data in _fetch_all.
read_batch_size : int
The number of documents to read in each batch in _fetch_all.
Example
-------
>>> G = nxadb.Graph("MyGraph")
Expand All @@ -262,6 +278,8 @@ def __init__(
db: StandardDatabase,
graph: Graph,
default_node_type: str,
read_parallelism: int,
read_batch_size: int,
*args: Any,
**kwargs: Any,
):
Expand All @@ -271,6 +289,9 @@ def __init__(
self.db = db
self.graph = graph
self.default_node_type = default_node_type
self.read_parallelism = read_parallelism
self.read_batch_size = read_batch_size

self.node_attr_dict_factory = node_attr_dict_factory(self.db, self.graph)

self.FETCHED_ALL_DATA = False
Expand Down Expand Up @@ -472,6 +493,8 @@ def _fetch_all(self):
is_directed=False, # not used
is_multigraph=False, # not used
symmetrize_edges_if_directed=False, # not used
read_parallelism=self.read_parallelism,
read_batch_size=self.read_batch_size,
)

for node_id, node_data in node_dict.items():
Expand Down
23 changes: 12 additions & 11 deletions nx_arangodb/classes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def get_arangodb_graph(
is_directed: bool,
is_multigraph: bool,
symmetrize_edges_if_directed: bool,
read_parallelism: int,
read_batch_size: int,
) -> Tuple[
NodeDict,
GraphAdjDict | DiGraphAdjDict | MultiGraphAdjDict | MultiDiGraphAdjDict,
Expand Down Expand Up @@ -142,11 +144,10 @@ def get_arangodb_graph(
if not load_adj_dict and not load_coo:
metagraph["edgeCollections"] = {}

config = nx.config.backends.arangodb
assert config.db_name
assert config.host
assert config.username
assert config.password
hosts = adb_graph._conn._hosts
hosts = hosts.split(",") if type(hosts) is str else hosts
db_name = adb_graph._conn._db_name
username, password = adb_graph._conn._auth

(
node_dict,
Expand All @@ -157,20 +158,20 @@ def get_arangodb_graph(
vertex_ids_to_index,
edge_values,
) = NetworkXLoader.load_into_networkx(
config.db_name,
database=db_name,
metagraph=metagraph,
hosts=[config.host],
username=config.username,
password=config.password,
hosts=hosts,
username=username,
password=password,
load_adj_dict=load_adj_dict,
load_coo=load_coo,
load_all_vertex_attributes=load_all_vertex_attributes,
load_all_edge_attributes=load_all_edge_attributes,
is_directed=is_directed,
is_multigraph=is_multigraph,
symmetrize_edges_if_directed=symmetrize_edges_if_directed,
parallelism=config.read_parallelism,
batch_size=config.read_batch_size,
parallelism=read_parallelism,
batch_size=read_batch_size,
)

return (
Expand Down
59 changes: 27 additions & 32 deletions nx_arangodb/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,13 @@ def __init__(
self.use_nxcg_cache = True
self.nxcg_graph = None

self.edge_type_key = edge_type_key
self.read_parallelism = read_parallelism
self.read_batch_size = read_batch_size

# Does not apply to undirected graphs
self.symmetrize_edges = symmetrize_edges

self.edge_type_key = edge_type_key

# TODO: Consider this
# if not self.__graph_name:
# if incoming_graph_data is not None:
Expand All @@ -227,8 +229,8 @@ def __init__(

self._loaded_incoming_graph_data = False
if self.graph_exists_in_db:
self._set_factory_methods()
self.__set_arangodb_backend_config(read_parallelism, read_batch_size)
self._set_factory_methods(read_parallelism, read_batch_size)
self.__set_arangodb_backend_config()

if overwrite_graph:
logger.info("Overwriting graph...")
Expand Down Expand Up @@ -284,7 +286,7 @@ def __init__(
# Init helper methods #
#######################

def _set_factory_methods(self) -> None:
def _set_factory_methods(self, read_parallelism: int, read_batch_size: int) -> None:
"""Set the factory methods for the graph, _node, and _adj dictionaries.
The ArangoDB CRUD operations are handled by the modified dictionaries.
Expand All @@ -299,39 +301,29 @@ def _set_factory_methods(self) -> None:
"""

base_args = (self.db, self.adb_graph)

node_args = (*base_args, self.default_node_type)
adj_args = (
*node_args,
self.edge_type_key,
self.edge_type_func,
self.__class__.__name__,
node_args_with_read = (*node_args, read_parallelism, read_batch_size)

adj_args = (self.edge_type_key, self.edge_type_func, self.__class__.__name__)
adj_inner_args = (*node_args, *adj_args)
adj_outer_args = (
*node_args_with_read,
*adj_args,
self.symmetrize_edges,
)

self.graph_attr_dict_factory = graph_dict_factory(*base_args)

self.node_dict_factory = node_dict_factory(*node_args)
self.node_dict_factory = node_dict_factory(*node_args_with_read)
self.node_attr_dict_factory = node_attr_dict_factory(*base_args)

self.edge_attr_dict_factory = edge_attr_dict_factory(*base_args)
self.adjlist_inner_dict_factory = adjlist_inner_dict_factory(*adj_args)
self.adjlist_outer_dict_factory = adjlist_outer_dict_factory(
*adj_args, self.symmetrize_edges
)

def __set_arangodb_backend_config(
self, read_parallelism: int, read_batch_size: int
) -> None:
if not all([self._host, self._username, self._password, self._db_name]):
m = "Must set all environment variables to use the ArangoDB Backend with an existing graph" # noqa: E501
raise OSError(m)
self.adjlist_inner_dict_factory = adjlist_inner_dict_factory(*adj_inner_args)
self.adjlist_outer_dict_factory = adjlist_outer_dict_factory(*adj_outer_args)

def __set_arangodb_backend_config(self) -> None:
config = nx.config.backends.arangodb
config.host = self._host
config.username = self._username
config.password = self._password
config.db_name = self._db_name
config.read_parallelism = read_parallelism
config.read_batch_size = read_batch_size
config.use_gpu = True # Only used by default if nx-cugraph is available

def __set_edge_collections_attributes(self, attributes: set[str] | None) -> None:
Expand All @@ -345,7 +337,7 @@ def __set_edge_collections_attributes(self, attributes: set[str] | None) -> None
self._edge_collections_attributes.add("_id")

def __set_db(self, db: Any = None) -> None:
self._host = os.getenv("DATABASE_HOST")
self._hosts = os.getenv("DATABASE_HOST", "").split(",")
self._username = os.getenv("DATABASE_USERNAME")
self._password = os.getenv("DATABASE_PASSWORD")
self._db_name = os.getenv("DATABASE_NAME")
Expand All @@ -355,17 +347,20 @@ def __set_db(self, db: Any = None) -> None:
m = "arango.database.StandardDatabase"
raise TypeError(m)

db.version()
db.version() # make sure the connection is valid
self.__db = db
self._db_name = db.name
self._hosts = db._conn._hosts
self._username, self._password = db._conn._auth
return

if not all([self._host, self._username, self._password, self._db_name]):
if not all([self._hosts, self._username, self._password, self._db_name]):
m = "Database environment variables not set. Can't connect to the database"
logger.warning(m)
self.__db = None
return

self.__db = ArangoClient(hosts=self._host, request_timeout=None).db(
self.__db = ArangoClient(hosts=self._hosts, request_timeout=None).db(
self._db_name, self._username, self._password, verify=True
)

Expand Down
4 changes: 2 additions & 2 deletions nx_arangodb/classes/multigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ def __init__(
# Init helper methods #
#######################

def _set_factory_methods(self) -> None:
super()._set_factory_methods()
def _set_factory_methods(self, read_parallelism: int, read_batch_size: int) -> None:
super()._set_factory_methods(read_parallelism, read_batch_size)
self.edge_key_dict_factory = edge_key_dict_factory(
self.db,
self.adb_graph,
Expand Down
Loading

0 comments on commit bd47753

Please sign in to comment.