diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..fe84e1f --- /dev/null +++ b/.gitattributes @@ -0,0 +1,3 @@ +# disable diffs on go.mod and go.sum so GitHub won't render the text changes +go.mod -diff +go.sum -diff diff --git a/eventstore_test.go b/eventstore_test.go index 3442140..c617f43 100644 --- a/eventstore_test.go +++ b/eventstore_test.go @@ -7,6 +7,7 @@ import ( "github.com/nats-io/nats.go" "github.com/nats-io/nats.go/jetstream" + "github.com/synadia-labs/rita/codec" "github.com/synadia-labs/rita/id" "github.com/synadia-labs/rita/testutil" "github.com/synadia-labs/rita/types" @@ -233,14 +234,14 @@ func TestEventStoreWithRegistry(t *testing.T) { nc, _ := nats.Connect(srv.ClientURL()) - tr, err := types.NewRegistry(map[string]*types.Type{ - "order-placed": { - Init: func() any { return &OrderPlaced{} }, + tr, err := types.NewInMemRegistry(map[string]types.Type{ + "order-placed": types.InMemType{ + InitFn: func() any { return &OrderPlaced{} }, }, - "order-shipped": { - Init: func() any { return &OrderShipped{} }, + "order-shipped": &types.InMemType{ + InitFn: func() any { return &OrderShipped{} }, }, - }) + }, codec.Default) is.NoErr(err) r, err := New(t.Context(), nc, TypeRegistry(tr)) diff --git a/go.mod b/go.mod index 7f5a140..4256376 100644 --- a/go.mod +++ b/go.mod @@ -1,25 +1,31 @@ module github.com/synadia-labs/rita -go 1.24 +go 1.25.0 require ( github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 - github.com/nats-io/nats-server/v2 v2.11.6 - github.com/nats-io/nats.go v1.44.0 + github.com/nats-io/nats-server/v2 v2.11.9 + github.com/nats-io/nats.go v1.45.0 github.com/nats-io/nuid v1.0.1 + github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 + github.com/synadia-io/schema-registry-sdk/go/schemaregistry v0.0.0-20250915095715-75c30f0c9054 github.com/vmihailenco/msgpack/v5 v5.4.1 - google.golang.org/protobuf v1.36.6 + google.golang.org/protobuf v1.36.9 ) require ( - github.com/google/go-tpm v0.9.5 // indirect + github.com/antithesishq/antithesis-sdk-go v0.5.0 // indirect + github.com/google/go-tpm v0.9.6 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/minio/highwayhash v1.0.3 // indirect - github.com/nats-io/jwt/v2 v2.7.4 // indirect + github.com/nats-io/jwt/v2 v2.8.0 // indirect github.com/nats-io/nkeys v0.4.11 // indirect + github.com/stretchr/testify v1.10.0 // indirect + github.com/synadia-labs/schema-registry v0.0.0-20250912022759-6e8f74fd7d2e // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect - golang.org/x/crypto v0.40.0 // indirect - golang.org/x/sys v0.34.0 // indirect - golang.org/x/time v0.12.0 // indirect + golang.org/x/crypto v0.42.0 // indirect + golang.org/x/sys v0.36.0 // indirect + golang.org/x/text v0.29.0 // indirect + golang.org/x/time v0.13.0 // indirect ) diff --git a/go.sum b/go.sum index 536dcff..7989684 100644 --- a/go.sum +++ b/go.sum @@ -1,43 +1,53 @@ -github.com/antithesishq/antithesis-sdk-go v0.4.3-default-no-op h1:+OSa/t11TFhqfrX0EOSqQBDJ0YlpmK0rDSiB19dg9M0= -github.com/antithesishq/antithesis-sdk-go v0.4.3-default-no-op/go.mod h1:IUpT2DPAKh6i/YhSbt6Gl3v2yvUZjmKncl7U91fup7E= -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/antithesishq/antithesis-sdk-go v0.5.0 h1:cudCFF83pDDANcXFzkQPUHHedfnnIbUO3JMr9fqwFJs= +github.com/antithesishq/antithesis-sdk-go v0.5.0/go.mod h1:IUpT2DPAKh6i/YhSbt6Gl3v2yvUZjmKncl7U91fup7E= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= +github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/go-tpm v0.9.5 h1:ocUmnDebX54dnW+MQWGQRbdaAcJELsa6PqZhJ48KwVU= -github.com/google/go-tpm v0.9.5/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY= +github.com/google/go-tpm v0.9.6 h1:Ku42PT4LmjDu1H5C5ISWLlpI1mj+Zq7sPGKoRw2XROA= +github.com/google/go-tpm v0.9.6/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/minio/highwayhash v1.0.3 h1:kbnuUMoHYyVl7szWjSxJnxw11k2U709jqFPPmIUyD6Q= github.com/minio/highwayhash v1.0.3/go.mod h1:GGYsuwP/fPD6Y9hMiXuapVvlIUEhFhMTh0rxU3ik1LQ= -github.com/nats-io/jwt/v2 v2.7.4 h1:jXFuDDxs/GQjGDZGhNgH4tXzSUK6WQi2rsj4xmsNOtI= -github.com/nats-io/jwt/v2 v2.7.4/go.mod h1:me11pOkwObtcBNR8AiMrUbtVOUGkqYjMQZ6jnSdVUIA= -github.com/nats-io/nats-server/v2 v2.11.6 h1:4VXRjbTUFKEB+7UoaKL3F5Y83xC7MxPoIONOnGgpkHw= -github.com/nats-io/nats-server/v2 v2.11.6/go.mod h1:2xoztlcb4lDL5Blh1/BiukkKELXvKQ5Vy29FPVRBUYs= -github.com/nats-io/nats.go v1.44.0 h1:ECKVrDLdh/kDPV1g0gAQ+2+m2KprqZK5O/eJAyAnH2M= -github.com/nats-io/nats.go v1.44.0/go.mod h1:iRWIPokVIFbVijxuMQq4y9ttaBTMe0SFdlZfMDd+33g= +github.com/nats-io/jwt/v2 v2.8.0 h1:K7uzyz50+yGZDO5o772eRE7atlcSEENpL7P+b74JV1g= +github.com/nats-io/jwt/v2 v2.8.0/go.mod h1:me11pOkwObtcBNR8AiMrUbtVOUGkqYjMQZ6jnSdVUIA= +github.com/nats-io/nats-server/v2 v2.11.9 h1:k7nzHZjUf51W1b08xiQih63Rdxh0yr5O4K892Mx5gQA= +github.com/nats-io/nats-server/v2 v2.11.9/go.mod h1:1MQgsAQX1tVjpf3Yzrk3x2pzdsZiNL/TVP3Amhp3CR8= +github.com/nats-io/nats.go v1.45.0 h1:/wGPbnYXDM0pLKFjZTX+2JOw9TQPoIgTFrUaH97giwA= +github.com/nats-io/nats.go v1.45.0/go.mod h1:iRWIPokVIFbVijxuMQq4y9ttaBTMe0SFdlZfMDd+33g= github.com/nats-io/nkeys v0.4.11 h1:q44qGV008kYd9W1b1nEBkNzvnWxtRSQ7A8BoqRrcfa0= github.com/nats-io/nkeys v0.4.11/go.mod h1:szDimtgmfOi9n25JpfIdGw12tZFYXqhGxjhVxsatHVE= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 h1:KRzFb2m7YtdldCEkzs6KqmJw4nqEVZGK7IN2kJkjTuQ= +github.com/santhosh-tekuri/jsonschema/v6 v6.0.2/go.mod h1:JXeL+ps8p7/KNMjDQk3TCwPpBy0wYklyWTfbkIzdIFU= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/synadia-io/schema-registry-sdk/go/schemaregistry v0.0.0-20250915095715-75c30f0c9054 h1:0avKvAQjiEP1sH3uWjq4PYuzEmF6EpombFEEf3Ogr0w= +github.com/synadia-io/schema-registry-sdk/go/schemaregistry v0.0.0-20250915095715-75c30f0c9054/go.mod h1:JO62cktLMaW5JaSBflMxlRsBcojnAp2IkbMmWMZlw0c= +github.com/synadia-labs/schema-registry v0.0.0-20250912022759-6e8f74fd7d2e h1:M5vE1D6jmI6HCp4o8XVLXyhsH4AVS2GNU7cPmdAUl5s= +github.com/synadia-labs/schema-registry v0.0.0-20250912022759-6e8f74fd7d2e/go.mod h1:Tmrimtwm23ar97KaRI5YVbfJSyf3ugPoLT0nz3qZAos= github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= -golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= -golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= +golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= +golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= -golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= -golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= -google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= -google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= +golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= +golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= +golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI= +golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= +google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7Iw= +google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/rita.go b/rita.go index 942af80..b2673dc 100644 --- a/rita.go +++ b/rita.go @@ -27,7 +27,7 @@ type RitaOption interface { } // TypeRegistry sets an explicit type registry. -func TypeRegistry(types *types.Registry) RitaOption { +func TypeRegistry(types types.Registry) RitaOption { return ritaOption(func(o *Rita) error { o.types = types return nil @@ -65,7 +65,7 @@ type Rita struct { id id.ID clock clock.Clock - types *types.Registry + types types.Registry } // UnpackEvent unpacks an Event from a NATS message. diff --git a/types/inmem_registry.go b/types/inmem_registry.go new file mode 100644 index 0000000..06cdecf --- /dev/null +++ b/types/inmem_registry.go @@ -0,0 +1,173 @@ +package types + +import ( + "fmt" + "reflect" + + "github.com/synadia-labs/rita/codec" +) + +// InMemRegistry is used for transparently marshaling and unmarshaling messages +// and values from their native types to their network/storage representation. +type InMemRegistry struct { + // Codec for marshaling and unmarshaling a values. + codec codec.Codec + + // Index of types. + types map[string]Type + + // Reflection type to the type name. + rtypes map[reflect.Type]string +} + +func (r *InMemRegistry) Codec() codec.Codec { + return r.codec +} + +type InMemType struct { + InitFn func() any +} + +func (t InMemType) Init() func() any { + return t.InitFn +} + +func (r *InMemRegistry) validate(name string, typ Type) error { + if name == "" { + return fmt.Errorf("%w: missing name", ErrTypeNotValid) + } + + if err := validateTypeName(name); err != nil { + return err + } + + if typ.Init() == nil { + return fmt.Errorf("%w: %s: missing init func", ErrTypeNotValid, name) + } + // Ensure the initialize value is not nil. + v := typ.Init()() + if v == nil { + return fmt.Errorf("%w: %s: init func returns nil", ErrTypeNotValid, name) + } + + // Get the Go type in order to transparently serialize to the correct name. + rt := reflect.TypeOf(v) + + // Ensure the initialize type is a pointer so that deserialization works. + if rt.Kind() != reflect.Ptr { + return fmt.Errorf("%w: %s: init func must return a pointer value", ErrTypeNotValid, name) + } + + // Ensure that the pointer value is a struct type. + if rt.Elem().Kind() != reflect.Struct { + return fmt.Errorf("%w: %s: value type must be a struct", ErrTypeNotValid, name) + } + + // Ensure [de]serialization works in the base case. + b, err := r.codec.Marshal(v) + if err != nil { + return fmt.Errorf("%w: %s: failed to marshal with codec: %s", ErrTypeNotValid, name, err) + } + + err = r.codec.Unmarshal(b, v) + if err != nil { + return fmt.Errorf("%w: %s: failed to unmarshal with codec: %s", ErrTypeNotValid, name, err) + } + + return nil +} + +func (r *InMemRegistry) addType(name string, typ Type) { + r.types[name] = typ + + // Initialize a value, reflect the type to index. + v := typ.Init()() + rt := reflect.TypeOf(v) + + r.rtypes[rt] = name + r.rtypes[rt.Elem()] = name +} + +// Init a value given the registered name of the type. +func (r *InMemRegistry) Init(t string) (any, error) { + x, ok := r.types[t] + if !ok { + return nil, fmt.Errorf("%w: %s", ErrTypeNotRegistered, t) + } + + v := x.Init()() + return v, nil +} + +// Lookup returns the registered name of the type given a value. +func (r *InMemRegistry) Lookup(v any) (string, error) { + rt := reflect.TypeOf(v) + t, ok := r.rtypes[rt] + if !ok { + return "", fmt.Errorf("%w: %s", ErrNoTypeForStruct, rt) + } + + return t, nil +} + +// Marshal serializes the value to a byte slice. This call +// validates the type is registered and delegates to the codec. +func (r *InMemRegistry) Marshal(v any) ([]byte, error) { + _, err := r.Lookup(v) + if err != nil { + return nil, err + } + + b, err := r.codec.Marshal(v) + if err != nil { + return b, fmt.Errorf("%T: marshal error: %w", v, err) + } + return b, nil +} + +// Unmarshal deserializes a byte slice into the value. This call +// validates the type is registered and delegates to the codec. +func (r *InMemRegistry) Unmarshal(b []byte, v any) error { + _, err := r.Lookup(v) + if err != nil { + return err + } + + err = r.codec.Unmarshal(b, v) + if err != nil { + return fmt.Errorf("%T: unmarshal error: %w", v, err) + } + return nil +} + +// UnmarshalType initializes a new value for the registered type, +// unmarshals the byte slice, and returns it. +func (r *InMemRegistry) UnmarshalType(b []byte, t string) (any, error) { + v, err := r.Init(t) + if err != nil { + return nil, err + } + err = r.Unmarshal(b, v) + if err != nil { + return nil, err + } + return v, nil +} + +func NewInMemRegistry(types map[string]Type, c codec.Codec) (Registry, error) { + r := &InMemRegistry{ + codec: c, + types: make(map[string]Type), + rtypes: make(map[reflect.Type]string), + } + + for n, t := range types { + err := r.validate(n, t) + if err != nil { + return nil, err + } + r.addType(n, t) + } + + return r, nil +} diff --git a/types/registry_test.go b/types/inmem_registry_test.go similarity index 79% rename from types/registry_test.go rename to types/inmem_registry_test.go index a556add..2d14d38 100644 --- a/types/registry_test.go +++ b/types/inmem_registry_test.go @@ -11,7 +11,7 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) -func TestNewRegistry(t *testing.T) { +func TestNewInMemRegistry(t *testing.T) { // Base case. type A struct{} @@ -44,11 +44,11 @@ func TestNewRegistry(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - _, err := NewRegistry(map[string]*Type{ - "a": { - Init: test.Init, + _, err := NewInMemRegistry(map[string]Type{ + "a": InMemType{ + InitFn: test.Init, }, - }) + }, codec.Default) if err != nil && !test.Err { t.Errorf("unexpected error: %s", err) } else if err == nil && test.Err { @@ -61,9 +61,9 @@ func TestNewRegistry(t *testing.T) { func TestMarshalUnmarshal(t *testing.T) { is := testutil.NewIs(t) - ty := map[string]*Type{ - "a": { - Init: func() any { + ty := map[string]Type{ + "a": InMemType{ + InitFn: func() any { return &pb.A{} }, }, @@ -71,7 +71,7 @@ func TestMarshalUnmarshal(t *testing.T) { for _, c := range codec.Codecs { t.Run(c.Name(), func(t *testing.T) { - rt, err := NewRegistry(ty, Codec(c.Name())) + rt, err := NewInMemRegistry(ty, c) is.NoErr(err) v1 := pb.A{ @@ -110,11 +110,11 @@ func TestMarshalUnmarshal(t *testing.T) { func BenchmarkInit(b *testing.B) { type T struct{} - r, _ := NewRegistry(map[string]*Type{ - "a": { - Init: func() any { return &T{} }, + r, _ := NewInMemRegistry(map[string]Type{ + "a": InMemType{ + InitFn: func() any { return &T{} }, }, - }) + }, codec.Default) b.ResetTimer() @@ -126,11 +126,11 @@ func BenchmarkInit(b *testing.B) { func BenchmarkLookup(b *testing.B) { type T struct{} - r, _ := NewRegistry(map[string]*Type{ - "a": { - Init: func() any { return &T{} }, + r, _ := NewInMemRegistry(map[string]Type{ + "a": InMemType{ + InitFn: func() any { return &T{} }, }, - }) + }, codec.Default) v := &T{} diff --git a/types/nats_registry.go b/types/nats_registry.go new file mode 100644 index 0000000..33c920a --- /dev/null +++ b/types/nats_registry.go @@ -0,0 +1,312 @@ +package types + +import ( + "context" + "errors" + "fmt" + "log/slog" + "os" + "reflect" + "strings" + + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" + "github.com/santhosh-tekuri/jsonschema/v6" + "github.com/synadia-labs/rita/codec" + + "github.com/synadia-io/schema-registry-sdk/go/schemaregistry" +) + +var ( + _ Type = (*NatsType)(nil) + _ Registry = (*NatsRegistry)(nil) +) + +// NatsRegistry is used for transparently marshaling and unmarshaling messages +// and values from their native types to their network/storage representation. +type NatsRegistry struct { + ctx context.Context + logger *slog.Logger + // Codec for marshaling and unmarshaling a values. + codec codec.Codec + + // Schema registry for storing and retrieving schemas + registry schemaregistry.Registry + + // JetStream for KV operations + js jetstream.JetStream + + // In-memory stores for type information + types map[string]Type // Maps schema name -> Type interface + rtypes map[string]string // Reverse: maps Go type string -> schema name + schemas map[string]*jsonschema.Schema // Compiled schemas for validation +} + +type NatsType struct { + InitFn func() any + Name string + Description string + DocPath string +} + +func (t NatsType) Init() func() any { + return t.InitFn +} + +func (r *NatsRegistry) Codec() codec.Codec { + return r.codec +} + +// Init a value given the registered name of the type. +func (r *NatsRegistry) Init(t string) (any, error) { + typ, ok := r.types[t] + if !ok { + return nil, fmt.Errorf("%w: %s", ErrTypeNotRegistered, t) + } + + initFn := typ.Init() + if initFn == nil { + return nil, fmt.Errorf("%w: %s has nil init function", ErrTypeNotValid, t) + } + + v := initFn() + return v, nil +} + +// Lookup returns the registered name of the type given a value. +func (r *NatsRegistry) Lookup(v any) (string, error) { + rt := reflect.TypeOf(v) + fullTypeName := rt.String() + + schemaName, ok := r.rtypes[fullTypeName] + if ok { + return schemaName, nil + } + + return "", fmt.Errorf("%w: %s", ErrNoTypeForStruct, rt) +} + +// Marshal serializes the value to a byte slice. This call +// validates the type is registered and delegates to the codec. +func (r *NatsRegistry) Marshal(v any) ([]byte, error) { + schemaName, err := r.Lookup(v) + if err != nil { + return nil, err + } + + b, err := r.codec.Marshal(v) + if err != nil { + return b, fmt.Errorf("%T: marshal error: %w", v, err) + } + + // Validate against schema if we have one + if schema, ok := r.schemas[schemaName]; ok { + // For validation, we need JSON representation + jsonData, err := r.codec.Marshal(v) + if err != nil { + return nil, fmt.Errorf("failed to marshal for validation: %w", err) + } + + switch r.codec { + case codec.JSON: + inst, err := jsonschema.UnmarshalJSON(strings.NewReader(string(jsonData))) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal for validation: %w", err) + } + + if err := schema.Validate(inst); err != nil { + return nil, fmt.Errorf("validation failed: %w", err) + } + default: + if r.logger != nil { + r.logger.Warn("validation skipped: provided codec not supported", slog.String("codec", r.codec.Name())) + } + } + } + + return b, nil +} + +// Unmarshal deserializes a byte slice into the value. This call +// validates the type is registered and delegates to the codec. +func (r *NatsRegistry) Unmarshal(b []byte, v any) error { + schemaName, err := r.Lookup(v) + if err != nil { + return err + } + + err = r.codec.Unmarshal(b, v) + if err != nil { + return fmt.Errorf("%T: unmarshal error: %w", v, err) + } + + if schema, ok := r.schemas[schemaName]; ok { + data, err := r.codec.Marshal(v) + if err != nil { + return fmt.Errorf("failed to marshal for validation: %w", err) + } + + switch r.codec { + case codec.JSON: + inst, err := jsonschema.UnmarshalJSON(strings.NewReader(string(data))) + if err != nil { + return fmt.Errorf("failed to unmarshal for validation: %w", err) + } + + if err := schema.Validate(inst); err != nil { + return fmt.Errorf("validation failed: %w", err) + } + default: + if r.logger != nil { + r.logger.Warn("validation skipped: provided codec not supported", slog.String("codec", r.codec.Name())) + } + } + } + + return nil +} + +// UnmarshalType initializes a new value for the registered type, +// unmarshals the byte slice, and returns it. +func (r *NatsRegistry) UnmarshalType(b []byte, t string) (any, error) { + v, err := r.Init(t) + if err != nil { + return nil, err + } + err = r.Unmarshal(b, v) + if err != nil { + return nil, err + } + return v, nil +} + +func NewNatsRegistry(ctx context.Context, logger *slog.Logger, types map[string]Type, c codec.Codec, nc *nats.Conn) (Registry, error) { + js, err := jetstream.New(nc) + if err != nil { + return nil, fmt.Errorf("creating jetstream context: %w", err) + } + + r := &NatsRegistry{ + ctx: ctx, + logger: logger, + codec: c, + registry: schemaregistry.NewSchemaRegistry(nc), + js: js, + types: make(map[string]Type), + rtypes: make(map[string]string), + schemas: make(map[string]*jsonschema.Schema), + } + + // Process each type registration + for schemaName, t := range types { + // Validate basic requirements + if schemaName == "" { + return nil, fmt.Errorf("%w: missing schema name", ErrTypeNotValid) + } + + if err := validateTypeName(schemaName); err != nil { + return nil, err + } + + // Store the type in memory + r.types[schemaName] = t + + // Get the init function from the type + initFn := t.Init() + + natsType, ok := t.(NatsType) + if !ok { + // For non-NatsType (e.g., InMemType), we still need to register it + if initFn == nil { + return nil, fmt.Errorf("%w: %s: init func is nil", ErrTypeNotValid, schemaName) + } + + v := initFn() + if v == nil { + return nil, fmt.Errorf("%w: %s: init func returns nil", ErrTypeNotValid, schemaName) + } + + // Store type mapping in memory + r.storeTypeMapping(schemaName, v) + + continue + } + + // Process NatsType with schema registration + if natsType.DocPath != "" { + // Read schema definition from file + schemaDefinition, err := os.ReadFile(natsType.DocPath) + if err != nil { + return nil, fmt.Errorf("reading schema file %s: %w", natsType.DocPath, err) + } + + // Try to add schema to registry - this may fail if registry service isn't running + _, err = r.registry.Add(ctx, schemaregistry.AddRequest{ + Name: schemaName, + Definition: string(schemaDefinition), + Format: schemaregistry.FormatJSONSchema, + CompatPolicy: schemaregistry.CompatNone, + Description: natsType.Description, + Metadata: map[string]string{}, + }) + if err != nil && errors.Is(err, nats.ErrNoResponders) { + logger.Debug("schema registry service not available, skipping registration", slog.String("schema", schemaName)) + } else if err != nil { + return nil, fmt.Errorf("failed to add schema %s: %w", schemaName, err) + } + } + + // Compile and store the schema if provided + if natsType.DocPath != "" { + schemaData, err := os.ReadFile(natsType.DocPath) + if err == nil { + c := jsonschema.NewCompiler() + doc, err := jsonschema.UnmarshalJSON(strings.NewReader(string(schemaData))) + if err != nil { + return nil, fmt.Errorf("failed to parse schema JSON %s: %w", natsType.DocPath, err) + } + if err := c.AddResource(schemaName, doc); err != nil { + return nil, fmt.Errorf("failed to add schema resource %s: %w", schemaName, err) + } + schema, err := c.Compile(schemaName) + if err != nil { + return nil, fmt.Errorf("failed to compile schema %s: %w", schemaName, err) + } + r.schemas[schemaName] = schema + if logger != nil { + logger.Debug("compiled schema", "schema", schemaName, "path", natsType.DocPath) + } + } + } + + // Store type mapping in memory using the already stored InitFn + if initFn != nil { + v := initFn() + if v != nil { + r.storeTypeMapping(schemaName, v) + } + } + } + + return r, nil +} + +// storeTypeMapping stores the Go type metadata in memory +func (r *NatsRegistry) storeTypeMapping(schemaName string, v any) { + rt := reflect.TypeOf(v) + fullTypeName := rt.String() + + // Store reverse mapping + r.rtypes[fullTypeName] = schemaName + + // Handle pointer/non-pointer types + if rt.Kind() != reflect.Ptr { + // Also store the pointer version + pointerType := "*" + fullTypeName + r.rtypes[pointerType] = schemaName + } else { + // Also store the non-pointer version + elemType := rt.Elem() + r.rtypes[elemType.String()] = schemaName + } +} diff --git a/types/nats_registry_test.go b/types/nats_registry_test.go new file mode 100644 index 0000000..2ad318b --- /dev/null +++ b/types/nats_registry_test.go @@ -0,0 +1,700 @@ +package types + +import ( + "context" + "log/slog" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" + "github.com/synadia-labs/rita/codec" + "github.com/synadia-labs/rita/testutil" +) + +// cleanupSchemasBucket removes the schemas bucket for test cleanup +func cleanupSchemasBucket(t *testing.T, nc *nats.Conn) { + js, err := jetstream.New(nc) + if err != nil { + return + } + + // Try to delete the schema bucket - ignore errors if it doesn't exist + _ = js.DeleteKeyValue(context.Background(), "schemas") +} + +// Test types that work with JSON Schema +type Person struct { + Name string `json:"name"` + Age int `json:"age"` + Email string `json:"email,omitempty"` +} + +type Order struct { + ID string `json:"id"` + CustomerID string `json:"customer_id"` + Total float64 `json:"total"` + Items []Item `json:"items"` +} + +type Item struct { + ProductID string `json:"product_id"` + Quantity int `json:"quantity"` + Price float64 `json:"price"` +} + +func TestNatsTypeInit(t *testing.T) { + is := testutil.NewIs(t) + + // Test that NatsType.Init() returns the InitFn + nt := NatsType{ + InitFn: func() any { return &Person{} }, + Name: "Person", + } + + initFn := nt.Init() + is.True(initFn != nil) + + // Test that calling the function creates new instances + v1 := initFn() + v2 := initFn() + is.True(v1 != v2) // Different instances + is.True(v1 != nil) + is.True(v2 != nil) +} + +func TestNewNatsRegistry(t *testing.T) { + is := testutil.NewIs(t) + + // Create test schema files + tmpDir := t.TempDir() + + // Person schema + personSchemaPath := filepath.Join(tmpDir, "person.json") + personSchema := `{ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "name": { + "type": "string", + "minLength": 1 + }, + "age": { + "type": "integer", + "minimum": 0, + "maximum": 150 + }, + "email": { + "type": "string", + "format": "email" + } + }, + "required": ["name", "age"] + }` + err := os.WriteFile(personSchemaPath, []byte(personSchema), 0644) + is.NoErr(err) + + // Invalid schema + invalidSchemaPath := filepath.Join(tmpDir, "invalid.json") + err = os.WriteFile(invalidSchemaPath, []byte(`{invalid json`), 0644) + is.NoErr(err) + + srv := testutil.NewNatsServer(-1) + defer testutil.ShutdownNatsServer(srv) + + nc, err := nats.Connect(srv.ClientURL()) + is.NoErr(err) + + // Clean up schemas bucket after test + t.Cleanup(func() { + cleanupSchemasBucket(t, nc) + }) + + // Not serializable type + type NotSerializable struct { + C chan int + } + + tests := map[string]struct { + Types map[string]Type + Err bool + }{ + "valid-nats-type": { + Types: map[string]Type{ + "person": NatsType{ + InitFn: func() any { return &Person{} }, + Name: "Person", + Description: "A person entity", + DocPath: personSchemaPath, + }, + }, + Err: false, + }, + "no-init": { + Types: map[string]Type{ + "person": NatsType{ + InitFn: nil, + Name: "Person", + }, + }, + Err: false, // Lazy loading doesn't validate during registration + }, + "non-pointer": { + Types: map[string]Type{ + "person": NatsType{ + InitFn: func() any { return Person{} }, + Name: "Person", + }, + }, + Err: false, // Lazy loading doesn't validate during registration + }, + "not-serializable": { + Types: map[string]Type{ + "not-serializable": NatsType{ + InitFn: func() any { return &NotSerializable{} }, + Name: "NotSerializable", + }, + }, + Err: false, // Lazy loading doesn't validate during registration + }, + "invalid-schema-path": { + Types: map[string]Type{ + "person": NatsType{ + InitFn: func() any { return &Person{} }, + Name: "Person", + DocPath: "/non/existent/path.json", + }, + }, + Err: true, + }, + "invalid-schema-content": { + Types: map[string]Type{ + "person": NatsType{ + InitFn: func() any { return &Person{} }, + Name: "Person", + DocPath: invalidSchemaPath, + }, + }, + Err: true, + }, + "empty-name": { + Types: map[string]Type{ + "": NatsType{ + InitFn: func() any { return &Person{} }, + Name: "Person", + }, + }, + Err: true, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + _, err := NewNatsRegistry(context.Background(), slog.New(slog.DiscardHandler), test.Types, codec.Default, nc) + if err != nil && !test.Err { + t.Errorf("unexpected error: %s", err) + } else if err == nil && test.Err { + t.Errorf("expected error") + } + }) + } +} + +func TestNatsMarshalUnmarshal(t *testing.T) { + is := testutil.NewIs(t) + + // Create test schema file + tmpDir := t.TempDir() + orderSchemaPath := filepath.Join(tmpDir, "order.json") + orderSchema := `{ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "id": {"type": "string"}, + "customer_id": {"type": "string"}, + "total": {"type": "number"}, + "items": { + "type": "array", + "items": { + "type": "object", + "properties": { + "product_id": {"type": "string"}, + "quantity": {"type": "integer"}, + "price": {"type": "number"} + } + } + } + }, + "required": ["id", "customer_id", "total", "items"] + }` + err := os.WriteFile(orderSchemaPath, []byte(orderSchema), 0644) + is.NoErr(err) + + srv := testutil.NewNatsServer(-1) + defer testutil.ShutdownNatsServer(srv) + + nc, err := nats.Connect(srv.ClientURL()) + is.NoErr(err) + + types := map[string]Type{ + "order": NatsType{ + InitFn: func() any { + return &Order{} + }, + Name: "Order", + DocPath: orderSchemaPath, + }, + } + + // Only test with JSON and msgpack codecs since binary and protobuf have specific requirements + compatibleCodecs := []codec.Codec{codec.JSON, codec.MsgPack} + + for _, c := range compatibleCodecs { + t.Run(c.Name(), func(t *testing.T) { + rt, err := NewNatsRegistry(context.Background(), slog.New(slog.DiscardHandler), types, c, nc) + is.NoErr(err) + + t.Cleanup(func() { + cleanupSchemasBucket(t, nc) + }) + + order1 := Order{ + ID: "order-123", + CustomerID: "customer-456", + Total: 99.99, + Items: []Item{ + {ProductID: "prod-1", Quantity: 2, Price: 29.99}, + {ProductID: "prod-2", Quantity: 1, Price: 40.01}, + }, + } + + // Test lookup + name, err := rt.Lookup(&order1) + is.NoErr(err) + is.Equal(name, "order") + + // Test marshal + b, err := rt.Marshal(&order1) + is.NoErr(err) + + // Test init - now works with factory functions in memory + x, err := rt.Init("order") + is.NoErr(err) + is.True(x != nil) + _, ok := x.(*Order) + is.True(ok) + + // Test unmarshal with pre-existing instance + order2 := &Order{} + err = rt.Unmarshal(b, order2) + is.NoErr(err) + is.Equal(order1.ID, order2.ID) + is.Equal(order1.CustomerID, order2.CustomerID) + is.Equal(order1.Total, order2.Total) + is.Equal(len(order1.Items), len(order2.Items)) + }) + } +} + +func TestNatsRegistryOperations(t *testing.T) { + is := testutil.NewIs(t) + + srv := testutil.NewNatsServer(-1) + defer testutil.ShutdownNatsServer(srv) + + nc, err := nats.Connect(srv.ClientURL()) + is.NoErr(err) + + // Create test schema file + tmpDir := t.TempDir() + personSchemaPath := filepath.Join(tmpDir, "person.json") + personSchema := `{ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "email": {"type": "string"} + } + }` + err = os.WriteFile(personSchemaPath, []byte(personSchema), 0644) + is.NoErr(err) + + types := map[string]Type{ + "person": NatsType{ + InitFn: func() any { return &Person{} }, + Name: "Person", + DocPath: personSchemaPath, + }, + } + + r, err := NewNatsRegistry(context.Background(), slog.New(slog.DiscardHandler), types, codec.Default, nc) + is.NoErr(err) + + // Test Init - should work now that we keep factory functions in memory + v, err := r.Init("person") + is.NoErr(err) + is.True(v != nil) + _, ok := v.(*Person) + is.True(ok) + + // Test Init with unknown type + _, err = r.Init("unknown") + is.True(err != nil) + + // Test Lookup with existing instance + p := &Person{Name: "test", Age: 25} + name, err := r.Lookup(p) + is.NoErr(err) + is.Equal(name, "person") + + // Test Lookup with unregistered type + type UnknownType struct{} + _, err = r.Lookup(&UnknownType{}) + is.True(err != nil) + + // Test Marshal/Unmarshal + p1 := &Person{Name: "John Doe", Age: 30, Email: "john@example.com"} + b, err := r.Marshal(p1) + is.NoErr(err) + + p2 := &Person{} + err = r.Unmarshal(b, p2) + is.NoErr(err) + is.Equal(p2.Name, "John Doe") + is.Equal(p2.Age, 30) + is.Equal(p2.Email, "john@example.com") + + // Test UnmarshalType - now works since Init is implemented + v3, err := r.UnmarshalType(b, "person") + is.NoErr(err) + p3 := v3.(*Person) + is.Equal(p3.Name, "John Doe") + is.Equal(p3.Age, 30) + is.Equal(p3.Email, "john@example.com") + + // Test Codec + is.Equal(r.Codec(), codec.Default) +} + +func TestNatsRegistryPointerNonPointer(t *testing.T) { + is := testutil.NewIs(t) + + srv := testutil.NewNatsServer(-1) + defer testutil.ShutdownNatsServer(srv) + + nc, err := nats.Connect(srv.ClientURL()) + is.NoErr(err) + + // Create minimal schema for testing + tmpDir := t.TempDir() + schemaPath := filepath.Join(tmpDir, "person.json") + err = os.WriteFile(schemaPath, []byte(`{"type": "object"}`), 0644) + is.NoErr(err) + + types := map[string]Type{ + "person": NatsType{ + InitFn: func() any { return &Person{} }, + Name: "Person", + DocPath: schemaPath, + }, + } + + r, err := NewNatsRegistry(context.Background(), slog.New(slog.DiscardHandler), types, codec.Default, nc) + is.NoErr(err) + + // Create a value + p := &Person{Name: "test", Age: 25} + + // Test that both pointer and value types can be looked up + name, err := r.Lookup(p) + is.NoErr(err) + is.Equal(name, "person") + + // Test with value (dereference) + name, err = r.Lookup(*p) + is.NoErr(err) + is.Equal(name, "person") +} + +func TestNatsRegistryWithComplexSchema(t *testing.T) { + is := testutil.NewIs(t) + + // Create a more complex schema with validation rules + tmpDir := t.TempDir() + schemaPath := filepath.Join(tmpDir, "complex.json") + complexSchema := `{ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "name": { + "type": "string", + "minLength": 2, + "maxLength": 100, + "pattern": "^[a-zA-Z ]+$" + }, + "age": { + "type": "integer", + "minimum": 18, + "maximum": 100 + }, + "email": { + "type": "string", + "format": "email" + } + }, + "required": ["name", "age"], + "additionalProperties": false + }` + err := os.WriteFile(schemaPath, []byte(complexSchema), 0644) + is.NoErr(err) + + srv := testutil.NewNatsServer(-1) + defer testutil.ShutdownNatsServer(srv) + + nc, err := nats.Connect(srv.ClientURL()) + is.NoErr(err) + + types := map[string]Type{ + "person": NatsType{ + InitFn: func() any { return &Person{} }, + Name: "Person", + Description: "Person with validation", + DocPath: schemaPath, + }, + } + + r, err := NewNatsRegistry(context.Background(), slog.New(slog.DiscardHandler), types, codec.Default, nc) + is.NoErr(err) + + // Test with valid data + p1 := &Person{Name: "John Doe", Age: 30, Email: "john@example.com"} + b, err := r.Marshal(p1) + is.NoErr(err) + + p2 := &Person{} + err = r.Unmarshal(b, p2) + is.NoErr(err) + is.Equal(p1.Name, p2.Name) +} + +func TestNatsRegistrySchemaValidation(t *testing.T) { + is := testutil.NewIs(t) + + // Test that NATS registry can work without schema files + srv := testutil.NewNatsServer(-1) + defer testutil.ShutdownNatsServer(srv) + + nc, err := nats.Connect(srv.ClientURL()) + is.NoErr(err) + + // NatsType without DocPath should work + types := map[string]Type{ + "person": NatsType{ + InitFn: func() any { return &Person{} }, + Name: "Person", + // No DocPath - should still work + }, + } + + r, err := NewNatsRegistry(context.Background(), nil, types, codec.Default, nc) + is.NoErr(err) + + // Init should work now with factory functions in memory + v, err := r.Init("person") + is.NoErr(err) + is.True(v != nil) + _, ok := v.(*Person) + is.True(ok) + + p1 := &Person{Name: "Test", Age: 25} + b, err := r.Marshal(p1) + is.NoErr(err) + + p2 := &Person{} + err = r.Unmarshal(b, p2) + is.NoErr(err) + is.Equal(p1.Name, p2.Name) +} + +func BenchmarkNatsInit(b *testing.B) { + srv := testutil.NewNatsServer(-1) + defer testutil.ShutdownNatsServer(srv) + + nc, _ := nats.Connect(srv.ClientURL()) + + // Clean up schemas bucket after benchmark + b.Cleanup(func() { + cleanupSchemasBucket(&testing.T{}, nc) + }) + + // Create minimal schema + tmpDir := b.TempDir() + schemaPath := filepath.Join(tmpDir, "person.json") + _ = os.WriteFile(schemaPath, []byte(`{"type": "object"}`), 0644) + + r, _ := NewNatsRegistry(context.Background(), nil, map[string]Type{ + "person": NatsType{ + InitFn: func() any { return &Person{} }, + Name: "Person", + DocPath: schemaPath, + }, + }, codec.Default, nc) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, _ = r.Init("person") + } +} + +func BenchmarkNatsLookup(b *testing.B) { + srv := testutil.NewNatsServer(-1) + defer testutil.ShutdownNatsServer(srv) + + nc, _ := nats.Connect(srv.ClientURL()) + + // Clean up schemas bucket after benchmark + b.Cleanup(func() { + cleanupSchemasBucket(&testing.T{}, nc) + }) + + // Create minimal schema + tmpDir := b.TempDir() + schemaPath := filepath.Join(tmpDir, "person.json") + _ = os.WriteFile(schemaPath, []byte(`{"type": "object"}`), 0644) + + r, _ := NewNatsRegistry(context.Background(), nil, map[string]Type{ + "person": NatsType{ + InitFn: func() any { return &Person{} }, + Name: "Person", + DocPath: schemaPath, + }, + }, codec.Default, nc) + + v := &Person{} + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, _ = r.Lookup(v) + } +} + +func BenchmarkNatsMarshalUnmarshal(b *testing.B) { + srv := testutil.NewNatsServer(-1) + defer testutil.ShutdownNatsServer(srv) + + nc, _ := nats.Connect(srv.ClientURL()) + + // Clean up schemas bucket after benchmark + b.Cleanup(func() { + cleanupSchemasBucket(&testing.T{}, nc) + }) + + // Create minimal schema + tmpDir := b.TempDir() + schemaPath := filepath.Join(tmpDir, "person.json") + _ = os.WriteFile(schemaPath, []byte(`{"type": "object"}`), 0644) + + r, _ := NewNatsRegistry(context.Background(), nil, map[string]Type{ + "person": NatsType{ + InitFn: func() any { return &Person{} }, + Name: "Person", + DocPath: schemaPath, + }, + }, codec.Default, nc) + + p := &Person{Name: "John Doe", Age: 30, Email: "john@example.com"} + data, _ := r.Marshal(p) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + target := &Person{} + _ = r.Unmarshal(data, target) + } +} + +func TestNatsRegistryDataValidation(t *testing.T) { + is := testutil.NewIs(t) + + // Create test schema file with validation rules + tmpDir := t.TempDir() + personSchemaPath := filepath.Join(tmpDir, "person.json") + personSchema := `{ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "name": { + "type": "string", + "minLength": 2, + "maxLength": 50 + }, + "age": { + "type": "integer", + "minimum": 0, + "maximum": 150 + }, + "email": { + "type": "string", + "format": "email" + } + }, + "required": ["name", "age"], + "additionalProperties": false + }` + err := os.WriteFile(personSchemaPath, []byte(personSchema), 0644) + is.NoErr(err) + + srv := testutil.NewNatsServer(-1) + defer testutil.ShutdownNatsServer(srv) + + nc, err := nats.Connect(srv.ClientURL()) + is.NoErr(err) + + types := map[string]Type{ + "person": NatsType{ + InitFn: func() any { return &Person{} }, + Name: "Person", + DocPath: personSchemaPath, + }, + } + + r, err := NewNatsRegistry(context.Background(), slog.New(slog.DiscardHandler), types, codec.Default, nc) + is.NoErr(err) + + // Test valid data passes validation + validPerson := &Person{Name: "John Doe", Age: 30, Email: "john@example.com"} + _, err = r.Marshal(validPerson) + is.NoErr(err) + + // Test invalid age (negative) + invalidAge := &Person{Name: "John", Age: -5} + _, err = r.Marshal(invalidAge) + is.True(err != nil) // Should fail validation + if err != nil { + is.True(strings.Contains(err.Error(), "validation failed")) + } + + // Note: Go's int defaults to 0, which satisfies minimum:0 constraint + // So we can't test "missing" age field this way + + // Test invalid email format + invalidEmail := &Person{Name: "John", Age: 25, Email: "not-an-email"} + _, err = r.Marshal(invalidEmail) + is.True(err != nil) // Should fail validation + if err != nil { + is.True(strings.Contains(err.Error(), "validation failed")) + } + + // Test name too short + shortName := &Person{Name: "J", Age: 25} + _, err = r.Marshal(shortName) + is.True(err != nil) // Should fail validation + if err != nil { + is.True(strings.Contains(err.Error(), "validation failed")) + } +} diff --git a/types/registry.go b/types/registry.go index b00b504..8883597 100644 --- a/types/registry.go +++ b/types/registry.go @@ -3,7 +3,6 @@ package types import ( "errors" "fmt" - "reflect" "regexp" "github.com/synadia-labs/rita/codec" @@ -19,204 +18,25 @@ var ( nameRegex = regexp.MustCompile(`^[\w-]+(\.[\w-]+)*$`) ) -func validateTypeName(n string) error { - if !nameRegex.MatchString(n) { - return fmt.Errorf("%w: name %q has invalid characters", ErrTypeNotValid, n) - } - return nil -} - -type Type struct { - Init func() any +type Type interface { + Init() func() any // TODO: support schema? // Schema } -type registryOption func(o *Registry) error - -func (f registryOption) addOption(o *Registry) error { - return f(o) -} - -// RegistryOption models a option when creating a type registry. -type RegistryOption interface { - addOption(o *Registry) error -} - -// Codec is a registry option to define the desired serialization codec. -func Codec(name string) RegistryOption { - return registryOption(func(o *Registry) error { - c, ok := codec.Codecs[name] - if !ok { - return fmt.Errorf("%w: %s", codec.ErrCodecNotRegistered, name) - } - - o.codec = c - return nil - }) -} - -// Registry is used for transparently marshaling and unmarshaling messages -// and values from their native types to their network/storage representation. -type Registry struct { - // Codec for marshaling and unmarshaling a values. - codec codec.Codec - - // Index of types. - types map[string]*Type - - // Reflection type to the type name. - rtypes map[reflect.Type]string -} - -func (r *Registry) Codec() codec.Codec { - return r.codec +type Registry interface { + Codec() codec.Codec + Init(t string) (any, error) + Lookup(v any) (string, error) + Marshal(v any) ([]byte, error) + Unmarshal(b []byte, v any) error + UnmarshalType(b []byte, t string) (any, error) } -func (r *Registry) validate(name string, typ *Type) error { - if name == "" { - return fmt.Errorf("%w: missing name", ErrTypeNotValid) - } - - if err := validateTypeName(name); err != nil { - return err - } - - if typ.Init == nil { - return fmt.Errorf("%w: %s: init func is nil", ErrTypeNotValid, name) - } - - // Ensure the initialize value is not nil. - v := typ.Init() - if v == nil { - return fmt.Errorf("%w: %s: init func returns nil", ErrTypeNotValid, name) - } - - // Get the Go type in order to transparently serialize to the correct name. - rt := reflect.TypeOf(v) - - // Ensure the initialize type is a pointer so that deserialization works. - if rt.Kind() != reflect.Ptr { - return fmt.Errorf("%w: %s: init func must return a pointer value", ErrTypeNotValid, name) - } - - // Ensure that the pointer value is a struct type. - if rt.Elem().Kind() != reflect.Struct { - return fmt.Errorf("%w: %s: value type must be a struct", ErrTypeNotValid, name) - } - - // Ensure [de]serialization works in the base case. - b, err := r.codec.Marshal(v) - if err != nil { - return fmt.Errorf("%w: %s: failed to marshal with codec: %s", ErrTypeNotValid, name, err) - } - - err = r.codec.Unmarshal(b, v) - if err != nil { - return fmt.Errorf("%w: %s: failed to unmarshal with codec: %s", ErrTypeNotValid, name, err) - } - - return nil -} - -func (r *Registry) addType(name string, typ *Type) { - r.types[name] = typ - - // Initialize a value, reflect the type to index. - v := typ.Init() - rt := reflect.TypeOf(v) - - r.rtypes[rt] = name - r.rtypes[rt.Elem()] = name -} - -// Init a value given the registered name of the type. -func (r *Registry) Init(t string) (any, error) { - x, ok := r.types[t] - if !ok { - return nil, fmt.Errorf("%w: %s", ErrTypeNotRegistered, t) - } - - v := x.Init() - return v, nil -} - -// Lookup returns the registered name of the type given a value. -func (r *Registry) Lookup(v any) (string, error) { - rt := reflect.TypeOf(v) - t, ok := r.rtypes[rt] - if !ok { - return "", fmt.Errorf("%w: %s", ErrNoTypeForStruct, rt) - } - - return t, nil -} - -// Marshal serializes the value to a byte slice. This call -// validates the type is registered and delegates to the codec. -func (r *Registry) Marshal(v any) ([]byte, error) { - _, err := r.Lookup(v) - if err != nil { - return nil, err - } - - b, err := r.codec.Marshal(v) - if err != nil { - return b, fmt.Errorf("%T: marshal error: %w", v, err) - } - return b, nil -} - -// Unmarshal deserializes a byte slice into the value. This call -// validates the type is registered and delegates to the codec. -func (r *Registry) Unmarshal(b []byte, v any) error { - _, err := r.Lookup(v) - if err != nil { - return err - } - - err = r.codec.Unmarshal(b, v) - if err != nil { - return fmt.Errorf("%T: unmarshal error: %w", v, err) +func validateTypeName(n string) error { + if !nameRegex.MatchString(n) { + return fmt.Errorf("%w: name %q has invalid characters", ErrTypeNotValid, n) } return nil } - -// UnmarshalType initializes a new value for the registered type, -// unmarshals the byte slice, and returns it. -func (r *Registry) UnmarshalType(b []byte, t string) (any, error) { - v, err := r.Init(t) - if err != nil { - return nil, err - } - err = r.Unmarshal(b, v) - if err != nil { - return nil, err - } - return v, nil -} - -func NewRegistry(types map[string]*Type, opts ...RegistryOption) (*Registry, error) { - r := &Registry{ - codec: codec.Default, - types: make(map[string]*Type), - rtypes: make(map[reflect.Type]string), - } - - for _, f := range opts { - if err := f.addOption(r); err != nil { - return nil, err - } - } - - for n, t := range types { - err := r.validate(n, t) - if err != nil { - return nil, err - } - r.addType(n, t) - } - - return r, nil -}