Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Export/import of JSON metadata #1622

Merged
merged 10 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 58 additions & 5 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Exporting HUGR graphs to their `hugr-model` representation.
use crate::{
extension::{ExtensionId, ExtensionSet, OpDef, SignatureFunc},
hugr::IdentList,
hugr::{IdentList, NodeMetadataMap},
ops::{DataflowBlock, OpName, OpTrait, OpType},
types::{
type_param::{TypeArgVariable, TypeParam},
Expand All @@ -21,6 +21,8 @@ type FxIndexSet<T> = IndexSet<T, fxhash::FxBuildHasher>;

pub(crate) const OP_FUNC_CALL_INDIRECT: &str = "func.call-indirect";
const TERM_PARAM_TUPLE: &str = "param.tuple";
const TERM_JSON: &str = "prelude.json";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps theses should these be public? Downstream clients should not have to memorise the string.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point yes, and there should be some extension declarations that define these term constructors. I've not made the names public yet since they'll probably change as we move closer to stabilising.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Their propensity to change is the precise reason they should be public now.

const META_DESCRIPTION: &str = "docs.description";

/// Export a [`Hugr`] graph to its representation in the model.
pub fn export_hugr<'a>(hugr: &'a Hugr, bump: &'a Bump) -> model::Module<'a> {
Expand Down Expand Up @@ -392,14 +394,19 @@ impl<'a> Context<'a> {
let inputs = self.make_ports(node, Direction::Incoming, num_inputs);
let outputs = self.make_ports(node, Direction::Outgoing, num_outputs);

let meta = match self.hugr.get_node_metadata(node) {
Some(metadata_map) => self.export_node_metadata(metadata_map),
None => &[],
};

// Replace the placeholder node with the actual node.
*self.module.get_node_mut(node_id).unwrap() = model::Node {
operation,
inputs,
outputs,
params,
regions,
meta: &[], // TODO: Export metadata
meta,
signature,
};

Expand Down Expand Up @@ -435,7 +442,7 @@ impl<'a> Context<'a> {
outputs: &[],
params: &[],
regions: &[],
meta: &[], // TODO: Metadata
meta: &[],
signature: None,
}))
}
Expand All @@ -452,8 +459,29 @@ impl<'a> Context<'a> {
decl
});

self.module.get_node_mut(node).unwrap().operation =
model::Operation::DeclareOperation { decl };
let meta = {
let description = Some(opdef.description()).filter(|d| !d.is_empty());
let meta_len = opdef.iter_misc().len() + description.is_some() as usize;
let mut meta = BumpVec::with_capacity_in(meta_len, self.bump);

if let Some(description) = description {
let name = META_DESCRIPTION;
let value = self.make_term(model::Term::Str(self.bump.alloc_str(description)));
meta.push(model::MetaItem { name, value })
}

for (name, value) in opdef.iter_misc() {
let name = self.bump.alloc_str(name);
let value = self.export_json(value);
meta.push(model::MetaItem { name, value });
}

self.bump.alloc_slice_copy(&meta)
};

let node_data = self.module.get_node_mut(node).unwrap();
node_data.operation = model::Operation::DeclareOperation { decl };
node_data.meta = meta;

model::GlobalRef::Direct(node)
}
Expand Down Expand Up @@ -843,6 +871,31 @@ impl<'a> Context<'a> {

self.make_term(model::Term::ExtSet { extensions, rest })
}

pub fn export_node_metadata(
&mut self,
metadata_map: &NodeMetadataMap,
) -> &'a [model::MetaItem<'a>] {
let mut meta = BumpVec::with_capacity_in(metadata_map.len(), self.bump);

for (name, value) in metadata_map {
let name = self.bump.alloc_str(name);
let value = self.export_json(value);
meta.push(model::MetaItem { name, value });
}

meta.into_bump_slice()
}

pub fn export_json(&mut self, value: &serde_json::Value) -> model::TermId {
let value = serde_json::to_string(value).expect("json values are always serializable");
let value = self.make_term(model::Term::Str(self.bump.alloc_str(&value)));
let value = self.bump.alloc_slice_copy(&[value]);
self.make_term(model::Term::ApplyFull {
global: model::GlobalRef::Named(TERM_JSON),
args: value,
})
}
}

#[cfg(test)]
Expand Down
5 changes: 5 additions & 0 deletions hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,11 @@ impl OpDef {
self.misc.insert(k.to_string(), v)
}

