Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 19 additions & 28 deletions .github/workflows/continuous-integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,48 +7,39 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@master
- name: Install Rust stable
uses: actions-rs/toolchain@v1
- name: Install rust toolchain
uses: dtolnay/rust-toolchain@stable
with:
toolchain: stable
profile: minimal
components: rustfmt
toolchain: stable
components: cargo, rustfmt
- name: Run cargo fmt
uses: actions-rs/cargo@v1
with:
command: fmt
args: --all -- --check
run: |
cargo fmt --all -- --check --color always

clippy:
name: Analyzing code with Clippy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@master
- name: Install Rust stable
uses: actions-rs/toolchain@v1
- name: Install rust toolchain
uses: dtolnay/rust-toolchain@stable
with:
toolchain: stable
profile: minimal
components: clippy
toolchain: stable
components: cargo, clippy
- name: Run cargo clippy
uses: actions-rs/cargo@v1
with:
command: clippy
args: --workspace -- -D warnings
run: |
cargo clippy --all-targets --all-features --workspace -- -D warnings

tests:
name: Tests
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@master
- name: Install Rust stable
uses: actions-rs/toolchain@v1
with:
toolchain: stable
profile: minimal
override: true
- name: Run tests
uses: actions-rs/cargo@v1
- name: Install rust toolchain
uses: dtolnay/rust-toolchain@stable
with:
command: test
args: --verbose --workspace
toolchain: stable
components: cargo
- name: Run cargo test
run: |
cargo test --verbose --workspace
10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
[package]
name = "relational_types"
description = "Manage relations between objects"
version = "2.1.0"
authors = ["Hove <team.coretools@kisio.org>", "Guillaume Pinot <texitoi@texitoi.eu>"]
edition = "2018"
version = "2.1.1"
authors = ["Hove <core@hove.com>", "Guillaume Pinot <texitoi@texitoi.eu>"]
edition = "2021"
license = "MIT"
homepage = "https://github.com/hove-io/relational_types"
repository = "https://github.com/hove-io/relational_types"
Expand All @@ -19,9 +19,9 @@ members = [
]

[dependencies]
derivative = "1"
derivative = "2"
relational_types_procmacro = { version = "2", path = "./relational_types_procmacro/", optional = true }
thiserror = "1"
thiserror = "2"
typed_index_collection = { git = "https://github.com/hove-io/typed_index_collection", tag = "v2"}

[features]
Expand Down
5 changes: 3 additions & 2 deletions relational_types_procmacro/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ keywords = ["macro", "floyd_marshall"]
proc-macro = true

[dependencies]
syn = "0.11.11"
quote = "0.3.15"
syn = "2.0"
quote = "1.0"
proc-macro2 = "1.0.94"
134 changes: 65 additions & 69 deletions relational_types_procmacro/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,37 @@
#![deny(missing_docs)]

//! Custom derive for GetCorresponding. See `relational_types` for the documentation.
//! Custom derive for GetCorresponding. See `relational_types` for the documentation.

#![recursion_limit = "128"]

extern crate proc_macro;
use quote::*;

use proc_macro::TokenStream;
use quote::quote;
use std::collections::{HashMap, HashSet};
use syn::{
parse_macro_input, Data, DataStruct, DeriveInput, Expr, Field, Fields, Ident, Lit, Meta,
PathArguments, Type,
};

/// Generation of the `GetCorresponding` trait implementation.
#[proc_macro_derive(GetCorresponding, attributes(get_corresponding))]
pub fn get_corresponding(input: TokenStream) -> TokenStream {
let s = input.to_string();
let ast = syn::parse_derive_input(&s).unwrap();
let ast = parse_macro_input!(input as DeriveInput);
let gen = impl_get_corresponding(&ast);
gen.parse().unwrap()
gen.into()
}

