Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(api): better user passwords handling #25

Merged
merged 1 commit into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 81 additions & 31 deletions internal/auth/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ CREATE TABLE IF NOT EXISTS users (
locked_until DATETIME, -- If set, user is locked until this time.
created_at DATETIME NOT NULL, -- User creation timestamp.
updated_at DATETIME NOT NULL, -- Timestamp for the last update.
password_changed_at DATETIME, -- Tracks last password change time.
previous_passwords TEXT -- JSON array of hashed previous passwords.
password_changed_at DATETIME -- Tracks last password change time.
);

CREATE TABLE IF NOT EXISTS tokens (
Expand Down Expand Up @@ -57,12 +56,23 @@ CREATE TABLE IF NOT EXISTS audit_logs (
created_at DATETIME NOT NULL, -- Timestamp when the action occurred.
FOREIGN KEY (user_id) REFERENCES users (id) -- Foreign key linking to users.
);

CREATE TABLE IF NOT EXISTS password_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
password_hash TEXT NOT NULL,
created_at DATETIME NOT NULL,
FOREIGN KEY (user_id) REFERENCES users (id)
);

CREATE INDEX IF NOT EXISTS idx_users_username ON users(username);
CREATE INDEX IF NOT EXISTS idx_tokens_user_id ON tokens(user_id);
CREATE INDEX IF NOT EXISTS idx_tokens_jti ON tokens(jti);
CREATE INDEX IF NOT EXISTS idx_tokens_expires_at ON tokens(expires_at);
CREATE INDEX IF NOT EXISTS idx_audit_logs_user_id ON audit_logs(user_id);
CREATE INDEX IF NOT EXISTS idx_audit_logs_created_at ON audit_logs(created_at);`
CREATE INDEX IF NOT EXISTS idx_audit_logs_created_at ON audit_logs(created_at);
CREATE INDEX IF NOT EXISTS idx_password_history_user_id ON password_history(user_id);
`

