Skip to content

Commit

Permalink
feat: Ensure packages always have modules at the root (#1589)
Browse files Browse the repository at this point in the history
Followup to #1587.

Ensures that rust `Package`s only contain module-rooted hugrs. Returns
errors at construction times if the condition is not met.
This required fixing `hugr-cli`, as it should be able to load both
packages and arbitrary hugrs.
Added `Package::from_hugr{,s}` methods that try to wrap the hugrs if
possible. We'll need these for tket2.

Packages are not on the spec yet, their description should include this
restriction. See #1388.

This PR does only modify the (unpublished) API introduced in #1587, so
I'm not marking it as breaking.
  • Loading branch information
aborgna-q authored Oct 21, 2024
1 parent ddca29c commit d349eee
Show file tree
Hide file tree
Showing 5 changed files with 326 additions and 33 deletions.
62 changes: 54 additions & 8 deletions hugr-cli/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use clap::Parser;
use clap_verbosity_flag::{InfoLevel, Verbosity};
use clio::Input;
use derive_more::{Display, Error, From};
use hugr::package::{PackageEncodingError, PackageValidationError};
use hugr::extension::ExtensionRegistry;
use hugr::package::PackageValidationError;
use hugr::Hugr;
use std::{ffi::OsString, path::PathBuf};

pub mod extensions;
Expand Down Expand Up @@ -40,11 +42,8 @@ pub enum CliError {
#[display("Error reading from path: {_0}")]
InputFile(std::io::Error),
/// Error parsing input.
#[display("Error parsing input: {_0}")]
Parse(serde_json::Error),
/// Error loading a package.
#[display("Error parsing package: {_0}")]
Package(PackageEncodingError),
Parse(serde_json::Error),
#[display("Error validating HUGR: {_0}")]
/// Errors produced by the `validate` subcommand.
Validate(PackageValidationError),
Expand Down Expand Up @@ -74,10 +73,57 @@ pub struct HugrArgs {
pub extensions: Vec<PathBuf>,
}

/// A simple enum containing either a package or a single hugr.
///
/// This is required since `Package`s can only contain module-rooted hugrs.
#[derive(Debug, Clone, PartialEq)]
pub enum PackageOrHugr {
/// A package with module-rooted HUGRs and some required extensions.
Package(Package),
/// An arbitrary HUGR.
Hugr(Hugr),
}

impl PackageOrHugr {
/// Returns the list of hugrs in the package.
pub fn into_hugrs(self) -> Vec<Hugr> {
match self {
PackageOrHugr::Package(pkg) => pkg.modules,
PackageOrHugr::Hugr(hugr) => vec![hugr],
}
}

/// Validates the package or hugr.
///
/// Updates the extension registry with any new extensions defined in the package.
pub fn update_validate(
&mut self,
reg: &mut ExtensionRegistry,
) -> Result<(), PackageValidationError> {
match self {
PackageOrHugr::Package(pkg) => pkg.update_validate(reg),
PackageOrHugr::Hugr(hugr) => hugr.update_validate(reg).map_err(Into::into),
}
}
}

impl AsRef<[Hugr]> for PackageOrHugr {
fn as_ref(&self) -> &[Hugr] {
match self {
PackageOrHugr::Package(pkg) => &pkg.modules,
PackageOrHugr::Hugr(hugr) => std::slice::from_ref(hugr),
}
}
}

impl HugrArgs {
/// Read either a package or a single hugr from the input.
pub fn get_package(&mut self) -> Result<Package, CliError> {
let pkg = Package::from_json_reader(&mut self.input)?;
Ok(pkg)
pub fn get_package_or_hugr(&mut self) -> Result<PackageOrHugr, CliError> {
let val: serde_json::Value = serde_json::from_reader(&mut self.input)?;
if let Ok(hugr) = serde_json::from_value::<Hugr>(val.clone()) {
return Ok(PackageOrHugr::Hugr(hugr));
}
let pkg = serde_json::from_value::<Package>(val.clone())?;
Ok(PackageOrHugr::Package(pkg))
}
}
2 changes: 1 addition & 1 deletion hugr-cli/src/mermaid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl MermaidArgs {
let hugrs = if self.validate {
self.hugr_args.validate()?.0
} else {
self.hugr_args.get_package()?.modules
self.hugr_args.get_package_or_hugr()?.into_hugrs()
};

for hugr in hugrs {
Expand Down
6 changes: 3 additions & 3 deletions hugr-cli/src/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl HugrArgs {
/// Returns the validated modules and the extension registry the modules
/// were validated against.
pub fn validate(&mut self) -> Result<(Vec<Hugr>, ExtensionRegistry), CliError> {
let mut package = self.get_package()?;
let mut package = self.get_package_or_hugr()?;

let mut reg: ExtensionRegistry = if self.no_std {
hugr::extension::PRELUDE_REGISTRY.to_owned()
Expand All @@ -60,8 +60,8 @@ impl HugrArgs {
.map_err(PackageValidationError::Extension)?;
}

package.validate(&mut reg)?;
Ok((package.modules, reg))
package.update_validate(&mut reg)?;
Ok((package.into_hugrs(), reg))
}

/// Test whether a `level` message should be output.
Expand Down
40 changes: 27 additions & 13 deletions hugr-cli/tests/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

use assert_cmd::Command;
use assert_fs::{fixture::FileWriteStr, NamedTempFile};
use hugr::builder::DFGBuilder;
use hugr::builder::{DFGBuilder, DataflowSubContainer, ModuleBuilder};
use hugr::types::Type;
use hugr::{
builder::{Container, Dataflow},
Expand All @@ -31,6 +31,29 @@ fn val_cmd(mut cmd: Command) -> Command {
cmd
}

// path to the fully serialized float extension
const FLOAT_EXT_FILE: &str = concat!(
env!("CARGO_MANIFEST_DIR"),
"/../specification/std_extensions/arithmetic/float/types.json"
);

/// A test package, containing a module-rooted HUGR.
#[fixture]
fn test_package(#[default(BOOL_T)] id_type: Type) -> Package {
let mut module = ModuleBuilder::new();
let df = module
.define_function("test", Signature::new_endo(id_type))
.unwrap();
let [i] = df.input_wires_arr();
df.finish_with_outputs([i]).unwrap();
let hugr = module.hugr().clone(); // unvalidated

let rdr = std::fs::File::open(FLOAT_EXT_FILE).unwrap();
let float_ext: hugr::Extension = serde_json::from_reader(rdr).unwrap();
Package::new(vec![hugr], vec![float_ext]).unwrap()
}

/// A DFG-rooted HUGR.
#[fixture]
fn test_hugr(#[default(BOOL_T)] id_type: Type) -> Hugr {
let mut df = DFGBuilder::new(Signature::new_endo(id_type)).unwrap();
Expand Down Expand Up @@ -169,12 +192,6 @@ fn test_no_std_fail(float_hugr_string: String, mut val_cmd: Command) {
.stderr(contains(" Extension 'arithmetic.float.types' not found"));
}

// path to the fully serialized float extension
const FLOAT_EXT_FILE: &str = concat!(
env!("CARGO_MANIFEST_DIR"),
"/../specification/std_extensions/arithmetic/float/types.json"
);

#[rstest]
fn test_float_extension(float_hugr_string: String, mut val_cmd: Command) {
val_cmd.write_stdin(float_hugr_string);
Expand All @@ -186,15 +203,12 @@ fn test_float_extension(float_hugr_string: String, mut val_cmd: Command) {
val_cmd.assert().success().stderr(contains(VALID_PRINT));
}
#[fixture]
fn package_string(#[with(FLOAT64_TYPE)] test_hugr: Hugr) -> String {
let rdr = std::fs::File::open(FLOAT_EXT_FILE).unwrap();
let float_ext: hugr::Extension = serde_json::from_reader(rdr).unwrap();
let package = Package::new(vec![test_hugr], vec![float_ext]);
serde_json::to_string(&package).unwrap()
fn package_string(#[with(FLOAT64_TYPE)] test_package: Package) -> String {
serde_json::to_string(&test_package).unwrap()
}

#[rstest]
fn test_package(package_string: String, mut val_cmd: Command) {
fn test_package_validation(package_string: String, mut val_cmd: Command) {
// package with float extension and hugr that uses floats can validate
val_cmd.write_stdin(package_string);
val_cmd.arg("-");
Expand Down
Loading

0 comments on commit d349eee

Please sign in to comment.