Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions internal/api/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,90 @@ func TestE2EHooks(t *testing.T) {
}
})
}

t.Run("AMRStringArrayUnmarshalling", func(t *testing.T) {
defer inst.HookRecorder.CustomizeAccessToken.ClearCalls()

// Setup hook that returns amr as array of strings
var claimsIn M
hr := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("content-type", "application/json")
w.WriteHeader(http.StatusOK)

err := json.NewDecoder(r.Body).Decode(&claimsIn)
require.NoError(t, err)

// Modify amr to be array of strings instead of objects
claimsOut := copyMap(t, claimsIn)
claimsOut["claims"].(M)["amr"] = []string{"password", "totp"}

err = json.NewEncoder(w).Encode(claimsOut)
require.NoError(t, err)
})

inst.HookRecorder.CustomizeAccessToken.ClearCalls()
inst.HookRecorder.CustomizeAccessToken.SetHandler(hr)

// Get token with modified amr
req := &api.PasswordGrantParams{
Email: string(currentUser.Email),
Password: defaultPassword,
}

res := new(api.AccessTokenResponse)
err := e2eapi.Do(ctx, http.MethodPost, inst.APIServer.URL+"/token?grant_type=password", req, res)
require.NoError(t, err)
require.True(t, len(res.Token) > 0)

// Verify hook was called
{
calls := inst.HookRecorder.CustomizeAccessToken.GetCalls()
require.Equal(t, 1, len(calls))
}

// Parse token to verify it can be unmarshalled
p := jwt.NewParser(jwt.WithValidMethods(globalCfg.JWT.ValidMethods))
token, err := p.ParseWithClaims(
res.Token,
&api.AccessTokenClaims{},
func(token *jwt.Token) (any, error) {
if kid, ok := token.Header["kid"]; ok {
if kidStr, ok := kid.(string); ok {
return conf.FindPublicKeyByKid(kidStr, &globalCfg.JWT)
}
}
if alg, ok := token.Header["alg"]; ok {
if alg == jwt.SigningMethodHS256.Name {
return []byte(globalCfg.JWT.Secret), nil
}
}
return nil, fmt.Errorf("missing kid")
})
require.NoError(t, err, "Token should parse successfully even with string array amr")

fmt.Println("token hereee", res.Token)
// Verify claims were unmarshalled correctly
claims, ok := token.Claims.(*api.AccessTokenClaims)
require.True(t, ok, "Claims should be AccessTokenClaims type")
require.NotNil(t, claims.AuthenticationMethodReference, "AMR should not be nil")
require.Len(t, claims.AuthenticationMethodReference, 2, "AMR should have 2 entries")
require.Equal(t, "password", claims.AuthenticationMethodReference[0].Method)
require.Equal(t, "totp", claims.AuthenticationMethodReference[1].Method)

// Call /user endpoint with the token to verify it works end-to-end
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, "/user", nil)
require.NoError(t, err)

httpRes, err := inst.DoAuth(httpReq, res.Token)
require.NoError(t, err, "Should be able to call /user endpoint with token containing string array amr")
require.Equal(t, http.StatusOK, httpRes.StatusCode, "/user endpoint should return 200 OK")

// Verify we got user data back
var userData models.User
err = json.NewDecoder(httpRes.Body).Decode(&userData)
require.NoError(t, err, "Should be able to decode user response")
require.Equal(t, currentUser.ID, userData.ID, "Should get the correct user")
})
})