type SQLiteDB struct {
db *sql.DB
Expand Down Expand Up @@ -133,12 +143,10 @@ func (s *SQLiteDB) GetUserByUsername(username string) (*models.User, error) {
// GetUserByID retrieves a user by their ID.
func (s *SQLiteDB) GetUserByID(id int64) (*models.User, error) {
var user models.User
var previousPasswords sql.NullString

err := s.db.QueryRow(`
SELECT id, username, password, role, last_login_at, last_login_ip,
failed_attempts, locked_until, password_changed_at,
previous_passwords, created_at, updated_at
SELECT id, username, password, role, last_login_at,
last_login_ip, failed_attempts, locked_until,
password_changed_at, created_at, updated_at
FROM users
WHERE id = ?
`, id).Scan(
Expand All @@ -151,21 +159,13 @@ func (s *SQLiteDB) GetUserByID(id int64) (*models.User, error) {
&user.FailedAttempts,
&user.LockedUntil,
&user.PasswordChangedAt,
&previousPasswords,
&user.CreatedAt,
&user.UpdatedAt,
)
if err != nil {
return nil, err
}

if previousPasswords.Valid {
user.PreviousPasswords = previousPasswords.String
} else {
ea, _ := json.Marshal([]string{})
user.PreviousPasswords = string(ea)
}

return &user, nil
}

Expand All @@ -184,7 +184,6 @@ func (s *SQLiteDB) UpdateUser(user *models.User) error {
return err
}

// Token methods
// CreateToken inserts a new token into the tokens table.
func (s *SQLiteDB) CreateToken(token *models.Token) error {
_, err := s.db.Exec(`
Expand Down Expand Up @@ -248,7 +247,6 @@ func (s *SQLiteDB) CleanupExpiredTokens() error {
return err
}

// Audit methods
// CreateAuditLog inserts a new audit log into the audit_logs table.
func (s *SQLiteDB) CreateAuditLog(log *models.AuditLog) error {
_, err := s.db.Exec(`
Expand Down Expand Up @@ -304,19 +302,6 @@ func (s *SQLiteDB) GetUserSessions(userID int64) ([]models.Session, error) {
return sessions, nil
}

// UpdateUserPassword updates a user's password and previous passwords.
func (s *SQLiteDB) UpdateUserPassword(user *models.User) error {
_, err := s.db.Exec(`
UPDATE users SET
password = ?,
password_changed_at = ?,
previous_passwords = ?,
updated_at = ?
WHERE id = ?
`, user.Password, user.PasswordChangedAt, user.PreviousPasswords, user.UpdatedAt, user.ID)
return err
}

// RevokeAllUserTokens revokes all active tokens for a user.
func (s *SQLiteDB) RevokeAllUserTokens(userID int64) error {
now := time.Now()
Expand Down Expand Up @@ -392,3 +377,68 @@ func (s *SQLiteDB) GetTokenByRefreshToken(refreshToken string, userID int64) (*m
}
return &token, nil
}

// AddPasswordToHistory adds a password hash to the user's password history
func (s *SQLiteDB) AddPasswordToHistory(userID int64, passwordHash string) error {
_, err := s.db.Exec(`
INSERT INTO password_history (user_id, password_hash, created_at)
VALUES (?, ?, ?)
`, userID, passwordHash, time.Now())

return err
}

// GetPasswordHistory retrieves the password history for a user
func (s *SQLiteDB) GetPasswordHistory(userID int64, limit int) ([]string, error) {
rows, err := s.db.Query(`
SELECT password_hash
FROM password_history
WHERE user_id = ?
ORDER BY created_at DESC
LIMIT ?
`, userID, limit)
if err != nil {
return nil, err
}
defer rows.Close()

var passwords []string
for rows.Next() {
var hash string
if err := rows.Scan(&hash); err != nil {
return nil, err
}
passwords = append(passwords, hash)
}

return passwords, nil
}

// CleanupOldPasswords removes old password entries keeping only the latest n entries
func (s *SQLiteDB) CleanupOldPasswords(userID int64, keep int) error {
_, err := s.db.Exec(`
DELETE FROM password_history
WHERE user_id = ?
AND id NOT IN (
SELECT id
FROM password_history
WHERE user_id = ?
ORDER BY created_at DESC
LIMIT ?
)
`, userID, userID, keep)

return err
}

// UpdateUserPassword updates a user's password and previous passwords.
func (s *SQLiteDB) UpdateUserPassword(user *models.User) error {
_, err := s.db.Exec(`
UPDATE users SET
password = ?,
password_changed_at = ?,
updated_at = ?
WHERE id = ?
`, user.Password, user.PasswordChangedAt, user.UpdatedAt, user.ID)
return err
}
12 changes: 0 additions & 12 deletions internal/auth/models/models.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package models

import (
"encoding/json"
"time"
)

Expand All @@ -24,7 +23,6 @@ type User struct {
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
PasswordChangedAt time.Time `json:"password_changed_at"`
PreviousPasswords string `json:"-"`
}

type Token struct {
Expand Down Expand Up @@ -60,13 +58,3 @@ type Session struct {
ClientInfo string `json:"client_info"`
Active bool `json:"active"`
}

func (u *User) GetPreviousPasswords() []string {
var passwords []string
if u.PreviousPasswords != "" {
if err := json.Unmarshal([]byte(u.PreviousPasswords), &passwords); err != nil {
return []string{}
}
}
return passwords
}
27 changes: 17 additions & 10 deletions internal/auth/service/auth_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,38 +151,45 @@ func (s *AuthService) ChangePassword(userID int64, oldPassword, newPassword stri
return fmt.Errorf("invalid new password: %w", err)
}

previousPasswords := user.GetPreviousPasswords()
// Get password history
previousPasswords, err := s.db.GetPasswordHistory(userID, s.passwordHistory)
if err != nil {
return err
}

// Check password history
if err := s.ValidatePasswordHistory(newPassword, previousPasswords); err != nil {
return err
}

// Hash new password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return err
}

previousPasswords = append(previousPasswords, user.Password)
if len(previousPasswords) > s.passwordHistory {
previousPasswords = previousPasswords[len(previousPasswords)-s.passwordHistory:]
// Add current password to history before updating
if err := s.db.AddPasswordToHistory(userID, user.Password); err != nil {
return err
}

passwordHistoryJSON, err := json.Marshal(previousPasswords)
if err != nil {
// Cleanup old passwords
if err := s.db.CleanupOldPasswords(userID, s.passwordHistory); err != nil {
return err
}

now := time.Now()
user.Password = string(hashedPassword)
user.PasswordChangedAt = now
user.UpdatedAt = now
user.PreviousPasswords = string(passwordHistoryJSON)

// Revoke all existing tokens associated with the user to enforce the password change.
if err := s.db.RevokeAllUserTokens(userID); err != nil {
// Update user record with new password
if err := s.db.UpdateUserPassword(user); err != nil {
return err
}

return s.db.UpdateUserPassword(user)
// Revoke all existing tokens associated with the user to enforce the password change.
return s.db.RevokeAllUserTokens(userID)
}

// CreateUser registers a new user with the provided username, password, and role.
Expand Down
Loading