Skip to content

Commit

Permalink
Closes Bears-R-Us#3372: Add logistic to random number generators (Bea…
Browse files Browse the repository at this point in the history
…rs-R-Us#3605)

* Closes Bears-R-Us#3372: Add logistic to random number generators

This PR (closes Bears-R-Us#3372) adds logistic to generators. This doesn't add multi-dim bc it uses `uniformStreamPerElem`

* added test, docs, and helper

---------

Co-authored-by: Tess Hayes <stress-tess@users.noreply.github.com>
  • Loading branch information
stress-tess and stress-tess authored Aug 8, 2024
1 parent 7c828f8 commit baffa9f
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 40 deletions.
42 changes: 42 additions & 0 deletions PROTO_tests/tests/random_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from collections import Counter
from itertools import product

import numpy as np
import pytest
Expand Down Expand Up @@ -194,6 +195,30 @@ def test_choice_flags(self):
current = rng.choice(a, size, replace, p)
assert np.allclose(previous.to_list(), current.to_list())

def test_logistic(self):
scal = 2
arr = ak.arange(5)

for loc, scale in product([scal, arr], [scal, arr]):
rng = ak.random.default_rng(17)
num_samples = 5
log_sample = rng.logistic(loc=loc, scale=scale, size=num_samples).to_list()

rng = ak.random.default_rng(17)
assert rng.logistic(loc=loc, scale=scale, size=num_samples).to_list() == log_sample

def test_lognormal(self):
scal = 2
arr = ak.arange(5)

for mean, sigma in product([scal, arr], [scal, arr]):
rng = ak.random.default_rng(17)
num_samples = 5
log_sample = rng.lognormal(mean=mean, sigma=sigma, size=num_samples).to_list()

rng = ak.random.default_rng(17)
assert rng.lognormal(mean=mean, sigma=sigma, size=num_samples).to_list() == log_sample

def test_normal(self):
rng = ak.random.default_rng(17)
both_scalar = rng.normal(loc=10, scale=2, size=10).to_list()
Expand Down Expand Up @@ -394,6 +419,23 @@ def test_exponential_hypothesis_testing(self, method):
)
assert ks_res.pvalue > 0.05

def test_logistic_hypothesis_testing(self):
# I tested this many times without a set seed, but with no seed
# it's expected to fail one out of every ~20 runs given a pval limit of 0.05.
rng = np.random.default_rng(34)
num_samples = 10**4
mu = rng.uniform(0, 10)
scale = rng.uniform(0, 10)

sample = rng.logistic(loc=mu, scale=scale, size=num_samples)
sample_list = sample.tolist()

# second goodness of fit test against the distribution with proper mean and std
good_fit_res = sp_stats.goodness_of_fit(
sp_stats.logistic, sample_list, known_params={"loc": mu, "scale": scale}
)
assert good_fit_res.pvalue > 0.05