t.Run("SendEmail", func(t *testing.T) {
Expand Down
37 changes: 37 additions & 0 deletions internal/api/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,43 @@ end; $$ language plpgsql;`,
"user_metadata": nil,
},
shouldError: false,
}, {
desc: "Modify amr to be array of strings",
uri: "pg-functions://postgres/auth/custom_access_token_amr_strings",
hookFunctionSQL: `
create or replace function custom_access_token_amr_strings(input jsonb)
returns jsonb as $$
declare
result jsonb;
begin
input := jsonb_set(input, '{claims,amr}', '["password", "mfa"]'::jsonb);
result := jsonb_build_object('claims', input->'claims');
return result;
end; $$ language plpgsql;`,
expectedClaims: map[string]interface{}{
"amr": []interface{}{"password", "mfa"},
},
shouldError: false,
}, {
desc: "Modify amr to be array of objects",
uri: "pg-functions://postgres/auth/custom_access_token_amr_objects",
hookFunctionSQL: `
create or replace function custom_access_token_amr_objects(input jsonb)
returns jsonb as $$
declare
result jsonb;
begin
input := jsonb_set(input, '{claims,amr}', '[{"method": "password"}, {"method": "mfa"}]'::jsonb);
result := jsonb_build_object('claims', input->'claims');
return result;
end; $$ language plpgsql;`,
expectedClaims: map[string]interface{}{
"amr": []interface{}{
map[string]interface{}{"method": "password"},
map[string]interface{}{"method": "mfa"},
},
},
shouldError: false,
},
}
for _, c := range cases {
Expand Down
49 changes: 47 additions & 2 deletions internal/tokens/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package tokens

import (
"context"
"encoding/json"
"fmt"
mathRand "math/rand"
"net/http"
Expand All @@ -26,6 +27,47 @@ import (

const retryLoopDuration = 5.0

// AMRClaim supports unmarshalling AMR as either strings or AMREntry objects.
type AMRClaim []models.AMREntry

// UnmarshalJSON accepts either an array of strings or AMREntry objects.
func (a *AMRClaim) UnmarshalJSON(data []byte) error {
// Handle null explicitly - null cannot be unmarshaled into a slice
if len(data) > 0 {
trimmed := strings.TrimSpace(string(data))
if trimmed == "null" {
*a = AMRClaim{}
return nil
}
}

var rawItems []json.RawMessage
if err := json.Unmarshal(data, &rawItems); err != nil {
return err
}

entries := make([]models.AMREntry, 0, len(rawItems))
for _, item := range rawItems {
var method string
if err := json.Unmarshal(item, &method); err == nil {
entries = append(entries, models.AMREntry{
Method: method,
Timestamp: time.Now().Unix(),
})
continue
}

var entry models.AMREntry
if err := json.Unmarshal(item, &entry); err != nil {
return err
}
entries = append(entries, entry)
}

*a = entries
return nil
}

// AccessTokenClaims is a struct thats used for JWT claims
type AccessTokenClaims struct {
jwt.RegisteredClaims
Expand All @@ -35,7 +77,7 @@ type AccessTokenClaims struct {
UserMetaData map[string]interface{} `json:"user_metadata"`
Role string `json:"role"`
AuthenticatorAssuranceLevel string `json:"aal,omitempty"`
AuthenticationMethodReference []models.AMREntry `json:"amr,omitempty"`
AuthenticationMethodReference AMRClaim `json:"amr,omitempty"`
SessionId string `json:"session_id,omitempty"`
IsAnonymous bool `json:"is_anonymous"`
ClientID string `json:"client_id,omitempty"`
Expand Down Expand Up @@ -951,7 +993,10 @@ const MinimumViableTokenSchema = `{
"amr": {
"type": "array",
"items": {
"type": "object"
"anyOf": [
{"type": "string"},
{"type": "object"}
]
}
},
"session_id": {
Expand Down
67 changes: 67 additions & 0 deletions internal/tokens/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"net/http"
"strconv"
"strings"
Expand Down Expand Up @@ -1022,3 +1023,69 @@ func (ts *IDTokenTestSuite) TestIDTokenWithMultipleScopes() {
phoneNumber, hasPhone := claims["phone_number"]
require.False(ts.T(), hasPhone || (phoneNumber != nil && phoneNumber != ""), "phone_number claim should not be present without phone scope")
}

func TestAMRClaimUnmarshal(t *testing.T) {
t.Run("mixed string and object formats", func(t *testing.T) {
var claim AMRClaim
before := time.Now().Unix()

err := json.Unmarshal([]byte(`["password", {"method":"totp","timestamp":123,"provider":"webauthn"}]`), &claim)
require.NoError(t, err)
require.Len(t, claim, 2)

require.Equal(t, "password", claim[0].Method)
require.GreaterOrEqual(t, claim[0].Timestamp, before)
require.LessOrEqual(t, claim[0].Timestamp, time.Now().Unix())
require.Empty(t, claim[0].Provider, "string format should not have provider")

require.Equal(t, "totp", claim[1].Method)
require.Equal(t, int64(123), claim[1].Timestamp)
require.Equal(t, "webauthn", claim[1].Provider, "provider should be preserved from object format")
})

t.Run("object with provider", func(t *testing.T) {
var claim AMRClaim
err := json.Unmarshal([]byte(`[{"method":"sso","timestamp":456,"provider":"saml"}]`), &claim)
require.NoError(t, err)
require.Len(t, claim, 1)
require.Equal(t, "sso", claim[0].Method)
require.Equal(t, int64(456), claim[0].Timestamp)
require.Equal(t, "saml", claim[0].Provider, "provider should be preserved")
})

t.Run("object without provider", func(t *testing.T) {
var claim AMRClaim
err := json.Unmarshal([]byte(`[{"method":"password","timestamp":789}]`), &claim)
require.NoError(t, err)
require.Len(t, claim, 1)
require.Equal(t, "password", claim[0].Method)
require.Equal(t, int64(789), claim[0].Timestamp)
require.Empty(t, claim[0].Provider, "provider should be empty when not provided")
})

t.Run("all strings", func(t *testing.T) {
var claim AMRClaim
before := time.Now().Unix()
err := json.Unmarshal([]byte(`["password", "totp"]`), &claim)
require.NoError(t, err)
require.Len(t, claim, 2)
require.Equal(t, "password", claim[0].Method)
require.Equal(t, "totp", claim[1].Method)
require.GreaterOrEqual(t, claim[0].Timestamp, before)
require.Empty(t, claim[0].Provider)
require.Empty(t, claim[1].Provider)
})

t.Run("all objects", func(t *testing.T) {
var claim AMRClaim
err := json.Unmarshal([]byte(`[{"method":"password","timestamp":100},{"method":"totp","timestamp":200,"provider":"webauthn"}]`), &claim)
require.NoError(t, err)
require.Len(t, claim, 2)
require.Equal(t, "password", claim[0].Method)
require.Equal(t, int64(100), claim[0].Timestamp)
require.Empty(t, claim[0].Provider)
require.Equal(t, "totp", claim[1].Method)
require.Equal(t, int64(200), claim[1].Timestamp)
require.Equal(t, "webauthn", claim[1].Provider, "provider should be preserved")
})
}