Skip to content

Commit

Permalink
Merge pull request #1306 from Mv77/plumbing/labeled-dist-of-fun
Browse files Browse the repository at this point in the history
[WIP] Labeled dist-of-fun
  • Loading branch information
alanlujan91 authored Jul 26, 2023
2 parents fc891e4 + 349182b commit 023e8cd
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
2 changes: 1 addition & 1 deletion HARK/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,7 +1181,7 @@ def func_wrapper(x: np.ndarray, *args: Any) -> np.ndarray:

if len(kwargs):
f_query = func(self.dataset, **kwargs)
ldd = DiscreteDistributionLabeled.from_dataset(f_query, self.pmv)
ldd = DiscreteDistributionLabeled.from_dataset(f_query, self.probability)

return ldd

Expand Down
39 changes: 39 additions & 0 deletions HARK/tests/test_distribution.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest

import numpy as np
import xarray as xr

from HARK.distribution import (
Bernoulli,
Expand Down Expand Up @@ -603,3 +604,41 @@ def test_combine_labeled_dist(self):
np.concatenate([de.expected(), abc.expected()]),
)
)


class labeled_transition_tests(unittest.TestCase):
def setUp(self) -> None:
return super().setUp()

def test_expectation_transformation(self):
# Create a basic labeled distribution
base_dist = DiscreteDistributionLabeled(
pmv=np.array([0.5, 0.5]),
atoms=np.array([[1.0, 2.0], [3.0, 4.0]]),
var_names=["a", "b"],
)

# Define a transition function
def transition(shocks, state):
state_new = {}
state_new["m"] = state["m"] * shocks["a"]
state_new["n"] = state["n"] * shocks["b"]
return state_new

m = xr.DataArray(np.linspace(0, 10, 11), name="m", dims=("grid",))
n = xr.DataArray(np.linspace(0, -10, 11), name="n", dims=("grid",))
state_grid = xr.Dataset({"m": m, "n": n})

# Evaluate labeled transformation

# Direct expectation
exp1 = base_dist.expected(transition, state=state_grid)
# Expectation after transformation
new_state_dstn = base_dist.dist_of_func(transition, state=state_grid)
# TODO: needs a cluncky identity function with an extra argument because
# DDL.expected() behavior is very different with and without kwargs.
# Fix!
exp2 = new_state_dstn.expected(lambda x, unused: x, unused=0)

assert np.all(exp1["m"] == exp2["m"]).item()
assert np.all(exp1["n"] == exp2["n"]).item()

0 comments on commit 023e8cd

Please sign in to comment.