Skip to content

Commit

Permalink
Support names in selector parser
Browse files Browse the repository at this point in the history
  • Loading branch information
sobolev-igor committed Aug 3, 2023
1 parent a89e988 commit 9261c38
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 10 deletions.
146 changes: 136 additions & 10 deletions accounts/abi/selector_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ func isDigit(c byte) bool {
}

func isAlpha(c byte) bool {
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c == ' ')
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')
}

func isSpace(c byte) bool {
return c == ' '
}

func isIdentifierSymbol(c byte) bool {
Expand All @@ -35,7 +39,7 @@ func parseToken(unescapedSelector string, isIdent bool) (string, string, error)
}
for position < len(unescapedSelector) {
char := unescapedSelector[position]
if !(isAlpha(char) || isDigit(char) || (isIdent && isIdentifierSymbol(char))) {
if !(isAlpha(char) || isDigit(char) || (isIdent && isIdentifierSymbol(char)) || (!isIdent && isSpace(char))) {
break
}
position++
Expand All @@ -48,6 +52,32 @@ func parseIdentifier(unescapedSelector string) (string, string, error) {
}

func parseElementaryType(unescapedSelector string) (parsedType string, rest string, err error) {
parsedType, rest, err = parseToken(unescapedSelector, false)
if err != nil {
return "", "", fmt.Errorf("failed to parse elementary type: %v", err)
}
parts := strings.Split(parsedType, " ")
if len(parts) > 1 {
parsedType = parsedType[len(parts[0])+1:]
}
// handle arrays
for len(rest) > 0 && rest[0] == '[' {
parsedType = parsedType + string(rest[0])
rest = rest[1:]
for len(rest) > 0 && isDigit(rest[0]) {
parsedType = parsedType + string(rest[0])
rest = rest[1:]
}
if len(rest) == 0 || rest[0] != ']' {
return "", "", fmt.Errorf("failed to parse array: expected ']', got %c", unescapedSelector[0])
}
parsedType = parsedType + string(rest[0])
rest = rest[1:]
}
return parsedType, rest, nil
}

func parseElementaryTypeWithName(unescapedSelector string) (parsedType string, rest string, err error) {
parsedType, rest, err = parseToken(unescapedSelector, false)
if err != nil {
return "", "", fmt.Errorf("failed to parse elementary type: %v", err)
Expand Down Expand Up @@ -95,7 +125,77 @@ func parseCompositeType(unescapedSelector string) (result []interface{}, rest st
return result, rest[1:], nil
}

func parseCompositeTypeWithName(unescapedSelector string) (result []interface{}, rest string, err error) {
var name string
parts := strings.Split(unescapedSelector, " ")
if len(parts) < 2 {
return nil, "", fmt.Errorf("expected name in the beginning, got %s", unescapedSelector)
} else {
name = parts[0]
unescapedSelector = unescapedSelector[len(parts[0])+1:]
}
if len(unescapedSelector) == 0 || unescapedSelector[0] != '(' {
return nil, "", fmt.Errorf("expected '(...', got %s", unescapedSelector)
}
result = []interface{}{name}
var parsedType interface{}
var counter int64
parsedType, rest, err = parseTypeWithName(unescapedSelector[1:], counter)
if err != nil {
return nil, "", fmt.Errorf("failed to parse type: %v", err)
}
result = append(result, parsedType)
for len(rest) > 0 && rest[0] != ')' {
counter += 1
parsedType, rest, err = parseTypeWithName(rest[1:], counter)
if err != nil {
return nil, "", fmt.Errorf("failed to parse type: %v", err)
}
result = append(result, parsedType)
}
if len(rest) == 0 || rest[0] != ')' {
return nil, "", fmt.Errorf("expected ')', got '%s'", rest)
}
if len(rest) >= 3 && rest[1] == '[' && rest[2] == ']' {
return append(result, "[]"), rest[3:], nil
}
return result, rest[1:], nil
}

func parseFunctionsArgs(unescapedSelector string) (result []interface{}, rest string, err error) {
if len(unescapedSelector) == 0 || unescapedSelector[0] != '(' {
return nil, "", fmt.Errorf("expected '(...', got %s", unescapedSelector)
}
var parsedType interface{}
var counter int64
parsedType, rest, err = parseTypeWithName(unescapedSelector[1:], counter)
if err != nil {
return nil, "", fmt.Errorf("failed to parse type: %v", err)
}
result = []interface{}{parsedType}

for len(rest) > 0 && rest[0] != ')' {
counter += 1
parsedType, rest, err = parseTypeWithName(rest[1:], counter)
if err != nil {
return nil, "", fmt.Errorf("failed to parse type: %v", err)
}
result = append(result, parsedType)
}
if len(rest) == 0 || rest[0] != ')' {
return nil, "", fmt.Errorf("expected ')', got '%s'", rest)
}
if len(rest) >= 3 && rest[1] == '[' && rest[2] == ']' {
return append(result, "[]"), rest[3:], nil
}
return result, rest[1:], nil
}

func parseType(unescapedSelector string) (interface{}, string, error) {
parts := strings.Split(unescapedSelector, " ")
if len(parts) > 1 {
unescapedSelector = unescapedSelector[len(parts[0])+1:]
}
if len(unescapedSelector) == 0 {
return nil, "", errors.New("empty type")
}
Expand All @@ -106,22 +206,48 @@ func parseType(unescapedSelector string) (interface{}, string, error) {
}
}

func parseTypeWithName(unescapedSelector string, counter int64) (interface{}, string, error) {
name, rest, _ := parseIdentifier(unescapedSelector)
if len(rest) > 0 && rest[0] == ' ' {
unescapedSelector = unescapedSelector[len(name)+1:]
} else {
name = fmt.Sprintf("name%d", counter)
}
if len(unescapedSelector) == 0 {
return nil, "", errors.New("empty type")
}
if unescapedSelector[0] == '(' {
return parseCompositeTypeWithName(fmt.Sprintf("%v %v", name, unescapedSelector))
} else {
return parseElementaryTypeWithName(fmt.Sprintf("%v %v", name, unescapedSelector))
}
}

func assembleArgs(args []interface{}) (arguments []ArgumentMarshaling, err error) {
arguments = make([]ArgumentMarshaling, 0)
for i, arg := range args {
// generate dummy name to avoid unmarshal issues
name := fmt.Sprintf("name%d", i)
for _, arg := range args {
var name string
if s, ok := arg.(string); ok {
if s == "[]" {
arguments = append(arguments, ArgumentMarshaling{Name: name, Type: s, InternalType: s})
continue
}
parts := strings.Split(s, " ")
if len(parts) > 2 {
return nil, fmt.Errorf("more than 2 spaces in type declaration in selector %s", s)
} else if len(parts) == 2 {
if len(parts) < 2 {
return nil, fmt.Errorf("no name in arg %s", s)
} else {
name = parts[0]
s = parts[1]
s = s[len(name)+1:]
}
arguments = append(arguments, ArgumentMarshaling{Name: name, Type: s, InternalType: s})
} else if components, ok := arg.([]interface{}); ok {
var subArgs []ArgumentMarshaling
if len(components) < 2 {
return nil, fmt.Errorf("no name in components %s", components)
} else {
name = components[0].(string)
components = components[1:]
}
subArgs, err = assembleArgs(components)
if err != nil {
return nil, fmt.Errorf("failed to assemble components: %v", err)
Expand Down Expand Up @@ -153,7 +279,7 @@ func ParseSelector(unescapedSelector string) (m SelectorMarshaling, err error) {
if len(rest) >= 2 && rest[0] == '(' && rest[1] == ')' {
rest = rest[2:]
} else {
args, rest, err = parseCompositeType(rest)
args, rest, err = parseFunctionsArgs(rest)
if err != nil {
return SelectorMarshaling{}, fmt.Errorf("failed to parse selector '%s': %v", unescapedSelector, err)
}
Expand Down
10 changes: 10 additions & 0 deletions accounts/abi/selector_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"fmt"
"log"
"reflect"
"strings"
"testing"
)

Expand All @@ -29,6 +30,11 @@ func TestParseSelector(t *testing.T) {
for i, typeOrComponents := range types {
name := fmt.Sprintf("name%d", i)
if typeName, ok := typeOrComponents.(string); ok {
names := strings.Split(typeName, " ")
if len(names) > 1 {
name = names[0]
typeName = typeName[len(name)+1:]
}
result = append(result, ArgumentMarshaling{name, typeName, typeName, nil, false})
} else if components, ok := typeOrComponents.([]ArgumentMarshaling); ok {
result = append(result, ArgumentMarshaling{name, "tuple", "tuple", components, false})
Expand Down Expand Up @@ -59,6 +65,10 @@ func TestParseSelector(t *testing.T) {
mkType([][]ArgumentMarshaling{mkType("uint256", "uint256")}, "bytes32[]")},
{"singleArrayNestWithArrayAndArray((uint256[],address[2],uint8[4][][5])[],bytes32[])", "singleArrayNestWithArrayAndArray",
mkType([][]ArgumentMarshaling{mkType("uint256[]", "address[2]", "uint8[4][][5]")}, "bytes32[]")},
{"transfer(to address,amount uint256)", "transfer",
mkType("to address", "amount uint256")},
{"execute(((a address,c uint256,b uint8),(d uint8,e bytes)),(address,uint256),(uint8,(uint256,address),(uint256,address),address,address,bytes),(uint256,bytes),(uint256,bytes))", "execute",
mkType(mkType(mkType("a address", "c uint256", "b uint8"), mkType("d uint8", "e bytes")), mkType("address", "uint256"), mkType("uint8", mkType("uint256", "address"), mkType("uint256", "address"), "address", "address", "bytes"), mkType("uint256", "bytes"), mkType("uint256", "bytes"))},
}
for i, tt := range tests {
selector, err := ParseSelector(tt.input)
Expand Down

0 comments on commit 9261c38

Please sign in to comment.