Skip to content

Commit

Permalink
require type checking on CustomConst (#325)
Browse files Browse the repository at this point in the history
* require type checking on `CustomConst`

this now makes it impossible for use of static values of types from resources
that do not provide binary `CustomConst`.

Is this what we want?

* address review comments

string custom type check error
  • Loading branch information
ss2165 authored Aug 2, 2023
1 parent 13c0094 commit b16eb21
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 21 deletions.
60 changes: 54 additions & 6 deletions src/extensions/rotation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::collections::HashMap;
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;

use crate::ops::constant::typecheck::CustomCheckFail;
use crate::ops::constant::CustomConst;
use crate::resource::{OpDef, ResourceSet, TypeDef};
use crate::types::type_param::TypeArg;
Expand Down Expand Up @@ -93,6 +94,15 @@ pub enum Constant {
Quaternion(cgmath::Quaternion<f64>),
}

impl Constant {
fn rotation_type(&self) -> Type {
match self {
Constant::Angle(_) => Type::Angle,
Constant::Quaternion(_) => Type::Quaternion,
}
}
}

#[typetag::serde]
impl CustomConst for Constant {
fn name(&self) -> SmolStr {
Expand All @@ -103,12 +113,24 @@ impl CustomConst for Constant {
.into()
}

fn custom_type(&self) -> CustomType {
let t: Type = match self {
Constant::Angle(_) => Type::Angle,
Constant::Quaternion(_) => Type::Quaternion,
};
t.custom_type()
fn check_custom_type(&self, typ: &CustomType) -> Result<(), CustomCheckFail> {
let self_typ = self.rotation_type();

if &self_typ.custom_type() == typ {
Ok(())
} else {
Err(CustomCheckFail::new(
"Rotation constant type mismatch.".into(),
))
}
}

fn equal_consts(&self, other: &dyn CustomConst) -> bool {
if let Some(other) = other.as_any().downcast_ref::<Constant>() {
self == other
} else {
false
}
}
}

Expand Down Expand Up @@ -320,4 +342,30 @@ mod test {
))
);
}

#[test]
fn test_type_check() {
let resource = resource();

let custom_type = resource
.types()
.get("angle")
.unwrap()
.instantiate_concrete([])
.unwrap();

let custom_value = Constant::Angle(AngleValue::F64(0.0));

// correct type
custom_value.check_custom_type(&custom_type).unwrap();

let wrong_custom_type = resource
.types()
.get("quat")
.unwrap()
.instantiate_concrete([])
.unwrap();
let res = custom_value.check_custom_type(&wrong_custom_type);
assert!(res.is_err());
}
}
13 changes: 6 additions & 7 deletions src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
use downcast_rs::{impl_downcast, Downcast};
use smol_str::SmolStr;

use self::typecheck::{typecheck_const, ConstTypeError};
use self::typecheck::{typecheck_const, ConstTypeError, CustomCheckFail};

use super::OpTag;
use super::{OpName, OpTrait, StaticTag};
Expand Down Expand Up @@ -134,7 +134,7 @@ pub enum ConstValue {
Tuple(Vec<ConstValue>),
/// An opaque constant value, with cached type
// Note: the extra level of tupling is to avoid https://github.com/rust-lang/rust/issues/78808
Opaque((CustomType, Box<dyn CustomConst>)),
Opaque((Box<dyn CustomConst>,)),
}

impl PartialEq for dyn CustomConst {
Expand All @@ -159,7 +159,7 @@ impl ConstValue {
match self {
Self::Int(value) => format!("const:int{value}"),
Self::F64(f) => format!("const:float:{f}"),
Self::Opaque((_, v)) => format!("const:{}", v.name()),
Self::Opaque((v,)) => format!("const:{}", v.name()),
Self::Sum(tag, val) => {
format!("const:sum:{{tag:{tag}, val:{}}}", val.name())
}
Expand Down Expand Up @@ -200,7 +200,7 @@ impl ConstValue {

impl<T: CustomConst> From<T> for ConstValue {
fn from(v: T) -> Self {
Self::Opaque((v.custom_type(), Box::new(v)))
Self::Opaque((Box::new(v),))
}
}

Expand All @@ -215,9 +215,8 @@ pub trait CustomConst:
/// An identifier for the constant.
fn name(&self) -> SmolStr;

/// Returns the type of the constant.
// TODO it would be good to ensure that this is a *classic* CustomType not a linear one!
fn custom_type(&self) -> CustomType;
/// Check the value is a valid instance of the provided type.
fn check_custom_type(&self, typ: &CustomType) -> Result<(), CustomCheckFail>;

/// Compare two constants for equality, using downcasting and comparing the definitions.
fn equal_consts(&self, other: &dyn CustomConst) -> bool {
Expand Down
25 changes: 17 additions & 8 deletions src/ops/constant/typecheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ pub enum ConstIntError {
IntWidthInvalid(HugrIntWidthStore),
}

/// Struct for custom type check fails.
#[derive(Clone, Debug, PartialEq, Error)]
#[error("Error when checking custom type.")]
pub struct CustomCheckFail(String);

impl CustomCheckFail {
/// Creates a new [`CustomCheckFail`].
pub fn new(message: String) -> Self {
Self(message)
}
}

/// Errors that arise from typechecking constants
#[derive(Clone, Debug, PartialEq, Error)]
pub enum ConstTypeError {
Expand All @@ -50,12 +62,12 @@ pub enum ConstTypeError {
/// Tag for a sum value exceeded the number of variants
#[error("Tag of Sum value is invalid")]
InvalidSumTag,
/// A mismatch between the type expected and the actual type of the constant
#[error("Type mismatch for const - expected {0}, found {1:?}")]
TypeMismatch(ClassicType, ClassicType),
/// A mismatch between the type expected and the value.
#[error("Value {1:?} does not match expected type {0}")]
ValueCheckFail(ClassicType, ConstValue),
/// Error when checking a custom value.
#[error("Custom value type check error: {0:?}")]
CustomCheckFail(#[from] CustomCheckFail),
}

lazy_static! {
Expand Down Expand Up @@ -140,11 +152,8 @@ pub(super) fn typecheck_const(typ: &ClassicType, val: &ConstValue) -> Result<(),
}
}
(Container::Sum(_), _) => Err(ConstTypeError::ValueCheckFail(ty.clone(), tm.clone())),
(Container::Opaque(ty), ConstValue::Opaque((ty_act, _val))) => {
if ty_act != ty {
return Err(ConstTypeError::ValueCheckFail(typ.clone(), val.clone()));
}
Ok(())
(Container::Opaque(ty), ConstValue::Opaque((val,))) => {
val.check_custom_type(ty).map_err(ConstTypeError::from)
}
_ => Err(ConstTypeError::Unimplemented(ty.clone())),
},
Expand Down

0 comments on commit b16eb21

Please sign in to comment.