/// Iterate over all miscellaneous data in the [OpDef].
pub(crate) fn iter_misc(&self) -> impl ExactSizeIterator<Item = (&str, &serde_json::Value)> {
self.misc.iter().map(|(k, v)| (k.as_str(), v))
}

/// Set the constant folding function for this Op, which can evaluate it
/// given constant inputs.
pub fn set_constant_folder(&mut self, fold: impl ConstFold + 'static) {
Expand Down
39 changes: 39 additions & 0 deletions hugr-core/src/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ use itertools::Either;
use smol_str::{SmolStr, ToSmolStr};
use thiserror::Error;

const TERM_JSON: &str = "prelude.json";

type FxIndexMap<K, V> = IndexMap<K, V, fxhash::FxBuildHasher>;

/// Error during import.
Expand Down Expand Up @@ -184,6 +186,14 @@ impl<'a> Context<'a> {
let node_data = self.get_node(node_id)?;
self.record_links(node, Direction::Incoming, node_data.inputs);
self.record_links(node, Direction::Outgoing, node_data.outputs);

for meta_item in node_data.meta {
// TODO: For now we expect all metadata to be JSON since this is how
// it is handled in `hugr-core`.
let value = self.import_json_value(meta_item.value)?;
self.hugr.set_metadata(node, meta_item.name, value);
}

Ok(node)
}

Expand Down Expand Up @@ -1200,6 +1210,35 @@ impl<'a> Context<'a> {
}
}
}

