diff --git a/eip712_cosmos.go b/eip712_cosmos.go index 6ce553e5..8b3c22b8 100644 --- a/eip712_cosmos.go +++ b/eip712_cosmos.go @@ -8,6 +8,7 @@ import ( "reflect" "strings" + codectypes "github.com/cosmos/cosmos-sdk/codec/types" cosmtypes "github.com/cosmos/cosmos-sdk/types" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" @@ -18,7 +19,7 @@ import ( // WrapTxToEIP712 is an ultimate method that wraps Amino-encoded Cosmos Tx JSON data // into an EIP712-compatible request. All messages must be of the same type. -func WrapTxToEIP712(chainID uint64, msg cosmtypes.Msg, data []byte) (typeddata.TypedData, error) { +func WrapTxToEIP712(cdc codectypes.AnyUnpacker, chainID uint64, msg cosmtypes.Msg, data []byte) (typeddata.TypedData, error) { txData := make(map[string]interface{}) if err := json.Unmarshal(data, &txData); err != nil { err = errors.Wrap(err, "failed to unmarshal data provided into WrapTxToEIP712") @@ -33,8 +34,13 @@ func WrapTxToEIP712(chainID uint64, msg cosmtypes.Msg, data []byte) (typeddata.T Salt: "0", } + msgTypes, err := extractMsgTypes(cdc, "MsgValue", msg) + if err != nil { + return typeddata.TypedData{}, err + } + var typedData = typeddata.TypedData{ - Types: extractMsgTypes("MsgValue", msg), + Types: msgTypes, PrimaryType: "Tx", Domain: domain, Message: txData, @@ -43,7 +49,7 @@ func WrapTxToEIP712(chainID uint64, msg cosmtypes.Msg, data []byte) (typeddata.T return typedData, nil } -func extractMsgTypes(msgTypeName string, msg cosmtypes.Msg) typeddata.Types { +func extractMsgTypes(cdc codectypes.AnyUnpacker, msgTypeName string, msg cosmtypes.Msg) (typeddata.Types, error) { rootTypes := typeddata.Types{ "EIP712Domain": { { @@ -90,19 +96,18 @@ func extractMsgTypes(msgTypeName string, msg cosmtypes.Msg) typeddata.Types { msgTypeName: {}, } - walkFields(rootTypes, msgTypeName, msg) + err := walkFields(cdc, rootTypes, msgTypeName, msg) + if err != nil { + return nil, err + } - return rootTypes + return rootTypes, nil } const typeDefPrefix = "_" -func walkFields(typeMap typeddata.Types, rootType string, in interface{}) { - defer func() { - if x := recover(); x != nil { - return - } - }() +func walkFields(cdc codectypes.AnyUnpacker, typeMap typeddata.Types, rootType string, in interface{}) (err error) { + defer doRecover(&err) t := reflect.TypeOf(in) v := reflect.ValueOf(in) @@ -119,10 +124,23 @@ func walkFields(typeMap typeddata.Types, rootType string, in interface{}) { break } - traverseFields(typeMap, rootType, typeDefPrefix, t, v) + err = traverseFields(cdc, typeMap, rootType, typeDefPrefix, t, v) + return } -func traverseFields(typeMap typeddata.Types, rootType string, prefix string, t reflect.Type, v reflect.Value) { +type cosmosAnyWrapper struct { + Type string `json:"type"` + Value interface{} `json:"value"` +} + +func traverseFields( + cdc codectypes.AnyUnpacker, + typeMap typeddata.Types, + rootType string, + prefix string, + t reflect.Type, + v reflect.Value, +) (err error) { n := t.NumField() if prefix == typeDefPrefix { @@ -143,37 +161,86 @@ func traverseFields(typeMap typeddata.Types, rootType string, prefix string, t r } fieldType := t.Field(i).Type + fieldName := jsonNameFromTag(t.Field(i).Tag) + + if fieldType == cosmosAnyType { + any := field.Interface().(*codectypes.Any) + anyWrapper := &cosmosAnyWrapper{ + Type: any.TypeUrl, + } + + err = cdc.UnpackAny(any, &anyWrapper.Value) + if err != nil { + err = errors.Wrap(err, "failed to unpack Any in msg struct") + return + } + + fieldType = reflect.TypeOf(anyWrapper) + field = reflect.ValueOf(anyWrapper) + + // then continue as normal + } for { - if fieldType.Kind() == reflect.Ptr || fieldType.Kind() == reflect.Interface { + if fieldType.Kind() == reflect.Ptr { fieldType = fieldType.Elem() if field.IsValid() { field = field.Elem() } + + continue + } + + if fieldType.Kind() == reflect.Interface { + fieldType = reflect.TypeOf(field.Interface()) continue } + + if field.Kind() == reflect.Ptr { + field = field.Elem() + continue + } + break } - if fieldType.Kind() == reflect.Array { + var isCollection bool + if fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice { + if field.Len() == 0 { + // skip empty collections from type mapping + continue + } + fieldType = fieldType.Elem() field = field.Index(0) + isCollection = true } for { - if fieldType.Kind() == reflect.Ptr || fieldType.Kind() == reflect.Interface { + if fieldType.Kind() == reflect.Ptr { fieldType = fieldType.Elem() if field.IsValid() { field = field.Elem() } + continue } + + if fieldType.Kind() == reflect.Interface { + fieldType = reflect.TypeOf(field.Interface()) + continue + } + + if field.Kind() == reflect.Ptr { + field = field.Elem() + continue + } + break } - fieldName := jsonNameFromTag(t.Field(i).Tag) fieldPrefix := fmt.Sprintf("%s.%s", prefix, fieldName) ethTyp := typToEth(fieldType) @@ -195,23 +262,36 @@ func traverseFields(typeMap typeddata.Types, rootType string, prefix string, t r } if fieldType.Kind() == reflect.Struct { + var fieldTypedef string + if isCollection { + fieldTypedef = sanitizeTypedef(fieldPrefix) + "[]" + } else { + fieldTypedef = sanitizeTypedef(fieldPrefix) + } + if prefix == typeDefPrefix { typeMap[rootType] = append(typeMap[rootType], typeddata.Type{ Name: fieldName, - Type: sanitizeTypedef(fieldPrefix), + Type: fieldTypedef, }) } else { typeDef := sanitizeTypedef(prefix) typeMap[typeDef] = append(typeMap[typeDef], typeddata.Type{ Name: fieldName, - Type: sanitizeTypedef(fieldPrefix), + Type: fieldTypedef, }) } - traverseFields(typeMap, rootType, fieldPrefix, fieldType, field) + err = traverseFields(cdc, typeMap, rootType, fieldPrefix, fieldType, field) + if err != nil { + return + } + continue } } + + return nil } func jsonNameFromTag(tag reflect.StructTag) string { @@ -244,10 +324,11 @@ func sanitizeTypedef(str string) string { } var ( - hashType = reflect.TypeOf(common.Hash{}) - addressType = reflect.TypeOf(common.Address{}) - bigIntType = reflect.TypeOf(big.Int{}) - cosmIntType = reflect.TypeOf(cosmtypes.Int{}) + hashType = reflect.TypeOf(common.Hash{}) + addressType = reflect.TypeOf(common.Address{}) + bigIntType = reflect.TypeOf(big.Int{}) + cosmIntType = reflect.TypeOf(cosmtypes.Int{}) + cosmosAnyType = reflect.TypeOf(&codectypes.Any{}) ) // typToEth supports only basic types and arrays of basic types. @@ -283,6 +364,11 @@ func typToEth(typ reflect.Type) string { if len(ethName) > 0 { return ethName + "[]" } + case reflect.Array: + ethName := typToEth(typ.Elem()) + if len(ethName) > 0 { + return ethName + "[]" + } case reflect.Ptr: if typ.Elem().ConvertibleTo(bigIntType) || typ.Elem().ConvertibleTo(cosmIntType) { @@ -299,3 +385,15 @@ func typToEth(typ reflect.Type) string { return "" } + +func doRecover(err *error) { + if r := recover(); r != nil { + if e, ok := r.(error); ok { + e = errors.Wrap(e, "panicked with error") + *err = e + return + } + + *err = errors.Errorf("%v", r) + } +}