Skip to content

Commit

Permalink
feat: Add Package definition on hugr-core
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Oct 16, 2024
1 parent 6cb0dcd commit 29c4c9b
Show file tree
Hide file tree
Showing 12 changed files with 164 additions and 91 deletions.
3 changes: 2 additions & 1 deletion hugr-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ categories = ["compilers"]
[dependencies]
clap = { workspace = true, features = ["derive"] }
clap-verbosity-flag.workspace = true
hugr-core = { path = "../hugr-core", version = "0.13.1" }
derive_more = { workspace = true, features = ["display", "error", "from"] }
hugr = { path = "../hugr", version = "0.13.1" }
serde_json.workspace = true
serde.workspace = true
thiserror.workspace = true
Expand Down
2 changes: 1 addition & 1 deletion hugr-cli/src/extensions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Dump standard extensions in serialized form.
use clap::Parser;
use hugr_core::extension::ExtensionRegistry;
use hugr::extension::ExtensionRegistry;
use std::{io::Write, path::PathBuf};

/// Dump the standard extensions.
Expand Down
54 changes: 17 additions & 37 deletions hugr-cli/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
use clap::Parser;
use clap_verbosity_flag::{InfoLevel, Verbosity};
use clio::Input;
use hugr_core::{Extension, Hugr};
use derive_more::{Display, Error, From};
use hugr::package::{PackageLoadError, PackageValidationError};
use std::{ffi::OsString, path::PathBuf};
use thiserror::Error;

pub mod extensions;
pub mod mermaid;
pub mod validate;

// TODO: Deprecated re-export. Remove on a breaking release.
pub use hugr::package::Package;

/// CLI arguments.
#[derive(Parser, Debug)]
#[clap(version = "1.0", long_about = None)]
Expand All @@ -30,18 +33,21 @@ pub enum CliArgs {
}

