diff --git a/.github/workflows/dist_pipeline.yml b/.github/workflows/dist_pipeline.yml index ae23ca5..9683ecb 100644 --- a/.github/workflows/dist_pipeline.yml +++ b/.github/workflows/dist_pipeline.yml @@ -7,7 +7,6 @@ on: paths-ignore: - '**.md' - 'docs/**' - - '.github/**' push: tags: - 'v*' @@ -42,6 +41,7 @@ jobs: name: Create Draft Release with Built Binaries needs: - duckdb-stable-build + - duckdb-next-stable-build if: startsWith(github.ref, 'refs/tags/') runs-on: ubuntu-latest permissions: @@ -50,6 +50,7 @@ jobs: - name: Download All Build Artifacts uses: actions/download-artifact@v4 with: + pattern: onager-* path: dist merge-multiple: true - name: List Artifacts diff --git a/.github/workflows/lints.yml b/.github/workflows/lints.yml index 6835c47..a62b2e1 100644 --- a/.github/workflows/lints.yml +++ b/.github/workflows/lints.yml @@ -8,7 +8,6 @@ on: paths-ignore: - '**.md' - 'docs/**' - - '.github/**' permissions: contents: read diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 83d8632..fe90a2b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,7 +8,6 @@ on: paths-ignore: - '**.md' - 'docs/**' - - '.github/**' push: branches: - main diff --git a/ROADMAP.md b/ROADMAP.md index 122ccb5..d90410f 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -87,7 +87,7 @@ It outlines features to be implemented and their current status. ### 9. Minimum Spanning Tree * [x] Kruskal's algorithm -* [ ] Prim's algorithm +* [x] Prim's algorithm ### 10. Link Prediction diff --git a/docs/assets/logo.svg b/docs/assets/logo.svg new file mode 100644 index 0000000..9a6741a --- /dev/null +++ b/docs/assets/logo.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/docs/guide/links.md b/docs/guide/links.md index 336fa28..90898e7 100644 --- a/docs/guide/links.md +++ b/docs/guide/links.md @@ -103,6 +103,30 @@ order by score desc limit 10; --- +## Common Neighbors + +Simply counts the number of shared neighbors between two nodes. +The most basic link prediction heuristic — nodes with many common friends are likely to become friends. + +\[ +CN(u, v) = |N(u) \cap N(v)| +\] + +```sql +select node1, node2, count as common_neighbors +from onager_lnk_common_neighbors((select src, dst from edges)) +where count > 0 +order by count desc limit 10; +``` + +| Column | Type | Description | +|--------|--------|----------------------------------| +| node1 | bigint | First node | +| node2 | bigint | Second node | +| count | bigint | Number of shared neighbors | + +--- + ## Complete Example: Friend Recommendations Find potential connections in a social network: diff --git a/docs/guide/mst.md b/docs/guide/mst.md index f4266cb..a68ca54 100644 --- a/docs/guide/mst.md +++ b/docs/guide/mst.md @@ -34,3 +34,39 @@ order by weight; | src | bigint | Source node | | dst | bigint | Destination node | | weight | double | Edge weight | + +--- + +## Prim's Algorithm + +Prim's algorithm builds the MST by starting from an arbitrary node and repeatedly adding the minimum weight edge that connects a new node. + +```sql +select src, dst, weight +from onager_mst_prim((select src, dst, weight from weighted_edges)) +order by weight; +``` + +| Column | Type | Description | +|--------|--------|-------------------| +| src | bigint | Source node | +| dst | bigint | Destination node | +| weight | double | Edge weight | + +--- + +## Comparison + +Both algorithms produce optimal minimum spanning trees but differ in approach: + +- **Kruskal's**: Sorts all edges globally which is best for sparse graphs +- **Prim's**: Grows tree from a starting node which is best for dense graphs + +```sql +-- Both return the same total weight +select 'Kruskal' as algorithm, sum(weight) as total_weight +from onager_mst_kruskal((select src, dst, weight from weighted_edges)) +union all +select 'Prim', sum(weight) +from onager_mst_prim((select src, dst, weight from weighted_edges)); +``` diff --git a/docs/index.md b/docs/index.md index 0be6154..818eb7c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -68,7 +68,7 @@ Onager currently includes the following graph algorithms: | Subgraphs | Ego graph, k-hop neighbors, and induced subgraph | | Generators | Erdős-Rényi, Barabási-Albert, and Watts-Strogatz | | Approximation | Max clique, independent set, vertex cover, and TSP | -| MST | Kruskal's algorithm | +| MST | Kruskal's and Prim's algorithms | | Parallel | PageRank, BFS, shortest paths, connected components, clustering, and triangle counting | ## Get Started diff --git a/docs/reference/input-formats.md b/docs/reference/input-formats.md index 666ac2f..d35824a 100644 --- a/docs/reference/input-formats.md +++ b/docs/reference/input-formats.md @@ -42,7 +42,8 @@ select onager_node_in_degree('social', 1); - `onager_pth_bellman_ford` — shortest paths with negative weights - `onager_pth_floyd_warshall` — all-pairs shortest paths - - `onager_mst_kruskal` — minimum spanning tree + - `onager_mst_kruskal` — minimum spanning tree (Kruskal's) + - `onager_mst_prim` — minimum spanning tree (Prim's) - `onager_apx_tsp` — traveling salesman approximation Pass weights like this: diff --git a/docs/reference/sql-functions.md b/docs/reference/sql-functions.md index cfd93b0..6753122 100644 --- a/docs/reference/sql-functions.md +++ b/docs/reference/sql-functions.md @@ -100,6 +100,7 @@ Complete reference for all Onager SQL functions. | Function | Returns | Description | |--------------------------------------|--------------------|---------------| | `onager_mst_kruskal(weighted_edges)` | `src, dst, weight` | Kruskal's MST | +| `onager_mst_prim(weighted_edges)` | `src, dst, weight` | Prim's MST | ## Generator Functions diff --git a/docs/stylesheets/extra.css b/docs/stylesheets/extra.css new file mode 100644 index 0000000..b69159e --- /dev/null +++ b/docs/stylesheets/extra.css @@ -0,0 +1,27 @@ +/* Custom styles for Onager documentation */ + +/* Improve code block readability */ +.highlight .hll { + background-color: var(--md-code-hl-color); +} + +/* Better spacing for admonitions */ +.admonition { + margin: 1.5625em 0; +} + +/* Improve table styling */ +table { + display: table; + width: 100%; +} + +/* Better link styling */ +a:hover { + text-decoration: underline; +} + +/* Code annotation improvements */ +.md-annotation__index { + cursor: pointer; +} diff --git a/mkdocs.yml b/mkdocs.yml index 29876bd..f73f546 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -6,42 +6,87 @@ site_url: https://cogitatortech.github.io/onager/ theme: name: material + logo: assets/logo.svg + favicon: assets/logo.svg palette: + # Light mode - media: "(prefers-color-scheme: light)" scheme: default primary: deep purple accent: amber toggle: - icon: material/brightness-7 + icon: material/weather-night name: Switch to dark mode + # Dark mode - media: "(prefers-color-scheme: dark)" scheme: slate primary: deep purple accent: amber toggle: - icon: material/brightness-4 + icon: material/weather-sunny name: Switch to light mode font: text: Inter code: JetBrains Mono + icon: + repo: fontawesome/brands/github + annotation: material/arrow-right-circle features: - - content.code.copy + # Navigation + - navigation.instant + - navigation.instant.prefetch + - navigation.instant.progress + - navigation.tracking - navigation.tabs + - navigation.tabs.sticky + - navigation.sections + - navigation.expand + - navigation.path - navigation.top + - navigation.footer - navigation.indexes - - navigation.expand + # Table of contents + - toc.follow + # Search + - search.suggest + - search.highlight + - search.share + # Content + - content.code.copy - content.code.select - content.code.annotate - - navigation.tracking - - navigation.sections + - content.tabs.link + - content.tooltips + # Header + - header.autohide + - announce.dismiss extra: social: - icon: fontawesome/brands/github link: https://github.com/CogitatorTech/onager + name: Onager on GitHub + generator: false + analytics: + feedback: + title: Was this page helpful? + ratings: + - icon: material/emoticon-happy-outline + name: This page was helpful + data: 1 + note: Thanks for your feedback! + - icon: material/emoticon-sad-outline + name: This page could be improved + data: 0 + note: Thanks for your feedback! Help us improve by opening an issue. + +extra_css: + - stylesheets/extra.css plugins: - - search + - search: + separator: '[\s\-,:!=\[\]()"/]+|(?!\b)(?=[A-Z][a-z])|\.(?!\d)|&[lg]t;' + - tags nav: - Home: index.md @@ -71,18 +116,57 @@ nav: - Input Formats: reference/input-formats.md markdown_extensions: - - pymdownx.highlight: - anchor_linenums: true - - pymdownx.inlinehilite - - pymdownx.snippets - - pymdownx.superfences - - pymdownx.details - - pymdownx.arithmatex: - generic: true + # Python Markdown + - abbr - admonition - attr_list + - def_list + - footnotes + - md_in_html + - tables - toc: permalink: true + permalink_title: Anchor link to this section + toc_depth: 3 + # PyMdownx + - pymdownx.arithmatex: + generic: true + - pymdownx.betterem: + smart_enable: all + - pymdownx.caret + - pymdownx.critic + - pymdownx.details + - pymdownx.emoji: + emoji_index: !!python/name:material.extensions.emoji.twemoji + emoji_generator: !!python/name:material.extensions.emoji.to_svg + - pymdownx.highlight: + anchor_linenums: true + line_spans: __span + pygments_lang_class: true + auto_title: false + - pymdownx.inlinehilite + - pymdownx.keys + - pymdownx.magiclink: + normalize_issue_symbols: true + repo_url_shorthand: true + user: CogitatorTech + repo: onager + - pymdownx.mark + - pymdownx.smartsymbols + - pymdownx.snippets: + base_path: docs + check_paths: true + - pymdownx.superfences: + custom_fences: + - name: mermaid + class: mermaid + format: !!python/name:pymdownx.superfences.fence_code_format + - pymdownx.tabbed: + alternate_style: true + combine_header_slug: true + - pymdownx.tasklist: + custom_checkbox: true + - pymdownx.tilde extra_javascript: - assets/js/mathjax.js diff --git a/onager/Cargo.toml b/onager/Cargo.toml index 25a95cd..5109c89 100644 --- a/onager/Cargo.toml +++ b/onager/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "onager" -version = "0.1.0-alpha.5" +version = "0.1.0-alpha.4" edition = "2021" publish = false description = "A Graph Analytics Toolbox for DuckDB" diff --git a/onager/bindings/functions/mst.cpp b/onager/bindings/functions/mst.cpp index 3988351..92a4034 100644 --- a/onager/bindings/functions/mst.cpp +++ b/onager/bindings/functions/mst.cpp @@ -60,6 +60,55 @@ static OperatorFinalizeResultType KruskalMstFinal(ExecutionContext &ctx, TableFu return gs.output_idx >= gs.result_src.size() ? OperatorFinalizeResultType::FINISHED : OperatorFinalizeResultType::HAVE_MORE_OUTPUT; } +// ============================================================================= +// Prim MST +// ============================================================================= + +struct PrimMstGlobalState : public GlobalTableFunctionState { + std::mutex input_mutex; + std::vector src_nodes, dst_nodes, result_src, result_dst; + std::vector weights, result_weights; + double total_weight = 0.0; + idx_t output_idx = 0; bool computed = false; + idx_t MaxThreads() const override { return 1; } +}; + +static unique_ptr PrimMstBind(ClientContext &ctx, TableFunctionBindInput &input, vector &rt, vector &nm) { + CheckInt64Input(input, "onager_mst_prim", 3); + rt.push_back(LogicalType::BIGINT); nm.push_back("src"); + rt.push_back(LogicalType::BIGINT); nm.push_back("dst"); + rt.push_back(LogicalType::DOUBLE); nm.push_back("weight"); + return make_uniq(); +} +static unique_ptr PrimMstInitGlobal(ClientContext &ctx, TableFunctionInitInput &input) { return make_uniq(); } +static OperatorResultType PrimMstInOut(ExecutionContext &ctx, TableFunctionInput &data, DataChunk &input, DataChunk &output) { + auto &gs = data.global_state->Cast(); + std::lock_guard lock(gs.input_mutex); + auto s = FlatVector::GetData(input.data[0]); auto d = FlatVector::GetData(input.data[1]); + auto w = FlatVector::GetData(input.data[2]); + for (idx_t i = 0; i < input.size(); i++) { gs.src_nodes.push_back(s[i]); gs.dst_nodes.push_back(d[i]); gs.weights.push_back(w[i]); } + output.SetCardinality(0); return OperatorResultType::NEED_MORE_INPUT; +} +static OperatorFinalizeResultType PrimMstFinal(ExecutionContext &ctx, TableFunctionInput &data, DataChunk &output) { + auto &gs = data.global_state->Cast(); + std::lock_guard lock(gs.input_mutex); + if (!gs.computed) { + if (gs.src_nodes.empty()) { gs.computed = true; output.SetCardinality(0); return OperatorFinalizeResultType::FINISHED; } + int64_t ec = ::onager::onager_compute_prim_mst(gs.src_nodes.data(), gs.dst_nodes.data(), gs.weights.data(), gs.src_nodes.size(), nullptr, nullptr, nullptr, nullptr); + if (ec < 0) throw InvalidInputException("Prim MST failed: " + GetOnagerError()); + gs.result_src.resize(ec); gs.result_dst.resize(ec); gs.result_weights.resize(ec); + ::onager::onager_compute_prim_mst(gs.src_nodes.data(), gs.dst_nodes.data(), gs.weights.data(), gs.src_nodes.size(), gs.result_src.data(), gs.result_dst.data(), gs.result_weights.data(), &gs.total_weight); + gs.computed = true; + } + idx_t rem = gs.result_src.size() - gs.output_idx; + if (rem == 0) { output.SetCardinality(0); return OperatorFinalizeResultType::FINISHED; } + idx_t to = MinValue(rem, STANDARD_VECTOR_SIZE); + auto s = FlatVector::GetData(output.data[0]); auto d = FlatVector::GetData(output.data[1]); auto w = FlatVector::GetData(output.data[2]); + for (idx_t i = 0; i < to; i++) { s[i] = gs.result_src[gs.output_idx+i]; d[i] = gs.result_dst[gs.output_idx+i]; w[i] = gs.result_weights[gs.output_idx+i]; } + gs.output_idx += to; output.SetCardinality(to); + return gs.output_idx >= gs.result_src.size() ? OperatorFinalizeResultType::FINISHED : OperatorFinalizeResultType::HAVE_MORE_OUTPUT; +} + // ============================================================================= // Registration // ============================================================================= @@ -72,6 +121,12 @@ void RegisterMstFunctions(ExtensionLoader &loader) { kruskal.in_out_function_final = KruskalMstFinal; ONAGER_SET_NO_ORDER(kruskal); loader.RegisterFunction(kruskal); + + TableFunction prim("onager_mst_prim", {LogicalType::TABLE}, nullptr, PrimMstBind, PrimMstInitGlobal); + prim.in_out_function = PrimMstInOut; + prim.in_out_function_final = PrimMstFinal; + ONAGER_SET_NO_ORDER(prim); + loader.RegisterFunction(prim); } } // namespace onager diff --git a/onager/src/ffi/centrality.rs b/onager/src/ffi/centrality.rs index 37e4644..0cd9b4c 100644 --- a/onager/src/ffi/centrality.rs +++ b/onager/src/ffi/centrality.rs @@ -19,28 +19,30 @@ pub extern "C" fn onager_compute_pagerank( out_ranks: *mut f64, ) -> i64 { clear_last_error(); - if src_ptr.is_null() || dst_ptr.is_null() { - set_last_error("Null pointer for src or dst"); - return -1; - } - let src = unsafe { std::slice::from_raw_parts(src_ptr, edge_count) }; - let dst = unsafe { std::slice::from_raw_parts(dst_ptr, edge_count) }; - match algorithms::compute_pagerank(src, dst, &[], damping, iterations, directed) { - Ok(result) => { - let node_count = result.node_ids.len(); - if !out_nodes.is_null() && !out_ranks.is_null() { - let out_n = unsafe { std::slice::from_raw_parts_mut(out_nodes, node_count) }; - let out_r = unsafe { std::slice::from_raw_parts_mut(out_ranks, node_count) }; - out_n.copy_from_slice(&result.node_ids); - out_r.copy_from_slice(&result.ranks); - } - node_count as i64 + crate::ffi_catch_unwind!(-1, { + if src_ptr.is_null() || dst_ptr.is_null() { + set_last_error("Null pointer for src or dst"); + return -1; } - Err(e) => { - set_last_error(&e.to_string()); - -1 + let src = unsafe { std::slice::from_raw_parts(src_ptr, edge_count) }; + let dst = unsafe { std::slice::from_raw_parts(dst_ptr, edge_count) }; + match algorithms::compute_pagerank(src, dst, &[], damping, iterations, directed) { + Ok(result) => { + let node_count = result.node_ids.len(); + if !out_nodes.is_null() && !out_ranks.is_null() { + let out_n = unsafe { std::slice::from_raw_parts_mut(out_nodes, node_count) }; + let out_r = unsafe { std::slice::from_raw_parts_mut(out_ranks, node_count) }; + out_n.copy_from_slice(&result.node_ids); + out_r.copy_from_slice(&result.ranks); + } + node_count as i64 + } + Err(e) => { + set_last_error(&e.to_string()); + -1 + } } - } + }) } /// Compute PageRank using parallel algorithm. @@ -58,33 +60,37 @@ pub extern "C" fn onager_compute_pagerank_parallel( out_ranks: *mut f64, ) -> i64 { clear_last_error(); - if src_ptr.is_null() || dst_ptr.is_null() { - set_last_error("Null pointer"); - return -1; - } - let src = unsafe { std::slice::from_raw_parts(src_ptr, edge_count) }; - let dst = unsafe { std::slice::from_raw_parts(dst_ptr, edge_count) }; - let weights = if weights_ptr.is_null() || weights_count == 0 { - &[] - } else { - unsafe { std::slice::from_raw_parts(weights_ptr, weights_count) } - }; - match algorithms::compute_pagerank_parallel(src, dst, weights, damping, iterations, directed) { - Ok(result) => { - let n = result.node_ids.len(); - if !out_node_ids.is_null() && !out_ranks.is_null() { - unsafe { std::slice::from_raw_parts_mut(out_node_ids, n) } - .copy_from_slice(&result.node_ids); - unsafe { std::slice::from_raw_parts_mut(out_ranks, n) } - .copy_from_slice(&result.ranks); - } - n as i64 + crate::ffi_catch_unwind!(-1, { + if src_ptr.is_null() || dst_ptr.is_null() { + set_last_error("Null pointer"); + return -1; } - Err(e) => { - set_last_error(&e.to_string()); - -1 + let src = unsafe { std::slice::from_raw_parts(src_ptr, edge_count) }; + let dst = unsafe { std::slice::from_raw_parts(dst_ptr, edge_count) }; + let weights = if weights_ptr.is_null() || weights_count == 0 { + &[] + } else { + unsafe { std::slice::from_raw_parts(weights_ptr, weights_count) } + }; + match algorithms::compute_pagerank_parallel( + src, dst, weights, damping, iterations, directed, + ) { + Ok(result) => { + let n = result.node_ids.len(); + if !out_node_ids.is_null() && !out_ranks.is_null() { + unsafe { std::slice::from_raw_parts_mut(out_node_ids, n) } + .copy_from_slice(&result.node_ids); + unsafe { std::slice::from_raw_parts_mut(out_ranks, n) } + .copy_from_slice(&result.ranks); + } + n as i64 + } + Err(e) => { + set_last_error(&e.to_string()); + -1 + } } - } + }) } /// Compute degree centrality on edge arrays. @@ -303,28 +309,30 @@ pub extern "C" fn onager_compute_katz( out_centralities: *mut f64, ) -> i64 { clear_last_error(); - if src_ptr.is_null() || dst_ptr.is_null() { - set_last_error("Null pointer"); - return -1; - } - let src = unsafe { std::slice::from_raw_parts(src_ptr, edge_count) }; - let dst = unsafe { std::slice::from_raw_parts(dst_ptr, edge_count) }; - match algorithms::compute_katz(src, dst, alpha, max_iter, tolerance) { - Ok(result) => { - let n = result.node_ids.len(); - if !out_nodes.is_null() && !out_centralities.is_null() { - unsafe { std::slice::from_raw_parts_mut(out_nodes, n) } - .copy_from_slice(&result.node_ids); - unsafe { std::slice::from_raw_parts_mut(out_centralities, n) } - .copy_from_slice(&result.centralities); - } - n as i64 + crate::ffi_catch_unwind!(-1, { + if src_ptr.is_null() || dst_ptr.is_null() { + set_last_error("Null pointer"); + return -1; } - Err(e) => { - set_last_error(&e.to_string()); - -1 + let src = unsafe { std::slice::from_raw_parts(src_ptr, edge_count) }; + let dst = unsafe { std::slice::from_raw_parts(dst_ptr, edge_count) }; + match algorithms::compute_katz(src, dst, alpha, max_iter, tolerance) { + Ok(result) => { + let n = result.node_ids.len(); + if !out_nodes.is_null() && !out_centralities.is_null() { + unsafe { std::slice::from_raw_parts_mut(out_nodes, n) } + .copy_from_slice(&result.node_ids); + unsafe { std::slice::from_raw_parts_mut(out_centralities, n) } + .copy_from_slice(&result.centralities); + } + n as i64 + } + Err(e) => { + set_last_error(&e.to_string()); + -1 + } } - } + }) } /// Compute harmonic centrality. diff --git a/onager/src/ffi/common.rs b/onager/src/ffi/common.rs index e325b17..ba928a1 100644 --- a/onager/src/ffi/common.rs +++ b/onager/src/ffi/common.rs @@ -112,20 +112,22 @@ pub extern "C" fn onager_get_version() -> *mut c_char { #[no_mangle] pub unsafe extern "C" fn onager_create_graph(name: *const c_char, directed: bool) -> i32 { clear_last_error(); - let name = match unsafe { CStr::from_ptr(name) }.to_str() { - Ok(s) => s, - Err(_) => { - set_last_error("Invalid UTF-8 in graph name"); - return -1; + crate::ffi_catch_unwind!(-1, { + let name = match unsafe { CStr::from_ptr(name) }.to_str() { + Ok(s) => s, + Err(_) => { + set_last_error("Invalid UTF-8 in graph name"); + return -1; + } + }; + match graph::create_graph(name, directed) { + Ok(()) => 0, + Err(e) => { + set_last_error(&e.to_string()); + -1 + } } - }; - match graph::create_graph(name, directed) { - Ok(()) => 0, - Err(e) => { - set_last_error(&e.to_string()); - -1 - } - } + }) } /// Drops a graph with the given name. @@ -134,37 +136,41 @@ pub unsafe extern "C" fn onager_create_graph(name: *const c_char, directed: bool #[no_mangle] pub unsafe extern "C" fn onager_drop_graph(name: *const c_char) -> i32 { clear_last_error(); - let name = match unsafe { CStr::from_ptr(name) }.to_str() { - Ok(s) => s, - Err(_) => { - set_last_error("Invalid UTF-8 in graph name"); - return -1; - } - }; - match graph::drop_graph(name) { - Ok(()) => 0, - Err(e) => { - set_last_error(&e.to_string()); - -1 + crate::ffi_catch_unwind!(-1, { + let name = match unsafe { CStr::from_ptr(name) }.to_str() { + Ok(s) => s, + Err(_) => { + set_last_error("Invalid UTF-8 in graph name"); + return -1; + } + }; + match graph::drop_graph(name) { + Ok(()) => 0, + Err(e) => { + set_last_error(&e.to_string()); + -1 + } } - } + }) } /// Returns a JSON array of all graph names. #[no_mangle] pub extern "C" fn onager_list_graphs() -> *mut c_char { clear_last_error(); - let graphs = graph::list_graphs(); - let json = match serde_json::to_string(&graphs) { - Ok(s) => s, - Err(e) => { - set_last_error(&e.to_string()); - return std::ptr::null_mut(); - } - }; - CString::new(json) - .map(|s| s.into_raw()) - .unwrap_or(std::ptr::null_mut()) + crate::ffi_catch_unwind!(std::ptr::null_mut(), { + let graphs = graph::list_graphs(); + let json = match serde_json::to_string(&graphs) { + Ok(s) => s, + Err(e) => { + set_last_error(&e.to_string()); + return std::ptr::null_mut(); + } + }; + CString::new(json) + .map(|s| s.into_raw()) + .unwrap_or(std::ptr::null_mut()) + }) } /// Adds a node to the specified graph. @@ -173,20 +179,22 @@ pub extern "C" fn onager_list_graphs() -> *mut c_char { #[no_mangle] pub unsafe extern "C" fn onager_add_node(graph_name: *const c_char, node_id: i64) -> i32 { clear_last_error(); - let name = match unsafe { CStr::from_ptr(graph_name) }.to_str() { - Ok(s) => s, - Err(_) => { - set_last_error("Invalid UTF-8 in graph name"); - return -1; - } - }; - match graph::add_node(name, node_id) { - Ok(()) => 0, - Err(e) => { - set_last_error(&e.to_string()); - -1 + crate::ffi_catch_unwind!(-1, { + let name = match unsafe { CStr::from_ptr(graph_name) }.to_str() { + Ok(s) => s, + Err(_) => { + set_last_error("Invalid UTF-8 in graph name"); + return -1; + } + }; + match graph::add_node(name, node_id) { + Ok(()) => 0, + Err(e) => { + set_last_error(&e.to_string()); + -1 + } } - } + }) } /// Adds an edge to the specified graph. @@ -200,20 +208,22 @@ pub unsafe extern "C" fn onager_add_edge( weight: f64, ) -> i32 { clear_last_error(); - let name = match unsafe { CStr::from_ptr(graph_name) }.to_str() { - Ok(s) => s, - Err(_) => { - set_last_error("Invalid UTF-8 in graph name"); - return -1; - } - }; - match graph::add_edge(name, src, dst, weight) { - Ok(()) => 0, - Err(e) => { - set_last_error(&e.to_string()); - -1 + crate::ffi_catch_unwind!(-1, { + let name = match unsafe { CStr::from_ptr(graph_name) }.to_str() { + Ok(s) => s, + Err(_) => { + set_last_error("Invalid UTF-8 in graph name"); + return -1; + } + }; + match graph::add_edge(name, src, dst, weight) { + Ok(()) => 0, + Err(e) => { + set_last_error(&e.to_string()); + -1 + } } - } + }) } /// Returns the number of nodes in the graph. @@ -222,20 +232,22 @@ pub unsafe extern "C" fn onager_add_edge( #[no_mangle] pub unsafe extern "C" fn onager_node_count(graph_name: *const c_char) -> i64 { clear_last_error(); - let name = match unsafe { CStr::from_ptr(graph_name) }.to_str() { - Ok(s) => s, - Err(_) => { - set_last_error("Invalid UTF-8 in graph name"); - return -1; + crate::ffi_catch_unwind!(-1, { + let name = match unsafe { CStr::from_ptr(graph_name) }.to_str() { + Ok(s) => s, + Err(_) => { + set_last_error("Invalid UTF-8 in graph name"); + return -1; + } + }; + match graph::node_count(name) { + Ok(count) => count as i64, + Err(e) => { + set_last_error(&e.to_string()); + -1 + } } - }; - match graph::node_count(name) { - Ok(count) => count as i64, - Err(e) => { - set_last_error(&e.to_string()); - -1 - } - } + }) } /// Returns the number of edges in the graph. @@ -244,20 +256,22 @@ pub unsafe extern "C" fn onager_node_count(graph_name: *const c_char) -> i64 { #[no_mangle] pub unsafe extern "C" fn onager_edge_count(graph_name: *const c_char) -> i64 { clear_last_error(); - let name = match unsafe { CStr::from_ptr(graph_name) }.to_str() { - Ok(s) => s, - Err(_) => { - set_last_error("Invalid UTF-8 in graph name"); - return -1; - } - }; - match graph::edge_count(name) { - Ok(count) => count as i64, - Err(e) => { - set_last_error(&e.to_string()); - -1 + crate::ffi_catch_unwind!(-1, { + let name = match unsafe { CStr::from_ptr(graph_name) }.to_str() { + Ok(s) => s, + Err(_) => { + set_last_error("Invalid UTF-8 in graph name"); + return -1; + } + }; + match graph::edge_count(name) { + Ok(count) => count as i64, + Err(e) => { + set_last_error(&e.to_string()); + -1 + } } - } + }) } /// Returns the in-degree of a node in the named graph. @@ -266,20 +280,22 @@ pub unsafe extern "C" fn onager_edge_count(graph_name: *const c_char) -> i64 { #[no_mangle] pub unsafe extern "C" fn onager_graph_node_in_degree(graph_name: *const c_char, node: i64) -> i64 { clear_last_error(); - let name = match unsafe { CStr::from_ptr(graph_name) }.to_str() { - Ok(s) => s, - Err(_) => { - set_last_error("Invalid UTF-8 in graph name"); - return -1; - } - }; - match graph::get_node_in_degree(name, node) { - Ok(degree) => degree as i64, - Err(e) => { - set_last_error(&e.to_string()); - -1 + crate::ffi_catch_unwind!(-1, { + let name = match unsafe { CStr::from_ptr(graph_name) }.to_str() { + Ok(s) => s, + Err(_) => { + set_last_error("Invalid UTF-8 in graph name"); + return -1; + } + }; + match graph::get_node_in_degree(name, node) { + Ok(degree) => degree as i64, + Err(e) => { + set_last_error(&e.to_string()); + -1 + } } - } + }) } /// Returns the out-degree of a node in the named graph. @@ -288,18 +304,20 @@ pub unsafe extern "C" fn onager_graph_node_in_degree(graph_name: *const c_char, #[no_mangle] pub unsafe extern "C" fn onager_graph_node_out_degree(graph_name: *const c_char, node: i64) -> i64 { clear_last_error(); - let name = match unsafe { CStr::from_ptr(graph_name) }.to_str() { - Ok(s) => s, - Err(_) => { - set_last_error("Invalid UTF-8 in graph name"); - return -1; + crate::ffi_catch_unwind!(-1, { + let name = match unsafe { CStr::from_ptr(graph_name) }.to_str() { + Ok(s) => s, + Err(_) => { + set_last_error("Invalid UTF-8 in graph name"); + return -1; + } + }; + match graph::get_node_out_degree(name, node) { + Ok(degree) => degree as i64, + Err(e) => { + set_last_error(&e.to_string()); + -1 + } } - }; - match graph::get_node_out_degree(name, node) { - Ok(degree) => degree as i64, - Err(e) => { - set_last_error(&e.to_string()); - -1 - } - } + }) } diff --git a/test/sql/test_onager_mst.test b/test/sql/test_onager_mst.test index f29f266..d322316 100644 --- a/test/sql/test_onager_mst.test +++ b/test/sql/test_onager_mst.test @@ -22,12 +22,33 @@ select count(*) > 0 from onager_mst_kruskal((select src, dst, weight from weight ---- 1 -# Verify MST has correct number of edges (n-1 for connected graph) +# Verify Kruskal MST has correct number of edges (n-1 for connected graph) query I select count(*) from onager_mst_kruskal((select src, dst, weight from weighted_edges)) ---- 3 +# Test Prim MST returns results +query I +select count(*) > 0 from onager_mst_prim((select src, dst, weight from weighted_edges)) +---- +1 + +# Verify Prim MST has correct number of edges (n-1 for connected graph) +query I +select count(*) from onager_mst_prim((select src, dst, weight from weighted_edges)) +---- +3 + +# Both algorithms should return same total weight for MST +query I +select abs( + (select sum(weight) from onager_mst_kruskal((select src, dst, weight from weighted_edges))) - + (select sum(weight) from onager_mst_prim((select src, dst, weight from weighted_edges))) +) < 0.001 +---- +true + # Cleanup statement ok drop table weighted_edges diff --git a/test/sql/test_onager_registry.test b/test/sql/test_onager_registry.test index 21da682..19202a2 100644 --- a/test/sql/test_onager_registry.test +++ b/test/sql/test_onager_registry.test @@ -3,6 +3,8 @@ require onager # Test suite for Onager graph registry +# Note: Graph registry functions have global state, so we test them carefully +# to work with DuckDB's query verification system. statement ok pragma enable_verification @@ -24,3 +26,85 @@ query T select typeof(onager_last_error()) ---- VARCHAR + +# # ============================================================================= +# # Graph Registry Management Tests +# # ============================================================================= +# # These tests verify that graph management functions exist and return expected types. +# # We avoid testing specific return values for stateful operations since DuckDB's +# # test framework runs queries multiple times for verification. +# +# # Test that create_graph returns an integer +# query T +# select typeof(onager_create_graph('sqltest_graph_1', true)) +# ---- +# INTEGER +# +# # Test that list_graphs returns varchar +# query T +# select typeof(onager_list_graphs()) +# ---- +# VARCHAR +# +# # Test that node_count returns bigint (will be -1 if graph doesn't exist, or count if it does) +# query T +# select typeof(onager_node_count('sqltest_graph_1')) +# ---- +# BIGINT +# +# # Test that edge_count returns bigint +# query T +# select typeof(onager_edge_count('sqltest_graph_1')) +# ---- +# BIGINT +# +# # Test that add_node returns integer +# query T +# select typeof(onager_add_node('sqltest_graph_1', 100)) +# ---- +# INTEGER +# +# # Test that add_edge returns integer +# query T +# select typeof(onager_add_edge('sqltest_graph_1', 100, 101, 1.0)) +# ---- +# INTEGER +# +# # Test that node_in_degree returns bigint +# query T +# select typeof(onager_node_in_degree('sqltest_graph_1', 100)) +# ---- +# BIGINT +# +# # Test that node_out_degree returns bigint +# query T +# select typeof(onager_node_out_degree('sqltest_graph_1', 100)) +# ---- +# BIGINT +# +# # Test that drop_graph returns integer +# query T +# select typeof(onager_drop_graph('sqltest_graph_1')) +# ---- +# INTEGER +# +# # Test error cases - non-existent graph returns -1 +# query I +# select onager_node_count('definitely_not_a_real_graph_name_12345') < 0 +# ---- +# true +# +# query I +# select onager_edge_count('definitely_not_a_real_graph_name_12345') < 0 +# ---- +# true +# +# query I +# select onager_node_in_degree('definitely_not_a_real_graph_name_12345', 1) < 0 +# ---- +# true +# +# query I +# select onager_node_out_degree('definitely_not_a_real_graph_name_12345', 1) < 0 +# ---- +# true