From b8c2a5cda0035a9f2f5fff8757706cb8e076a163 Mon Sep 17 00:00:00 2001 From: gyuwonMoon <78714820+MoonGyu1@users.noreply.github.com> Date: Mon, 7 Aug 2023 11:32:50 +0900 Subject: [PATCH] Remove panic from crdt.Counter (#598) To prevent stopping the server in the production environment, the panic methods were replaced with error methods. Modify test codes for counter As the panic methods in the counter were replaced, the related test codes also were modified. --- api/converter/from_bytes.go | 11 +++- api/converter/from_pb.go | 15 ++++- api/converter/to_bytes.go | 6 +- api/converter/to_pb.go | 6 +- pkg/document/crdt/counter.go | 94 +++++++++++++++++----------- pkg/document/crdt/counter_test.go | 95 +++++++++++++++++------------ pkg/document/json/counter.go | 4 +- pkg/document/json/object.go | 12 +++- pkg/document/operations/increase.go | 4 +- 9 files changed, 159 insertions(+), 88 deletions(-) diff --git a/api/converter/from_bytes.go b/api/converter/from_bytes.go index 04662aded..3c18f9971 100644 --- a/api/converter/from_bytes.go +++ b/api/converter/from_bytes.go @@ -267,12 +267,19 @@ func fromJSONCounter(pbCnt *api.JSONElement_Counter) (*crdt.Counter, error) { if err != nil { return nil, err } + counterValue, err := crdt.CounterValueFromBytes(counterType, pbCnt.Value) + if err != nil { + return nil, err + } - counter := crdt.NewCounter( + counter, err := crdt.NewCounter( counterType, - crdt.CounterValueFromBytes(counterType, pbCnt.Value), + counterValue, createdAt, ) + if err != nil { + return nil, err + } counter.SetMovedAt(movedAt) counter.SetRemovedAt(removedAt) diff --git a/api/converter/from_pb.go b/api/converter/from_pb.go index f63585a2d..868aefd59 100644 --- a/api/converter/from_pb.go +++ b/api/converter/from_pb.go @@ -737,11 +737,20 @@ func fromElement(pbElement *api.JSONElementSimple) (crdt.Element, error) { if err != nil { return nil, err } - return crdt.NewCounter( + counterValue, err := crdt.CounterValueFromBytes(counterType, pbElement.Value) + if err != nil { + return nil, err + } + + counter, err := crdt.NewCounter( counterType, - crdt.CounterValueFromBytes(counterType, pbElement.Value), + counterValue, createdAt, - ), nil + ) + if err != nil { + return nil, err + } + return counter, nil case api.ValueType_VALUE_TYPE_TREE: return BytesToTree(pbElement.Value) } diff --git a/api/converter/to_bytes.go b/api/converter/to_bytes.go index 7313e9f51..42d3ce711 100644 --- a/api/converter/to_bytes.go +++ b/api/converter/to_bytes.go @@ -158,11 +158,15 @@ func toCounter(counter *crdt.Counter) (*api.JSONElement, error) { if err != nil { return nil, err } + counterValue, err := counter.Bytes() + if err != nil { + return nil, err + } return &api.JSONElement{ Body: &api.JSONElement_Counter_{Counter: &api.JSONElement_Counter{ Type: pbCounterType, - Value: counter.Bytes(), + Value: counterValue, CreatedAt: ToTimeTicket(counter.CreatedAt()), MovedAt: ToTimeTicket(counter.MovedAt()), RemovedAt: ToTimeTicket(counter.RemovedAt()), diff --git a/api/converter/to_pb.go b/api/converter/to_pb.go index 49d736ff0..b4d22e570 100644 --- a/api/converter/to_pb.go +++ b/api/converter/to_pb.go @@ -435,11 +435,15 @@ func toJSONElementSimple(elem crdt.Element) (*api.JSONElementSimple, error) { if err != nil { return nil, err } + counterValue, err := elem.Bytes() + if err != nil { + return nil, err + } return &api.JSONElementSimple{ Type: pbCounterType, CreatedAt: ToTimeTicket(elem.CreatedAt()), - Value: elem.Bytes(), + Value: counterValue, }, nil case *crdt.Tree: bytes, err := TreeToBytes(elem) diff --git a/pkg/document/crdt/counter.go b/pkg/document/crdt/counter.go index 0bf7c5159..15335b6de 100644 --- a/pkg/document/crdt/counter.go +++ b/pkg/document/crdt/counter.go @@ -18,11 +18,15 @@ package crdt import ( "encoding/binary" + "errors" "fmt" "github.com/yorkie-team/yorkie/pkg/document/time" ) +// ErrUnsupportedType returned when the given type is not supported. +var ErrUnsupportedType = errors.New("unsupported type") + // CounterType represents any type that can be used as a counter. type CounterType int @@ -33,16 +37,16 @@ const ( ) // CounterValueFromBytes parses the given bytes into value. -func CounterValueFromBytes(counterType CounterType, value []byte) interface{} { +func CounterValueFromBytes(counterType CounterType, value []byte) (interface{}, error) { switch counterType { case IntegerCnt: val := int32(binary.LittleEndian.Uint32(value)) - return int(val) + return int(val), nil case LongCnt: - return int64(binary.LittleEndian.Uint64(value)) + return int64(binary.LittleEndian.Uint64(value)), nil + default: + return nil, ErrUnsupportedType } - - panic("unsupported type") } // Counter represents changeable number data type. @@ -55,39 +59,47 @@ type Counter struct { } // NewCounter creates a new instance of Counter. -func NewCounter(valueType CounterType, value interface{}, createdAt *time.Ticket) *Counter { +func NewCounter(valueType CounterType, value interface{}, createdAt *time.Ticket) (*Counter, error) { switch valueType { case IntegerCnt: + intValue, err := castToInt(value) + if err != nil { + return nil, err + } return &Counter{ valueType: IntegerCnt, - value: castToInt(value), + value: intValue, createdAt: createdAt, - } + }, nil case LongCnt: + longValue, err := castToLong(value) + if err != nil { + return nil, err + } return &Counter{ valueType: LongCnt, - value: castToLong(value), + value: longValue, createdAt: createdAt, - } + }, nil + default: + return nil, ErrUnsupportedType } - - panic("unsupported type") } // Bytes creates an array representing the value. -func (p *Counter) Bytes() []byte { +func (p *Counter) Bytes() ([]byte, error) { switch val := p.value.(type) { case int32: bytes := [4]byte{} binary.LittleEndian.PutUint32(bytes[:], uint32(val)) - return bytes[:] + return bytes[:], nil case int64: bytes := [8]byte{} binary.LittleEndian.PutUint64(bytes[:], uint64(val)) - return bytes[:] + return bytes[:], nil + default: + return nil, ErrUnsupportedType } - - panic("unsupported type") } // Marshal returns the JSON encoding of the value. @@ -146,20 +158,28 @@ func (p *Counter) ValueType() CounterType { // than MinInt32, Counter's value type can be changed Integer to Long. // Because in golang, int can be either int32 or int64. // So we need to assert int to int32. -func (p *Counter) Increase(v *Primitive) *Counter { +func (p *Counter) Increase(v *Primitive) (*Counter, error) { if !p.IsNumericType() || !v.IsNumericType() { - panic("unsupported type") + return nil, ErrUnsupportedType } switch p.valueType { case IntegerCnt: - p.value = p.value.(int32) + castToInt(v.value) + intValue, err := castToInt(v.value) + if err != nil { + return nil, err + } + p.value = p.value.(int32) + intValue case LongCnt: - p.value = p.value.(int64) + castToLong(v.value) + longValue, err := castToLong(v.value) + if err != nil { + return nil, err + } + p.value = p.value.(int64) + longValue default: - panic("unsupported type") + return nil, ErrUnsupportedType } - return p + return p, nil } // IsNumericType checks for numeric types. @@ -169,37 +189,37 @@ func (p *Counter) IsNumericType() bool { } // castToInt casts numeric type to int32. -func castToInt(value interface{}) int32 { +func castToInt(value interface{}) (int32, error) { switch val := value.(type) { case int32: - return val + return val, nil case int64: - return int32(val) + return int32(val), nil case int: - return int32(val) + return int32(val), nil case float32: - return int32(val) + return int32(val), nil case float64: - return int32(val) + return int32(val), nil default: - panic("unsupported type") + return 0, ErrUnsupportedType } } // castToLong casts numeric type to int64. -func castToLong(value interface{}) int64 { +func castToLong(value interface{}) (int64, error) { switch val := value.(type) { case int64: - return val + return val, nil case int32: - return int64(val) + return int64(val), nil case int: - return int64(val) + return int64(val), nil case float32: - return int64(val) + return int64(val), nil case float64: - return int64(val) + return int64(val), nil default: - panic("unsupported type") + return 0, ErrUnsupportedType } } diff --git a/pkg/document/crdt/counter_test.go b/pkg/document/crdt/counter_test.go index de3345dfb..d83a2b066 100644 --- a/pkg/document/crdt/counter_test.go +++ b/pkg/document/crdt/counter_test.go @@ -30,81 +30,94 @@ import ( func TestCounter(t *testing.T) { t.Run("new counter test", func(t *testing.T) { - intCntWithInt32Value := crdt.NewCounter(crdt.IntegerCnt, int32(math.MaxInt32), time.InitialTicket) + intCntWithInt32Value, err := crdt.NewCounter(crdt.IntegerCnt, int32(math.MaxInt32), time.InitialTicket) + assert.NoError(t, err) assert.Equal(t, crdt.IntegerCnt, intCntWithInt32Value.ValueType()) - intCntWithInt64Value := crdt.NewCounter(crdt.IntegerCnt, int64(math.MaxInt32+1), time.InitialTicket) + intCntWithInt64Value, err := crdt.NewCounter(crdt.IntegerCnt, int64(math.MaxInt32+1), time.InitialTicket) + assert.NoError(t, err) assert.Equal(t, crdt.IntegerCnt, intCntWithInt64Value.ValueType()) - intCntWithIntValue := crdt.NewCounter(crdt.IntegerCnt, math.MaxInt32, time.InitialTicket) + intCntWithIntValue, err := crdt.NewCounter(crdt.IntegerCnt, math.MaxInt32, time.InitialTicket) + assert.NoError(t, err) assert.Equal(t, crdt.IntegerCnt, intCntWithIntValue.ValueType()) - intCntWithDoubleValue := crdt.NewCounter(crdt.IntegerCnt, 0.5, time.InitialTicket) + intCntWithDoubleValue, err := crdt.NewCounter(crdt.IntegerCnt, 0.5, time.InitialTicket) + assert.NoError(t, err) assert.Equal(t, crdt.IntegerCnt, intCntWithDoubleValue.ValueType()) - intCntWithUnsupportedValue := func() { crdt.NewCounter(crdt.IntegerCnt, "", time.InitialTicket) } - assert.Panics(t, intCntWithUnsupportedValue) + _, err = crdt.NewCounter(crdt.IntegerCnt, "", time.InitialTicket) + assert.ErrorIs(t, err, crdt.ErrUnsupportedType) - longCntWithInt32Value := crdt.NewCounter(crdt.LongCnt, int32(math.MaxInt32), time.InitialTicket) + longCntWithInt32Value, err := crdt.NewCounter(crdt.LongCnt, int32(math.MaxInt32), time.InitialTicket) + assert.NoError(t, err) assert.Equal(t, crdt.LongCnt, longCntWithInt32Value.ValueType()) - longCntWithInt64Value := crdt.NewCounter(crdt.LongCnt, int64(math.MaxInt32+1), time.InitialTicket) + longCntWithInt64Value, err := crdt.NewCounter(crdt.LongCnt, int64(math.MaxInt32+1), time.InitialTicket) + assert.NoError(t, err) assert.Equal(t, crdt.LongCnt, longCntWithInt64Value.ValueType()) - longCntWithIntValue := crdt.NewCounter(crdt.LongCnt, math.MaxInt32+1, time.InitialTicket) + longCntWithIntValue, err := crdt.NewCounter(crdt.LongCnt, math.MaxInt32+1, time.InitialTicket) + assert.NoError(t, err) assert.Equal(t, crdt.LongCnt, longCntWithIntValue.ValueType()) - longCntWithDoubleValue := crdt.NewCounter(crdt.LongCnt, 0.5, time.InitialTicket) + longCntWithDoubleValue, err := crdt.NewCounter(crdt.LongCnt, 0.5, time.InitialTicket) + assert.NoError(t, err) assert.Equal(t, crdt.LongCnt, longCntWithDoubleValue.ValueType()) - longCntWithUnsupportedValue := func() { crdt.NewCounter(crdt.LongCnt, "", time.InitialTicket) } - assert.Panics(t, longCntWithUnsupportedValue) + _, err = crdt.NewCounter(crdt.LongCnt, "", time.InitialTicket) + assert.ErrorIs(t, err, crdt.ErrUnsupportedType) }) t.Run("increase test", func(t *testing.T) { var x = 5 var y int64 = 10 var z = 3.14 - integer := crdt.NewCounter(crdt.IntegerCnt, x, time.InitialTicket) - long := crdt.NewCounter(crdt.LongCnt, y, time.InitialTicket) - double := crdt.NewCounter(crdt.IntegerCnt, z, time.InitialTicket) + integer, err := crdt.NewCounter(crdt.IntegerCnt, x, time.InitialTicket) + assert.NoError(t, err) + long, err := crdt.NewCounter(crdt.LongCnt, y, time.InitialTicket) + assert.NoError(t, err) + double, err := crdt.NewCounter(crdt.IntegerCnt, z, time.InitialTicket) + assert.NoError(t, err) integerOperand := crdt.NewPrimitive(x, time.InitialTicket) longOperand := crdt.NewPrimitive(y, time.InitialTicket) doubleOperand := crdt.NewPrimitive(z, time.InitialTicket) // normal process test - integer.Increase(integerOperand) - integer.Increase(longOperand) - integer.Increase(doubleOperand) + _, err = integer.Increase(integerOperand) + assert.NoError(t, err) + _, err = integer.Increase(longOperand) + assert.NoError(t, err) + _, err = integer.Increase(doubleOperand) + assert.NoError(t, err) assert.Equal(t, integer.Marshal(), "23") - long.Increase(integerOperand) - long.Increase(longOperand) - long.Increase(doubleOperand) + _, err = long.Increase(integerOperand) + assert.NoError(t, err) + _, err = long.Increase(longOperand) + assert.NoError(t, err) + _, err = long.Increase(doubleOperand) + assert.NoError(t, err) assert.Equal(t, long.Marshal(), "28") - double.Increase(integerOperand) - double.Increase(longOperand) - double.Increase(doubleOperand) + _, err = double.Increase(integerOperand) + assert.NoError(t, err) + _, err = double.Increase(longOperand) + assert.NoError(t, err) + _, err = double.Increase(doubleOperand) + assert.NoError(t, err) assert.Equal(t, double.Marshal(), "21") // error process test - // TODO: it should be modified to error check - // when 'Remove panic from server code (#50)' is completed. - unsupportedTypePanicTest := func() { - r := recover() - assert.NotNil(t, r) - assert.Equal(t, r, "unsupported type") + unsupportedTypeErrorTest := func(v interface{}) { + _, err = crdt.NewCounter(crdt.IntegerCnt, v, time.InitialTicket) + assert.ErrorIs(t, err, crdt.ErrUnsupportedType) } - unsupportedTest := func(v interface{}) { - defer unsupportedTypePanicTest() - crdt.NewCounter(crdt.IntegerCnt, v, time.InitialTicket) - } - unsupportedTest("str") - unsupportedTest(true) - unsupportedTest([]byte{2}) - unsupportedTest(gotime.Now()) + unsupportedTypeErrorTest("str") + unsupportedTypeErrorTest(true) + unsupportedTypeErrorTest([]byte{2}) + unsupportedTypeErrorTest(gotime.Now()) assert.Equal(t, integer.Marshal(), "23") assert.Equal(t, long.Marshal(), "28") @@ -112,11 +125,13 @@ func TestCounter(t *testing.T) { }) t.Run("Counter value overflow test", func(t *testing.T) { - integer := crdt.NewCounter(crdt.IntegerCnt, math.MaxInt32, time.InitialTicket) + integer, err := crdt.NewCounter(crdt.IntegerCnt, math.MaxInt32, time.InitialTicket) + assert.NoError(t, err) assert.Equal(t, integer.ValueType(), crdt.IntegerCnt) operand := crdt.NewPrimitive(1, time.InitialTicket) - integer.Increase(operand) + _, err = integer.Increase(operand) + assert.NoError(t, err) assert.Equal(t, integer.ValueType(), crdt.IntegerCnt) assert.Equal(t, integer.Marshal(), strconv.FormatInt(math.MinInt32, 10)) }) diff --git a/pkg/document/json/counter.go b/pkg/document/json/counter.go index 668ebd23c..0ae72809a 100644 --- a/pkg/document/json/counter.go +++ b/pkg/document/json/counter.go @@ -71,7 +71,9 @@ func (p *Counter) Increase(v interface{}) *Counter { panic("unsupported type") } - p.Counter.Increase(primitive) + if _, err := p.Counter.Increase(primitive); err != nil { + panic(err) + } p.context.Push(operations.NewIncrease( p.CreatedAt(), diff --git a/pkg/document/json/object.go b/pkg/document/json/object.go index 2cf157480..11b44bf24 100644 --- a/pkg/document/json/object.go +++ b/pkg/document/json/object.go @@ -75,14 +75,22 @@ func (p *Object) SetNewCounter(k string, t crdt.CounterType, n interface{}) *Cou v := p.setInternal(k, func(ticket *time.Ticket) crdt.Element { switch t { case crdt.IntegerCnt: + counter, err := crdt.NewCounter(crdt.IntegerCnt, n, ticket) + if err != nil { + panic(err) + } return NewCounter( p.context, - crdt.NewCounter(crdt.IntegerCnt, n, ticket), + counter, ) case crdt.LongCnt: + counter, err := crdt.NewCounter(crdt.LongCnt, n, ticket) + if err != nil { + panic(err) + } return NewCounter( p.context, - crdt.NewCounter(crdt.LongCnt, n, ticket), + counter, ) default: panic("unsupported type") diff --git a/pkg/document/operations/increase.go b/pkg/document/operations/increase.go index cff99aeea..5b3d126a6 100644 --- a/pkg/document/operations/increase.go +++ b/pkg/document/operations/increase.go @@ -51,7 +51,9 @@ func (o *Increase) Execute(root *crdt.Root) error { } value := o.value.(*crdt.Primitive) - cnt.Increase(value) + if _, err := cnt.Increase(value); err != nil { + return err + } return nil }