From adb1aefbbadf315779a9d446a079ea031bbbf9bc Mon Sep 17 00:00:00 2001 From: Tom Cole Date: Thu, 19 Dec 2024 13:31:50 -0500 Subject: [PATCH] fix: better error handling when using SQL for resoruce object mgmt A number of cases, such as bad filter construction, resulted in either incorrect errors, panics, or no corrective behaviro at all. This now causes dependent or chained operations like filter specification to store any error recorded in the resource handle and percolate that to the caller of the resource operation. --- errors/messages.go | 2 ++ i18n/languages/messages_en.txt | 2 ++ resources/create.go | 18 +++++++++++++++ resources/defs.go | 11 +++++++++ resources/delete.go | 4 ++++ resources/filters.go | 41 ++++++++++++++++++++++++++++------ resources/insert.go | 4 ++++ resources/read.go | 7 ++++++ resources/update.go | 4 ++++ server/auth/users_sqldb.go | 10 ++++----- server/dsns/dsn_sqldb.go | 26 +++++++++++---------- tools/buildver.txt | 2 +- 12 files changed, 106 insertions(+), 25 deletions(-) diff --git a/errors/messages.go b/errors/messages.go index e7858540..008065af 100644 --- a/errors/messages.go +++ b/errors/messages.go @@ -71,6 +71,7 @@ var ErrInvalidDuration = Message("invalid.duration") var ErrInvalidEndPointString = Message("endpoint") var ErrInvalidField = Message("field.for.type") var ErrInvalidFileMode = Message("file.mode") +var ErrInvalidFilter = Message("invalid.filter") var ErrInvalidFormatVerb = Message("format.spec") var ErrInvalidFunctionArgument = Message("func.arg") var ErrInvalidFunctionCall = Message("func.call") @@ -197,6 +198,7 @@ var ErrRestClientClosed = Message("rest.closed") var ErrReturnValueCount = Message("func.return.count") var ErrServerAlreadyRunning = Message("server.running") var ErrServerError = Message("server.error") +var ErrSQLInjection = Message("sql.injection") var ErrStackUnderflow = Message("stack.underflow") var ErrStructInitTuple = Message("struct.init.tuple") var ErrSymbolNotExported = Message("symbol.not.exported") diff --git a/i18n/languages/messages_en.txt b/i18n/languages/messages_en.txt index e6491b85..0d57b671 100644 --- a/i18n/languages/messages_en.txt +++ b/i18n/languages/messages_en.txt @@ -184,6 +184,7 @@ invalid.blockquote=invalid block quote invalid.cache.item=internal error; invalid cache item type invalid.catch.set=invalid catch set {{index}} invalid.duration=invalid duration string +invalid.filter=invalid SQL filter invalid.named.return.values=Invalid use of named and non-named return values invalid.struct.or.package=invalid structure or package invalid.unwrap=invalid unwrap of non-interface value @@ -252,6 +253,7 @@ server.error=internal server error server.running=server already running as pid slice.index=invalid slice index spacing=invalid spacing value +sql.injection=possible SQL injection violation stack.underflow=stack underflow statement=missing statement statement.not.found=unexpected token diff --git a/resources/create.go b/resources/create.go index 552a6767..114e1687 100644 --- a/resources/create.go +++ b/resources/create.go @@ -12,6 +12,10 @@ import ( func (r *ResHandle) Create() error { var err error + if r.Err != nil { + return r.Err + } + if r.Database == nil { return errors.New("database not open") } @@ -33,6 +37,10 @@ func (r *ResHandle) Create() error { func (r *ResHandle) CreateIf() error { var err error + if r.Err != nil { + return r.Err + } + if r.Database == nil { return errors.New("database not open") } @@ -52,3 +60,13 @@ func (r *ResHandle) CreateIf() error { return err } + +// This resets the state of a resource handle to a known state before +// beginning a chain of operations. For example, this ressts the error +// state such that any subsuent operations (like applying filters) will +// result in a new error state that can be detected by the caller. +func (r *ResHandle) Begin() *ResHandle { + r.Err = nil + + return r +} diff --git a/resources/defs.go b/resources/defs.go index de02fdc1..f83134f9 100644 --- a/resources/defs.go +++ b/resources/defs.go @@ -68,6 +68,10 @@ type ResHandle struct { // select operations. This list is used to consturct the ORDER BY // clause. OrderList []int + + // Error(s) generated during resource specification, filtering, etc. + // are reported here. + Err error } // Filter is an object describing a single comparison used in creating @@ -86,6 +90,7 @@ type Filter struct { const ( EqualsOperator = " = " NotEqualsOperator = " <> " + InvalidOperator = " !error " ) const ( @@ -95,3 +100,9 @@ const ( SQLFloatType = "float" SQLDoubleType = "double" ) + +var invalidFilterError = &Filter{ + Name: "", + Value: "", + Operator: InvalidOperator, +} diff --git a/resources/delete.go b/resources/delete.go index 0401da8b..7dc3977c 100644 --- a/resources/delete.go +++ b/resources/delete.go @@ -21,6 +21,10 @@ func (r *ResHandle) Delete(filters ...*Filter) (int64, error) { count int64 ) + if r.Err != nil { + return 0, r.Err + } + if r.Database == nil { return 0, ErrDatabaseNotOpen } diff --git a/resources/filters.go b/resources/filters.go index a1904064..97d2392c 100644 --- a/resources/filters.go +++ b/resources/filters.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/google/uuid" + "github.com/tucats/ego/errors" "github.com/tucats/ego/util" ) @@ -21,16 +22,27 @@ import ( // only supported operators are EqualsOperator and NotEqualsOperator. A panic // is also generated if the column name specified does not exist in the // database table for the resource object type. -func (r ResHandle) newFilter(name, operator string, value interface{}) *Filter { +func (r *ResHandle) newFilter(name, operator string, value interface{}) *Filter { + if r.Err != nil { + return invalidFilterError + } + if !util.InList(operator, EqualsOperator, NotEqualsOperator) { - // @tomcole need better handling of this - panic("unknown or unimplemented filter operator: " + operator) + r.Err = errors.ErrInvalidFilter.Context(operator) + + return invalidFilterError } for _, column := range r.Columns { if strings.EqualFold(column.Name, name) { switch actual := value.(type) { case string: + if strings.Contains(actual, "'") { + r.Err = errors.ErrSQLInjection.Context(actual) + + return invalidFilterError + } + return &Filter{ Name: column.SQLName, Value: "'" + actual + "'", @@ -38,24 +50,39 @@ func (r ResHandle) newFilter(name, operator string, value interface{}) *Filter { } case uuid.UUID: + text := actual.String() + if strings.Contains(text, "'") { + r.Err = errors.ErrSQLInjection.Context(text) + + return invalidFilterError + } + return &Filter{ Name: column.SQLName, - Value: "'" + actual.String() + "'", + Value: "'" + text + "'", Operator: operator, } default: + text := fmt.Sprintf("%v", actual) + if strings.Contains(text, "'") { + r.Err = errors.ErrSQLInjection.Context(text) + + return invalidFilterError + } + return &Filter{ Name: column.SQLName, - Value: fmt.Sprintf("%v", actual), + Value: text, Operator: operator, } } } } - // @tomcole need better error handling - panic("attempt to create filter on non-existent column " + name + " for " + r.Name) + r.Err = errors.ErrInvalidColumnName.Context(name) + + return nil } // Equals creates a resource filter used for a read, update, or delete diff --git a/resources/insert.go b/resources/insert.go index 1867d45f..7e8c2e6e 100644 --- a/resources/insert.go +++ b/resources/insert.go @@ -5,6 +5,10 @@ import "github.com/tucats/ego/app-cli/ui" func (r *ResHandle) Insert(v interface{}) error { var err error + if r.Err != nil { + return r.Err + } + sql := r.insertSQL() items := r.explode(v) diff --git a/resources/read.go b/resources/read.go index 2bb77819..2c9863b3 100644 --- a/resources/read.go +++ b/resources/read.go @@ -29,6 +29,10 @@ func (r *ResHandle) Read(filters ...*Filter) ([]interface{}, error) { return nil, ErrDatabaseNotOpen } + if r.Err != nil { + return nil, r.Err + } + sql := r.readRowSQL() for index, filter := range filters { @@ -109,6 +113,9 @@ func (r *ResHandle) Read(filters ...*Filter) ([]interface{}, error) { // The default key is the "id" column, but this can be overridden // using the SetIDField() method. func (r *ResHandle) ReadOne(key interface{}) (interface{}, error) { + // Reset the deferred error state for a fresh start. + r.Err = nil + keyField := r.PrimaryKey() if keyField == "" { return nil, errors.ErrNotFound diff --git a/resources/update.go b/resources/update.go index d21271c2..9c11b00d 100644 --- a/resources/update.go +++ b/resources/update.go @@ -17,6 +17,10 @@ import ( func (r *ResHandle) Update(v interface{}, filters ...*Filter) error { var err error + if r.Err != nil { + return r.Err + } + sql := r.updateSQL() for index, filter := range filters { diff --git a/server/auth/users_sqldb.go b/server/auth/users_sqldb.go index 750cffd1..27e430bc 100644 --- a/server/auth/users_sqldb.go +++ b/server/auth/users_sqldb.go @@ -69,7 +69,7 @@ func NewDatabaseService(connStr, defaultUser, defaultPassword string) (userIOSer func (pg *databaseService) ListUsers() map[string]defs.User { r := map[string]defs.User{} - rowSet, err := pg.userHandle.Read() + rowSet, err := pg.userHandle.Begin().Read() if err != nil { ui.Log(ui.ServerLogger, "Database error: %v", err) @@ -105,7 +105,7 @@ func (pg *databaseService) ReadUser(name string, doNotLog bool) (defs.User, erro return user, nil } - rowSet, err := pg.userHandle.Read(pg.userHandle.Equals("name", name)) + rowSet, err := pg.userHandle.Begin().Read(pg.userHandle.Equals("name", name)) if err != nil { ui.Log(ui.ServerLogger, "Database error: %v", err) @@ -141,10 +141,10 @@ func (pg *databaseService) WriteUser(user defs.User) error { _, err := pg.ReadUser(user.Name, false) if err == nil { action = "updated in" - err = pg.userHandle.Update(user, pg.userHandle.Equals("name", user.Name)) + err = pg.userHandle.Begin().Update(user, pg.userHandle.Equals("name", user.Name)) } else { action = "added to" - err = pg.userHandle.Insert(user) + err = pg.userHandle.Begin().Insert(user) } if err != nil { @@ -165,7 +165,7 @@ func (pg *databaseService) DeleteUser(name string) error { // Make sure the item no longer exists in the short-term cache. caches.Delete(caches.AuthCache, name) - count, err := pg.userHandle.Delete(pg.userHandle.Equals("name", name)) + count, err := pg.userHandle.Begin().Delete(pg.userHandle.Equals("name", name)) if err != nil { ui.Log(ui.ServerLogger, "Database error: %v", err) diff --git a/server/dsns/dsn_sqldb.go b/server/dsns/dsn_sqldb.go index e354041e..0f29cd83 100644 --- a/server/dsns/dsn_sqldb.go +++ b/server/dsns/dsn_sqldb.go @@ -72,7 +72,7 @@ func (pg *databaseService) ListDSNS(user string) (map[string]defs.DSN, error) { r := map[string]defs.DSN{} // Specify the sort info (ordered by the DSN name) and read the data. - iArray, err := pg.dsnHandle.Sort("name").Read() + iArray, err := pg.dsnHandle.Begin().Sort("name").Read() if err != nil { return r, err } @@ -101,7 +101,7 @@ func (pg *databaseService) ReadDSN(user, name string, doNotLog bool) (defs.DSN, ) if item, found = caches.Find(caches.DSNCache, name); !found { - item, err = pg.dsnHandle.ReadOne(name) + item, err = pg.dsnHandle.Begin().ReadOne(name) if err != nil { if !doNotLog { ui.Log(ui.AuthLogger, "No dsn record for %s", name) @@ -135,7 +135,7 @@ func (pg *databaseService) WriteDSN(user string, dsname defs.DSN) error { caches.Delete(caches.DSNCache, dsname.Name) - items, err := pg.dsnHandle.Read(pg.dsnHandle.Equals("name", dsname.Name)) + items, err := pg.dsnHandle.Begin().Read(pg.dsnHandle.Equals("name", dsname.Name)) if err != nil { return err } @@ -144,9 +144,9 @@ func (pg *databaseService) WriteDSN(user string, dsname defs.DSN) error { action = "added to" dsname.ID = uuid.NewString() - err = pg.dsnHandle.Insert(dsname) + err = pg.dsnHandle.Begin().Insert(dsname) } else { - err = pg.dsnHandle.UpdateOne(dsname) + err = pg.dsnHandle.Begin().UpdateOne(dsname) } if err != nil { @@ -166,10 +166,10 @@ func (pg *databaseService) DeleteDSN(user, name string) error { caches.Delete(caches.DSNCache, name) - err = pg.dsnHandle.DeleteOne(name) + err = pg.dsnHandle.Begin().DeleteOne(name) if err == nil { // Delete any authentication objects for this DSN as well... - _, _ = pg.authHandle.Delete(pg.authHandle.Equals("dsn", name)) + _, _ = pg.authHandle.Begin().Delete(pg.authHandle.Equals("dsn", name)) ui.Log(ui.AuthLogger, "Deleted DSN %s from database", name) } @@ -192,7 +192,7 @@ func (pg *databaseService) Flush() error { func (pg *databaseService) initializeDatabase() error { err := pg.dsnHandle.CreateIf() if err == nil { - err = pg.authHandle.CreateIf() + err = pg.authHandle.Begin().CreateIf() } if err != nil { @@ -208,6 +208,8 @@ func (pg *databaseService) initializeDatabase() error { // the dsnauths table, which has a bit-mask of allowed operations. The // result is a bit-mapped AND of the requested and permitted actions. func (pg *databaseService) AuthDSN(user, name string, action DSNAction) bool { + pg.dsnHandle.Begin() + dsn, err := pg.ReadDSN(user, name, true) if err != nil { return false @@ -217,7 +219,7 @@ func (pg *databaseService) AuthDSN(user, name string, action DSNAction) bool { return true } - rows, err := pg.authHandle.Read( + rows, err := pg.authHandle.Begin().Read( pg.authHandle.Equals("user", user), pg.authHandle.Equals("dsn", name), ) @@ -245,7 +247,7 @@ func (pg *databaseService) GrantDSN(user, name string, action DSNAction, grant b } // Get the privilege info for this item. - rows, err := pg.authHandle.Read( + rows, err := pg.authHandle.Begin().Read( pg.authHandle.Equals("user", user), pg.authHandle.Equals("dsn", name), ) @@ -302,7 +304,7 @@ func (pg *databaseService) GrantDSN(user, name string, action DSNAction, grant b // action mask. If it did not exist before, insert it into the auth table. auth.Action = existingAction if exists { - err = pg.authHandle.Update(*auth, + err = pg.authHandle.Begin().Update(*auth, pg.authHandle.Equals("user", user), pg.authHandle.Equals("dsn", name)) } else { @@ -329,7 +331,7 @@ func (pg *databaseService) Permissions(user, name string) (map[string]DSNAction, return result, nil } - auths, err := pg.authHandle.Read(pg.authHandle.Equals("dsn", name)) + auths, err := pg.authHandle.Begin().Read(pg.authHandle.Equals("dsn", name)) if err != nil { return nil, err } diff --git a/tools/buildver.txt b/tools/buildver.txt index d8557317..028d1a92 100644 --- a/tools/buildver.txt +++ b/tools/buildver.txt @@ -1 +1 @@ -1.5-1160 +1.5-1161