def test_legacy_randint(self):
testArray = ak.random.randint(0, 10, 5)
assert isinstance(testArray, ak.pdarray)
Expand Down
122 changes: 96 additions & 26 deletions arkouda/random/_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,9 @@ def exponential(self, scale=1.0, size=None, method="zig"):
pdarray
Drawn samples from the parameterized exponential distribution.
"""
if isinstance(scale, pdarray):
if (scale < 0).any():
raise ValueError("scale cannot be less then 0")
elif _val_isinstance_of_union(scale, numeric_scalars):
if scale < 0:
raise ValueError("scale cannot be less then 0")
else:
raise TypeError("scale must be either a float scalar or pdarray")
_, scale = float_array_or_scalar_helper("exponential", "scale", scale, size)
if (scale < 0).any() if isinstance(scale, pdarray) else scale < 0:
raise TypeError("scale must be non-negative.")
return scale * self.standard_exponential(size, method=method)

def standard_exponential(self, size=None, method="zig"):
Expand Down Expand Up @@ -297,6 +292,78 @@ def integers(self, low, high=None, size=None, dtype=akint64, endpoint=False):
self._state += full_size
return create_pdarray(rep_msg)

def logistic(self, loc=0.0, scale=1.0, size=None):
r"""
Draw samples from a logistic distribution.
Samples are drawn from a logistic distribution with specified parameters,
loc (location or mean, also median), and scale (>0).
Parameters
----------
loc: float or pdarray of floats, optional
Parameter of the distribution. Default of 0.
scale: float or pdarray of floats, optional
Parameter of the distribution. Must be non-negative. Default is 1.
size: numeric_scalars, optional
Output shape. Default is None, in which case a single value is returned.
Notes
-----
The probability density for the Logistic distribution is
.. math::
P(x) = \frac{e^{-(x - \mu)/s}}{s( 1 + e^{-(x - \mu)/s})^2}
where :math:`\mu` is the location and :math:`s` is the scale.
The Logistic distribution is used in Extreme Value problems where it can act
as a mixture of Gumbel distributions, in Epidemiology, and by the World Chess Federation (FIDE)
where it is used in the Elo ranking system, assuming the performance of each player
is a logistically distributed random variable.
Returns
-------
pdarray
Pdarray of floats (unless size=None, in which case a single float is returned).
See Also
--------
normal
Examples
--------
>>> ak.random.default_rng(17).logistic(3, 2.5, 3)
array([1.1319566682702642 -7.1665150633720014 7.7208667145173608])
"""
if size is None:
# delegate to numpy when return size is 1
return self._np_generator.logistic(loc=loc, scale=scale, size=size)

is_single_mu, mu = float_array_or_scalar_helper("logistic", "loc", loc, size)
is_single_scale, scale = float_array_or_scalar_helper("logistic", "scale", scale, size)
if (scale < 0).any() if isinstance(scale, pdarray) else scale < 0:
raise TypeError("scale must be non-negative.")

rep_msg = generic_msg(
cmd="logisticGenerator",
args={
"name": self._name_dict[akdtype("float64")],
"mu": mu,
"is_single_mu": is_single_mu,
"scale": scale,
"is_single_scale": is_single_scale,
"size": size,
"has_seed": self._seed is not None,
"state": self._state,
},
)
# we only generate one val using the generator in the symbol table
self._state += 1
return create_pdarray(rep_msg)

def lognormal(self, mean=0.0, sigma=1.0, size=None, method="zig"):
r"""
Draw samples from a log-normal distribution with specified mean,
Expand Down Expand Up @@ -622,24 +689,9 @@ def poisson(self, lam=1.0, size=None):
# delegate to numpy when return size is 1
return self._np_generator.poisson(lam, size)

if _val_isinstance_of_union(lam, numeric_scalars):
is_single_lambda = True
if not _val_isinstance_of_union(lam, float_scalars):
lam = float(lam)
if lam < 0:
raise TypeError("lambda must be >=0")
elif isinstance(lam, pdarray):
is_single_lambda = False
if size != lam.size:
raise TypeError("array of lambdas must have same size as return size")
if lam.dtype != akfloat64:
from arkouda.numeric import cast as akcast

lam = akcast(lam, akfloat64)
if (lam < 0).any():
raise TypeError("all lambdas must be >=0")
else:
raise TypeError("poisson only accepts a pdarray or float scalar for lam")
is_single_lambda, lam = float_array_or_scalar_helper("poisson", "lam", lam, size)
if (lam < 0).any() if isinstance(lam, pdarray) else lam < 0:
raise TypeError("lam must be non-negative.")

rep_msg = generic_msg(
cmd="poissonGenerator",
Expand Down Expand Up @@ -758,3 +810,21 @@ def default_rng(seed=None):
).split()[1]

return Generator(name_dict, seed if has_seed else None, state=state)


def float_array_or_scalar_helper(func_name, var_name, var, size):
if _val_isinstance_of_union(var, numeric_scalars):
is_scalar = True
if not _val_isinstance_of_union(var, float_scalars):
var = float(var)
elif isinstance(var, pdarray):
is_scalar = False
if size != var.size:
raise TypeError(f"array of {var_name} must have same size as return size")
if var.dtype != akfloat64:
from arkouda.numeric import cast as akcast

var = akcast(var, akfloat64)
else:
raise TypeError(f"{func_name} only accepts a pdarray or float scalar for {var_name}")
return is_scalar, var
4 changes: 4 additions & 0 deletions pydoc/usage/random.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ integers
---------
.. autofunction:: arkouda.random.Generator.integers

logistic
---------
.. autofunction:: arkouda.random.Generator.logistic

lognormal
---------
.. autofunction:: arkouda.random.Generator.lognormal
Expand Down
35 changes: 34 additions & 1 deletion src/RandMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,38 @@ module RandMsg
return MsgTuple.error("choice does not support the bigint dtype");
}

inline proc logisticGenerator(mu: real, scale: real, ref rs) {
var U = rs.next(0, 1);

while U <= 0.0 {
/* Reject U == 0.0 and call again to get next value */
U = rs.next(0, 1);
}
return mu + scale * log(U / (1.0 - U));
}