fn import_json_value(
&mut self,
term_id: model::TermId,
) -> Result<serde_json::Value, ImportError> {
let (global, args) = match self.get_term(term_id)? {
model::Term::Apply { global, args } | model::Term::ApplyFull { global, args } => {
(global, args)
}
_ => return Err(model::ModelError::TypeError(term_id).into()),
};

if global != &GlobalRef::Named(TERM_JSON) {
return Err(model::ModelError::TypeError(term_id).into());
}

let [json_arg] = args else {
return Err(model::ModelError::TypeError(term_id).into());
};

let model::Term::Str(json_str) = self.get_term(*json_arg)? else {
return Err(model::ModelError::TypeError(term_id).into());
};

let json_value =
serde_json::from_str(json_str).map_err(|_| model::ModelError::TypeError(term_id))?;

Ok(json_value)
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
Expand Down
11 changes: 9 additions & 2 deletions hugr-core/tests/snapshots/model__roundtrip_call.snap
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
---
source: hugr-core/tests/model.rs
expression: "roundtrip(include_str!(\"fixtures/model-call.edn\"))"
expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-call.edn\"))"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a shame, is it possible to arrange things so that this cross-crate reference is removed?

Copy link
Contributor Author

@zrho zrho Nov 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could move the tests that are in hugr-model to hugr-core, but arguably they don't really belong there since they only test stuff within hugr-model. We can't move the tests in hugr-core to hugr-model since they use stuff from hugr-core. We could copy the fixtures.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that means yes, we could duplicate them. I prefer duplication, up to you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll merge this for now but create an issue so we can figure out if there is some elegant approach to this.

---
(hugr 0)

(declare-func example.callee
(forall ?0 ext-set)
[(@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
(ext arithmetic.int . ?0))
(ext arithmetic.int . ?0)
(meta doc.description (@ prelude.json "\"This is a function declaration.\""))
(meta doc.title (@ prelude.json "\"Callee\"")))

(define-func example.caller
[(@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
(ext arithmetic.int)
(meta doc.description
(@
prelude.json
"\"This defines a function that calls the function which we declared earlier.\""))
(meta doc.title (@ prelude.json "\"Caller\""))
(dfg
[%0] [%1]
(signature
Expand Down
8 changes: 6 additions & 2 deletions hugr-model/src/v0/text/hugr.pest
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@ ext_name = @{ identifier ~ ("." ~ identifier)* }
symbol = @{ identifier ~ ("." ~ identifier)+ }
tag = @{ (ASCII_NONZERO_DIGIT ~ ASCII_DIGIT*) | "0" }

string = @{ "\"" ~ (!("\"") ~ ANY)* ~ "\"" }
list_tail = { "." }
string = { "\"" ~ (string_raw | string_escape | string_unicode)* ~ "\"" }
string_raw = @{ (!("\\" | "\"") ~ ANY)+ }
string_escape = @{ "\\" ~ ("\"" | "\\" | "n" | "r" | "t") }
string_unicode = @{ "\\u" ~ "{" ~ ASCII_HEX_DIGIT+ ~ "}" }

list_tail = { "." }

module = { "(" ~ "hugr" ~ "0" ~ ")" ~ meta* ~ node* ~ EOI }

Expand Down
75 changes: 62 additions & 13 deletions hugr-model/src/v0/text/parse.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use bumpalo::Bump;
use bumpalo::{collections::String as BumpString, Bump};
use pest::{
iterators::{Pair, Pairs},
Parser, RuleType,
Expand Down Expand Up @@ -60,7 +60,7 @@ impl<'a> ParseContext<'a> {
}

fn parse_module(&mut self, pair: Pair<'a, Rule>) -> ParseResult<()> {
debug_assert!(matches!(pair.as_rule(), Rule::module));
debug_assert_eq!(pair.as_rule(), Rule::module);
let mut inner = pair.into_inner();
let meta = self.parse_meta(&mut inner)?;

Expand All @@ -81,7 +81,7 @@ impl<'a> ParseContext<'a> {
}

fn parse_term(&mut self, pair: Pair<'a, Rule>) -> ParseResult<TermId> {
debug_assert!(matches!(pair.as_rule(), Rule::term));
debug_assert_eq!(pair.as_rule(), Rule::term);
let pair = pair.into_inner().next().unwrap();
let rule = pair.as_rule();
let mut inner = pair.into_inner();
Expand Down Expand Up @@ -160,9 +160,7 @@ impl<'a> ParseContext<'a> {
}

Rule::term_str => {
// TODO: Escaping?
let value = inner.next().unwrap().as_str();
let value = &value[1..value.len() - 1];
let value = self.parse_string(inner.next().unwrap())?;
Term::Str(value)
}

Expand Down Expand Up @@ -218,7 +216,7 @@ impl<'a> ParseContext<'a> {
}

fn parse_node(&mut self, pair: Pair<'a, Rule>) -> ParseResult<NodeId> {
debug_assert!(matches!(pair.as_rule(), Rule::node));
debug_assert_eq!(pair.as_rule(), Rule::node);
let pair = pair.into_inner().next().unwrap();
let rule = pair.as_rule();

Expand Down Expand Up @@ -503,7 +501,7 @@ impl<'a> ParseContext<'a> {
}

fn parse_region(&mut self, pair: Pair<'a, Rule>) -> ParseResult<RegionId> {
debug_assert!(matches!(pair.as_rule(), Rule::region));
debug_assert_eq!(pair.as_rule(), Rule::region);
let pair = pair.into_inner().next().unwrap();
let rule = pair.as_rule();
let mut inner = pair.into_inner();
Expand Down Expand Up @@ -541,7 +539,7 @@ impl<'a> ParseContext<'a> {
}

fn parse_func_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a FuncDecl<'a>> {
debug_assert!(matches!(pair.as_rule(), Rule::func_header));
debug_assert_eq!(pair.as_rule(), Rule::func_header);

let mut inner = pair.into_inner();
let name = self.parse_symbol(&mut inner)?;
Expand All @@ -566,7 +564,7 @@ impl<'a> ParseContext<'a> {
}

fn parse_alias_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a AliasDecl<'a>> {
debug_assert!(matches!(pair.as_rule(), Rule::alias_header));
debug_assert_eq!(pair.as_rule(), Rule::alias_header);

let mut inner = pair.into_inner();
let name = self.parse_symbol(&mut inner)?;
Expand All @@ -581,7 +579,7 @@ impl<'a> ParseContext<'a> {
}

fn parse_ctr_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a ConstructorDecl<'a>> {
debug_assert!(matches!(pair.as_rule(), Rule::ctr_header));
debug_assert_eq!(pair.as_rule(), Rule::ctr_header);

let mut inner = pair.into_inner();
let name = self.parse_symbol(&mut inner)?;
Expand All @@ -596,7 +594,7 @@ impl<'a> ParseContext<'a> {
}

fn parse_op_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a OperationDecl<'a>> {
debug_assert!(matches!(pair.as_rule(), Rule::operation_header));
debug_assert_eq!(pair.as_rule(), Rule::operation_header);

let mut inner = pair.into_inner();
let name = self.parse_symbol(&mut inner)?;
Expand Down Expand Up @@ -670,7 +668,7 @@ impl<'a> ParseContext<'a> {
}

fn parse_port(&mut self, pair: Pair<'a, Rule>) -> ParseResult<LinkRef<'a>> {
debug_assert!(matches!(pair.as_rule(), Rule::port));
debug_assert_eq!(pair.as_rule(), Rule::port);
let mut inner = pair.into_inner();
let link = LinkRef::Named(&inner.next().unwrap().as_str()[1..]);
Ok(link)
Expand All @@ -697,6 +695,47 @@ impl<'a> ParseContext<'a> {
unreachable!("expected a symbol");
}
}

fn parse_string(&self, token: Pair<'a, Rule>) -> ParseResult<&'a str> {
debug_assert_eq!(token.as_rule(), Rule::string);

// Any escape sequence is longer than the character it represents.
// Therefore the length of this token (minus 2 for the quotes on either
// side) is an upper bound for the length of the string.
let capacity = token.as_str().len() - 2;
let mut string = BumpString::with_capacity_in(capacity, self.bump);
let tokens = token.into_inner();

for token in tokens {
match token.as_rule() {
Rule::string_raw => string.push_str(token.as_str()),
Rule::string_escape => match token.as_str().chars().nth(1).unwrap() {
'"' => string.push('"'),
'\\' => string.push('\\'),
'n' => string.push('\n'),
'r' => string.push('\r'),
't' => string.push('\t'),
_ => unreachable!(),
},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the criteria for allowed escapes? https://www.ietf.org/rfc/rfc4627.txt says any character can be escaped (i.e. \x should be interpreted as x).

Copy link
Contributor Author

@zrho zrho Nov 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://www.ietf.org/rfc/rfc4627.txt describes how JSON does string escaping, which is not the only way. Note that we use the string to store JSON but the string is not itself a JSON string. These escape codes are closer to how Rust does escaping (which has a nicer unicode escape sequence). More precisely it's a subset of what Rust does (not deliberately). It's also a subset of what the webassembly text format does (which is very similar to Rust strings but not entirely the same). I don't care that much about which escaping system we use so this was picked rather arbitrarily; we could adjust it to taste if you wish.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine, I did not understand that the string is not itself a json string.

Rule::string_unicode => {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this unreachable is panicking on bad user input? We should throw a ParseError.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way that the pest library works, at the point we're doing the match we know that the input was valid for the string_escape rule, which is defined as follows:

string_escape  = @{ "\\" ~ ("\"" | "\\" | "n" | "r" | "t") }

So that case is indeed unreachable. This is similar to the other uses of unreachable in the parser and an idiomatic use of pest.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. I agree that the case is unreachable. Nevertheless it is better to throw an error here, because there are two sources-of-truth for allowed escapes, which may diverge over time. If you feel strongly feel free to leave this as is.

Copy link
Contributor Author

@zrho zrho Nov 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When these diverge we have a bug. A parse error would indicate that there's something wrong with the input, when in truth the error is in the code. So I think the panic is appropriate.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A panic aborts the program, it's very user hostile. We should aim to minimize the impact of bugs, e.g. throw an error here.

let token_str = token.as_str();
debug_assert_eq!(&token_str[0..3], r"\u{");
debug_assert_eq!(&token_str[token_str.len() - 1..], "}");
let code_str = &token_str[3..token_str.len() - 1];
let code = u32::from_str_radix(code_str, 16).map_err(|_| {
ParseError::custom("invalid unicode escape sequence", token.as_span())
})?;
let char = std::char::from_u32(code).ok_or_else(|| {
ParseError::custom("invalid unicode code point", token.as_span())
})?;
string.push(char);
}
_ => unreachable!(),
}
}

Ok(string.into_bump_str())
}
}

/// Draw from a pest pair iterator only the pairs that match a given rule.
Expand Down Expand Up @@ -750,6 +789,16 @@ impl ParseError {
InputLocation::Span((offset, _)) => offset,
}
}

fn custom(message: &str, span: pest::Span) -> Self {
let error = pest::error::Error::new_from_span(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no coverage for this function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see what to test about this function. It constructs an error.

pest::error::ErrorVariant::CustomError {
message: message.to_string(),
},
span,
);
ParseError(Box::new(error))
}
}

// NOTE: `ParseError` does not implement `From<pest::error::Error<Rule>>` so that
Expand Down
Loading
Loading