Skip to content

Commit

Permalink
Add ability to use regular expressions in Aim search queries (G-Resea…
Browse files Browse the repository at this point in the history
…rch#171)

Co-authored-by: Jonathan Giannuzzi <jgiannuzzi@users.noreply.github.com>
  • Loading branch information
dsuhinin and jgiannuzzi authored Aug 2, 2023
1 parent 607bc42 commit 7340019
Show file tree
Hide file tree
Showing 9 changed files with 258 additions and 34 deletions.
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ require (
github.com/go-python/gpython v0.2.0
github.com/gofiber/fiber/v2 v2.48.0
github.com/google/uuid v1.3.0
github.com/hashicorp/golang-lru/v2 v2.0.4
github.com/hetiansu5/urlquery v1.2.7
github.com/mattn/go-sqlite3 v1.14.16
github.com/pkg/errors v0.9.1
github.com/rotisserie/eris v0.5.4
github.com/sirupsen/logrus v1.9.3
Expand Down Expand Up @@ -50,7 +52,6 @@ require (
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mattn/go-runewidth v0.0.14 // indirect
github.com/mattn/go-sqlite3 v1.14.16 // indirect
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 // indirect
github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5m
github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/golang-lru/v2 v2.0.4 h1:7GHuZcgid37q8o5i3QI9KMT4nCWQQ3Kx3Ov6bb9MfK0=
github.com/hashicorp/golang-lru/v2 v2.0.4/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/hetiansu5/urlquery v1.2.7 h1:jn0h+9pIRqUziSPnRdK/gJK8S5TCnk+HZZx5fRHf8K0=
Expand Down
38 changes: 38 additions & 0 deletions pkg/api/aim/query/clause.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package query

import (
"fmt"

"gorm.io/driver/postgres"
"gorm.io/gorm/clause"
)

// Regexp whether string matches regular expression
type Regexp struct {
clause.Eq
Dialector string
}

// Build builds positive statement.
func (regexp Regexp) Build(builder clause.Builder) {
builder.WriteQuoted(regexp.Column)
switch regexp.Dialector {
case postgres.Dialector{}.Name():
builder.WriteString(" ~ ")
default:
builder.WriteString(" regexp ")
}
builder.AddVar(builder, fmt.Sprintf("%s", regexp.Value))
}

// NegationBuild builds negative statement.
func (regexp Regexp) NegationBuild(builder clause.Builder) {
builder.WriteQuoted(regexp.Column)
switch regexp.Dialector {
case postgres.Dialector{}.Name():
builder.WriteString(" !~ ")
default:
builder.WriteString(" NOT regexp ")
}
builder.AddVar(builder, regexp.Value)
}
101 changes: 80 additions & 21 deletions pkg/api/aim/query/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,16 @@ import (
"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models"
)

const (
OperationEndsWith = "endswith"
OperationContains = "contains"
OperationStartsWith = "startswith"
)

type DefaultExpression struct {
Contains string
Expression string
}

type QueryParser struct {
Default DefaultExpression
Tables map[string]string
TzOffset int
Default DefaultExpression
Tables map[string]string
TzOffset int
Dialector string
}

type ParsedQuery interface {
Expand Down Expand Up @@ -217,7 +212,7 @@ func (pq *parsedQuery) parseAttribute(node *ast.Attribute) (any, error) {
}
attribute := string(node.Attr)
switch strings.ToLower(attribute) {
case OperationEndsWith:
case "endswith":
return callable(func(args []ast.Expr) (any, error) {
if len(args) != 1 {
return nil, errors.New("`endwith` function support exactly one argument")
Expand All @@ -231,11 +226,15 @@ func (pq *parsedQuery) parseAttribute(node *ast.Attribute) (any, error) {
if !ok {
return nil, errors.New("unsupported argument type. has to be `string` only")
}
return clause.Expr{
SQL: fmt.Sprintf(`"%s"."%s" LIKE '%%%s'`, c.Table, c.Name, arg.S),
return clause.Like{
Value: fmt.Sprintf("%%%s", arg.S),
Column: clause.Column{
Table: c.Table,
Name: c.Name,
},
}, nil
}), nil
case OperationContains:
case "contains":
return callable(func(args []ast.Expr) (any, error) {
if len(args) != 1 {
return nil, errors.New("`contains` function support exactly one argument")
Expand All @@ -249,11 +248,15 @@ func (pq *parsedQuery) parseAttribute(node *ast.Attribute) (any, error) {
if !ok {
return nil, errors.New("unsupported argument type. has to be `string` only")
}
return clause.Expr{
SQL: fmt.Sprintf(`"%s"."%s" LIKE '%%%s%%'`, c.Table, c.Name, arg.S),
return clause.Like{
Value: fmt.Sprintf("%%%s%%", arg.S),
Column: clause.Column{
Table: c.Table,
Name: c.Name,
},
}, nil
}), nil
case OperationStartsWith:
case "startswith":
return callable(func(args []ast.Expr) (any, error) {
if len(args) != 1 {
return nil, errors.New("`startwith` function support exactly one argument")
Expand All @@ -267,8 +270,12 @@ func (pq *parsedQuery) parseAttribute(node *ast.Attribute) (any, error) {
if !ok {
return nil, errors.New("unsupported argument type. has to be `string` only")
}
return clause.Expr{
SQL: fmt.Sprintf(`"%s"."%s" LIKE '%s%%'`, c.Table, c.Name, arg.S),
return clause.Like{
Value: fmt.Sprintf("%s%%", arg.S),
Column: clause.Column{
Table: c.Table,
Name: c.Name,
},
}, nil
}), nil
}
Expand Down Expand Up @@ -403,7 +410,7 @@ func (pq *parsedQuery) parseName(node *ast.Name) (any, error) {
case "run":
table, ok := pq.qp.Tables["runs"]
if !ok {
return nil, errors.New("unsupported name identifier \"run\"")
return nil, errors.New(`unsupported name identifier "run"`)
}
return attributeGetter(
func(attr string) (any, error) {
Expand Down Expand Up @@ -431,7 +438,7 @@ func (pq *parsedQuery) parseName(node *ast.Name) (any, error) {
case "experiment":
e, ok := pq.qp.Tables["experiments"]
if !ok {
return nil, errors.New("unsupported attribute \"experiment\"")
return nil, errors.New(`unsupported attribute "experiment"`)
}
return clause.Column{
Table: e,
Expand Down Expand Up @@ -554,7 +561,7 @@ func (pq *parsedQuery) parseName(node *ast.Name) (any, error) {
case "metric":
table, ok := pq.qp.Tables["metrics"]
if !ok {
return nil, errors.New("unsupported name identifier \"metric\"")
return nil, errors.New(`unsupported name identifier "metric"`)
}
return attributeGetter(
func(attr string) (any, error) {
Expand All @@ -581,6 +588,58 @@ func (pq *parsedQuery) parseName(node *ast.Name) (any, error) {
}
},
), nil
case "re":
return attributeGetter(
func(attr string) (any, error) {
switch attr {
case "match":
fallthrough
case "search":
return callable(
func(args []ast.Expr) (any, error) {
if len(args) != 2 {
return nil, errors.New("re.match function support exactly 2 arguments")
}

parsedNode, err := pq.parseNode(args[0])
if err != nil {
return nil, err
}
str, ok := parsedNode.(string)
if !ok {
return nil, errors.New("first argument type for re.match function has to be a string")
}

parsedNode, err = pq.parseNode(args[1])
if err != nil {
return nil, err
}
column, ok := parsedNode.(clause.Column)
if !ok {
return nil, errors.New(
"second argument type for re.match function has to be clause.Column",
)
}

// handle difference between `match` and `search`.
if attr == "match" {
str = fmt.Sprintf("^%s", str)
}

return Regexp{
Eq: clause.Eq{
Column: column,
Value: str,
},
Dialector: pq.qp.Dialector,
}, nil
},
), nil
default:
return nil, fmt.Errorf("unsupported re function %s", attr)
}
},
), nil
case "datetime":
return callable(
func(args []ast.Expr) (any, error) {
Expand Down
99 changes: 92 additions & 7 deletions pkg/api/aim/query/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"

"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models"
Expand All @@ -24,6 +25,7 @@ func TestQueryTestSuite(t *testing.T) {
func (s *QueryTestSuite) SetupTest() {
mockedDB, _, err := sqlmock.New()
assert.Nil(s.T(), err)

db, err := gorm.Open(postgres.New(postgres.Config{
Conn: mockedDB,
DriverName: "postgres",
Expand All @@ -32,7 +34,7 @@ func (s *QueryTestSuite) SetupTest() {
s.db = db
}

func (s *QueryTestSuite) Test_Ok() {
func (s *QueryTestSuite) TestPostgresDialector_Ok() {
tests := []struct {
name string
query string
Expand All @@ -48,20 +50,102 @@ func (s *QueryTestSuite) Test_Ok() {
{
name: "TestRunNameWithContainsFunction",
query: `(run.name.contains('run'))`,
expectedSQL: `SELECT * FROM "runs" WHERE ("runs"."name" LIKE '%run%' AND "runs"."lifecycle_stage" <> $1) ORDER BY "runs"."run_uuid" LIMIT 1`,
expectedVars: []interface{}{models.LifecycleStageDeleted},
expectedSQL: `SELECT * FROM "runs" WHERE ("runs"."name" LIKE $1 AND "runs"."lifecycle_stage" <> $2) ORDER BY "runs"."run_uuid" LIMIT 1`,
expectedVars: []interface{}{"%run%", models.LifecycleStageDeleted},
},
{
name: "TestRunNameWithStartWithFunction",
query: `(run.name.startswith('run'))`,
expectedSQL: `SELECT * FROM "runs" WHERE ("runs"."name" LIKE 'run%' AND "runs"."lifecycle_stage" <> $1) ORDER BY "runs"."run_uuid" LIMIT 1`,
expectedVars: []interface{}{models.LifecycleStageDeleted},
expectedSQL: `SELECT * FROM "runs" WHERE ("runs"."name" LIKE $1 AND "runs"."lifecycle_stage" <> $2) ORDER BY "runs"."run_uuid" LIMIT 1`,
expectedVars: []interface{}{"run%", models.LifecycleStageDeleted},
},
{
name: "TestRunNameWithEndWithFunction",
query: `(run.name.endswith('run'))`,
expectedSQL: `SELECT * FROM "runs" WHERE ("runs"."name" LIKE '%run' AND "runs"."lifecycle_stage" <> $1) ORDER BY "runs"."run_uuid" LIMIT 1`,
expectedVars: []interface{}{models.LifecycleStageDeleted},
expectedSQL: `SELECT * FROM "runs" WHERE ("runs"."name" LIKE $1 AND "runs"."lifecycle_stage" <> $2) ORDER BY "runs"."run_uuid" LIMIT 1`,
expectedVars: []interface{}{"%run", models.LifecycleStageDeleted},
},
{
name: "TestRunNameWithRegexpMatchFunction",
query: `(re.match('run', run.name))`,
expectedSQL: `SELECT * FROM "runs" WHERE ("runs"."name" ~ $1 AND "runs"."lifecycle_stage" <> $2) ORDER BY "runs"."run_uuid" LIMIT 1`,
expectedVars: []interface{}{"^run", models.LifecycleStageDeleted},
},
{
name: "TestRunNameWithRegexpSearchFunction",
query: `(re.search('run', run.name))`,
expectedSQL: `SELECT * FROM "runs" WHERE ("runs"."name" ~ $1 AND "runs"."lifecycle_stage" <> $2) ORDER BY "runs"."run_uuid" LIMIT 1`,
expectedVars: []interface{}{"run", models.LifecycleStageDeleted},
},
}

for _, tt := range tests {
s.T().Run(tt.name, func(T *testing.T) {
pq := QueryParser{
Default: DefaultExpression{
Contains: "run.archived",
Expression: "not run.archived",
},
Tables: map[string]string{
"runs": "runs",
"experiments": "Experiment",
},
Dialector: postgres.Dialector{}.Name(),
}
parsedQuery, err := pq.Parse(tt.query)
assert.Nil(s.T(), err)
result := parsedQuery.Filter(
s.db.Session(&gorm.Session{DryRun: true}).Model(models.Run{}),
).First(&models.Run{})
assert.Nil(s.T(), result.Error)
assert.Equal(s.T(), tt.expectedSQL, result.Statement.SQL.String())
assert.Equal(s.T(), tt.expectedVars, result.Statement.Vars)
})
}
}

func (s *QueryTestSuite) TestSqliteDialector_Ok() {
tests := []struct {
name string
query string
expectedSQL string
expectedVars []interface{}
}{
{
name: "TestRunNameWithoutFunction",
query: `(run.name == 'run')`,
expectedSQL: `SELECT * FROM "runs" WHERE ("runs"."name" = $1 AND "runs"."lifecycle_stage" <> $2) ORDER BY "runs"."run_uuid" LIMIT 1`,
expectedVars: []interface{}{"run", models.LifecycleStageDeleted},
},
{
name: "TestRunNameWithContainsFunction",
query: `(run.name.contains('run'))`,
expectedSQL: `SELECT * FROM "runs" WHERE ("runs"."name" LIKE $1 AND "runs"."lifecycle_stage" <> $2) ORDER BY "runs"."run_uuid" LIMIT 1`,
expectedVars: []interface{}{"%run%", models.LifecycleStageDeleted},
},
{
name: "TestRunNameWithStartWithFunction",
query: `(run.name.startswith('run'))`,
expectedSQL: `SELECT * FROM "runs" WHERE ("runs"."name" LIKE $1 AND "runs"."lifecycle_stage" <> $2) ORDER BY "runs"."run_uuid" LIMIT 1`,
expectedVars: []interface{}{"run%", models.LifecycleStageDeleted},
},
{
name: "TestRunNameWithEndWithFunction",
query: `(run.name.endswith('run'))`,
expectedSQL: `SELECT * FROM "runs" WHERE ("runs"."name" LIKE $1 AND "runs"."lifecycle_stage" <> $2) ORDER BY "runs"."run_uuid" LIMIT 1`,
expectedVars: []interface{}{"%run", models.LifecycleStageDeleted},
},
{
name: "TestRunNameWithRegexpMatchFunction",
query: `(re.match('run', run.name))`,
expectedSQL: `SELECT * FROM "runs" WHERE ("runs"."name" regexp $1 AND "runs"."lifecycle_stage" <> $2) ORDER BY "runs"."run_uuid" LIMIT 1`,
expectedVars: []interface{}{"^run", models.LifecycleStageDeleted},
},
{
name: "TestRunNameWithRegexpSearchFunction",
query: `(re.search('run', run.name))`,
expectedSQL: `SELECT * FROM "runs" WHERE ("runs"."name" regexp $1 AND "runs"."lifecycle_stage" <> $2) ORDER BY "runs"."run_uuid" LIMIT 1`,
expectedVars: []interface{}{"run", models.LifecycleStageDeleted},
},
}

Expand All @@ -76,6 +160,7 @@ func (s *QueryTestSuite) Test_Ok() {
"runs": "runs",
"experiments": "Experiment",
},
Dialector: sqlite.Dialector{}.Name(),
}
parsedQuery, err := pq.Parse(tt.query)
assert.Nil(s.T(), err)
Expand Down
Loading

0 comments on commit 7340019

Please sign in to comment.