proc logisticGeneratorMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
const name = msgArgs["name"],
isSingleMu = msgArgs["is_single_mu"].toScalar(bool),
muStr = msgArgs["mu"].toScalar(string),
isSingleScale = msgArgs["is_single_scale"].toScalar(bool),
scaleStr = msgArgs["scale"].toScalar(string),
size = msgArgs["size"].toScalar(int),
hasSeed = msgArgs["has_seed"].toScalar(bool),
state = msgArgs["state"].toScalar(int);

var generatorEntry = st[name]: borrowed GeneratorSymEntry(real);
ref rng = generatorEntry.generator;
if state != 1 then rng.skipTo(state-1);

var logisticArr = makeDistArray(size, real);
const mu = new scalarOrArray(muStr, !isSingleMu, st),
scale = new scalarOrArray(scaleStr, !isSingleScale, st);

uniformStreamPerElem(logisticArr, rng, GenerationFunction.LogisticGenerator, hasSeed, mu=mu, scale=scale);
return st.insert(createSymEntry(logisticArr));
}

@arkouda.instantiateAndRegister
proc permutation(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws {
const name = msgArgs["name"],
Expand Down Expand Up @@ -656,7 +688,7 @@ module RandMsg
var poissonArr = makeDistArray(size, int);
const lam = new scalarOrArray(lamStr, !isSingleLam, st);

uniformStreamPerElem(poissonArr, rng, GenerationFunction.PoissonGenerator, hasSeed, lam);
uniformStreamPerElem(poissonArr, rng, GenerationFunction.PoissonGenerator, hasSeed, lam=lam);
return st.insert(createSymEntry(poissonArr));
}

Expand Down Expand Up @@ -717,6 +749,7 @@ module RandMsg
}

