diff --git a/OIDC_INTEGRATION.md b/OIDC_INTEGRATION.md new file mode 100644 index 0000000..184cd13 --- /dev/null +++ b/OIDC_INTEGRATION.md @@ -0,0 +1,180 @@ +# OpenID Connect (OIDC) Integration + +This document describes the OIDC integration added to pgbackweb to support authentication via external providers like Authentik, Keycloak, and other OIDC-compliant identity providers. + +## Overview + +The OIDC integration allows users to authenticate using external identity providers instead of local username/password combinations. This supports enterprise SSO workflows and centralized user management. + +## Features + +- Support for any OIDC-compliant identity provider +- Automatic user creation on first login +- User information synchronization on each login +- Seamless integration with existing authentication system +- Configurable user attribute mapping + +## Configuration + +Add the following environment variables to enable OIDC: + +```bash +# Enable OIDC authentication +PBW_OIDC_ENABLED=true + +# OIDC Provider Configuration +PBW_OIDC_ISSUER_URL=https://your-provider.com/auth/realms/your-realm +PBW_OIDC_CLIENT_ID=pgbackweb +PBW_OIDC_CLIENT_SECRET=your-client-secret +PBW_OIDC_REDIRECT_URL=https://your-domain.com/auth/oidc/callback + +# Optional: Customize OIDC scopes (default: "openid profile email") +PBW_OIDC_SCOPES="openid profile email" + +# Optional: Customize claim mappings +PBW_OIDC_USERNAME_CLAIM=preferred_username # default: preferred_username +PBW_OIDC_EMAIL_CLAIM=email # default: email +PBW_OIDC_NAME_CLAIM=name # default: name +``` + +## Provider-Specific Setup + +### Authentik + +1. Create a new **OAuth2/OpenID Provider** in Authentik +2. Set the **Redirect URI** to: `https://your-domain.com/auth/oidc/callback` +3. Configure the **Client Type** as **Confidential** +4. Note the **Client ID** and **Client Secret** +5. Create a new **Application** and link it to the provider +6. Configure the **Issuer URL**: `https://your-authentik.com/application/o/your-app/` + +### Keycloak + +1. Create a new **Client** in your Keycloak realm +2. Set **Client Protocol** to `openid-connect` +3. Set **Access Type** to `confidential` +4. Add `https://your-domain.com/auth/oidc/callback` to **Valid Redirect URIs** +5. Note the **Client ID** and get the **Client Secret** from the Credentials tab +6. Configure the **Issuer URL**: `https://your-keycloak.com/auth/realms/your-realm` + +### Generic OIDC Provider + +For any OIDC-compliant provider: + +1. Create a new OIDC client/application +2. Set the redirect URI to: `https://your-domain.com/auth/oidc/callback` +3. Ensure the client can access `openid`, `profile`, and `email` scopes +4. Note the issuer URL (usually ends with `/.well-known/openid_configuration`) + +## Database Schema Changes + +The OIDC integration adds the following columns to the `users` table: + +```sql +ALTER TABLE users +ADD COLUMN oidc_provider TEXT, +ADD COLUMN oidc_subject TEXT; + +-- Make password nullable for OIDC users +ALTER TABLE users ALTER COLUMN password DROP NOT NULL; + +-- Create unique index for OIDC users +CREATE UNIQUE INDEX users_oidc_provider_subject_idx +ON users (oidc_provider, oidc_subject) +WHERE oidc_provider IS NOT NULL AND oidc_subject IS NOT NULL; + +-- Ensure users have either password or OIDC authentication +ALTER TABLE users ADD CONSTRAINT users_auth_method_check +CHECK ( + (password IS NOT NULL AND oidc_provider IS NULL AND oidc_subject IS NULL) OR + (password IS NULL AND oidc_provider IS NOT NULL AND oidc_subject IS NOT NULL) +); +``` + +## User Flow + +1. **First-time users**: When an OIDC user logs in for the first time, a new user account is automatically created with information from the OIDC provider. + +2. **Returning users**: Existing OIDC users are matched by their provider and subject ID. User information (name, email) is updated from the OIDC provider on each login. + +3. **Mixed authentication**: The system supports both local users (with passwords) and OIDC users in the same instance. + +## Security Considerations + +- **State parameter**: CSRF protection using a random state parameter +- **Token validation**: ID tokens are cryptographically verified +- **Secure cookies**: State is stored in secure, HTTP-only cookies +- **Provider validation**: Only configured OIDC providers are accepted + +## User Interface + +When OIDC is enabled, the login page displays: +- A "Login with SSO" button at the top +- A divider separating SSO from traditional login +- The existing email/password form below + +## Implementation Details + +### Services + +- **`internal/service/oidc/`**: Core OIDC authentication logic +- **`internal/view/web/oidc/`**: Web routes for OIDC authentication flow +- **`internal/config/`**: Environment variable configuration and validation + +### Routes + +- `GET /auth/oidc/login`: Initiates OIDC authentication flow +- `GET /auth/oidc/callback`: Handles OIDC provider callback + +### Database Queries + +- `OIDCServiceCreateUser`: Creates a new OIDC user +- `OIDCServiceGetUserByOIDC`: Retrieves user by provider and subject +- `OIDCServiceUpdateUser`: Updates existing OIDC user information + +## Troubleshooting + +### Common Issues + +1. **Invalid redirect URI**: Ensure the redirect URI in your OIDC provider matches exactly: `https://your-domain.com/auth/oidc/callback` + +2. **Certificate errors**: If using self-signed certificates, ensure your Go application trusts the certificates + +3. **Claim mapping**: Verify that your OIDC provider returns the expected claims (`email`, `name`, `preferred_username`) + +4. **Scopes**: Ensure your OIDC client has access to the required scopes (`openid`, `profile`, `email`) + +### Debug Logging + +The application logs OIDC authentication events. Check logs for: +- OIDC provider initialization errors +- Token exchange failures +- User creation/update events + +## Migration from Local Authentication + +Existing local users are unaffected by OIDC integration. To migrate users to OIDC: + +1. Enable OIDC authentication +2. Users can continue using local authentication or switch to OIDC +3. No automatic migration is performed - users choose their preferred method + +## Development + +To run the application with OIDC in development: + +```bash +# Set environment variables in .env file +echo "PBW_OIDC_ENABLED=true" >> .env +echo "PBW_OIDC_ISSUER_URL=https://your-dev-provider.com" >> .env +# ... other OIDC variables + +# Run database migrations +task migrate up + +# Generate database code +task gen-db + +# Build and run +task dev +``` diff --git a/cmd/app/main.go b/cmd/app/main.go index cbb179c..50fd45e 100644 --- a/cmd/app/main.go +++ b/cmd/app/main.go @@ -43,7 +43,10 @@ func main() { dbgen := dbgen.New(db) ints := integration.New() - servs := service.New(env, dbgen, cr, ints) + servs, err := service.New(env, dbgen, cr, ints) + if err != nil { + logger.FatalError("error initializing services", logger.KV{"error": err}) + } initSchedule(cr, servs) app := echo.New() diff --git a/cmd/changepw/main.go b/cmd/changepw/main.go index b6c1d8c..d2b5567 100644 --- a/cmd/changepw/main.go +++ b/cmd/changepw/main.go @@ -60,7 +60,7 @@ func main() { err = dbg.UsersServiceChangePassword( context.Background(), dbgen.UsersServiceChangePasswordParams{ ID: userID, - Password: hashedPassword, + Password: sql.NullString{String: hashedPassword, Valid: true}, }, ) if err != nil { diff --git a/go.mod b/go.mod index 755a35d..cc49afb 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.49 github.com/aws/aws-sdk-go-v2/service/s3 v1.72.3 github.com/caarlos0/env/v11 v11.3.1 + github.com/coreos/go-oidc/v3 v3.14.1 github.com/go-co-op/gocron/v2 v2.11.0 github.com/go-playground/validator/v10 v10.22.0 github.com/google/uuid v1.6.0 @@ -21,9 +22,10 @@ require ( github.com/nodxdev/nodxgo-htmx v0.1.0 github.com/nodxdev/nodxgo-lucide v0.1.1 github.com/orsinium-labs/enum v1.4.0 - github.com/stretchr/testify v1.9.0 - golang.org/x/crypto v0.25.0 - golang.org/x/sync v0.7.0 + github.com/stretchr/testify v1.10.0 + golang.org/x/crypto v0.36.0 + golang.org/x/oauth2 v0.30.0 + golang.org/x/sync v0.12.0 ) require ( @@ -43,6 +45,7 @@ require ( github.com/aws/smithy-go v1.22.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect + github.com/go-jose/go-jose/v4 v4.0.5 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/jonboulle/clockwork v0.4.0 // indirect @@ -56,8 +59,8 @@ require ( github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasttemplate v1.2.2 // indirect golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 // indirect - golang.org/x/net v0.27.0 // indirect - golang.org/x/sys v0.22.0 // indirect - golang.org/x/text v0.16.0 // indirect + golang.org/x/net v0.37.0 // indirect + golang.org/x/sys v0.31.0 // indirect + golang.org/x/text v0.23.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 2907ba0..1686b36 100644 --- a/go.sum +++ b/go.sum @@ -40,6 +40,8 @@ github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/caarlos0/env/v11 v11.3.1 h1:cArPWC15hWmEt+gWk7YBi7lEXTXCvpaSdCiZE2X5mCA= github.com/caarlos0/env/v11 v11.3.1/go.mod h1:qupehSf/Y0TUTsxKywqRt/vJjN5nz6vauiYEUUr8P4U= +github.com/coreos/go-oidc/v3 v3.14.1 h1:9ePWwfdwC4QKRlCXsJGou56adA/owXczOzwKdOumLqk= +github.com/coreos/go-oidc/v3 v3.14.1/go.mod h1:HaZ3szPaZ0e4r6ebqvsLWlk2Tn+aejfmrfah6hnSYEU= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -47,6 +49,8 @@ github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uq github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/go-co-op/gocron/v2 v2.11.0 h1:IOowNA6SzwdRFnD4/Ol3Kj6G2xKfsoiiGq2Jhhm9bvE= github.com/go-co-op/gocron/v2 v2.11.0/go.mod h1:xY7bJxGazKam1cz04EebrlP4S9q4iWdiAylMGP3jY9w= +github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE= +github.com/go-jose/go-jose/v4 v4.0.5/go.mod h1:s3P1lRrkT8igV8D9OjyL4WRyHvjB6a4JSllnOrmmBOA= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= @@ -55,6 +59,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.22.0 h1:k6HsTZ0sTnROkhS//R0O+55JgM8C4Bx7ia+JlgcnOao= github.com/go-playground/validator/v10 v10.22.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= @@ -97,28 +103,30 @@ github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= -golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 h1:yixxcjnhBmY0nkL253HFVIm0JsFHwrHdT3Yh6szTnfY= golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8/go.mod h1:jj3sYF3dwk5D+ghuXyeI3r5MFf+NT2An6/9dOA95KSI= -golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= -golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= +golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= +golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= -golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/internal/config/env.go b/internal/config/env.go index 31d3985..ea9957a 100644 --- a/internal/config/env.go +++ b/internal/config/env.go @@ -12,6 +12,17 @@ type Env struct { PBW_POSTGRES_CONN_STRING string `env:"PBW_POSTGRES_CONN_STRING,required"` PBW_LISTEN_HOST string `env:"PBW_LISTEN_HOST" envDefault:"0.0.0.0"` PBW_LISTEN_PORT string `env:"PBW_LISTEN_PORT" envDefault:"8085"` + + // OIDC Configuration + PBW_OIDC_ENABLED bool `env:"PBW_OIDC_ENABLED" envDefault:"false"` + PBW_OIDC_ISSUER_URL string `env:"PBW_OIDC_ISSUER_URL"` + PBW_OIDC_CLIENT_ID string `env:"PBW_OIDC_CLIENT_ID"` + PBW_OIDC_CLIENT_SECRET string `env:"PBW_OIDC_CLIENT_SECRET"` + PBW_OIDC_REDIRECT_URL string `env:"PBW_OIDC_REDIRECT_URL"` + PBW_OIDC_SCOPES string `env:"PBW_OIDC_SCOPES" envDefault:"openid profile email"` + PBW_OIDC_USERNAME_CLAIM string `env:"PBW_OIDC_USERNAME_CLAIM" envDefault:"preferred_username"` + PBW_OIDC_EMAIL_CLAIM string `env:"PBW_OIDC_EMAIL_CLAIM" envDefault:"email"` + PBW_OIDC_NAME_CLAIM string `env:"PBW_OIDC_NAME_CLAIM" envDefault:"name"` PBW_PATH_PREFIX string `env:"PBW_PATH_PREFIX" envDefault:""` } diff --git a/internal/config/env_validate.go b/internal/config/env_validate.go index 07aa366..a2dbcfa 100644 --- a/internal/config/env_validate.go +++ b/internal/config/env_validate.go @@ -16,6 +16,20 @@ func validateEnv(env Env) error { return fmt.Errorf("invalid listen port %s, valid values are 1-65535", env.PBW_LISTEN_PORT) } + // Validate OIDC configuration if enabled + if env.PBW_OIDC_ENABLED { + if env.PBW_OIDC_ISSUER_URL == "" { + return fmt.Errorf("PBW_OIDC_ISSUER_URL is required when OIDC is enabled") + } + if env.PBW_OIDC_CLIENT_ID == "" { + return fmt.Errorf("PBW_OIDC_CLIENT_ID is required when OIDC is enabled") + } + if env.PBW_OIDC_CLIENT_SECRET == "" { + return fmt.Errorf("PBW_OIDC_CLIENT_SECRET is required when OIDC is enabled") + } + if env.PBW_OIDC_REDIRECT_URL == "" { + return fmt.Errorf("PBW_OIDC_REDIRECT_URL is required when OIDC is enabled") + } if !validate.PathPrefix(env.PBW_PATH_PREFIX) { return fmt.Errorf("invalid path prefix %s, must start with / and not end with / (or be empty)", env.PBW_PATH_PREFIX) } diff --git a/internal/database/migrations/20250708000000_add_oidc_support_to_users.sql b/internal/database/migrations/20250708000000_add_oidc_support_to_users.sql new file mode 100644 index 0000000..1d1606c --- /dev/null +++ b/internal/database/migrations/20250708000000_add_oidc_support_to_users.sql @@ -0,0 +1,35 @@ +-- +goose Up +-- +goose StatementBegin +ALTER TABLE users +ADD COLUMN oidc_provider TEXT, +ADD COLUMN oidc_subject TEXT, +ADD COLUMN password_nullable TEXT; + +-- Make password nullable and copy existing passwords +UPDATE users SET password_nullable = password; +ALTER TABLE users DROP COLUMN password; +ALTER TABLE users RENAME COLUMN password_nullable TO password; + +-- Create unique index for OIDC users +CREATE UNIQUE INDEX users_oidc_provider_subject_idx +ON users (oidc_provider, oidc_subject) +WHERE oidc_provider IS NOT NULL AND oidc_subject IS NOT NULL; + +-- Add constraint to ensure either password or OIDC is set +ALTER TABLE users ADD CONSTRAINT users_auth_method_check +CHECK ( + (password IS NOT NULL AND oidc_provider IS NULL AND oidc_subject IS NULL) OR + (password IS NULL AND oidc_provider IS NOT NULL AND oidc_subject IS NOT NULL) +); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DROP INDEX IF EXISTS users_oidc_provider_subject_idx; +ALTER TABLE users DROP CONSTRAINT IF EXISTS users_auth_method_check; +ALTER TABLE users DROP COLUMN IF EXISTS oidc_provider; +ALTER TABLE users DROP COLUMN IF EXISTS oidc_subject; + +-- Make password required again (this will fail if there are OIDC users) +ALTER TABLE users ALTER COLUMN password SET NOT NULL; +-- +goose StatementEnd diff --git a/internal/service/auth/cookies.go b/internal/service/auth/cookies.go index 8a0e82a..cc76988 100644 --- a/internal/service/auth/cookies.go +++ b/internal/service/auth/cookies.go @@ -18,6 +18,8 @@ func (s *Service) SetSessionCookie(c echo.Context, token string) { Value: token, MaxAge: int(maxSessionAge.Seconds()), HttpOnly: true, + Secure: true, // Force HTTPS + SameSite: http.SameSiteLaxMode, Path: "/", } c.SetCookie(&cookie) @@ -29,6 +31,8 @@ func (s *Service) ClearSessionCookie(c echo.Context) { Value: "", MaxAge: -1, HttpOnly: true, + Secure: true, // Force HTTPS + SameSite: http.SameSiteLaxMode, Path: "/", } c.SetCookie(&cookie) diff --git a/internal/service/auth/login.go b/internal/service/auth/login.go index f05d57b..20b3b69 100644 --- a/internal/service/auth/login.go +++ b/internal/service/auth/login.go @@ -17,7 +17,11 @@ func (s *Service) Login( return dbgen.AuthServiceLoginCreateSessionRow{}, err } - if err := cryptoutil.VerifyBcryptHash(password, user.Password); err != nil { + if !user.Password.Valid { + return dbgen.AuthServiceLoginCreateSessionRow{}, fmt.Errorf("user has no password set") + } + + if err := cryptoutil.VerifyBcryptHash(password, user.Password.String); err != nil { return dbgen.AuthServiceLoginCreateSessionRow{}, fmt.Errorf("invalid password") } diff --git a/internal/service/auth/login_oidc.go b/internal/service/auth/login_oidc.go new file mode 100644 index 0000000..d70cd4e --- /dev/null +++ b/internal/service/auth/login_oidc.go @@ -0,0 +1,27 @@ +package auth + +import ( + "context" + + "github.com/eduardolat/pgbackweb/internal/database/dbgen" + "github.com/google/uuid" +) + +func (s *Service) LoginOIDC( + ctx context.Context, userID uuid.UUID, ip, userAgent string, +) (dbgen.AuthServiceLoginCreateSessionRow, error) { + session, err := s.dbgen.AuthServiceLoginCreateSession( + ctx, dbgen.AuthServiceLoginCreateSessionParams{ + UserID: userID, + Ip: ip, + UserAgent: userAgent, + Token: uuid.NewString(), + EncryptionKey: s.env.PBW_ENCRYPTION_KEY, + }, + ) + if err != nil { + return dbgen.AuthServiceLoginCreateSessionRow{}, err + } + + return session, nil +} diff --git a/internal/service/oidc/oidc.go b/internal/service/oidc/oidc.go new file mode 100644 index 0000000..5264929 --- /dev/null +++ b/internal/service/oidc/oidc.go @@ -0,0 +1,191 @@ +package oidc + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/base64" + "errors" + "fmt" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/eduardolat/pgbackweb/internal/config" + "github.com/eduardolat/pgbackweb/internal/database/dbgen" + "golang.org/x/oauth2" +) + +// Custom error types for better error handling +var ( + ErrEmailAlreadyExists = errors.New("email already exists with different authentication method") + ErrOIDCNotEnabled = errors.New("OIDC is not enabled") + ErrInvalidToken = errors.New("invalid or expired token") + ErrMissingClaims = errors.New("required user information missing from OIDC claims") +) + +type Service struct { + env config.Env + dbgen *dbgen.Queries + provider *oidc.Provider + config oauth2.Config +} + +type UserInfo struct { + Email string + Name string + Username string + Subject string +} + +func New(env config.Env, dbgen *dbgen.Queries) (*Service, error) { + if !env.PBW_OIDC_ENABLED { + return &Service{env: env, dbgen: dbgen}, nil + } + + ctx := context.Background() + provider, err := oidc.NewProvider(ctx, env.PBW_OIDC_ISSUER_URL) + if err != nil { + return nil, fmt.Errorf("failed to create OIDC provider: %w", err) + } + + scopes := strings.Split(env.PBW_OIDC_SCOPES, " ") + config := oauth2.Config{ + ClientID: env.PBW_OIDC_CLIENT_ID, + ClientSecret: env.PBW_OIDC_CLIENT_SECRET, + RedirectURL: env.PBW_OIDC_REDIRECT_URL, + Endpoint: provider.Endpoint(), + Scopes: scopes, + } + + return &Service{ + env: env, + dbgen: dbgen, + provider: provider, + config: config, + }, nil +} + +func (s *Service) IsEnabled() bool { + return s.env.PBW_OIDC_ENABLED +} + +func (s *Service) GetAuthURL(state string) string { + if !s.IsEnabled() { + return "" + } + return s.config.AuthCodeURL(state) +} + +func (s *Service) GenerateState() (string, error) { + b := make([]byte, 32) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return base64.URLEncoding.EncodeToString(b), nil +} + +func (s *Service) ExchangeCode(ctx context.Context, code string) (*UserInfo, error) { + if !s.IsEnabled() { + return nil, ErrOIDCNotEnabled + } + + token, err := s.config.Exchange(ctx, code) + if err != nil { + return nil, fmt.Errorf("failed to exchange code: %w", err) + } + + rawIDToken, ok := token.Extra("id_token").(string) + if !ok { + return nil, fmt.Errorf("no id_token field in oauth2 token") + } + + verifier := s.provider.Verifier(&oidc.Config{ClientID: s.env.PBW_OIDC_CLIENT_ID}) + idToken, err := verifier.Verify(ctx, rawIDToken) + if err != nil { + return nil, fmt.Errorf("failed to verify ID token: %w", err) + } + + claims := make(map[string]interface{}) + if err := idToken.Claims(&claims); err != nil { + return nil, fmt.Errorf("failed to parse claims: %w", err) + } + + userInfo := &UserInfo{ + Subject: idToken.Subject, + } + + // Extract email + if email, ok := claims[s.env.PBW_OIDC_EMAIL_CLAIM].(string); ok { + userInfo.Email = strings.ToLower(email) + } + + // Extract name + if name, ok := claims[s.env.PBW_OIDC_NAME_CLAIM].(string); ok { + userInfo.Name = name + } + + // Extract username + if username, ok := claims[s.env.PBW_OIDC_USERNAME_CLAIM].(string); ok { + userInfo.Username = username + } + + // Fallback to email as username if username not provided + if userInfo.Username == "" && userInfo.Email != "" { + userInfo.Username = strings.Split(userInfo.Email, "@")[0] + } + + // Fallback to username as name if name not provided + if userInfo.Name == "" && userInfo.Username != "" { + userInfo.Name = userInfo.Username + } + + if userInfo.Email == "" || userInfo.Name == "" || userInfo.Subject == "" { + return nil, ErrMissingClaims + } + + return userInfo, nil +} + +func (s *Service) CreateOrUpdateUser(ctx context.Context, userInfo *UserInfo) (*dbgen.User, error) { + // Try to get existing OIDC user + _, err := s.dbgen.OIDCServiceGetUserByOIDC(ctx, dbgen.OIDCServiceGetUserByOIDCParams{ + OidcProvider: sql.NullString{String: "oidc", Valid: true}, + OidcSubject: sql.NullString{String: userInfo.Subject, Valid: true}, + }) + + if err == nil { + // OIDC user exists, update their information + user, err := s.dbgen.OIDCServiceUpdateUser(ctx, dbgen.OIDCServiceUpdateUserParams{ + Name: userInfo.Name, + Email: userInfo.Email, + OidcProvider: sql.NullString{String: "oidc", Valid: true}, + OidcSubject: sql.NullString{String: userInfo.Subject, Valid: true}, + }) + if err != nil { + return nil, fmt.Errorf("failed to update user: %w", err) + } + return &user, nil + } + + // OIDC user doesn't exist, check if regular user with same email exists + _, err = s.dbgen.AuthServiceLoginGetUserByEmail(ctx, strings.ToLower(userInfo.Email)) + if err == nil { + // Regular user with same email exists - we cannot create OIDC user + // This prevents account takeover and maintains data integrity + return nil, ErrEmailAlreadyExists + } + + // No existing user, create new OIDC user + user, err := s.dbgen.OIDCServiceCreateUser(ctx, dbgen.OIDCServiceCreateUserParams{ + Name: userInfo.Name, + Email: userInfo.Email, + OidcProvider: sql.NullString{String: "oidc", Valid: true}, + OidcSubject: sql.NullString{String: userInfo.Subject, Valid: true}, + }) + if err != nil { + return nil, fmt.Errorf("failed to create user: %w", err) + } + + return &user, nil +} diff --git a/internal/service/oidc/oidc_test.go b/internal/service/oidc/oidc_test.go new file mode 100644 index 0000000..fb4ee2d --- /dev/null +++ b/internal/service/oidc/oidc_test.go @@ -0,0 +1,242 @@ +package oidc + +import ( + "context" + "encoding/base64" + "testing" + + "github.com/eduardolat/pgbackweb/internal/config" + "github.com/stretchr/testify/assert" +) + +func TestNew(t *testing.T) { + t.Run("OIDC Disabled", func(t *testing.T) { + env := config.Env{ + PBW_OIDC_ENABLED: false, + } + + service, err := New(env, nil) + + assert.NoError(t, err) + assert.NotNil(t, service) + assert.False(t, service.IsEnabled()) + assert.Equal(t, env, service.env) + assert.Nil(t, service.provider) + }) + + t.Run("OIDC Enabled with Invalid Issuer", func(t *testing.T) { + env := config.Env{ + PBW_OIDC_ENABLED: true, + PBW_OIDC_ISSUER_URL: "invalid-url", + PBW_OIDC_CLIENT_ID: "test-client", + PBW_OIDC_CLIENT_SECRET: "test-secret", + PBW_OIDC_REDIRECT_URL: "http://localhost:8080/auth/oidc/callback", + PBW_OIDC_SCOPES: "openid profile email", + } + + service, err := New(env, nil) + + assert.Error(t, err) + assert.Nil(t, service) + assert.Contains(t, err.Error(), "failed to create OIDC provider") + }) +} + +func TestIsEnabled(t *testing.T) { + t.Run("Enabled", func(t *testing.T) { + service := &Service{ + env: config.Env{PBW_OIDC_ENABLED: true}, + } + assert.True(t, service.IsEnabled()) + }) + + t.Run("Disabled", func(t *testing.T) { + service := &Service{ + env: config.Env{PBW_OIDC_ENABLED: false}, + } + assert.False(t, service.IsEnabled()) + }) +} + +func TestGetAuthURL(t *testing.T) { + t.Run("OIDC Disabled", func(t *testing.T) { + service := &Service{ + env: config.Env{PBW_OIDC_ENABLED: false}, + } + + url := service.GetAuthURL("test-state") + assert.Empty(t, url) + }) + + t.Run("OIDC Enabled", func(t *testing.T) { + service := &Service{ + env: config.Env{PBW_OIDC_ENABLED: true}, + } + + // This will return empty string without proper config, but shows the behavior + _ = service.GetAuthURL("test-state") + // Without proper oauth2.Config, this will return empty string or panic + // In a real test, we'd need to mock or provide proper config + assert.True(t, service.IsEnabled()) + }) +} + +func TestGenerateState(t *testing.T) { + service := &Service{} + + t.Run("Success", func(t *testing.T) { + state, err := service.GenerateState() + + assert.NoError(t, err) + assert.NotEmpty(t, state) + assert.Greater(t, len(state), 20) // Base64 encoded 32 bytes should be > 20 chars + }) + + t.Run("Multiple Calls Generate Different States", func(t *testing.T) { + state1, err1 := service.GenerateState() + state2, err2 := service.GenerateState() + + assert.NoError(t, err1) + assert.NoError(t, err2) + assert.NotEqual(t, state1, state2) + }) +} + +func TestExchangeCode(t *testing.T) { + t.Run("OIDC Disabled", func(t *testing.T) { + service := &Service{ + env: config.Env{PBW_OIDC_ENABLED: false}, + } + + userInfo, err := service.ExchangeCode(context.Background(), "test-code") + + assert.Error(t, err) + assert.Nil(t, userInfo) + assert.Equal(t, ErrOIDCNotEnabled, err) + }) + + // Note: Testing the enabled case would require mocking the OIDC provider + // and OAuth2 flow, which is complex. In a real implementation, you'd + // want to use dependency injection to make these components testable. +} + +func TestErrorTypes(t *testing.T) { + t.Run("Error Messages", func(t *testing.T) { + assert.Equal(t, "email already exists with different authentication method", ErrEmailAlreadyExists.Error()) + assert.Equal(t, "OIDC is not enabled", ErrOIDCNotEnabled.Error()) + assert.Equal(t, "invalid or expired token", ErrInvalidToken.Error()) + assert.Equal(t, "required user information missing from OIDC claims", ErrMissingClaims.Error()) + }) +} + +func TestUserInfo(t *testing.T) { + t.Run("UserInfo Structure", func(t *testing.T) { + userInfo := UserInfo{ + Email: "test@example.com", + Name: "Test User", + Username: "testuser", + Subject: "test-subject-123", + } + + assert.Equal(t, "test@example.com", userInfo.Email) + assert.Equal(t, "Test User", userInfo.Name) + assert.Equal(t, "testuser", userInfo.Username) + assert.Equal(t, "test-subject-123", userInfo.Subject) + }) +} + +func TestService_AdvancedErrorHandling(t *testing.T) { + t.Run("GenerateState Multiple Calls", func(t *testing.T) { + service := &Service{} + + // Generate multiple states to ensure uniqueness + states := make(map[string]bool) + for i := 0; i < 100; i++ { + state, err := service.GenerateState() + assert.NoError(t, err) + assert.NotEmpty(t, state) + assert.False(t, states[state], "State should be unique") + states[state] = true + } + }) +} + +func TestService_ConfigurationValidation(t *testing.T) { + t.Run("Empty Service Struct", func(t *testing.T) { + service := &Service{} + + // Test methods on empty service + assert.False(t, service.IsEnabled()) + assert.Empty(t, service.GetAuthURL("test-state")) + + state, err := service.GenerateState() + assert.NoError(t, err) + assert.NotEmpty(t, state) + + // ExchangeCode should fail on empty service + userInfo, err := service.ExchangeCode(context.Background(), "test-code") + assert.Error(t, err) + assert.Equal(t, ErrOIDCNotEnabled, err) + assert.Nil(t, userInfo) + }) +} + +func TestService_ContextHandling(t *testing.T) { + t.Run("Canceled Context", func(t *testing.T) { + service := &Service{} + + // Create canceled context + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + // ExchangeCode should fail on disabled service (context won't matter) + userInfo, err := service.ExchangeCode(ctx, "test-code") + assert.Error(t, err) + assert.Equal(t, ErrOIDCNotEnabled, err) + assert.Nil(t, userInfo) + }) +} + +func TestService_StateGeneration(t *testing.T) { + t.Run("State Format Validation", func(t *testing.T) { + service := &Service{} + + state, err := service.GenerateState() + assert.NoError(t, err) + assert.NotEmpty(t, state) + + // Check that state is base64 URL encoded + decoded, err := base64.URLEncoding.DecodeString(state) + assert.NoError(t, err) + assert.Equal(t, 32, len(decoded)) // Should be 32 bytes + }) +} + +func TestUserInfo_StructValidation(t *testing.T) { + t.Run("UserInfo Fields", func(t *testing.T) { + userInfo := UserInfo{ + Email: "test@example.com", + Name: "Test User", + Username: "testuser", + Subject: "sub-123", + } + + assert.Equal(t, "test@example.com", userInfo.Email) + assert.Equal(t, "Test User", userInfo.Name) + assert.Equal(t, "testuser", userInfo.Username) + assert.Equal(t, "sub-123", userInfo.Subject) + }) + + t.Run("Empty UserInfo", func(t *testing.T) { + userInfo := UserInfo{} + + assert.Empty(t, userInfo.Email) + assert.Empty(t, userInfo.Name) + assert.Empty(t, userInfo.Username) + assert.Empty(t, userInfo.Subject) + }) +} + +// Integration tests would require a real database and proper setup +// These would be placed in a separate test file with build tags +// like // +build integration diff --git a/internal/service/oidc/queries.sql b/internal/service/oidc/queries.sql new file mode 100644 index 0000000..f3c2566 --- /dev/null +++ b/internal/service/oidc/queries.sql @@ -0,0 +1,18 @@ +-- name: OIDCServiceCreateUser :one +INSERT INTO users (name, email, oidc_provider, oidc_subject) +VALUES (@name, lower(@email), @oidc_provider, @oidc_subject) +RETURNING *; + +-- name: OIDCServiceGetUserByOIDC :one +SELECT * FROM users +WHERE oidc_provider = @oidc_provider AND oidc_subject = @oidc_subject; + +-- name: OIDCServiceGetUserByEmail :one +SELECT * FROM users +WHERE email = lower(@email); + +-- name: OIDCServiceUpdateUser :one +UPDATE users +SET name = @name, email = lower(@email), updated_at = NOW() +WHERE oidc_provider = @oidc_provider AND oidc_subject = @oidc_subject +RETURNING *; diff --git a/internal/service/service.go b/internal/service/service.go index 0500c44..8a48c84 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -10,6 +10,7 @@ import ( "github.com/eduardolat/pgbackweb/internal/service/databases" "github.com/eduardolat/pgbackweb/internal/service/destinations" "github.com/eduardolat/pgbackweb/internal/service/executions" + "github.com/eduardolat/pgbackweb/internal/service/oidc" "github.com/eduardolat/pgbackweb/internal/service/restorations" "github.com/eduardolat/pgbackweb/internal/service/users" "github.com/eduardolat/pgbackweb/internal/service/webhooks" @@ -21,6 +22,7 @@ type Service struct { DatabasesService *databases.Service DestinationsService *destinations.Service ExecutionsService *executions.Service + OIDCService *oidc.Service UsersService *users.Service RestorationsService *restorations.Service WebhooksService *webhooks.Service @@ -29,9 +31,13 @@ type Service struct { func New( env config.Env, dbgen *dbgen.Queries, cr *cron.Cron, ints *integration.Integration, -) *Service { +) (*Service, error) { webhooksService := webhooks.New(dbgen) authService := auth.New(env, dbgen) + oidcService, err := oidc.New(env, dbgen) + if err != nil { + return nil, err + } databasesService := databases.New(env, dbgen, ints, webhooksService) destinationsService := destinations.New(env, dbgen, ints, webhooksService) executionsService := executions.New(env, dbgen, ints, webhooksService) @@ -47,8 +53,9 @@ func New( DatabasesService: databasesService, DestinationsService: destinationsService, ExecutionsService: executionsService, + OIDCService: oidcService, UsersService: usersService, RestorationsService: restorationsService, WebhooksService: webhooksService, - } + }, nil } diff --git a/internal/service/users/create_user.go b/internal/service/users/create_user.go index c793869..931d399 100644 --- a/internal/service/users/create_user.go +++ b/internal/service/users/create_user.go @@ -2,6 +2,7 @@ package users import ( "context" + "database/sql" "github.com/eduardolat/pgbackweb/internal/database/dbgen" "github.com/eduardolat/pgbackweb/internal/util/cryptoutil" @@ -10,11 +11,19 @@ import ( func (s *Service) CreateUser( ctx context.Context, params dbgen.UsersServiceCreateUserParams, ) (dbgen.User, error) { - hash, err := cryptoutil.CreateBcryptHash(params.Password) + // Convert sql.NullString to string for hashing + passwordStr := "" + if params.Password.Valid { + passwordStr = params.Password.String + } + + hash, err := cryptoutil.CreateBcryptHash(passwordStr) if err != nil { return dbgen.User{}, err } - params.Password = hash + + // Convert hash back to sql.NullString + params.Password = sql.NullString{String: hash, Valid: true} return s.dbgen.UsersServiceCreateUser(ctx, params) } diff --git a/internal/service/users/users.go b/internal/service/users/users.go index 7835b36..7c6224a 100644 --- a/internal/service/users/users.go +++ b/internal/service/users/users.go @@ -11,3 +11,8 @@ func New(dbgen *dbgen.Queries) *Service { dbgen: dbgen, } } + +// IsOIDCUser checks if a user is authenticated via OIDC +func (s *Service) IsOIDCUser(user dbgen.User) bool { + return user.OidcProvider.Valid && user.OidcSubject.Valid +} diff --git a/internal/view/middleware/inject_reqctx.go b/internal/view/middleware/inject_reqctx.go index c83a40f..bc54e86 100644 --- a/internal/view/middleware/inject_reqctx.go +++ b/internal/view/middleware/inject_reqctx.go @@ -30,11 +30,14 @@ func (m *Middleware) InjectReqctx(next echo.HandlerFunc) echo.HandlerFunc { reqCtx.IsAuthed = true reqCtx.SessionID = user.SessionID reqCtx.User = dbgen.User{ - ID: user.ID, - Name: user.Name, - Email: user.Email, - CreatedAt: user.CreatedAt, - UpdatedAt: user.UpdatedAt, + ID: user.ID, + Name: user.Name, + Email: user.Email, + CreatedAt: user.CreatedAt, + UpdatedAt: user.UpdatedAt, + OidcProvider: user.OidcProvider, + OidcSubject: user.OidcSubject, + Password: user.Password, } } diff --git a/internal/view/web/auth/create_first_user.go b/internal/view/web/auth/create_first_user.go index d0ec490..6156b57 100644 --- a/internal/view/web/auth/create_first_user.go +++ b/internal/view/web/auth/create_first_user.go @@ -1,6 +1,7 @@ package auth import ( + "database/sql" "net/http" "github.com/eduardolat/pgbackweb/internal/database/dbgen" @@ -132,7 +133,7 @@ func (h *handlers) createFirstUserHandler(c echo.Context) error { _, err := h.servs.UsersService.CreateUser(ctx, dbgen.UsersServiceCreateUserParams{ Name: formData.Name, Email: formData.Email, - Password: formData.Password, + Password: sql.NullString{String: formData.Password, Valid: true}, }) if err != nil { return respondhtmx.ToastError(c, err.Error()) diff --git a/internal/view/web/auth/login.go b/internal/view/web/auth/login.go index ccd9402..70304a4 100644 --- a/internal/view/web/auth/login.go +++ b/internal/view/web/auth/login.go @@ -2,6 +2,7 @@ package auth import ( "net/http" + "net/url" "github.com/eduardolat/pgbackweb/internal/logger" "github.com/eduardolat/pgbackweb/internal/util/echoutil" @@ -32,13 +33,62 @@ func (h *handlers) loginPageHandler(c echo.Context) error { return c.Redirect(http.StatusFound, pathutil.BuildPath("/auth/create-first-user")) } - return echoutil.RenderNodx(c, http.StatusOK, loginPage()) + // Check for error message in URL parameters + errorMsg := c.QueryParam("error") + if errorMsg != "" { + // URL decode the error message to handle encoded characters + if decodedMsg, err := url.QueryUnescape(errorMsg); err == nil { + errorMsg = decodedMsg + } + } + + return echoutil.RenderNodx(c, http.StatusOK, loginPage(h.servs.OIDCService.IsEnabled(), errorMsg)) } -func loginPage() nodx.Node { +func loginPage(oidcEnabled bool, errorMsg string) nodx.Node { content := []nodx.Node{ component.H1Text("Login"), + } + + // Add JavaScript to show toast notification if error message is present + if errorMsg != "" { + // Use a data attribute to safely pass the error message to JavaScript + content = append(content, + nodx.Script( + nodx.Attr("data-error-message", errorMsg), + nodx.Text(` + (function() { + const errorMsg = document.currentScript.dataset.errorMessage; + if (errorMsg) { + window.toaster.error(errorMsg); + } + })(); + `), + ), + ) + } + // Add OIDC login option if enabled + if oidcEnabled { + content = append(content, + nodx.Div( + nodx.Class("mt-4"), + nodx.A( + nodx.Href("/auth/oidc/login"), + nodx.Class("btn btn-outline btn-block"), + component.SpanText("Login with SSO"), + lucide.ExternalLink(), + ), + ), + nodx.Div( + nodx.Class("divider"), + nodx.Text("OR"), + ), + ) + } + + // Traditional login form + content = append(content, nodx.FormEl( htmx.HxPost(pathutil.BuildPath("/auth/login")), htmx.HxDisabledELT("find button"), @@ -76,7 +126,7 @@ func loginPage() nodx.Node { ), ), ), - } + ) return layout.Auth(layout.AuthParams{ Title: "Login", diff --git a/internal/view/web/auth/router.go b/internal/view/web/auth/router.go index f04fb79..8cbeb6e 100644 --- a/internal/view/web/auth/router.go +++ b/internal/view/web/auth/router.go @@ -5,6 +5,7 @@ import ( "github.com/eduardolat/pgbackweb/internal/service" "github.com/eduardolat/pgbackweb/internal/view/middleware" + "github.com/eduardolat/pgbackweb/internal/view/web/oidc" "github.com/labstack/echo/v4" ) @@ -31,4 +32,7 @@ func MountRouter( requireAuth.POST("/logout", h.logoutHandler) requireAuth.POST("/logout-all", h.logoutAllSessionsHandler) + + // Mount OIDC routes + oidc.MountRouter(parent, mids, servs) } diff --git a/internal/view/web/dashboard/profile/update_user.go b/internal/view/web/dashboard/profile/update_user.go index 6333402..d1119f2 100644 --- a/internal/view/web/dashboard/profile/update_user.go +++ b/internal/view/web/dashboard/profile/update_user.go @@ -19,6 +19,14 @@ func (h *handlers) updateUserHandler(c echo.Context) error { reqCtx := reqctx.GetCtx(c) ctx := c.Request().Context() + // Check if user is OIDC user + isOIDCUser := reqCtx.User.OidcProvider.Valid && reqCtx.User.OidcSubject.Valid + + // Block profile updates for OIDC users + if isOIDCUser { + return respondhtmx.ToastError(c, "Profile updates are not allowed for SSO users. Your profile is managed by your identity provider.") + } + var formData struct { Name string `form:"name" validate:"required"` Email string `form:"email" validate:"required,email"` @@ -28,15 +36,18 @@ func (h *handlers) updateUserHandler(c echo.Context) error { if err := c.Bind(&formData); err != nil { return respondhtmx.ToastError(c, err.Error()) } + if err := validate.Struct(&formData); err != nil { return respondhtmx.ToastError(c, err.Error()) } + passwordUpdate := sql.NullString{String: formData.Password, Valid: formData.Password != ""} + _, err := h.servs.UsersService.UpdateUser(ctx, dbgen.UsersServiceUpdateUserParams{ ID: reqCtx.User.ID, Name: sql.NullString{String: formData.Name, Valid: true}, Email: sql.NullString{String: formData.Email, Valid: true}, - Password: sql.NullString{String: formData.Password, Valid: formData.Password != ""}, + Password: passwordUpdate, }) if err != nil { return respondhtmx.ToastError(c, err.Error()) @@ -46,6 +57,12 @@ func (h *handlers) updateUserHandler(c echo.Context) error { } func updateUserForm(user dbgen.User) nodx.Node { + // Check if user is OIDC user + isOIDCUser := user.OidcProvider.Valid && user.OidcSubject.Valid + + // Build form fields + formFields := []nodx.Node{ + component.H2Text("Update profile"), return component.CardBox(component.CardBoxParams{ Children: []nodx.Node{ nodx.FormEl( @@ -96,17 +113,99 @@ func updateUserForm(user dbgen.User) nodx.Node { Type: component.InputTypePassword, }), + // Show different message for OIDC users + nodx.If(isOIDCUser, + nodx.Div( + nodx.Class("alert alert-info mb-4"), nodx.Div( - nodx.Class("flex justify-end items-center space-x-2 pt-2"), - component.HxLoadingMd(), - nodx.Button( - nodx.Class("btn btn-primary"), - nodx.Type("submit"), - component.SpanText("Save changes"), - lucide.Save(), + nodx.Class("flex items-center space-x-2"), + lucide.Info(), + nodx.Div( + nodx.Class("text-sm"), + nodx.Text("You are logged in via SSO. Your profile information is managed by your identity provider and cannot be changed here."), ), ), ), + ), + + component.InputControl(component.InputControlParams{ + Name: "name", + Label: "Full name", + Placeholder: "Your full name", + Required: !isOIDCUser, // Don't require if disabled + Type: component.InputTypeText, + AutoComplete: "name", + Children: []nodx.Node{ + nodx.Value(user.Name), + nodx.If(isOIDCUser, nodx.Disabled("")), + nodx.If(isOIDCUser, nodx.Readonly("")), + }, + }), + + component.InputControl(component.InputControlParams{ + Name: "email", + Label: "Email", + Placeholder: "Your email", + Required: !isOIDCUser, // Don't require if disabled + AutoComplete: "email", + Type: component.InputTypeEmail, + Children: []nodx.Node{ + nodx.Value(user.Email), + nodx.If(isOIDCUser, nodx.Disabled("")), + nodx.If(isOIDCUser, nodx.Readonly("")), + }, + }), + } + + // Add password fields only for non-OIDC users + if !isOIDCUser { + formFields = append(formFields, + component.InputControl(component.InputControlParams{ + Name: "password", + Label: "Change password", + Placeholder: "New password", + AutoComplete: "new-password", + Type: component.InputTypePassword, + HelpText: "Leave empty to keep your current password", + }), + + component.InputControl(component.InputControlParams{ + Name: "password_confirmation", + Label: "Confirm password", + Placeholder: "Confirm new password", + AutoComplete: "new-password", + Type: component.InputTypePassword, + }), + ) + } + + // Add submit button (disabled for OIDC users) + formFields = append(formFields, + nodx.Div( + nodx.Class("flex justify-end items-center space-x-2 pt-2"), + component.HxLoadingMd(), + nodx.Button( + nodx.ClassMap{ + "btn btn-primary": !isOIDCUser, + "btn btn-disabled": isOIDCUser, + }, + nodx.Type("submit"), + nodx.If(isOIDCUser, nodx.Disabled("")), + component.SpanText("Save changes"), + lucide.Save(), + ), + ), + ) + + return component.CardBox(component.CardBoxParams{ + Children: []nodx.Node{ + nodx.FormEl( + append([]nodx.Node{ + nodx.If(!isOIDCUser, htmx.HxPost("/dashboard/profile")), + nodx.If(!isOIDCUser, htmx.HxDisabledELT("find button")), + nodx.Class("space-y-2"), + }, formFields...)..., + ), }, }) } diff --git a/internal/view/web/oidc/router.go b/internal/view/web/oidc/router.go new file mode 100644 index 0000000..f4becbc --- /dev/null +++ b/internal/view/web/oidc/router.go @@ -0,0 +1,175 @@ +package oidc + +import ( + "errors" + "net/http" + "net/url" + + "github.com/eduardolat/pgbackweb/internal/logger" + "github.com/eduardolat/pgbackweb/internal/service" + "github.com/eduardolat/pgbackweb/internal/service/oidc" + "github.com/eduardolat/pgbackweb/internal/view/middleware" + "github.com/eduardolat/pgbackweb/internal/view/web/respondhtmx" + "github.com/labstack/echo/v4" +) + +type handlers struct { + servs *service.Service +} + +func MountRouter( + parent *echo.Group, mids *middleware.Middleware, servs *service.Service, +) { + if !servs.OIDCService.IsEnabled() { + return + } + + h := handlers{servs: servs} + + requireNoAuth := parent.Group("", mids.RequireNoAuth) + + requireNoAuth.GET("/oidc/login", h.oidcLoginHandler) + requireNoAuth.GET("/oidc/callback", h.oidcCallbackHandler) +} + +func (h *handlers) oidcLoginHandler(c echo.Context) error { + state, err := h.servs.OIDCService.GenerateState() + if err != nil { + logger.Error("OIDC: failed to generate state", logger.KV{ + "ip": c.RealIP(), + "ua": c.Request().UserAgent(), + "error": err, + }) + return handleOIDCError(c, "OIDC: Unable to initiate login") + } + + // Store state in session/cookie for verification + c.SetCookie(&http.Cookie{ + Name: "oidc_state", + Value: state, + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteLaxMode, + MaxAge: 300, // 5 minutes + Path: "/", + }) + + authURL := h.servs.OIDCService.GetAuthURL(state) + return c.Redirect(http.StatusFound, authURL) +} + +func (h *handlers) oidcCallbackHandler(c echo.Context) error { + ctx := c.Request().Context() + + // Verify state parameter + state := c.QueryParam("state") + stateCookie, err := c.Cookie("oidc_state") + if err != nil || stateCookie == nil || stateCookie.Value != state { + expectedValue := "" + if stateCookie != nil { + expectedValue = stateCookie.Value + } + logger.Error("OIDC: state mismatch", logger.KV{ + "ip": c.RealIP(), + "ua": c.Request().UserAgent(), + "state": state, + "expected": expectedValue, + }) + return handleOIDCError(c, "OIDC: Invalid state parameter") + } + + // Clear the state cookie + c.SetCookie(&http.Cookie{ + Name: "oidc_state", + Value: "", + HttpOnly: true, + MaxAge: -1, + Path: "/", + }) + + // Check for error from OIDC provider + if errorParam := c.QueryParam("error"); errorParam != "" { + errorDesc := c.QueryParam("error_description") + errorMsg := "OIDC: Login failed" + if errorDesc != "" { + errorMsg = "OIDC: " + errorDesc + } else { + errorMsg = "OIDC: " + errorParam + } + logger.Error("OIDC provider returned error", logger.KV{ + "ip": c.RealIP(), + "ua": c.Request().UserAgent(), + "error": errorParam, + "error_description": errorDesc, + }) + return handleOIDCError(c, errorMsg) + } + + code := c.QueryParam("code") + if code == "" { + return handleOIDCError(c, "OIDC: Missing authorization code") + } + + // Exchange code for user info + userInfo, err := h.servs.OIDCService.ExchangeCode(ctx, code) + if err != nil { + logger.Error("failed to exchange OIDC code", logger.KV{ + "ip": c.RealIP(), + "ua": c.Request().UserAgent(), + "error": err, + }) + return handleOIDCError(c, "OIDC: Unable to authenticate with provider") + } + + // Create or update user + user, err := h.servs.OIDCService.CreateOrUpdateUser(ctx, userInfo) + if err != nil { + errorMsg := "OIDC: Unable to create user account" + if errors.Is(err, oidc.ErrEmailAlreadyExists) { + errorMsg = "OIDC: Email already exists. Use regular login." + } + logger.Error("failed to create/update OIDC user", logger.KV{ + "ip": c.RealIP(), + "ua": c.Request().UserAgent(), + "email": userInfo.Email, + "error": err, + }) + return handleOIDCError(c, errorMsg) + } + + logger.Info("OIDC: authentication successful", logger.KV{ + "email": userInfo.Email, + "name": userInfo.Name, + "subject": userInfo.Subject, + "user_id": user.ID, + }) + + // Create session for the user + session, err := h.servs.AuthService.LoginOIDC( + ctx, user.ID, c.RealIP(), c.Request().UserAgent(), + ) + if err != nil { + logger.Error("OIDC: failed to create session for user", logger.KV{ + "ip": c.RealIP(), + "ua": c.Request().UserAgent(), + "user_id": user.ID, + "error": err, + }) + return handleOIDCError(c, "OIDC: Unable to create session") + } + + // Set session cookie and redirect to dashboard + h.servs.AuthService.SetSessionCookie(c, session.DecryptedToken) + return c.Redirect(http.StatusSeeOther, "/dashboard") +} + +// handleOIDCError handles OIDC errors by detecting if it's an HTMX request or regular browser request +func handleOIDCError(c echo.Context, message string) error { + // Check if it's an HTMX request + if c.Request().Header.Get("HX-Request") != "" { + return respondhtmx.ToastError(c, message) + } + + // For regular browser requests, redirect to login with error parameter + return c.Redirect(http.StatusSeeOther, "/auth/login?error="+url.QueryEscape(message)) +} diff --git a/internal/view/web/oidc/router_test.go b/internal/view/web/oidc/router_test.go new file mode 100644 index 0000000..7561104 --- /dev/null +++ b/internal/view/web/oidc/router_test.go @@ -0,0 +1,327 @@ +package oidc + +import ( + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/eduardolat/pgbackweb/internal/service" + "github.com/eduardolat/pgbackweb/internal/service/oidc" + "github.com/eduardolat/pgbackweb/internal/view/middleware" + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +func TestHandleOIDCError(t *testing.T) { + t.Run("HTMX Request", func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("HX-Request", "true") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handleOIDCError(c, "Test error message") + + // Should handle HTMX request (returns no error, sets headers for toast) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + // Check HTMX headers are set for toast error + assert.Contains(t, rec.Header().Get("HX-Reswap"), "none") + assert.Contains(t, rec.Header().Get("HX-Trigger"), "ctm_toast_error") + }) + + t.Run("Regular Browser Request", func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handleOIDCError(c, "Test error message") + + assert.NoError(t, err) + assert.Equal(t, http.StatusSeeOther, rec.Code) + + location := rec.Header().Get("Location") + assert.Contains(t, location, "/auth/login") + assert.Contains(t, location, url.QueryEscape("Test error message")) + }) +} + +func TestOIDCLoginHandler_StateGeneration(t *testing.T) { + t.Run("State Generation and Cookie Setting", func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/auth/oidc/login", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + // Create a mock OIDC service that's enabled + mockOIDCService := &oidc.Service{} + // Note: In a real test, you'd need to properly mock this + // For now, we'll just test the structure + + mockServices := &service.Service{ + OIDCService: mockOIDCService, + } + + h := handlers{servs: mockServices} + + // This will succeed because GenerateState() doesn't depend on config + // and GetAuthURL() will return empty string but won't error + err := h.oidcLoginHandler(c) + + // The function will succeed and redirect (even to empty URL) + assert.NoError(t, err) + assert.Equal(t, http.StatusFound, rec.Code) + + // Check that the state cookie was set + cookies := rec.Result().Cookies() + foundStateCookie := false + for _, cookie := range cookies { + if cookie.Name == "oidc_state" { + foundStateCookie = true + assert.NotEmpty(t, cookie.Value) + assert.True(t, cookie.HttpOnly) + assert.Equal(t, 300, cookie.MaxAge) + break + } + } + assert.True(t, foundStateCookie, "oidc_state cookie should be set") + }) +} + +func TestOIDCCallbackHandler_StateValidation(t *testing.T) { + t.Run("State Mismatch", func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/auth/oidc/callback?code=test-code&state=wrong-state", nil) + req.AddCookie(&http.Cookie{ + Name: "oidc_state", + Value: "correct-state", + }) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + mockOIDCService := &oidc.Service{} + mockServices := &service.Service{ + OIDCService: mockOIDCService, + } + + h := handlers{servs: mockServices} + + err := h.oidcCallbackHandler(c) + + // Should handle state mismatch by redirecting (returning no error) + assert.NoError(t, err) + // Check that it redirected to login with error + assert.Equal(t, http.StatusSeeOther, rec.Code) + assert.Contains(t, rec.Header().Get("Location"), "/auth/login") + }) + + t.Run("Missing State Cookie", func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/auth/oidc/callback?code=test-code&state=test-state", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + mockOIDCService := &oidc.Service{} + mockServices := &service.Service{ + OIDCService: mockOIDCService, + } + + h := handlers{servs: mockServices} + + err := h.oidcCallbackHandler(c) + + // Should handle missing state cookie by redirecting (returns no error) + assert.NoError(t, err) + assert.Equal(t, http.StatusSeeOther, rec.Code) + assert.Contains(t, rec.Header().Get("Location"), "/auth/login") + }) + + t.Run("OIDC Provider Error", func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/auth/oidc/callback?error=access_denied&error_description=User+denied+access", nil) + req.AddCookie(&http.Cookie{ + Name: "oidc_state", + Value: "test-state", + }) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + mockOIDCService := &oidc.Service{} + mockServices := &service.Service{ + OIDCService: mockOIDCService, + } + + h := handlers{servs: mockServices} + + err := h.oidcCallbackHandler(c) + + // Should handle provider error by redirecting (returns no error) + assert.NoError(t, err) + assert.Equal(t, http.StatusSeeOther, rec.Code) + assert.Contains(t, rec.Header().Get("Location"), "/auth/login") + }) + + t.Run("Missing Authorization Code", func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/auth/oidc/callback?state=test-state", nil) + req.AddCookie(&http.Cookie{ + Name: "oidc_state", + Value: "test-state", + }) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + mockOIDCService := &oidc.Service{} + mockServices := &service.Service{ + OIDCService: mockOIDCService, + } + + h := handlers{servs: mockServices} + + err := h.oidcCallbackHandler(c) + + // Should handle missing code by redirecting (returns no error) + assert.NoError(t, err) + assert.Equal(t, http.StatusSeeOther, rec.Code) + assert.Contains(t, rec.Header().Get("Location"), "/auth/login") + }) +} + +func TestOIDCCallbackHandler_AdvancedScenarios(t *testing.T) { + t.Run("State Parameter Empty", func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/auth/oidc/callback?code=test-code&state=", nil) + req.AddCookie(&http.Cookie{ + Name: "oidc_state", + Value: "valid-state", + }) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + mockOIDCService := &oidc.Service{} + mockServices := &service.Service{ + OIDCService: mockOIDCService, + } + + h := handlers{servs: mockServices} + + err := h.oidcCallbackHandler(c) + + // Should handle empty state parameter by redirecting + assert.NoError(t, err) + assert.Equal(t, http.StatusSeeOther, rec.Code) + assert.Contains(t, rec.Header().Get("Location"), "/auth/login") + }) + + t.Run("State Parameter Missing", func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/auth/oidc/callback?code=test-code", nil) + req.AddCookie(&http.Cookie{ + Name: "oidc_state", + Value: "valid-state", + }) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + mockOIDCService := &oidc.Service{} + mockServices := &service.Service{ + OIDCService: mockOIDCService, + } + + h := handlers{servs: mockServices} + + err := h.oidcCallbackHandler(c) + + // Should handle missing state parameter by redirecting + assert.NoError(t, err) + assert.Equal(t, http.StatusSeeOther, rec.Code) + assert.Contains(t, rec.Header().Get("Location"), "/auth/login") + }) + + t.Run("HTMX Callback Request", func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/auth/oidc/callback?code=test-code&state=wrong-state", nil) + req.Header.Set("HX-Request", "true") + req.AddCookie(&http.Cookie{ + Name: "oidc_state", + Value: "correct-state", + }) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + mockOIDCService := &oidc.Service{} + mockServices := &service.Service{ + OIDCService: mockOIDCService, + } + + h := handlers{servs: mockServices} + + err := h.oidcCallbackHandler(c) + + // Should handle HTMX request with toast error + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Contains(t, rec.Header().Get("HX-Reswap"), "none") + assert.Contains(t, rec.Header().Get("HX-Trigger"), "ctm_toast_error") + }) +} + +func TestOIDCLoginHandler_EdgeCases(t *testing.T) { + t.Run("HTMX Login Request", func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/auth/oidc/login", nil) + req.Header.Set("HX-Request", "true") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + mockOIDCService := &oidc.Service{} + mockServices := &service.Service{ + OIDCService: mockOIDCService, + } + + h := handlers{servs: mockServices} + + err := h.oidcLoginHandler(c) + + // Should redirect normally even for HTMX requests + assert.NoError(t, err) + assert.Equal(t, http.StatusFound, rec.Code) + }) +} + +func TestHandlers_ServiceIntegration(t *testing.T) { + t.Run("Handler Creation", func(t *testing.T) { + mockOIDCService := &oidc.Service{} + mockServices := &service.Service{ + OIDCService: mockOIDCService, + } + + h := handlers{servs: mockServices} + + // Verify handlers struct is created correctly + assert.NotNil(t, h.servs) + assert.NotNil(t, h.servs.OIDCService) + }) +} + +func TestMountRouter_EnabledCheck(t *testing.T) { + t.Run("OIDC Disabled - No Panic", func(t *testing.T) { + e := echo.New() + group := e.Group("/auth") + + // Create a disabled OIDC service + mockOIDCService := &oidc.Service{} + mockServices := &service.Service{ + OIDCService: mockOIDCService, + } + + mockMiddleware := &middleware.Middleware{} + + // This should not panic even if OIDC is disabled + assert.NotPanics(t, func() { + MountRouter(group, mockMiddleware, mockServices) + }) + }) +}