diff --git a/pkg/specter/unitloading.go b/pkg/specter/unitloading.go index 50826f1..be9c310 100644 --- a/pkg/specter/unitloading.go +++ b/pkg/specter/unitloading.go @@ -77,6 +77,11 @@ func UnwrapUnit[T any](unit Unit) (value T, ok bool) { return v, ok } +func UnwrapUnitSafe[T any](unit Unit) T { + t, _ := UnwrapUnit[T](unit) + return t +} + func (w *WrappingUnit) ID() UnitID { return w.id } diff --git a/pkg/specter/unitloading_test.go b/pkg/specter/unitloading_test.go index dfd8d44..30f6c41 100644 --- a/pkg/specter/unitloading_test.go +++ b/pkg/specter/unitloading_test.go @@ -447,3 +447,30 @@ func TestUnwrapUnit(t *testing.T) { }) } } + +func TestUnwrapUnitSafe(t *testing.T) { + + type testCase struct { + name string + when specter.Unit + then string + } + tests := []testCase{ + { + name: "Unwrap of non wrapped unit should return zero value", + when: testutils.NewUnitStub("id", "kind", specter.Source{}), + then: "", + }, + { + name: "Unwrap of a wrapped unit should return the value", + when: specter.UnitOf("hello", "id", "kind", specter.Source{}), + then: "hello", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotValue := specter.UnwrapUnitSafe[string](tt.when) + assert.Equal(t, tt.then, gotValue) + }) + } +}