Skip to content

Commit

Permalink
Fix unit test
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
adam2392 committed Jul 21, 2023
1 parent e91d5bd commit d7a696b
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

from pywhy_stats import Methods, PValueResult, independence_test

rng = np.random.default_rng(0)


def test_independence_test():
# Replace these arrays with appropriate test data for your specific case
X = np.random.standard_normal(size=(20, 3)) # 10 samples, 3 features for X
Y = np.random.standard_normal(size=(20, 3)) # 10 samples, 3 features for Y
condition_on = np.random.standard_normal(
size=(20, 3)
) # 10 samples, 3 features for condition_on
X = rng.standard_normal(size=(20, 3)) # 10 samples, 3 features for X
Y = rng.standard_normal(size=(20, 3)) # 10 samples, 3 features for Y
condition_on = rng.standard_normal(size=(20, 3)) # 10 samples, 3 features for condition_on

# Test case for unconditional independence test
result = independence_test(X, Y, method=Methods.FISHERZ)
Expand All @@ -28,9 +28,11 @@ def test_independence_test():
independence_test(X, Y, method="unsupported_method")

# Test invalid data is caught with testing for Gaussianity
n_samples = 2000
X = np.random.uniform(low=0.0, high=10, size=(n_samples, 3)) # 10 samples, 3 features for X
Y = np.random.uniform(size=(n_samples, 3)) # 10 samples, 3 features for X
condition_on = np.random.uniform(size=(n_samples, 3)) # 10 samples, 3 features for condition_on
n_samples = 1000
X1 = rng.uniform(low=-5, high=0, size=(n_samples // 2, 3)) # 10 samples, 3 features for X
X2 = rng.normal(loc=10, scale=1, size=(n_samples // 2, 3)) # 10 samples, 3 features for X
X = np.concatenate((X1, X2), axis=0)
Y = rng.uniform(size=(n_samples, 3)) # 10 samples, 3 features for X
condition_on = rng.uniform(size=(n_samples, 3)) # 10 samples, 3 features for condition_on
with pytest.warns(UserWarning, match="The provided data does not seem to be Gaussian"):
independence_test(X, Y, condition_on, method=Methods.FISHERZ)

0 comments on commit d7a696b

Please sign in to comment.