From f2a883d572c6e17517df14a508a0d587fa772b57 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Sat, 17 Aug 2024 16:20:59 -0600 Subject: [PATCH] implemented singleUserManager --- api/errors.go | 30 +++++++++++++++++ api/subscription.go | 4 --- database/db.go | 7 ---- database/memory.go | 52 +++++++++++++++--------------- database/sql.go | 78 ++++++++++++++++++++++++++++++++++----------- runner.go | 71 ++++++++++++++++++++++++++++++++++++++++- 6 files changed, 185 insertions(+), 57 deletions(-) create mode 100644 api/errors.go diff --git a/api/errors.go b/api/errors.go new file mode 100644 index 0000000..3b0214d --- /dev/null +++ b/api/errors.go @@ -0,0 +1,30 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package api + +import ( + "errors" +) + +var ( + ErrStopIter = errors.New("stop iteration") + ErrNotFound = errors.New("Item not found") + ErrExist = errors.New("Item is already exist") +) diff --git a/api/subscription.go b/api/subscription.go index 7b88ad9..2749c89 100644 --- a/api/subscription.go +++ b/api/subscription.go @@ -32,10 +32,6 @@ import ( "github.com/LiterMC/go-openbmclapi/utils" ) -var ( - ErrNotFound = errors.New("Item not found") -) - type SubscriptionManager interface { GetWebPushKey() string diff --git a/database/db.go b/database/db.go index 8bca346..b2fd003 100644 --- a/database/db.go +++ b/database/db.go @@ -20,7 +20,6 @@ package database import ( - "errors" "time" "github.com/google/uuid" @@ -28,12 +27,6 @@ import ( "github.com/LiterMC/go-openbmclapi/api" ) -var ( - ErrStopIter = errors.New("stop iteration") - ErrNotFound = errors.New("no record was found") - ErrExists = errors.New("record's key was already exists") -) - type DB interface { // Cleanup will release any release that the database created // No operation should be executed during or after cleanup diff --git a/database/memory.go b/database/memory.go index 8169cdc..394e9f1 100644 --- a/database/memory.go +++ b/database/memory.go @@ -73,7 +73,7 @@ func (m *MemoryDB) ValidJTI(jti string) (bool, error) { expire, ok := m.tokens[jti] if !ok { - return false, ErrNotFound + return false, api.ErrNotFound } if time.Now().After(expire) { return false, nil @@ -85,7 +85,7 @@ func (m *MemoryDB) AddJTI(jti string, expire time.Time) error { m.tokenMux.Lock() defer m.tokenMux.Unlock() if _, ok := m.tokens[jti]; ok { - return ErrExists + return api.ErrExist } m.tokens[jti] = expire return nil @@ -96,13 +96,13 @@ func (m *MemoryDB) RemoveJTI(jti string) error { _, ok := m.tokens[jti] m.tokenMux.RUnlock() if !ok { - return ErrNotFound + return api.ErrNotFound } m.tokenMux.Lock() defer m.tokenMux.Unlock() if _, ok := m.tokens[jti]; !ok { - return ErrNotFound + return api.ErrNotFound } delete(m.tokens, jti) return nil @@ -114,7 +114,7 @@ func (m *MemoryDB) GetFileRecord(path string) (*FileRecord, error) { record, ok := m.fileRecords[path] if !ok { - return nil, ErrNotFound + return nil, api.ErrNotFound } return record, nil } @@ -136,7 +136,7 @@ func (m *MemoryDB) RemoveFileRecord(path string) error { defer m.fileRecMux.Unlock() if _, ok := m.fileRecords[path]; !ok { - return ErrNotFound + return api.ErrNotFound } delete(m.fileRecords, path) return nil @@ -148,7 +148,7 @@ func (m *MemoryDB) ForEachFileRecord(cb func(*FileRecord) error) error { for _, v := range m.fileRecords { if err := cb(v); err != nil { - if err == ErrStopIter { + if err == api.ErrStopIter { break } return err @@ -163,7 +163,7 @@ func (m *MemoryDB) GetSubscribe(user string, client string) (*api.SubscribeRecor record, ok := m.subscribeRecords[[2]string{user, client}] if !ok { - return nil, ErrNotFound + return nil, api.ErrNotFound } return record, nil } @@ -176,7 +176,7 @@ func (m *MemoryDB) SetSubscribe(record api.SubscribeRecord) error { if record.EndPoint == "" { old, ok := m.subscribeRecords[key] if !ok { - return ErrNotFound + return api.ErrNotFound } record.EndPoint = old.EndPoint } @@ -191,7 +191,7 @@ func (m *MemoryDB) RemoveSubscribe(user string, client string) error { key := [2]string{user, client} _, ok := m.subscribeRecords[key] if !ok { - return ErrNotFound + return api.ErrNotFound } delete(m.subscribeRecords, key) return nil @@ -203,7 +203,7 @@ func (m *MemoryDB) ForEachSubscribe(cb func(*api.SubscribeRecord) error) error { for _, v := range m.subscribeRecords { if err := cb(v); err != nil { - if err == ErrStopIter { + if err == api.ErrStopIter { break } return err @@ -218,7 +218,7 @@ func (m *MemoryDB) GetEmailSubscription(user string, addr string) (*api.EmailSub record, ok := m.emailSubRecords[[2]string{user, addr}] if !ok { - return nil, ErrNotFound + return nil, api.ErrNotFound } return record, nil } @@ -229,7 +229,7 @@ func (m *MemoryDB) AddEmailSubscription(record api.EmailSubscriptionRecord) erro key := [2]string{record.User, record.Addr} if _, ok := m.emailSubRecords[key]; ok { - return ErrExists + return api.ErrExist } m.emailSubRecords[key] = &record return nil @@ -242,7 +242,7 @@ func (m *MemoryDB) UpdateEmailSubscription(record api.EmailSubscriptionRecord) e key := [2]string{record.User, record.Addr} old, ok := m.emailSubRecords[key] if ok { - return ErrNotFound + return api.ErrNotFound } _ = old m.emailSubRecords[key] = &record @@ -255,7 +255,7 @@ func (m *MemoryDB) RemoveEmailSubscription(user string, addr string) error { key := [2]string{user, addr} if _, ok := m.emailSubRecords[key]; ok { - return ErrNotFound + return api.ErrNotFound } delete(m.emailSubRecords, key) return nil @@ -267,7 +267,7 @@ func (m *MemoryDB) ForEachEmailSubscription(cb func(*api.EmailSubscriptionRecord for _, v := range m.emailSubRecords { if err := cb(v); err != nil { - if err == ErrStopIter { + if err == api.ErrStopIter { break } return err @@ -285,7 +285,7 @@ func (m *MemoryDB) ForEachUsersEmailSubscription(user string, cb func(*api.Email continue } if err := cb(v); err != nil { - if err == ErrStopIter { + if err == api.ErrStopIter { break } return err @@ -303,7 +303,7 @@ func (m *MemoryDB) ForEachEnabledEmailSubscription(cb func(*api.EmailSubscriptio continue } if err := cb(v); err != nil { - if err == ErrStopIter { + if err == api.ErrStopIter { break } return err @@ -318,7 +318,7 @@ func (m *MemoryDB) GetWebhook(user string, id uuid.UUID) (*api.WebhookRecord, er record, ok := m.webhookRecords[webhookMemKey{user, id}] if !ok { - return nil, ErrNotFound + return nil, api.ErrNotFound } return record, nil } @@ -338,7 +338,7 @@ func (m *MemoryDB) AddWebhook(record api.WebhookRecord) (err error) { key := webhookMemKey{record.User, record.Id} if _, ok := m.webhookRecords[key]; ok { - return ErrExists + return api.ErrExist } if record.Auth == nil { record.Auth = emptyStrPtr @@ -357,7 +357,7 @@ func (m *MemoryDB) UpdateWebhook(record api.WebhookRecord) error { key := webhookMemKey{record.User, record.Id} old, ok := m.webhookRecords[key] if ok { - return ErrNotFound + return api.ErrNotFound } if record.Auth == nil { record.Auth = old.Auth @@ -376,7 +376,7 @@ func (m *MemoryDB) UpdateEnableWebhook(user string, id uuid.UUID, enabled bool) key := webhookMemKey{user, id} old, ok := m.webhookRecords[key] if ok { - return ErrNotFound + return api.ErrNotFound } record := *old record.Enabled = enabled @@ -390,7 +390,7 @@ func (m *MemoryDB) RemoveWebhook(user string, id uuid.UUID) error { key := webhookMemKey{user, id} if _, ok := m.webhookRecords[key]; ok { - return ErrNotFound + return api.ErrNotFound } delete(m.webhookRecords, key) return nil @@ -402,7 +402,7 @@ func (m *MemoryDB) ForEachWebhook(cb func(*api.WebhookRecord) error) error { for _, v := range m.webhookRecords { if err := cb(v); err != nil { - if err == ErrStopIter { + if err == api.ErrStopIter { break } return err @@ -420,7 +420,7 @@ func (m *MemoryDB) ForEachUsersWebhook(user string, cb func(*api.WebhookRecord) continue } if err := cb(v); err != nil { - if err == ErrStopIter { + if err == api.ErrStopIter { break } return err @@ -438,7 +438,7 @@ func (m *MemoryDB) ForEachEnabledWebhook(cb func(*api.WebhookRecord) error) erro continue } if err := cb(v); err != nil { - if err == ErrStopIter { + if err == api.ErrStopIter { break } return err diff --git a/database/sql.go b/database/sql.go index aa8ca57..a7df6d5 100644 --- a/database/sql.go +++ b/database/sql.go @@ -283,7 +283,7 @@ func (db *SqlDB) RemoveJTI(jti string) (err error) { if _, err = db.jtiStmts.remove.ExecContext(ctx, jti); err != nil { if err == sql.ErrNoRows { - err = ErrNotFound + err = api.ErrNotFound } return } @@ -411,7 +411,7 @@ func (db *SqlDB) GetFileRecord(path string) (rec *FileRecord, err error) { rec.Path = path if err = db.fileRecordStmts.get.QueryRowContext(ctx, &rec.Path).Scan(&rec.Hash, &rec.Size); err != nil { if err == sql.ErrNoRows { - err = ErrNotFound + err = api.ErrNotFound } return } @@ -436,7 +436,7 @@ func (db *SqlDB) RemoveFileRecord(path string) (err error) { if _, err = db.fileRecordStmts.remove.ExecContext(ctx, path); err != nil { if err == sql.ErrNoRows { - err = ErrNotFound + err = api.ErrNotFound } return } @@ -457,7 +457,12 @@ func (db *SqlDB) ForEachFileRecord(cb func(*FileRecord) error) (err error) { if err = rows.Scan(&rec.Path, &rec.Hash, &rec.Size); err != nil { return } - cb(&rec) + if err = cb(&rec); err != nil { + if err == api.ErrStopIter { + return nil + } + return + } } if err = rows.Err(); err != nil { return @@ -623,7 +628,7 @@ func (db *SqlDB) GetSubscribe(user string, client string) (rec *api.SubscribeRec rec.Client = client if err = db.subscribeStmts.get.QueryRowContext(ctx, user, client).Scan(&rec.EndPoint, &rec.Keys, &rec.Scopes, &rec.ReportAt); err != nil { if err == sql.ErrNoRows { - err = ErrNotFound + err = api.ErrNotFound } return } @@ -661,14 +666,14 @@ func (db *SqlDB) SetSubscribe(rec api.SubscribeRecord) (err error) { } else if rec.LastReport.Valid { if _, err = tx.Stmt(db.subscribeStmts.setUpdateLastReportOnly).Exec(rec.LastReport, rec.User, rec.Client); err != nil { if err == sql.ErrNoRows { - err = ErrNotFound + err = api.ErrNotFound } return } } else { if _, err = tx.Stmt(db.subscribeStmts.setUpdateScopesOnly).Exec(rec.Scopes, rec.ReportAt, rec.User, rec.Client); err != nil { if err == sql.ErrNoRows { - err = ErrNotFound + err = api.ErrNotFound } return } @@ -685,7 +690,7 @@ func (db *SqlDB) RemoveSubscribe(user string, client string) (err error) { if _, err = db.subscribeStmts.remove.ExecContext(ctx, user, client); err != nil { if err == sql.ErrNoRows { - err = ErrNotFound + err = api.ErrNotFound } return } @@ -706,7 +711,12 @@ func (db *SqlDB) ForEachSubscribe(cb func(*api.SubscribeRecord) error) (err erro if err = rows.Scan(&rec.User, &rec.Client, &rec.EndPoint, &rec.Keys, &rec.Scopes, &rec.ReportAt, &rec.LastReport); err != nil { return } - cb(&rec) + if err = cb(&rec); err != nil { + if err == api.ErrStopIter { + return nil + } + return + } } if err = rows.Err(); err != nil { return @@ -866,7 +876,7 @@ func (db *SqlDB) GetEmailSubscription(user string, addr string) (rec *api.EmailS rec.Addr = addr if err = db.emailSubscriptionStmts.get.QueryRowContext(ctx, user, addr).Scan(&rec.Scopes, &rec.Enabled); err != nil { if err == sql.ErrNoRows { - err = ErrNotFound + err = api.ErrNotFound } return } @@ -909,7 +919,7 @@ func (db *SqlDB) RemoveEmailSubscription(user string, addr string) (err error) { if _, err = db.emailSubscriptionStmts.remove.ExecContext(ctx, user, addr); err != nil { if err == sql.ErrNoRows { - err = ErrNotFound + err = api.ErrNotFound } return } @@ -930,7 +940,12 @@ func (db *SqlDB) ForEachEmailSubscription(cb func(*api.EmailSubscriptionRecord) if err = rows.Scan(&rec.User, &rec.Addr, &rec.Scopes, &rec.Enabled); err != nil { return } - cb(&rec) + if err = cb(&rec); err != nil { + if err == api.ErrStopIter { + return nil + } + return + } } if err = rows.Err(); err != nil { return @@ -953,7 +968,12 @@ func (db *SqlDB) ForEachUsersEmailSubscription(user string, cb func(*api.EmailSu if err = rows.Scan(&rec.Addr, &rec.Scopes, &rec.Enabled); err != nil { return } - cb(&rec) + if err = cb(&rec); err != nil { + if err == api.ErrStopIter { + return nil + } + return + } } if err = rows.Err(); err != nil { return @@ -975,7 +995,12 @@ func (db *SqlDB) ForEachEnabledEmailSubscription(cb func(*api.EmailSubscriptionR if err = rows.Scan(&rec.User, &rec.Addr, &rec.Scopes); err != nil { return } - cb(&rec) + if err = cb(&rec); err != nil { + if err == api.ErrStopIter { + return nil + } + return + } } if err = rows.Err(); err != nil { return @@ -1153,7 +1178,7 @@ func (db *SqlDB) GetWebhook(user string, id uuid.UUID) (rec *api.WebhookRecord, rec.Id = id if err = db.webhookStmts.get.QueryRowContext(ctx, user, hex.EncodeToString(id[:])).Scan(&rec.Name, &rec.EndPoint, &rec.Auth, &rec.Scopes, &rec.Enabled); err != nil { if err == sql.ErrNoRows { - err = ErrNotFound + err = api.ErrNotFound } return } @@ -1205,7 +1230,7 @@ func (db *SqlDB) RemoveWebhook(user string, id uuid.UUID) (err error) { if _, err = db.webhookStmts.remove.ExecContext(ctx, user, hex.EncodeToString(id[:])); err != nil { if err == sql.ErrNoRows { - err = ErrNotFound + err = api.ErrNotFound } return } @@ -1226,7 +1251,12 @@ func (db *SqlDB) ForEachWebhook(cb func(*api.WebhookRecord) error) (err error) { if err = rows.Scan(&rec.User, &rec.Id, &rec.Name, &rec.EndPoint, &rec.Auth, &rec.Scopes, &rec.Enabled); err != nil { return } - cb(&rec) + if err = cb(&rec); err != nil { + if err == api.ErrStopIter { + return nil + } + return + } } if err = rows.Err(); err != nil { return @@ -1249,7 +1279,12 @@ func (db *SqlDB) ForEachUsersWebhook(user string, cb func(*api.WebhookRecord) er if err = rows.Scan(&rec.Id, &rec.Name, &rec.EndPoint, &rec.Auth, &rec.Scopes, &rec.Enabled, &rec.User); err != nil { return } - cb(&rec) + if err = cb(&rec); err != nil { + if err == api.ErrStopIter { + return nil + } + return + } } if err = rows.Err(); err != nil { return @@ -1271,7 +1306,12 @@ func (db *SqlDB) ForEachEnabledWebhook(cb func(*api.WebhookRecord) error) (err e if err = rows.Scan(&rec.User, &rec.Id, &rec.Name, &rec.EndPoint, &rec.Auth, &rec.Scopes); err != nil { return } - cb(&rec) + if err = cb(&rec); err != nil { + if err == api.ErrStopIter { + return nil + } + return + } } if err = rows.Err(); err != nil { return diff --git a/runner.go b/runner.go index 41251c2..9d60979 100644 --- a/runner.go +++ b/runner.go @@ -111,7 +111,13 @@ func NewRunner() *Runner { } } - // r.userManager = + r.userManager = &singleUserManager{ + user: &api.User{ + Username: r.Config.Dashboard.Username, + Password: r.Config.Dashboard.Password, + Permissions: api.RootPerm, + }, + } if apiHMACKey, err := utils.LoadOrCreateHmacKey(dataDir, "server"); err != nil { log.Errorf("Cannot load HMAC key: %v", err) os.Exit(1) @@ -718,3 +724,66 @@ type subscriptionManager struct { func (s *subscriptionManager) GetWebPushKey() string { return base64.RawURLEncoding.EncodeToString(s.webpushPlg.GetPublicKey()) } + +type singleUserManager struct { + user *api.User +} + +func (m *singleUserManager) GetUsers() []*api.User { + return []*api.User{m.user} +} + +func (m *singleUserManager) GetUser(id string) *api.User { + if id == m.user.Username { + return m.user + } + return nil +} + +func (m *singleUserManager) AddUser(user *api.User) error { + if user.Username == m.user.Username { + return api.ErrExist + } + return errors.New("Not implemented") +} + +func (m *singleUserManager) RemoveUser(id string) error { + if id != m.user.Username { + return api.ErrNotFound + } + return errors.New("Not implemented") +} + +func (m *singleUserManager) ForEachUser(cb func(*api.User) error) error { + err := cb(m.user) + if err == api.ErrStopIter { + return nil + } + return err +} + +func (m *singleUserManager) UpdateUserPassword(username string, password string) error { + if username != m.user.Username { + return api.ErrNotFound + } + m.user.Password = password + return nil +} + +func (m *singleUserManager) UpdateUserPermissions(username string, permissions api.PermissionFlag) error { + if username != m.user.Username { + return api.ErrNotFound + } + m.user.Permissions = permissions + return nil +} + +func (m *singleUserManager) VerifyUserPassword(userId string, comparator func(password string) bool) error { + if userId != m.user.Username { + return errors.New("Username or password is incorrect") + } + if !comparator(m.user.Password) { + return errors.New("Username or password is incorrect") + } + return nil +}