fn impl_get_corresponding(ast: &syn::DeriveInput) -> quote::Tokens {
if let syn::Body::Struct(syn::VariantData::Struct(ref fields)) = ast.body {
fn impl_get_corresponding(ast: &DeriveInput) -> proc_macro2::TokenStream {
if let Data::Struct(DataStruct {
fields: Fields::Named(ref fields),
..
}) = ast.data
{
let name = &ast.ident;
let edges: Vec<_> = fields.iter().filter_map(to_edge).collect();
let edges: Vec<_> = fields.named.iter().filter_map(to_edge).collect();
let next = floyd_warshall(&edges);
let edge_to_impl = make_edge_to_get_corresponding(name, &edges);

let edges_impls = next.iter().map(|(&(from, to), &node)| {
if from == to {
quote! {
Expand All @@ -47,23 +54,18 @@ fn impl_get_corresponding(ast: &syn::DeriveInput) -> quote::Tokens {
}
}
});

quote! {
/// A trait that returns a set of objects corresponding to
/// a given type.
pub trait GetCorresponding<T: Sized> {
/// For the given self, returns the set of
/// corresponding `T` indices.
fn get_corresponding(&self, model: &#name) -> IdxSet<T>;
}
impl #name {
/// Returns the set of `U` indices corresponding to the `from` set.
pub fn get_corresponding<T, U>(&self, from: &IdxSet<T>) -> IdxSet<U>
where
IdxSet<T>: GetCorresponding<U>
{
from.get_corresponding(self)
}
/// Returns the set of `U` indices corresponding to the `from` index.
pub fn get_corresponding_from_idx<T, U>(&self, from: Idx<T>) -> IdxSet<U>
where
IdxSet<T>: GetCorresponding<U>
Expand All @@ -78,65 +80,66 @@ fn impl_get_corresponding(ast: &syn::DeriveInput) -> quote::Tokens {
}
}

fn to_edge(field: &syn::Field) -> Option<Edge> {
use syn::MetaItem::*;
use syn::NestedMetaItem::MetaItem;
use syn::PathParameters::AngleBracketed;

let ident = field.ident.as_ref()?.as_ref();
fn to_edge(field: &Field) -> Option<Edge> {
let ident = field.ident.as_ref()?.to_string();
let mut split = ident.split("_to_");
let _from_collection = split.next()?;
let _to_collection = split.next()?;
if split.next().is_some() {
return None;
}
let segment = if let syn::Ty::Path(_, ref path) = field.ty {
path.segments.last()
} else {
None

let segment = match &field.ty {
Type::Path(type_path) => type_path.path.segments.last(),
_ => None,
}?;
let (from_ty, to_ty) = if let AngleBracketed(ref data) = segment.parameters {
match (data.types.get(0), data.types.get(1), data.types.get(2)) {
(Some(from_ty), Some(to_ty), None) => Some((from_ty, to_ty)),

let (from_ty, to_ty) = if let PathArguments::AngleBracketed(data) = &segment.arguments {
match (data.args.first(), data.args.get(1), data.args.get(2)) {
(
Some(syn::GenericArgument::Type(from_ty)),
Some(syn::GenericArgument::Type(to_ty)),
None,
) => Some((from_ty, to_ty)),
_ => None,
}
} else {
None
}?;

let weight = field
.attrs
.iter()
.flat_map(|attr| match attr.value {
List(ref i, ref v) if i == "get_corresponding" => v.as_slice(),
_ => &[],
})
.map(|mi| match *mi {
MetaItem(NameValue(ref i, syn::Lit::Str(ref l, _))) => {
assert_eq!(i, "weight", "{} is not a valid attribute", i);
l.parse::<f64>()
.expect("`weight` attribute must be convertible to f64")
.filter_map(|attr| {
if let Meta::NameValue(meta) = &attr.meta {
if meta.path.is_ident("weight") {
if let Expr::Lit(expr_lit) = &meta.value {
if let Lit::Str(lit_str) = &expr_lit.lit {
return lit_str.value().parse::<f64>().ok();
}
}
}
}
_ => panic!("Only `key = \"value\"` attributes supported."),
None
})
.last()
.unwrap_or(1.);
.unwrap_or(1.0);

Edge {
ident: ident.into(),
from: from_ty.clone(),
to: to_ty.clone(),
Some(Edge {
ident,
from: (*from_ty).clone(),
to: (*to_ty).clone(),
weight,
}
.into()
})
}

fn make_edge_to_get_corresponding<'a>(
name: &syn::Ident,
name: &Ident,
edges: &'a [Edge],
) -> HashMap<(&'a syn::Ty, &'a syn::Ty), quote::Tokens> {
let mut res = HashMap::default();
) -> HashMap<(&'a Type, &'a Type), proc_macro2::TokenStream> {
let mut res = HashMap::new();
for e in edges {
let ident: quote::Ident = e.ident.as_str().into();
let ident = Ident::new(&e.ident, proc_macro2::Span::call_site());
let from = &e.from;
let to = &e.to;
res.insert(
Expand All @@ -163,11 +166,11 @@ fn make_edge_to_get_corresponding<'a>(
res
}

fn floyd_warshall(edges: &[Edge]) -> HashMap<(&Node, &Node), &Node> {
use std::f64::INFINITY;
let mut v = HashSet::<&Node>::default();
let mut dist = HashMap::<(&Node, &Node), f64>::default();
let mut next = HashMap::default();
fn floyd_warshall(edges: &[Edge]) -> HashMap<(&Type, &Type), &Type> {
let mut v = HashSet::<&Type>::new();
let mut dist = HashMap::<(&Type, &Type), f64>::new();
let mut next = HashMap::new();

for e in edges {
let from = &e.from;
let to = &e.to;
Expand All @@ -178,34 +181,27 @@ fn floyd_warshall(edges: &[Edge]) -> HashMap<(&Node, &Node), &Node> {
next.insert((from, to), to);
next.insert((to, from), from);
}

for &k in &v {
for &i in &v {
let dist_ik = match dist.get(&(i, k)) {
Some(d) => *d,
None => continue,
};
let dist_ik = *dist.get(&(i, k)).unwrap_or(&f64::INFINITY);
for &j in &v {
let dist_kj = match dist.get(&(k, j)) {
Some(d) => *d,
None => continue,
};
let dist_ij = dist.entry((i, j)).or_insert(INFINITY);
let dist_kj = *dist.get(&(k, j)).unwrap_or(&f64::INFINITY);
let dist_ij = dist.entry((i, j)).or_insert(f64::INFINITY);
if *dist_ij > dist_ik + dist_kj {
*dist_ij = dist_ik + dist_kj;
let next_ik = next[&(i, k)];
next.insert((i, j), next_ik);
next.insert((i, j), next[&(i, k)]);
}
}
}
}

next
}

struct Edge {
ident: String,
from: Node,
to: Node,
from: Type,
to: Type,
weight: f64,
}

type Node = syn::Ty;
4 changes: 2 additions & 2 deletions relational_types_procmacro_tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ publish = false
autotests = false

[dev-dependencies]
pretty_assertions = "0.6"
pretty_assertions = "1.4"
trybuild = "1"
typed_index_collection = { git = "https://github.com/hove-io/typed_index_collection", tag = "v2"}
typed_index_collection = { git = "https://github.com/hove-io/typed_index_collection", tag = "v2.3.0"}
relational_types = { version = "2", path = "../" }

[[test]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ mod test_utils;

use relational_types::*;
use test_utils::*;
use typed_index_collection::*;
use typed_index_collection::collection::*;

#[derive(GetCorresponding)]
pub struct Model {
Expand Down
Loading