From f7877646a4a16678f9f0ff1af9192aa0e742582f Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Sun, 9 Oct 2022 00:35:00 -0400 Subject: [PATCH] allow passing functions Co-authored-by: Yingbo Ma --- Project.toml | 2 +- src/api.jl | 36 +++++++++++++++++++++++------------- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index e963016..233d96d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "VectorizedRNG" uuid = "33b4df10-0173-11e9-2a0c-851a7edac40e" authors = ["Chris Elrod "] -version = "0.2.19" +version = "0.2.20" [deps] Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" diff --git a/src/api.jl b/src/api.jl index 44167e2..86259cb 100644 --- a/src/api.jl +++ b/src/api.jl @@ -167,7 +167,7 @@ end @inline _vload(ptr::VectorizationBase.AbstractStridedPointer, args::Vararg{Any,K}) where {K} = vload(ptr, args...) @inline _vload(x::Number, args::Vararg{Any,K}) where {K} = x -function random_sample_u2!(f::F, rng::AbstractVRNG{P}, x::AbstractArray{T}, α, β, γ) where {F,P,T} +function random_sample_u2!(f::F, rng::AbstractVRNG{P}, x::AbstractArray{T}, α, β, γ, g::G=identity) where {F,P,T,G} state = getstate(rng, Val{2}(), pick_vector_width(UInt64)) GC.@preserve x begin ptrx = zero_pointer(x); ptrβ = zero_pointer(β); ptrγ = zero_pointer(γ); @@ -179,8 +179,8 @@ function random_sample_u2!(f::F, rng::AbstractVRNG{P}, x::AbstractArray{T}, α, z₁, z₂ = data(zvu2) x₁ = vload(ptrx, (n,)); β₁ = _vload(ptrβ, (n,)); γ₁ = _vload(ptrγ, (n,)); x₂ = vload(ptrx, (vadd(W, n),)); β₂ = _vload(ptrβ, (vadd(W, n),)); γ₂ = _vload(ptrγ, (vadd(W, n),)); - vstore!(ptrx, α * x₁ + z₁ * γ₁ + β₁, (n,)); - vstore!(ptrx, α * x₂ + z₁ * γ₂ + β₂, (vadd(W, n),)); + vstore!(ptrx, g(α * x₁ + z₁ * γ₁ + β₁), (n,)); + vstore!(ptrx, g(α * x₂ + z₁ * γ₂ + β₂), (vadd(W, n),)); n = vadd(W2, n) end m = VectorizationBase.mask(W, N) @@ -189,19 +189,19 @@ function random_sample_u2!(f::F, rng::AbstractVRNG{P}, x::AbstractArray{T}, α, z₁, z₂ = data(zvu2) x₁ = vload(ptrx, (n,)); β₁ = _vload(ptrβ, (n,)); γ₁ = _vload(ptrγ, (n,)); x₂ = vload(ptrx, (vadd(W, n),), m); β₂ = _vload(ptrβ, (vadd(W, n),), m); γ₂ = _vload(ptrγ, (vadd(W, n),), m); - vstore!(ptrx, α * x₁ + z₁ * γ₁ + β₁, (n,)); - vstore!(ptrx, α * x₂ + z₂ * γ₂ + β₂, (vadd(W, n),), m); + vstore!(ptrx, g(α * x₁ + z₁ * γ₁ + β₁,), (n,)); + vstore!(ptrx, g(α * x₂ + z₂ * γ₂ + β₂,), (vadd(W, n),), m); elseif scalar_less(n, N) state, zvu1 = f(state, Val{1}(), T) (z₁,) = data(zvu1) x₁ = vload(ptrx, (n,), m); β₁ = _vload(ptrβ, (n,), m); γ₁ = _vload(ptrγ, (n,), m); - vstore!(ptrx, α * x₁ + z₁ * γ₁ + β₁, (n,), m); + vstore!(ptrx, g(α * x₁ + z₁ * γ₁ + β₁), (n,), m); end storestate!(rng, state) end # GC preserve x end -function random_sample_u2!(f::F, rng::AbstractVRNG{P}, x::AbstractArray{T}, ::StaticInt{0}, β, γ) where {F,P,T} +function random_sample_u2!(f::F, rng::AbstractVRNG{P}, x::AbstractArray{T}, ::StaticInt{0}, β, γ,g::G=identity) where {F,P,T,G} state = getstate(rng, Val{2}(), pick_vector_width(UInt64)) GC.@preserve x begin ptrx = zero_pointer(x); ptrβ = zero_pointer(β); ptrγ = zero_pointer(γ); @@ -213,8 +213,8 @@ function random_sample_u2!(f::F, rng::AbstractVRNG{P}, x::AbstractArray{T}, ::St (z₁,z₂) = data(zvu2) β₁ = _vload(ptrβ, (n,)); γ₁ = _vload(ptrγ, (n,)); β₂ = _vload(ptrβ, (vadd(W, n),)); γ₂ = _vload(ptrγ, (vadd(W, n),)); - vstore!(ptrx, z₁ * γ₁ + β₁, (n,)); - vstore!(ptrx, z₂ * γ₂ + β₂, (vadd(W, n),)); + vstore!(ptrx, g(z₁ * γ₁ + β₁), (n,)); + vstore!(ptrx, g(z₂ * γ₂ + β₂), (vadd(W, n),)); n = vadd(W2, n) end m = VectorizationBase.mask(W, N) @@ -223,13 +223,13 @@ function random_sample_u2!(f::F, rng::AbstractVRNG{P}, x::AbstractArray{T}, ::St (z₁,z₂) = data(zvu2) β₁ = _vload(ptrβ, (n,)); γ₁ = _vload(ptrγ, (n,)); β₂ = _vload(ptrβ, (vadd(W, n),), m); γ₂ = _vload(ptrγ, (vadd(W, n),), m); - vstore!(ptrx, z₁ * γ₁ + β₁, (n,)); - vstore!(ptrx, z₂ * γ₂ + β₂, (vadd(W, n),), m); + vstore!(ptrx, g(z₁ * γ₁ + β₁), (n,)); + vstore!(ptrx, g(z₂ * γ₂ + β₂), (vadd(W, n),), m); elseif scalar_less(n, N) state, zvu1 = f(state, Val{1}(), T) (z₁,) = data(zvu1) β₁ = _vload(ptrβ, (n,), m); γ₁ = _vload(ptrγ, (n,), m); - vstore!(ptrx, z₁ * γ₁ + β₁, (n,), m); + vstore!(ptrx, g(z₁ * γ₁ + β₁), (n,), m); end storestate!(rng, state) end # GC preserve @@ -238,11 +238,21 @@ function random_sample_u2!(f::F, rng::AbstractVRNG{P}, x::AbstractArray{T}, ::St x end + +function Random.rand!( + f::F, + rng::AbstractVRNG, x::AbstractArray{T}, α::Number = StaticInt{0}(), β = StaticInt{0}(), γ = StaticInt{1}() +) where {T <: Union{Float32,Float64},F} + random_sample_u2!(random_uniform, rng, x, α, β, γ, f) +end + function Random.rand!( rng::AbstractVRNG, x::AbstractArray{T}, α::Number = StaticInt{0}(), β = StaticInt{0}(), γ = StaticInt{1}() ) where {T <: Union{Float32,Float64}} - random_sample_u2!(random_uniform, rng, x, α, β, γ) + random_sample_u2!(random_uniform, rng, x, α, β, γ, identity) end + + function Random.randn!( rng::AbstractVRNG, x::AbstractArray{T}, α::Number = StaticInt{0}(), β = StaticInt{0}(), γ = StaticInt{1}() ) where {T<:Union{Float32,Float64}}