Skip to content

Commit

Permalink
WrapTxToEIP712 now supports messages with opaque Any types. Like prop…
Browse files Browse the repository at this point in the history
…osals.
  • Loading branch information
Maxim committed Mar 17, 2021
1 parent 23eaa6c commit dee4071
Showing 1 changed file with 122 additions and 24 deletions.
146 changes: 122 additions & 24 deletions eip712_cosmos.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"reflect"
"strings"

codectypes "github.com/cosmos/cosmos-sdk/codec/types"
cosmtypes "github.com/cosmos/cosmos-sdk/types"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/math"
Expand All @@ -18,7 +19,7 @@ import (

// WrapTxToEIP712 is an ultimate method that wraps Amino-encoded Cosmos Tx JSON data
// into an EIP712-compatible request. All messages must be of the same type.
func WrapTxToEIP712(chainID uint64, msg cosmtypes.Msg, data []byte) (typeddata.TypedData, error) {
func WrapTxToEIP712(cdc codectypes.AnyUnpacker, chainID uint64, msg cosmtypes.Msg, data []byte) (typeddata.TypedData, error) {
txData := make(map[string]interface{})
if err := json.Unmarshal(data, &txData); err != nil {
err = errors.Wrap(err, "failed to unmarshal data provided into WrapTxToEIP712")
Expand All @@ -33,8 +34,13 @@ func WrapTxToEIP712(chainID uint64, msg cosmtypes.Msg, data []byte) (typeddata.T
Salt: "0",
}

msgTypes, err := extractMsgTypes(cdc, "MsgValue", msg)
if err != nil {
return typeddata.TypedData{}, err
}

var typedData = typeddata.TypedData{
Types: extractMsgTypes("MsgValue", msg),
Types: msgTypes,
PrimaryType: "Tx",
Domain: domain,
Message: txData,
Expand All @@ -43,7 +49,7 @@ func WrapTxToEIP712(chainID uint64, msg cosmtypes.Msg, data []byte) (typeddata.T
return typedData, nil
}

func extractMsgTypes(msgTypeName string, msg cosmtypes.Msg) typeddata.Types {
func extractMsgTypes(cdc codectypes.AnyUnpacker, msgTypeName string, msg cosmtypes.Msg) (typeddata.Types, error) {
rootTypes := typeddata.Types{
"EIP712Domain": {
{
Expand Down Expand Up @@ -90,19 +96,18 @@ func extractMsgTypes(msgTypeName string, msg cosmtypes.Msg) typeddata.Types {
msgTypeName: {},
}

walkFields(rootTypes, msgTypeName, msg)
err := walkFields(cdc, rootTypes, msgTypeName, msg)
if err != nil {
return nil, err
}

return rootTypes
return rootTypes, nil
}

const typeDefPrefix = "_"

func walkFields(typeMap typeddata.Types, rootType string, in interface{}) {
defer func() {
if x := recover(); x != nil {
return
}
}()
func walkFields(cdc codectypes.AnyUnpacker, typeMap typeddata.Types, rootType string, in interface{}) (err error) {
defer doRecover(&err)

t := reflect.TypeOf(in)
v := reflect.ValueOf(in)
Expand All @@ -119,10 +124,23 @@ func walkFields(typeMap typeddata.Types, rootType string, in interface{}) {
break
}

traverseFields(typeMap, rootType, typeDefPrefix, t, v)
err = traverseFields(cdc, typeMap, rootType, typeDefPrefix, t, v)
return
}

func traverseFields(typeMap typeddata.Types, rootType string, prefix string, t reflect.Type, v reflect.Value) {
type cosmosAnyWrapper struct {
Type string `json:"type"`
Value interface{} `json:"value"`
}

func traverseFields(
cdc codectypes.AnyUnpacker,
typeMap typeddata.Types,
rootType string,
prefix string,
t reflect.Type,
v reflect.Value,
) (err error) {
n := t.NumField()

if prefix == typeDefPrefix {
Expand All @@ -143,37 +161,86 @@ func traverseFields(typeMap typeddata.Types, rootType string, prefix string, t r
}

fieldType := t.Field(i).Type
fieldName := jsonNameFromTag(t.Field(i).Tag)

if fieldType == cosmosAnyType {
any := field.Interface().(*codectypes.Any)
anyWrapper := &cosmosAnyWrapper{
Type: any.TypeUrl,
}

err = cdc.UnpackAny(any, &anyWrapper.Value)
if err != nil {
err = errors.Wrap(err, "failed to unpack Any in msg struct")
return
}

fieldType = reflect.TypeOf(anyWrapper)
field = reflect.ValueOf(anyWrapper)

// then continue as normal
}

for {
if fieldType.Kind() == reflect.Ptr || fieldType.Kind() == reflect.Interface {
if fieldType.Kind() == reflect.Ptr {
fieldType = fieldType.Elem()

if field.IsValid() {
field = field.Elem()
}

continue
}

if fieldType.Kind() == reflect.Interface {
fieldType = reflect.TypeOf(field.Interface())
continue
}

if field.Kind() == reflect.Ptr {
field = field.Elem()
continue
}

break
}

if fieldType.Kind() == reflect.Array {
var isCollection bool
if fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice {
if field.Len() == 0 {
// skip empty collections from type mapping
continue
}

fieldType = fieldType.Elem()
field = field.Index(0)
isCollection = true
}

for {
if fieldType.Kind() == reflect.Ptr || fieldType.Kind() == reflect.Interface {
if fieldType.Kind() == reflect.Ptr {
fieldType = fieldType.Elem()

if field.IsValid() {
field = field.Elem()
}

continue
}

if fieldType.Kind() == reflect.Interface {
fieldType = reflect.TypeOf(field.Interface())
continue
}

if field.Kind() == reflect.Ptr {
field = field.Elem()
continue
}

break
}

fieldName := jsonNameFromTag(t.Field(i).Tag)
fieldPrefix := fmt.Sprintf("%s.%s", prefix, fieldName)

ethTyp := typToEth(fieldType)
Expand All @@ -195,23 +262,36 @@ func traverseFields(typeMap typeddata.Types, rootType string, prefix string, t r
}

if fieldType.Kind() == reflect.Struct {
var fieldTypedef string
if isCollection {
fieldTypedef = sanitizeTypedef(fieldPrefix) + "[]"
} else {
fieldTypedef = sanitizeTypedef(fieldPrefix)
}

if prefix == typeDefPrefix {
typeMap[rootType] = append(typeMap[rootType], typeddata.Type{
Name: fieldName,
Type: sanitizeTypedef(fieldPrefix),
Type: fieldTypedef,
})
} else {
typeDef := sanitizeTypedef(prefix)
typeMap[typeDef] = append(typeMap[typeDef], typeddata.Type{
Name: fieldName,
Type: sanitizeTypedef(fieldPrefix),
Type: fieldTypedef,
})
}

traverseFields(typeMap, rootType, fieldPrefix, fieldType, field)
err = traverseFields(cdc, typeMap, rootType, fieldPrefix, fieldType, field)
if err != nil {
return
}

continue
}
}

return nil
}

func jsonNameFromTag(tag reflect.StructTag) string {
Expand Down Expand Up @@ -244,10 +324,11 @@ func sanitizeTypedef(str string) string {
}

var (
hashType = reflect.TypeOf(common.Hash{})
addressType = reflect.TypeOf(common.Address{})
bigIntType = reflect.TypeOf(big.Int{})
cosmIntType = reflect.TypeOf(cosmtypes.Int{})
hashType = reflect.TypeOf(common.Hash{})
addressType = reflect.TypeOf(common.Address{})
bigIntType = reflect.TypeOf(big.Int{})
cosmIntType = reflect.TypeOf(cosmtypes.Int{})
cosmosAnyType = reflect.TypeOf(&codectypes.Any{})
)

// typToEth supports only basic types and arrays of basic types.
Expand Down Expand Up @@ -283,6 +364,11 @@ func typToEth(typ reflect.Type) string {
if len(ethName) > 0 {
return ethName + "[]"
}
case reflect.Array:
ethName := typToEth(typ.Elem())
if len(ethName) > 0 {
return ethName + "[]"
}
case reflect.Ptr:
if typ.Elem().ConvertibleTo(bigIntType) ||
typ.Elem().ConvertibleTo(cosmIntType) {
Expand All @@ -299,3 +385,15 @@ func typToEth(typ reflect.Type) string {

return ""
}

func doRecover(err *error) {
if r := recover(); r != nil {
if e, ok := r.(error); ok {
e = errors.Wrap(e, "panicked with error")
*err = e
return
}

*err = errors.Errorf("%v", r)
}
}

0 comments on commit dee4071

Please sign in to comment.