/// Error type for the CLI.
#[derive(Debug, Error)]
#[error(transparent)]
#[derive(Debug, Display, Error, From)]
#[non_exhaustive]
pub enum CliError {
/// Error reading input.
#[error("Error reading from path: {0}")]
InputFile(#[from] std::io::Error),
#[display("Error reading from path: {_0}")]
InputFile(std::io::Error),
/// Error parsing input.
#[error("Error parsing input: {0}")]
Parse(#[from] serde_json::Error),
#[display("Error parsing input: {_0}")]
Parse(serde_json::Error),
/// Error loading a package.
#[display("Error parsing package: {_0}")]
Package(PackageLoadError),
#[display("Error validating HUGR: {_0}")]
/// Errors produced by the `validate` subcommand.
Validate(#[from] validate::ValError),
Validate(PackageValidationError),
}

/// Validate and visualise a HUGR file.
Expand All @@ -68,36 +74,10 @@ pub struct HugrArgs {
pub extensions: Vec<PathBuf>,
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
/// Package of module HUGRs and extensions.
/// The HUGRs are validated against the extensions.
pub struct Package {
/// Module HUGRs included in the package.
pub modules: Vec<Hugr>,
/// Extensions to validate against.
pub extensions: Vec<Extension>,
}

impl Package {
/// Create a new package.
pub fn new(modules: Vec<Hugr>, extensions: Vec<Extension>) -> Self {
Self {
modules,
extensions,
}
}
}

impl HugrArgs {
/// Read either a package or a single hugr from the input.
pub fn get_package(&mut self) -> Result<Package, CliError> {
let val: serde_json::Value = serde_json::from_reader(&mut self.input)?;
// read either a package or a single hugr
if let Ok(p) = serde_json::from_value::<Package>(val.clone()) {
Ok(p)
} else {
let hugr: Hugr = serde_json::from_value(val)?;
Ok(Package::new(vec![hugr], vec![]))
}
let pkg = Package::from_json_reader(&mut self.input)?;
Ok(pkg)
}
}
2 changes: 1 addition & 1 deletion hugr-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use clap_verbosity_flag::Level;
fn main() {
match CliArgs::parse() {
CliArgs::Validate(args) => run_validate(args),
CliArgs::GenExtensions(args) => args.run_dump(&hugr_core::std_extensions::STD_REG),
CliArgs::GenExtensions(args) => args.run_dump(&hugr::std_extensions::STD_REG),
CliArgs::Mermaid(mut args) => args.run_print().unwrap(),
CliArgs::External(_) => {
// TODO: Implement support for external commands.
Expand Down
2 changes: 1 addition & 1 deletion hugr-cli/src/mermaid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::io::Write;

use clap::Parser;
use clio::Output;
use hugr_core::HugrView;
use hugr::HugrView;

/// Dump the standard extensions.
#[derive(Parser, Debug)]
Expand Down
51 changes: 10 additions & 41 deletions hugr-cli/src/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

use clap::Parser;
use clap_verbosity_flag::Level;
use hugr_core::{extension::ExtensionRegistry, Extension, Hugr};
use thiserror::Error;
use hugr::package::PackageValidationError;
use hugr::{extension::ExtensionRegistry, Extension, Hugr};

use crate::{CliError, HugrArgs, Package};
use crate::{CliError, HugrArgs};

/// Validate and visualise a HUGR file.
#[derive(Parser, Debug)]
Expand All @@ -19,18 +19,6 @@ pub struct ValArgs {
pub hugr_args: HugrArgs,
}

/// Error type for the CLI.
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum ValError {
/// Error validating HUGR.
#[error("Error validating HUGR: {0}")]
Validate(#[from] hugr_core::hugr::ValidationError),
/// Error registering extension.
#[error("Error registering extension: {0}")]
ExtReg(#[from] hugr_core::extension::ExtensionRegistryError),
}

/// String to print when validation is successful.
pub const VALID_PRINT: &str = "HUGR valid!";

Expand All @@ -50,49 +38,30 @@ impl ValArgs {
}
}

impl Package {
/// Validate the package against an extension registry.
///
/// `reg` is updated with any new extensions.
///
/// Returns the validated modules.
pub fn validate(mut self, reg: &mut ExtensionRegistry) -> Result<Vec<Hugr>, ValError> {
// register packed extensions
for ext in self.extensions {
reg.register_updated(ext)?;
}

for hugr in self.modules.iter_mut() {
hugr.update_validate(reg)?;
}

Ok(self.modules)
}
}

impl HugrArgs {
/// Load the package and validate against an extension registry.
///
/// Returns the validated modules and the extension registry the modules
/// were validated against.
pub fn validate(&mut self) -> Result<(Vec<Hugr>, ExtensionRegistry), CliError> {
let package = self.get_package()?;
let mut package = self.get_package()?;

let mut reg: ExtensionRegistry = if self.no_std {
hugr_core::extension::PRELUDE_REGISTRY.to_owned()
hugr::extension::PRELUDE_REGISTRY.to_owned()
} else {
hugr_core::std_extensions::STD_REG.to_owned()
hugr::std_extensions::STD_REG.to_owned()
};

// register external extensions
for ext in &self.extensions {
let f = std::fs::File::open(ext)?;
let ext: Extension = serde_json::from_reader(f)?;
reg.register_updated(ext).map_err(ValError::ExtReg)?;
reg.register_updated(ext)
.map_err(PackageValidationError::Extension)?;
}

let modules = package.validate(&mut reg)?;
Ok((modules, reg))
package.validate(&mut reg)?;
Ok((package.modules, reg))
}

/// Test whether a `level` message should be output.
Expand Down
14 changes: 7 additions & 7 deletions hugr-cli/tests/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@

use assert_cmd::Command;
use assert_fs::{fixture::FileWriteStr, NamedTempFile};
use hugr_cli::{validate::VALID_PRINT, Package};
use hugr_core::builder::DFGBuilder;
use hugr_core::types::Type;
use hugr_core::{
use hugr::builder::DFGBuilder;
use hugr::types::Type;
use hugr::{
builder::{Container, Dataflow},
extension::prelude::{BOOL_T, QB_T},
std_extensions::arithmetic::float_types::FLOAT64_TYPE,
type_row,
types::Signature,
Hugr,
};
use hugr_cli::{validate::VALID_PRINT, Package};
use predicates::{prelude::*, str::contains};
use rstest::{fixture, rstest};

Expand Down Expand Up @@ -128,7 +128,7 @@ fn test_bad_json(mut val_cmd: Command) {
val_cmd
.assert()
.failure()
.stderr(contains("Error parsing input"));
.stderr(contains("Error parsing package"));
}

#[rstest]
Expand All @@ -139,7 +139,7 @@ fn test_bad_json_silent(mut val_cmd: Command) {
val_cmd
.assert()
.failure()
.stderr(contains("Error parsing input").not());
.stderr(contains("Error parsing package").not());
}

#[rstest]
Expand Down Expand Up @@ -188,7 +188,7 @@ fn test_float_extension(float_hugr_string: String, mut val_cmd: Command) {
#[fixture]
fn package_string(#[with(FLOAT64_TYPE)] test_hugr: Hugr) -> String {
let rdr = std::fs::File::open(FLOAT_EXT_FILE).unwrap();
let float_ext: hugr_core::Extension = serde_json::from_reader(rdr).unwrap();
let float_ext: hugr::Extension = serde_json::from_reader(rdr).unwrap();
let package = Package::new(vec![test_hugr], vec![float_ext]);
serde_json::to_string(&package).unwrap()
}
Expand Down
2 changes: 1 addition & 1 deletion hugr-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ serde = { workspace = true, features = ["derive", "rc"] }
serde_yaml = { workspace = true, optional = true }
typetag = { workspace = true }
smol_str = { workspace = true, features = ["serde"] }
derive_more = { workspace = true, features = ["display", "from"] }
derive_more = { workspace = true, features = ["display", "error", "from"] }
itertools = { workspace = true }
html-escape = { workspace = true }
bitvec = { workspace = true, features = ["serde"] }
Expand Down
27 changes: 27 additions & 0 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ impl ExtensionRegistry {
/// If extension IDs match, the extension with the higher version is kept.
/// If versions match, the original extension is kept.
/// Returns a reference to the registered extension if successful.
///
/// Avoids cloning the extension unless required. For a reference version see
/// [`ExtensionRegistry::register_updated_ref`].
pub fn register_updated(
&mut self,
extension: Extension,
Expand All @@ -107,6 +110,30 @@ impl ExtensionRegistry {
}
}

/// Registers a new extension to the registry, keeping most up to date if
/// extension exists.
///
/// If extension IDs match, the extension with the higher version is kept.
/// If versions match, the original extension is kept. Returns a reference
/// to the registered extension if successful.
///
/// Clones the extension if required. For no-cloning version see
/// [`ExtensionRegistry::register_updated`].
pub fn register_updated_ref(
&mut self,
extension: &Extension,
) -> Result<&Extension, ExtensionRegistryError> {
match self.0.entry(extension.name().clone()) {
btree_map::Entry::Occupied(mut prev) => {
if prev.get().version() < extension.version() {
*prev.get_mut() = extension.clone();
}
Ok(prev.into_mut())
}
btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension.clone())),
}
}

/// Returns the number of extensions in the registry.
pub fn len(&self) -> usize {
self.0.len()
Expand Down
1 change: 1 addition & 0 deletions hugr-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub mod hugr;
pub mod import;
pub mod macros;
pub mod ops;
pub mod package;
pub mod std_extensions;
pub mod types;
pub mod utils;
Expand Down
Loading

0 comments on commit 29c4c9b

Please sign in to comment.