use CommandMap;
registerFunction("logisticGenerator", logisticGeneratorMsg, getModuleName());
registerFunction("segmentedSample", segmentedSampleMsg, getModuleName());
registerFunction("poissonGenerator", poissonGeneratorMsg, getModuleName());
}
15 changes: 14 additions & 1 deletion src/RandUtil.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,18 @@ module RandUtil {

enum GenerationFunction {
ExponentialGenerator,
LogisticGenerator,
NormalGenerator,
PoissonGenerator,
}

// TODO how to update this to handle randArr being a multi-dim array??
// I thought to just do the same randArr[randArr.domain.orderToIndex(i)] trick
// but im not sure how randArr.localSubdomain() will differ with multi-dim
proc uniformStreamPerElem(ref randArr: [?D] ?t, ref rng, param function: GenerationFunction, hasSeed: bool, const lam: scalarOrArray(?) = new scalarOrArray()) throws {
proc uniformStreamPerElem(ref randArr: [?D] ?t, ref rng, param function: GenerationFunction, hasSeed: bool,
const lam: scalarOrArray(?) = new scalarOrArray(),
const mu: scalarOrArray(?) = new scalarOrArray(),
const scale: scalarOrArray(?) = new scalarOrArray()) throws {
if hasSeed {
// use a fixed number of elements per stream instead of relying on number of locales or numTasksPerLoc because these
// can vary from run to run / machine to mahchine. And it's important for the same seed to give the same results
Expand Down Expand Up @@ -79,6 +83,9 @@ module RandUtil {
when GenerationFunction.ExponentialGenerator {
agg.copy(randArr[i], standardExponentialZig(realRS, uintRS));
}
when GenerationFunction.LogisticGenerator {
agg.copy(randArr[i], logisticGenerator(mu[i], scale[i], realRS));
}
when GenerationFunction.NormalGenerator {
agg.copy(randArr[i], standardNormZig(realRS, uintRS));
}
Expand All @@ -101,6 +108,9 @@ module RandUtil {
when GenerationFunction.ExponentialGenerator {
randArr[i] = standardExponentialZig(realRS, uintRS);
}
when GenerationFunction.LogisticGenerator {
randArr[i] = logisticGenerator(mu[i], scale[i], realRS);
}
when GenerationFunction.NormalGenerator {
randArr[i] = standardNormZig(realRS, uintRS);
}
Expand All @@ -124,6 +134,9 @@ module RandUtil {
when GenerationFunction.ExponentialGenerator {
rv = standardExponentialZig(realRS, uintRS);
}
when GenerationFunction.LogisticGenerator {
rv = logisticGenerator(mu[i], scale[i], realRS);
}
when GenerationFunction.NormalGenerator {
rv = standardNormZig(realRS, uintRS);
}
Expand Down
24 changes: 12 additions & 12 deletions src/registry/Commands.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -1341,51 +1341,51 @@ registerFunction('choice<bigint>', ark_choice_bigint, 'RandMsg', 503);

proc ark_permutation_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
return RandMsg.permutation(cmd, msgArgs, st, array_dtype=int, array_nd=1);
registerFunction('permutation<int64,1>', ark_permutation_int_1, 'RandMsg', 577);
registerFunction('permutation<int64,1>', ark_permutation_int_1, 'RandMsg', 609);

proc ark_permutation_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
return RandMsg.permutation(cmd, msgArgs, st, array_dtype=uint, array_nd=1);
registerFunction('permutation<uint64,1>', ark_permutation_uint_1, 'RandMsg', 577);
registerFunction('permutation<uint64,1>', ark_permutation_uint_1, 'RandMsg', 609);

proc ark_permutation_uint8_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
return RandMsg.permutation(cmd, msgArgs, st, array_dtype=uint(8), array_nd=1);
registerFunction('permutation<uint8,1>', ark_permutation_uint8_1, 'RandMsg', 577);
registerFunction('permutation<uint8,1>', ark_permutation_uint8_1, 'RandMsg', 609);

proc ark_permutation_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
return RandMsg.permutation(cmd, msgArgs, st, array_dtype=real, array_nd=1);
registerFunction('permutation<float64,1>', ark_permutation_real_1, 'RandMsg', 577);
registerFunction('permutation<float64,1>', ark_permutation_real_1, 'RandMsg', 609);

proc ark_permutation_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
return RandMsg.permutation(cmd, msgArgs, st, array_dtype=bool, array_nd=1);
registerFunction('permutation<bool,1>', ark_permutation_bool_1, 'RandMsg', 577);
registerFunction('permutation<bool,1>', ark_permutation_bool_1, 'RandMsg', 609);

proc ark_permutation_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
return RandMsg.permutation(cmd, msgArgs, st, array_dtype=bigint, array_nd=1);
registerFunction('permutation<bigint,1>', ark_permutation_bigint_1, 'RandMsg', 577);
registerFunction('permutation<bigint,1>', ark_permutation_bigint_1, 'RandMsg', 609);

proc ark_shuffle_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
return RandMsg.shuffle(cmd, msgArgs, st, array_dtype=int, array_nd=1);
registerFunction('shuffle<int64,1>', ark_shuffle_int_1, 'RandMsg', 664);
registerFunction('shuffle<int64,1>', ark_shuffle_int_1, 'RandMsg', 696);

proc ark_shuffle_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
return RandMsg.shuffle(cmd, msgArgs, st, array_dtype=uint, array_nd=1);
registerFunction('shuffle<uint64,1>', ark_shuffle_uint_1, 'RandMsg', 664);
registerFunction('shuffle<uint64,1>', ark_shuffle_uint_1, 'RandMsg', 696);

proc ark_shuffle_uint8_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
return RandMsg.shuffle(cmd, msgArgs, st, array_dtype=uint(8), array_nd=1);
registerFunction('shuffle<uint8,1>', ark_shuffle_uint8_1, 'RandMsg', 664);
registerFunction('shuffle<uint8,1>', ark_shuffle_uint8_1, 'RandMsg', 696);

proc ark_shuffle_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
return RandMsg.shuffle(cmd, msgArgs, st, array_dtype=real, array_nd=1);
registerFunction('shuffle<float64,1>', ark_shuffle_real_1, 'RandMsg', 664);
registerFunction('shuffle<float64,1>', ark_shuffle_real_1, 'RandMsg', 696);

proc ark_shuffle_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
return RandMsg.shuffle(cmd, msgArgs, st, array_dtype=bool, array_nd=1);
registerFunction('shuffle<bool,1>', ark_shuffle_bool_1, 'RandMsg', 664);
registerFunction('shuffle<bool,1>', ark_shuffle_bool_1, 'RandMsg', 696);

proc ark_shuffle_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
return RandMsg.shuffle(cmd, msgArgs, st, array_dtype=bigint, array_nd=1);
registerFunction('shuffle<bigint,1>', ark_shuffle_bigint_1, 'RandMsg', 664);
registerFunction('shuffle<bigint,1>', ark_shuffle_bigint_1, 'RandMsg', 696);

import StatsMsg;

Expand Down

0 comments on commit baffa9f

Please sign in to comment.