Skip to content

Commit

Permalink
feat: allow specifying optional algorithms during verification (#32)
Browse files Browse the repository at this point in the history
This allows an optional `algorithms` keyword argument to be specified for `validate!` and `with_valid_jwt` methods that can contain a list of algorithms to accept. If the JWT is not signed with any of the supplied algorithms, the validation fails.
  • Loading branch information
tanmaykm authored Jun 19, 2024
1 parent bce27d2 commit a84e31e
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 37 deletions.
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ for k in keys(signingkeyset.keys)
end
```

The `alg` method on a JWK returns the algorithm used for the key.

```julia
julia> JWTs.alg(keyset.keys["7978a91347261a291bd71dcab4a464be7d279666"])
"RS256"
```

## Tokens

**JWT** represents a JSON Web Token containing the payload at the minimum. When signed, it holds the header (with key id and algorithm used) and signature too. The parts are stored in encoded form.
Expand Down Expand Up @@ -97,6 +104,13 @@ julia> kid(jwt)
"4Fytp3LfBhriD0eZ-k3aNS042bDiCZXg6bQNJmYoaE"
```

The `alg` method shows the algorithm used to sign a JWT.

```julia
julia> alg(jwt)
"RS256"
```

## Validation

To validate a JWT against a key, call the `validate!` method, passing a key set and the key id to use.
Expand Down Expand Up @@ -128,3 +142,10 @@ julia> with_valid_jwt(jwt2, keyset) do valid_jwt
...
"email" => "user@example.com"
```

Both `validate!` and `with_valid_jwt` methods can optionally take an `algorithms` argument, which is a list of algorithms to validate against. If the JWT's algorithm is not in the list, the validation will fail.

```julia
julia> validate!(jwt, keyset, keyname; algorithms=["RS256"])
true
```
33 changes: 24 additions & 9 deletions src/JWTs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,27 +152,33 @@ show(io::IO, jwt::JWT) = print(io, issigned(jwt) ? join([jwt.header, jwt.payload
Validate the JWT using the keys in the keyset.
The JWT must be signed. An exception is thrown otherwise.
The keyset must contain the key id from the JWT header. A KeyError is thrown otherwise.
The optional `algorithms` parameter can be used to specify the algorithms to use for validation.
Returns `true` if the JWT is valid, `false` otherwise.
"""
validate!(jwt::JWT, keyset::JWKSet) = validate!(jwt, keyset, kid(jwt))
function validate!(jwt::JWT, keyset::JWKSet, kid::String)
validate!(jwt::JWT, keyset::JWKSet; algorithms::Vector{String}=String[]) = validate!(jwt, keyset, kid(jwt); algorithms=algorithms)
function validate!(jwt::JWT, keyset::JWKSet, kid::String; algorithms::Vector{String}=String[])
isverified(jwt) && (return isvalid(jwt))
(kid in keys(keyset.keys)) || refresh!(keyset)
validate!(jwt, keyset.keys[kid])
validate!(jwt, keyset.keys[kid]; algorithms=algorithms)
end
function validate!(jwt::JWT, key::JWK)
function validate!(jwt::JWT, key::JWK; algorithms::Vector{String}=String[])
isverified(jwt) && (return isvalid(jwt))
issigned(jwt) || throw(ArgumentError("jwt is not signed"))

data = jwt.header * "." * jwt.payload
sigbytes = base64decode(urldec(jwt.signature))

jwt.verified = true

# Check that the (optional) `alg` header claim matches the algorithm of the validation key
alg_jwt = alg(jwt)
valid_alg = alg_jwt === nothing || alg_jwt == alg(key)
if !isempty(algorithms)
alg_matched = alg_jwt === nothing ? alg(key) : alg_jwt
if !(alg_matched in algorithms)
return false
end
end
jwt.valid = valid_alg && if key isa JWKRSA
try
MbedTLS.verify(key.key, key.kind, MbedTLS.digest(key.kind, data), sigbytes) == 0
Expand Down Expand Up @@ -362,13 +368,22 @@ Arguments:
Keyword arguments:
- `kid`: The key id to use for validation. If not specified, the `kid` from the JWT header is used.
- `algorithms`: Ensure validation with one of the listed algorithms. Not enforced by deault.
"""
with_valid_jwt(f::Function, jwt::String, keyset::JWKSet; kid::Union{Nothing,String}=nothing) = with_valid_jwt(f, JWT(jwt), keyset; kid=kid)
function with_valid_jwt(f::Function, jwt::JWT, keyset::JWKSet; kid::Union{Nothing,String}=nothing)
function with_valid_jwt(f::Function, jwt::String, keyset::JWKSet;
kid::Union{Nothing,String}=nothing,
algorithms::Vector{String}=String[],
)
with_valid_jwt(f, JWT(jwt), keyset; kid=kid, algorithms=algorithms)
end
function with_valid_jwt(f::Function, jwt::JWT, keyset::JWKSet;
kid::Union{Nothing,String}=nothing,
algorithms::Vector{String}=String[],
)
if isnothing(kid)
validate!(jwt, keyset)
validate!(jwt, keyset; algorithms=algorithms)
else
validate!(jwt, keyset, kid)
validate!(jwt, keyset, kid; algorithms=algorithms)
end

isvalid(jwt) || throw(ArgumentError("invalid jwt"))
Expand Down
70 changes: 42 additions & 28 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ function test_in_mem_keyset(template)
end
end

function test_signing_keys(keyset, signingkeyset)
function test_signing_keys(keyset, signingkeyset, algorithms::Vector{String})
for k in keys(keyset.keys)
for d in test_payload_data
jwt = JWT(; payload=d)
Expand All @@ -97,7 +97,17 @@ function test_signing_keys(keyset, signingkeyset)
@test issigned(jwt2)
@test !isverified(jwt2)
@test isvalid(jwt2) === nothing
@test validate!(jwt, keyset, k)
# test with valid algos
@test validate!(jwt, keyset, k; algorithms=algorithms)

# test with invalid algos
jwt_check = JWT(; jwt=string(jwt))
@test !validate!(jwt_check, keyset, k; algorithms=["invalidalgo"])

# test without specifying algos
jwt_check = JWT(; jwt=string(jwt))
@test validate!(jwt_check, keyset, k; algorithms=String[])

@test issigned(jwt)
@test isvalid(jwt)
@test isverified(jwt)
Expand All @@ -111,15 +121,15 @@ function test_signing_keys(keyset, signingkeyset)
@test !isverified(jwt2)
@test isvalid(jwt2) === nothing
invalidkey = findfirst(x -> x != keyset.keys[k], keyset.keys)
@test !validate!(jwt2, keyset, invalidkey)
@test !validate!(jwt2, keyset, invalidkey; algorithms=algorithms)
@test issigned(jwt2)
@test !isvalid(jwt2)
@test isverified(jwt2)
end
end
end

function test_signing_asymmetric_keys(keyset_url)
function test_signing_asymmetric_keys(keyset_url, algorithms::Vector{String})
print_header("signing asymmetric keys")
keyset = JWKSet(keyset_url)
refresh!(keyset)
Expand All @@ -131,16 +141,16 @@ function test_signing_asymmetric_keys(keyset_url)
end
signingkeyset.keys[k] = JWKRSA(signingkeyset.keys[k].kind, MbedTLS.parse_keyfile(keyfile))
end
test_signing_keys(keyset, signingkeyset)
test_signing_keys(keyset, signingkeyset, algorithms)
end

function test_signing_symmetric_keys(keyset_url)
function test_signing_symmetric_keys(keyset_url, algorithms::Vector{String})
print_header("signing symmetric keys")
keyset = test_and_get_keyset(keyset_url)
test_signing_keys(keyset, keyset)
test_signing_keys(keyset, keyset, algorithms)
end

function test_with_valid_jwt(keyset_url)
function test_with_valid_jwt(keyset_url, algorithms::Vector{String})
print_header("with_valid_jwt do block")

keyset = JWKSet(keyset_url)
Expand All @@ -151,7 +161,7 @@ function test_with_valid_jwt(keyset_url)
key = first(keys(keyset.keys))
sign!(jwt, keyset, key)

with_valid_jwt(jwt, keyset) do jwt3
with_valid_jwt(jwt, keyset; algorithms=algorithms) do jwt3
@test isvalid(jwt3)
@test claims(jwt3) == d
end
Expand All @@ -167,24 +177,28 @@ function test_with_valid_jwt(keyset_url)
end
end

test_and_get_keyset("https://www.googleapis.com/oauth2/v3/certs")
test_signing_symmetric_keys("file://" * joinpath(@__DIR__, "keys", "oct", "jwkkey.json"))
test_in_mem_keyset(joinpath(@__DIR__, "keys", "oct", "jwkkey.json"))
test_signing_asymmetric_keys("file://" * joinpath(@__DIR__, "keys", "rsa", "jwkkey.json"))
test_with_valid_jwt("file://" * joinpath(@__DIR__, "keys", "oct", "jwkkey.json"))

@testset "alg" begin
rsakey = MbedTLS.parse_keyfile(joinpath(@__DIR__, "keys", "rsa", "rsakey1.private.pem"))
@test JWTs.alg(JWKRSA(MbedTLS.MD_SHA256, rsakey)) == "RS256"
@test JWTs.alg(JWKRSA(MbedTLS.MD_SHA384, rsakey)) == "RS384"
@test JWTs.alg(JWKRSA(MbedTLS.MD_SHA, rsakey)) == "RS512"

@test JWTs.alg(JWKSymmetric(MbedTLS.MD_SHA256, UInt8[])) == "HS256"
@test JWTs.alg(JWKSymmetric(MbedTLS.MD_SHA384, UInt8[])) == "HS384"
@test JWTs.alg(JWKSymmetric(MbedTLS.MD_SHA, UInt8[])) == "HS512"

for kind in (MbedTLS.MD_SHA1, MbedTLS.MD_SHA224)
@test_throws ArgumentError JWTs.alg(JWKRSA(kind, rsakey))
@test_throws ArgumentError JWTs.alg(JWKSymmetric(kind, UInt8[]))
@testset "JWTs" begin
@testset "signing" begin
test_and_get_keyset("https://www.googleapis.com/oauth2/v3/certs")
test_signing_symmetric_keys("file://" * joinpath(@__DIR__, "keys", "oct", "jwkkey.json"), ["HS256", "HS384", "HS512"])
test_in_mem_keyset(joinpath(@__DIR__, "keys", "oct", "jwkkey.json"))
test_signing_asymmetric_keys("file://" * joinpath(@__DIR__, "keys", "rsa", "jwkkey.json"), ["RS256"])
test_with_valid_jwt("file://" * joinpath(@__DIR__, "keys", "oct", "jwkkey.json"), ["HS256", "HS384", "HS512"])
end

@testset "alg" begin
rsakey = MbedTLS.parse_keyfile(joinpath(@__DIR__, "keys", "rsa", "rsakey1.private.pem"))
@test JWTs.alg(JWKRSA(MbedTLS.MD_SHA256, rsakey)) == "RS256"
@test JWTs.alg(JWKRSA(MbedTLS.MD_SHA384, rsakey)) == "RS384"
@test JWTs.alg(JWKRSA(MbedTLS.MD_SHA, rsakey)) == "RS512"

@test JWTs.alg(JWKSymmetric(MbedTLS.MD_SHA256, UInt8[])) == "HS256"
@test JWTs.alg(JWKSymmetric(MbedTLS.MD_SHA384, UInt8[])) == "HS384"
@test JWTs.alg(JWKSymmetric(MbedTLS.MD_SHA, UInt8[])) == "HS512"

for kind in (MbedTLS.MD_SHA1, MbedTLS.MD_SHA224)
@test_throws ArgumentError JWTs.alg(JWKRSA(kind, rsakey))
@test_throws ArgumentError JWTs.alg(JWKSymmetric(kind, UInt8[]))
end
end
end

2 comments on commit a84e31e

@tanmaykm
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/109345

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.0 -m "<description of version>" a84e31e342c38af801b4c8b621e9a892ffdd39fe
git push origin v0.3.0

Please sign in to comment.