From 3832e27e252276368775d997a594c99f69b2f69e Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Wed, 14 Aug 2024 13:09:10 +0200 Subject: [PATCH 1/8] * update API with incremental * update to accept arbitrary dt --- essm_jax/essm.py | 300 +++++++++++++++++++++++++++++------------------ 1 file changed, 186 insertions(+), 114 deletions(-) diff --git a/essm_jax/essm.py b/essm_jax/essm.py index c0d2b1f..71fb768 100644 --- a/essm_jax/essm.py +++ b/essm_jax/essm.py @@ -36,6 +36,13 @@ class FilterResult(NamedTuple): observation_cov: jax.Array # [num_timesteps, observation_size, observation_size] The covariance of p(x[t] | x[:t-1]) +class IncrementalFilterState(NamedTuple): + t: jax.Array # the time index + log_cumulative_marginal_likelihood: jax.Array # log marginal likelihood prod_t p(x[t] | x[:t-1]) + filtered_mean: jax.Array # [latent_size] mean of p(z[t+1] | x[:t]) + filtered_cov: jax.Array # [latent_size, latent_size] covariance of p(z[t+1] | x[:t]) + + class SmoothingResult(NamedTuple): """ Represents the result of a backward smoothing pass. @@ -90,12 +97,14 @@ class ExtendedStateSpaceModel: Must be a MultivariateNormalLinearOperator. more_data_than_params: If True, the observation function has more outputs than inputs. materialise_jacobians: If True, the Jacobians are materialised as dense matrices. + dt: The time step size, default is 1. """ transition_fn: Callable[[jax.Array, jax.Array], tfpd.MultivariateNormalLinearOperator] observation_fn: Callable[[jax.Array, jax.Array], tfpd.MultivariateNormalLinearOperator] initial_state_prior: tfpd.MultivariateNormalLinearOperator more_data_than_params: bool = False materialise_jacobians: bool = False + dt: float = 1. def __post_init__(self): if not callable(self.transition_fn): @@ -165,7 +174,7 @@ def _sample_latents_op(latent, y): init = self.initial_state_prior.sample(seed=init_key) xs = ( jax.random.split(latent_key, num_time), - jnp.arange(1, num_time + 1) + t0 + jnp.arange(1, num_time + 1) * self.dt + t0 ) _, samples = lax.scan( _sample_latents_op, @@ -197,24 +206,34 @@ def _check_shapes(self, observations: jax.Array, mask: Optional[jax.Array] = Non raise ValueError('mask and observations must have the same length') def forward_simulate(self, key: jax.Array, num_time: int, - observations: jax.Array, mask: Optional[jax.Array] = None) -> SampleResult: + filter_result: Union[FilterResult | IncrementalFilterState]) -> SampleResult: """ - Simulate from the model, from the end of the forward filtering pass. + Simulate from the model, from the end of a forward filtering pass. Args: key: a PRNGKey num_time: the number of time steps to simulate - observations: [num_time, observation_size] array of observations - mask: [num_time] array of masks, True for missing observations + filter_result: the result of the forward filtering pass, or incremental filter update. Returns: sample result: num_time timesteps of forward simulated """ - filter_result = self.forward_filter(observations, mask) - initial_state_prior = tfpd.MultivariateNormalTriL( - loc=filter_result.filtered_mean[-1], - scale_tril=jnp.linalg.cholesky(_efficient_add_scalar_diag(filter_result.filtered_cov[-1], 1e-6)) - ) + if isinstance(filter_result, FilterResult): + initial_state_prior = tfpd.MultivariateNormalTriL( + loc=filter_result.filtered_mean[-1], + scale_tril=lax.linalg.cholesky(_efficient_add_scalar_diag(filter_result.filtered_cov[-1], 1e-6), + symmetrize_input=False) + ) + t0 = filter_result.t[-1] + elif isinstance(filter_result, IncrementalFilterState): + initial_state_prior = tfpd.MultivariateNormalTriL( + loc=filter_result.filtered_mean, + scale_tril=lax.linalg.cholesky(_efficient_add_scalar_diag(filter_result.filtered_cov, 1e-6), + symmetrize_input=False) + ) + t0 = filter_result.t + else: + raise ValueError('filter_result must be a FilterResult or IncrementalFilter instance.') new_essm = ExtendedStateSpaceModel( transition_fn=self.transition_fn, observation_fn=self.observation_fn, @@ -222,11 +241,142 @@ def forward_simulate(self, key: jax.Array, num_time: int, more_data_than_params=self.more_data_than_params, materialise_jacobians=self.materialise_jacobians ) - return new_essm.sample(key=key, num_time=num_time, t0=filter_result.t[-1]) + return new_essm.sample(key=key, num_time=num_time, t0=t0) + + def incremental_update(self, filter_state: IncrementalFilterState, observation: jax.Array, + mask: Optional[jax.Array] = None) -> Tuple[ + IncrementalFilterState, tfpd.MultivariateNormalLinearOperator]: + """ + Perform an incremental update of the filter state. Does not advance the time index. I.e. produces + + p(z[t] | x[:t]) from p(z[t] | x[:t-1]) and p(x[t] | x[:t-1]) + + Args: + filter_state: the current filter state + observation: [n] the observation at the current time + mask: scalar, the mask at the current time, True for missing observations + + Returns: + the updated filter state, and the marginal distribution of the observation p(x[t] | x[:t-1]) + """ + + Hop = self.get_observation_jacobian(t=filter_state.t, observation_size=observation.size) + H = Hop(filter_state.filtered_mean) + if self.materialise_jacobians: + H = H.to_dense() + + # Update step, compute p(z[t] | x[:t]) from p(z[t] | x[:t-1]) + observation_dist = self.observation_fn(filter_state.filtered_mean, filter_state.t) + R = observation_dist.covariance() + # Push-forward the prior (i.e. predictive) distribution to the observation space + x_expectation = observation_dist.mean() # [observation_size] + tmp_H_P = H @ filter_state.filtered_cov + S = tmp_H_P @ H.T + R # [observation_size, observation_size] + S_chol = lax.linalg.cholesky(S, symmetrize_input=False) # [observation_size, observation_size] + + marginal_dist = tfpd.MultivariateNormalTriL(x_expectation, S_chol) + # return_marginal = tfpd.MultivariateNormalFullCovariance(x_expectation, S) + + # Compute the log marginal likelihood p(x[t] | x[:t-1]) + log_marginal_likelihood = marginal_dist.log_prob(observation) + + # Compute the Kalman gain + # K = predict_cov @ H.T @ inv(S) = predict_cov.T @ H.T @ inv(S) + K = hpsd_solve(S, tmp_H_P, cholesky_matrix=S_chol).T # [latent_size, observation_size] + + # Update the state estimate + filtered_mean = filter_state.filtered_mean + K @ (observation - x_expectation) # [latent_size] + + # Update the state covariance using Joseph's form to ensure positive semi-definite + # tmp_factor = (I - K @ H) + if self.more_data_than_params: + tmp_factor = _efficient_add_scalar_diag((- K) @ H, 1.) + else: + tmp_factor = _efficient_add_scalar_diag(K @ (-H), 1.) + filtered_cov = tmp_factor @ filter_state.filtered_cov @ tmp_factor.T + K @ R @ K.T # [latent_size, latent_size] + + # When masked, then the filtered state is the predicted state. + if mask is not None: + filtered_mean = lax.select(*jnp.broadcast_arrays(mask, + filter_state.filtered_mean, filtered_mean)) + filtered_cov = lax.select(*jnp.broadcast_arrays(mask, + filter_state.filtered_cov, filtered_cov)) + log_marginal_likelihood = lax.select(*jnp.broadcast_arrays(mask, + jnp.zeros_like(log_marginal_likelihood), + log_marginal_likelihood)) + + log_cumulative_marginal_likelihood = filter_state.log_cumulative_marginal_likelihood + log_marginal_likelihood + return IncrementalFilterState( + t=filter_state.t, + log_cumulative_marginal_likelihood=log_cumulative_marginal_likelihood, + filtered_mean=filtered_mean, + filtered_cov=filtered_cov + ), marginal_dist + + def incremental_predict(self, filter_state: IncrementalFilterState) -> IncrementalFilterState: + """ + Perform an incremental prediction step of the filter state, advancing the time index. I.e. produces + + p(z[t+1] | x[:t]) from p(z[t] | x[:t]) + + Args: + filter_state: the current filter state + + Returns: + the predicted filter state, with time index advanced + """ + # Predict step, compute p(z[t+1] | x[:t]) + + t = filter_state.t + jnp.asarray(self.dt, filter_state.t.dtype) + + Fop = self.get_transition_jacobian(t=t) + F = Fop(filter_state.filtered_mean) + if self.materialise_jacobians: + F = F.to_dense() + + predicted_dist = self.transition_fn(filter_state.filtered_mean, t) + predicted_mean = predicted_dist.mean() # [latent_size] + Q = predicted_dist.covariance() # [latent_size, latent_size] + predicted_cov = F @ filter_state.filtered_cov @ F.T + Q # [latent_size, latent_size] + + return IncrementalFilterState( + t=t, + log_cumulative_marginal_likelihood=filter_state.log_cumulative_marginal_likelihood, + filtered_mean=predicted_mean, + filtered_cov=predicted_cov + ) + + def create_initial_filter_state(self, t0: Union[jax.Array, float] = 0.) -> IncrementalFilterState: + """ + Create an incremental filter at the initial time. + + Args: + t0: the time of prior state (before the first observation) + + Returns: + the initial incremental filter state at the first possible observation time, i.e. t0+1. + """ + # Push forward initial state to create p(z[1] | z[0]) + t0 = jnp.asarray(t0, jnp.float32) + t1 = t0 + jnp.asarray(self.dt, t0.dtype) + init_predict_dist = self.transition_fn(self.initial_state_prior.mean(), t1) + + init_Fop = self.get_transition_jacobian(t1) + init_F = init_Fop(self.initial_state_prior.mean()) + if self.materialise_jacobians: + init_F = init_F.to_dense() + init_predicted_mean = init_predict_dist.mean() + init_predicted_cov = init_F @ self.initial_state_prior.covariance() @ init_F.T + init_predict_dist.covariance() + return IncrementalFilterState( + t=t1, + log_cumulative_marginal_likelihood=jnp.asarray(0.), + filtered_mean=init_predicted_mean, + filtered_cov=init_predicted_cov + ) def forward_filter(self, observations: jax.Array, mask: Optional[jax.Array] = None, marginal_likelihood_only: bool = False, - t0: Union[jax.Array, int] = 0) -> Union[FilterResult, jax.Array]: + t0: Union[jax.Array, float] = 0.) -> Union[FilterResult, jax.Array]: """ Run the forward filtering pass, computing the total marginal likelihood @@ -248,19 +398,13 @@ def forward_filter(self, observations: jax.Array, mask: Optional[jax.Array] = No """ self._check_shapes(observations=observations, mask=mask) - class Carry(NamedTuple): - log_cumulative_marginal_likelihood: jax.Array # log marginal likelihood - predicted_mean: jax.Array # [latent_size] mean of p(z[t+1] | x[:t]) - predicted_cov: jax.Array # [latent_size, latent_size] covariance of p(z[t+1] | x[:t]) - class YType(NamedTuple): observation: jax.Array # [observation_size] observation at time t mask: jax.Array # mask at time t - t: jax.Array # time index num_time = np.shape(observations)[0] - def _filter_op(carry: Carry, y: YType) -> Tuple[Carry, FilterResult]: + def _filter_op(filter_state: IncrementalFilterState, y: YType) -> Tuple[IncrementalFilterState, FilterResult]: """ A single step of the forward equations. """ @@ -268,111 +412,38 @@ def _filter_op(carry: Carry, y: YType) -> Tuple[Carry, FilterResult]: # Note: We perform update FIRST, then predict which is contrary to the usual order. # This is so that the filter results naturally align with the smoothing operation. - Hop = self.get_observation_jacobian(t=y.t, observation_size=y.observation.size) - H = Hop(carry.predicted_mean) - if self.materialise_jacobians: - H = H.to_dense() - - # Update step, compute p(z[t] | x[:t]) from p(z[t] | x[:t-1]) - observation_dist = self.observation_fn(carry.predicted_mean, y.t) - R = observation_dist.covariance() - # Push-forward the prior (i.e. predictive) distribution to the observation space - x_expectation = observation_dist.mean() # [observation_size] - tmp_H_P = H @ carry.predicted_cov - S = tmp_H_P @ H.T + R # [observation_size, observation_size] - S_chol = jnp.linalg.cholesky(S) # [observation_size, observation_size] - - marginal_dist = tfpd.MultivariateNormalTriL(x_expectation, S_chol) - - # Compute the log marginal likelihood p(x[t] | x[:t-1]) - log_marginal_likelihood = marginal_dist.log_prob(y.observation) - - # Compute the Kalman gain - # K = predict_cov @ H.T @ inv(S) = predict_cov.T @ H.T @ inv(S) - K = hpsd_solve(S, tmp_H_P, cholesky_matrix=S_chol).T # [latent_size, observation_size] - - # Update the state estimate - filtered_mean = carry.predicted_mean + K @ (y.observation - x_expectation) # [latent_size] - - # Update the state covariance using Joseph's form to ensure positive semi-definite - # tmp_factor = (I - K @ H) - if self.more_data_than_params: - tmp_factor = _efficient_add_scalar_diag((- K) @ H, 1.) - else: - tmp_factor = _efficient_add_scalar_diag(K @ (-H), 1.) - filtered_cov = tmp_factor @ carry.predicted_cov @ tmp_factor.T + K @ R @ K.T # [latent_size, latent_size] - - # When masked, then the filtered state is the predicted state. - if mask is not None: - filtered_mean = lax.select(*jnp.broadcast_arrays(y.mask, - carry.predicted_mean, filtered_mean)) - filtered_cov = lax.select(*jnp.broadcast_arrays(y.mask, - carry.predicted_cov, filtered_cov)) - log_marginal_likelihood = lax.select(*jnp.broadcast_arrays(y.mask, - jnp.zeros_like(log_marginal_likelihood), - log_marginal_likelihood)) + updated_filter_state, marginal_dist = self.incremental_update( + filter_state=filter_state, + observation=y.observation, + mask=y.mask if mask is not None else None + ) # Predict step, compute p(z[t+1] | x[:t]) - - Fop = self.get_transition_jacobian(t=y.t + 1) - F = Fop(filtered_mean) - if self.materialise_jacobians: - F = F.to_dense() - - predicted_dist = self.transition_fn(filtered_mean, y.t + 1) - predicted_mean = predicted_dist.mean() # [latent_size] - Q = predicted_dist.covariance() # [latent_size, latent_size] - predicted_cov = F @ filtered_cov @ F.T + Q # [latent_size, latent_size] - - # Update cumulative marginal likelihood - log_cumulative_marginal_likelihood = carry.log_cumulative_marginal_likelihood + log_marginal_likelihood - - return Carry( - log_cumulative_marginal_likelihood=log_cumulative_marginal_likelihood, - predicted_mean=predicted_mean, - predicted_cov=predicted_cov - ), FilterResult( - t=y.t, - log_cumulative_marginal_likelihood=log_cumulative_marginal_likelihood, - filtered_mean=filtered_mean, - filtered_cov=filtered_cov, - predicted_mean=predicted_mean, - predicted_cov=predicted_cov, - observation_mean=x_expectation, - observation_cov=S + predicted_filter_state = self.incremental_predict(filter_state=updated_filter_state) + + return predicted_filter_state, FilterResult( + t=updated_filter_state.t, + log_cumulative_marginal_likelihood=updated_filter_state.log_cumulative_marginal_likelihood, + filtered_mean=updated_filter_state.filtered_mean, + filtered_cov=updated_filter_state.filtered_cov, + predicted_mean=predicted_filter_state.filtered_mean, + predicted_cov=predicted_filter_state.filtered_cov, + observation_mean=marginal_dist.mean(), + observation_cov=marginal_dist.covariance() ) - # Push forward initial state to create p(z[1] | z[0]) - t0 = jnp.asarray(t0, jnp.int32) - t1 = t0 + jnp.asarray(1, jnp.int32) - init_predict_dist = self.transition_fn(self.initial_state_prior.mean(), t1) - - init_Fop = self.get_transition_jacobian(t1) - init_F = init_Fop(self.initial_state_prior.mean()) - if self.materialise_jacobians: - init_F = init_F.to_dense() - init_predicted_mean = init_predict_dist.mean() - init_predicted_cov = init_F @ self.initial_state_prior.covariance() @ init_F.T + init_predict_dist.covariance() - - # init_predicted_mean = self.initial_state_prior.mean() - # init_predicted_cov = self.initial_state_prior.covariance() - init_result = Carry( - log_cumulative_marginal_likelihood=jnp.asarray(0.), - predicted_mean=init_predicted_mean, - predicted_cov=init_predicted_cov - ) + filter_state = self.create_initial_filter_state(t0=t0) if mask is None: _mask = jnp.zeros(num_time, dtype=jnp.bool_) # dummy variable (we skip the mask select) else: _mask = mask xs = YType( observation=observations, - mask=_mask, - t=jnp.arange(1, num_time + 1, dtype=jnp.int32) + t0 + mask=_mask ) final_accumulate, filter_results = lax.scan( f=_filter_op, - init=init_result, + init=filter_state, xs=xs ) if marginal_likelihood_only: @@ -393,7 +464,7 @@ def log_prob(self, observations: jax.Array, mask: Optional[jax.Array] = None) -> return self.forward_filter(observations, mask, marginal_likelihood_only=True) def posterior_marginals(self, observations: jax.Array, mask: Optional[jax.Array] = None, - t0: Union[jax.Array, int] = 0) -> Union[ + t0: Union[jax.Array, float] = 0.) -> Union[ SmoothingResult, Tuple[SmoothingResult, InitialPrior]]: """ Compute the posterior marginal distributions of the latents, p(z[t] | x[:T]). @@ -443,7 +514,8 @@ def _smooth_op(carry: Carry, y: FilterResult) -> Tuple[Carry, SmoothingResult]: # Compute the RTS smoother gain # J = y.filtered_cov @ F.T @ jnp.linalg.inv(y.predicted_cov) - predicted_cov_chol = jnp.linalg.cholesky(y.predicted_cov) # Possibly need to add a small diagonal jitter + predicted_cov_chol = lax.linalg.cholesky(y.predicted_cov, + symmetrize_input=False) # Possibly need to add a small diagonal jitter tmp_F_P = F @ y.filtered_cov J = hpsd_solve(y.predicted_cov, tmp_F_P, cholesky_matrix=predicted_cov_chol).T @@ -493,7 +565,7 @@ def _smooth_op(carry: Carry, y: FilterResult) -> Tuple[Carry, SmoothingResult]: init_predicted_mean = init_predict_dist.mean() init_predicted_cov = init_F @ self.initial_state_prior.covariance() @ init_F.T + init_predict_dist.covariance() - t0 = t1 - jnp.asarray(1, jnp.int32) + t0 = t1 - jnp.asarray(self.dt, t1.dtype) y = FilterResult( t=t0, log_cumulative_marginal_likelihood=jnp.asarray(0.), From 3489c36274516eaea2a447391b371ffb3450caac Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Wed, 14 Aug 2024 13:09:19 +0200 Subject: [PATCH 2/8] * pytree utils --- essm_jax/pytee_utils.py | 43 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 essm_jax/pytee_utils.py diff --git a/essm_jax/pytee_utils.py b/essm_jax/pytee_utils.py new file mode 100644 index 0000000..71626d9 --- /dev/null +++ b/essm_jax/pytee_utils.py @@ -0,0 +1,43 @@ +from typing import Tuple, Callable, TypeVar + +import jax +import jax.numpy as jnp +import numpy as np +from jax import lax + +PT = TypeVar('PT') + + +def pytree_unravel(example_tree: PT) -> Tuple[Callable[[PT], jax.Array], Callable[[jax.Array], PT]]: + """ + Returns functions to ravel and unravel a pytree. + + Args: + example_tree: a pytree to be unravelled + + Returns: + ravel_fun: a function to ravel the pytree + unravel_fun: a function to unravel + """ + leaf_list, tree_def = jax.tree.flatten(example_tree) + + sizes = [np.size(leaf) for leaf in leaf_list] + shapes = [np.shape(leaf) for leaf in leaf_list] + dtypes = [leaf.dtype for leaf in leaf_list] + + def ravel_fun(pytree: PT) -> jax.Array: + leaf_list, tree_def = jax.tree.flatten(pytree) + # promote types to common one + common_dtype = jnp.result_type(*dtypes) + leaf_list = [leaf.astype(common_dtype) for leaf in leaf_list] + return jnp.concatenate([lax.reshape(leaf, (size,)) for leaf, size in zip(leaf_list, sizes)]) + + def unravel_fun(flat_array: jax.Array) -> PT: + leaf_list = [] + start = 0 + for size, shape, dtype in zip(sizes, shapes, dtypes): + leaf_list.append(lax.reshape(flat_array[start:start + size], shape).astype(dtype)) + start += size + return jax.tree.unflatten(tree_def, leaf_list) + + return ravel_fun, unravel_fun From 30d48223c9c140da53348c35307ec4b575f3fb45 Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Wed, 14 Aug 2024 13:11:34 +0200 Subject: [PATCH 3/8] * update change log --- README.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f1fa846..e668498 100644 --- a/README.md +++ b/README.md @@ -67,8 +67,7 @@ print(smooth_result) forward_samples = essm.forward_simulate( key=jax.random.PRNGKey(0), num_time=25, - observations=samples.observation, - mask=mask + filter_result=filter_result ) import pylab as plt @@ -86,9 +85,14 @@ plt.legend() plt.show() ``` +## Online Filtering + +Take a look at [examples](./docs/examples) to learn how to do online filtering, for interactive application. + # Change Log 13 August 2024: Initial release 1.0.0. +14 August 2024: 1.0.1 released. Added sparse util. Add incremental API for online filtering. Arbitrary dt. ## Star History From cb0ac8146de6656d78a0b4048f2a292f838cc731 Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Wed, 14 Aug 2024 13:11:46 +0200 Subject: [PATCH 4/8] * add sparse utils --- essm_jax/sparse.py | 68 +++++++++++++++++ essm_jax/tests/test_essm.py | 133 ++++++++++++++++++++++++++++++++-- essm_jax/tests/test_sparse.py | 31 ++++++++ 3 files changed, 226 insertions(+), 6 deletions(-) create mode 100644 essm_jax/sparse.py create mode 100644 essm_jax/tests/test_sparse.py diff --git a/essm_jax/sparse.py b/essm_jax/sparse.py new file mode 100644 index 0000000..4b83001 --- /dev/null +++ b/essm_jax/sparse.py @@ -0,0 +1,68 @@ +from typing import Tuple, NamedTuple + +import jax +import jax.numpy as jnp +import numpy as np + + +class SparseRepresentation(NamedTuple): + shape: Tuple[int, ...] + rows: jax.Array + cols: jax.Array + vals: jax.Array + + +def create_sparse_rep(m: np.ndarray) -> SparseRepresentation: + """ + Creates a sparse rep from matrix m. Use in linear models with materialise_jacobian=False for 2x speed up. + + Args: + m: [N,M] matrix + + Returns: + sparse rep + """ + rows, cols = np.where(m) + sort_indices = np.lexsort((cols, rows)) + rows = rows[sort_indices] + cols = cols[sort_indices] + return SparseRepresentation( + shape=np.shape(m), + rows=jnp.asarray(rows), + cols=jnp.asarray(cols), + vals=jnp.asarray(m[rows, cols]) + ) + + +def to_dense(m: SparseRepresentation, out: jax.Array | None = None) -> jax.Array: + """ + Form dense matrix. + + Args: + m: sparse rep + out: output buffer + + Returns: + out + M + """ + if out is None: + out = jnp.zeros(m.shape, m.vals.dtype) + + return out.at[m.rows, m.cols].add(m.vals, unique_indices=True, indices_are_sorted=True) + + +def matvec_sparse(m: SparseRepresentation, v: jax.Array, out: jax.Array | None = None) -> jax.Array: + """ + Compute matmul for sparse rep. Speeds up large sparse linear models by about 2x. + + Args: + m: sparse rep + v: vec + out: output buffer to add to. + + Returns: + out + M @ v + """ + if out is None: + out = jnp.zeros(m.shape[0]) + return out.at[m.rows].add(m.vals * v[m.cols], unique_indices=True, indices_are_sorted=True) diff --git a/essm_jax/tests/test_essm.py b/essm_jax/tests/test_essm.py index c8d3dc2..da7f0cf 100644 --- a/essm_jax/tests/test_essm.py +++ b/essm_jax/tests/test_essm.py @@ -1,6 +1,11 @@ import time import jax +import pytest + +from essm_jax.sparse import create_sparse_rep, matvec_sparse + +jax.config.update('jax_enable_x64', True) import numpy as np import tensorflow_probability.substrates.jax as tfp from jax import numpy as jnp @@ -217,19 +222,21 @@ def observation_fn(z, t): ) sample = essm.sample(jax.random.PRNGKey(0), 1000) - filter_fn = jax.jit(lambda: essm.forward_filter(sample.observation)).lower().compile() - filter_jvp_fn = jax.jit(lambda: essm_jvp.forward_filter(sample.observation)).lower().compile() + filter_fn = jax.jit( + lambda: essm.forward_filter(sample.observation, marginal_likelihood_only=True)).lower().compile() + filter_jvp_fn = jax.jit( + lambda: essm_jvp.forward_filter(sample.observation, marginal_likelihood_only=True)).lower().compile() t0 = time.time() filter_results = filter_fn() - filter_results.t.block_until_ready() + filter_results.block_until_ready() t1 = time.time() dt1 = t1 - t0 print(f"Time for essm: {t1 - t0}") t0 = time.time() filter_results_jvp = filter_jvp_fn() - filter_results_jvp.t.block_until_ready() + filter_results_jvp.block_until_ready() t1 = time.time() dt2 = t1 - t0 print(f"Time for essm_jvp: {t1 - t0}") @@ -282,8 +289,7 @@ def observation_fn(z, t): forward_samples = essm.forward_simulate( key=jax.random.PRNGKey(0), num_time=25, - observations=samples.observation, - mask=mask + filter_result=filter_result ) try: @@ -313,3 +319,118 @@ def test__efficienct_add_scalar_diag(): A = jnp.eye(100) c = 1. assert jnp.all(_efficient_add_scalar_diag(A, c) == A + c * jnp.eye(100)) + + +def test_incremental_filtering(): + def transition_fn(z, t): + mean = z + z * jnp.sin(2 * jnp.pi * t / 10) + cov = 0.1 * jnp.eye(np.size(z)) + return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov)) + + def observation_fn(z, t): + mean = z + cov = t * 0.01 * jnp.eye(np.size(z)) + return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov)) + + n = 1 + + initial_state_prior = tfpd.MultivariateNormalTriL(jnp.zeros(n), jnp.eye(n)) + + essm = ExtendedStateSpaceModel( + transition_fn=transition_fn, + observation_fn=observation_fn, + initial_state_prior=initial_state_prior, + materialise_jacobians=False, # Fast + more_data_than_params=False # if observation is bigger than latent we can speed it up. + ) + samples = essm.sample(jax.random.PRNGKey(0), 100) + + filter_result = essm.forward_filter(samples.observation) + + filter_state = essm.create_initial_filter_state() + + import pylab as plt + + for i in range(100): + filter_state, _ = essm.incremental_update(filter_state, samples.observation[i]) + plt.scatter(filter_state.t, filter_state.filtered_mean, c='black') + assert filter_state.t == filter_result.t[i] + filter_state = essm.incremental_predict(filter_state) + # print(i, np.abs( + # filter_state.log_cumulative_marginal_likelihood - filter_result.log_cumulative_marginal_likelihood[i])) + # print(i, np.max(np.abs(filter_state.filtered_mean - filter_result.predicted_mean[i]))) + # print(i, np.max(np.abs(filter_state.filtered_cov - filter_result.predicted_cov[i]))) + # print(i, np.max(np.abs(filter_state.filtered_cov))) + np.testing.assert_allclose(filter_state.log_cumulative_marginal_likelihood, + filter_result.log_cumulative_marginal_likelihood[i], atol=1e-5) + np.testing.assert_allclose(filter_state.filtered_mean, filter_result.predicted_mean[i], atol=1e-5) + np.testing.assert_allclose(filter_state.filtered_cov, filter_result.predicted_cov[i], atol=1e-5) + + plt.plot(filter_result.t, filter_result.filtered_mean[:, 0], label='filtered latent') + plt.legend() + plt.show() + + +@pytest.mark.parametrize('use_sparse', [False, True]) +def test_performance_sparse(use_sparse: bool): + # Show that using sparse rep speeds up linear system + n = 128 + k = 10 + m = np.zeros((n, n)) + rows = np.random.randint(n, size=k) + cols = np.random.randint(n, size=k) + m[rows, cols] += 1. + + if use_sparse: + m = create_sparse_rep(m) + else: + m = jnp.asarray(m) + + def transition_fn(z, t): + if use_sparse: + mean = matvec_sparse(m, z) + else: + mean = m @ z + scale = jnp.ones(np.size(z)) + return tfpd.MultivariateNormalDiag(mean, scale) + + def observation_fn(z, t): + mean = z + scale = jnp.ones(np.size(z)) + return tfpd.MultivariateNormalDiag(mean, scale) + + initial_state_prior = tfpd.MultivariateNormalTriL(jnp.zeros(n), jnp.eye(n)) + + essm = ExtendedStateSpaceModel( + transition_fn=transition_fn, + observation_fn=observation_fn, + initial_state_prior=initial_state_prior, + materialise_jacobians=True + ) + + essm_jvp = ExtendedStateSpaceModel( + transition_fn=transition_fn, + observation_fn=observation_fn, + initial_state_prior=initial_state_prior, + materialise_jacobians=False + ) + + sample = essm.sample(jax.random.PRNGKey(0), 512) + filter_fn = jax.jit( + lambda: essm.forward_filter(sample.observation, marginal_likelihood_only=True)).lower().compile() + filter_jvp_fn = jax.jit( + lambda: essm_jvp.forward_filter(sample.observation, marginal_likelihood_only=True)).lower().compile() + + t0 = time.time() + filter_results = filter_fn() + filter_results.block_until_ready() + t1 = time.time() + dt1 = t1 - t0 + print(f"Time for essm(use_sparse={use_sparse}): {t1 - t0}") + + t0 = time.time() + filter_results_jvp = filter_jvp_fn() + filter_results_jvp.block_until_ready() + t1 = time.time() + dt2 = t1 - t0 + print(f"Time for essm_jvp(use_sparse={use_sparse}): {t1 - t0}") diff --git a/essm_jax/tests/test_sparse.py b/essm_jax/tests/test_sparse.py new file mode 100644 index 0000000..32ef1c2 --- /dev/null +++ b/essm_jax/tests/test_sparse.py @@ -0,0 +1,31 @@ +import jax +import numpy as np +from jax import numpy as jnp + +jax.config.update('jax_enable_x64', True) + +from essm_jax.sparse import create_sparse_rep, matvec_sparse, to_dense + + +def test_sparse_rep(): + m = np.asarray([[1., 0, 0], + [-1., 2., 0.], + [0., 0., 5.]]) + rep = create_sparse_rep(m) + v = jnp.asarray([1, 1, 1]) + np.testing.assert_allclose(matvec_sparse(rep, v), m @ v) + + m = np.random.normal(size=(100, 100)) + v = np.random.normal(size=100) + + rep = create_sparse_rep(m) + np.testing.assert_allclose(matvec_sparse(rep, v), m @ v) + + +def test_to_dense(): + m = np.asarray([[1., 0, 0], + [-1., 2., 0.], + [0., 0., 5.]]) + rep = create_sparse_rep(m) + M = to_dense(rep) + np.testing.assert_allclose(M, m) From bf755c10d671612731c31ccea6f3f10498b6a908 Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Thu, 15 Aug 2024 01:36:05 +0300 Subject: [PATCH 5/8] * Add dt * Add examples --- README.md | 2 +- ...excitable_damped_harmonic_oscillator.ipynb | 317 ++++++++++++++++++ docs/examples/online_filtering.ipynb | 196 +++++++++++ essm_jax/essm.py | 168 +++++----- essm_jax/tests/test_essm.py | 44 +-- 5 files changed, 624 insertions(+), 103 deletions(-) create mode 100644 docs/examples/excitable_damped_harmonic_oscillator.ipynb create mode 100644 docs/examples/online_filtering.ipynb diff --git a/README.md b/README.md index e668498..af2d27e 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ from essm_jax.essm import ExtendedStateSpaceModel tfpd = tfp.distributions -def transition_fn(z, t): +def transition_fn(z, t, t_next): mean = z + jnp.sin(2 * jnp.pi * t / 10 * z) cov = 0.1 * jnp.eye(np.size(z)) return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov)) diff --git a/docs/examples/excitable_damped_harmonic_oscillator.ipynb b/docs/examples/excitable_damped_harmonic_oscillator.ipynb new file mode 100644 index 0000000..7188573 --- /dev/null +++ b/docs/examples/excitable_damped_harmonic_oscillator.ipynb @@ -0,0 +1,317 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "ef3eaff3797489e6", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-14T22:28:24.816871897Z", + "start_time": "2024-08-14T22:28:23.676057239Z" + } + }, + "outputs": [], + "source": [ + "import dataclasses\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import pylab as plt\n", + "import tensorflow_probability.substrates.jax as tfp\n", + "\n", + "from essm_jax.essm import ExtendedStateSpaceModel\n", + "\n", + "tfpd = tfp.distributions\n" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Let's define the dampled harmonic oscillator\n", + "\n", + "$$m \\ddot x= - c \\dot x - k x + f(t)$$\n", + "$$\\iff \\ddot x = - 2 \\zeta \\omega \\dot x - \\omega^2 x + \\frac{f(t)}{m}$$\n", + "\n", + "with $\\omega = \\frac{k}{m} \\in \\mathbb{R}^+$, $\\zeta = \\frac{c}{2 \\sqrt{mk}} \\in \\mathbb{R}^+$ and $f(t) \\sim \\mathcal{N}[0, \\sigma_f^2]$.\n", + "\n", + "Let us make this non-linear by supposing that $M = \\log m \\sim \\mathcal{N}[\\bar{M}, \\sigma_M^2]$.\n", + "\n", + "This defines a state,\n", + "\n", + "$$ z = [x, v, M]$$\n", + "\n", + "transition mean function,\n", + "\n", + "$$ z_{t}, t \\to [x + v \\Delta t, v + (-c v - k x) \\Delta t, \\bar M]$$\n", + "\n", + "transition noise scale function,\n", + "\n", + "$$ z_{t}, t \\to [0, \\sigma_f, \\sigma_M]$$\n", + "\n", + "observation mean function,\n", + "\n", + "$$ z_{t}, t \\to [x, 0, 0]$$\n", + "\n", + "and observation noise scale function,\n", + "\n", + "$$ z_{t}, t \\to [\\sigma_x, 0, 0]$$\n" + ], + "metadata": { + "collapsed": false + }, + "id": "716c02a70cfecdff" + }, + { + "cell_type": "code", + "outputs": [ + { + "data": { + "text/plain": "
", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkAAAAGdCAYAAAD60sxaAAAAP3RFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMS5wb3N0MSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8kixA/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOydd3QU1d/Gn9ndJJseSCAJSSCE3gOhS5MuWFBUUBQLir2hqPizvWLvDUVRsSKKBRWkiTTphN47SUghvWfrvH/czE7flm1J7uecnMxOvVtm5plvZViWZUGhUCgUCoXSjND4ewAUCoVCoVAovoYKIAqFQqFQKM0OKoAoFAqFQqE0O6gAolAoFAqF0uygAohCoVAoFEqzgwogCoVCoVAozQ4qgCgUCoVCoTQ7qACiUCgUCoXS7ND5ewCBitVqRW5uLiIjI8EwjL+HQ6FQKBQKxQlYlkVlZSXatGkDjUbdzkMFkAq5ublISUnx9zAoFAqFQqG4QXZ2NpKTk1WXUwGkQmRkJADyAUZFRfl5NBQKhUKhUJyhoqICKSkptvu4GlQAqcC5vaKioqgAolAoFAqlkeEofIUGQVMoFAqFQml2UAFEoVAoFAql2UEFEIVCoVAolGYHFUAUCoVCoVCaHVQAUSgUCoVCaXZQAUShUCgUCqXZQQUQhUKhUCiUZgcVQBQKhUKhUJodVABRKBQKhUJpdlABRKFQKBQKpdlBBRCFQqFQKJRmBxVAFAqFQqFQmh1UAFEoFAqFQvEKmzcDn3/u71Eo4xMBtGDBAqSmpkKv12PQoEHYtWuX3fWXLVuGrl27Qq/Xo1evXvj7779Fy3/77TeMHz8esbGxYBgG+/fvl+1j1KhRYBhG9Hfvvfd68m1RKBQKhUKxw8iRwD33AIF4+/W6APrpp58wZ84cvPDCC9i7dy/69OmDCRMm4NKlS4rrb9u2DTfddBNmzZqFffv2YcqUKZgyZQoOHz5sW6e6uhrDhg3DG2+8YffYd999N/Ly8mx/b775pkffG4VCoVAoFGUOHuSnP/sMMJv9NxYlGJZlWW8eYNCgQRgwYAA+/vhjAIDVakVKSgoeeughPP3007L1p02bhurqaqxYscI2b/DgwUhPT8fChQtF654/fx7t27fHvn37kJ6eLlo2atQopKen4/3333dr3BUVFYiOjkZ5eTmioqLc2geFQqFQKM0VhhG/LisDoqO9f1xn799etQAZjUZkZmZi7Nix/AE1GowdOxbbt29X3Gb79u2i9QFgwoQJquvb44cffkBcXBx69uyJefPmoaamRnVdg8GAiooK0R+FQqFQKBTPUF3t7xGI0Xlz50VFRbBYLIiPjxfNj4+Px/HjxxW3yc/PV1w/Pz/fpWPffPPNaNeuHdq0aYODBw/iqaeewokTJ/Dbb78prv/aa6/h//7v/1w6BoVCoVAoFOdoVgLIn8yePds23atXLyQmJmLMmDE4c+YMOnToIFt/3rx5mDNnju11RUUFUlJSfDJWCoVCoVCaOs1KAMXFxUGr1aKgoEA0v6CgAAkJCYrbJCQkuLS+swwaNAgAcPr0aUUBFBISgpCQkAYdg0KhUCgUijKBJoC8GgMUHByMjIwMrF+/3jbParVi/fr1GDJkiOI2Q4YMEa0PAOvWrVNd31m4VPnExMQG7YdCoVAoFIrrBJoA8roLbM6cObjtttvQv39/DBw4EO+//z6qq6txxx13AABmzpyJpKQkvPbaawCARx55BCNHjsQ777yDyZMnY+nSpdizZw8+F1RSKikpQVZWFnJzcwEAJ06cAECsRwkJCThz5gyWLFmCSZMmITY2FgcPHsRjjz2GESNGoHfv3t5+yxQKhUKhNGusVvm8qirfj8MeXhdA06ZNQ2FhIZ5//nnk5+cjPT0dq1evtgU6Z2VlQaPhDVFDhw7FkiVL8Oyzz+KZZ55Bp06dsHz5cvTs2dO2zp9//mkTUAAwffp0AMALL7yAF198EcHBwfjnn39sYislJQVTp07Fs88+6+23S6FQKBRKs0cp6TrQLEBerwPUWKF1gCgUCoVCcY/8fEAacfLpp76pCB0QdYAoFAqFQqE0P5SsPbW1vh+HPagAolAoFAqF4lE4sdOqFTBrlnheoEAFEIVCoVAoFI9x8SJQVESmQ0MBvZ5MB5oAarKFECkUCoVCofiWnBxAWEM4NJT8AUBdnX/GpAa1AFEoFAqFQvEIGzaIXwsFUKBZgKgAolAoFAqF4hGkeeVUAFEoFAqFQmnyKAmgQI0BogKIQqFQKBSKR7BnAaIxQBQKhUKhUJok0hYY1AVGoVAoFAqlySMVOYHsAqNp8BQKhUKhUBpETQ2QmQlUVIjn6/WBawGiAohCoVAoFEqDmDoVWL0aCAkRz6cuMAqFQqFQKE2W1avJf4NBPD8yEggLI9NUAFEoFAqFQmkWtGjBC6CaGv+ORQoVQBQKhUKhULyCUAApdYj3J1QAUSgUCoVC8QrUAkShUCgUGUcLj+L6n6/HoYJD/h4KpQlRaaj09xBstGgBhIeTaZOJ/AUKVABRKBSKn7j+5+vx67Ff0Xthb2w6v8nfw6E0Ad7d/i6iXo/CT4d/8vdQAIgtQEBgBUJTAUShUCh+4ljRMdv0qG9GYcO5DeorUygOMFqMeHzt4wCA6b9OR7XRN0E3L7ygviw+nqTGMwx5HUhxQFQAUSgUip/QacSl2JYcWuKnkVCaAlIB/d6O93xy3JdeUl8WH0/ETyDGAVEBRGm2sCwLo9nqeEUKxUtEhUSJXhfVFvlpJJSmwJ7cPQCANpFtMHfoXMwdOtfPI+ItP1wcEBVAFEoAsOZIPhZsOI3CSoPjlSkUD1NnrkOtiQREJEclAwBKakv8OSRKI+dc2TkAwOx+s/HmuDcRogtxsEXDsViU5/fqBaxfz78OxFR4KoAozZZjeSRTYm9WqZ9HQmmOfH/we9SaaxGjj8HHV3wMgAogSsM4X3YeANC+RXufHbOqSnn+Dz8Ao0fzr7mGqHV13h+Ts1ABRGn2MP4eAKVZsvz4cgDAnMFz0Da6LQAqgCjuU1pbil0XdwEA0lqk+ey45eXK8zUSdaGrD3dTsxj5A9oMldLsYRgqgSi+pdpYjX/O/gMAuLbbtYgMjgRABRDFPYwWI9I+TEOlkVi1u8V189mxpd3fOdQEkNns3fG4ArUAUZo9VP5QfM3evL0wWAxoE9kGPVr1QFxYHAASF0RFEMVVssqzUFZXZnsdGxbrs2MrCaBJk4CuXcXzqACiUAIQagCi+Bqu/k/v+N5gGAbhweG2ZYO+GOSvYVEaKRcrLtqm3x3/rk+PXalQdHrlSvl1lQogCoVCaWbkVebhvhX34WzpWWSXZ2Pi9xNxz4p7ACi7Kk6XnPb1ECmNnIuVRACNSh2Fx4Y85tNjGyRJtN99p7xeIAogGgNEafZQCxDFm9yz4h78dfIvLD2yFO2i2+FAwQHbssHJgxW3MVlMCNIG+WqIlEbOmjNrAPDlFHyJ0chP//ILMHWq8npSAbR2LVBaCgwZArRt690xqkEtQJRmDwMGR3MrsOlkIViW9fdwKE2MzLxMAEBZXZlI/GgZLUaljrK93n/Pftt0yzdb2moEUSiOWH16NQBgSpcpLm/77YFv0eOTHjhZfNKtY3MCaMwYdfED8AKIa4b65pvA9OnAli1uHdYjUAFEoYAURdx7oRQXigOoTCmlSRAfHi96PaDNALw6+lVsuWMLWoe3ts3vk9DHNl1lrLKlyVMo9rCyVhTXFAMAhqQMcXo7s9WM17a8htuW34ajhUcx8/eZbh2fE0DBwfbXC6o3aHIWoLIy8j8mxq3DegTqAqNQBC4wA22NQfEgZ0rOYF/+PtG8O9LvwH0D7lNc/4buN2DZ0WUAgG8Pfoubet3k9TFSGi8miwlDvxoKC0uK63DZhPYoqinCI6sfQZ25Dr8d+802f+fFnW6NwVkBJHWBcfWDoqPdOqxHoBYgSrNHGAKkofFAFA/y2n+vyeZN7jxZdf2l1y9F5mziMlt7Zi1yK3O9NjZK42fXxV22/l8AEKx1oEIAPLzqYSw5tEQkfjgsVterFLorgALBAkQFEIUigBZFpHiS3bm7AQB39b0LWkaLqzpfZav6rISG0aBfYj+kJ6TDylqxM8e9p3JK82B7znaXt/n33L+qy44UHnF5f+4IIJalAohCCQisgsBnNQtQtcGM05cqaZA0xWlYlrX1ZpozZA5yH8/FLzf+4tS2bSLbAABK62ifOoo6F8ou2Kav7nK14jrZ5dmY/ddsHLl0BFf8cAUKqgtU97fh3AaXx+COAKqt5S1BNAaIQvExQiFjFYT9aFQsQBtPFOJkQSX6to1BYnQozhVVY2y31tBp6TMERZmyujJUGEiZ3HYx7RAWFOb0ti30LQCIW2OwLIudF3eib0Jfn3T5pgQ2/5z9Bx/vJk10Hxr4EN4c96bieumfpaOktgSZeZnYm7fX7j63Zm/FI4MfcWkc7gggzvqj1QLh4aqbeB169aY0S6wCQ47ZKrQAKQugkwWk3Om+rDL8fSgPx/IqcPCiShdACgW8O6FNZBuXxA8AtAxtCYA0uLSyVuzL24f5m+djyJdD8OS6Jz0+Vkrj4mDBQYz7bpzt9aCkQdDr9KJ11p5ZK2qt4kj8AL5zgZ05Q6ZbtfJvHTYqgCjNEqHbS21aiJIwqjEEUFtjSsCxI2cHAHJzchXOAlRaV4qv9n2Ffp/3wwsbXwAAfLjrQ7eCVSlNg315+9BnYR/RPGn214qTKzDh+wlo93471f0MTBpom+bE08nik3h1y6t4adNLTo/HHQG0fj2Zvvxypw/jFagAojRLhELHYnUsgJQ8XSxoPBBFnbOlZwEAPVv3dHlbmwWorhTzN8+XLdfN1+GvE381bICURsmr/70qmyd1iVYbqwEAl6ovKe7jvzv+w867diLnsRw8MeQJHLn/CGL0MTBbzfjfv//DCxtfQH5VvlPjcUcAFRaS6U6dnDqE16ACiNIsOZbHd/ATCiA1SaNRiI62Uv1DsUNxLSlO50xtFincNgVVBTZrkJSrl16NrPIs9wdIaZQU1RTZpq/rdh16x/eWtVSZ1nOaalA0AFzW9jIAQFJUEt4a/xbSWqQhPSFd9Tj2cEcA1dWRab1efX1fQAUQpVmy4Tj/ZCSMAVLL8tIquMBYlkV5rQkbT1xCeY3J84OkNGq42AvOmuMK7Vu0BwDsz9+Po4VHAQDzL5+PYw8cE63X7v12+OnwT9hyYQvNUGwGsCyLgwUHAQB77t6DX2/8FQfuPSCL/wHU22K8NErZvTWy3UjRa2EAvj3cEUBcA9UQP8fyUwFEafaYLXwamJpVR6tgAWIB/L43B/uyyvDHgYteGh2lscLdQGJDY13eNq1FGgDiAjNZTegc2xnPjngWXeO6ytad/ut0jPh6BFadXtWwAVMCntzKXJTUlkDDaNCjdQ+7696efju+vPpLrJqxClEhUQCA50Y8h2dHPKu4/t397kZiRKLttTcFELUAUZo82dmApRHEauaV19mm1R6iFbPDWKC03vJTXGWUL6c0G57f8Dyu+OEK/Jf1H9aeWYulh5fasm7csQDFh8cjPIjPD+4W1802vfLmlXhs8GMy19rhS4fdHD2lscBZf7rEdlG0+ghhGAZ39r0TEztOxKH7DuHd8e/if8P/p1rsNSkqCRfnXMQVHa8A4HkBJOwFFigWIFoHiOIVVqwArroKuO024Ouv/T0a5+GCoA1mC3aeLUHXhEi0jlK+0NAgaEqVsQrPrH8GH+36CADflVuIOwKIYRiktUjDoUuHAEDUNHVSp0mY1GkSpvWYhsFf8rEfXENMStOFE0DCxrnO0Da6LR4b8pjD9RiGQWwYsVg6+3uiFiAKRcKLL5L/33zj12G4jJVlUVptxCcbziDzQil+2Jllmy9bl/ZNbfZM+2WaTfyo0bFlR7f2zbnBALEA4hiUPAivj3nd9vpSjXLGD6XpcPASEUC9W/f22jFSo1MBAE/+8yTqzHX2VwZQTRLOEOag1FWzFUALFixAamoq9Ho9Bg0ahF27dtldf9myZejatSv0ej169eqFv//+W7T8t99+w/jx4xEbGwuGYbB//37ZPurq6vDAAw8gNjYWERERmDp1KgoK1EuAUzxLY3B9KcGywM97smXzLQrBQcI5tIVY86PaWI2/T/1td53Xx7zudn+5zrGdbdNKAggAnhj6BGb3mw2A1B06kH/ArWNRGgecBah3vPcE0Oj2o23TL216CSeKTthd39mu7s0yCPqnn37CnDlz8MILL2Dv3r3o06cPJkyYgEuXlJ9Wtm3bhptuugmzZs3Cvn37MGXKFEyZMgWHD/P+7erqagwbNgxvvPGG6nEfe+wx/PXXX1i2bBk2bdqE3NxcXHfddR5/fxRluD4v9jh8sRwXy2q9PxgJ9rJlckprUGOUqzel4OiiKoNtWilLjNK0+fXYr4rzx6aNhU5DrvbjOoxTXMcZxncYb5vu0KKD4jpajRbXdrsWAHC86DjSP0vH1/u/xpVLrsRrW+Sd6CmNF4PZgONFxwF4VwCNSh1lm37tv9fQdUFXHCs8pro+19bCFQEUKBYgr8cAvfvuu7j77rtxxx13AAAWLlyIlStX4quvvsLTTz8tW/+DDz7AxIkTMXfuXADA/PnzsW7dOnz88cdYuHAhAODWW28FAJw/f17xmOXl5fjyyy+xZMkSjB5N1OzixYvRrVs37NixA4MHD1bcjuI5HFmAcstqse4oscg9Nq6z/ZU9jMGs7rsS1gcSouQCyxcETyvVCaI0bXZf3G2bntFrBqqMVYjRx2DxNYtRYajA+bLzLsdqCLk89XLc3e9uxIXF4YpOV6iuJ7xhAcAdf5Br7cpTKzFv+Dy3j08JHH489CNOlZyC2WpGjD4GyVHJXjsWwzB4fMjjeGf7O7Z53xz4Bq+PfV1xfVctQCZT4FiAvCqAjEYjMjMzMW8efxJqNBqMHTsW27dvV9xm+/btmDNnjmjehAkTsHz5cqePm5mZCZPJhLFjx9rmde3aFW3btsX27dsVBZDBYIDBwD/RV1RUOH08ihyhBYhl5S6iijrv182pMpgRHqyVuSDsCSAlao0WGB1so5QmT2m6sCyLdWfXAQB+nPojpvecLloerY9ukPgBiHXn86s+d7ieXqfHmlvWYML3E2TLjBYjgrUOolP9iNUKTJwIpKYCnzt+q82S3Mpc3PzbzbbXveN7u+1WdRapwNp0YZPqus4KoKgofv1AsQB51QVWVFQEi8WC+Ph40fz4+Hjk5yuX2c7Pz3dpfbV9BAcHIyYmxun9vPbaa4iOjrb9paSkOH08ipgzZ4BTp/jXdQpxdGpNRz3F2cIqLNp8FqsPy79vVTFTeAnYvBn4ZRnwwQfA4sXA2bP4ZW+Ow+PVGi3Ym1Xa0GFTGgnHi47jRPEJhOpCManTJH8PB+M7jEf2Y/LYtU92f+KH0TimuBj44gvSE2rdOmDRIqCUnj6KcIUwObwZAM0hDdzPq8xTXM9kAmrroxgcCaC4+qoNRUWBYwGiWWD1zJs3D+Xl5ba/7Gz5xYTiHFIDW5FCRXWhAPJGBdtd50gNi+P5cpeWwSzxz5lMwOpVwCefAhs2AEeOEsd2Vhbw/Xco+msNwDq2Gm06UeiJoVMaAVyfpNSYVFuROX+THJWMLXdsEc17bM1jWHZkmZ9GpM511wF33w3cfjs/r2VLYL687Vmzh4v74fBm/A/HsLbDRK/zq/IVr9Oc9QfgLTxqCAVQs7AAxcXFQavVyrKvCgoKkJCQoLhNQkKCS+ur7cNoNKKMi85yYj8hISGIiooS/VHcQyp4du+WryP0GCllWHmS3buBfmPL8NUfpVh7JB85pYLA66NHgQ8/AHbWZyZ26QyMHg3ccguQkUFSvTZtAr7/HrhwATAaFI9BaV5wfZLc6fPlTS5LuQxPDn1SNO/GX27Ec/8+h8LqwoBpl7F5M/mfmyue//zzvh9LoMM11eUYmjLU68eM0cfgvQnvYVbfWQAAg8WACoM8LKSmhvzX6/lCh2o0OwtQcHAwMjIysH79ets8q9WK9evXY8iQIYrbDBkyRLQ+AKxbt051fSUyMjIQFBQk2s+JEyeQlZXl0n4onmHfPvk8YcyM2csCaNR4I87WXcLcdwtxJLcC28/UF/jaswdYtgyoqgZaxADTpwPTbwKGDwc6dACuvJI8qgYHAWfPkYqOb71NBBEtgtgsYVkWxTXFASuAGIbBG+PewObbN2PesHno1boXAODlLS/jyh+vRNirYViwa4Hitq9ueRVvbX3Ll8OlOEFOBXHBx4XFYfm05Q5bYHiKRwc/ii+u/gIRwREAoNgd3hVLDieAhDFAjmoHeRuvZ4HNmTMHt912G/r374+BAwfi/fffR3V1tS0rbObMmUhKSsJrr5GUzUceeQQjR47EO++8g8mTJ2Pp0qXYs2cPPhdEyJWUlCArKwu59Y8PJ06QOgUJCQlISEhAdHQ0Zs2ahTlz5qBly5aIiorCQw89hCFDhtAMMD/w8svE3N22LT9P6ALztgXIYLEgCCQzbf16YMwYAGdOA6vqa7gMGUJmarXyjXv1AhITgM1bgLNnSdWvjRvJzkaPlq2+eOs5dI6PxGUdA+vGSPEMr2x5Bc9teM72OtAEEMfwdsMxvN1wjG4/GuO+I6n4uy4SK+eDqx7Evf3vhVbD/94PXzqM//37PwCk0rSvbrIU+1QZq7Di5AoAwMLJC3FN12t8PobkqGQcLzqOrPIsdInrIlrmigBq0YIkwwiNkOHh6uv7Aq/HAE2bNg1vv/02nn/+eaSnp2P//v1YvXq1LdA5KysLeXl8gNXQoUOxZMkSfP755+jTpw9++eUXLF++HD179rSt8+eff6Jv376YPHkyAGD69Ono27evLU0eAN577z1ceeWVmDp1KkaMGIGEhAT89ttv3n67FBWeeUb8WphW7g0LkDDGmtHw+//vP5Bg52W/kOI+6X2A8eOUxQ9HXCtiCXricWBSfcDrli3EfSahrMZkiz+iNC3K68pF4gcAWuhb+Gk0zjE2bSw+u/Iz2fzsCj7GsbyuHL0+7WV73fPTnjhfdt4Xw6M44Lblt6HWTFz2SVFJfhkDV4Nq/Pfj8dPhn0TLXBFAGo04TohhgNBQT43SPXwSBP3ggw/iwoULMBgM2LlzJwYNGmRbtnHjRnwtaRZ1ww034MSJEzAYDDh8+DAmTRJnWdx+++1gWVb29yLXfwGAXq/HggULUFJSgurqavz2228uxRFR3IPzCUv57z/g9Gn+tfApwNsWIOgkQc9LlgAGA/Rtk4ArrwLgXEZaVGgwMGAAcFm9D/6PP8RRgAJy/VDgkeJd5q6bK5t3e/rtvh+Ii1zd5WrZvFPFfJrmrb/fKlv+8uaXAyZeqLlSVleG347xD+3pCel+GYewCOf0X8XlHlwNZhYmZoeF+b+CPs0Co3gUpYwvgMQPd+rEvxZagA7klOG7HRdQXuud2kAanSSDq6wciG0J/e0zFS0/IUHKp0XXhEgyMXoMkJJCugCuWgWleKCfdtMswqaGUuXnbq26KawZWCREyB/8xn8/HsnvJmPlyZXYmr1VtvzLfV+i00edcKbkjC+GSFFAWH3575v/dtj93Vtc1eUq1WWuCqDISH7antHdV1ABRPEohQ4ywbkTRigZ9meVoajSgI0nvNPMkZFagEKCgZtuhj5GOdMvMkQ5NC5EV3+6aDTAVVcCWg1w4gRw6rTi+pSmhdkq7u/y49Qf/TQS19l420bc0P0G/G/4/2zzLlZexJU/XomSWuKy7dCiA2b2mWlbfqb0DDp+1NGphpgUz8Nlf41KHWW3Eri3GZs2Fqcf4q9xO3N24qOdH4FlWZcFkKCjFQKh1jAVQM2Mf/4hpW68hSMBxBU7UzKvG0zeaa/OaCX7HTAQiI1FWLDyI8iEHsqu0iCt4HRp1ZovePTverFPj9LkqDRU2tKAdRodbuh+g6z6cyAzMnUkfr7hZzw9TN5+CACGtx2O0w+fxkujXpItO3LpiLeHR1GAE0BpMWl+HgnQvkV72/TgLwfj4dUPY+nhpQFTz8ddqABqRhQVAePGkeQlg5fK2TgvgOTLLG6IiIKKOvySmYOCCuWnVJYFghMkcTqDBgIAwoOVLT2to5TPZp1W4rC+bBigDwHyC4Aj9CbRlLlYeREAEB0Sjcp5lfjp+p8cbBGYRARH4Jlhz8jmD04mYr5tdFvc0vsW0bIzpdQN5g/OltULoBb+F0AaRi4VDl867LIA6t/fg4PyAFQANSOE9+i8PNKHZ80aUpbeU3AxQO3bAwsUyo1wtSmVpA4XDM0FtTvDL5k5yC6pwc8qMTdWK8AESVxgEcQRHaF3rQqETiM5XUJDgaGXkekNGwCrgw6wlEYLF4+RFJUEvU7v9V5M3uSVMa/I5g1JJvXRGIbBd9d+J1omLcRH8Q1c/FUgCCAAmNBB3Gvu1f9exRcF9wBwXgD98ounR9UwqABqRgj9r3l5pP/OxImAICmvwXAWoKuuAhIT+fnayFpowgw2C5BSd3WLlQifZZk5+Gl3tlMiiOvrpZZKL+xKb6kUn6URglifAaktoWEYjOsu7kMnRLHf6aBBpJhFSQmwb7/D8VI8i8UHovPPE3/iup+vAwCMajfK68fzBWtvWWubjgiOwPB2w0XLv53yrW3a04HQgRD7EcgYzAZc8cMV2JJF2poEigD6+Yaf8eroV0XzttR8DkRlOy2A2rUjybMAcKs8+dDnUAHUjBBaenJzgaVLyfQZD17fOAEUFwf060emmWAzInpnI7LvBbsusDqTBdVGCy6W1iKvvA7Vxobf3IQCiLWKf+5CAZTWKhwPju6Inkmko9+0AfJmuBolBRQcTCpHA6RAotHY4DFTnONo4VG0eKMFXtokj1vxBHXmOvRZ2AfXLOWLz/mjEJ03GNdhHNgXWFTOq8SZh8/ICjre2udWfDPlGwC8K0bI/9b/D7P/mu2WAOVugEpIjazNBeHnuCNnB1afXm173Tm2sz+GJCMqJArzhs+D5XkLFkwSmPdTN7qU0XX11aSm7OLFnh+jqzTTn1vzRBj3k58PmM3q67oLJ4BatSJq/4cfAE0In95eVUX+K1mAaowWLNrMX2y1broZTBYrcsuIc9pSxx+btZD9WetjosMFAkjDMKL2HPEqcUCK9M8gZU6rqoCt8pRiind4fO3jqDRW4oWNL3hl/5vOb8LBgoOieb7ow+RLIoIj0Dq8teIyrv7L6RJxlmN5XTle/e9VLNq7CNtztrt8TEmLRhHBwS7vrtHz2pbXEP16NP7L+g/VxmqcKjklWt4iNLCKbWoYDe4fcD9fAys6y9bbzVnat6dp8BQfIxRAdXVi64in4GKAWrUi/4eLLeuoqeHifBzvq9poRkWd87WB8svrcKmyDsfz+A7wlqMn+RXqLUDc+xZmgUmNO0rGHqUxJ0TrAa2ORJcDwPZtpF0GgGqDGZkXSlHrAUsWRU5xjf3gtZyKHBgt7lvkSutKRa9/nPqjrS9Sc6B7q+7QMlpklWfhXOk52/wjhXww4dYs1wW/vQQMRw01myLP/PsMqk3VGL54OPov6o/M3Ezbso+v+NiPI7NP26j63kbRWZg9279jcRcqgJoRwguPxeIdCxDnZuMa38XEAGB45VBbXyDZGQH03fYL+HLLOdSZnBMQP+7Kwg87smDkFI7VAksm34mVtYgFkDCtXZqB5kyQ6zXpbTCM6/nVrSvQpg1gMgM7dgAA/jqQi80nC7H2aD7yymux4fglp98LxT4GswF78/baXnNC52zpWfT+tDeY/2OQ8l4KQl4OwVPrnoLB7Hra44WyC6LXjSnt3RO0CG2By9qSIP+0D9NsdZAOX+KDCfcX7IeVda18hT0BFNF89CUAeYD58aLjWJhJWjr9fP3PeGDgA/4YllO0i2lHJlqcC4h4HnegAqgZIbzwmM3eEUD1xg9bkzvpBa2mtt4C5EI39bIa1ypEbz5Zb4bavQeWUj7ikrUSUcMLIPsNWTPatUAXrvqzAmmtIgQiiuHNXZmZgMWMvHLihjtbWI2lu7KxP7sMn248g0uVtLBcQ/nn7D+wsLyY5LqYP/XPUzh06ZBo3Te3vYn5m+e7fAxhP6x3xr/j3kAbOZ1b8vEnd/5xJyxWi6gu0NLDS/HE2idcEkFc6vTEicAXX5BexBzNTQA9vvZxxfl6nR5Tu0/18Whco3d8bzKRtAsabeN8sKMCqBkhtQB5wgXGsmJrDmfh4ZrcSQ0ptfUXP1faf7kilvgD1QCbNsICgaNZYAG6rGOcyMoTGiR3SI/o3AqTeiXK5gsRusquuXEUqfVeWwucPKW6zQ87spx7DxRFjBYj7v7rbtG8pUeWoqyuDIXVyoWouI7aQl7Y8AIS3k7AyWLiJi2oKsCizEUwWoworyu3PYl/efWXmDNkjoffReOgzsKL9e8Ofoev93+Nw4WHReu8t+M96F7S4f6V9zu1T+461KMHMGuWuCN4UxNA1cZqHC2UN03m+OfsP4rzu8R2Uay9E0j0ad0XMEQA+nJcqD7h7+G4RWB/whSP4mkL0OefkwvWfMHDtVJhrD/+BOLriytfMBaitNro/UaLmzYDtXU4Fs5X3uKCoC0WoGNrcqWd0jcJY7vFIzYixImd8mPW1IsnoYgKDtYBffqQFwf2N2z8FFWm/TINeVV5AIAR7UYAIG6Z6b9MtwUtvzr6VfRs3dO2TWGNWBidKz2Hlza/hILqAtzxxx2Y9888JLyTgNkrZuPr/V/jlS18rZzUmFQvv6PA5ZFBj4heL9q7CP+e+1e2HgsWn+75VNRkVQ3uOhRSf8oJO4IHQmCsJ3lk9SPo8UkPLD28VLaMZVnVNiP+6vzuClaLFqgk4yw1Fvh5NO5BBVAzwtMxQM8/T7q/v/ACyQAH5BYgAOjVE+jZg0wXW8vx855s1yxArmql4mJg924AwObqDH4/giBoLuOrfVw4eiVHO7VbcdaYfHmITssLoFOngOqq+nUbb9G8QKOopgjLjy+3veYK+AHAmjNrUFpXipSoFMy9bC4O3XcIu+7aBQDIrcxFbmUuAGINSvuQr62yLXsbXt/6uu310sNL8da2t2yvA6UOiz/o36Y/Lk+93PZ658WdtmlhLSGOB1c96HCf3EOSkgCyeqcbjt/4ct+XAICbfr0Jj6x6xGaJPFV8Crsu7rLFVT008CFM7DjRtt3Q5MDPNjSZAFSTbJdSo4MWAAEKFUDNhH37xJWgzeaGu8CEhQ7//FO8T6EFyMqytuwOk5mkuyvXglbGYmVRYzTjm23nsetcieMN/lkHWK0oa99XPF8QA6RVrGqozNXpbTCicyskRvNXaq4mkDB2KFinIdHfycnEx3eIuApcOBTFAU+te0r0emzaWPSJ7yOa9+6Ed6HTELE6IGkABiaR1icL9xCXltR9JmXDeb5Z3o09bmzWFiAAWDJ1CTISM2Tz+7fpj6cuE38fa8+stTVXVUNqARJeK7yRmeovpFbuD3d9iIdWPYQKQwUGfzkYg78k7Uc0jAYfTPwAq2aswo9Tf8SsvrMahcuVCCBSQoEKIErAUlBAihKeElinPeECExo23nsP+Ogj/rXoqY5loas3nphN3Dznj2NlWew5X4qSaiO2ni7CwZwyWwVoGRfOA8dPABoGhoHDFVexWFyrMdShVQQy2olrcXCbC7PHbEHVveuDA4/UCyCqgDzC36f+xlf7vxLNiw+Px/579yMsKAwAEKQJwvXdrxetM3foXADAx7s+xpmSM8ivynfqeIOSBjWqju/eIiEiAbvv3o3b+twmmh+jj0H7mPay9WPfjEVORY7q/jgBxAmfibzho0kJIGkZBYAE1s9ZM0ckEqNComyu9Ok9p+OLq79AaFCobNtAw2QCUEMsQCV1VABRApSTJ+XzPOECk24/R/DQwj3dldUYcehiucgCBLjm1jJZWGRe4C8m649dwn+nFU44lgXW1JvlMzJgjlQuIGaxyIOzXYVzayVE6dE6KgRdEyLBoH6n3bqRA+RcBMrsiDWKS3AWnEFJg5CRmIG0FmnoEtcFALD59s3om9AXf930l2y7a7tei9bhrVFaV4o3t74pW/7MsGfwxVVfyOb/OPXHgA9E9RUMw2DxNYvxwsgXRPOEsSpj2o+xTae8l6JYp2n7dr4CPXeNmD4dtjoyTckFtv7sesX5nFuMIyokyhfD8TgmE4CKZADAliwXKyEGCPTsbgYoCR1PWIDUntZCQviS9kt2ZeFsYbVNAPEWIOcVkFKn93NFNfIVDxwgTc5CQoCRo2TjY81a27iDtQ376XNGHa2Gwc0D2+KKXom8qIqIAFLra2QcVc8AoTiPlbXaXFMLr1yInXftxLEHjiFYS0oHZ7TJwN579mJCxwmybbUaLa7rSnp5fb73cwBAr9a9MKnTJJx88CReGfMK7ux7J04/dBrD2g4DANzd7260byG3bjRnGIbBvGHzcG3Xa/HhxA8BAK3CWtmWf3blZ6L1496KQ5WxSjRvqCC0hRNADAPMmEGmm5IFaMnhJU6tl1XeOLNCTSYAh6cBADae3+hWrS1/QwVQM0BJ6FgsDW9bpXSx0oQZoA/nFxhM5JGOc4G5YwEymOUHYiDxsVeUA+vrU0pHjESFJVw0voQEcsyKXWkYm9yhwW4pYUA0Y8sIE6zQvT7q+4g4ZZjiHmdKzqDKWAW9To+erXtCq9HaxI8z3NL7FtHrGb1mYOXNK9EpthMA8h12aNkBv934G/bfsx+fX/W5R8ffVAjRheC3ab/hoUEPAQAGJg3EzD4z8ezwZ9GhZQfkzskVrb/uzDr1fQkSL7nsr6ZiAWJZFmvPyIPEmxImE4CSjoBZDxYsLlZe9PeQXIYKoGaASaGOoNnMZ2O4i1QAaSNrEdn3AiL6npetK7QAaRjGpTR4s0W+rtlqxdfb6o9TWwN8/wNQVQ20boXV5YPw3nvA7j1kcWIicM89JBuNNemgsbqfazulbxLaxOhxRU95fSCbCwwAute7wXLzSKd4SoPgqj73ju9tC3B2BWlDyfEdxiuu1yq8Ffok9FFcRpHDMAy+mfIN5o8mtTASIxNhfZ5XMVy5AiWEwc+cxbixW4DOl53HxO8n4sfDP6LGRKzU30z5Bv8b/j+snrFacZtXRr+iOD/QIfcVBppK0jg6uzzbr+NxByqAmgFcdWYhFos8Ld5VpJaloFhi7q6qIzsrF1RwtlmATMT9da5YYVAqKFVprjZYSIVog4F0XC0sBKKigJtnYOcu8rM+dpS4vbT1x9ZbSWBhQyxf7ePCMW1AW7QMl1sfRBagsHDS8Q+gbjA3OV1yGgfyD4BlWdzyO7Hg9Evo59a+hN3OQ3WhSE9I98QQKQowDIP7+t8HAHYDzpUsQI1dAL2w8QWsObMGM34jPr3Y0FjM7DMTL49+GRM6TsDX13yNtBZpeG7Ec7ZtuCD9xgb3YK2rJgKoMbryqABqBlRVyeeZzWJzs5KVyBGOLlZfbeUbKNosQPWi6VJhBbBjB7TLfwfWrSV5+sXyoEkAMFpU7OJGIxE/F3OBsFDglluAaEFNHyuDij2piCvoiNkj0qCvV2ENdf2pIav300PdDbb+WAHyy2lLDDXMVjOGLx6O9M/S8VnmZ7Z6KVxvKldhGAbXdr0WEcERWHnzSqd6vVHcJyGCVD4VCiCpxTkmhp9uKi6w0lpx5lebyDai17el34YzD5+xFZjsHNsZQdrG2QGWu2cEVXUEAJwobnzVoF23JVMaHUoC6LvvxK9NJrFJ2hnUBFDrePk83gLEApu3kHSQujpoLSZYuAsAwwDXXgv06iXa1qCURWW1AMt+BrKzycBvuZVvQV9P5f62gEULfTAQHqJBcL3RxlsCSHZL7dYVWLkSyC8AiouAWN4KcTCnHAdzyvHYuM7SrSgA9uXts90871t5n23+TT1vcnufv9z4C+rMdbaUeYr3SIwgLuLNp/bCYGAREsKgVJIV3rYtP91UXGAx+hjR6w4tOyiuFxsWi8K5hQgPCldc3hjgBJC+ojuqAbstPwIVagFqBigJIGkIjkcsQPUKYNRIwCpxW9ksQEYW2LCBPA62ioNu7Bhg0CBSPJBlgd9/B/79FzDxKqWsRqpYWGDFCuD0GbLjW24RV2Xk1qrv/cUJH68LIKkCCg0DOtRXET54SLY+RZ3dubtl8x4Y8AC0GvfjtzSMhoofHzGx40QwrA4nKjMx68nTACATQMJTtim4wKysFd8d5J8s01qk2W2iGxcW1yjq/ajB3TNCa0gpilMljtugBBrUAtQMUBJAUgoKgJYtXduvPLuMiJ6gIHmaOyeALKwGrEYL5qorgT59oA0NAurMRPysXAFk7gW2bAEOHgQmTwY6dUK1QXJV3LgJ2LefKI7rrweSlPvmcAKIizXwvgBScKv07gOcOk1S9EeNangBombCudJzsnmcW4US+KREp4DN6wO0ycQPa4/g2eOd8Lik8blOcPcRusBYtvGcJizL4n///g/ZFdn4/uD3tvmrZqwStbZoCpSWkjBL7rviBFCIiSjZS9WX/DQy96EWoGaAMwKoe3dgs4u1rLintX/+4eN9AU4AidfVaXk3lmnMRCA9HWAYPm6GYYDJVwJXX03ieMrLgSVLgK++BHIE2QU7dgCbNpHpyZOBzmIX0qgOiag92xq1Z1sDFnKm+koAKdK1C6APIe/n/HkfHrhxc65MLoCENWcojYDC7uR/qyPo1g1YLUiCWrNGvCrnAissJK6xwkZSWHjzhc147b/XROIH4Jv0NhXOnCEPyKNH8/NsFiALaYdRVFMEi7VxmfCoAGoGKGWBKfHqq67tlxNA7dqRVhicC0ynk1uAdGd586i5B59mLOrJxTBA377AA/cDQ4aQR43sHOCrr4C//wb+/IO/cg4fDmSI+xOVlwND0/Uw5sXAmBdjm88JIM4K5VMBpAsCetR3JT9wwIcHbrxYWast7V0YU9EqnAqgRkVpvfs3WpwdNGYMMF5ShUDYBT4nB/jkEy+PzU2Ka4rx9f6v8eXeL5FbmYt/z/0rW+fy1MubnKv1hx/If+FDsk0Asa3AgIGVtaK4VjmRJVChAqgZ4IwFyB04F5hOB4QJzvegIHmMEbNrJ3QgG5jAZz0IBdCg9vU+uKBgcoV89BFiKWJBurvv209E1tixwGi+QzVHdjZQVy2PEfGrBQgA0usF39GjgLHxVUv1NYcvHca5snMIDwrH7X1ut82nFqBGRmV9kE//zwGGtwALs784tJLTtrbWe8NqCLNXzMYdf9yBu/66C6O/GY29+Xtl67SLaeeHkXkXpQQZTgAF63SIDYsFABRUFfhwVA2HCqBGgMEArF3r/kXBWQHkqt+dswBptWIBpNOJm4TiUgFw7hx0IGeMMOBaJ7IASQ4QEQlccw1wywwgrT3QuRNw883AZZcprFw/Hqt8vq+CoNXQt2+HbkEG8saPHvPtwRshFytIRdnOsZ1tvb4ANPuu7I2OKkGU88RHwcUIhiskPmkkdyJ3kjJ8wW/HfrNNnyg+gRUnV4iWj0sbh7fGveXrYXkdJQHEeRbCwoDkKNIT7HzZed8NygNQAdQIePhhYMIE4N573dvengBKTnZvn4BYAIWGAoxaEPSuXWQ+1xFeEDwttAAxCqImuUUo0KEjcOtM4KabgY6dVMdD6ozI9yG1APny4poaF4Z7R3ZAi9H1MQEH9vvu4I2UgmryFNk6vDWu7nI1urfqjldGv4KU6BQ/j4ziEjV82QcM+ggYSjKiguU1RGUWoHffJQ99vuSP439g1alVLm+XEJGAsqfKUPZUGdbeulZUdLOpECpIVuMeIAvqjT3x8UCnluS63NgywagAagR8Xt+W6Ntv3dueE0A6HfD22/z8q68WBy+7AsvyAkinq4+vqdceQUEAy1m8a2tIRhcArZ64voSprkGCpqQs5BWfg3XO/0TVWnv40wIUotOCYRhoJownn8/5C0BJ4/KT+xoum6R1eGu0iWyDI/cfwTPDn/HzqCguky9pKTJ+LtBrCazBpbJVpQIIIA99vqLCUIEpP03BpCWTUFZXprpe6/DWsnkZiRmI1kcjWh+tsEXTQCiAKivJ//z6GpcJCbwA2pq91ccjaxhUADUDOAG0fr1Y8IwbB0RGurdPYcVWrVbsPtPpgFWH63sA7d1HOqAmJkATrJVva6cpqb1lSrC1IYrzueBnfwigsPr3zLRuDXSqt15t3+67ATRCuDiC+HCFipqUxoMpHPg/C7DyY37e1BnYE/W8bFWpC8zX5FXyPct25OxQXMfKWlFcI394+b9R/+e1cQUKQoO+kgCa0nUKAOIidMeK5i+oAGoGcL7a8HDxk5Ze774AElpxtFqgjaDiu04H5JXXEaWzu76g3cBBYOoFjVAAxUaod/SWCqBIvXrZqplD2iG6WNlF4k8BFFHfNZ5hmPrYJZBg7qpK3w2ikcE10IyPoAKo0cNqgN0PAMW86/qg/mPZakoWIF+RV5mHrgu62l5n5mYqrncg/wAsrAURwRG2eVEhUchok6G4flNCGDbQrx+5hl+qL/vTujUwIGmATQSpCchAhAqgRoS7T0nl5eR/ZKS4+FhIiOcEUMuWwLvvAfffL1jpzBly8LAwoGdPaDSAtS4IadEtbKu0DA/G9RnJuGVwO0g9YO3jwkXFBcND1AVQbEQIKsrIBzR8uHgZ9545IWTwYSJWuE0AAWjbDmibQj48rpYRRUZ2Ban7lBLVvGN+Tp4EZs8mp1Gj58Cttskotq1ssT8tQI+ueVT0Wq2n1fpz6wEAI9uNxB/T/0Db6LZYPm25l0fnf957D5g1i39dWgpcvAjUkGb3iKjXg33iicvzYuVFH4/QfagAakS42qsLIO4vzmSZmCi3AEVFuTcWoQDiBEZaGitux3Wwvu5N716ATgcNA7BWBqwgU0vDMEhpGYZWkWL31fBOcRjdVexvbxFmv2kgV2pf0hLMJnykDVl9QXhwvQDiZnCVxDIzgfw8xW2aO1xX6bbR8htlc2L8eGDRImDaNH+PxANU8VW8rZBnIehUnm2k5TS8wbFCcWamWk8rTgCNaT8GV3e5GhcevYDL28vLcTQ15syRzzMa+ZhLLsmEc1l/ue9L5FTk+Gh0DYMKoEZEqBttY3Jzyf/ISLkFSOoCW70auOqGOnz13zmcyLfvohFlcimZr+vqgOPHyXRv8mSg0QBgGUDgAhN2UBde6/qntoQ+SLzjmDB1dxkAlJWR/60lcYpSAeTLLLDwEPIebO+zXSrQswd5s2vXOdye9cUdIICwWC22NPjmLoAuXCD/M5U9Mo2LAzOBk5MAADXMJVhZq2ix2rXNF+7qOrM4eyIzL9MmwjlyKnKw+jQpZT0qdZT3BxXgVFbylnROAAldg8uOLPPDqFyHCqAAR1j7J8yN4qKcAOJidBzFAG06l4eCUhP+PmTfOiF1gQESD9bRo4DZArRuBSSSpz+mXgCJLUD2xy9cHGHHBQY4bwHyrQASuMA4xo4FtBrg3Dlg40bVbc0WK77fmYU1R/K9O8gAIq8qDxbWAp1GR3t/1aNmHQl0RF1qLCHA0j8AcwissOBs6VnRukqp8YBv3NVc9eLNt2/GkOQhACCr8Hzjshtt0z1b9/T+oAKcigq5ABJ+LiZrgBZykkAFUIAj7J8jtWw4g7BWAyDuwNy/v1xUMVorSkoc71cogDj/vchYsa++QmqfdHAyRqMBak+3hlUggISBzp3jiRoTBkYLhYO9GCCAtwBJBZA0BsiXAiikPo1faOlCdAzQrz5w8rnnVO38WSU1KKo04GhuhZdHGThkl5P4n6TIpAZ1fm9KKBUObAzIrDdWHZDXD4A8UFatCKu3BZDFakFpLXly6hTbCZelkESF7dniTM3tOeQ1AwZBWvuu+KaE8DovRGgB4kIz+ib2Rf82/QFAZkELVKgACnC4mzpATOE33ujaDZy7CHE/0q5dSfPSrCygRQtevdtgeEuKPYRFELmLl81dU3gJyLlIFE+f3rZtwksSYKnWi6o1C4VBq8gQ3DW8PW4eyLs+hnWMgz5IiyEdYm0p5UqYzXysU2yseJk/LUCKHeIBYPgwQKdF3fad+P6btTaLm8liRUUdGWDzcn4RuAtncy96KNTEajehQEfRfZUzGABw6++3wmB2rG686QL75egvWHNmja3+WMvQlhiSQixAn+/9HJm5mTCYDXhrK1/Zedusbd4bUACiVlutokIeAwQA92TcAwBYsHsByuvKvTy6htNIjavNB+kPcNkyYOpU5wMjhf26OMaM4adlAgisrHlqRZ0JpwqqkF9eh8u7tkJYsM62X6FLzdYBft8+8r9zZyCc9wvrtEQMsFa+5rNGluoufrqKCQvGvSPTwDAM6kzqd4IKgZFEGpPECR/uM/BmEHSwTgOj2YoRnVuhYyv+vWukQigyCujXD1tzjqPwv50oTGqPSb0SseJgLs4X1eCKXgmiIpHNBS4DrLnH/whd34HaF8sRygJokG3y5yM/49Y+tyqsxOMtC9BfJ/7CDctusL1OikxCsDYYg5MH2+b1X9Rftl2X2C6yeU0Ztd/eo4/y5VWE9xDh5/fJ7k8wb/g87w3OA1ABFOAo/QBduSBy1o4gFautzPfOkG3y8siTp1YLfLnlnG2xhgFahAdjx9FKMLoUaAVqw8KVhz5AKj+jb1/RrrlVGSt/Y9c60YCMs6Log7To2DoCpy9VSZbzQlGr5dMzOXzpAps1rD2qDWbERoiVpeLbHDAQB3ftBk6cBMrKYLGyOF9EBv/fqSJc3tUNn2cjx2YBauYp8MLfsNq5G+hwAmjRIuC//4BvvgFweqJt+b/n//WbAPrt+G+i1+M6jAMAtIlsg8SIRFstKikx+hjvDChAUbvXcDWAALEA6t6qu216W07gW8ua3yNmI0PJBKkWMKiEkgVIiNQCxDAs/vmHtN/45Rd5FlJFnQnbzxSjpMaI4Dal0GpJsO7GE5eQXVJDMr9qaoDICKBjR9G2Nq0kEECu1v9IbiFPF9EwjO1iGxwMjBwpXu5LF5g+SCsTP9wYZcTFAWlpxN+RuQcrDuaq7re5ZINRCxBBKIAa61fPiZdx44BHHuFmRuPB+J8BAKdLTjvch7dcYNJjj0sbZ5ved88+RaFz6YlL6i7tJoozD9vC8iwaRoPPryS9m5xxcfobKoACHHsCqKwMuOceYMsW9e0dWYDatZPMEJzfq1YBJovy1ddqBRidFTodsOtcCfZllcFosgD/1Q+mX4ZM3dgsQKxAALl4QVEajU7LC6CQEKIpfv6ZX+7PGCAO1bc5cCD5v3cvzuaV2WazrDgDztpIb4KuYLKYsPsiqRxOLUD8tC/rVilhNDoXFyjEYuEFUFiYONu0tZ50YFarFRMtaKnlLQvQyeKTotfC1Pb4iHjkPJaDqBC+SFrNMzVoFS7JrmgGqMUACZE+RHdsSR58G0MgNBVAAY6SAuesOc89Ryw1I0aob+/IAtSnD/D664IZDH+nTU0FDGbluBurFWA0LLRaILu0/mp94gSQXwCEBAODBsm2sYW0WFxzgTkiNEgrsgAB4hRc7r37IgZIDdV32bkTueLX1AJH+AJsVsljv6UZKKDNFzbjYuVFhOpCbcGozRWhALJY/GsF6tOHVHrnMkqdQTj+8HCJAAohAuhixUVZPSCAJDFw5+/DD3v+vX974Ftbw12OxIhE0evw4HBUGvhaaKFBbhRhawI4YwGSCiAugSG7IjvgLddUAAU4Sgqcywo5dcrx9o4sQADw1FN80TXpcQxm+QUKqL8oaazQaoHKOjMAFthc395h4CDFymY2F5iFjxuSBkE7Qul8Cg3S2p4UpZ3fgUCxAKm8T0YDZNSnxAsq3knfplQQNUVOlZAf9Ji0MYgLi/PzaPyLNBHBX5lghYV8PdN1jut22uDGzzDkUiAUQOFsIrSMFiarSdENlpTEn787dwJ797o5eAUKqgpw2/LbRPNahbVSPD/ZZpmHKcYtAVRvva0x1dhqLAUqPhFACxYsQGpqKvR6PQYNGoRdu3bZXX/ZsmXo2rUr9Ho9evXqhb///lu0nGVZPP/880hMTERoaCjGjh2LUxI1kJqaCoZhRH+vi0wdjQOlHyBn7XAmONKRBYijbVvgeUmTZpNJXQBZrUBwq0roWpWTdU6eAvLygeAgYPBgxW2UYoA84VIPDZZbgOwJoFOngFGjgBUrGn5sZ7H7Pvumk+jy7GxSQgByoaekf/ZmleLb7edRZfCzj8RDcMXx0mLS/DwS/yMN5PeXG4zT5oBr8Tjc+MPCeBHEYTbqMCaNpKL+fux32/xVq8h5uXixOJOz2IP30D9O/CF6/f2132P33bsV1/3y6i8BAIuvWey5ATQynHGBSeM4Q3S8Iuq+oDss1sCt4+B1AfTTTz9hzpw5eOGFF7B371706dMHEyZMwKVLlxTX37ZtG2666SbMmjUL+/btw5QpUzBlyhQcPnzYts6bb76JDz/8EAsXLsTOnTsRHh6OCRMmoE7ybb300kvIy8uz/T300ENefa+epKCA1PyRaD8ArgkgZyxAHFIlbzYDBpXUc66ju65dAXHvcM09Bw60VVcc1aUVbhyQAl29lScM5CrIClxgrhs2+A2C64sMdmwdIRNAwver1Ax10ybgqqtcPbb72I11iogEOten1+4lJQRYyfOnkgVo04lCFFcZseNMYD9lOcu5MpJtmNaCCqBAEUDZ2fy0KwKIswBxRRyFP/+6OmB4W9KxeOOFjbb5EycCGzYAHTqIrwuetNhuOL/BNv3T9T9hRu8ZaBcjDYQk3Nn3TpQ8WYLb02/33AAaCadOASkpwBtvuLc9VxCxsKYQuy7aN3j4E68LoHfffRd333037rjjDnTv3h0LFy5EWFgYvvrqK8X1P/jgA0ycOBFz585Ft27dMH/+fPTr1w8ff/wxAGL9ef/99/Hss8/immuuQe/evfHtt98iNzcXy5cvF+0rMjISCQkJtr/wRlRS9eGHSc2fPIVszPz6zgietAABcgFkMgFGix0XGMiFjTlzhvTcCNIBQ0jsRrBOg75tWyApJhSzR6bh7hFpCNGSQbBWBv3atUCPNlGI0rtfieHWIe1wZe9EdE+McskC5A+CtA5MXf1IhVwc2A+YTWBZceaXxY5SbCruMZsFiAqggBFAQu67D/jqK7l7TgluHaX2PSYTkBqTCgBYfXo1Xtn8imwdoQXIk3WQDl8iD9Irb16JG3vc6GBtoEVoC88dvBHx+ONATo7dTj12+ezKz2zTOy/u9MygvIBXBZDRaERmZibGjh3LH1CjwdixY7F9+3bFbbZv3y5aHwAmTJhgW//cuXPIz88XrRMdHY1BgwbJ9vn6668jNjYWffv2xVtvvQWznauIwWBARUWF6M+fnDypvuzpp4GVK71vATKZALOdLDCAhLBoNtY/VfXLAMKIyLxtaCq/X50WESE620XNbAZGdm6F8T0SXE4rFd7ro/RB6BQfCUaSBi/8D/AmWn/2VArROWjr0LEDCYaurQOOHoPRbMXerDLbYoVYURtNJTWXE0DtW7T380j8j/SmHyjVoGfNAiIigL/+sr+e1AIEEAGVnAzccguQHJVsm//shmdRZRTX9hK6VcobWFD4QtkFTPtlGnZd3IXjRSSgqUerHg3baRNH6fbXpQtw//3Obd8vsR9eHPkiAOBgwUHPDczDeFUAFRUVwWKxIJ5rRFVPfHw88jkzhoT8/Hy763P/He3z4YcfxtKlS7Fhwwbcc889ePXVV/Hkk0+qjvW1115DdHS07S8lxb9puI6e+B54wLkbekMsQLUWE0wOLEAaqwk157NJDMsQEvtzfUayYuNSbgzeuJgL0+ABseDjxqokAisr5fO8gT5IfqrFhAkGJAyG3kNiEi6W8ndBe1aepiB/yuvKUVZXBgBoH0MFkDT9OxAsQELuvtv+ciUB9MknfAueQUmDMDBpoG3Z3jwS6bzm9BqcLjktsgAJ2wG5w51/3omfj/yMQV8MgtlqRkpUSrOvM+UIqdtx4kQSDN+9u/L6SnSN6woANtEZiDTZLLA5c+Zg1KhR6N27N+6991688847+Oijj2BQKSwxb948lJeX2/6yhc5vP+DogscwYlGj5p93ZAE6kV+JzAuk+6lUAJ0LOoeNJwpl2+g0jM0CpKmtt9X36kWafAKqLRyEFiB3qTMA3/8A7BD3UrRrAbIngB591P2xuII+SGwBYhggNVbiku3blzz6ZucABeIHBCvLYte5EvyamQOzRJQ2BQMQly0SFhSG8ODG46r2FoEugKTtZqRU1Rt0pFEH3G81NCgUO+/aicmdJgMALv/mcqw6tQoTf5iIEYtHiCxAUgHkamq1tPHqiHYjmozV1FtIf29q94+oKOX5AC+AThSf8NCoPI9XBVBcXBy0Wi0KJAUkCgoKkJCQoLhNQkKC3fW5/67sEwAGDRoEs9mM8+fPKy4PCQlBVFSU6M+fOCOAhD9KNb+8IwvQ34fysPlkEQorDTIBpGL8gZUVuMAM9YHnQy+zLVer7uwJC9CDDwBnTgNr1ojnS9PgnbUA/fCD+2NxBZ0k3T9Iq0GotLlrRATQjVw0hCnxAIkB2nq6CFklNTieLzZbuVpMMhApqSUivIW+ecZcSGnsAmhDvVfckSGdu0laWSue2/AcACCvKg+Mjn+iKysDzpedx0c7P8L/bfw/xLwRg80XNjs1TpPFJMtCenPcm05t25yRWoC4a6fwUjNgALB2rfo+uDivktoSVBudCBzzA14VQMHBwcjIyMD69ett86xWK9avX48hQ5QLnQ0ZMkS0PgCsW7fOtn779u2RkJAgWqeiogI7d+5U3ScA7N+/HxqNBq1bN47+StIfYJs24tfnzokvimoCyJ4FSPgkVWeyyFpsWBUEEMsSa4TNBQYr0KUzIPhc1YobNtQCVFYGHDvC/2SFgaJSC5BQhMXEkP9Kn4GvAqOlT5xBWka5u31GfQPGAwcBI38XFD70ykoTNH79g9JaUmq4uQadSvG0AMrLA4qK+NfLlgGffaa+viMcudT37yf/HWVaPj+Sr72RmceLfkMEXx+orAwY9fUoPLz6Yby46UVUGCpw86834+cjPzu0Bu3P3w+DRfxhxofHq6xNUYO7To4aRf4zDLBrl2K9WxtRIVGICCYNoS9WXvTuAN3E62Ghc+bMwW233Yb+/ftj4MCBeP/991FdXY077rgDADBz5kwkJSXhtddeAwA88sgjGDlyJN555x1MnjwZS5cuxZ49e/D556S/CMMwePTRR/Hyyy+jU6dOaN++PZ577jm0adMGU6ZMAUACqXfu3InLL78ckZGR2L59Ox577DHccsstaNGicVxgpRe8JUuA774DvvySnye0grhjARIWF9ZoGJkFSKl8A1u/U+uufQAGEAE0bLhoHTWLREMrMZtMgLEgCrqYGphLw3H+PO+TlgogAPjxR9K0j2tJpvQZ+DIzbEBqS+w+TywdV/RMVF6pfSophVtcDBw6bIsLEsYASS/6jVn/GC1GjPp6FLbnkASGlqEt/TyiwMCTAqi8nH+AYlkSYH1jfQLUlClAvIoesGepdWQB4h68HBnShe0mhORn3A9ouwOWIHz//ftAR3Gl1ouVFzHtl2mom2LATIWGqizL4qcjP+HjXR/Lx65xMHiKzIrPXSe7dwcOHADsOFtsMAyDpMgknCg+gYsVF9E5trPjjXyM1wXQtGnTUFhYiOeffx75+flIT0/H6tWrbUHMWVlZ0Ag+7aFDh2LJkiV49tln8cwzz6BTp05Yvnw5evbsaVvnySefRHV1NWbPno2ysjIMGzYMq1evhr6+K1tISAiWLl2KF198EQaDAe3bt8djjz2GOXPmePvtegwlH2yupFemMEzJHQuQWWDiYSDuwQNIXGDV1UB5Gaw7TgOHT8J6MQLAADBRkUCy+CKmVt2Zu2i66wIzmwGwGtQcJ1fz3bvtC6Dp08XbK1vB3BuLOwzrFIdhneJgsbLQahjUKdZYYoD+/Ym63bMbyOgHgEFVHf+DkHbFaMzxDAfyD9jED0BdYByeFEAHDvDTJhNw5Ij6ceyNQQh3LrMscOYMqd0j/Bm6kn2qRG7QJmBAfW2xAZ+ornffGxsw4tVbkZoqnv/u9nfxxLonAAAMGHw95Wvctvw2Uc8vijrS7174Pfbu7fx+kqOScaL4hK3JcaDhk8TgBx98EA8++KDiso0KhQZuuOEG3HDDDar7YxgGL730El566SXF5f369cMOaZRsI0N6Yw4OlgsgIY4sQEoXIqGLi2GAzl1Y+XKTEfjjT9tVs9JkAIJCwAb1AkyAJlr+BKfW3cITFiAhK1aQCtb33w/bBVDqxhOiVJPEH7EV2voPSB+kRZ+UaJRWm5BVIvDn9ekDrF9P+qrlXASSk7HiIF8QSpoR1njlD5BfJQ72pi4wgicFUEkJP11bK24UbK/IoL36O5wAeuUV0pPwmWfItHS/zgigGb1m4IdDdoLxtOpvvib0JJ56xog2M5/CFZ2uwPgO41FSW2ITPwBwfffrMbPPTAxrO0zW84uijDSpxl0hy2XbBWpj1CabBdbYkbprgoPlWVpCHFmAlNw/QgvQ0l3Z0OjFv3prnQlY+Bn/yBgWhtpuPYBBg2AdMw4Ayd6WolVRQA0NgpZerNesAUaPJumZq1eTefYEUGgo0RZCkpOV1/UVo7vGY2qGZBChoUDP+jolmXtk28gEUCNWQNILY/c4F/JsmzBSAdSQxAFh7M+OHcBbb/GvGyqAniNxy3j1VfFyJYusGl9e/SX+b9T/2V7bs9Jc1+06vDVO8AbiD+BA9Mt4f+f7mPD9BDy/4XmZ2+uqziQQKa1FWrNtauoqnhJA7aJJle0LZQrNJgMAKoACFCUBtGgRCQm57Tb5+tLKsRz2LEDSDuMbTxTi6qt5/68lN588PkZGAnfeAcydC8yYAUycCDaCdDdUyvhSiwFqaBC0dDulGj5q8QwcN93keB/+QPaRcTWBjhwBuEy7eqTWQXezwFiWVa3z5CsulIsvjMPbDVdZs3nhSQuQsJr899+Ll9lrb2Gv4rOjIGiXCrDqQjC+w3jb6yXXLeEXWrXA0t8Qs/sNbJq5Fb/e+CueGPoEjtx/BChPBkKqcCJhvm31+Zvny9Leh6YMdTwIigiPCaD6NiOBmgpPBVCAIg0yDA4mvtc9e4B775Wv754FSHwnrTFZ0LcvcMMksjOrhSWK4p7ZQIq4cJitDpDCvddRELQnLEBqyXzSWAAp0j6tniyz3xAYqSMrORlo1QowmYHDR0SLpN+buwagdUcL8PG/p1FUZSfYw8twrQm6xnXFgkkLMCjJTlpJM8JdAXTqFPmNf/cdv43w2iAtb2av07pKu0YAnhVAgDgzK1ofjZ137cT8y+fj+651wPFrUbbySWhzeSHTvVV3IEtZLB+6dAgA8NyI57D1zq3o0LKDc4Og2PCUABqSPAQMGGzJ2oJzpeew+cJmzF07F+vOrGv4ID0AFUABivSCJzQlt1Po3acmgDjXkDMWIJY0oIJmD2leZw3SE3NTeIRsW2ErDClqLjBhM1J34C6qiYlEGyjR3kER4ZEjSQow1wne3bF4mhBZpWiGFEYEZHcps8UqzgRzUwEdySX17vecL3VvBx6AK5P/xVVf4P4B9zfqgG5P4q4AuvFGYOdOYOZMUqcFEHf0lgqgO+4gD1VKqBTrB+B8FpizN872LdrjoYEPYe7QuQgLCsPApIF4dsSzmHGTDldeSdbZt0+yUWUb2X4AIKciBwAwNm0stf64ib0gaFfoEtcFveNJ1HRmXiZGfzMab29/G9cvux5mq/+LW1EBFKDYE0Dx8XLfupIAOs2X0lCMH5JaEqxWFti8GZp8UrPBEtMSbRKVg1JtdYAULUCKm9hSYt1tsya8qLZUyZbu0sXxfq6/HkhPJ9N1dXZX9RnXpCtczPv0BrQaEv0uKPxZWiMOmpZZjxoJBrPBVh8kEFNk/YmrAshqJeKGq78D8NPC3/iZM/JtX35ZeZ+cALruOtK/a/RofpmnBRAAfHjFh4pFCnv1Iv+PSzsq6Mvs7i8hwolcbYoiUguQM7FcanSK7QQAuPuvu2Fhifm/wlCBbdnb3N+ph6ACKECRBidKm3tKCyMqCaBSwYO98OLFYZE0Oi09cQbYtAlaEPNOfqEOwzq1QmK0XratzQUm+AWN6x6PKX2TVJ/iuYKE7vb2EcYzbdkiX56SAiQlObev+ooJsFgCo8puYnQoeidL6hCEhQNd6itD7+OtQNklNfhtL19YrLEaTR5Z/YhtOi4szo8jCTxcEUAXLpDSUW0V2luxrGORXyjvdgOA19yJicSlVl+KDYB3BJAaXDsNmbU2c7Zt8uluX6Bjy46ixW0ilS1EFPuwrOdcYADf24/r9ccxd91cWO11efYBVAAFKPYsQIDcP69kxuZEUbdupAGhFIvQjWIxA3/+SVxgHXg/0lVXMlAqlaPkAuuSEIn2cep9nDgBJOli4jTCeCalEvvOih9AbBELFDeYopDh3GAHD6reBRup/sFnmXwpYur6EuOKAFq/Xv2hYv164Jtv7B9LLQ6Os9Ry9cFCBQlUau1uAHID9aQA4vYhy1i7OBD44AzwTg6uSr4Ts/rOEi3mqhBTXMNikSdaNOR7vCfjHtFrzjK36+IuHC086v6OPQAVQAGKPQsQAEybJn69bp38R8tlhinVvwEAi7AQ0LbtQFExEBEOUz8+EPXgftLBmUPDMDAYlIOgHWUjcQLo2DFg0ya7qyoivKj++ad8uaOnUiF6gVHLYCAWpeJi18fkLa7qU//02iENiI4CausUfACExigejBb+EbN1eONoT+NLpOe/vcQBJbcWx7hxjo+llkHKWY64c0VY1dneDVF4A/WEAOICrhVT9kvTgMokGAwMrm75NBYMWw4No8HLl6v49SgOUcoMbMj32KFlB1vPNwCY0mUK+rchLX9Ol5xW28wnUAEUoEhPdunNff584I03eHFSVyd/kuMsQNKOzBy2GKCyUmBzfXPB8ROQ0JZXWyzLoKqaV1amWh1efx1YtYq8FlqA1GJ/ODgBBAAvvmh/XSWEAoiL4RHiigDS6fin2GXLgBEj+KrS/kIYy9MijOs+qAHSlYOhbdsFuP5Ze2YtYt+Mxc9H+Ap8q0+vtk3vumuXP4YV0EgtPvYsQPYEkDOoWYA4KxQngCIiYAtI1su94jaE166GxI5wqFqABJSWAj16AA+MvQb91lTg4b7PNPzAzRQli3hDhSxXEBEA/jfif+jUksQFnSo+1bAdNxAqgAIU6QVPepNLSgKefJJkS3NPSMKKr4DYAlRnsshqvpgtLACWqBmzmaRQ9eqJiAigBRdkbGXw1des7eJzeL9YZQhN4Y4sEcInSEddopWwV9MIcE0AAfxF/KefyH97ab++RpRJ1zed+LnOnSNi1QlKqo348r9zOJRT7pXxucLM32eipLYE034hZstTxads0/dk3GOrFULhkVp87Akgd2PqOBxZgITu4qlTHY9HKFQ86QITHlPaqFlY7HHPtnB8+WWAPxUEMH/9JZ/X0O/xoYEPYXDyYCyduhTJUcm2uCBpHTBfQwVQAGK1KndiV4Jh+IwoqQDiLECh4RZ8uvEMFm05K1pusbLAiZPAyVMk22jSJHARJRH1ViOWZVBZyWLDBvI6LEhcAMSR1UeIsEhhWprz23E4iiu4807X9sdd2O09WfoUoTtR+MFGx/Af2P4DkJJbVotiSS2fDccvoaLWhH+OuRlw5SGsrBUF1fwY9uTuwc2/3Yw6M7m7TugwwV9DC2i4mz33gGHvN9rQ36+zFiDAuXY23hJAwv1K3TTS9xAw53QjRKnQbkO/xys7X4nts7ZjWk/y4JMURQI2/d0lngqgAMTVrCROAEljWLgnO10kuVoYTFILkBXYtJG8GDIUiOMzcWxOL5YBGBa59dVkw4LFAkipDpAaWi1wT308nDvFEKVFHQXDxcaNpEi1K3AX9kDIApMii6fi3GD79wGSzImzhdX4drv4Scriyy6vdjhfdl70esCiAdiTy0fsd2vVzccjahxw5wdnNVWz0gCeEUBKPxclC5CaAOKsr1VVwNH6uFaGcd0qqwR3812xguQCAHIBVFUlfq3m9qe4h1ocqbskRdYLoAoqgCgSXL2gxcaS/1IBxFmAwgTZGzmlNTZXmHXXLiAvn1xhhg4Rb8xdEK0MGA2LuvonrDaSoog6hsEdl6Vi9gjnTDrchcleCX41pBagbdv4qrcjR7oeC8Nd2BvqQvAUwuFrpW+ma1cgVA+UVxCrnQIFFXUorDTI9uVPbvntFvVlvW+h9X9U4AQGl4ElvcErresuVqvy+eiKBYgTOn37kng6wH6mmCsIrQ+cldeRALLXN5HiOmqFZ92FswD5u0s8FUABiKsCiHtKlF4EuKdGfSj/eLdsTw5WH84HWBbmrxaTmRkZQKhY4nNbsCwAhn8ajAsNQ/VhQfNOBogJC0Z4iIPa+PVwQZHuCCBpDFCnTsD27aRImzucP0/+qyRX+RXZzUOnAzJI5gS2bAYUihMs2ZmF73dcgNXKBkxm2LmycwAg68LdN6Evvrv2O2hcMSE2IzgLEJc4YE8A2bteOGsJUXKDuWoBuv12cfHVhjRwFSJsu8F9Do4EUKCUtmiMcKJbiKcFUJfYLtAyWuRX5cusxL6EXn38SHk5cP/9wH//iedLLy6jRtnfD/eE9vTT/E0dEFiAJObL05eqULx8JYwnTxM1cdllsn3GckHQ9S4w7mJoNALmcn6HpSWyTe3iTEaHGp6sLRKICEWL0AJkmxwymLz53DzxnUaC0WINCAvQpepLyK8i5YSPPXAMTwx5wrasZ+ue/hpWo4C7BjgjgOxZgE4pJNkoZVAqudhcsQDV1jquN+QuwvP9RH1PTUcCKFB6/DVGuAfqzgLjrDDcwBNEhkRicDJpzOjPvmBUAPmRZ58FPv0UGC7p6ScUB6+8Avz8M+zCXaAKCoDLL+fn2yxA0pRVlsW3P2zA3jZdgYEDSX6rgE7xERjPNWeud4FxT3PSJytXI00aYgHyhQAKkNAZUQxQkLb+NA0LB/rXW4E2K1uBABLr5Sn3gztcrLiIlze/jM/2kEKHfeL7IFofjWdHPIuxaWMxOHmwYssDCo+zLrDPPxe3vxAmF9xwA6ni/MsvwFNP8fOV4nIaagHyJtLzPTPTsQCyFzNFsQ/XQSAykp/naQEEAPf2vxevj3kdl7e/3PHKXsI5vwXFK5xUDuWwXVyCg4FnnChnodcDmjADQtMuIftCHC5dCkV2bSnOaSrB6JKk3i0SSZifD4QEA0PlzQITovQID6/C+PHAsq3EAsQJA3ebNHK4awGqqwMeeIBMO+pE7SzTpwNLl4rnGY3+ix8QWm2EWWDBWg2M5vrA56FDgN27gOwc4MhRUvxEQk5ZDYoq3VCYDeRk8Un8eOhHfL73c+RW5trmz84gLQui9dFYd2tgdIEOdJx1gXFJBQApLpqWxpeY4M61qVOBq68mdcMA5wRQbS0/T80C5KuHBakA+usv+TxpBiy1ALkHy/K/NaEA8sZD5y293Yxd8CDUAuRHuAtRSHIJjufzHUJdtXTo9UB494vQRdcionc2kjrX4I1vClFtrUPUoDO4aBUUuDEZgfX/kOkRI4GwMESH8gdKaxVuc8UIhYaaAHLVkuOuBUjYh8hTJ+OcOfJ5/myOKg3baREWBJ2GQbywF1tEJDC03mX515/yKz+AtUcKUGXwbWoby7IY9tUwvLjpRZH4SYpMkpXCpzjGlSBojqAgceFB4fkrnNZogCd4byQAscWkuJi4QTh3upoFyNE5vGaN4zE7g/R8/7//I9ZzIf/+K35NBZB7COO25s4F+vUjXQaaKlQA+RGdDtBG1EHfrgi/7sy3zXdVAIWEAJoQ/oYX3jMHezIBY/1+TBrBlerAQaCyCmgRAwwaSI6j5e+8Y7vF227EnEAzl4eCZYGW4cEyAWRyUchw72nJEnG8kiPy8vhpT1SXBZRTOwOlOzwA3DokFfeO6gC9TnKajhxJOl8ajMS/YXFO7JgtVrBeemzPqchBYY28q+aEDhOg1XggF7qZ4YwFSGp9DQpSt14KxbVGA7z5JokP6lrfoUAoGH7/XbxvNQuQI5Fhc6M3EHcsvtQF5h7C733YMOJuHDvWf+PxNlQA+RGdDmCCyJWOq0YM8D9CZ098vR4kWFmA1aokpFhg924yOWgwoCUHsLXEABAaxN+sOAFUcyIRtRdicV2/JLkFyEVXllC8ZGQ4v50wm0Wpsas7KGXI+PPJkZGELms1DIK0GpgsEtGi0RC/RlgoUYZr1zrcd3mtCQs2nMHao94pjJhTkaM4PzEyUXE+xT72LEB795Kq5VKxrtOJzy81rRsURARRx478Q4BQMEiLsCpZgEwm+yLjo4/Ul7mKvQfBTp2U51MLkHsIBZCnQg0CGSqA/IjwByZsOOqMBchiZUklZ6j35eGsM7aLYlYWuXIG6YA+ffj1BC0yNBr+NswJINakgyE7FpH6II9ZgABF740qQrHCFX5sKIFuAeKoMiiozKgo4NpryfSu3ap9wgDintqfXQYry+JoboXqeg1BSQC1jW6Le/vf65XjNXU4CxD3u+fOu/37yYNDYiJwxx3ibaQuMLU4O65uGMB3eBcKBqkAEl5fODFkMJD6W1Iuv5yM/cEHlY/tDvaugwohcAAC8zxuDAhdYJ4oYhnoUAHkR4gAkj+mObIAsSyLb7efx6ItZ2GxsggJYQFGvB8GCkJqyxbyv3dv0VXNKhlCcL3LRamUjFQApbZXHqMa7rqvhGLFUy4wJQtQIMUAcVTWqbi4OnbiaySsWqUaKGJl5S1LPOkKY1kWn2WSjK/pPafj4L0HUfu/Wlx49AKSo5IdbE1RgrsGcL9R7nfJxWNYrcT7KSQoSHzT4rJ5pAhrujgjgIQWIO6ykZdHym5ICQ/3XAFEDnsCqGNH5fm0FYZ7UAsQxWdIf2AsS5KzuAuX9Ab97/ECfL31HGpNFpTVmFBrtKC81oQqXaVs3xUV/H6CggBkZwGnz5Cr02XDROump8QgJiwIwzuRXMeuCVHo0DoCXaPl1a+EAqhzZ+Caa1x6y24HMAs/K089mYSGyucF4pNjsCAGqH+qxP83cgSQkkyuXNu2Km5vtlpllaW3nCpSXNdVqo3VaPVWK6w/tx4A0Lt1b/SK7wW9zk67cIpdWJYXIVILkL2sS+n1RM3CKhRAzrjAhBYge13gAe+0oLB3zRBas4QEYnubxoDwc/NnKQ1f0QzeYuAivWDt2UNM21zwoPTkPpBdjtIaE/Znl9nmWawsqjTVsn0LTZlBQQA2biIv0tNtQTRtYvTolhiJjHYtcMdl7dE/lfiWtBoGV/dpg3YR4put2cxfiMeNA266yfULntR646whQvhE5ykBpNHI3WD+jQFSZkKPBLSJ0eOG/skY3qkVEoRZYWBIUDRAfkDVciuQRaEydOYF57rKO+L347+juJbvwTKl6xSP7Lc5Izx3uRJdnDC3d2N3lBreujX5f911/DypBchsloeUKVmA1JCUFPMI9m7EavGA1ALkHtxvT6t1vbVQY4QKID+i1UJ01xs4UKwG1J5uiqv4wBuWZaHT2W8dH3bpPHD2LLmS1Fdd1GoYTOyZiIk9E/lCew6oq+MvLJ3DSWv3K3u7FuQqvUg7m60hvPBff71Lh7SLtCu9Py1Aaa3I3SM0WKzw4qP0mDagLZJbqHQk7NABSEoCTGZgx07ZYrOVlbnAxLBgWRbFVQZYpf5QOxwtPIpbf78VADCgzQDsumsXbW7qAYQCqCEWIGkrimPHSAmwvn35eZwA4s7Dd98ldXbU9utIALVta3+5O0gtUkK4LDkp1ALkHq4m4DR2qADyI45MjMLqm8IbU62Jv7KxAMqscgsQR3wCC/3G1eTFgAFATAwm9kzA7BFpovo/zlBXx58gyeHReHhMJ9tN21mkFiBn6psAvPCaPBlo08alQ9pFEAsOwL8CKCFaj1uHtMMdl6W6uCVDclYBkuVnEL+JgvI6bDvDW2mU4n8yL5Ti2+0XsP74JdkyNf5v0//Zpmf1nYUBSQNcHDdFCeHNmxNAxcXESuOMBejnn8k58v334uUtWwK9eonncRZQzgIkrLfFIbQEOBJA7V2MCXQGez3F1LqUUwuQe1ABRPEZspOUUbcAGQWZWsKH+ZzSGruVi9M050mPDL2exIsA0GkY6IMc+5GkJlChANLpiBXJVaQWIGcLInKfladLsksFkL/TZ+MiQhCis//dKLoNu3QBWsURU8HefaJFKw7miV6bFaw8nEA6fLFc9bgmiwmltcquMxrs7DmULEAWC+la44wF6IYbgIsXFVv8yZBagBw1MHVUJb1DB8fHdBV7ViWlOD6AWoDcRegCaw5QAeQnzhdV40LQBegiBE/rEj0hTPcWCiDh7WvzySL7oiC3vsDiFVfYOr4769u94griWeEQPoG6+4Qg7SrsqgDydEn23r3FrwMxCNopGAYYPIRM794FsOp+A7O0rhDU448A4PClw5j4/UTEvRWHlm+2RPK7ybjup+tgsvAKvmWoh2oTeIm6OuCuu0iRv0BHyQIEAIcP2xfo7pwb0iBoe+4mgFit1bIwY2OJSPM0UVHAmTPKrm/VEiDUAuQW1AJE8Qm/77uIWhgQkqJeDEf4dGPrBQXgYqn4KsgVS1OC1QUBU6aI7vQWBxc5jvBwUq2ZE1hGI39hcfcE6dJF/NrfAmiAxGvTaAUQAPTqhdZsHVBaBpxS7xZvVrnLZVdk4/fjv9s6uHPcvvx2rDmzBhUGUkPoYuVF/H78d/x+nKiJYG2wrbNzoPLdd8CXX4oDgAMVoRVG6uKxVzvLnXODE1jOCiBAXXS0b++9G2damrJ7jVqAPMfevXxtKSqAKF7HKjU3S1xgQnOzyY5qYRhg3HDlOzebni7z81hdqAGj04kbmDb0CUGnI5YlDmcFkLeeTFq2BFau5F/72wXmDKxKF3gEBSHosqEIN9WSbvEq37OSBQgAvtr3FQ4WHMT1P1+PE0UnbPOPFh61O56td26VZZkFGsLv1ZGbx99wv3WNRn6D37hRfTt3zg1OYFXXhxE689moCSBvp00rWZ6oBchzDB4M7KzPoaAuMIrXccXfLrQAyTh1CkN3vIuWKJYtapUgvypaXMj0AZQFUEMsMX/9xXca9rcFCAAmTQJmzSLTjdoCBICZMAFMcDAJAjl8WHEdaQzQqYIqWAQus63ZWzH4y8GoMFTgYMFB1JrVVWFqTCoyEl3oaeInhK7XHOWuHQGDMA5DetO/eFH8+rXX+Gl3BAhnAeIEkNQCpKRr1USHt2+aSgKIWoA8h1A0UgsQxetIBZD0YsNdaIqqDPhtr+TKx3HmNPDTUsBkgi5IfrUSprxyhAW7dqXiRMfAgcSwADTsBNFqgXiSRR8QAgjgL6SNQQDZM+BpYmLAjBlDXvzzj+KjsNQFZray2HJBXESxrK4MPx/5GX0WSqLEJXSL6xbw1h9AfEO8cMF/43AGobXT3kd71VVA//4NOxYngP7+G3jySblwcCSAbrqJn/a2BUgpAJtmgXkHagGieB2ZV0vFBfZrpsoja0U58NtvZEc9eiCotbgq2OjRyhel9nGuVS9UEh0NfULgnuYCRQBxF/VffgGOHyedshsjDAMwI0cCMdGkHLhCjzC5AZDF9pztsvXu/utu23THlh1R8EQBusZ1RXhQOAYlDQIA3ND9Bk8O32sIK5gXyw2lAQX3YOToHEtKIr23rrwSeOAB944lFBBvvSX/bJSuH0IBdNtt/PTQoe6NwVmUiixyD1Ic06eT/9QC1DCaiwWombzNwGPDBuDCeTKdkQFkZsrXMWmMKKiwosao4CurqQa++x6oqQUSEoBrpyB4ifhxTemJ6ao+bVx+Ylc6GZqqADp+HOjWjQiJ2lrHab+BhoZhwAQFAcOGAytWANu3ARn9AB3/wV0oFteNqjbVoMbEV6TUWZNg1ogtjsunLUfr8NY49sAxAKSW0OmS0+jQ0gt5z15A+DsrLPTfOJyBu3k7egrnen9JCxe6gqNK7o4EkF5Piiv++iswd67743CGlBR+OikJWLCAnKezZ5P6RStWkPpHS5dSC5CrSHs8NhcBRC1APub0pSr8d7LY5koCgJ49gVtvBXRRtQjtmA8miFwB/ys4jyU7s+Q7MRiA738AioqA6Ci0nT0T0OoQIvGRS/3j47rHo2Nr12vVB4IFyNvpmdL0fJYlBpRAxF4EF8PUuy369CGBVuUVwBpxb4OdZ8WpRKtOrQIAtNC3wI09bsSdXT+R7bddTDvJcRh0iu0EDdM4LiHCC3yRZ9qgeQ2pBUjBiAdAnlHpDo4EkPS8AOQCqFcv4MUXvdMHTIiwHtCuXXwfwoULyXc6eTL/mXnCAlRQQNLvmwOVknaS1AVG8QonCyrx30mxnTk8nKR5hnXNQ3B8BfRtiwGwKjd7FvhjOWnHHB4G3HorBg3oBEAeJCgVQC3C3WujriSAGmqJCTQLkLQgovCYgUa7liqBDyAWICsLcieYUn+H2LNHNSAagC29PVofjW5x3dA1ritWzViFF0e+CADQMlpEBBPhvD+7DAcEvegaC41JAEktQNI4vpAQYm2ZPbvhx1KLoeG46ir5PFeao3qSZEGtTaE7jGH4orHChI2GkpBAus1fcr44eqNF2pKoubgQqQDyMRpGfnJKfduaEPLrUxRAO3YCx44DWg1w081AbJytlxcnKlgTuXLq9cCAVL5AXZAblZsB71iAOOEjDKK0h7cFkLQeEOC8OPM1QzrEYnTX1uiWGCVbxgIwc8FlaR2AYfXlgH//nTw2K9iPqk3EJTYqdZRt3sSOE/HCqBdw8N6DOP/oeQBAncmCDccv4d/jl+xnJQYgjVkAAcBTT/HTN9wAvPmmZ84Fe81LMzLIcaT4SwDFxQGvvw688gopjqiEJy1AHHaeHZoMUgFUXa28XlODCiAfo2EYmQCSXkSsRh3AKJghc3KAdevI9PgJtjLNunphwz0FWQ3kyhgSIm5XoXOy6akUbwigXbvIf2ee1FgW+PFH9bF4grAw4O67xfMCVQAFaTXokxKjmM1ntlhF4mTkPTdiQJcEkt+8ahXw1WLg0CFRKlmNkVztwoPkPoxe8b1sbS6EtahcqSUVCAgFULl6t4+AgHNHcKUiAODee/lptQ7o7tCunfL8668nhkPhGDj8JYAAIgSfeUZ9uacsQMKfd3OwhkgFkLM9Ghs7VAD5GA3DOE41ZQEwEhdYXS1JUbJagR7dgYG8yYITNgMGEFeOpZqYgvR6iLqA64MCRwC5grBuS5ZCSJSnELYeAQJXAHEo1XMyW1hRnZ9+7ePQ+Z35pPqkTgdkZ5PMwZ9/AqxWWFgL6ixEHYQH2feHCI9GBZD34DKxhL0AhZYaJVHiLhoNMHGifL5SU1QOoatdrQ6Pv/CUBUhYoqQ5CCBpAVgqgCheQauRW4CkMDqrXACtWUuu3C1bAlddDWH3Js4CpNUCb8+Nww1X6jFwELlQCm+GegdNNtXwtgBydC8V1ua58krPHVeKNOMrUGOAOJSamlYa5FdrjU5Lijg9+AAweBD58o6fAP5eiUtVBWQdMNAH2b+bCduLuVhL0+8IxayaAKqrI2nUX3/tkyGpwrnohAJIKHo8bQVVynS0F9AsFARqrih/wX02ZrPj64o9hOd+cxBANAaI4hMYhRggjltuqV9HawGjZXkX2NkzwP79RPNMmSK7YgndXFH6IHz1TjQeuS0Sl3dtLXKHaDwYA+TJi7A0BVOK8OY1fLjnjitFas5vTBagIC35boO18u9Yy5kYo2OACROBqVPJbylzLw7tWwMASGvZAYzdlqjiAoquVhP3N85YgJYsAX76ie+H5C+ULEDCU97T1lclAaTW8BQQnxeBViZC+Nn06AEsXuzefpq7AGouUAHkY5RcYBwdOhA3li6mFlEDzpKT2WjgC30MHCQuhlGP0IXGgoVWw2By70Skp8SIusi7i7ctQI5OPu7mJexM7w2kF/PGJIBuHtQO7ePCMbZ7vGw9mfDt2hWYSBqyXSomPsVucd1Eq7CSx+djeRX4djtfQtnaBAWQ8Pt2pimot1ASQEJ8IYDsIfycAq0IuPBadewYcOed7u2HCqDmARVAPsaRC6yiEraK0CFBVuDPv4CyciAmhpR2VkBj5yqkc9PqI8QbAuiyy/hpRycfd8G191TqCRqbBUhokWkZHowpfZOQGC13Yym2Phk4EBg6BIX1ro5WFvFdcMmuLJEIWn1Y3CG+sccAKQ1fGFxcUOD9ManhSAAp1eZpCImJrq0fyK5hT4lD4XsM9OuAJ6ACiOITpC6wG28UL+dim4cNY8GsWQ0cOUIima+5RlUBaAUCSHphH5QWizYxeoxTsAw4izcE0G+/8dPOCiBvm9sbmwWoV1I0ACC5hf3YnSCtBrcPTRW5SgHAcPkIVNS/51a7j4iWXaowoKJW/dHX4oIAYlkWa4/kY+tp/+WfCwWQxSIP+gTE5+VXX3l/TGqoCaAFC4DrriNFUz3JvHmurR/I54WnXPPC34LSb6Wp0RzeoxJUAPkYDcPAaLIAjBWpqaTtgpC0NBZz7yjCGNMaYPduEqsx9XogNVV1n0IDUGyEWCRFhOgwbUBb9Ky/WbqDN1phtG7NP3k6Ovm4m5e3LUDS/QfyhR4A0lpF4PahqbiuX7LDdVuEB6Nrgjh9qLCGCJJIA6A/eAwoEFt5WLAoqzGixigXQvZcRGcKq3CqgC8tW1xtxJHcCuw6V6K+kZeRtr9QanornKdWfdkXqAmg++8nLSc8fR7ExJCeYhxLlthfP5DPC50OiHb/UmejqQmg118HPvpIfTm1AHmRBQsWIDU1FXq9HoMGDcIurgiMCsuWLUPXrl2h1+vRq1cv/P3336LlLMvi+eefR2JiIkJDQzF27FicknSvLCkpwYwZMxAVFYWYmBjMmjULVQGQ21daV4zVlW8B3X5HmDDTwmImgc4LFyJs8QJg504yf/wEoHt322ojOsvt3wzD4NYh7TClbxJaR3q+MIe3gqC5KrSBYgGSXugC+ULP0SI8WGbZAYh4uVB+AUU1vNVFut7RwqMAgNZMfY71tm2i5dVGCxZvPY/PNp2V7f/wReVAGouVxZ/7c7HiYB7qTCSX2GzhrUXS2CJfkZcnfq0kgITfvz8LwTlygXkD4efhqDhpoAU+S2nTpuH7aEoC6MwZYuV7+GF19yX3Hjk38KxZvhmbv/G6APrpp58wZ84cvPDCC9i7dy/69OmDCRMm4JJKffFt27bhpptuwqxZs7Bv3z5MmTIFU6ZMwWFBOc4333wTH374IRYuXIidO3ciPDwcEyZMQJ3gLJ4xYwaOHDmCdevWYcWKFdi8eTNme6J2fANZf/YfmGEAWh+GVl8NgAUOHAA++AD44w+g4BKg05KI6OuvBwYPFm2vdLMDgLiIEJe7vDuLt4KgnRVAvrIASXt/NQYBpMbJ4pP4ev/X6PhhR9SayNVNGgx9ovgEACCjXf1v7PAR0YdQXKWenndIRQAJY4MMJrmZyB+x0yzrWAD9+ivwxBP8a388K7EsyVo6Qb4WvwkgR3z8MdC+PfDll94bT0PIzW34PpqSAMrO5qfVhD33/d92G3D+vP06UE0Jr5eze/fdd3H33Xfjjvrc0oULF2LlypX46quv8PTTT8vW/+CDDzBx4kTMrW8tPH/+fKxbtw4ff/wxFi5cCJZl8f777+PZZ5/FNfXd8L799lvEx8dj+fLlmD59Oo4dO4bVq1dj9+7d6N+/PwDgo48+wqRJk/D222+jjSceEdyAZVlUnj0J1BfZytb+DNPXDHChPrsmMgLI6A/06c0X4jCJ78JGQx1MknnVXn5cVXpoNxga/pTMPUkWF9vfF3dP1um8+2QudetUVjbOkvAmkxF5ZXmABSivLMfhnMPo3ro7DHX8b2dHzk6UVBGXVEJSd5janiRVJrdvB0YTf4jSb02I0u/OaLHatqmuqYaODUJ1Db+fyqoqW+sWX1Fezt/EQkLIb7e0VPzdXn+9fBtff/cbNoizlkJDfTcG4UOIo2O2bUuKiTuzrj+Qijl3xih8GPLHb6GhsCyLmpoahIWF4fx5/sGnoED5gZZ7v1otaTniS9EXFhYGxk/phAzrRZu00WhEWFgYfvnlF0yZMsU2/7bbbkNZWRn++OMP2TZt27bFnDlz8Oijj9rmvfDCC1i+fDkOHDiAs2fPokOHDti3bx/S09Nt64wcORLp6en44IMP8NVXX+Hxxx9HaWmpbbnZbIZer8eyZctw7bXXyo5rMBhgEERKVlRUICUlBeXl5YjyULWv6upqtOzYF6EdFBpPOUnNif8Q1mWYaF75VgdOe0qzIvqym23Twt+GPi0DIYnOtxCvPbsHoWn9VZdXHVwDSyXf2FcXkwh9u97QRhDTRWXmn7DWVUEbGYeI3uPJeHb8TNy9FAqFAqCqqgrh9ipvukFFRQWio6Md3r+9+ihWVFQEi8WC+HhxBlJ8fDzy8/MVt8nPz7e7Pvff0TqtW7cWLdfpdGjZsqXqcV977TVER0fb/lIU6u14BFZsZrBUFqqsqIzV2MjtsRS/odG5FrzBaOwbiCN6TwCj4/2S4T0ut4mf+j3I98nQvAsKhRIY+LCjU2Azb948zJkzx/aaswB5krCwMGzd+h/unV+E04mvAQBqmN14tN83iA9LAwCM6dIK60+oi6IZAxdg65kSnC+uQXxUCNKTo9GxlXcdtsuWyavjeiJG4r77gO++I9NbtgB9+yqvt3Ahic+49lp+fW9y113A0qWk6/Qjj/Dzjx4l5n97HbT9SWZuJkZ+PRIQauyR5F+76Hao/e0bGJgaYNgbAAAto8XTw54Cy3X5OnKUxKEFBwP3zEZGjxRkZpXZPeaMgQsQE0ps6gs2nRMvG7AAMWFByC2vw+/7SRDOHUM+Va5L5EU2bwYmTQI6dSIm/uPH+WX79wMdO8q/04gIQOVZyWusXAlMm8a/9mUcUnIyUFbm++N6A6NR3NfvxAnST3DxYmD+fOfqKHG/GYC0R/niC8+OUfh788bnLXSBzZ/P4M03yfyRI0ldXU39c8j995OMx3btyO/vrbfIddmXhHHBoH7AqwIoLi4OWq0WBZKqYgUFBUhISFDcJiEhwe763P+CggIkCip4FRQU2FxiCQkJsiBrs9mMkpIS1eOGhIQgxMvpDQzDIFQfitOngwFdOpCwH9BZsbcwE1d37AoAMGjLodEXILvYgtSYVNk+WsVE4ZqMSNQZrYgO81JrdAk33QR88gnJygeA3r3t9wpyFmFQ84EDwLBhyutx7uGwMM8c1xHcBbKsjD/eli3AiBGkiPKxY94fgzucrDwJBAMQZnrU/0Qu1F5AqxAzEFQL1OuP+eMfgtUcDBOXpdWnD5C5B8i5CKz/F9r0uxEUZD/yPCI8HOH1Aki6bmhYGMLDg6E3MLZloWFhCA/x7XMXF9DaooW8qu/Jk0QASamuJr83X4YmaCW60Be/dY5XXgEeeIDc/Hx5XG8gHb/BAIwZQ6bNZuDHHx3vQ5jkYTJ59zPx1r4j6lWWRmB03bQJWL0auOEG8vrbb8l/LscoOrrxf/+u4FV7dHBwMDIyMrB+/XrbPKvVivXr12PIkCGK2wwZMkS0PgCsW7fOtn779u2RkJAgWqeiogI7d+60rTNkyBCUlZUhMzPTts6///4Lq9WKQYMGeez9ucMhLpnt1BXoENYPAIvM3Exsy96Gd7e/ixGLh2HulrH45sA3OFl8UrRtsE6DYJ0GITqtz8QPQAJHd+0iN4Wvvwb++ccz+xUGFu7bp74eF5rlq/TbrkSLioTO99+T/0LrQaBxqlhcCkIncWGVjroc6P8ZACC+RRXuGSaJ72EY0m1WwwBHj+L4xt0Oj2kvhJDLCBNmhvmjgjTX+iI6Wl7t+8YbibCVwrKuZUZ5An+2XLjvPiIGP/7Yf2PwFkJL3smT6usJaUpZYMLO9gCfc6NEqP2aqk0Orzvk58yZg0WLFuGbb77BsWPHcN9996G6utqWFTZz5kzME5QifeSRR7B69Wq88847OH78OF588UXs2bMHDz74IABiRXn00Ufx8ssv488//8ShQ4cwc+ZMtGnTxhZo3a1bN0ycOBF33303du3aha1bt+LBBx/E9OnT/ZYBxlFSVn8DsARjWFp/AFawsGDd2XWoNFYCDP9rPVfKuxTuG9UBs0ek+Xi0YsLCSJqkp0rxCzNPNm9W797sq1YYHFxxSi4dGWgcF8GDlw6KXo/vMA5/Tv8TyVGkUKJZk49y/TcoD1qCfqlh0GoY+WcenwAMJg8Stcv/BMrL7B7TXlo7J3aEPctYP/TYsieAAPWih54S+s4ivOn6+jmNYYiLUNMEQ7SEDgWplU2NpiCAzGYgM1PebNredVTp/GjKeP3nPm3aNLz99tt4/vnnkZ6ejv3792P16tW2IOasrCzkCYp0DB06FEuWLMHnn3+OPn364JdffsHy5cvRs2dP2zpPPvkkHnroIcyePRsDBgxAVVUVVq9eDb3g2/vhhx/QtWtXjBkzBpMmTcKwYcPweQAUN6irIzeDtm2B1lFR9dEXghsElyMPFiV1JE151vD20AdpfZ4+7G0ESXw4dYqkASvhawsQFz9QyRczDviLoMliwn9Z/wEABiUNwoQOEzAwaQCu6nIV7usvcOozRIH0bt0bgFjAcB3lMWokKdVdWQX88IPdN2/PosMtCiQLkCu/oauv9s541BDedLn+x5SGI7QAOSvwhN/Fhg2AStm6gOZ//wP69wc+/FA8354Aam4WIJ844x988EGbBUfKxo0bZfNuuOEG3MA5KRVgGAYvvfQSXnrpJdV1WrZsiSWOarr7gZr6e0lQMBAWFAZxxCpgKxIEK4prSIpxlN537i5f8uSTpPbOokXkKe3sWeV+r762AHE6WugCCdS4H45/zv6DKmMVWoa2xISOE8CAsfWIm50xG5/u+RQ5FTm29dtGtwUAPgAaQHyUHjmlteTHOWMGqXRXWESCJmbeCujkv0OuK3xJtbxeEGf5EYosfwggrhpGixZAif+6cTiEc4Fdc43nG542Zy5e5KfdEUAA0KuXfxvkugMX+CyFu44qnYrUAkTxKpwFiBSjYrDoqs/RL7EvZvaZiTmDHwNsNiErimuLUVpXqrqvxk54OMnKGDWKvK6pIaJDemL62gLEHYcTQJWVpCdtoMKyLGavIFXOR7YbCaY+/ZwL4I0Li8PJB08iaMW3tm0SIxPrt1XZaVQUcMsMckXMzialkhUagHHi5p+j8ruDkgvMH5Wgi+q7gcTFBfYFnrvpeqqjeXNm3Tpe7JwTJCc66wITWn+BxmkBUoMrhKjUzy+Qzw9vQAWQjzHWkY+c+xHe2GMa7h9wH9rHtEdkSBQq5pWi+plqJEeRG1RBVSN77HADLgvyxAmSHjpjhni5vyxABgMRCDk54uV+amelysnikzbrzhND+X4OwuqqoUGhYEr4dKeECOVsSFHWU6vWJAVQpwWOnwBW/Q2huxYA8sprcbawCtUKDVO5z0kUA+SHD08ogOLiHK//3HP89DvveGdMSnACyFMdzZszY8cCK1aQ6bOCVnbOWIAWLQLuucc74woEOIGtFHTvx4x0v0AFkI8JN0bBXBqO4CAgNS4MESE60VOxTqtBWFAYovWkc3e1qZHVYHcDLu3yk09IxoI0TdXXFiBOALEsuSlx9VE41BoK+ovtOdsBEOvP0JSh6JZIfjsDUluKVyxvZ5tMjEiEEhpp3nfbtsB1U0lNwz2ZwKbNosUbTxTij/25qFPs+1VvzRS5wJx5R55F2Fw0Odnx+pddxk8L+4N5GyqAPEuHDuS/UAA5YwFSaxnZ2OsjcbAs+VPKBouM9P14/AkVQD6mrlaD6qNJ6K3pjGv7JoNhGNETsq7+ESUymKiCTec34e9Tf/tlrL7C0VOHry1AQqFVVycXQIEWEH34Eqmt0Ce+DwBgXPcEzBjUFgNSW4hXrGwDbHsc2DoXLUJbSHcDhmGg2Gu3WzfgivqqcBs3kvQoiSVHp7Ah97O2+DkIWmgBckYAtWzpeB1vwD2RUxeYZ2jXjlg0heU2GpLldv/9DR+TEt5oumxvn2Yz8PLLQBeFrjiBWuTVW1AB5GO4k1F40xfeFLj7SGQI+SVWGisxeclkXKpuQk5oCY4Kb/krBog7tlQA+bo+jCOOFZEI7e6tugMAtBoGraP0sgaDViuAtW8D61SiIwH1poQDBgBjxiA97wSwdSvw01LAyOfXhgTJLyVKMUCZF0qxN8u3cW1c4HPLlkBSkuP1/VUIjssHoRYgzxASIhe8DRFA3qpC/+mnnt+nveatZ88Czz+vvIwKIIpX4WrfCC+ywodi7gYUqRf/EtedWeftofmNQLMAaTT8sRqDBSi/iuT5cvV+1HDG+CJzgQkZNgz95j1IYoJOnCT5tbt2AWBRa7TIVmdtLjD+wKcvVWHTiULUmeTrewvOdREZ6Vx2lT9SgcvLgTVryLSfGmM3Sdq3F792RgCpff7Dhzd8PEocOuT5fdoTQP/3f+rLaAwQxatwAkj4Q9MouA/SWsaKXmeVZ3lzWH7F0UnnawsQIE6FL5UYLPxlAcouz8a27G2y+YXVpHdc6/DWsmVClARQp3ix0FZ0gQmX3zIDuO12oGULoLoGWLUK2LhJJQYIWHskH7vPyy0+Fh8FAxmNvICOiFB2b3XuLH7tj0wYrlYRIC9cR3GfFhJPrzMxQFLRxJGoHDbXYOw0K3cbYZFZZ4mIaJqFMO3RzN6u/xk2jGQ5de/Oz1OKn7ipfzqMmtOo0hLLT1MWQEouB5Yl/cFqanxvAQJ4sWUwyIMf/WUB6vhRR1z21WX44eAPIusK5x51JICU0l7Hdxdng9m1AHHLk5ORP/VBZA+4lszctAlWQdsZjqziGhzJrVC09vgqFkj4JBweLr8hAiTQ+amn+Nf+sAAJP/ZAc7E2ZqQuHWdu8GoC2BOxOk8/DcydK54nDDzmMk8bij0LkBrNzf0FUAHkc+6/n/SVGj+en6dVEEAdWqZgYFooLBrydJ9V0XQFkJIF6JVXSKXoJ5/0vwVImvXljxtUaW0pjBZyBb7l91vw/UHSnKzSWAmDhXxArcJdr54XrOMvAQwcu180DGmf8dkiDb7a3Rt1Q+srV65cQcp5C6gzq7u5pGJsw/FL+G77eRjNnu2VwYnX4GDyp/Qb0uv55pDca18j/I1RAeQ5pDd1ZyxAnHj47z9xS5KGCqCaGuCNN4C33xbP5yxAly4B8fHyMiDuQAWQc1ABFAAoWYAYhsHG2zdi7S1rAQAXyux0sGvkKN1wuFosCxb4xwIkFEDSC58/LECjvxWXyJ65fCZ6fdoLd/5xJwAgIjiivrK4+4SHaDGpVyKCdRqM7qpsTWIYSYuQjGFAeh/i71r2s6jsrj03l9QCtD+7DEVVRpy6VKmyhXtwAsjexT0khIjt4cOBiRPJa19bgYS/MSqAPIfUuuzIulJVxaeHR0cD69cDc+aQ1w0VQGrbc7/Nb74hrlBnutU7QkkAOcoubE5d4DmoAAoAdHZ6fHEtC44UHsHifYth9Uc3SS/j6MT0hwVI6ALztwWoxlSD/fn7ZfMPXzqMX4/9CgBoF91OttxZrurTBqlxYRjeqRXaxITivpEd0CclRnFdDcOI4lXMFga46iqgY0fAZAZ++xUwkSu9PTeXWUUceTo2SBgArUZkJLEMbN5MQpoYBthOSishJsajw1GFWoC8g1T4OqrhJczICg8nf4MHk9cNjc1yJKA8KbqVBJCj/VMBRPELWjt+B04AAcCdf96JJ9Y+gUqDZ5+S/Y0jAeRvC5D0oulrC9D6s+sdrtMuxn0B1LF1BK7tm4zwEPJFcEH5LcLk+dgMIw7YNRoBaLTA9VOJLb+kFFhH2qhb7Gh1UYNUUaVot9+GIkoWoAMHxK4NpSBUbp43arQoQS1A3kEqgBx9n0eP8tOca5677njLAmSp9xQLBUhDfwNKQdCOHiCbWyNUgAqggCCjHYnMlGbkAKSFgfDp/r0d76HvZ31hMCs/jpTWluJU8SnFZYGKo7on/owBqqoCvvhCvMzXNyiuy3t6QjrOP3IePVv3lK2TGp3q0j6dERrXZSRjYPuWmNybT3/RMIzo4mq7qIfogauvItO7dwN794qEjZS/DuTapk2CgCBPB0crCaDevYGHH+ZfKwkgoQXQF1AB5B2kVg1HFqB2gueI+Hjyn/steFsACa+D0sxTV0lLA+6+WzzP0anV3FLgASqAAoKEaD3uHdkBk3sp51ne1e8u0eszpWfw77l/wbKsLRuoqKYIX+z9Au3eb4fOH3dG+sJ0nC877+2he4RAtABxF73335cv87UF6EI5CUq4tfetaBfTDq+Ped3W8JTj3v732t2H9OKn1AdISpQ+CJd1jENECP8FaRjxTcRkIoHUWg0DdOjId7ZduRJVh4+p7ruyjh+AyeK9StFqMUDCi709AWSx8DcobyL8TJUy1SjuIf3eN29Wzobk4M5tLu4H4K873nKBcb8v4f654p3uMnQo8Pnn4lYujs55agGi+I3QYK1qFV6uxYGQSUsmQfOSBpOXTIbZasbwxcNx9193o9JI3GMHCg5gzLdjvDpmT+HIAsT5s/1hAdqxQ77M10/onADiLIGTO09GwRMFuPDoBXRv1R2fTPoEveJ72d2H9OLnytOsMDWeYRgkWfnUeaMR6Bwfidkj0tA/tQUwcgTQqydgtaL8uyXA+XNKuxRhEmR+GRTqCTUENQEkFN32BBBAmmPau2l6AuH3sWiRd4/VnFAKfl+7Vn19TgAJxQAngPLyGjaW779Xns8JIOGDVUMFEIfw/TuyflEBRAlILmt7meqyVadX4dPdn+J40XHZsrOlZ3Eg/4A3h+YRHFmAuKwjXzbqs5cK7SsL0Kbzm9Dxw47YkUNUmDAerFV4K7SNbosj9x/BfQPuc7gvqWhzpVBabEQwdBoGUaFEqYbU8YrBZAJ6J0dDH6StjyFigKuvATqkAUYT8NNPfDdSEDErDdA0CYKFDPYCh9yA++1Ib4RCQaj0uxJaG++7D1i92qPDksEJoIwMvoknpeEoCSB7KeJKhWq530JpKWmD5y6vvKI8nxNAwnPUnUKGSlABZB8qgBoBLUNb4uKci8icnYnru1+PmX1mYnyH8eib0BcA8PDqh1W3Tf8s3UejdB9nex/5KiMHUBZAnFXAVxagd3e8izOlZ8ixtSEOrTz2kIo2V+qEBGk1uHdUB9w+NBUWC7GqsBZy6bCYNIiPIh+WzXul0wE33QSkpAB1BmDpUsBggMVCaqC8/TaxqHDu2/wK/gP1Vh0ge9lASgJc+psUBsd6A248vnTzNgeUMpvsxddwwkMoBoTWwHnzyP+8POCcY+OmUygJIE89ZLkigJpjDBDtO9xIaBPZBm0i22DZDcts8/45+w/GfTfO4bZrTq/BhI4TvDm8BuFM92ut1rcnqJK7LTKS+Ol9ZQESWvU6x3aGXudehb7//gPGjhXPc/UJM0irgdUKZPQH9u8HtOHJCEkpQUfECdYSxO9odcCNNwKLPift2NesQfWoq22LDQZSOkjLADvO8hYib6XBSwVQT3kcuQipN9qZAnoNgbMA0UaonkXJAmTPvcSd20oWIOH+2rTh99XQmC0lF5inHrJcKW5ILUCURoW0+eWiqxZh6dSliAuLE83/5egvvhyWyzhz0Y+I8G2TSKUncS5WxBcWoKKaIpwsPml7/dqY19ze1/jx8gBOdyrFFhYS8QMAlmo9ao63QU0Z/0HJ4pcjIoCpU0mJ6X37YD7DF/M0m3nXlzDGyORhF5iaAOrWDfj3X+C43HNsdz/eglqAvIPwe+c+W4FHVoaSBUj4nYSHi92n5883eIhetQC50mesOQogagFqxCRFJtmmXx/zui1b7Jqu18DKWvHXib8w/dfpOHjpoL+G6BTOWIB8jZK5mLuY+MICtOXCFgBE5K64aQX6JMgD4Z1FabzuCCAlAVpQwE8r2m7atgMGDQZ27MD+NQUASCC3yQQYLVYYTFaEh+hsWWFmi28sQABw+eXO76fSy6W3/JHp2BwQusDatgVOn7YvgJQsQML+YeHh4vNJaBlcupRcy66/3rUxKgkgTz1ktW3reB0O6gKjNCoiQyKREJGA/Kp8XN+dP+s4V0l6QjoAUjHYYrVAq/GyHd9NnBFAPuqdaUPpAsQFy/rCAnSqhNRyGpU6qkHiRw13giyVRGF+Pj+t+h1dfjlw6hS2FA8U7evn3dmidHgAMHs43cqZVhiu7MdbUBeYdxB+7y1bkv/C79JsFl9/lIKgo6P56ZAQsQDiHgqKikjYG0CsrVIha+9n7U0XmFpneyWaowWIusAaOTvv2omj9x9Fh5by1JGOLTtCr9OjxlSD0yWn/TA651C76At96750fwH2BZAvLEDZ5dkAgJSoFK/s3x0LkGMBpKKAgoOBW28R7+v4GZn4AcQ1gRpKbi7w++9kuqECaNUq7wpf6gLzDkIhw01zIuf554HYWOAk72lWTINv2ZJvlltbKz7/OXeYMK5IqV6QvRpCOTmkAOPixfJxNBShC2zBAvvrNkcLEBVAjZy20W3RrVU3xWVajRa9WpPMoc8zP/flsFxCzQIkNF8nJCiv4y2ULljC9hjeotpYjfe2v4ePd38MQB7n5bHjuCGAuCdcIYWF/LTd+OXoGNFL08atwEG5a9bswRigKVP46YYKoPPngdmzG7YPe1ALkHdgwSIkqQTa8Drb9YQTQPPnAxUVwDPP8OsrWYAAYPJk8r+qSmw9feEF8XaA8rXD3jXj22+BrCzn13eVH38E5s4l5RxO23kOpgKI0uSY3nM6AGBL1hY/j0Qd6UU/IwNITCQlZDi4hoS+QukCxD2de9MC9H+b/g9z1vJlaDu27OiV4zjrAhMadZSKQgpTinVa58103+EWnF1+QJZfrtYk1R127+anGyqAAOC775xfd9Eikn3nLDQGyDvszy6DPrUIEelZ6NePzJPVoRJYNpWCoAH+91NdLT7///qLdI/v25ef56oFSAlPCqDp04E33yRWdHvnARVAlCbH4GSiHIpqivw8EnWkFqBbbyXui6FDiQiaMIHUjvElSu0PuJuTNy1Avx37TfR6THvPVvMeU787ZyxAf/1FLG/r1qmvYzAAjz5KpnsnR6uvKIGFBt+xtwK//gqc4nvXedIFJsQTAshZ/v2XWIuGD3d+G+4mTC1A7mMwW/Dd9vPYdppc68prTLhUUYe77yaBydz3IRX/QgGkFAQN8L+fykr59i+/LBmHixYgJbz1kGWv4zsVQJQmR6uwVgCAwppCB2v6D+lFXyiIbryRVOGNE2f2e5333pPP84UFSFjr58EBD3o8cL0V+Tk49UR69dXApUskjd4eH3xA/ofotIjU68CywIqVpOu6Q6xWYPly213Fky4wIb4UQIcPu74NJ7i9XW+oKXMopxxFVUbsPFeC/04V4aut53A8vxJt2gA9eohjgISWTWcsQFwgdFmZ/PyXFlYMNAuQEHsihwogSpODqwlUZaxCnTkw20xLLUC+7PmlRp8+wOuvi+dxQdnetAAJBdDQlKEN3p/UkhUbS/57q8s5wzA4fhzI3EN0jdo4bMS3Jnedf/8FQFxg648V4FKlZz9kXwogd34fVAA1HKH3dPd5ebVD7gZ/+jQpUs7BBTKbTPz3IBUD3LlfWioXQEUS47onLEClpeLgbE+hsXPHpwKI0uSI0cdAy5CraqC6wTQa8YkZKHEQwqfAtDTgyivJtDcsQGV1ZThaeNTW+uKRQY/Y4rcagvBi/Mor/HvylogL1mkU3WuqAmhSfXTpvr22iOqDOeX4JTOnQePYu1f8OtAFEHcTDsSaWE0F4Q3+4kV+mrMACV1bagKovFxeEsEZAcTNCwtz7rf4669Aly7Axo2O1/UUVABRmhwMw6B9C1IMYm/eXgdr+w/hhT8QLECAWIgtW8b7z70hHgYsGoAen/RAWV0ZUqJS8O6Ed8F4IPdfONYnn+Qz2d5+2zvF/a7oqZyuZzYDYBXeT9u2QNcu5PF9+e+AlSglg8mKGqM8Td5Znn5a/NpfAsjZ+lXUAtRwHJ0uajd4TgBxDzYMI38IE5bkkHaFVxJAFot4PieAkpP5ekTO8PXXzq/bUKgAojRJxqWRfmEbzm3w80jUCUQBJBxHSAgvHjxtATJZTKI6TVd1vgoaxrlTU9WyUg93M9ZqyWcsbPL60UfOj9HZG3lcRAg6h8fbXnOWjdhQO33MJk0CQvVAbh6wdZtt9rqjBerbOEAa7Omp35SjzxsQCyAuu8vZ/VIB5D3CwgBoLYhIvwB9Kh8TKbUAhYXJxVRQEC+ihbWvALkAqq0FRo4E4uOBY8fIPO43odc7rkAutIY783vzFFQAUZokGYkZAIAjhUf8PBJ1hIHQgeICE44jJMR77qOscnERkCEpQ5zabt488jR55oz6OsILLyAWAhUVzo9x61bn1xXCCQDWaufxPDIKmDARGtYK7NwJWIhqyi1z/4OWWnw8VUhTerNTQvj7cDbWigqghuPoKw4PB0ISyqENNyAkiY9c5kS6WgA0B2cFKpDocqlI+eorcr5YrcC2bSSZ4NtvybKQEOD990lNHq64ohTh8b0hgF54gRhe8/KAzZuVj9tcoAKoGcAVSlx3dh3K68r9PBplAtECJBVAQguQJ1tzcG0vOJyt/fP660TE3HWX+jpSASS0ALkiClxJ6Rbe9DkBZLUwKs3C6unVE0xUFMnPP3LUNj7V6tIOEAogVyxdjjhxgmQlKlXF5igr46edFUA0Bsh1LFYWmRdKUVxFPmRHv+eQEBb6SLmiMJnIvoQp8GTfJbZ9A+oCSMrPP/PT8+eTchLff09e6/VATAzwySfAEJXnHOE56g0B9OKLpLBnQgLQQdBAoDmWYKACqBnQLY6vFH3ZV5e5fVPxJlJ3UyAgFEDBwfwTktVq/wboKitPrhS9TmuR5tL2Gzeqx/PYswB98IH4Zu0phC5CW4E/xv7VdVzPNrBk9CcvDpEK0bVGC97/5xT+PpRnZ0tlhAJowACXN1dl5EjgiivkMUZCDh3ip++/37n9UguQ62ReKMXmk4X4dvsFp9Y3W1mERsgVRW1kMT5cfwpbzxFlExoKrD2Sj80ni/D3Yd7f5awAEiJdV3j+qX3X3hZAAC8W27QB/v6bWKqaI1QANQNahPIRfEcKj+DQpUN21vYPXbrw04HiAhM+UYaEiONK3Gkloca6s3ylwXFp42y1m1zhxAnl+VIBJIwvqK0FXnrJ5UMB4FsDSMnKEmeu9I1LRFqrcHRvGScyABlyYwAQMXnL4HbomRQN9CJtW3DmLFDNp9qcyHc9WlvYfLKPm71kf/2VCKlrr5Uve/dd4Nw5+XyLBTh+nH/9yy/OHYsKINfJLZMG49k3AVmsLELCFSxALUh7+BOF5dBF1yAqCjhe/5srquQtQFzwslP1rVQQihu171o43xcxQFdcoW6NaupQAdQMOZDfgDPYS3Bl6oHAsQAJCQkhwowbmyvxM/YorinGiWKiXoqfLMbaW9e6lf2lZr6WCiCp5UoYA+AKf/whfr1wIelq0a6duLtFNCJwTXoSYNGKssCsBjJgs5mkzgMgd5ikJOJfPCJukeHIarl3r9j1wL3vp58W33Rc4brrSNrzb7+RSsJSPvlEPq+mhndnuQIVQK7D2vWpyrFYWQSHyLfhrJQGA6AJM9iaHksRZoI5i/RUdsYCJBQ9vgyCbo5QAdRM2HXXLtv0T0d+wg8HfwgoV1ibNvx0oAgg4cfDjYm7OHoqhZyL/0mJSkHLUBfyYyWoxZpIBZD05ty7t3vHk16877uPVNuVUlJC7gBcO43wcPLEyd27iooAjfAmwVmBDomtlI5aZGRkANOm8f3KOIsYV8HXXTiLmZJV0qpQtNrZHmtSaAyQ60g/f0fPDWYrixDR90h+U9xDgdEIgIFMAFUZzLhUWecRAeSMBUgoerxZdJVCBVCzYUDSAHwwkfQsWHlqJW75/Ra8ufVNP4+KJyaGnw4UF5gQ7kYYFUX+e8oCdL7sPADYajW5QoKg5I4jAcQJOKkAUrqJA8BDD7k8HEW4minvvkv+BwcDAwcCQ4eQO8Py5QArrA/UowfxZOTkiD5ks9pAJRw5QgI8168nr921/khR+k0qPT+4K4CoBch1hB+/0Wx1mAVmsbJoLwyv0/B7sFjIOcQwrEwALdp8Fj/syEJojLimwU03OR6jOxYg4U+dCiDvQgVQMyI5Kln0+pM9CjZ8PyF8Ug8UC5ASnrYAcQIoNSbV5W2FN2A1AcTFKnHxS9JUV6VYJqsV+Phj9eMKM0cckZtLCjxzn9ekSfX7SCN3hsJCYOmPgg0iInhzoCDIxmRWtwAJbxgMY+uqAcBzvyWl/ShpMmH7D1egAsg5sktq8Pu+HJRWG2EVnAALNpx2WDbBbLWKXZkMv31dHWcBkgsgDlYv3r8zFiF3LEDCc5kKIO9CBVAzQiqAakxuPq56Ac6yAgSOAEpTSMbytAWIq83UsYVzqe9ChEX21C6UXNl+7qI+c6Z4uZIA4oq3qbFpk3PjA4grihtDUBDQsf5tCr/jLZsld4m0eoV1li9wZLJjARK2Jrh0SbxvT2XrOeMC27IFmDPHvf1TAeQcv2Tm4HxRDf45ViBz4Z8ssP9UYrGyiAjnA/gZgQWopqZeeGhY0bVIiNQ9KbRaA85dt5yxAJULKpVQAeRdqABqRiRFJoleF9UU4VTxKZW15ZgsHsz9liCsQhooLrBevYClS8VFAD1pAcouz8b3B0mBkH6J/RysLUd4c1ezAHHj5NLCQ0OJcOKKsCkJIEfdzJOS7C8XPlmvWsX3XdJpeaEjLAwdJI17Sat3B549C87RcaG4WjVmTShG580DSgR9MKWdut1F6Tf50UdiEbR7t/v7pzFArlFlMLtci8tsYWFlBZ+x4HdaXQ0YjMRio2YBste0uVMnkgEpO6bE5WzPApSQIO/9RQWQd6ECqBmRGJkom/fG1jec2nbm7zMR9XoUJnw/AU+sfQJmq/t9mpQQumYCxQIEkMDaoYKm7Fwq7H//NXzfPx35yTY9OHmwy9u7IoCEF/WgIOC228i0OxYgR7AW/rJisQAHSVkf0QU/RHAj0OkkFqCUFKKKqqqJSQfA5pNF2H6mWPF4UjH6xRf8tKfqHKmJcs5Lt3Ur8Pjj8uU9ezq3f2oBco0grUbU/Z2HJT+4lStJBcKVK4A9e4CqKmSX1MBitdoyJoUWoKoqwGQEgmKrEB6urKy0uvr5WgsAViR+jUblTEypgBFe26QWpI0bSZ0pe9tTPAsVQM0IDaPBjT1uFM3748QfqDbaL2pzvuw8vjv4HerMdVh7Zi3e2f4Ofjr8k91tXCUlhZ8OFAuQElddRf57okvzntw9AIDHBj+G2LBYl7cXusDUBJDUBcbBxQQpCSClJ1kllFLDAcBUFAlzaTjqzscB4Ovi6PJI1PbILq1ET8Iyq4dWR/LpAeAsHwe081yJYoNUqTuSE1yA84UIHaH2m+RuelOmKC93NiWeCiAxpdVGfLHlLPZnl9nmmS284gjSMvI0eKuF1Gf4/XcievbuBfZkEjH0wQfY+dmPyM0pElmAtFEkDKCqCjBbAE2oEZdQAiWCggAm2ITowWcQ0SdbJIAqKpyrpCz83SdKnke57W+/nZ9HBZB3oQKombH4msV4f8L72Hz7ZqREpaCopghP//M0Zv0xC1XGKsVtcipyZPNOFp/06LhatSI1aXbt8lzfJm/QvTv57wkXGCeAJnWa5PK2FovjIOhdu0jfIUBdACm9D2ebvQrr7ohgGVQfTUKbYGIu41LSQ0zheHhMJ/Rr20J0IwiSWoAA2NJ1zp0VzT5bKFdsavFYDz9MXBOeQE0AcQJHzdLkqgCiLjDCxpOXUFlnxobjl1BeQ0ydRoEAYlmILUAlxcB33wH7D5C6CgMHAqNGAcOGAUltyBexfQfw4QcIOrIfAKBrUYOIXuTaVlXNfwfFFuUflE5HLEQAoI2oEwmg8nLnBJDQAiTM4uT2DwCffkr6iQFUAHkbero1M8KCwvDI4EcAAGPTxmLx/sX4eDdJ+dFqtPj8qs9F6+dU5OCjXfJmSoU1hbJ5DcWVflP+ghMSVcpa0WlKa0txppQE+XLNal1B2mVcSQANG8ZPS5uDtmtHrA0XLxKXVze+W4rtovvKK6SU/4cfKo/BkVBNSSEuIq57dkgIoK0v+iOM+1G86XNxQOcvkCd7DTGNGMzyYGi1qtwZrn+sqqgJIM4NqSZ0nBVA3HrUAkQQaB38vi8Ht1/WXpQJaDBbSf0oiwXYuAHYtp0EZOl0JMCtc2fB3kaTjsEbNwI5F6E7tBdAOkLbizvCcwJIFpNWj1YyX/gAYrU6J4CEmWOtJAXfufNArwfGjyfTVAB5F69agEpKSjBjxgxERUUhJiYGs2bNQpWDO0ddXR0eeOABxMbGIiIiAlOnTkWBpKFKVlYWJk+ejLCwMLRu3Rpz586FWXCl2bhxIxiGkf3lc1diCgCge6vuoteL9i7Cdwe+w+J9i1FSWwKWZTFw0UD8fET+qO8NAdQY4IREXR2pfuwsxTXi+JU1Z9YAALrEdhG1KnEWaXaTkgASriO1ALVuzffIktQctFmAkpPlcS3p6eLXU6eqjzGOeMBsHdSFT7/C2j+cBWhc93hE6nVIjQsD4hOAsPqIbS6KGiSTR4raTWL6dPWxuYpaXJqawOHqOSpV8mVZ4ORJ8bbUBcZTUm0UNSEtVbAA1ZossFqspF/Jf/Wt1zt2BO69RyR+OraOAMAAHToCs2YBU6ZAx8hFtNUCWGyB6MrKXif5bqxW/vvSaMRtZpTo0oVUF+eQftdCASWs3E6rQXsPrwqgGTNm4MiRI1i3bh1WrFiBzZs3Y/bs2Xa3eeyxx/DXX39h2bJl2LRpE3Jzc3Gd4FdjsVgwefJkGI1GbNu2Dd988w2+/vprPP/887J9nThxAnl5eba/1q1be/w9NmY6x3aWzZu5fCbu/PNODP1yKDQvaZBXpdyIsrC6eQsggFQ/doYlh5Yg7q04fLCDFKI0WUyY9ecsAMDUbnYUhB2kAsjRk6JSZgtXe0nq8hJWjxbe+F97ja+0zLFoETHZCy1IHNwTbmH9T0VoRVEKM+2ZFI27hqehV1I0MS+lctlgfBwQVxAxp7QGP+/OxqXKOtt4uRR7AHjiCc/GkjmyAAlp25YXx0oCadkycjMUpsxTAURgWRbfbDuPGqP4rn8wpwwmgQCqq6iC5ZtviPlSqwVuvBGYMQOIjRNtN6pLK3SO5378DNCnD4LGXy47rtXq2A2pDRL/aq1W4rbv1Qv45x/77+v++4H9++XnobBXnPC4QhcxtQJ5D68JoGPHjmH16tX44osvMGjQIAwbNgwfffQRli5ditzcXMVtysvL8eWXX+Ldd9/F6NGjkZGRgcWLF2Pbtm3YUX/lXbt2LY4ePYrvv/8e6enpuOKKKzB//nwsWLAARolfoHXr1khISLD9aRxJ9GaGveJ7XH8qIYuvWYx/biVnekG1Cy2RmxDuZKjdu+JeAMCjax7Fhzs/xDVLr7HVYLqhxw1ujcMZF5gQqQsM4DPv1ARQaKj4/aamyt9/ixbAvffK4xkA3gLEiQTRtpL4pfx8XixoON8aV4hJUBDRXN8SY9meHFwsq8XfB/Ns7z1WEEcubFzrCRzFAAmZNIkXl0rL7yU/B3wk8CzTGCCg1mhBtVHZ3LH+2CVcKK4hqVo7d4L94ANUHT1J4n2mTlVW4AAYhoFWctnXdZFX8nRGAGk0cgE0dCgJur9coqmk7uGwMOWq5MJzQmgBEs6nAsh7eE0RbN++HTExMejfv79t3tixY6HRaLBz507FbTIzM2EymTB27FjbvK5du6Jt27bYvn27bb+9evVCfHy8bZ0JEyagoqICR44cEe0vPT0diYmJGDduHLYKi7lQAADtotu5tH7v+N5IiSbpWseLjuPyby5HeV25g60o8RH8b/WR1Y9g1elVAAAto0WfePdalTvjAhOiZAHiai9JBRD3WmoBshfzoyQMpTEOwnV6CvqGbd9OMmI4l5VMAGVn2brDS11gtSar7QYhFECeFhKuCCC9nj++0nKl2kTNPQbIZLFi4aYzWLT5rOo6RWs3AO+9B6xeDdQZYElIAG67XVX8AKSrikXi8VKK1bFWVMJsiwFS/qFrtHIBpIawrhmg/nvUqcTC6XT8ayqAvIfXBFB+fr7M5aTT6dCyZUvVWJz8/HwEBwcjRlIgIT4+3rZNfn6+SPxwy7llAJCYmIiFCxfi119/xa+//oqUlBSMGjUKe/fuVR2vwWBARUWF6K+pE62PRgs9iT9ZPWM13h73NvQ65eZJY9qPQd+Evmgb3dY2b+P5jUj9INUXQ23USCtwc7w57k23Or8DcguQo4rHSgLIkQXIFQGkJBCkjUiF+0ppy2fUcc9Dv/5K/ttuMy1akKqLVhY4RKozmiR3s2Cdxib+uBpN3sAVF1hwsH0BpERzd4EVVtpR8KwVWLkShk8+BWrryO9i8mTgnnuIv9EOGoaRlU5QEiMXzllRSEpOocqsPBaNTiyA7BVilKa4uyqAAN5iRAWQ93BZAD399NOKAcbCv+Nc4Q8/0aVLF9xzzz3IyMjA0KFD8dVXX2Ho0KF47733VLd57bXXEB0dbftLERamacKceugULjx6ARM6TsDjQx9HzTM1GJI8xLbc+rwVuXNyseaWNWAYRiaQyurKfDzixkdprXI5YlctcBwvvyzvvO5IANlzgUkbeApdYEKvsSsWoNGjgfbt1dcJ1mmQrKwLRfVe0KcPMi4eIwEUYFFrsqDOxLtJgnUa23g97fYS4sgCJDx2SQkVQK5SWWfng1q/HtizBwZtMDDsMuChB4H+/QHG8e2LYUjVaCFBQQBrEn/Q5Syv1tW+A0bDQthx9Z571I8rLYCptk+h6JGuw52fTVYA/fADMH8+ydDzEy4LoMcffxzHjh2z+5eWloaEhARcqq/iymE2m1FSUoIEpYABAAkJCTAajSiTFNUoKCiwbZOQkCDLCuNeq+0XAAYOHIjTp0+rLp83bx7Ky8ttf9nZ2arrNiViw2JFVh2GYfDS5S8BAIamDAXDMEiMTIRWw5+dc4fOFe3DYm1eaQpcjQ6Gsf8UCJB2IwcLSGW+zNmZeG8CL8L7t+mvtpldnnvOuxYgoQvMWYQCYdkycs+SjlG4TvvYcMTrI23FEoWYhW6unj0xIvcIxmz7C8jLx9nCany6kb9ghmg1IouVt5AKIO6z4z534e8gJ8d5AcR91s2xFcaJ/Er8vi8HtUYLiqpULEAHDwJbtwEALt06Cxgz1inhI4QrvWB7rQWsBvUPWjUGSOACu/56oHNn9ZO/Qwfxb0Ztn0J3nFrj1CYrgD76CHj+eWDNGr8NwWUB1KpVK3Tt2tXuX3BwMIYMGYKysjJkZmbatv33339htVoxaNAgxX1nZGQgKCgI69evt807ceIEsrKyMGQIsUoMGTIEhw4dEomrdevWISoqCt27d5ftk2P//v1IlNolBYSEhCAqKkr011wZmzYWh+47hN+n/a64/M1xb6L2f/xds9LoodbojQQu9ZtlHV+cPt39KViwSI1JRc/WPXF99+sRHRKNa7tei3Yx7lmAlHBHAKnFAKkJiji5VrEhtO5w1qbLLlNfR6Nh0LtFIgwXxX4roxEIDRI8CoeGAtdeC63VAhzYLzuu0AXmzRYqwvgiblgAL1yEYi8oSC6Adu8GHniAWIeENzouHqi5WYAsVhZ/H8rD+aIabP//9t47zqkq//9/3fRM773CAEPvvQsIiiiKBSkr1dUfiIAfXd1d3f3uurKua1ldV0QXKwji2rBQpIvUgUHKMLSBGWaYoUyvaff3x8lN7r25aTPJJJmc5+ORR5KbW05ubu593Xe9dBMllRKNmUtLgc3fktejRwN9+ri9HRnD4PYeSYgLEypYVm9fAMkl0uQBoPBWFdTJ1QDIMS7dioOgVluLmQP2BdCQIeRr3XWX7WcdWgAVFxPfN8MIawO0M16LAerevTumTJmCxYsX4/Dhw9i/fz+WLl2KmTNnIiUlBQBQWlqK3NxcHD58GAAQGRmJhQsXYuXKldi1axfy8vIwf/58DB8+HMOGkV5Jt99+O3r06IG5c+fixIkT2Lp1K/74xz9iyZIlUJvPgG+88Qa++eYbXLhwAadOncLy5cuxc+dOLFmyxFtft8PRK6EXEkLtlw3QKDRQy8n+DrZA6LAw652byBgpoLalFmuOkcKSfx3/V6jkKqRFpKHi/yrw+QP2yii3DmcCSEocuBIDBAAffQQ88wwprGsP/t0uJ4BCQoDly+2PQcqtVF0NZMaKIkjnzoXSZAROn7aJPGUY4Xi52kYPPWR/rK1BfO/E7RuDgQhhTuh07gy8+qr1gseyZMhDhgD/+Q/w6KNCaxFXxDHYBNCOAusfp77FaOsCq68HNm4g/Sm6dbNNs3IRhgGSIjWYOzwLyZFWRf+3v8iRkQmMGGm7jPzYEUgVapDLAZmG/NEUCsDkwPw7apRrAkipJN7db7+1/Yw7xlytzB5QbNpEnkePlk4hbSe8anBdt24dli5digkTJkAmk2HGjBl4k1dWVq/Xo7CwEI28IITXX3/dMm9LSwsmT56M//znP5bP5XI5vvvuOzz++OMYPnw4QkND8cgjj+Avf/mLZR6dToennnoKpaWlCAkJQZ8+ffDTTz9hfCv/RBRpIjWRuN5wHdXN1ciE56wZ/o5MBuTmkgKCp06R9HAxLMti+H+H42rtVURrogX1ftSK1psq7J1zxa4W8XxS8TtSMUBGo1VMcZ//5jfOxyVlARK/FmfGiN8DxCKSkMAgOVKDazVmZTNpEuThYaRfwZXL1jYZZjgLkEZDmpJWVgKiPIk2I14f3wXGF59Hj5ImlzW8ewL+b/P118L1BKsAOl1mTTIxmVjbvl7btgJ19SSV8N57W90fh78U3xXWv7cc8yPI/1eM/NB+4NppoFdvICICCA0BUlIh5+XTK5VEALEsi62nKxCuUWBkThzOnwdO/MqiKbEUEZ21wE/EdOjod7X31TqsBYhlyV0V4Pk7FTfxqgCKiYnB+vXr7X6elZUFVnSm1mg0ePvtt/H222/bXS4zMxM//PCD3c+feeYZPPPMM+4PmOIWkWoigGpagssCBBCz9cmTpO9iXBwwTNTM/UTFCZy5cQYAsPH+jdAqtRJrcR9x24fISHKxFVuA+IaSjXb61kpZgPgn29bGAPFFj5b3tcVuOKnAbK5Q/Nhu8dhwuARDs2MAlQrK2ycBp26QKxZPALEsETwAEWFKpefFD2A/BshgELq/uPn4d/xSFZ85uN8zGGOAOIwmVuhOKioiWX8MQ8RPG3yb/CxLvsUmVE12tFRpOLmcAYpLyINDo4YsdyCASQDI78SyJHut4BoRcyM6xyInhwEbXo/vfm1Ec1QjgFjL/O7C/f9eeAEYMwZYupTUmPJkhXOfcPAgOXlqNMDDD/t0KLQyIKXVRGmiAASfCwywxoSsXQsMHw6IY+b/tu9vAICpXaZiUudJlul5ecCCBcA16QLbThE3L+WSFcUCiH/RnTxZel2cBYZvAfKmABILHqm4JE4gJEdq8cRtORiRQwKP5HdPIx+cKbD2LAAREF99ZTuGtlDZoJMMyuXHadizAHGuUXsCSEywWYB0BhOKbwnjfYwsaxUnRiPwo/nmdvAgW98jgLhwqyBKjHB8kPLjn/VGqwBSK8ilT1IA/XYhMGkSkJVJLFBaDdDcAnm+tYyKTEYElZ6n3LgaVdwkfuuM1ggg7hg7fpzc6Hzyic/1gmdYY+43+dBDwuZoPiAI7zconiJSQ1JHg9ECpBUZdC5csIqRFkMLvj/3PQDguVHPCeYbO5Zc9IqKgF273N+uWAClpRGjiCMBZO/kywkQ/jo5a5BS6d7FmG/G5wsbvpvLFQuQYNx8l8PwYUD458QtcvGSpd/TsePWC5CjeCxXMZlIKwYA+P/Gd4aadxX773+t1iVO6FRWkumWMSuEz4BjFwYnPoPFArT5RBmKRQHPLMtaXbYHDwI3bhK3k0TIQqf4UDS0WM1oWpWtgpExjGR8Dr+GlMqRAIqKICWeR4wwD9AEXCoC8+1WwOy5kzU32gRB640sFHJYKk+T35IFwLTJAtShqKmxmqSdtMVqDzr43837GI1G6J1FoHZQssOykRmaiebmZjQHmKNaqVRC3obbbfHJiX+CO1FxAk2GJsSHxGNE+gjBfNwd/+7drduulAACWieAuERHfs3P1qaUP/IIKdKbmCgUPe5agOxZS+RKBSl+dPAQcOqkRQBd5Vne5s1zb8xSGHkXziadUSCAEhKsbSy4/qxPPmldVqWyCkH+hZVrBisFdzxwv4HUPulIiMUPQCo1sywL1NYAe/eQiZNuBzS2buNBWTHYU2jtQ6iVaN0+IDMKRy+T9Dq+CywlSovqRj1UChmUcvsCyOb/wsiAzp2h/u08dPpXCUw6AyK2/4yiUakIDbOOUWc0QQu5pZK5QgGokquhuxbdKgGk0OohCzHB1OjF9Mb25tNPyV1Wz57EdO5jqABqJSzLory83KZmUTDxm4zf4L6k+xCljEIRr19ToBAVFYWkpKRWVWMWCwS+ljp/6zwAoEd8j1ZXeraHOAaIq37sKQtQawVQTg5w5YpQBABC0eOKBche52uFXEaqyx08BJwtJD2hlCrozd/zuedIBlZb4RsOpNKc33mHPEtl7oprusjl5PtwcU18uM8aGoTzBGP1DeICA7BtG6DTk+rOfaVT3uWi/1NKlMYSg8MRopK+sRnTJR5hagW6J0egwVwcUUoA2f3LhoRg7oI64P2PgUsG/PSXfyN87mzLx1wBT4YngLSdbsBQGQa5XKL/hhNuxhYhvD9Qe7iTw7T9gOL998nzb3/b6sB2T9JB9mr7w4mfhIQEhISEePxCFwho6jS41XQLcSFxSArzXSqju7Asi8bGRkstKUf1oewhdoHp9eTiyTDAxSpSqC8nJsdmOY2mbVkd4uKCXEVoRwLIXg9gKQsQ5wITfz9XENfKAaztvAD3XWB8FDKGtMWIjgKqqoFz54GePS3NUaX6O7UYjLh8sxHZcaEWl4cz+K4TR2nOUtsTxyCpVGR/iq12APkqxcVEAPEFUjAKIJOJhenGDeD0GfIHuvNOCPO3rIiLGvZIjsCOAmtNuLgwFVR2LLtalRwjzTFlXDVxt/tjJyYCs2YBn30GnL+Aug8+Ip3oNVpLjBF33HDDYFQGqNXSAshkYiGT2X5XlmUtNy4yrQ4mAKyJAYxy6PXSx5/f8+uvJOdfpQJmz3Y6e3tABVArMBqNFvETK3XWDxLUejWgBxgFA02AOay15iv89evXkZCQ4LY7TPx1x40jLu133wUKbhYAkBZAbdXJXMp3YiKwbx9wnhibBAKoqcmahq1Q2N8mJ0jq663izdNVlXN4u0As0qTS4O0KIDkDgAF69gJ+/pm4wZwIoC2nSOXoronhmNrHNZErtADZF0BSVjXxGCIjyW9x44btvHwBxAlQlapjxH3cqGtBfkk1hnaKQYTG+ZXaaGLBcn0au3RxmManEIkFhVyGST0Ssf0MCQDLSQh36T+WHKlBv4woFBdXO59ZTHY2uYCvXw8UXSYp3Y88YokxunyTmGm5Y4RRGRAeDlQ16PBTQQWGZMcgMzYUlQ06bDhSjN6pkegcH4bkSI3lRtpgYiHnHWMRQ0iT2Jr9XdHQQEotBBxc7Z+pU73buM8NaBZYK+BifkKkzuBBhJwhosFgcrHhkZ/B/X6tieGSspBwyQ2HrpLunq1tdeEITgB17kyuFdxFl/sKRiPQuzepHQc4DqrlBBDLWl1rR46QZ09diMPCrNsRF/KVulDZtwCZT1W9zU2Wzl8Ampstnb5VKuDC9XqsP1Rsaax56Qb5UucqHFcqb9IZsfFIMfaeu4FDRbcs0x11+5YSXGLLHpfgIhWczQnD8nLiOgQC3/pTUtmIb/JL8enBKzhVWoMffnUt1dHY0kJSnQBg4ECH88rlDIZ3JjedvVJJEoYgAF/j2j09wzAY3y0B0SFWs13jeTfqJ2RmAgvmA2GhQHkF8NkGbDlegut1zfj1Krn7sFhwlEaEhQHbCypwtaoJXx4jAWT7zt9Ai96Eo5ersPFICc6WW49TvdEElfkYk6n5fwpW0qUaEHxPEkNw772+HQcPKoDaQDC6vfgoZOQf3mJ00MnZj2nL72dPINS21KKomsRDSQmgth4ynAuMK40iFkBnzgh7CzoSQHz9vmYNsGUL8NRT5L0nLRHFxcDly67V57EbBM3d+SckkNRkoxE4exaNbDMgN0KpJBlGFbXN+PTgFeRdqXR5fIeKbqGsuhl5V6pwnGcRMDhQQFL7tVYYiuJQAA0YQJ5PnCDFcIHAD4D+Iu+qRXQCwHVHHd556I7nky7vUVFAF6IMlXI7LjCGQXZcKH47thMmdieV6mW8P5WrAohjQEo8WIMMjWeTAdbNP2diEjBnLvkzFhej/qNPsemXS5aP+RagsDCgXlTt2igKMuMfe3oDKy2IGdYmDjAgqKiwilx7dTl8ABVAlFYToY4AAwaN+kYcLTuKBl3b/pk3Gm7gRPkJNOmtlfkMJgMuVV3CtbprNkUzfYk9gXC5+jIAIFYba6mTZI/WfB1x3yvuJMsJoGPHhPM7EkB8MXbpEvDss9b3rYkBskdUlLAtgCOcd09nrK228/NhNAHqlGobi8zecw5Sr0Q066WFjvgCxceVrB4pAZSSArz0kjV26+efrZ95Io3fn3DkQrTCQn/QbHbkdXiPClFhxaSuGNctXjA3J4RDVArLDYxQALkXHJMeFYraQ52hvxUu1f0CgJOblsREYNbDgEoJXLoE3afrSIA+eBYgNRFA4lAf8fHF3186o8l6w8Dw5mOsMWOlpUD//sB77zn5kv4A19+zXz9yE+MnUAFEaTVKuVLQL+xKzZU2re9KzRXoTXqU1Fpzm69UX0FlUyVK60r9qumqPYFwqoLE/2RFZUl+zr8mOL/Y2yIWQGILkDjl2tmFeulS62t+BpavYlHsZYEJ6NePXE2uXIGpsQWMwuRWUCjLsjhXUWfJBLInrA0OBFBurvPtiAVQjx7kovXccySMRMzEic7XGUiId6ukoLx6FWx5OTlQB/S3TObcmH3TotA92WoKEccAAcL6PmFq9yxA5P/BrdO+1ckhGZnAnDnED1t0GVi3DtC1WIKgZRo96kyNqGoUutrFAtHEspZjUm80ISMD6D9AJMAY1hLf9+yzJKbYD8rpOIcrejZhgm/HIYIKIEqb4Gd/NeobUVJTAhPrIHjCDjqjNb2pxdCCkxUnUXCjAFXNVZbp1c3VbRqrJ7EnEN47SqriSQVAi2lNNhgngLiMI7EAEmeJORMG3M2YXi/MyvJVlolLojAiAsjtDgAwVtWA1cvdGu/Os9fx/a/XsPcciU62p3NMDgSQK+dxsQDijzEjw3Z+rj1SR4UvVCxwQWe9egFa25hKmYzBsE7WgFlxFhgAi2gA4HK2H4fgBsHOzy2VpWVDegYw1+wOu1JMssTMFcsVEU346eJVwewsy0K8O27V67Bm7yWcq6izBPdPuA3Q5lhNgwxD+uUBti5Xv2bvXvI8dqxvxyGCCiBKm1DKleib2NfyvqKhAtcbrjtYQprKJmvMRouxBS3GFjTohS61ZoP/FFu0J4DyK/IAAHP6zJH8nH/T1xoBZC8GiBMO4nhuZxYgvoDiCyBf1fZ0JICGdeJlXJqDZY21DWANMpfbYLAsawlSPVteh/oWg6D4oWAsDgRQfLzttCFDhO+5GCt+lheHeLy/+12AZva4QYtBdMVvqAdOnyavhwy2u1xUiApDs2MwpmucZNyeUiR6pESSK7B2YoBcXl9aGvCbuYBaBVy+AtOO3dYxigS6zmiC0U6M2Z7CG9CZ1ZHN/1fGWnrfOQrS9yvKy4Fz54h6GzXK16MRQAVQEHHjxg0kJSXhpZdeskz75ZdfoFKpsIPz0bYCpajIV12L+64qvgVIDGdl8icBFBoqMVHZiGodOTuNzhht8zHLCi00Gza4v11nLjCxBcgdAcS/tviqRY8jAcRl/wAgPqToKMtFwFUL0P4LtwTv39t7ya4LzFEMkFisvPyybWsTTiRzLgtHIi0g67q4iY0F6Hg+KQOdlgokpzhcdkROHAZmSqdO906NRI+UCEzrS9bRJYGklA/Kcn4QC/4vdn5uKbebXVJSgYdnAQoFVFetxWHFus1oYu0eX/UtBjTqyB/BtjpHAAog7o/Rt6/Pe3+JoQLIE3B5xL54uBFJGx8fj7Vr1+LPf/4zjh49irq6OsydOxdLly7FhDb6Zvsl9WvT8nqjtMlBzsgRF0KKl+mMula517yB5P84gsQuhanCEKG2TeEwGIQnrWXL3N+uPQHEncjbYgHii48kH9W1dDkuimGA/gNghFwyBkijFF45Nh0tQWl1E45cts0Osxes68gCpNEIrYB9+9rWNeLixKQsQICwh1hHF0Asy1qKD5onkM7AADDIvvXHFZRyGSb3TEJOAjFhKuQyzBySgdFdJMx0Inr1Au6/nxuTtNCRuZC6GRPK+3EzM4G5c5AuK8NA5GFS9xKb+fVGFkYHp26uuKNYADEMAk8A7TG3Nxk3zqfDkIIWQvQEjY3SZW3bg/p6O+YIae68804sXrwYs2fPxqBBgxAaGopVq1a1eRgKmQJdY7vi3K1zqNfVw2AyWNLkXUFcSyghNAFxIXFQypRQyBSQMTKYWBN0Rh00Ct9Xi5MUQJHkRJcekS5pqm/xQLUAcQwQd9FtaiLXlLYIIP6y8+e3faytwa3A8P79YNpZC03mTeDSdaCv9X4uRCVHi8FouT+4WtWEz4/YXogA+xcSey6K67XNKK1uQlRUFMrLye/MCVKOqgYdzhhLoUqKRlN5FABbAZSaan3dkQVQs96IL/KuWgKbAZBaDdXVREX27GGzDD/w2ZswDKnPN3Ys8MuJ1rvAbu+ZiG/yy9CkM4u8jEzIJ9+Ou7Z8B1xUAnVPCOoctOiNqG1y7me2OY0wLG6ZjZh83T58OLEou5pt2a7s20ee/Sz+B6AWoKDkn//8JwwGAzZt2oR169ZBLT57t5JwVThUchWMrNHtlHixAJIzcoQoQ6CUK8EwDFRycvWo1/lHFTDJGCCzBSgtIl1yGamYn6Ym22mOEMcAcedUk4msqy0uME58vPqqbdHC9sKlLDCOsHAYtUT8N278BOFq6+2ywcS6dOcOwG4MUJPOVgA1641Yd6gYuwtvoFpn/UHF4mZX4XUYZXpoO1+3Ow9f9HRkAXS6rFYofgDgMCkWin59AYXtlxenv3ubTz8FZswgLaoA4C5e5XCFqCbRnb2Tcf/ANMv7tGgtkiO1tq6yESNIXJBOT4oA8qzXF2608jzGsBaLIl+fHzwozOj0G65fJ8XJAL+L/wGoAPIMISHEEuOLRyuqUV+8eBFlZWUwmUy4fPmyx3YDwzAIUZLxuFMckWVZSwyQRqGBjJEhPlRU/8NcdbqsrsxDo/UCESTTIyXUdQEkrtvjDLELjP/z19V5xgXGVZH2Be6WBjCGEAWoOH8GRu5OE8R642rohsGOL+Jaja06rWywKkyTzKrWVCrhOowm1mbfi0VORxJANgKHh8UqwnG9glTyZhjbyHEzXLf29iI9HXj9Ncbi+uVbfUJV1h9yQvcEdEsKF2SGjTWLNRvBzTDAnXcAchlQWAj88AO4QKNDl1wv1Clcp7Vqu1i3+2Vfbu4/2asXEBfn27FIQF1gnoBh3HJD+RKdToc5c+bgoYceQrdu3bBo0SKcPHkSCR4qTsVZahwFNfMxmAxo1DfCyBrBgEGPuB5gwUIuEzq/Y7QxNllhfofZBZakdV0AHTgAjBzp+ibEAkgmI97X+noigDxhAXKlyJ+3cFcAmUCOEyX0MP64FUjNBFJTYTCxZjek8xg5exfvqkYdfjx5DfHhagzKIgG4/Hihp59m8bffkUq/P14qhiEm0pKpJmMYm/0o1SyVI1AFEMuS/bz5hP0bE5uK2r8cIM/duwPR0oHNrUzkahNxYVZLeBMvXolfW0itIMcbf3jh5kanYlcZy7IkuPu+GcAXm4CjeUBEpLX0dytgZNZWGGIB5HZj1/bAj+N/AGoBCjr+8Ic/oKamBm+++SZ+97vfoWvXrliwYIHH1q+Wk5OIKxlbLMui4EYBzt06BwDQKrWQyWQ24geApaqy3qj3m4rQAuMbYwI6/QQAyA7vLjk/XwCNGEGey8vd26Y4BgiwusGkLEDOToqBLoA4l5lqUD+SVfPZeqC6Ci16E3TitGs72AuCbmgx4mx5Hfadt1aX5M/68CzyrEm7hfoWAw5ctGaYMYztfuxoLrCDl27h3b2XUNOoR7NB2ndZUtko3L+1tcDJk+Q19yeQwBdthpRyGbomhkOlkCEjxvrnDuUJII2S/KH4VkNuWqRW+COO4oKwe/QAptxBXu/cCWzcAFRXoVXwWmE0q+oRklsGRtHKzvbtwe7d5HnMGJ8Owx7+uMsoXmL37t1444038MknnyAiIgIymQyffPIJ9u3bh3feeccj2+BcYI36RqfzNuobLa4yOSNHZqT9CD4u1Z4F26o6Q96Ac21D3gIsGAlEFwHNERideJfk/JwAysgA7jLPUummJZyLGeKLL04A3bplawFylikS6AKI+36Kp1fCmJoKNDQCGzd6rZAR/2LOyFj861/AnVNts+YYxlbUdDQBdODiLTTpjDhw6SYYO1WUv8i7KnQx/vIL+dEyM4VR4H7Cnb2T8OiYToKWGvHhVssQl12YGq1FekwIBmZGW8Ran7RIwbr6pUdZ3wwZAtx2G1EpZwuBt/9Dyji34mauvh6oadSjLqYMyth6qDOI8LZNmfcxN25Yxa6fWoCoCyyIGDdunE3n86ysLNRwhUo8ACeAdEYd9Ea9TY0gDqPJiIKbBZb3PeJ7QK2wH4wtY6xavaS2BIlhbnRu9hKZmUB0nB5VS3kR0Wfuh8woHZfFCSCNxppF5q4A4u7++B5XLgFx0iTb84yzoOJAF0Dc91OEa8DOfJh0dS2vAL7bbO467RlLgsnEQiZjBFWjTSyLZcuA7meAU6XC+RnYusA6cgyQI4ONpd7NzRvWys9jWu8G8iYMw1gasd7bPxUNOgMSIngCyOwCk8sYQSA0AHSKd5IJPHo06aHy4w+kZcY33wA7dwATJ7medWC2AK3dX2TRTjKln1qAfiIWcfTtK1051A/wt11GCXDkMrklTd2RFYifzZUYmuhQ/HDwa+sYTe6kC3mPevU54YTy/nYrPHPT1Wogxhz6UOWmJZwTQFIWIMBqceZwxwLEaeP2FECbN5POFlz6rltZYOB9P5mJrOj++0kAya8ngZ/3C+ZtbYVgAPjWHOPCd79y1iApl6yUC0ycOdhRBBDDMA5jdkg9JRbYspX8YN26AZ06t9v4WktWXCh6pkTyk7egVjq+ZEaFOPkh4+OBub8hLiG5HKirB77+yloR2wkMA9TXm4830WHndxagLVvI8+23+3YcDqACiOJxOCuQvd5d1+qu4Xzlecv72JBYyfnEdInpYjG1i9PmfYVeIawsjMJpduv9cNM1GqsA8oQFyNHFszUWoPa8GN91FxGBDzxA3rtjAeJX1t5/1RyEm5VljbfYsQM4dcoyf1tsQUU3G2AysQILkKNK0UDwCKAzZbV2XWCAeT8VniO1f+RyYLL/XhCl4IsetZu9xiRhGGD8eODZ35EGsCyAr74k+8fpsqzdLDC/sgCZTFYBdMcdvh2LA/xpl1E6CKFKcnW+0XgDTXrbVOLSOqu/IFIdaRFMziDmaXKl0Jt81KyKh86oA+54gry5kYvO314GajKdWoD4LrBTp8hFv7kZ+Oor5w0OG81GNb4AclQJwVksKXfhPXMGKDNriPZ2gclk1m26I4AaG60Cj3MDMgww78kHrOnV//ufORWXJZaIxgbyRQ3uHz+b8kqkLUC8ebgGqgzD2NRG5SpDczjqDRZo1LfY/+GKymuAreaL4YgRNplfuUlWE+aE7gmYNVSiU6wP0SjleHhIBuYOz/RscLZCSe4AevYkLUE2bnRuCZKx0BvIcW/5r3CuMH+6mh89SmoARUS4l+baztAYIIrHCVNZz/x1ujpoldYzv7iVRbg6HO6gkCmgM+rQqG8UbMcXfHLiEyDpVwBAXKwCYUbix7FnAeILoBjeNWDWLCAxEfj3v0kcz7Zt9rcpZQFyVLfHWXUGvuWBu7b7IgaoNQJIplMB0CE0VGj+jw5VIenBe1DOADh0mGTenC0AQkKBS5fI3alCAWSkA336Al27SHYiF1NW3Yy+6Va5I9XcfPXei+iaEA4GRJjefTfwidkT58gC5HfuC09y4ABQVQ1EhEsWw+OLij5pUe03LjdIinSt+rwzeXRbbgJ2nuUlcTAyEqum0wHnzwNffAGUFAMTJgBKCVXMsBYr0LVrwo/8SgD9+CN5njTJr9U9FUAUjxOqCoVCpoDBZECLQagGLlZazbwqucrS58tVuDvw4ppiJIR6pnZRazl49aDldZW8AJ3N50hXLEB8AbRpk/UcsX27/e3t2QOUmLs58K0+f/gD8Mor0su4I4A4fCGAOAHgTAAxDNmPZw6FQqNmAOgQzuuawIm4EI2SuMLi4oEtPwJlvKuFRg00twCXisgDAGJjSHpeeoY1sCghHkhLF5jR+Kn1nAuM74po0ZtwsrQGXRKJOOcnOjkSQB2W2lprMbxJt0teDH1R86e9mNY3GZtPWI+9vulRQgEEkIP/4ZnAzl3Azz8T0X75CjBnNhAmvEFUKgEDw0qWz/ArAbR1K3meMsW343ACFUAUr5AanoorNVfQZLC6wEwmE2parBlnvRJ6CbK7XCFSE4mmerLOo2VHAZAiidlR2e1eO+RExQnL68UDFqNgJ3ntigAKFxm+XDl58TO8+MImMhJYuxaQKufkrEWdvwggcVd7e8wZlonHn6vF/96NhjbnOpQx0i5ALdcMddAgRPbtgZoz58lddkYGkJhAMsXOnQNOnCCBWLfMj+P5whXFRAPT7iaxRRAKIJZl0aQzouCard+Si4nh78ugFEC7dpEfNSMd6NVTchZf1PxpL3ISwiFjyoXlExiJ7HdGRqw+GRnAt98AFRXAunXAI/MwfrwGu3YBAwcC+ddYNPOqQQNWF6zfCKDaWuDwYfJ60iTfjsUJVABRvAKXCVbbUosGXQNCVaGCuJ2usV3dFj8AkBaRhurmakGhxcqmSiSGJiJU1b7VuC9VXQIAPD3iafxxzB/xwFtkujMXmFpte7ISv2dZx/E7YstOly6uzSdGahu+EEAaJ9YzjrgwNS7ujwerB7hTv1TXBK3K6lMaP6QLvhbHmSUlkceYMWSjJSVAcTFQepXEZjAArhQDlVXAxx8Dw4YCEyfhIq+Hk5Fl8cNJkR/CDLdfg1oAVVQAJ/LJ69snw56DKC7Mf10knkCtlAnagcgYxm7/OXTpQu5k/ruWiPSPP8boefPRtasSCQnAuZ1AFS8QGoCli73fCKBdu0iQUk6On3ZntUIFEMUr8Du2F9wsQLfYbhbBo5KrBCnt7tI1tituNd4SBFMX3CxATkyOpWK0t6ltqcWtJpIB9vyY5xGuDnd6Ea+oIM9SLXH4QuS228g6fv7Z/klNLGw6dZKez1mHE6mx+kIAcVacRvuVEyxw+4TVKwTvBevjCSCn9gWNhlx4xCpS1wL88COxEh04CNTXo2z6vZYNGk0siiulB8xtk78vxd4fftyPnxQ3dxmn1dhNRuDbb4lG7dnDYdHDvmlR0BlMyIh1v69hIHBPvxRsO12BMV25nmEAJ4c6xYfiRl0L6pp5vt/oGGDuXOCTj4Fr18Ds2Y0ksyVFG8IK+oHx8RsBxMX/TJ7s23G4gL/sMkoHQylXWhqYAiTzS28kFiCFrG1XWJVcheTwZAxKGYSU8BTL9AuVF9qtTUZeWR4AICE0wRLIzfXnsmcBunKFPEvdFPEv/Lt2kbjRGzfsb1+cURQVJXy/eDEwaBCJD3JE//6203xhmXBHAHFisfmKue+WxFmM69kESDSpdBWVGpg+neToy2TAyVPA119b1Iq9Fhr8MRIBxAqmiefxZwxGE365eNOmMazDCgAtzcD69STbTqMxW38IXCVlPjIZg6GdYpEcqbX5rCOQHKnFIyOykB1H7lr4Lr+7+6Zgwchs24USE4G77yGvDx4gRSRhFkBgUS8hgPxCRLMs8N135PXUqb4diwtQAUTxGn2T+lqsPvW6epTUkgjetgogPtGaaMH72hYneeQeYtOZTQCAu7vebZnmzALEZbhyAoif7SV18mppIZmkUnd7YrRa4QX1j38kRXdjnZRYkstJ9hkfX1iAOEHXZFs1wQa+BcjUpJQUQPxu4vaERrSzonUcPXrwRNBJ4NAhANJZYFZ4MUCMP1yZWsfxkmoculSJDYdLBNNtGpxyVFeTgLSLl4iSvn8GEBGBu/ulYNbQDMwfmWWZNTM2BAtHS1z8Ozj845FhGEFneYDnvu3alVSONrGWk4U2hAXDAOf59VfNh5eXur+4x8mTQGkpuaMZP97Xo3EKFUAUryFjZOgcba34yvX98qQA0iq16BLTxZISX9FQ4bF126OkpgTvHCW90ybnWO9uHVmAjh61CqDevcnzpEnCbDAxeXnkRtCee4uPuPJwZKT9ecVUVwvf+7sLTCBoGGkLkCtVnx8a7Ea9mdxca0G3HT8B1ytcsgDJZABk9ufjPEPDhrk+FE/Dr+FjMrG4WtVoCfa+VS9tzpTUP0WXgPffB67fAMLDgPnzgM45AACFjEFihEZgAUqN0iJC07ECoVwJ6JaySKZFkzuAPmmRUPGD2iZNIkFu5y8AF86bXWCs5DlGr7feSB0/bq3r1a5wxQ/HjbMNevNDqACieJUIdQTSIoQ9c7hCiZ4iUhOJrKgsAMQCJK415Gk+PvGx5fWQ1CGW144sQEdJwhrGjiWxgRxiVxaf++4jz9evA660a+NXfXaW/cVHLJYCSQAxchNkEjV0+AJIJmMwsbuwd1xKlEYQKM2hlDOCTuACBg0EuuQABiOwbh1MFfab8go0moNr4sWLRIBGR9ufxxvoDCY0643Iu1KJ9/ZewuEiUpI8/2o1Nh29iu9+5a6e1sGzLAu92ewlCOKtqgI+3wh8/AkxVyYmAosWAclW9zT/95jYPREZMSHolxHlte/nz0hp8/sHpmHJ+BxM6J4oFO8xMcDQoeT1tm1Qq412g9q+/57MvnYtiwEDfNRrlov/8ePqz3yoAAoi1qxZg5SUFJhEt2/33HMPFkjlUHsAhmFs6vVEaz1/tlfL1ZbUYy7WyFscKz8GAMiJyUFGpNWK4MgCdN7c+UMcc+PqTdKBA9bXf/2r9Dz8n9WdwnoLFggtEL4oytcaFxgAQG6StAApeBcRBkDvtEj0TLEG3k/pmQwA+O3YToIYjIGZMZghanJphQHuvY/0c6qtg/GVfwJGFyo3ml1gUhcktdo9a50n+PVqNd7edQHv7L6IveduAgD2X7iJRp0Bx4urAQBXbhElyhdvb/x0Hqt3X0R9i8HaBuTiRWD1aqDgLLmyDxkMLJgPRAi/lIL3I/VOi8SMgWmCOK1gQsoCxDAMVOY2GzZxjKPHkDuEGzfRYiyEPFz4J2GURoQPKoI6/RZq9c1YvroIqkTPNbh2maoqa82nO+9s/+23ApoF5gFY1rU7V28QEuJ6MOUDDzyAJ554Art27cKECRMAAJWVldiyZQt++OEHr41RxsgQogyxNEdVyT2f9sq1ydAZddCb9FDDeXPV1lBWV4YfzpN99d609wSfObIAFZnr7XUW9YB0ZAHi8//+n/X1kiWuLeMqISHA3r1Ar17EEuFPQdDXrpFuAY8+Cvz2t2SawAIkY526wLgLjkLOE0XmZUJUCkAlvZwkWi0pUPfuuzhf1wTs2CnZ7JEfJPz5JhbFF/2nI8COAmnL1d5zN2wuvuK9YTCx+GDPBYy6dRHYfcLq101NBe65G4iXTjtsSyPaQEKrkgMuxOw5wiCOMNdoSDzN998j9NolhHQR7mNlDNmgJuMWlLH1kKkM0OZUQFfRzsp62zZihu7RwzW/vR9ABZAHaGx0z+XgSerrndd64YiOjsYdd9yB9evXWwTQF198gbi4OIz3csBaclgyKhoqkBruPbusSq6CzqgjPbq8xM6inWg2NKN3Qm+MyRwj+MyRBajeXD5GfLfvqgA6aC067RUXlVJJ+oEBvslOsieAHnsMOHaMPHMCiC94TM1KyGS2Fj+BBcj8Us5b0F5mmFRNIRsiIoF7pgOffUZ+mL59iduHB79R6sRJLKLucWG9Pqa2yWDT4NUmpqW5GYZPP8Hu0jJuBmDwYFLET6WCXMZINolVBIkAur1HIradrsDALPtWbmdZifz9lx0XCo1SjsKBA2A6cgRh16sF8955JyC4d/VlwD0X/+Pn1Z/5UBdYkDF79mz873//Q4v5Kr1u3TrMnDkTMi8XkYjWRiM3Ltft3l/uwFmWxO032gLLsoKGrodLSYXTCdkTbAo5OrIAcdPEgsdVAcTHWxYaudx3PansucD277edl3/9aDiTgijY3n3IRC4wAJAzfKuQ9DhcTpnv2pXUt2FZYNtWCFuiCrOknHWN9yfEQxXsDaMBWL8OKC0jar9vH2DxIhLvYS5yZM/SI5cHhwCKClHhwcHp6Bxv/47YmRbkx1dN75+KKb2SMKJrAjD5doTCWohzzBiiPR94gLcwK1x5u6XGG40kCAkIiPR3DmoB8gAhIdY7fF9s2x2mTZsGlmXx/fffY/Dgwdi3bx9ef/117wyunQlRhqCyqdLiavME/7ft//DvI//GkcVH0CexD05dPwUA6JfUz2ZeRxYgfhsMPt4QQOnp7q/T13DHsV5PHtx3vHXLdl6BBahJjUxZCuSy8wKhIbQA2VbKtSd0lC6ZgMxMnAgUFpJ+YoXnBF1p+WMJIP1jk9km2E1btwIlVwGtBnjkESAxyWZ5hYyBlP01WCxAriBOexdjNNqxoHXqjLDE84A50ZX7j/ToYX9den079SI9fJgULouMBEaPbocNegZqAfIADEPcUL54uOuu0Gg0uO+++7Bu3Tp89tln6NatGwYMGOCdHdPOcNllDfo2OuF5vHbwNeiMOvRd3Re/VvyKszfPAgC6x3e3mZcTNxs32tbk4CwbnhBA9ox1a9aQKtNffun+On0NX8g7C4QWH/NKJTC6CymvzWUWySVcYHzRI15H3/RIxISq0DXRDQtlVDQwbDh5zcU/mDEIBFDgKCC+cKtr1lstCCdOAEfMqYz33ScpfgBAYUdABksMkCs4S5W3iQGCNYg8bEBXyzSlzHq8SVWXB+wXZfU4XPHDKVMCqscLtQAFIbNnz8Zdd92F06dPY86cOb4ejscIMfd70hl1aNQ3IkQZApZlPdZscfQHoy2FFrvFdrP5nC9uNm4E+LvWky4we19n8WKSfRwIFYbFqNXWJpFNTUCEg04p4u+nUAD90qOQHReKSC05+fKzjrjZkyKsP5DYAnRbbmLrjpXRo4D846Sh6uHDwHAiiEqrrCrOWy4wk4nF2fI6pEZrLd+7LbBgBWN9f585cr+iAvhuM3k9ZgyQY6fxHOxbehR+06fB97RGC3ICMrRXFmDONJeVlwIgWaiWn010qOm8Fw4phAtECiD3F0AtQEHJbbfdhpiYGBQWFmLWrFm+Ho7HkMvkFhF05sYZHC07iuPlx1HT3LqU0G8LvxW858RPclgyIjW2GRZ8U7M4mNeTLjBHBKL4Aci4uX3hLKNSfC1VKMhddVSIyiJg+BYHLqYiMzYEwzvHYnSXOElXlz3xExeuxsBMO0GtKjVwG0kowN49QKOt9dFbFqDjJVXYerocGw4Xe2X9AIgJ4YsvSO2jLl1IgTsH2LP0UAOQlRCJ+lPO4DIYtSHW41Z3+Ro4xWMpfSbaz+1iASopAfLzyZ84gAKgASqAghKZTIaysjKwLItOAZKu6Cr8ujwAYGJNlq7t7mAwGXDPBunUnW5xttYfQBgHJs72ai8BFMi4WgxRygIkhn8h5vQHwzAY1ikWg7IclN8W0Ts1EnOHZWJkjh0fAwD06wskJwHNLcCPWyC+DbfXNaKtnC2vAwA08jqN26OuWY9TpTUwGE1QK6VP+zfrReYCkxH46kvg5k0gIhyYfo9Tha20E+zsKStsR+C23ESkRGlwV59kl5eREpYtdToSkwVesLMoC6xdLECbSFsgjBxJamQFEFQAUToUYaowJIcJTyytqQxdXGP/rjo3NldyOj9g1yCqj2cvBsjdIPaOjKvFEKUsQDbz8K4XrbHAcC0bOsWTuDK5jMEge6nNjIzkI8sY4NQp4NBhwceNeheKJYooqWzE3nM3YHDQcKyxxbnw4dhwuATbz1Tg8OVKQTYcH679BcACZ88C771PgrsVcuChmUCI83obcurqckqkVomHBmegixvxZlKuxTSUAEfIscYd4vIQoeJpFwvQJ5+Q54cfboeNeRYaA0TpcMhFvRHE6equcP7WecvroieLkP0va7Xg3DhpAWQurQTA9iJuLwYoJQUUM560APEtDtEh7qfBzBuRhVsNLUiNsv5g9oQDACAtHZh0O8mU2rYNSEgAsskx8+PJcphMQI8UB4FNIr7II3f2oWqFXfcbv4eXM7h5i2422GalsSxwrQwoLgFqa4ALF0lGD0ACWh94wOUDlX+hXjymE4wmVtjbitIq+C7bJ58Eygsq0XXbeeC0DJg4CSwrfWx5XQAdOULcX0ol8NBDXt6Y56FHJqXDIa403Zrmq+criQC6p9s9yIrKwh9G/8HyWWZUpuQygwaRMAlAeBE3GKwJQmILUIYb/TiDhTfecPy5UWT4sFe76PFxnbFodLagAaeraFVypEWHCISUs/RlDBsK9O5FfF6fbwSuWxvzbj1d7vYYAKC22bNtXRgwYPkuuitXgH//m1h7tm4FDhwk4kelBEaOAJ54wnpQu4CwAjexdkj1XKM4h6+3+QIoKgrIHR4DZGWRY+3QQbv1frzuAnv7bfL84INAbKyXN+Z5vCaAKisrMXv2bERERCAqKgoLFy5EvZNiOc3NzViyZAliY2MRFhaGGTNmoKJC2N172bJlGDhwINRqNfr16ye5nl9//RWjR4+GRqNBeno6/vGPf3jqa1ECgGhNNOJD4pEURlJ1DSb3XRCcBahLDDn5v3jbi3hx/IuY2mUqpuTYD/Tj2h3wLUD812IBlCmtpYKSwkLy/NVXjucTlxiwd5LXKOUI92C3ceep3Axw9z1ARjqJB/roI9LLoy3b9HDsjIzhxYtcugis+5RksKlUQG43YOgQ4N57gRUrgYmTgHD3CpcKe7DRuJ/WMLlnEpRyBtP7WavmS1rRRo4gz0fzYLKTaehVC9ClS8D69eT10qVe3JD38JoAmj17Nk6fPo3t27fju+++w969e/Hoo486XGbFihXYvHkzNm3ahD179qCsrAz3cS2xeSxYsAAP2TG31dbW4vbbb0dmZiby8vLwyiuv4M9//jPWrFnjke9F8X8YhkFmVKZFABlZo9txQCevnwQAdIm13v3+Ycwf8N2s76BR2O9gKhXHwq8MLRZAudLeNIoDxAJIqvK2N3DJlaNQkFiI1FSgsQn4+GOgrq7V2/R0/RxSaoAFLpwH1n8G6A1A1y7AUytJnM+UO4A+fVzv0iuicwKpgJwQoYbGTrA1xTE9UiLw/43LQVacNeZKqZA4DnJyiKtVpwNrJ87MawKIZYFly8ifcfJkYTflAMIrR2hBQQG2bNmC999/H0OHDsWoUaPw1ltvYcOGDSgrK5NcpqamBv/973/x2muv4bbbbsPAgQPxwQcf4JdffsFBXiOkN998E0uWLLGbvbRu3TrodDqsXbsWPXv2xMyZM7Fs2TK89tpr3viqFD9Gzsjd7hD/9LanoXlRg12XdwEABqUMcmub4jiWkhJr+IRWaxvAGxpq2yGeIg2XTSUOMG+vYm8uu9I0WmDuXCA5maizn34CAJTXuK/UpCpWn6+ow8361n1pBgxMV0uBTV+YG1d2J+4LVdubB/9meCa6JIRh7vBMzBycQTO/2oDY3SotvhmLydmeAPKaC+zf/yatL1Qq4NVXvbQR7+MVAXTgwAFERUVh0CDrxWPixImQyWQ4dOiQ5DJ5eXnQ6/WYOHGiZVpubi4yMjJw4MABt7Y9ZswYqHhFWSZPnozCwkJUVVXZXa6lpQW1tbWCByWw4TrEA665wViWxT8P/BMtRnJx6ZPYBwOS3auSLbYALV9uvWDbiyNVS1x7srJs19nRGeREa3KWH19ZgNyyaKjVpI09AJz8FbhxHZ8dLobRxGL7mQoUXLN/fuF3ZBcbgEoqG/Hdr9fwyYErgukHL0n0DDEjKMRYVgr2448BnQ6y7CzgvhmAvO25ME/cloPYMDUYhkFcmJpWfvYwUvszNVoL9OoJRESgM3tRcjmv3BycOwc88wx5/c9/Aj17emEj7YNXBFB5eTkSEhIE0xQKBWJiYlBeLh0MWF5eDpVKhaioKMH0xMREu8vYW0+iqDMz997RelatWoXIyEjLIz0QGypRbFDKiADid4hnRRGDey/vxfM7n8fETyYKpneN7Qp3EQugy5etn6Wm2sxOxigRpsIXS/ffb33NnXc6Ip99Rp7tdbvnhI+9IpPeRq1wzQIUF2a++UpJAbp3J2WB9u4DAOSXVONUaQ22nLJ/LuK3QhBf+Cpqpb/sgYu38NnhYptjGwBaDOao8fPnwbz8MkzNLUBaKpRzZnms+629FhgUzyBlTQtRyZESEwoMG4ap+B4TQoSGgjvvtCQieg69nvSBa24GJk0K2NgfDreO2meffRYMwzh8nD171ltj9SrPPfccampqLI+SkhJfD4niAbjK0BerLsJkMuFm403kXcvDpapLYFkWVc1VePS7R/Hivhexs2inYNm08DS3txdqdttzYR+XeDUY3RFAMbxafXfcAbz8Mun19fLLbg8pYODufbisOfG1XK8n04pFJZr8zQLUI4VXBXPMGPJ8+hRw86ZLris9r/aP2BXiqKJReU0zim7aVqIuq24CLl4ENm6AobGJFG2cNRuqEPdNi/yyABTfEhemJi7SgQOg0bAY1bgNUaFW8+i777Lo08fDG331VeDgQVLpdc2awC09b8Yt2+dTTz2FefPmOZynU6dOSEpKwvXr1wXTDQYDKisrkZQk3UQvKSkJOp0O1dXVAitQRUWF3WXsrUecOca9d7QetVoNtZQvghLQxIXE4UYjqWnSZGhCeT25865sqoRJZ0Jts60rIlYbi+rmatze+Xa3t5dsrsFYVgaUlgLV1dbP7BVJlRJA/OaGsbEBWWPMbfhxty0ttpYgnY7EVDU2WvuGAcCIEe0zPikL0MTuifipwHq+6ZoYLqyGnJREsqvOFgL79qKhm/O6B3peN3CxCLSX7cMhVRX6+sFjJC3faIJu0BDg/pmtdns9ODgdr28/16plKW0jJlSFygarJXtARjQRtyo1MGgw8PPPqG6wnkyiY1jY9MZoC8XFwIsvktdvvin00wcoblmA4uPjkZub6/ChUqkwfPhwVFdXIy8vz7Lszp07YTKZMHToUMl1Dxw4EEqlEjt27LBMKywsRHFxMYabGwy6wvDhw7F3717oeYEC27dvR7du3RAdbaeSK6XDEqoKtWRt6U16gYuguqXaZv7f9P0Nbj5zE3XP1eGOLne4vT3OylNaCqxeLfzMnr6WEkDdeN02HDUG7UjwBVBTk22sz4gRwAsvkNepqcCFCyTT3Mk9mcfQquTonRqJ3qlWC096jBaJvCar6TFa28DlMWPJ88lTaCi2JoF8cuAyrtU0oVlvxJfHruJUKelZx6/+LK5i/ctF+7E+fKoadFj7cxFOfPYdTE8/A+j0QKdOqLz/YYv4aW2Q8pBsq3lSIWMweygtZtUezBmWic4JYYgLU2Hu8EyoFDLrsTZoEMAwGA+SvNG7D6B0v/6nY555BmhoIIHXc+d6eOW+wSuO2+7du2PKlClYvHgxDh8+jP3792Pp0qWYOXMmUszBDaWlpcjNzcXhw6SUd2RkJBYuXIiVK1di165dyMvLw/z58zF8+HAM46XYXbhwAfn5+SgvL0dTUxPy8/ORn58PnTncfdasWVCpVFi4cCFOnz6NjRs34l//+hdWrlzpja9KCQDUcqI89EY99CbbbDB+rE9WZBYAQKtsnamfE0BlZUBRkWgcLgigDz4g1p7ly63TxH3FOioKhdXq09xsK4AuXQIKCsjrceOAzp2B3/zGY2EsLjGxRyIm9kjEnb2TMaF7AqJCVHh4iDVekAFj6xVITiaKlmXR8vkmcI6sm/U6fHW8FEcuV+LKrUZsP0MsSS0GngDiWXya9c5bX3Db/uXiLdQcOIKd72yAUW8gxQxnPiTYWSqF89N/r1Tbgy8z1tq/ZUROLBIiWpcyT3EPuYzB3X1TMHd4FuLCyMnE4iKNjAS65GA09mJOjzxMvdPDTXj37QM2biQH2H/+E/CuLw6vRa6tW7cOubm5mDBhAu68806MGjVKUItHr9ejsLAQjbyIxtdffx133XUXZsyYgTFjxiApKQlffvmlYL2LFi1C//798e677+LcuXPo378/+vfvb0mvj4yMxLZt21BUVISBAwfiqaeewgsvvOC0BlGwsGXLFowaNQpRUVGIjY3FXXfdhYsXpTMIOgpcJtiVmiuS9YD+Mv4vltf2Gp26CucCa2khFgo+Kjt3ZHwBNG8eqS2m0QDvvgv8+c8kjjZY4KxAUhYgADDfL+Hxx9tvTFJ0SwpHn7QoAEJLSly4SjoDavLtgFIBw8UiIO+YZXKL3oSz14R1guqarRmLTjxeNjAw+wZ//AH45hvAxOL65LuI+BGZBFR2GpfyGZgZjQHmVhzdk0lRRL6FSypNn9J+CA61gYPAAOh8eSfUSiOade4XgJWkuRlYsIC8XrAAng8s8h1e6wUWExOD9VyVSAmysrJsMhY0Gg3efvttvM2V15Zg9+7dTrfdp08f7Nu3z+WxthWWZdGod9LAyEuEKEPcMmU3NDRg5cqV6NOnD+rr6/HCCy/g3nvvRX5+PmQdtJEhlwnGfx+picTN2pvQKrXom9gXr09+HUfKjuCBHg+0aVtqNRE0er0wABpwTQDxCUbNrtUC9fXknOso/d/j2S1tZM6wTNQ06ZAcqUVtk/XC89jYzli95yIQHQPcNgFNW7eSukC5uZaIeXFPrzpe+wv+Xbwrd/RNLTrg8ceh31cIRKcAI0bg6qSJkIoFSY8OQVm14whyGQOMyolDp7hQJEdqzNOs66Lp7n5ETg4QFgbU1yP0yCGELrkbWLgAGD8eGDq09Vabl18md3PJyQFd80cK2gzVAzTqGxG2Kswn265/rh6hKuddmjlmzJgheL927VrEx8fjzJkz6NWrl6eH5xdwmWAcaoUaWVFZSNIkoaihCAzDYPmw5R7bXlgYUFVl7Sdp2a4dFxituGCFbwEKs/OX6tmTxBb7E/HhasSHkx+Y32tLLmMQopKT4OQhQ4ATJ4DycmD7dmD6dMl1CS1AfAHkZBAmE/a+8j5yPvwURYOnkwvf6FGwFwg7JDsG12qaUVxp/+ZNJmMglzFIjwnhTbN+TgWQbxHUeJLJgD69gV8OoM//PoTy4gXg978nnw0eDEycSKw3U6e63uJk+3bgpZfI6zfe6HD++I55y0+xy/nz5/Hwww+jU6dOiIiIQJY5kr9YnFvcgQhTCa+kYouQx7dn58JtzwL0+98D990HfP6598YUKPDrKEm5wABSfsSfPS98Q41cxmD+yGx0ig8lF6i7phI9cuIE6cUlgU4QBG2dbjQ6UEBGA/DFJuDXk7gSlwbMmAGMHg1HWUAKuQz39EtBuMb+fbCUi4tagPwHvfiY6NMXAHA4NBlFz/+N9HXTaknX9lWrSIBhly7EqmNy0h7o00+BKVNI+uX06cADbbOO+yPUAuQBQpQhqH/OcaNXb27bHaZNm4bMzEy89957SElJgclkQq9evSxB5B0RpVyJnvE9cfrGact7bxJqxyBnzwIUHg7873/eG08gwYnHhgb7Asjfm07zBZCMAeQKGSK4pqypaeRu/PAR4OtvSDATz9fXrDcK7ur5QdCfH7VTm0yvJ+r5wgVALofslVcAF2PZFHIZFozMxr92nJf8XErfCASQPyvRIMAmMD4xEUhJhrHsGoq1Mcj+8ktSjfXdd0nT282bSYPeZ58lKZSLF5OeXuJMglOnyGcmEyl8uHq1f991tBJqAfIADMMgVBXqk4c78T+3bt1CYWEh/vjHP2LChAno3r27w/YgHQl+VpejZqaewF0LEMUKt+/q6gJYAPFcYNz/U1DQcOIk8iXq6kiwMo81ey+hsNwaFM13gYljhQAWOHOGXJwuXCDBZLNm4VC6e65scbFFwWcS5xe+6HG0LMX78AWQJTtvwEDy2bafiBrPyiLWn3ffJWLonXeI6C4oAFauJCn0GzeaV9gMfPUVcZc1NxML0Nq1rW6O6+9QARREREdHIzY2FmvWrMGFCxewc+fOoCoPkBqeinBVOGK13r2C2hNA9lo8UKxwoQl8ARQZKUw88XsBJOGpEugEpRK4715yR33yFFBs7etlFAX6SMf9sKSy8/vvA5s2kTv70FBgzhygUyfUNrnW+NcVpAQQw7tqKKgA8il8ATS9n7kGR69egFKJ5oqbwM8/CxdQqYDHHgOuXgX+9S9ils7PB2bOJL7l1FTij6+oIH+6Tz6x7eDcgei434xig0wmw4YNG5CXl4devXphxYoVeOWVV3w9rHYjOTwZ3eK6QS7zbuEYewKoA1qQPY6UAIqJAX73O+s8/DYhgYKNqyglFRhobrS7dZu0aoLVAmQ0sYBBT1wTa9eS+IzSMkClBMaOBZY9AWR4viChMxcYTYP3LVEhxKyslDNWa5xaDfTuhea4BEDUkcFCTAxxfRUUWIuO/fQTEdPJycD8+cDevcKS9B0Qek8aZEycOBFnzpwRTJNqoEhpPfZigOi1wjmcAKqvtwoglYpkjQMkVKHD1EUaNw749SSpmllYaP2SPEwmFjXHf0XVex8C56qBZnMvMaWSCKhRo4BQ72WgSgU5UxeY/zC1dzIOXrqFwdmiu4JJk9D84L3AqM6OV5CdDbz+OhE8e/YAnToRt1d7Vhf1IVQAUSgehh/snJcHDCQueSqAXEAqBkipBAYMAA4dIg3W7TWV9Re6JobjcFElMnip45K3GKFhpD7Lvn3A7t1At65C/xIA4+EjWPvJFwDX0T0yEujbl4ifCPdSksfnJiC/uApVjY5dZKO7xGHf+ZsApNtl8CfRQ9q3RIeqcEfvZMv7gZnRyLtSBWi0aHJeONxKnz4dqsChq1ABRKF4GH6RvgEDrK+pAHKOlAuMKxQ5ZIhvxuQuKoUM80dmuZagMHwYcOQwibk4dtyqlgGgoQHXPt4AGBnS92PkSBLQ6uKBpJAxMPCCiGJCVOiWFIGDlxz3E1PKHUdG8N1e9Jj2L8Z0jceAzGi8t/cSmvVGsCzb6p5vwQCNAaJQPMySJSQbdeZM4XR6HnKOIwEUSLh80dGGAOPGk9c7d5ACSBzbt6PeyJCqj7NmEWXtxkGkUsgQE2pNPZTJpGN63IXvFmOoDcjv0Jh7vLGssK8cxRZqAaJQPEx8PEmyELvRqQByDlcSh98MNRAFkBiHYXaDBwPHjpGA1fffBzplAw2N1s6vU6e2KhPHyLKCGB25jHEpZsfZccpfBT2m/Q+FXAaVQgadwYRmvREaZXDE87QGagGiULyAQmF7cQjE7KX2houfamnpWALIITIZMG0aoFaRLJyjeVbxM348kJbWqtVqlXJBmrqcYVyyADmz6vCtW6Fqeg/tj6jNVqBmPbUAOYIevRSKl3nvPZK9PG6cr0fi/3RUAcRKh0FbSUsDlj1JChpWVpJpiYlAd9vMMDEqhQwGI2vTLHVqn2TsPmttSCeXMegUF4a9527atL+QMYxleVfaWzw4OB06gwlhVAD5JRqlHHXNBjSJK0VTBNCjl0LxMosW+XoEgQNXLbujCSCXCAlpVSbOb8d0wrt7L0FnsAqgoZ1ikBCugZEniqJCVJDLGCwcnQ2tyC3CMLCkqnVNDMOvVzVIjdbCHqlR9j+j+J4wtQKNOoONKKYIoQKIQqH4DZwFSKfrWALIW9ehrLgQKOQyJEVoBF3dubiP8d0SsOloCXqkRFgsO5a+ZDxkDMDZChRyGWYO8XxRRUr7Mb2/n9eK8BNoDBCFQvEbOqoLjE9OAil2FBXSti+WkxBmaX8wpVeS4LPeqaRGUFKkBr8d2xnjuyU4XNegLBKg1i0pvE1jolACCWoBolAofkMwCKA7eyfjVkMLGlqM+Pp4aavXo5AxloBkfjByuEYhqOWjUji/zx2SFYOs2FDEh6udzkuhdBSoBSjImDdvHhiGwWOPPWbz2ZIlS8AwDObNm9f+A6NQIIwB0unI644ggLhO3TKGgVzGICFc0+ZGop5MQZfJGCRFalwKgKZQOgpUAAUh6enp2LBhA5p4Rdeam5uxfv16ZHihoSKF4iodNQYoMzYU9w9Mw6LR1jLhbRcbVKxQKG2BCqAgZMCAAUhPT8eXX35pmfbll18iIyMD/fv39+HIKMFOR3aBpceECFxV/mQBolCCERoD5AFYlkVjY6PzGb1ASEhIq3q9LFiwAB988AFmz54NAFi7di3mz5+P3bt3e3iEFIrrdGQBJKatFiAZVUAUSpugAsgDNDY2IoxrY93O1NfXIzQ01O3l5syZg+eeew5XrlwBAOzfvx8bNmygAojiU6TqAKlU9ucPZNoqgOwtTZtfUiiuQQVQkBIfH4+pU6fiww8/BMuymDp1KuLi4nw9LEqQ01FjgKRojQBKjtTgWk0zAOoCo1DaChVAHiAkJAT19fU+23ZrWbBgAZYuXQoAePvttz01JAql1XACyGgEGhqE0zoazgTQ0OwYHCqqFEzj11OkAohCaRtUAHkAhmFa5YbyNVOmTIFOpwPDMJg8ebKvh0OhCMROTQ151nbQrgtiATRjQBpSojS4UtkIo4lFl4Qw9EqLxEf7L8NgItKHX1GauroolLZBBVAQI5fLUWDuOi2Xy53MTaF4H368T3U1ee6wAkgkYNKitZDJGHSOt8YTittW8JuqUvlDobQNmgYf5ERERCAiIsLXw6BQAAjjfYLNAiRzISaob1qU5bU9CxAVRhSKa1ALUJDx4YcfOvz866+/bpdxUChSMAxxg7W0dHwB5KoLix/30zMlAj8VVIBlgcyY1sf/USgUKoAoFIqfwQmgju4C4zMyx7UMTIZhsGh0J1Q16JBOBRCF0iaoAKJQKH4FFwcUDAJocs8kVDfpMDgr2uVlwtQKhKntn7ppbDSF4hpUAFEoFL+CywRrJuVuoNH4bizepkcKjb+jUHwFDYKmUCh+hbjuT0e2ALkCP/XdEfHhZMflJlFRRaG4ArUAUSgUv4IKoNZx/8A0lFY3ISs28GqSUSi+gAogCoXiV4h7f1EB5BoapVxQQ4hCoTiGusAoFIpfQS1AQli46AOjUChuQQUQhULxK6gAolAo7QEVQBQKxa+gAohCobQHVABRLOzevRsMw6CaK8AS4HS07xMsKIXtrzp0GjyFQvEdVABRKBS/4scfra8VCvIIZlxNg6dQKO5BBRDFp+h0Ol8PgeLHUPcXhULxFlQABRktLS1YtmwZEhISoNFoMGrUKBw5ckQwz/79+9GnTx9oNBoMGzYMp06dsnx25coVTJs2DdHR0QgNDUXPnj3xww8/WD4/deoU7rjjDoSFhSExMRFz587FzZs3LZ+PGzcOS5cuxfLlyxEXF4fJkydj1qxZeOihhwRj0Ov1iIuLw8cffwwAMJlMWLVqFbKzs6HVatG3b1988cUXgmV++OEHdO3aFVqtFuPHj8fly5c9tdsoPoK6vygUiregAsgDsCwLncHkkwfrpn38mWeewf/+9z989NFHOHbsGHJycjB58mRUVlZa5nn66afx6quv4siRI4iPj8e0adOg1+sBAEuWLEFLSwv27t2LkydP4uWXX0ZYGKk9Ul1djdtuuw39+/fH0aNHsWXLFlRUVODBBx8UjOGjjz6CSqXC/v37sXr1asyePRubN29GfX29ZZ6tW7eisbER9957LwBg1apV+Pjjj7F69WqcPn0aK1aswJw5c7Bnzx4AQElJCe677z5MmzYN+fn5WLRoEZ599ln3f0yKz3nnHV+PgEKhBAMM6+4VNEiora1FZGQkampqEBEhLC3f3NyMoqIiZGdnQ6PRQGcw4e1dF3wyziXjc6BSuKZjGxoaEB0djQ8//BCzZs0CQCwtWVlZWL58OQYPHozx48djw4YNFotMZWUl0tLS8OGHH+LBBx9Enz59MGPGDPzpT3+yWf+LL76Iffv2YevWrZZpV69eRXp6OgoLC9G1a1eMGzcOtbW1OHbsmGUeg8GA5ORkvPbaa5g7dy4AYNasWTCZTNiwYQNaWloQExODn376CcOHD7cst2jRIjQ2NmL9+vX4/e9/j2+++QanT5+2fP7ss8/i5ZdfRlVVFaKiomzGK/4dKf4BywIy8yEdHQ3wtHlQsuloCa5WNSFco8Ci0Z18PRwKxe9xdP3mE+ThhcHFxYsXodfrMXLkSMs0pVKJIUOGoKCgAIMHDwYAgciIiYlBt27dUFBQAABYtmwZHn/8cWzbtg0TJ07EjBkz0KdPHwDAiRMnsGvXLotFSLztrl27AgAGDhwo+EyhUODBBx/EunXrMHfuXDQ0NOCbb77Bhg0bAAAXLlxAY2MjJk2aJFhOp9Ohf//+AICCggIMHTpU8Dn/e1ACB343c7PhMai5s3cy8kuq0Ss10tdDoVA6FFQAeQClnMGS8Tk+23Z7smjRIkyePBnff/89tm3bhlWrVuHVV1/FE088gfr6ekybNg0vv/yyzXLJycmW16Ghtr2KZs+ejbFjx+L69evYvn07tFotpkyZAgAW19j333+P1NRUwXJqcdEYSofCYPD1CHxPqFqBkTlxvh4GhdLh8FoMUGVlJWbPno2IiAhERUVh4cKFghgPKZqbm7FkyRLExsYiLCwMM2bMQEVFhWCeZcuWYeDAgVCr1ejXr5/NOi5fvgyGYWweBw8e9OTXE8AwDFQKmU8eDOO6AOrcubMl9oZDr9fjyJEj6NGjh2Uaf19VVVXh3Llz6N69u2Vaeno6HnvsMXz55Zd46qmn8N577wEABgwYgNOnTyMrKws5OTmCh5To4TNixAikp6dj48aNWLduHR544AEozQVhevToAbVajeLiYpv1pqenAwC6d++Ow4cPC9bpzd+c0j5QAUShULyF1wTQ7Nmzcfr0aWzfvh3fffcd9u7di0cffdThMitWrMDmzZuxadMm7NmzB2VlZbjvvvts5luwYIFN1pCYn376CdeuXbM8xG6XYCQ0NBSPP/44nn76aWzZsgVnzpzB4sWL0djYiIULF1rm+8tf/oIdO3bg1KlTmDdvHuLi4jB9+nQAwPLly7F161YUFRXh2LFj2LVrl0UcLVmyBJWVlXj44Ydx5MgRXLx4EVu3bsX8+fNhNBqdjm/WrFlYvXo1tm/fjtmzZ1umh4eH4//+7/+wYsUKfPTRR7h48SKOHTuGt956Cx999BEA4LHHHsP58+fx9NNPo7CwEOvXr8eHH37ouZ1HaVd69iTPY8f6dhwUCqUDw3qBM2fOsADYI0eOWKb9+OOPLMMwbGlpqeQy1dXVrFKpZDdt2mSZVlBQwAJgDxw4YDP/n/70J7Zv374204uKilgA7PHjx9v0HWpqalgAbE1Njc1nTU1N7JkzZ9impqY2bcMXNDU1sU888QQbFxfHqtVqduTIkezhw4dZlmXZXbt2sQDYzZs3sz179mRVKhU7ZMgQ9sSJE5blly5dynbu3JlVq9VsfHw8O3fuXPbmzZuWz8+dO8fee++9bFRUFKvVatnc3Fx2+fLlrMlkYlmWZceOHcs++eSTkmPjjpvMzEzL/Bwmk4l944032G7durFKpZKNj49nJ0+ezO7Zs8cyz+bNm9mcnBxWrVazo0ePZteuXcsCYKuqquzui0D9HTs6p0+z7Esvsezly74eCYVCCTQcXb/5eCULbO3atXjqqadQVVVlmWYwGKDRaLBp0yZLajOfnTt3YsKECTYZO5mZmVi+fDlWrFghmP/Pf/4zvv76a+Tn5wumX758GdnZ2UhPT0dzczO6du2KZ555BnfffbfDMbe0tKClpcXyvra2Funp6S5lgVECE/o7UigUSsfD1Swwr7jAysvLkZCQIJimUCgQExOD8vJyu8uoVCqbdOXExES7y0gRFhaGV199FZs2bcL333+PUaNGYfr06fj2228dLrdq1SpERkZaHlxsCYVCoVAolI6HWwLo2WeflQww5j/Onj3rrbG6RFxcHFauXImhQ4di8ODB+Pvf/445c+bglVdecbjcc889h5qaGsujpKSknUZMoVAoFAqlvXErDf6pp57CvHnzHM7TqVMnJCUl4fr164LpBoMBlZWVSEpKklwuKSkJOp0O1dXVAitQRUWF3WVcZejQodi+fbvDedRqNU2pplAoFAolSHBLAMXHxyM+Pt7pfMOHD0d1dTXy8vIs2Vc7d+6EyWSyKVbHMXDgQCiVSuzYsQMzZswAABQWFqK4uLjNBe3y8/MFdWgoFAqFQqEEN14phNi9e3dMmTIFixcvxurVq6HX67F06VLMnDkTKSkpAIDS0lJMmDABH3/8MYYMGYLIyEgsXLgQK1euRExMDCIiIvDEE09g+PDhGDZsmGXdFy5cQH19PcrLy9HU1GQJgu7RowdUKpWlzxRXIfjLL7/E2rVr8f7773vjq1IoFAqFQglAvFYJet26dVi6dCkmTJgAmUyGGTNm4M0337R8rtfrUVhYiMbGRsu0119/3TJvS0sLJk+ejP/85z+C9S5atMjSABOARegUFRUhKysLAPDXv/4VV65cgUKhQG5uLjZu3Ij777/f49/RZDJ5fJ2U9oP+fhQKhRK80GaodnCURmcymXD+/HnI5XLEx8dDpVK5VZGZ4ltYloVOp8ONGzdgNBrRpUsXyGReqwlKoVAolHaENkP1IjKZDNnZ2bh27RrKysp8PRxKKwkJCUFGRgYVPxQKhRKEUAHUSlQqFTIyMmAwGFxq80DxL+RyORQKBbXcUSgUSpBCBVAbYBgGSqXS0rSTQqFQKBRKYEBt/xQKhUKhUIIOKoAoFAqFQqEEHVQAUSgUCoVCCTpoDJAduOoAtbW1Ph4JhUKhUCgUV+Gu286q/FABZIe6ujoAoF3hKRQKhUIJQOrq6hAZGWn3c1oI0Q4mkwllZWUIDw/3aKp0bW0t0tPTUVJS4rBAE8UxdD96BrofPQPdj56B7kfPEOz7kWVZ1NXVISUlxWGdN2oBsoNMJkNaWprX1h8RERGUB6anofvRM9D96BnofvQMdD96hmDej44sPxw0CJpCoVAoFErQQQUQhUKhUCiUoIMKoHZGrVbjT3/6E9Rqta+HEtDQ/egZ6H70DHQ/ega6Hz0D3Y+uQYOgKRQKhUKhBB3UAkShUCgUCiXooAKIQqFQKBRK0EEFEIVCoVAolKCDCiAKhUKhUChBBxVA7cjbb7+NrKwsaDQaDB06FIcPH/b1kAKOVatWYfDgwQgPD0dCQgKmT5+OwsJCXw8roPn73/8OhmGwfPlyXw8lICktLcWcOXMQMdyyIAAABfRJREFUGxsLrVaL3r174+jRo74eVkBhNBrx/PPPIzs7G1qtFp07d8Zf//pXp72cgp29e/di2rRpSElJAcMw+PrrrwWfsyyLF154AcnJydBqtZg4cSLOnz/vm8H6IVQAtRMbN27EypUr8ac//QnHjh1D3759MXnyZFy/ft3XQwso9uzZgyVLluDgwYPYvn079Ho9br/9djQ0NPh6aAHJkSNH8O6776JPnz6+HkpAUlVVhZEjR0KpVOLHH3/EmTNn8OqrryI6OtrXQwsoXn75Zbzzzjv497//jYKCArz88sv4xz/+gbfeesvXQ/NrGhoa0LdvX7z99tuSn//jH//Am2++idWrV+PQoUMIDQ3F5MmT0dzc3M4j9VNYSrswZMgQdsmSJZb3RqORTUlJYVetWuXDUQU+169fZwGwe/bs8fVQAo66ujq2S5cu7Pbt29mxY8eyTz75pK+HFHD87ne/Y0eNGuXrYQQ8U6dOZRcsWCCYdt9997GzZ8/20YgCDwDsV199ZXlvMpnYpKQk9pVXXrFMq66uZtVqNfvZZ5/5YIT+B7UAtQM6nQ55eXmYOHGiZZpMJsPEiRNx4MABH44s8KmpqQEAxMTE+HgkgceSJUswdepUwXFJcY9vv/0WgwYNwgMPPICEhAT0798f7733nq+HFXCMGDECO3bswLlz5wAAJ06cwM8//4w77rjDxyMLXIqKilBeXi74f0dGRmLo0KH0umOGNkNtB27evAmj0YjExETB9MTERJw9e9ZHowp8TCYTli9fjpEjR6JXr16+Hk5AsWHDBhw7dgxHjhzx9VACmkuXLuGdd97BypUr8fvf/x5HjhzBsmXLoFKp8Mgjj/h6eAHDs88+i9raWuTm5kIul8NoNOJvf/sbZs+e7euhBSzl5eUAIHnd4T4LdqgAogQsS5YswalTp/Dzzz/7eigBRUlJCZ588kls374dGo3G18MJaEwmEwYNGoSXXnoJANC/f3+cOnUKq1evpgLIDT7//HOsW7cO69evR8+ePZGfn4/ly5cjJSWF7keK16AusHYgLi4OcrkcFRUVgukVFRVISkry0agCm6VLl+K7777Drl27kJaW5uvhBBR5eXm4fv06BgwYAIVCAYVCgT179uDNN9+EQqGA0Wj09RADhuTkZPTo0UMwrXv37iguLvbRiAKTp59+Gs8++yxmzpyJ3r17Y+7cuVixYgVWrVrl66EFLNy1hV537EMFUDugUqkwcOBA7NixwzLNZDJhx44dGD58uA9HFniwLIulS5fiq6++ws6dO5Gdne3rIQUcEyZMwMmTJ5Gfn295DBo0CLNnz0Z+fj7kcrmvhxgwjBw50qYMw7lz55CZmemjEQUmjY2NkMmElyO5XA6TyeSjEQU+2dnZSEpKElx3amtrcejQIXrdMUNdYO3EypUr8cgjj2DQoEEYMmQI3njjDTQ0NGD+/Pm+HlpAsWTJEqxfvx7ffPMNwsPDLb7syMhIaLVaH48uMAgPD7eJmQoNDUVsbCyNpXKTFStWYMSIEXjppZfw4IMP4vDhw1izZg3WrFnj66EFFNOmTcPf/vY3ZGRkoGfPnjh+/Dhee+01LFiwwNdD82vq6+tx4cIFy/uioiLk5+cjJiYGGRkZWL58OV588UV06dIF2dnZeP7555GSkoLp06f7btD+hK/T0IKJt956i83IyGBVKhU7ZMgQ9uDBg74eUsABQPLxwQcf+HpoAQ1Ng289mzdvZnv16sWq1Wo2NzeXXbNmja+HFHDU1tayTz75JJuRkcFqNBq2U6dO7B/+8Ae2paXF10Pza3bt2iV5PnzkkUdYliWp8M8//zybmJjIqtVqdsKECWxhYaFvB+1HMCxLS21SKBQKhUIJLmgMEIVCoVAolKCDCiAKhUKhUChBBxVAFAqFQqFQgg4qgCgUCoVCoQQdVABRKBQKhUIJOqgAolAoFAqFEnRQAUShUCgUCiXooAKIQqFQKBRK0EEFEIVCoVAolKCDCiAKhUKhUChBBxVAFAqFQqFQgg4qgCgUCoVCoQQd/z/stnSsb+4R5QAAAABJRU5ErkJggg==" + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": "
", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkAAAAGdCAYAAAD60sxaAAAAP3RFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMS5wb3N0MSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8kixA/AAAACXBIWXMAAA9hAAAPYQGoP6dpAADXFElEQVR4nOydd3zU5B/HP9e9Wwq0Ze89ZClDRDYIDpQfIiIoIgiKTEVxAE5QmSKCqIgDBAFBluwpS/beo2W1UEr3bu/3x9MneZJL7nJ3udU+79cLLpfkck/vck8++U6D0Wg0gsPhcDgcDqcE4eXqAXA4HA6Hw+E4Gy6AOBwOh8PhlDi4AOJwOBwOh1Pi4AKIw+FwOBxOiYMLIA6Hw+FwOCUOLoA4HA6Hw+GUOLgA4nA4HA6HU+LgAojD4XA4HE6Jw8fVA3BXCgsLcfv2bYSGhsJgMLh6OBwOh8PhcDRgNBqRlpaG8uXLw8tL3c7DBZAKt2/fRqVKlVw9DA6Hw+FwODZw48YNVKxYUXU7F0AqhIaGAiAfYFhYmItHw+FwOBwORwupqamoVKmScB1XgwsgFajbKywsjAsgDofD4XA8DEvhKzwImsPhcDgcTomDCyAOh8PhcDglDi6AOBwOh8PhlDi4AOJwOBwOh1Pi4AKIw+FwOBxOiYMLIA6Hw+FwOCUOLoA4HA6Hw+GUOLgA4nA4HA6HU+LgAojD4XA4HE6JgwsgDofD4XA4JQ4ugDgcDofD4ZQ4uADicDgcDodT4uACiMPhcNyAK0lXMPXfqUjLSXP1UDicEgHvBs/hcDhuQOufWuNe5j3cz7yPr7t+7erhcDjFHm4B4nA4HDfgXuY9AMCO6ztcPBIOp2TABRCHw+G4mHsZ94TlMkFlXDgSDqfkwAUQh8PhuBCj0Yglp5YIzw/dPoSDNw+6cEQcTsmAxwBxOByOi8jJz0GjeY1wKemSsC4pKwmtfmqFtAlpCPELceHoOJziDbcAcTgcjou48uCKRPywJKQnOHk0HA7w8A8Pw/CxAYaPDXhx5YuuHo5D4RYgDofDcRH7b+xX3fYg+4ETR8Ipqby48kX8cfoPxW1/nP5D2GaAAS80fAFLei9R3NcT4RYgDofDcRGvrX1NdVtSVpITR8Ipacw/PB/Bnwerih85Rhjxx+k/4PWxV7GxDHEBxOFwOC7AaDSa3c4FEMdRPPzDwxi+fjgy8zMVt/dr2A/9GvZT3EaFUPAXwZh/eL4jh+lwuADicDgcF5Cem26ybkWfFcIyF0AcvaFWn8O3D0vWB/kGYV7PeTBOMsI4yYglvZdgSe8lwvN5PechyCdI8prMvEy8sf4NjxZBXABxOByOC0jJSTFZ17t+b7zWlLjFuADi6M24zeMkVh8/bz/M6zkPGe9nYFiLYaqvG9ZiGDI+yDARQkYY8eb6Nz1WBHEBxOFwOE7EaDTiePxxSZZXVHAUFj69EAAQGRgJQCqAFh5biJrf1MSZu2ecO1hOsYBafjLzRPHTr2E/5HyYY1b4yKFCiHWPFaIQw9cP98i4IC6AOBwOx4n8de4vNP2+KXou6QkAqF6qOhLeTsCgpoMAKAugwWsG48qDK3hzw5vOHzDH45Fbfvo17GdXNteS3kswr+c8eDES4o/Tf3icJYinwXM4HI4TmXVwFgAgIYNYgML9wyXb5QKIjdeIS4lzwgg5xYkXV75oYvnRI5WdWo7eWP8GjCAB/W+uf1Oyzd3hFiAOh8NxItHB0ZLn4QHmBdD4LeOFbbkFuQ4eHac4Ia/xo5f4oQxrMQzf9fxOsAQVohDjNo3T7fiOhgsgDofDcSEVwypKnssFUFRwlLCN3mlzOJaYf3i+Q8UPZViLYZjbc67wPDM/02PigbgA4nA4HCcSnx4veR7mFyZ5HhEQAUCsBM32A7uddhv5hfmOHSDH45m2bxpGbxwtPA/yCXJoBedhLYZJAqP/OP2HR4ggLoA4HA7HSRQaC7Hvxj7JOoPBIHke6h8KAMjIzSCPeRmS7SP/GenAEXKKAyvOrkBOQQ4AwAtemN5tusPfc0nvJSYiyN2DorkA4nA4HCfx0faPBDdW05imiAqOwtjWYyX7UItPem468gvzcSftjmT7vMPznDNYjsdCA+yjg6Mxt+dcpwUlL+m9BEG+Yp2gD7Z94JT3tRUugDgcDsdJfPHvF8Ly7kG7ET8uHtVLVZfsQwWQEUYMXTsUu2J3SbazFxgOR47RaBRqTO19da/TM7Kmd50uBEUnZSe5tSuMCyAOh8NxASF+ISbuL0AqcH4+/rOw/HyD5wEANSNrOn5wHI8lPTcdWflZAIDokGgLe+uPPCjaneOBuADicDgcJ6A1hd3L4KVo5Wka0xQA8CDrga7j4hQfzt47ixn7ZwAgQpoNoHcm8qDopaeXumQcluACiMPhcJzA9eTr4vKo66r7AVC8cMWExAAQs8M4HDkNvmuAybsmAwDKhZRz6VjYeCAjjG5pBeICiMPhcJzA5aTLAIBGUY1QJaKK2X2VBFClsEoAiIsjryBP/wFyihWWzjFnML2rmH3mjllhXABxOByOE7iSdAUAUCOyhsV9WQE08pGRmNVtFtpXbS/cUVMxxeFQcvJzJM8rh1d20UhE5K4wd8sK4wKIw+FwnMCVB0UCqJRlAVS7dG1heUjzIRjVahS8vbzRJKYJAODInSMOGSPHc7mRekPyvEq46y1AgNQV9iD7gVtZgbgA4nA4HCdABZCWLK5pXaahTaU26FGrB+qWqSusr1uaLP8b969jBsnxWKbvkxY7pEHz7kCATwAAEgvkTlYgLoA4HA7HCVy8fxGANgtQlYgq2PvqXqx/cT18vHyE9aUCSwEAvj/yPY7c5lYgjggV2JSWFVu6aCSmfN7xc6E2UHJ2sttYgbgA4nA4HAdzM/UmLt6/CC+DF5qXb27zccL9xc7xH+34SI+hcTyQGyk3UFBYAICUV1hzYQ22XN0ibA/2DRayBt0BtjaQO3WM5wKIw+FwHMyphFMAgAZlGwjd3m0hPEAUQFcfXLV7XBzPY/3F9ag8qzKGrx8OABi4aiCeWfqMsP3dR9/F1VHud24MazFMsAJl5me6hRWICyAOh1MsSMtJc/UQVEnNSQUAu8QPIMZSANDcFT47PxuJmYl2vS/HfZiwbQIA4IejPyA9Nx3LziyTbJ/0+CREBUe5YmgW6duwr7DsDrFAXABxOByPZ82FNQifGo4xG8e4eiiKpOUScUY7vdsKK3pouwNLPLrwUVSYUQH3Mu7Z9d4c94CeSwAQlxIn2Tag8QAE+gY6e0iaWdJ7iXAT4A6xQFwAcTgcj+efS//ACCNmHZxlUg/FHaDWqVA/+wSQl0Gcsu9m3EWhsdDs/tn52Th65yhyC3Kx/dp2u96b4x5QayJAzgGW/o36O3s4VvN5x88BuEcsEBdAHA7H48nMzxSWk7OTXTcQFQQLkJ0C6MVGLwrd4/ML83E/877JPinZKVh8cjFyC3KF4osAkFPgfsKQYz2sAKJd3ykPxTzk7OFYjTvFAnEBxOFwPB42xsUde2XRi5a9LrAw/zBcGXkFZYLKAADi0+NN9hm8ZjBeWvUS+q7oK6TeAyRziOP5sG5Q9vv/p/8/bpX5ZQ53iQXiAojD4Xg8bHyLW1qAdHKBUWijy+Vnl0vWrzi7AivPrQQArD6/WrJdXimY43nIA/1pg903WryB7jW7u2BEtsHGAsHgunFwAcThAMCsWcDAgUBurqtHwrESo9GIW2m3hOcPslxvAfr+8Pfot7If0nPTAegXBE05l3gOAPDp7k8l69/b+p7k+R+n/xCWuQDyfOTfIS1+SAtkehKfd/wcVcKrCDFBrsDH8i4cjv0UFgIGA/nndhiNwJii7KEyZYAZM1w7Ho5VXLh/AbfTbgvPXW0BuvrgKoatHwYAeLHhi3iqzlNCtk7ZoLK6vIdSCnxBYQFiU2JVX8NdYJ6P/Ds8e+8sAGmBTE9hWIthGNZimEvHwC1AHIeTlwc0bgw88YSrR6JCAhNIOHMmkJqqvi/H7Vh0fJHkeVJWkmsGUkRssihC0nPTcSftDg7cPAAAaFWxlS7v8eNTPwrL2fnZAIBj8ceQX5gPXy9fLHhygclruAXI85GnvVMLEFsgk6MdpwiguXPnomrVqggICEDLli3x33//md1/+fLlqFu3LgICAtCoUSNs2LBBsv2vv/5C165dUbp0aRgMBhw/ftzkGNnZ2XjzzTdRunRphISEoHfv3khISDDZj+N4jh0DzpwBNm0ixha3wmgEdu6UrgsPB27fVtyd435QcRHoQ+qf3Em/48rhSCxQGXkZ+HzP5ygwFqBhVENNjVC1MKjpIBiKgidSslMAAFP/nQoACPELUYwHSc5OdssSARztqIlYT7QAuQMOF0DLli3D2LFjMWnSJBw9ehQPPfQQunXrhrt37yruv2/fPvTr1w+DBw/GsWPH0KtXL/Tq1QunT58W9snIyEDbtm3x5Zdfqr7vmDFjsHbtWixfvhy7du3C7du38dxzz+n+93Es48WcZXl5rhuHCZ99BgQEAP36mW6rUAHYt8/5Y+JYzb1MEgBNrStsPJAzuJtxF2surBF6M7FZaJl5mVh1fhUA4MvOX8Kgkw/Yy+AlxBOl5BABRNPcO1XvhErhlTCmFXHrNolpItQPcrV1jGMfN1NvAoCkQS7ALUC24nABNGPGDAwZMgSDBg1C/fr1MX/+fAQFBWHhwoWK+8+ePRvdu3fHO++8g3r16uHTTz9Fs2bN8O233wr7DBgwABMnTkTnzp0Vj5GSkoKffvoJM2bMQMeOHdG8eXP8/PPP2LdvHw4cOOCQv5Ojjre3uJzjRjegCz6KQ+vcnUhEabJi0SKAFdVdu7qZYuMoQTPAmsY0BQDcSnWuAGr5Y0s8s/QZoSUBawG6n3lfiE9qWUHf7tz0rp+m2GfkZgAAnqtLbvRmdJuBxHcSsWfQHpQKIEGydzPu4q0Nb2Hp6aW6joXjHKi4/rKz9OafW4Bsw6ECKDc3F0eOHJEIFS8vL3Tu3Bn79+9XfM3+/ftNhE23bt1U91fiyJEjyMvLkxynbt26qFy5supxcnJykJqaKvnH0Qf2pjc723XjkPM6FuAAWmMSPgbq1QMGDADGjwfefZfskJEBvPqqawfJMUuhsRD3s0gxwKbligSQgy1ANOYGAPIK8oRU5P5/9cfAVQMxbrNY3XbFuRUAiHvO3j5gcuhdP3WBZeQRARTsFyzsUzqoNEL8QlA6iIj8dova4dtD36LfSgWrJ8ftoeKalkGglA3WJ7i+pOFQAZSYmIiCggJER0dL1kdHRyM+3rSAFwDEx8dbtb/aMfz8/BAREaH5OFOmTEF4eLjwr1KlSprfj2OefCZhpVYtYN06141FgAlGSkYE8Ndfoq9u6lRxv99/F5dnzQKqVweWS2uvcFxHUlaS0A6icXRjAKKbwBHMOTgHYVPCsO3qNuQX5qPTr50k2387+ZvkOc3SKRNURjf3FyXMPwyA6AKjFqBg32CTfan4YqsIczwPKoDYtHd/b3/UKFXDRSPybHgWWBETJkxASkqK8O/GDZ4xoQf37gEdOojPU1KAp55y3XgEUlKExYIOXYC6daXbV60Sl3ftAs6eJany164BEya4YTR3yYT2QooIiEDViKoAyEWe1t/Rm5EbRyKvMA8DVg3AhksbsCduj6bXPVblMd3HQt0e1AKUmUfagbAWIAqtHM1i5Oewx0EFUERAhLCuftn6uovrkoJDBVCZMmXg7e1tkn2VkJCAmBjlkt0xMTFW7a92jNzcXCQnJ2s+jr+/P8LCwiT/OPbzwQdAumOuRfbBBOHnR0aZbn/mGcCnKNBwzBhg61Zx25UrQGKi6Ws4TofG/0QFRyHMPwwhfiEAHB8HZDAYJH225JQPLS8sd6jaAQufVo55tAfqAhNigPLULUDVIqqZrNPaTZ7jPigJIL1dqyUJhwogPz8/NG/eHNu2bRPWFRYWYtu2bWjdurXia1q3bi3ZHwC2bNmiur8SzZs3h6+vr+Q4Fy5cQFxcnFXH4diP21YeOHVKWMw3rSlHApeor+7YMWDUKOn28+fJo9EIXLxIKj1ynA7NAKMFBqnwuJ12GwWFBQ5L+/YyeCn24aJ0qd5FWO5Zqyf8ffx1H4NgAZK5wIJ8g0z2lceMAPa5w47dOYYei3tIeo1xHEuhsVCw9kUERKBP/T4AgA8ec10vLU/H4S6wsWPH4ocffsAvv/yCc+fOYfjw4cjIyMCgQYMAAAMHDsSECROE/UeNGoWNGzdi+vTpOH/+PCZPnozDhw9jxIgRwj5JSUk4fvw4zp4l/vULFy7g+PHjQnxPeHg4Bg8ejLFjx2LHjh04cuQIBg0ahNatW6NVK30KkXG04ba6gMkGVBRAANClC9Cpk/K2du2AX38lcUF16gALTAvPcRwPtQDRIFDq6rmfdR/tFrVDrTm1JEHLenEz9Sa+2veV6vZHKz0qLFcMq6j7+wNiDNCknZOQk5+jGARNeb3F62hVsRW+6PiFYCGyRwA9/MPD+OfyP+j/V3+bj8GxjrScNBhB3JYRARFY/NxiXBt1DR2qdbDwSo4aDm+F0bdvX9y7dw8TJ05EfHw8mjRpgo0bNwqBznFxcfBiCsW0adMGS5YswYcffoj3338ftWrVwurVq9GwYUNhnzVr1ggCCgBeeOEFAMCkSZMwefJkAMDMmTPh5eWF3r17IycnB926dcN3333n6D+XI6OgwNUjUOHcOWFRVQB5eQFLlgBsUH7TpsQiBAAvvyyuHz4cGObasu4lEbkFiAqghPQE7LtB6jgdvXMUbSq1sfu9rImZaVu5rbDMusP0hLX0vPL3K8KyWhD0/sEkA/a7w98hIy/DpLGmNRQYyQ/73L1zFvbk6AVNgQ/wCUCATwAACHFvHNtwSi+wESNGSCw4LDvlVXgB9OnTB3369FE93iuvvIJXXnnF7HsGBARg7ty5mDt3rjVD5eiM2wqgy5eFRVUBBABlyxIBlJAA+PoCbduKAoglJET/MXIsQhuf0jiI0oEk3ftS0iVhH1oE0F6UAqurhFfBsv8tQ+8/e+PFRi/iZMJJ9KzVE7VL1xb2cdRFilq/AEjq+ihZgFjk2WOUrVe3Ytq+aZjXcx5yC3JRNrisYnxJXoFYG8vby9tkO8cxKMX/cOyDN0PlOBS3dIHl5wNXrwpPzdY6NBiAf/4BVqwAXnqJ/EFz5pjul54O/PwzwFgmOY5H6LLuR6oiUwHExqbQ2Bh7ofWGKH7efrg26hoMBgNujLlhkomz99W9SMxMRKVwx5TUULoQftbhM4uCj4qa+5nSv6f7791RYCxAu0XtcCftDvy8/TCl0xSsOr8Kf/T+A+VCSRwR/cwB8NYaToQLIP3hafAch+KWFqALFySqx6wFCCBur88/J8USGzQg4unCBWD0aOBHsSklXn2Vp8c7GXoxplYNWvBv/02x4KletW/kgqFUQClB9CilIbep1AZP13lal/dWYlybcSbrtFibooJJ1iMtIUChbq2bqTdRYCxAVn4WRm8ajV2xuzBt3zRhP/bzzC3I5en0ToILIP3hAojjUNzSAnT0qOSpRQEkp1o1oHZt0jl+8GCAjS07edL+8XE0Q+NYaF+sDlU7wACDpB2FbgJIZgEK9A3U5bi2EhkYiRV9VkjW0SBZc0QHk5g2VgAdvn3Y7GtmHJiB84kk85H9PI0wWpVObzS6cWaom8MFkP5wAcRxKEoWoHBXta25d4/E8AwcKFltt5Vq+HDgySfJ8vr1dh6MYw30YkxdYC0rtsSbD7+puI+9yC1A7kDv+r3xevPXhec9avWw+BpqAUrIEJXI2E1jLb6ux2JybHnwtDVFJ99/H4iJkRZY52iD9pTjdX/0gwsgjkNRsgC5rL/o228De/earLbaAqREz57kccMGHQ7G0YoQA1RkAQJg4nZylAXIXfiu53e4OeYmsj/I1nRxVHKB0XYicmpF1hKWryVfww9HfjD5PK3JJqNdZt56S/NLOEXsjt0NAHik/CMuHknxgQsgjkNRsq5kZbkoVIZthPvee8KiLgKoR9Gd9969wB9/6HBAjhYEF5ifKIA6VOuAoc2GCs+vPBArNv9y/Bf8cvwXm95LbgGiriRX42XwQoWwCpqLLVKRRNOqAaBO6Tom+w1oPABjWo2RrBu6bqhJ9tiQtUOsHbLrboI8mAv3LwAAWpRv4eKRFB+4AOI4FFYAPfsseTQagdxcJw8kORmg/d2+/hqYMkXYpIsYq1xZXH7xRR0OyNGCPAgaAHy8fPD9U99jY/+NAEh6N0BS5l/5+xW88vcrNtXAYS1ABhjwzRPf2DN0l1EqgDTSpCUEAOW2GGH+YYgOMRV58k7yO67vsHoMTv/9FwPoOctjgPSDCyCOQ2FdYKHiTTpSnd2Uets2IDubND0dN04yLt2sUUxFc47juZN2B4mZiTDAgAphFUy20+KHN1Jv4G7GXdxJvyNss8WdRV/zecfPcXvcbTxSwTNdEbSTOGsBYlPbKWH+YXii5hMSa5pecAuQ9Si5ezn2wQUQx6GwQsPLS6wXmJKivL/DOFyU5dKuHWAwSNxeugmgkSPJo5eXm+b/Fy8O3CTtTBpFN1K8Kw71D0XNyJoAgK/3fi0EkQJAYqb1zWypC6x8aHnEhGhvzuxuKFmAqHWBbZoaGRiJQN9AfP/U97gz7g44riO3IBe5BcRsxrp7OfbBBRDHobA6IDhYzABzugDaR9oioAXxnztEAJUmNWhQWEhcbhyHQoN4zdW+oVag9ZfWSwTQH6esi9MyGo3YdGUTALHYoqdCLUBpuWmYPSdfWAaAGpE1hP0aRonth2JCYmCAtNbRT0//JCyzZQc49vPbid/Qc0lPIeCcddlyC5B+cAHEcSisqTsy0kUCKC1NzP7q3BmAgwSQr6+4/OmnOh2UowYNxqVd0ZX4svOXAIDziedx6b7YHmPGgRlWvRfNwAEc19zUWbDWstHvJQMQU9lph3EAaBzdWPI6Nt1+QtsJ6F2vt/B8+PrhDhhpySS3IBcDVw/Ehksb8POxnwGImYyBPoHw8eINHPSCf5IcCVevkrI2774LdOxo//GymNhKlwmgM2eIKapcOVLEEDplfplj9mzSKZ7jMOhFwZwAigmJQVRwFO5m3MX269ttfq9rydeE5SYxTWw+jjvg4+WDEK/SSC+8D5Q7ijMJlQSXYOuKrbHkuSXILcg1aeI6s/tM1ClTB33q90GFsAqS1Pmlp5ciOjgaM7rN0K33WknlVMIpYTmngLQa4fE/joGfqRwJAwcCmzcDnTrpc7zsbHHZZQLoAkkfRb16wipWAOlarXrZMnH5wQP1/Tg2E5sci86/dsbiU4sBSDPAlKgcTjL0aHd4ijV9rOLT4wEAAx8aqNj2wtMo5V0UND6gGxrOr4+krCQAQJWIKujXqB9ebvKyyWsCfAIwutVoIeBcLnRmH5yN7de0i8ywMCAoCPjzTxv/iGLKrbRbwjJ18yqVe+DYDxdAHAm3b1vexxpYAdS5swstQADJACuCFUC6ZqQ8/zxQoejiQoUXR1c+2vERtl3bhuvJ1wEA4QHmS4tXCDXNEAOsywRLSCdVk2OCPTf4mSXUUN5kXWRgpEUxKWfOE9LGwDRQVwtpacRC3LevVW9Z7LmVKgqgO+l3YDQacSz+GADLYp9jHVwAcST46OwUpS6wAweA8uXJHR8AZGbKdrx3D/jtN+CRR4ArV6Ably6Ruj8A0KyZsNphAgggfcIA4OJF8/txbELeesHSRUFNAFmTCUZT6D05+4slMM80jqlMUBmrj/NS45ckz2/cycbmzbaNqaAAuHy5+PcTvpN2B11/64oBqwYgv9DUF88G6yekJ+C7Q9/hrX9I6WzeBkNfuADiSNBTABUWigXPqlcnj4FF/SPZ2CD89x8QFUX8b4cOATVrkmwtewN10tPFHl0A0Lq1sOhQAVSrqH2AnkKOIyAXNOZigACgQVQDYXnh0wtRrwxxhWrt7ZWZl4llZ4hrU0u3dU8gMKu6yTp/b22VpFnk5QeGjUpFt27ABx9YbwAdMID8dDZutHoYHsXai2ux5eoW/H7yd+y8vtNkO9ujbdu1bRjxzwjhORdA+sIFEEeCngIohwmxoMJHUQDNkZrRAQBHjthnQcnOBho0EI/RvDlQv76w2aECqGLR3fWtW+b349hEXqH0C5NnK8l5temreKr2UxjSbAgGNR2E0kEkjV2rBWjAqgHCMiumPBm/jJqm67z9bDoWmy4PfxKY/sUXEo8zAPOWndxcsYNMcY8JYksGrL2w1mS7ud51XADpC88C40jQUwCxIicggDyaCKCtW9VbQ588KREtVrFiBRAXR5aDg4E1aySbWQGke0YYjQHiAsghyC8QlkRJgE8A1vQTv3/q6nl367tCocSm5ZoqvtZoNOKvc38Jz6uXMrWceCIBKY0BWTmjRtGNbDrWppc2oc63dYhr0l/94m3uRiOR0aK1aqnvVxxga/rsjN1pul2hKjeFFrHk6AO3AHEkeHvrdywaAO3tLQoriQCKjQW6dBFf8PXXQKVK4vNFi2x/88uXyWOpUsQVVl4a9Mn2ItLdAsQFkMNIzUnFH6ftazZLCxleS76GZguaodmCZkJVaTls489jrx8rNjVYvB/UAX7bCCzaDq/LPdEqugO+6vyVTccqH1perBFkRgCZxP0xJIheH+GGJDMTWLdOZi0uBrAC/lTCKSEDj2KuTx23AOkLF0AcCY6wAFHRwy5nZQGYOlX6gtGjidWGuq02bSJd1m3pnEjFx+jRipuTmDknL49or969dbIGURfYzZs6HIzDsuj4IsnzSmGVlHc0g1LM0E9Hf1LYE1hwZIGw7On1f1gyMgBc6QZc74DC39ch/svtKBtc1ubjCYHogepxVenpqpsk2ad5eSR+sFYt4KmngDfesHlYTud4/HGUm14O3x/+XnWf1FxRABlhxKC/B0lEkDkLkKcX4XQ3uADiSGCLGdsLtQBR9xcgE0DbtpEnLVoAP/4oqq9atcTlf/4BFi607o2NRuCvIrdFBeUMoHv3xOXCQmD8ePKStaYueeuh7/ngQfG7fXUxbJE4ABLXllbYpqiUC/eVI3bf3fqu1cf3BDIypM+vX7fveI2iitxn1bYDEIN92LgfcwLoklikG7m55CdPRZE9hmBnM37LeMSnx2PY+mGSQpEschfumgtr8NVe0fpmzgLUrFwz1W0c6+ECiCNBTwuQWQGUUUjKTgPA6tXA4MHSF3/4obj8xRfWNRc9dEg08agEFCSqxL/qolfCw8V8f+4G05WkbPK9VouohvyP8m2yyoxsOdJk3cmEk6oXLACoFVm8AlPkAshe2ldtTxZKXQd8xR8RNd7++KOkDqkJbMZYXp5pILTTewfaiLeXGEOg1m+ObWtB+fXEr3hyyZNoNK8Rrjwg2aPU2jO8xXCE+4cjOjgatUoXr/PQ1XABxJGgpwA6eZI8Ci6w1FQErlsOAMg6fZmImpAQk/gcAEQALShyP9y4Aezcqf2N//uPPNaoATz2mOIurAWIxUuPX4TBIFqBbtzQ4YAcSmYeCSSZ+PhEycXGGlpVbIVtA7dJ1qXkpKDCjAqYsX8GHv7hYfx9/m8AQFRwFADgj972xR25G0oC6Nw5248nqcXkKx68Uydy3CFDzL+eFUC5uabz0Oef2z42Z0LPFwD4et/XivtQARTsFyysS8lJwfpL63H67mlh3e5XdiP/o3x81/M7XB99HVdGXuFtRnSGf5ocCezEQ1tEpKUBd0y9BhZ59VXymEotvvPnI3DFrwCArDvJZF2FCkQwyPH2JrPm88+T58eOaXvTU6eAt0jRMPTvr3xsqFuAdAsCp4WPaDA2RxcycsnFNdg32MKe5nms8mN4qvZTGNVyFDpVI31f4tPjMW7zOBy+fRi9lvXC1qtbhVYE0SHR9g3cAWRlASNHkp/Jfe1FrQEoC6BRo2wfi7eXN7wKiky9baYBjxKXzt69wDffWH49ayjNyzMVQJ5yH+FjEAeulp1I608F+IimcSrsWUoHlRZEfkRAhEQwcfSBCyCOBHbioS6sqChipGEzNSzB+v6F1+3fj0AQ83gWisxCZS0EXlK7uZbb04ICEjVJefZZyeaUFBLsHBennpGiiwUIENP3hw4FVq3S6aCcjLwiAWTnxcDX2xdr+q3BrO6zhN5Wcrr8JmYo0swxd+Krr0gJrR9/BMqUsU4EKQkge7MhvQuLvpO2XwFd3gVCyF2TlhwG1sV17pxpLF64+VqXqiQlAT17qlfasIfJOyfD8LEBYVPCkJJN/gC2iCFb74eFiur3275v9vi87YXj4QKII8GPqYV2l/xOBSF06JD24yhOphcvIgrkoNdRFbnwtSyAGhUFVy5cSG4nzc3S+/eT1HoAmDkTaNJEsnnYMBLs3KmTtEcZi24CqClTV+a553TuuFpyoXfK9lqAWML8LF9oAn0DLe7jbE5J48Hx3nvaXmc0Kt8A2HvuexXIvpOiWCB/DQWmk5PF5T17TLfbKoB+/hnYsIFUmdabj3d9DIBkbc0/PB83U29i/aX1wnZq6bn64Cr6reyHp/54CpuvbBayvPo16oeTw04qFqDcPlB7U1mO7XABxJHAxhrHxUktOdZkiLFVoL28QNTU2bNoiNMoG5aNdITiKJpZFkBsK4u2bYEXX1Tf90BRLZdnn1VMf1+9mjxevqwugHTrQ9S3L7H+UGbOBPbtK/6NjhwMdYEF+QbpdkwjzH8n/Rr20+299MRPdt2kOQWWKChQPg3tFkD5cgFEvistFiB2vlAiyMavm03AUHN728KdNGlMwJl7Z4S4Mcq9TBJo+POxn7H09FKsu7gO3X7vBoBU3Q73D0ej6EZCMU7K4ucWo0O1DvoNlqMKF0AcCWwdnNhY6eRljQBiX+ftbQSiSQyFV/VqqFCZ+LVTEUYKFZrD31+aEbZihfJsmZcnBhs0U04VZUWPmgDSrSiinx/w/fdA9+7k+dtvA48+Sm5HOTajlwuMRakhJcuiXot0ey89kQsgrQkMaoJEdwHkR/Le9cjgsiSQ1GBrkOkZR3Q5SRrbdyvtFuJS4iTrridfR3puOm6kmr5xuZByMBTFJ7Lu1fvj7+PFRmZu8ji6wgUQRwIrgO7ckZrKbRVAkrvNIUPgG0hm6qTyjXCu1SDLB/vwQ+CJJ8TnBxSq9p47J85w8pR6BeiEKo8N0L0q9A8/SK9MnlTUxA1xhAvMXGbNkaFHbO6R5Wjkv8ctW7S9Tk0AqeQLaMZgIoCIWKWudHtQu2Gx5nW21FM1OV5+Nv679R/a/9Jesn77te347vB3AICJ7SaiYlhFFBoLcfj2YcVYoL4N+grLrJjnrS6cCxdAHAmsAMjNtb0uDnvHJinhM24c/PzITPvK/emo/2wd/PqrhYP5+xPLyQsvkOc7dpjuc/YseWzdGihXzuL46MQoN63rMUlKqFiRBCJQoqLU9+WYpaCwANn55IvT0wX27qPvmnQ1p1jqNO9K5BYgoxE4etTy69SsKfZagEwFELEAyYOzy5Sx/ti0pIa1sALIVisS5eqDqwj8PBAtf2ypWDMqPZf8vV1qdBHqUx27cwx/X/jbZN92VdoJy+y5Z7BXhXKsggsgjgR5l3RWAFnTJkJqASr6UXftCvj6CneudEJ6+WWNB6U1fT7+2HSmP36cPDbQ1q1bqUgj4AALEAC89BLw5Zdk2ZZ6AhwAwD+X/wFA4if0zJCpElEFie8konBiIW6PvS3Z5s6ZOHIBBIj3AeZwlAvMkBsqXVF5D1D6Am6FrwBbHXrXLvOhfEps2ybmN1iDngJo9MbRputajkabSm0k6xpGNRQKZ47fOl5YP6iJaO2uVqqasNy7Xm8A7n2uFVe4AOJIYEVObq7UBWarABIoKnioNHFrgo3t+fRTcXnGDFFgtG2r6VBqAkh3CxCFVqTmAshmtlwhPp5+DfvB30dDapEVeHt5w2AwoFxoOUkxu/AA97UAKbmkgzV4Buk5HhIiXW+vAAq61UO6os0M4K26SO7aB5jsBfhmoHVrUiHi119JcmQHC7G+7FxB65taAyt6rPltn088j8RMMWr6v1v/Ye1FaW6+AQZ82vFTbB2wFd8/+T2alWuGmd1mIiIgArVL1wYgjS/7pMMn8PP2Q1RwFKqXqi6s712vN1b0WYHjrx+37o/j2E3xaG3M0Q3WAiK3AFljHWEnnooois2pTSYFm/uNPfSQuBwcTCo0+voC48aJ67t1U3zplSvicqlS4vjkKboOsQABYrXr27fN78dR5eRd4gfpUNWxGTI1I2sKtVrcNf4HUA56lgt6JagQ8PcHhg8ntbEA+xsBB13pB7RWN+dGNjqCnTuJ68fbG1i5kiRGPvqo5bEC2v42lvx84JNPxOdaLUBXkq6g3tx6CPELQdoEkrL+6ELTQU7vOh0hfkRFDm0+FEObi1mf8tYp//T/BxXDKuLqyKswwigpgmgwGNC7fm+tfxZHR7gFiIN584iB4upV/Vxgb78tLq9HT7LQsiUAZQuQplijwEDgs8/I8uLFQFiYNM3j0CEgJkbxpWzNlJgYF1iAqACKi+OB0DZyJYmo2Dpl6jj0fbrX6O7Q4+uFUns8LQKenuN+fsRwGlbkebG3D15Bnvk7m6iacSa/fUuihr1Zsva3uX+/9LnW12+9uhUAienZeHkjsvKyFDMFzZ2H1AJEKRdC4hIrhFXgHd3dCC6AOHjjDVIb5+23pSLnm2+A9u3F59ZYR9jWXY1RpD4efxyAsgXo2jWNB27eXHn9mDGkq7wKaUyD5ZwcJ8cAAUIZAADAoEG8HpAN0GwaR1dlfrvN2+jXsB9md5/t0PexF6VzVYuVg+7j50cyv2gSgr0CKD8fwJzzwG3lMhTGsDiTdeYE0ODBpJsNhf0Na0FuIdNqAYpNEYONnlj8BEZtVO4RUrdMXdVjVAirAH9v0bwcE6J8Y8ZxLVwAcQSyssxbeewyka9YITTaUrIA7d2r8TiPPQa0amW6vo55q4BcANHJ0GkWIPkfbW3zphJOQWGBUEFXLWNLLwJ9A7Gk9xLFrvHuBBVAbC9hLRd51gIEiEZUXQTQ/TrA4eHK24NNBRBrwJXzzTdi4idAegoajcC0aSQo2hLyhCotn016bjqm/DtFsu6Hoz8Iy49XeVxYrhJeRfU4XgYvtKoozlNlgmxIfeM4HC6AOAIGg3kLiF3WkVqiT1zJAsR2gzZLcDCxbRuNUiFUqZLZl6lZgPz9gd9+E7c5zAIESFta81ggq6AdtAH3Dkx2JlTIsJ3W7RFAttbaoQg3SKf7AommNyQ5gaZpXJGR6sfz9yeJozT0Ly0N2LgReOcdoHNny+OR/z1abm6+P/y92e0/P/MzQvxC0KV6F6FRqRpze8xF4+jGGPHwCIv7clwDF0AcAYPB/CRolwWoRg1hUckCNH06cPCglccsiikCYJUAysgQ/5aAAJKl/u675LnDLEAA8P774mzOtr/mWIS6vwJ9At06MNmZULEeEAD0LoqhtUYA0QQAXS1AAJAbCnx7Dm8G/gvcbAls/QIAkOlragEKD5dasFi8vcmc1LEjeZ6aCpw/r/B+Ksj/Hi2fzW8nf1Pd5ufth2qlquH22Nv4p/8/Fo/VIKoBTgw7gTk95lh+Y45L4AKII0EPAWTS97NcOUl+rloW2Jtvaju+QM+i4OqmTcXu6yqwAoidGGkqMB2TQy1AAFChqPM4F0BWQQWQo91fngQ9V319RTHjSheYNCjbgIZhjwI/HgDOEXWW5H0WF+9fNHldvXrmjxtaVF4oLU06P1kyotpiAbqfpe6apkUxQ/1DuUWnmMAFEEfAYDA/CWoVBybHqCVNCVUTQEeOaDu+QJcuJL1r3z4hvkgNpQBKb2/xIkBjgZS6ZOtK1arkUWvnSg4AICWHNJTi7i8RekG3VgCxQdCAAyxARQhV1lMqC+vqfGvqGguzUP+P/jazs4GbN8X1caYGJQm2WICoq1Wp1QovVFj84AKII8HcJKjVApSeLi63wn6J+wuwoxCiEg0baioQoiSAgoPFQElanl/PjtGK0GBt1pbPsQity8ODSUUcYQGyJzlRPj8IRt986e/zQdYDyXNLdcHYvy0pSVyfnGz+dXILkNJnYzQaMWbjGHy19ysUGguRlkMmCqVUdVrzh1N84AKohMNOWgUFyrVFKFotQBkZ4vJmdCX9uRhsLoRoB+zESWGr5pYtSx7v3XPwQOoWpc6uWgX8+6+D36z4cCeNVNCm9VQ44u/Rz08fAVRYaLsL2Gg0FUDVqinvS1uaUCzdENH7m5wc6dxi6W+V38wpucBOJpzErIOz8O7Wd5GWkwZjUcuOUoGmTUl5n67iBxdAJRx2krCUBaLVAkQnqSgkINQnW5qmAp0tQBqh8QLtxB6EklYAVADp0bnaLGwdI9rbjGOWK0lXsOzMMgBcALHYagFSc4EBtrvBTOL+IGvLsWSNsLjh0gbJfpbmA/Zvs0YAabEAsTE/N1JJxXpvgzcCfUzz8wsKzdwdcjwSLoBKOGzMi1L8Cw1ZAbTfHVIXWDAyFIsTmrMAKU2kekAFUG2mQCs7QdMm7Q63AJV2bBG/4ki337th/01S1rdcKBdAFKUYIC0CngoDKnxoQUTAdgGkdHPk4wM8/XTRk4tPYUKtpQCAuBRp8I6fH4CgRMCgLDBYAcTOUdZagJRu8JKyRNPwtQekGqtarM/tNF66orjBBVAJh42NUfKpDx8ODBtGllevNu8io2SkEzNyCNIVUzzM3fHZG4ipRHo6SaEF1AUQrUeSmqrtb7QLtqujwxWXZ5OSnYIrD8RGbm0ra2t2WxJgLUBUzPz+u+WbCHkVdINBXNZTAHl7A3/8IT4PM5K4mj1xe/D4oseFuK60wDPA+LLA8/9TPLZeFiClvy0+PV5YvpYsCiDW3fV2a9LXZ1zrceAUL7gAKuE0biwuKwUAh4WJFpvDh4E5GkpaZNwgd1XByAC+/dZkuzkLEDvB6QWN//H3Jxn5FFYAsXHUWkvm28zDD5PgbQDYutXBb+a5bL6yGRFfRgAgF6V779zjAoiBjQF66inT9WrILUDssq0CSCm+xmhkMsEAPFJPDCzeHbsbQ9eS5qGng+aSlfVWKx5bqwAqKCzA5aTLwnO6b3hR4qAlAXQ9+ToAUwvQV12+wpGhRzC29VjF8XE8Fy6ASjjspKAUKBwaKu2ps2yZ5WOmn74OAAgJhnQGLMLZAoi942Xjfthltiu8wwUQADz5JHmkjZg4Es7eO4tuv3cTnreq2IpngMlgLUBswLHWAoGs6Le3GvSVIiMd+5uilqhLl4Bdu4C2D0ndl39f+BtJWUnw9hKtLezvUL7OnAAyGo3o9Gsn1JpTC3+f/xuAOJ/R0ltKLn4aXA8A5xNJZmaYfxgMEMdkMBjQrFwz+PsoDI7j0XABxDELawECtKXJZlwkRf6CI5UnDHmTQhY2hV4v1AQQaw3y9RXjIJwigF5+mTxu3+7g8tOeyZyDUlPjpMcnuWgk7gsbA8SWwbLkwlVqBEx/47ZWez9V1O/44YfFdTSxoGZNknzg5+1nElxc+qvSErERFGpqvtIigGJTYrErdhcAIp4B0aJNBZDcApSVl4WFxxcKz2l2Wph/GGZ0m4FAn0B+3hVzuADyQNauJd3bKampwPPPk8xqa1CLFahencQuN20KtGljgwC6mgAACI4yLSYGmLcAOSIGiG18SoOdAaAiU+rDYLAuk8Zu6tQBSpUiV7EzZ5zwhp5FVLD4RT1Z+0m0qdTGhaNxT1gLEHtTYUnEKAkg+npb0+BpkkH16sC5c8CJE5YLHFIKISq2F14lWVnUQwxoC4I+d++csJyZl4kzZ4ANRclm9Hcun1u+3ve14njC/MPQOLoxkt9LxuT2k7X9ERyPxCkCaO7cuahatSoCAgLQsmVL/McGgSqwfPly1K1bFwEBAWjUqBE2bJCmTRqNRkycOBHlypVDYGAgOnfujEuXLkn2qVq1KgwGg+Tf1KlTdf/bnM3WrSSzgi2u/NlnwPLlwHPPWXcstcmudm3g0CHg6FGStMTGClgUQHl5SL9GAntDKkYo7mLOAmRvQ0ZzxwwIkMZky++UnSqADAagWTOybHUJ7OJPei4xBT5d52ms7bfWxaNxT9gYIFssQOzvmv4mbbUAsWOpW1caW8hSYDQd3DHv+cLy4BH3sHw5sGOHuJ3+LlNSpDdtt2+TufCvv4BziaIASs9Nx+zZ4n5qAuh4/HHFMdIYIN5zzrHMn0+yjOfPt7irw3C4AFq2bBnGjh2LSZMm4ejRo3jooYfQrVs33FXJ19y3bx/69euHwYMH49ixY+jVqxd69eqF06dPC/t89dVX+OabbzB//nwcPHgQwcHB6NatG7JlV89PPvkEd+7cEf699dZbDv1bnQE7MVBs7aqgJoBq1pQ+Z4OFLQqg779HRjY5rYIrKbd6NieAHGEBYgWQry/Qowe5YPTtK92PTrSOEGGKUAF09KiT3tBzeJBNqgW3qtDKxSNxX1gXmMEAeBXN5rbEAOklgCwVOc0tMO/uTcm/h//9T6zMDoi/S3mM4s8/E2t4795SC1B6brqkNzKtPCGfW9gAaGvGyNGHDz4AYmPJo6twuACaMWMGhgwZgkGDBqF+/fqYP38+goKCsHDhQsX9Z8+eje7du+Odd95BvXr18Omnn6JZs2b4tiibyGg0YtasWfjwww/xzDPPoHHjxvj1119x+/ZtrF69WnKs0NBQxMTECP+Cg5VdMp6EkkBgLRY//6z9WGqTnc0CKCsLmDYNGSAvCI5Qng1daQECiKvw1i3Tv9OpFiBAFEDHjjnpDT0HKoCUKvJyCHLRQa1AtsQA2SuA6OssCSClHlss15Ov44s9X+DSfdGiT3+X5uYe1gKUkZeBUsxpQ2uPyufOhAziqt/9ym7J+qsPeJ++koJDBVBubi6OHDmCzp07i2/o5YXOnTtj//79iq/Zv3+/ZH8A6Natm7D/tWvXEB8fL9knPDwcLVu2NDnm1KlTUbp0aTRt2hRff/018s38unNycpCamir5544oCSBWNLz6qvZjsRagcKbHpM0CaNcuIDYW6T4RAKQBxyzmJkk9BVBKChkvPSadSP38gOho0/2dLoAaNSKPBw4AK1c66U09A9ovind/V0cugLSKGEcEQWu1AK3ttxYdq3VU3T54zWB8sP0DdPmti7BOKTNMDlsrKj03XRhPr17qFiBahygmJEay/vEqj1t+Q45dULdXZCTw+eeuG4dDBVBiYiIKCgoQLbvaREdHIz5e2fwYHx9vdn/6aOmYI0eOxNKlS7Fjxw68/vrr+OKLLzB+/HjVsU6ZMgXh4eHCv0qsDdWNUErltPWCTScJb2+A1Y42C6Cizz+jdBWT17GwFqAY6dyjmzn033+BiAhg5EjlCV8JtueQU6hdW5zd//c/4J9/zO9fgqB352WDyrp4JO4LG3cDWG8BUosBun8fSEiwbSyWBFCHah2wbeA2bBu4zex+sSmxwnJEhIU3NxQIYgYAjsUfQ04u+RDCw8W/U1L1Pi9TiDOLDpFeS95r+56FN+TYywcfiC5NWmjXFRTbLLCxY8eiffv2aNy4MYYNG4bp06djzpw5yFG5uk2YMAEpKSnCvxs3bjh5xJZJTJSWjaFiRP4nde4MbN5s+XjspMX63Nn2FwAQHCDOqGYFUFFcV4YPCSLUIoAaNCDVaylXrpAAbHv56CPy+O232gWQ0y1Avr7Ajz+Kz/v0cUI7evfHaDTiRgr5/VUKd88bEXeAjQECxN+VJQFkLgYoM5NYR+vV0/47mDdPrHdqzr3N0rFaR3So2kFccVXdKhQaSuYJVYISUWgUo6Nvpt7Eupx3ARBxSOehnBzRwkUFk7+3P0L9QoWx1ChVA8F+nh8q4e7QOdlp8ZYqOFQAlSlTBt7e3kiQ3U4kJCQgRn7rX0RMTIzZ/emjNccEgJYtWyI/Px/Xr19X3O7v74+wsDDJP3dDrpTpj1l+Em3bBnTrBouwAqhsWWDFCmDjRlOTs+87o4XlwgIzCqhIAKUbQgGou8DYSdLHB+jfX5o1YiFJUBPsXSObBm8OpwsgAHjpJSAujszyGRnKUe4ljAfZD5CRRwq+VArjAkgNtRggS24seTd4QPxN7t5NBNSDB9obA7/xhrhsyQLEMuKREQCAmLzWwPleZvcdMMDMxhBTb8Je43RhPJGRYoA47TxDBVB0SDQMBgMWP7cYb7R4A3+/8Lf2P4BjE/PnK4twV+BQAeTn54fmzZtj2zbR3FlYWIht27ahdevWiq9p3bq1ZH8A2LJli7B/tWrVEBMTI9knNTUVBw8eVD0mABw/fhxeXl6IYgvBeBh790qfZ2eTi/XJk8r7W+oJJJ9Ae/dWEE5pafA+JQbpFl66TGbYOXOAadOk+xaJ0gxjURC0yo0UO0nSSZs1T9sah8DCCqCbN8mjVgHk9LuSSpXEytDPP++YVDgP4mYq+cLKBJVBoK9pV24O+W3T37e1FiAldxV9LZuQaEtRUmsE0LN1n8Xmlzbj+/Z/A5lSV6ePlw+MjLmZvbc1SbEPIfNO5fDKJu9BSwTQoow0SiIhnbyG1psqF1oOc3vORYMoc6Ymjh5MnUo8Cd7ero3/AZzgAhs7dix++OEH/PLLLzh37hyGDx+OjIwMDBo0CAAwcOBATJgwQdh/1KhR2LhxI6ZPn47z589j8uTJOHz4MEaMIHcLBoMBo0ePxmeffYY1a9bg1KlTGDhwIMqXL49evXoBIIHUs2bNwokTJ3D16lUsXrwYY8aMwUsvvYRSpTw3qyRQdi3IzgYuXlTfPyXF/PE0+e3Pn4c3U6gsLdsXePNNElzzzjvAnTvEglG1quDLyjCS9hdaLEBUALHXfFuLsbGwAujLL8mjpWBKl1iAKB0Yd8Bff7lgAO4D7dDNW1+owxYPl8cAqd1ApKUR0WROALGix9ECyGAwoEuNLni6c1n8OFsqgPIL85GaIyaisAKoXDlZXFAgOV+ql6qOtx4hpU78jeGAoVAYD309dRwIFqBghWwIjkNp04acq88/79r4HwDQ6LG1nb59++LevXuYOHEi4uPj0aRJE2zcuFEIYo6Li4OXl6jD2rRpgyVLluDDDz/E+++/j1q1amH16tVoyJQGHT9+PDIyMjB06FAkJyejbdu22LhxIwKKbvH9/f2xdOlSTJ48GTk5OahWrRrGjBmDsWM9u5mdvK1WTo75kBFLlgxNAuj4cTyEE8LTRJQBFiwQt1epAjz2GCnoUER6IVFqWmKA2NgDit4WIIpbusAor75KgimOHRMbK5VQkrOTAfAMMHN8+KG4rMUCdPUqUKMGqX2lVQClpVk/LmsEEMvDDcoCe6TrbqXdQngASU+tVg2AbwYQmISgoEoIDweSk4t2bPkNAKBUQCmMbDkSc/6bgxxDCjCgK/z8SLNhmjMjWIAypBYgjnOYPx/4809yju7b5+rROEEAAcCIESMEC46cnTt3mqzr06cP+vTpo3o8g8GATz75BJ988oni9mbNmuHAgQM2jdWdkQug7GzzAsiSJ8WiAMrMBIYORRCAO4hBOcQjAyHIRCCCkCUeZPt24SWFDRvj5mlyQC0CSMkCpIcA8lKwbSr0ZVXcrpRp53C8vUmFxmPHgEmTSPrKqFEuGIjzmH94PmpF1kKn6p0k61Oyieky3D9c6WUcANOni8taYoDmFjVcX7ZM7I2ldCNirQCSt7GzVQApZft1+a0Lroy8ggCfAMR57QQ+6AAU+OL+hYOIiGhK7rlKXQUqkRTWMP8wlA4sLR6g+jbQ4ZAkDyNm3uqDVUvzUCWcZKpyAeRcPviAiB8vL+A9N0i2K7ZZYMURJRfY/fvq+9stgJhZNnrvKsHUfg8qqcm7d2Ppu2K8kJY6QHTSZpOh9HCBKR3DUh3MUBK7bdOdry6w8WmjR5PUGpeYoxzP8fjjGL5+ODr/1hnXHlyTpDFzC5B1UPFizgLEZm/qaQF68EB5LNYiT0UHgNtpt/Hbid+QmpOKLouLXMTeefB/bJ5Ytyw8Ttg/pyBHsBhRvH1IoFR4OIDgeziZvxJrLqwRGqcqxQ1xHE9EhOvdXwAXQB6FXKh062a+jdSaNeSOTw2zAigvTwxynjkThjathUDCRBTFZrRvT1K4KK1bY/4C8ZTSmgUGkMO89BJZ1sMCpHQMrQLIER3pNSH/wN56S+rrKEZQkQMA1b+pjrYL2wrPU3K4BcgaDEXN1M1ZgNiECKqplQQQa9HRIoDk+9hqAfIyeOHDx8i53jBKDHfIys9CXEqcZN+CsCvExR0eC0SJLZLSctLgZZBd0vzIjzk8HEDoLWH1yQSSOVItopptA+bYRLdu5DzVkqXsDLgA8iBKl5Y+v31bajmR8/77wAsvkDhlJRQFUEEBKSK0ZQtpM1+6NAl4BiMQ3nyP5Mvu2EECn7duBY4eRYHBB3sYP75aATMlFxgg+un1EEC2WICo/nCZBYjtcEtZuFB7PrIHkZMvtWxdSrokuL4m7ZwEQGxKydGGVgsQTY5QEkAsWn4H8oL5tgogAPikwye4P/4+tg8UXeqjNo7CnTTpBHYz9Sbyog8CI2sBPcT+jibiB0C+T5GYDgcQdstke9WIqrYPmGMV7hb/A3AB5FHI/e1aUZvIFAXQH38Qed6zJ3n+2GNCQA0VEBk9+pD1lE6dgKZNsWWLuOqXX9THo+QCA+zvR8TikRagxx4jn//Zs2Jtg6QkIozkvgYP4sDNA1h2WmqKpFV4WeJS4iSZPzw+wzrMWYCU5gAlSyyLUhKF3CPbooX0uT0CyGAwIDIwEmWDy2J0y9HC+pdXvwxAFCs3U2/iatnZgLf0LmdKpykmx8z3SQZABdBNybYAnwBUL1Xd9gFzrGLqVCJ+vL3dI/4H4ALIo7A1HEQtG0xRAG3dKt2JETqCAMpQPh5rZm/SRH08ahMvXXZVDJDLLUAAMdnVqwdUZybm1FSgXz/XjckOjEYjWv/UGi+sfAFn7p4R1qflmn7IcSlxkligoc2HOmWMnojSuWzOAiRkTDFYsgDJf0Pbt5ObhDlzxHXyyvD2CCCWW2miteZOOrEANS9Huppm5mXigt8fkv1fjZ6nWMMnz4uxAAVLLamtKraCv4+GRmMcXXCn9HcKF0AehDkL0MSJRFWz1V0paoUSFQUQDSigdO8uLFoSQOxkaFKsjEHNBWZvQ0YWpWNYmpxdHgTNIr/CbdoEqFQxd2fYCxkby5GWQz7kPvX7oFfdXgBI/ycqgKqXqm4S0MoRqVOHPH7xhbjOnAVI6ebJkgCSH+f558mcUeQRV0QvAfRcvedM1jWKaiTN8mIILRQrhpcLLi8s53olAwDCwiDUC6IMbzHc/oFyNOGO7i+ACyCPwpwF6IkngClTSL0P1jsFqJeRpwJIMvmxTWqnTgXq1xee0muymouITpgtW6qP0+T9FNY7KgbIbB8zuIELTA57dQOAM2eU93Njjt4RSws/yBbdeNQFFuIXgsphJBMnLiUOB28eBMCLIFqCWnVbtRLXmbMAKd08WWsBkt/4KP2eLFWf18rzDZ43WVevbD1UDKuouH9wviiAavy7S1jOLbIAhYQACJSmzD5Z+0kdRsrRgrulv1O4APIgzAkgGkBcoYJpsDSgPDHR40kqJFMBtGED8O67kv0tWYDoxMtadZRQu0vU0wVmKRNGCVok3G1ijt97Dzh/Hniu6G740iXXjscG1l5YKyzHp4vimrrAQv1CUSWC1GQ5dfcUxm4mxUpD/FRSCDkAxBIXbGkMay1ArOhRqhovP47cla4ktPTq4uJl8MLUTlMl656t+6yqAArNFRMI/v27JnC2NwAgB8kAij6nIFEAjXxkJIJ8LRQG4+iOu6S/U7gA8iDUXGBPPFFUKbUIeb0gQDmGlhb8E7wtRiMxIQGkP5UMrQLIUi0Qdjt7F6mnC0xJRLVta7qOhSZhxca6vksxAOKOrFMHqF2bPL982bXjsZKkrCT8eExMU2SzeWi7ixC/EKEWy4ZLG4TtLzV6yUmjNCU/Hxg0CPj+e5cNwSL0/FTq6G6LBUgpU9Tc73D6dOXfmJ5t7Ma0HoPHKj8GHy8fHB16FL7evqhRqoaw3ddL/APys+VF0iIAgFSERtHnVOQCW913NWY/MVu/gXIs4m7p7xQugDwINQvQ669LnysJIKWK0VTICALo5k0ScOvjI150GSwJIDphWrIAqQkgR2SBPfww+Xf+vLJljCUqisQKGI1upjWoMvMwC9DAVQMlz2mGl9FoxPpL6wEAzcs3N6kC/PJDL2NQ00HOGaQCv/8OLFrkXneqcqgAsscCxAogpern5iyxb7+tvJ26kfXAz9sPuwftRup7qWharikAoGVF0b8+u/tsBBSWAbZ/YnrDkk3ix7KMyQCoBYhMgqWDLEwEHF1x1/gfgAsgj0JNAMkv7FoE0NWroi9WaBFBz87atRWjqfVygbECiP2bHBEDNHgw8N9/YtCoOQwGoGKRhd1t3GAAULMmefQwAURFDuVm2k1M3jkZh28fFjq+d6/ZHZGBkZL9GpR1bUfuY8cs7+NqqKXFHgsQ2y5GXs8HIL9Do5F0iFf6PciPOWwY8KQDwmoCfcUJrVuNbqgSXgXP1HkGQ5oPwVvZd4HdH5lanoosQNlGYgHy9zcKhRDLh5YHx3m4a/wP4KReYBx9UHOBhcnqxSk1/ZQLoOeYJItgvzzg0HGSgg0Azz6r+D5aLUCWXGDsxMsmNtE7Uj1jgKzNSrH0N7oEqt6uXSNB6bNmAV27unRIlohNjjVZt+HSBmy4tAEf7/oYAOn/FOQbZCKAYkJiTF7rTK5dc+nbmyUnh/w8rbUAWaohlpRkui4vD/j3X6BdO+XXsL/TuDhFr7nulA4qjWujrsFQlK0aWDTXmQigHGIBWp+wAMD3yDIkAb7kQysXwgWQK3C3+B+AW4A8CjULkNzsrGQBkpu4T4gN3hG8awPwyCPiipdfVnwfWifHXgsQCw05AhxjAbK2N5FbCqDoaCCmSBScOwd8/bVrx6OBFWdXWNyHxv7IBVCpwFIOGZNWEhJc+vZm+e034J9/xOdaLUB07njiCeXjPm+adIX8fFJ9QQ06p/j7O0f8UAxMqQ46182bR9zcAinigM4nnseD/KJyDBllgHyFO0SOQ5g/nzxGRgKff+7asSjBBZAHQScxeakeuQVISQCZuwMMOvav+KRjR+WWDNAvCBoQ+36NGSOu01MA0Ym7WFiAABLIRJHM9O6JvH+TElQABfoGIsBHvCiVCnCtAFJyB7kL8hIN1lqAglQSn5QuTnl5QI0apuspX35JHl3Zr5f9+yX1iS48LSyeSjiF2PSL5ElKFfdIcCghfPCBaF10N+sPwAWQR0EnMX9Z8VK5BYi9K6STImuulgf4+qJo4//+B2zcqPr+egVBA8BPP5HYnFGjxHVUAG3aBPz6q+VjqHHwoOkxtWLJyuUyfvxRTKG4eRM4fNi147HAtWTiR/qm+zdY22+t4j5sJ262wJ2rLUDuLIBY97GXl3IldSUXHhUpatXQAwNNt+XnK1uTKD/8YHm8joYVQJIyF0Zv4OirAICz987ivzv7yfpbD3MBxBHgAshDMBrFSUwe4yOPV2YnBVqUkLUAyTVOEiKJ1Wf5crMmE72CoOmYH35YOqGzb/3yy8rl+7Vwi+l5aG3/NLe1AEVFkS+OtsR4+GF9gqUcgNFoFLpt1ytbD8G+ylddVgDRWkCA6y1AbE0cS8UznQ1r/Q0MlD6n7WfY3zcdP/0dmGsHI/+t5OW51rqjBXauM8nbSK4KALiRegMrzi4n62If1zVVn2Med01/p3AB5CGw17pSFq4PrKigHdnZyU0eEJ2BYKB/f4tj0CsIWg3562xt/sqa+dkYIy24rQCiPPWUuLxtm+vGYYbrydcRmxILXy9ftK7YWrXgXKUw5cARV1uA2O9eD3esnrC/bfmNUOfO5JHeOHz3Hdm/b19RyKi5wABTPZ2fb/tv0FmwAsjk3i2DNNM9cPMAbqTeILE/53txC5CTcOf0dwoXQB4COxH98QdQrhxZ/uQT031ZUzAVQOzkxmZ8lME9vIU55hv8FGGpFYYtQdAscgFkzvxuDvYOr3dv615r6W90OU+LsQ146y1iubt5U31/F3A77TYAYuEJ9gtWFUBs88pedXoBACqEVpDEAzkbueBxNyMba/GRu8LpuZuZSX47b75JntOLECDe58i7uCuxaRPw11/2jdfRsCJQTQCduUdayPikVQPyA7gAchLunP5O4QLIQ2BN0S1aALdvE/P2Rx+Z7ssKIHrHRwVUerrYzXna23eQgGhULZVq2awEqXWksND0YmFNELQS8gnM1osPFUB160pamWnC7S1AwcHA8KImjpcvk39ulhWWmUfSg4L9yIepJoDqlaknLA9/eDjm9piLE8NOKO7raK5cIanlSm4gd4IVQPIbDfpbz8xU18Q1a5IboP37tb3fv/+armveXNtrnYEWCxDFL7MqAH2rVXMs447p7xReB8hDoALI29uyhYWNW6B+cTqx08wNACiVnwgvGIHKYiyGOVhx0Lo16Zp+8qRp9patFiD5Ha29AshcBosabi+AALFaI8XNbmkz8siHR4WP3KX109M/oXF0Y0k6c4hfCN54+A3nDZJh717SJsXfHxgyRLrN3QSQPAiahQqgrCx116+fn/3Vmps2BY4cse8YemE2BihNWu/HP6sKMuH+br3igLunv1O4BchDoAJIoUCzCeYEEHvnl3a9qDlg1aqaxkDFQUEByeA6d45Yoij2usBoECfF1olKqVGkVjxCAPXoIX1uba6/g8nIJR8eDX6ODIwUihvO6zkPrzZ9FS3Ka/DBOAmacZiTA3z7rXSbuwkg1gIkL4fBxvdcuKD8entPlZMngWbN7DuGnrDWZtMg6GoI9RPVnn8eiRtwt++0OOLu6e8ULoA8BLUUeCV69iSPVaqIkwL90bPdFNqtLirCQ1PFLKCUQcJOQPYGQdN4JYq9FqBiK4CaNAFmzBCfK3W6dRGxybF4aRUp8kRdYABwdOhR/PDUD3it2WuuGpoqciHB4m5B0GxcnNwCxJ7vSs1N/f1NX2MtjRrZ9rtyFKy7X0ncjW05XlgOyCcinFuAOBQugDwEagHSIoCqVycxAOfOiZNCbi6pbxJXVJ9uc6XBaIrj5AmbWWQGX1/TSYadkO21AMnhAsgMo0eLRZTol+pATt89jStJV1S3D183HM0XNMcLK18Q1rGxP+VCy+G1Zq/Bx8uzvO7uZi1gxyMXM97e4vygVM3aXAYYIP625YVV1fZzB9ggaIPBNHEiKqSMsBxYGA2ACyBH4ynuL4ALII/BGhcYAFSoQAQA6wI7fbpoW3gautxYSGa6K1eAhg01j0OpWBrF3iBoQNpM0ZaJKjub/ElAMRdABoPYssTBlaETMxPRaF4j1JxTE0aFwjhJWUmYf2Q+jt45igM3Dwjr1er/eBKeJIAAUeQoCSBLv4fDh0ky6MKFytuXLiWP8t/3o4+aP64jeeghcVmpblFMqBgIHcQFkFPwFPcXwAWQx2CNC4yFCqDvvxcnqoYpe8nChx8Sc5EV0ErJFFYA2RsEDQA//ywu23LxefJJMabDFgHktpWglahThwihxESS66wzqTmpWHp6KY7cFiNeaYYXy/Zr2xVfr5b95W6Yc4G5swBSstRQAbR6tfo2NRo3BmbPFtvOsbz+OqknBJhagL7/3vxxHYnBAEydSpaVBFCX6l3wbN1n8XiVxxGZ2wQAF0AcES6A3AAt1WatcYGx+Gab1vWvjDhin1RpemoOLRYgewRQmTKiQcqWiw9bG9AeC5Db1gFiCQoST56+fW0vna3C0LVD0W9lP7zy9yvCuqQs07bhm69sVny9p1iAPCkGiP1NyGPmAJKZqYYlAURR+v2yqeNyC5CrXWKslVsubkL9Q/FX37+w85WdCPQNEPbjOA53r/7MwgWQi8nNJWmlL75ISvC3bQvMnCnd5/p1YPBgsqzVBQYAGDAAfh9/YLI6yDcfOH6ctFewEnMCyN4gaAqdUK0VQJJeQLAt3ddjXGCUzz4Tl//+W9dDLzuzDAAQnx4vrJMLoOz8bNXO72wQtDvjSRYg9uKtVLpr/HjTdRStNwRKwsqcALL3924v7HxhrnUHnTtnzAD27HH8uEoinlD9mYULIBezZw9w4gSp7vztt6Qmydix0n2efZaIIMC0/L0i6emkI+jvv8MPprc7gS0aAJWU2xBYwtEWIMA0dV8rcgNIdLT1780KoCvqMb/uw/jxwCOPkGWt1e3sQC6AVp1bhQfZyllobK8vT8XdBJClljgfmN7vCGi1ANWtC/zyC5mLKLTyPGBq8fE0AXTuHNCunePHVRKZOpVcB7y93bf6MwsXQC6GDWRka+qwbrHjx8Vli3dxFy+Suj6tWgGAsgCqUd5knVYcHQQN2G4BundP+lwplsES7N83a5b1r3c6vr7AO++QZSd0iB+1cZSwXGgsxNjNRK2Paz3OZN9akbUcPh498CQLEDse2urCHOzNiFYBBAADBwJt2oj9dydNEre5mwWIvWFiBdC0acr7cRxHmzbknHv+efcPgAa4AHI57OTBWjBKlZJOOhSLFqCffgLu3xee+tYzvQgFNrAu8JlFiwvMXguQXgLIFgsQG+StVzq/w6H9Pi5eBGJjSb8wtk6QDShlewHAqbunhOWPd34suMdaVjCtJdUoupFdY3AWnhgDNHQoCVq2xIAB4rItMXHdugFLlpCQQYo7W4Co1bhCBWCcTJPLBRCPBdIXT3N/AVwAuRz2IhsbKy6npCg3OlUVQJmZRPzIuhf6jzKV4YEhtl/Z5XeRjqgDxNYusgZ54LINIU4wGIDJk8myOXO6W1GjBjElpqUR69/atWT2TzINWNZKQoY0j7pGKbGvSF4BuQovOrFIWNe0XFNhOSIgArGjYxHiJ0sZdFPMCaDOnW1vyusIqAAqr9GIGx4uLltjATKHu1qA8vLEOUApWUQugOz4eXAU8DT3F8AFkMth7zDlFgwK6yZTvIvLzSU5sa+9RppjUh59FEEVI012t6eSq1yAOSIIWl69WitsS6xPPgHKlrXt/Wl6cappAp174u8PVKtmuj4+3nSdRi4kir0UfLx8sOx/y+DvTa4qN1JvABBbXnzU7iPUjKwp7F8upJxHxf9YsvKccE1/VkXob0JL5lXlylKLrV4VnN3VApSbC3ToQJbZcAL5fhTGUM7RAU9zfwFcALkc1sqh9oNkLSqKFqBvvpHeps6dSyJ4165VnPTsuRM0J4D0tgBZK4BopkqHDsBHH9n+/h4ngADlPiV379p8uN2xuwEAT9R8AonvJKJ5+eaCyDl37xyy8rJwP4ucsGNakZYqVCA9Vvkxm9/XFViyNNrbPkJP6FjNCaB//iEFAletkp4WxdUCxAogmgmq1B+YW4Acy6ZN5BqwaZOrR6IdN/ppl0wsCaDcXOkEo3gXxxa/mTIFeOMNUuCwVCnFSc+eYEAtAshVQdB00rP3TtfjBdBjRQLks8+Ajz+2qVv8inMktf2Fhi8gPID4Uaib6+ido7idRm6xA30CEREQQda/fhQfPPYBvuzypY1/hGuwJIAyTWs/uoyUFPLIurbkdO9OEieaNZPGtOklgNzNAkTnpGvXzO8nn/doZi2n5MIFkIuxNPmmp0snGBMLUGysKLnDw0mPKAYlMWBP8J/ct56fTzLWdu8Wzc56ucCsHSe9zmsqFWAGKoDoxcYj+PZbkr+8fr0Y/LRtGwlosrIhT35hPs4nkvYarDWnYVlSofLC/Qu4+uAqAKBKRBUYioJo6petj886fiYIIk/B0nnWu7dzxqEFmiihVKtHCUe4wNzNAkTnJEsuLbkAOnPGMeMpiXhS/y8WLoBcjKXJNyPDggD66y+iQB5/nMyOsh2UJj0bDAICcvdWfj6weTN5e1o3xF4BQic0a8eptwDyKAtQs2akwEmPHqbxQD/8II2wt8CFxAvILchFoE8gqkRUEdZHBpJ4spScFFxKugQAqF26tv1jdzGWgt3tCKXSnQdFJZeUagAp4WgXmMHgeheh1t+7XACdO6f/WEoqntT/i4ULIBejxQLEVjQ22X9FURVeldtUpUmvYkXt45Mjz5jJzwe2y1pBuUqA6CWAqGi0Ryi6lMdkMTgJCSQ7jG20pkJ6bjoe+ZEUVmxXpR28DOIUQV1hKdkpuJxEgu09pdaPOTwm2w/WW4DYucMRLjBXW38A7b93ufVaLemEU3LgAsjFKMW57NkjTizp6VLRce8eSJGFzZtJoZt9+4hZ5tlnFY/PWoAiIoCvviKpvbYiF0Cff06yTVjsFSA0vsFaFxQNgrbX1E8nSk+6MEro1EkM/qCdYQHg1VdJARmFtKf8wnyk5qRiyp4pQsPT/9X/n2SfcP8iAZSTIqTJlw+1vaimq6Exa4mJlvfV0q/PGVhrAWKFkiNcYJ4kgOR/P88C0wdPdX8BgBucviUbJQtQjRpArVrERJuRIe3Dc+/iA7GtO+XDD1XNOuxdX8eOYtFgW5ELoOPHTS8OrhJAelmAbHXBuQ3BwcClS8DJk0DXrqQ+EC0b/MMPpHFqp06Sl3T6tROO3TkmETR1y9SV7EMtQNceXMPF+xcBAGWCyjjwD3Ec6elAvXpA8+ZislxYmLrVMS/P9ZWE8/PFOjdaLUDsfo6wALna/QUo1/xZsMB0nVwAaRG+HMtQ91dkpGe5vwBuAXI5SgIoMFCcrDIzpVkonb22mb5g5EjV47NiQI+CbkpF4+QFCK3tWC/HXQSQx1qAANIHpGtXsjx8uHTb6dM4ducY9t0g5VrTctKwO3Y30nLThOwuwNS9RS1AablpyM4nH3bpwNIO+gMcy+nTwM2bpH8sDY86fhwYQzL60UhWxNodqgaz84BS1QMlHG0Bcgfkv/c2bYAhQ0z3k//9Dx64V5FLjvPhAsjFKE2sQUHi3WZ2tigwZr5/D2/s6SfdefhwaZ16GaxgcdSPXZ5NUVwsQAUFxWSCNBiAiROFp/lnTqHZgmZ4dOGjSMpKwn+3/hO2peWmAQCaxDRBdIi0l4hSdpenWoCUXFpRUcAXX5BQqc2bgd9/F7e5gwBi3eVarVHm0uVtxd0sQPLfu5pAkwsgo9E9agFlZQEvv2xSxN9j6NaNRGF06+bqkViPG5y+JRv5xOrtTSYYehGmPn8AGJI7F75g4je6dAG++87iezQk2ct4/XU7BwtlCxB7oQBcJ4CoULT3Tpcdv0dbgVgmTRKCoG9eOSasvjdiEBZtmmqyOy1uyFI6yNTaQ91inobcvRkYSKwqAQHAK68QA1r//uJ2dzgP6FxhMGgvNspaivQScWazUl2ArQIIIN5hVzN7NgnVc6dyC1rxxP5fLFwAuRj5RBwaSiY4eof32mvk0csLCLp8kjxp3Jiojt9+0/QeO3aQkJAnn7R/vOb6JlFckQVWUEA6V7OvtxXWheexcUByvLxIqjyAa1ePCqtTN65B7KGtJrvTlHcWP28/PF3naeF5qYBSqBah0ILDA5ALGrWYGnouu8ICtHQpKe9EoRYga2KR2N+r1sBpS7AWIHvd3XqgtTCjkgByhyKXN264egS244n9v1jczJtbsjh92rThKU1blU9yISGA4RIJPMWXX5JyrxopU4b804N27Szv44o09Hv3RGtZnz72vb+PD7lwGI3uceevG3XqAAYDrpUS/T/JAUCqwkUsOjjadCWARlGNsObCGgBA7OhY+Pu4wRXQBuQp0Gqi2d+fnIfOFkDnzwP9irzdzz1Hmp9a0weMZelS4NQpUqtLD1i3lzsIIIOBzDl0vrBGAGVkOG5cAHGxBQWZnxNp+w5PpE0bEkvnSf2/WLgFyIVs2WK6jq3bwZKaCmLGAciFzEW0bWta90eOvZMinSzY7DdL0AtUQID2DBk1DIZiEggtx98fye+PxeBnxFVnooA0BYtC/bL1FQ/RtnJbYTnUX+VkdXP27gUGDpSuUxNA9EbE2efBunXiMq39Q89xa7PR+vYlXVG0WG+txVIjWWfBCgx3sQCdPElKhLRta34/T40z9HT3F8AFkEtRMn1SAaR4V5CbS67M8sI7ToZ2XFbDFRYgeoHSK1WZHqe41QpZ01BqPhjTXdkCFOynnGbUrUY3LHpmEY4OPaq43ROYMMF0nVqwsK1tWezl+HFxmV6kbbUAORJ3cCEB0jlHLT6KFUC0TJYjLUBffEGOf+SI+f081QLk6e4vgLvAXEpcHHksW1Y0yVMBpFp4rVYt+9utOxh7BRB9fX4++acl7dbWu2M1aPxR9+7Fq2JsYZ1awAXpOiqAWlZoiXKh5TC65WjV1xsMBrzc5GXHDdAJKLmDzbnAAOcLoMuXxeXZs8m8MGgQee7qekQs1lhpHQkrbrRYgMqWJUkTjhRwtL4UQESOWsacp1qAPN39BXALkEuh6eOsR4sKoILEB6YvAICmTR07KI2YSz6z1wXGTlRarUD0AqV3TIKnFEtLzk7GjP0zJHV8lDh976zJutyiC8bGPn9jVd9VeLyqTsEibgrtFcuiZgGiF62jTjZ4sSnvv/8OzJtHUvMB97IAuUuSACtg1QQQOzfQgHBHWoDYFHtzLlTWAuQuFcctURzcXwAXQC4jLo4EOnp7S1tThF45BhgMKDx0WLL/OvQkCy1bOnGU6gwfTlpMsVSoQEyh9hqoWAuS1jtMvS1AnsYb69/AuM3j0HNJT9V9cgtyMX3/dADAkxdMt4fetbLugIdijQWIWmJoIW1noRRbQ8s8uIMAoiKjtpv0wtUigMLCiKXitdeA6tXJOkcKINZ9bm4eYwWQUmskd6Q4uL8AJwmguXPnomrVqggICEDLli3x33//md1/+fLlqFu3LgICAtCoUSNs2LBBst1oNGLixIkoV64cAgMD0blzZ1yiAcJFJCUloX///ggLC0NERAQGDx6MdHnJYhdy8CB5bNIEiA4U871DT5KW6oWyr6YpjgEDBphGb7qQ0rKyMIsWAVOm2H9cLy9pIUgt6B0D5GmsOEua4h6PPw4AyMnPQVaedNa9lXpLWO5QWZrOF54NeNeuYznCvZgSE+PqEUhREkD0xsAdzvGDB4FnnhF7MbsaLQIIIJa0H34Q6yM50gXGCiBz8xgrgDwl6aJNGyJ+PNn9BThBAC1btgxjx47FpEmTcPToUTz00EPo1q0b7rIOUoZ9+/ahX79+GDx4MI4dO4ZevXqhV69eOH36tLDPV199hW+++Qbz58/HwYMHERwcjG7duiGbOcv69++PM2fOYMuWLVi3bh12796NoUOHOvrP1Qw1qTevnwW/SRMAQwFgKEQoSGUuuQAKrV2eVMtSSxNzAXJLj169hgDRDcYtQNowMCk+9zPvo/Ksyij7dVlcT74urL+TfkdYHvLxOvblaHWzaKFTJ+VGSsUIpXiecuWcPw5zKMWF0AulO1iAmjUDVq8m/dTcAdaFqSVmkM5VjrIAFRRI5y5zAoj9rt3FpWiJTZvIuDdtcvVI7MPhQdAzZszAkCFDMKgogm/+/PlYv349Fi5ciPcUbGezZ89G9+7d8U5R185PP/0UW7Zswbfffov58+fDaDRi1qxZ+PDDD/HMMySf99dff0V0dDRWr16NF154AefOncPGjRtx6NAhtGjRAgAwZ84c9OjRA9OmTUP58q7pYG00GnFw1zncS7yJrWeSgPLRqLluK4zZiUDfJ4DoEyj46QVkXE1AXoM4IJH5da76HRmOLlphA717AytXis/1GiL11yclaTsmDVr28dFnDD/+SEzljRo5vlaIHhhzjUDRRLrj4g7cfUBuMPZe3ouy9coCAM7eOgvkAq0qtoIhzwAwQqD+LUD4M19/nXyx7lDm1wEoGYJbtrT8PTvzPFASabQyure3Z5yTzoSNGywstPz50Ju3tDTHfJbyCtNJSUB0NLBmDbHuPfecuI09H2ndIL0ICgqS3BxxZBgdSE5OjtHb29u4atUqyfqBAwcan376acXXVKpUyThz5kzJuokTJxobN25sNBqNxitXrhgBGI8dOybZp127dsaRI0cajUaj8aeffjJGRERItufl5Rm9vb2Nf/31l+L7ZmdnG1NSUoR/N27cMAIwpqSkaPxrLZOenm4EwP/xf/wf/8f/8X8e+u8xIxCj2/H69eun2zWWkpKSYgQsX78d6gJLTExEQUEBoqOlVWWjo6MRHx+v+Jr4+Hiz+9NHS/tEyVI9fHx8EBkZqfq+U6ZMQXh4uPCvUqVKGv9KDofD4XBKClcBKF9HbeHPP//U7VjWwusAFTFhwgSMHTtWeJ6amqq7CAoKCkJaWhpSUgyo0yQJmf0ehTFUWg2xeehT2PXWH4h+twMywg8J69MmpLmlKXP7duDpovZQ166R+hp60KSJmIGzZQvQurX5/f/4AxgyBOjYkZiZ7eXwYaB9e6BiRZKtZzSK4VdNmwJ79tj/HnpRaCxExNQIFBpJkEi1iGq4lnwNAPB+2/fxfrv3EfJFiLD/+DbjMbH9RNT8pibi08lEtuCpBXix0YvAyJHAwoVkx/h4sWJcMeK110h7iC++IH+uOeLigPr1iUvWmUUxq1eX1pFh8fKyrk9eSWDePKAoagIjR5Lv1hxffw18/DHJKaElPe7cIcHwekyzJ0+SQGHKunXErUldX8eOkZJuANnv5Enp62fMAGjIKj0HASA21jT5RI1BgwZh+fLltv8RqvQCsACAPilrzz//vC7HsQWHCqAyZcrA29sbCbJ86YSEBMSopF3ExMSY3Z8+JiQkoBwTuZiQkIAmTZoI+8iDrPPz85GUlKT6vv7+/vB3cGMbg8GAkJAQhIQAV04HY1vsAvT/5wnJPvcMVxAcHIyQ0CBkMAG9+xL2oWuNrg4dny2wvvcyZaTdp+2BDbD29rZ8XDppBQfrMwYaVJmfT47Hnk4REfr9nXpwP/M+Cn3FVJJrmdeAonMnyysLJ5NOCs8BICYyBsHBwSjwKRDWl40oi+DgYKBrV1EAHT4M9FRPq/dUaDBxaKjl7zGyqCdsTg6JzXDWPYi56sCFhe51/rkD7I1XUJDlz4fqevpZLlpECk2+954+maxKQezsHGI0imNUyvwaOxYYM4YsszFBOTnav3tHWFbmzwc++AAAvsXnn3t2Bhjg4CwwPz8/NG/eHNu2bRPWFRYWYtu2bWitckvfunVryf4AsGXLFmH/atWqISYmRrJPamoqDh48KOzTunVrJCcn4whTg3z79u0oLCxESzepoxMdDbz4SHeMajlKsv52zmUUFBagbi1pqke337s5c3g2odRrx1ZYAaTlblfvLDB5CwS2bYm7parezVAxFYBkfrVZ2EayrnQQuYW8lymWuA7zL8oj7tNH6BqP77/Xd6BugjXnii1FOfXAXXpseQpsFpiWOmTyCt9vvEEep07VZzzywOrsbGlhRHa7pVR8tjYQewxX8MEH4hg8XfwATkiDHzt2LH744Qf88ssvOHfuHIYPH46MjAwhK2zgwIGYwDTnGTVqFDZu3Ijp06fj/PnzmDx5Mg4fPowRI0YAIFaU0aNH47PPPsOaNWtw6tQpDBw4EOXLl0evXr0AAPXq1UP37t0xZMgQ/Pfff9i7dy9GjBiBF154wWUZYGo8UuERyfN8Yy7i0+Ph7+cZNSrZyqV6duhgy8bTZpDmcFQvMDpBsndvblROCoBUyMhZenqpybrIwEiTdSF+RbfEXl6k9DAA/PNPsfS1WHOusAJID8uAVrgAsg6tdYAo8mbHerf0kM8RL78szQxjt5sTQOfOSd15rhZAxQ2HX2X79u2LadOmYeLEiWjSpAmOHz+OjRs3CkHMcXFxuHNHrE/Spk0bLFmyBAsWLMBDDz2EFStWYPXq1WjYsKGwz/jx4/HWW29h6NChePjhh5Geno6NGzcigEnbXbx4MerWrYtOnTqhR48eaNu2LRa4YX2TqGDTuvzdF3fH5iuk7j3bffusQhsDV9OqFZlw9K4HwoopLQLI0RYg9u5fnuLqSjLzMjHlX+uuzKUCSB+AciGiC9nXi7E41qsH1KxJrsJsjQPK3bsOqSDnrNRua84VtubOp586r2+TOQFEY0c4ItYKIPrdO8qaKxdAGRnAxo3S50rLLG+/TWJ/fv5ZXOdqAdStG5mbu7m/Q0ITTjEzjBgxArGxscjJycHBgwclbqidO3di0aJFkv379OmDCxcuICcnB6dPn0aPHj0k2w0GAz755BPEx8cjOzsbW7duRW1ZTfbIyEgsWbKkKOg4BQsXLkSIGwZ0BvmKDt7m5ZoDAE7fFYs+fvvEtygTRGr3N/iuAf6N+9e5A7RAaCgRKPIgPnux1gLkKAGUl0esXGxdFneyAA1cNRAbL5OZtU2lNor7VAqrhJ61emJUy1EY1XIUWlciruIDrx1Ai/It8FrT19Akpon0Ra+9Rh5pPBDl4kVkV6oFY+//6flnYO1achGbNUvXwypiT984a8TvypXA3LnWvwdgXgCtXWvbMYsztlqAHNXkVknUsFYmOocUFqq7VqdPN12nZS50FMWl/xeLZ/hZijEtyrdAs3LN8GKjF9E4urHJ9pqRNUX3BJRdGq4mOFjbpGMN1gogOrnoVbuPCiCjkfzg2TtFd7IArTwnWmjaV2kv2Va9FGl49GXnL7HuxXWY1X0WZnWfBS8D+XArh1fGoSGH8MPTP5hmGPbuTR4PHABuFzVYLShAfJ12CM+9iz4bXyXxQuaida1gwgRyKBr46UjsEcvWXID+9z9gxAjggkLfNUtQS9NPP5luc7e2He4AK4C0WOnMWYCuXbN/PEo3Saw3mQoka11vjhJsWigu/b9YuAByMX7efjgy9AgWP7cYrzV7zWR7sF8wEjPFluSBPjpGGrsxrEHPGgGkVyA2e3HMzZVOPNnZ7hmjUbu01Aq65oU1SHwnEf0a9bP+YDVqkLS+/HzS5fbIEeD6dSzCK8iFP1bif6QR1Pnzuoy9alVxWSdNpYq1Aoi14qRo7BfLXlitdVsUFoqxdc88Y5qI5+BkVY+EbXBrSwwQy2um07DVUIHz4oviOvY8OFsUzWCtJ9mVAqi49P9i4QLIjWhTqQ0ODzmMu2/fxcKnF2L/4P0AgPRc8XYiPkO/AlTuDOsKcQcBJJ8o3SETjD0vAKBLjS6Y13Oe8DzEL0TI+LIagwFIFIU31q4FLlxAAWSR7mfO2HZ8GWyLO0d3xLY2YP6NN8Q0a60WIPZu38vKWZYV1z4+phf0ktrvzhw+PsDu3URwaOkXbc4FplZ/yRqoBahiRbGGGevq+ucfUleqQwfrjutKAbRvX/FyfwFcALkdzcs3R9ngshjUdBBaVWwFABjcdLCwnRauK+5ERZHihoC2u269BRAb/JqTYzrxuIMFKCFdrJf1VeevUD60PHrUEuPlAnzs9AeyOcE//QR89pmpADp82L73KIL9PB09ydviAqtOvImaLUCsALJWLLOfhbe3VAD5+lovqEoKjz0GLF6srbGtORcYrbtDrTOJiSQAPi5O+1ioAAoJUbbYZWQAkydbf//gKgE0fz5x/UdGFh/3F8AFkEcwo9sM9GtI3Bjmar4UN2htD2ssQHo1EjQYxCJ4GzaYTpSuFED5hflYc2EN1l4k0bDVIqrhnUdJGdyKYRXRJKYJ6petLwTP28yYMcDy5eQKfPMmsH+/qQD65x/73qMI1urjjgIoIoI82mIBstbNIbcAsRmR3P2lD+ZcYEFBwH//kdjG0aOBF14AJk6UNjC1BHWBBQcrxyVmZ0sNrGwInrnz0lUCqLjV/6FwAeQBhPmHYVzrcQBI1d+SgjUXHb0tQADw1FPk8exZ04nH0W4ac/x24jc8s/QZjNlEIoajQ8S+eF4GLxwZegQnhp2At5edhZn8/EgkL9MihhVAozETOHVKl9QUNhjUWQLIGjFBXXRaMwDtEUBsEK/cBcYFkD7Iy1ywBAXRasfA7NkArbnL1NW1CGsBUhNArJW5cmVxOdK0TJeAK11gxREugDwEWrwuKavkVMJytQCqUkU8tjtZgNZclDY7kxc29DJ4wcdLx7S8F14QFgteE2//ZmM0rqIa0L273YqQTRt2RwsQvYhpzdrRywIkd4FxAaQPrAWILeYKqFuRrcl0peezmgDKyZEeLzKSVJu/dcv8HOYqAVTc6v9QuADyEGgwa1Z+FrLydC5b6qZQAZSSYjkzyBECiI0FUIsBys0ljVGdOTHVjpRme7FlEhzCQw8Bw4cDH32EglJSt1ohvICDB4GtW8kHtX+/6RVFA6wAcrcgaEA8r7S2w9BDAHl7E9cIF0D6Q7/PrCzT8y0gQLnnm5bYIgotlREcrO07CwoiAdPly5sXWq5IviiO9X8oXAB5CKF+ocJdfUmxAlEBVFho2fXgCAFEj5WZaTrx0Elz1CigXTsSK+AssvKlApgtpukQvLxIy+xPPjGpsWKsWSTGvv+edJNs0wZYssTqtyhuFiBWKNkjgNhHgAsgvWAFkPy3rdZs1Jo6ujRYPiJC6upiYW/q2O9VbX/ANRag4lj/h8IFkIdgMBgEV8f9rJIRBxQQIF6kLLnBXGUBmj+fPM6bB6dx6PYhyXMDnNSiHKYCIGf6t+R2+e+/yW0iYNMsyYoER0zy168Dd+4Q4xQVr7YIIK0WIFYo2iqAqCWAW4D0h/62jUbldndKFiBroPNVRIS0cOq4ceIyK4BYo6k5C5ArBFBxrP9D4QLIg6AC6GTCSTy37DnsvL7TtQNyMAaD9jggRwsgV8cAZeRm4NCtQzgRfwIHbh6QbDPCepeTrSQkSJ/nVKguiRECQDLGGja06kNiJ3a9J/kHD4Bq1Yh7gRUw1ogJ1mKgBfZPt7Z1ChVP9EJ4/bq4Ta8sx5IOO0/ISxvo4YJlBdDly+L6adPE8449F1nB7E4WoOLs/gK4APIoSgeSOKABqwZg1flV6PBLByRnJ7t2UA6GjQNS48oV8SKhZr62BXMWIGdngb279V088uMjaPJ9E+e+MUNuLrB+vXRdTg7ElBmWM2eA7dutOrbSsh6cOiUuswLCkS4w9oI2ZQpJo9aK3AKkli7NsR1fX9G1+OCBdFtenm2f87lz5AYhO1sUNxERphYdei6xlie2+KI7WYCKs/sL4ALIo1Cq6tvguwYuGInz0GIBevhhcblCBf3e250sQN8d+k51m9GGoGNbyMgw/RxycwE0aAC8/jrQooV04wGppcocrKDUW1yqVfZ1ZBC0PFbq00+1v5dcALGiy9FtQkoKBoP4ncrnFlvOv+XLSef25s1JhWf6HqGhJHyuRQtg82ayngog9qauATONu5MFqDi7vwAugDyKMP8wk3W3024jt6D4FoeggYfm3AjsHRwVTHqgJQbIGRiNRrNuLme5wJQuDIIgmj8fOHQI6CFWoka89qrljrQAxcaarvPykgYXW0KLBejBA6BmTWD8ePvOD3kQNNtDyklat0RAf99KAshaC9DBg+Tx1i2xmWp4ODnPmjUjP40uXch66gJjBdC0aeKy3ALUtKm47GwBtGkTEfObNjn3fZ0FF0AexMX7FxXXxyYrzPDFBHqXrvWuTE8XAZ0gL10yvfN3pgtMKeg9430xbcpZFiCli7pJWi7bOVQeMKSC0ejYVhj3FXIGrO2npcUC9NNPxB379dfaOpKrQc8teiHkAsgx6GkBundPXL5xgzyquePlLrDZs6XNgFkLUN26pMIE7SfGCyHqCxdAHkSrCq0U1x+9c9TJI3Ee5nr2UOidcqdO+r53qVLi8pYt0m3OtADdSbtjsi7INwjVS5EGVc83eN7mY+/ZAyxbpm1fsxYgkGyX7VeromDZCrJCowVIfly9J3kl0aLVlUWhF62NG9XHx1qU5ALImtg0eZba77+L27gLTD/oDc6tW9L1ahYgczdXd5ifKD3t1YLs5S4w+X6sAPL1JUUSP/qIPHdmHSCa4RoZCXz+ufPe15lwAeRBTGo/SXH9x7s+Rn6h+hV53419+P3k7ygotOO21EWYK1lPKVNUm2/6dH3fu2pV0glCCbkA0sPylFuQi9fXvo6/zv0lWZ+QoWxJOTL0CPYP3o+etXra/J7t2pEkrovKxkUJ9MIcGgp07lw0ZuZ7eeIJIkJn7GxGVsTGajJZyL9bZwgga2EvPDt3Ku/DBljbI5Dp50wvhL17i9u4BUg/6Pc1ZYp0fV6e8udsNJJzs1cvYOZM6TZW69NlNSsjXU/d+nIBJG9+y77GmRagqVOJ9TE0tHjG/wBcAHkUkYGRGNbc9Ew8l3gOx+OPq77uuWXPYcCqAZh9cLYDR+cYtPzw2bLzemIwAAMHKm/bulX63Jp4EjV+O/EbFhxdgN5/9pasZ7u+s0QERKBVxVYw2Ki+WGuCFmMNe2FWaia5dy95/GlrZXJ1uXWLdJW0gCcIoLp1xWW175pNraZuFdpAMzNTu3gxV6iRCyD9UCuZoSZeCwqAlStJySumPR4WLJBmGlLPr5oAkgc5W7IAsY/OdL3TAOg2bZz3ns6GCyAPgw2ETpuQho7VOgIATiWcUtw/NSdVsCDsv7nf8QPUGUsCqLDQcQKIfX8506YBv/0mPtdDACVmJiqup99fREAEDDBgdnd9hOzixeKyUr8iOWx2krlu2vDyJuYgQOwWbzQKt7x37khFiXxS13uS10MAPfKIuHz3LnDsmOk+7DlA43bCin6uRqP2FHq5BYjloYe0HYNjGbVsKzUXWF6e8jz0+uvS57/8Qh61CqAwWW6LkgWIrnOW67241/+hcAHkYfRr1A8AUK9MPYT4haBRVCMAwKm7UgH0b9y/iEuJQ1xKnLBuxdkV6P9Xf+QVuLCVuZVYEkBZWeJdsTMEEDsxsuVv9BBAbE+v7PxsXH1wFWM2jhFivAY2Hoi0CWl465G37H8zSK1bWho9shdmi7FZbduSx48/JjPoo48CpUrh+tIDKF9eKijk363ecQ5UALHfUb161h+HBqK++CLJ7NmxQ7qdFW408Jp2kQek7T7MoSSADh0CRoyQZgtx7EP+my1bljyqCfD8fFM3pzmLnFYBJO/+7g4WoOJe/4fCBZCH0SSmCU4NP4W9rxJ/Q+PoxgCAmQdmovefvZFfmI9jd47hsZ8fQ7XZ1UwyxJacWoJNVzwnp9GSAGLT4/WsAk2Rm6fZYNabN8VlPQSQr7c4893NuIt2P7fDrIOzsPgUMdWUCiyFYL9gm11e5tByZ6nkAlP6Xq5fBy63f01c8eijOLM/BZPzP8CyfqsASF0G8kldqTWBPVDLC3uhmaQcTmcW+bnw7bfk0WgE1q0T058BMSvIz0+0rmkVQEousBYtgDlzTC+WHNuRi36aiWXOAsQKoJQUaYYedXdStAogNtlCPi65AHKWBai41/+haLjv47gbDaMaCstUAAHAX+f+wvqL63E5idReLzQWIj7dNLgjPdfK2vwuxJIAoheV4GBSc8NR708JDhZFF3v3FxVl/3tl5olNoxLSE3ArTZqeUiqglPwlumGrAFJyL+XkALWahiCtVWeEHCDBUg1xBgAQDfF8TE0l5n/5d5ukc69fOsbISKkwsRa5ALpwgTyuXAn06SPdRuvCeHsTy2R2tuV2LhRzLjCOfshvWipXJpY2NSuLPDg6OVkUSsHBxKr5F5O/oHaOyYWXOQsQFc/0Nc6wAJUU9xfALUAez0PRDyEqWLz69lrWC98dFqsGX0q6ZPIataBad0SrBcgR7i/2/Slq6cxaXEhK/Hj0R3y0/SMYjUZk5IomgpkHZprsGxEQYdubaEDLxMrGANHP25xVI+6TRSbrEhAjLN/4l1gnnSmAKHoIICqmtm1Tf423t+hu+/prbe9jS7NWjvU0aSJ9TouomrMAsb+TlBTx3AoIMBVUtlqA2POMLjvTBVZS3F8AF0Aej6+3rxAHRLn64KqwfDLhpMlrbqfddvi49MJTBJAtE1OhsRBD1g7BZ3s+w94be5GWK7aN/uP0Hyb7u1oAsZYJGrhpzl2VHl4B6NtXdfvtDccV33vhQtO6S/bgKAFErQHm3J8+PmKHkCVLtFna6LnOLUCOZcIE6XP62zZnAWK3JSdLBZC89pMWARQcbPo9s242es45Mwi6pLi/AC6AigXBfupV1v65TLJwBjQegHcffRcA8N9ty6nJ7oK5WBNAFEB6NkFl0SqAbEndvpshNqn65uA3+HLvl2b39zQBlJEBYOBAFEI5Zin9BHHVKn12XbtaHo9W6EWKvdPWUwCZs/55ewPvvis+11Icm7vAnENwMPDkk+JzKjzy8pQreVuyANEgaooWAaQU08UKIOoCc6YFqLi3v2DhAqgY0Lpia4v7PFLhEQxvMRwAsPP6Tkl2mDujNQbIWRYgdnJisWViupFyQ1hefna5xf1LBeobA8T2TbM2BkiLAEpNBdCxI5KgHLmbcYYUSnR0cTdHWYAo5ixA3t5AdDRQqRJ5Lq86rAR3gTkPtvwD/W3n5yv/nvPzpfOQ3AL00kvS/W0VQGwyhytcYCUJLoCKAWNajUHbyiTtuHxoecV9SgeWRpWIKni8yuMAgCqzqii2WHA33M0FxqY1s9hyEb+ResPyTgy1S9e2/k3MEB0tLlsbA6RFACUnAwgIQPYPvytuz3iQA+zbJ7x3NYiuW7XP2RaUssCc6QIDgAoVyCObOagGd4E5D1YAsS4wtRsCNuYtJUU8twICyDn1zDPidjXBzH6v8vgfQNkCRM+jwkLHtkMpCe0vWLgAKgb4+/hjz6A9ME4yYnTL0Yr7PFr5UQDAe23FqLZt18xEb7oJ7iaA9HSBqVnh3mjxhsm6Ma3GIMBHQ7VCK2DN/I5wgdGsp/zO3RW3pyMEWLdO+OwikYRJmAxA2gHbXuhFi3VRONMFBohi8+5d9X0p3AXmPNQEkNrvgT3ff/xRDNinVhv2vNJiAbIkgOQWIMCxcUAlof0FCxdAxYyXm7yM15q+hpndxCyiV5u8isrhlQEA3Wt2R8sKLQGoV492J9iCexs3mroQ2DR4R74/Rc0FlpGhqeuDBNYFRpnWZRrm9pyLyY9PBgA0L9ccp4efxoxuM6w7uAbYO0lHuMCoOFXrjJ6BYODWLeRkkh38kIvmOAJAn+rNAPkb6ThYi5cjLUDsc7pM388aocldYI6H/U7ZGCC138Ply+LymTNiFWh5rA6gjwBSOq4j3WAlof0FC68DVMyICo7CD0//AADoU78P8grzUDWiqmSf15q9hoO3DmL2wdloWbElnq37rEOK6+kBnUTWryf/fH2l1hZnW4DMXZRatrSuV1NcKrEATe86HWfunsHV5KsY8cgIAKTx7dDmQxEdEg0vg2PuU+yxAFEXFRVASn83Pb5ZAbR/P7IfigdQAYHIQgCI8tFLALEuC7ZWky3iQt4uhP7N8otlaKho/aLWIWsK2XEXmPOw1gL088/S56wLDJDWIrtyRfkY1gRBy7PAAMdZgEpS/R8KtwAVYyqEVTARPwDQr2E/RAREIKcgB73/7I3dsbudPziNyKs7yycmeoFTs8zYC3s3X7u27fV+lLiefB0AUL1Udfz0zE/Y8fIO+PuIt6TlQss5TPwAUguQtTFA9PPOzJRuU9pfTQClIwS4fBlZb38EADIBpE/Xz7SiygJeXkDp0uJ6vSxAy5YRtwELK8bp+WNNECt3gTkPawWQpeOwrztyRHlfdg5xJwtQSar/Q+ECqAQS7BeMHrV6CM9pryl3pHJl89vpRdZRFwvWMDZ0qH4CyGg04kIiKSVcp3QdfQ5qJfZYgOjFgnY5V+rfpckCBCALROUGVigtCqD4ZMsD0gAVQKGh0guLXgLohRdM92PdsfYIIO4Cczzsd8WeH9b2o1MSQPI6QxRLLjClLDAvL3EucpQAKkn1fyhcAJVQpnedLizfy7znwpGYp1o189vpxVVPy4wakZH6vc/JhJNIy02Dj5cPakbW1OegVsIKE2tjgNiLRXa28gXDkgUoo8fzwMMPIxvk6hHQ8iEEdCTB+tmZ+lqAQkOlFx5bvke5AFKLf2JFOxVA1hSy4y4w58Gex+wydW1pRUkADR+uvK8lAcSKMtZC5eh+YCWp/g+FC6ASSkxIDKZ2Irb7m6kacnNdBOu2UIJOBno0I7VEy5amF6XwcOuPczfjLl5aRYqGtKvSTtIE1ZlY6wKj+/j4SO9SMzOVxYBFF1hhELBvH7LakqqHgaWDEDB+JADgbn4kcPGi5UFZgI4rNFT6XcnjebSgltYsp21bcVkeA8RdYO4FKzbYZT0EkFpvQksCqCZzP3T/vunreC0g/eACqARDawbdSXffekBqPXkozrAAnTsH7NgB1K8vfZ/q1UlgttrY1Gg8rzFO3z0NABjbaqyOI7UOa11gbMadj4/oopk0SbnRJz2+2ZoqPj7IatcNABFVAXWqCNv/nHDM8qAswFqAQkKAzZvJP3lsmRa0CqA6jEfTFheYUjd4jmNQc4HR2DatlClDHrV8v5aCoNkYMlYgObIdRkmr/0PhWWAlmPAAckuclpNmYU/3IiNDrGLsDAtQ3brkHyAVQEpZHunpynd1lPzCfCRkiP0QWleyXMXbUVjrAqNigqbABwaSi/XcuUCNGqb7W3SBFQkqNpMmIFj8Isdv6ojnLQ/LLPQ7qliRPHbpYvuxtAigf/+V1vqRu8C4Bci9YIUwu2ytBYiWWLBWAKnNFQcOEFfUiy+avs4RFqAPPiD1fyIjS078D8AtQCWaMH9yJUvNMVPMxQ1hU5vZzCRnULWq6bqDB8VlmpavxsX7UrdOqQB921tYg7UuMNadJH/NtWum+1t0gRV9VjTlPTBQ6pqq7K2hb4QFThb1An7oIbsPZVEAVa4MPPqo1HJjSxo8F0DOQ163x9bPPCaGPGr5HV0VC54LwlxOy5bAxInSc04ugAoKgPh468fKEeECqAQT6keuZJ4mgFiRQS+uzogBAoB+/YBx44C//xbXPfKIGKuUZsGYFpscK3nuyvpL1rrA5BYg1k2g1OLBWgtQYKDUDVE356R1hZUUoMKstg5dRCwJoIEDySMrgLgLzA34+GPyw1Uo1y6v3GyrALKm0CVtiwJYF4smd4H17QuUKwfstrOKSUl1fwFcAJVoPNUCxAogZ1uAvL2BadOAp5+Wrqd+e0sWoPh0192ynTgBLF0qPmctQFpaecgtQCxK7kCLafAKLjBvb6BbF/KC0Jx7QGys8os1Qr8PW4LV5VgSQPQctFcAcQuQjuzdC0yeTE78zZtNNmu1AI0fL33O3gABYuVkLd/v8OFEk50+bXlfFvk5tHIlefzqK+uOI4e6v4CS5f4CuAAq0VABlJabhkKjAzvs6QzrAnNmGrw5qCiwJIDY+J+naj/lwBGZ0qQJuRGmVV5ZYXL3LhFEU6cCe/Yov15uAXr4YXHb9eum+2u1ALEuMAB4uCVRDbnwA47ZFwitZ6VwrQJIqR2CNQGsXADpyKFD4vLChSab2c/Y21v9M586VVpJvIdYRg39+4sxibVqWR5SSAhxbzVoYHlfFjrHULFCsTZeiSPCBVAJhgogAMjIzTCzp2tZtUr6XMkC5CwXmBr0AmvJBZaQTgRQh6odsPR/S83vrCOsJ4m6hVhhEhcHLF9Oire1a6d8DLkFiL3ztCUNPi+PWJ5YFxjANMCFn3JwkRXoKYAsuSvoxTNM/FmhdWvpNu4CcyJGIzBmjPh81SpgzhzJLjQ2jGZxyW+kQkOBrVtJNuqqVcSS+P330v3Y5TlzgAEDiOFJb2iigdzaam3Gmpxu3Yosr93sO44nwgVQCSbAJwA+XuTXu+zMMhePRp1evYiVgF6Y3dECpNUFdjeTpAj1rNUTQb4O6t+hQGKiuExTb1kXWGysshWHhU0pB4D27YHZs9X3t+QCA8jnJRdAgliAL/DFF+YHZQE6Zr0tQPXqAV9+Kd1Oz8GGDYFvviEeF/pZcxeYCzh7VlwuW5Y8fvaZ5MQPCQFSUoAbRX2J5Z95bCzQqRNZbtOGWF+GDpXuw7qEY2KAX391TDNRal26dEm63l4BtG9fyer/xcIFUAnGYDCgViT5Vf164ldcun/JwiuUWXJqCXZc26Hn0Ezw9xdrdlCRsXixeKflagsQnQQtWYBovFVEQIRjByTjFpNQZTRKxQ8AJCRIRYJS7DGdaNlAZXpxUIJagJTcPuXKkcdTp0wbSkosQPfvAxs3qr+JGYxGx7nAfH1NawmxF8+33pKm3PM0eBdw6pS4HBdHfqR375IgPkZFhIUp99wCTG+s2OKGixaRrL8PP9R32GqUJ2XbcE9WuN8eF9j8+WTOiowsOf2/WLgAKuEseGoBAGBP3B7U/rY2Ju6YaNXrzyeeR/+/+qPjrx1htDNjxxL0IpaRQVLPX3pJ9Id7ggVow6UNWHdxHQCp+9EZJIihR8jJMbXK5OVJi8KlpJgeg060rACqX19MAZZjzgX22GPk8cAB0xggIaMGRVejJ56waZbPyhKFnFLgtrXIBZA8gU/pM2P3B6xrhcFdYHZy/Dh5HDKEKJxHHiHP332XpAXS7QxyAWROhL78Mqn7RGsAORp6/slFtNwiZA0lOQAa4AKoxFOvTD3J8093f4rEzESVvU05Hn9cWE7JMXMF0AHWAnT+vHSbqy1AWgRQzyU9heVQfx2uyFZARQZALrByC1B+PjBokPg8UeEUoBYg1vJhMJC7YAobBGpOANGA0rQ0dRdY7iOPiS/47jvTg1iA/S6CdPA2sgLI29u0JpS5cCWtLrCsLDHGg1uA7CAnh/iiALE3ibwY1HPPmaQ/ynsPuvrGikWwjCpkbMoDozna4AKohFMq0LQQ3+5Y9cISGbkZ6PxrZ/T+sze6/NYF/Vb2E7bdSrW/cJ05WAuQ/ALu6olKqwuM4kwL0MWLwLPPis+VLEBy2B5EABEzVNDIXT9sjZ0PPxRjJNRigL76SmoRUXWBlSkPzJtHnixYYH7ACtDvIihIvS+TNbACyGgEuncXa6gA0voucrS6wDp2FIUbF0B2sGwZcOcO+RB7Ft14tGgh3efaNZOUx88+k+7i6hsrFvq7UGo8fMuGqbck1/+hcAFUwvEymJ4Cu67vUt1/4bGF2HZtG/469xe2Xt0q2XYs/hgeZD3QfYwUagH67DPgwgXpNncRQGodwuXuQWcKoF69pDE9ShYgOWygOSD1QMmtKazVJzSUBEcDphagzp2JFemdd6QWEbkLTGIt6duXmJkuXiT+MiugLik9agABUgFUWEjOuddfB/77D3j7bfJPDa0uMPZP5C4wO6ATxEsviVVKO3QQt1MfrKw2UPPmUkHvwjqlJsgtQKzrzZb2GCXd/QVwAcRRYO8N9RxOc+6xAasGoMF3DRwWC8QGssozcFx9p0brgCg1BQWAnALpbZszBdC5c7KxaLAAyc3sbKaJPB2cFUBBQaY1b9hq3VTksPuopsHngjRLeuYZsqJ7d2kfAQtQAUS/G3thBQkrIB9+GPj6a2n6uxxb+jhxC5AFjEZg2zblH92ZM+SRLbYTEwP8/juJXh4wgKw7etTkpSNG6D5SXZALIPZc0lLIlGMKF0AcvNHiDQDAN92/AQDcTrutum9+oflb2Dvpd5Ccnazb2FjMxXG42gJEmxo+UDGAyRvOOjsImkWLAJKb2Vk3lfyuuGZNcdnLS/wu5C4wVqQKcT656i4wYYL/7Tdya56SQlwbGqHXRb0EEPt3W7KgyeECyAFs3UrMik2bSk/oefPEUs3Vq0tf078/iV5u1ow8P3rUJOXR1XOJGnIBxFoTrRVA3P1FcKgASkpKQv/+/REWFoaIiAgMHjwY6RYKpWRnZ+PNN99E6dKlERISgt69eyOBTWEBEBcXh549eyIoKAhRUVF45513kM+cDTt37oTBYDD5F887xykyp8ccxI+Lx3P1ngMA3Mu8p1oZWqltBn0d5W7GXZN99EAee8LiaguQRQGU6z4CSIsLTE0AKYlQ1hRfUCB+F/I0ePY7oheZxYvFdSZB0HRSDwkB/vc/sky7m2pAbxcYi60CyJpkNu4CUyA/n7RIb9EC+OMPsu76dSKQqd/2jTfE/eUCiNKwITkJk5JIinxmJhFO8fEun0vUMGcBstYFxt1fBIcKoP79++PMmTPYsmUL1q1bh927d2OovIqUjDFjxmDt2rVYvnw5du3ahdu3b+O558QLbEFBAXr27Inc3Fzs27cPv/zyCxYtWoSJE03Tty9cuIA7d+4I/6LYWuYcAS+DF6JDolE2mBQLyy/MV43luZF6Q/L8+ye/R5fqXSTr2HYPemJOALn6ro0KoKNHlRuDshag5+o9pxh7pTfbtwONG5uul1uAlC60agJI6TswGICZM0m4TufO6i4w9juigoANGld0gVHoH0JdGxrQ2wLEYq2Xl1bxvXxZ+8WKW4AU2LiRCJ8jR4CffxbXnzgBTJpkur88rYvi709EEECsQc8/T4RTixbw8XLPtkDcBaY/DpuFz507h40bN+LHH39Ey5Yt0bZtW8yZMwdLly7F7dvKLpaUlBT89NNPmDFjBjp27IjmzZvj559/xr59+3CgKDpw8+bNOHv2LH7//Xc0adIETzzxBD799FPMnTsXubKzICoqCjExMcI/Lz1SQYoxft5+KBVAruRf7f0KB28eNNnnwn1p9HG9MvXwaKVHYYDoH3CUBchcKwJX37WVYpLpKlcG7mfex//+/B+2XNkCAEjPFS2fS3s7pwVGp07SWnCU3FypW0rpQisXQEop8CyjR5N+kz4+2lxgSoKVjkPRXUQvZFY0R9U7BojFWgtQtWokRignxzSAnyIXVVwAKbBLPUEDR4+aFmMyF5j1VFEvvqQkYP16snzrFnobV6BsWZIl706wAsholLrArLUAleT2FywOUwT79+9HREQEWjCph507d4aXlxcOHjS9sALAkSNHkJeXh86dOwvr6tati8qVK2P//v3CcRs1aoRoxu7erVs3pKam4ozs7rBJkyYoV64cunTpgr0WmrPk5OQgNTVV8q8kUiaINMX5at9XaPVTK9zPFPOh8wrycDnpsmT/EL8QNIpuhJtjbwrNPe9lyEqV6oTpxdcIVNsOBCa5jQUIIJPTNwe/wcpzK9H1964oKCwQXGBNYprA19uxV7YZM4ApU9S35+SIF3A1ATRkiPQiHxdHHrXU05G7wMzFALHQGBtFC1DlyuQxNVU90lwGvRaauwbairUCyMtL/BPu3CGPCxeSlgl3i+4X5HFZrhb1bsnOneJy2bKk5Dbl/n3So4Sydq35Y40apbg6fMdq3LoFrFhh+zAdAfu7kJ8r1liA5s8H/vyz5La/YHGYAIqPjzdxOfn4+CAyMlI1Fic+Ph5+fn6IkN2yRUdHC6+Jj4+XiB+6nW4DgHLlymH+/PlYuXIlVq5ciUqVKqF9+/Y4qhDxT5kyZQrCw8OFf5UqVbLq7y0uyAv07bgutri4lnzNJAg6xI+kZpUPLS+0d8jMs7M5jQomAqjeKuDlTsDLHd1KAAHSmJ/3t72PlGxyNQ71c2wBxIwMYNw44P331fdhJ1AvL3VLAxU9ALBmDXnU0uPIXBaYfB/K3LnismK9k+BgsWOlxgap9K7YEbE01gogwDRObPBgYP9+0XMjt7o5Qrh5HAUFwI4dxLyYkiJmbd24QZTjN9+I58PJk6K6rFYNePJJ88cuXZoEw8jJy1Os9O1qWAEkt/hYI4CmThVj9Upi+wsWqwXQe++9pxhgzP47Ly/T62Tq1KmD119/Hc2bN0ebNm2wcOFCtGnTBjNnzlR9zYQJE5CSkiL8u3Hjhuq+xZlg32DJc9bicz6RfK+RgZHCOrahJ112mgBqWORKijkBeFlIa3Iw8ovsg2wxhuqrfV/hhZUvAHB8BWgtQbZsDJCaBUh+LGoQbdnS8vG1CCD2PUNDpXGrtN6TSZNHGgdUZA22BH1/R4hjWwSQWqkEWnSSFUBnz0pbk5RYRo8m1SH79SMp7YWFJLC5YkVxnypVTO9AfvpJ2/FpNhgATJhAHq1wszoTWocqJ8dUAFnjAmvThvwWn3++ZAdAAzYIoHHjxuHcuXNm/1WvXh0xMTG4e1caC5Kfn4+kpCTEqDQPiomJQW5uLpJlM0RCQoLwmpiYGJOsMPpc7bgA8Mgjj+Dy5cuq2/39/REWFib5VxKhFh3KlaQrwvKFRBK80LVGV7zU+CU8WftJlA8tL2wP9CEKxaECyCsf6PUK0PQnIEf8ju7lutekpVZKQP756g3b8kIN+QSqJoDYn6E1/anYGKB//xWbcqtZgOQ/W9W2IrTz6lJtMVRUADkilkYPCxB7rKwsUQB5eUk9OSWW/fuBb78Vn9Nyx12kSRcwGKTBLAMHSosemqN5c3GZlkuPjSXR+aNGaRbbzkCLBSg7mxRNV0rEALj7S47V90Zly5ZF2bJlLe7XunVrJCcn48iRI2hedJJt374dhYWFaKlyG9m8eXP4+vpi27Zt6N27NwCSyRUXF4fWrVsLx/38889x9+5dwcW2ZcsWhIWFoX79+qrjOX78OMrRFtQcVYL9ZBagB0Q0ztg/Ax9sJ+biuqXrYlJ704wLp1iAamwGmvxC/t1pKmzLNmrsQeEkEtKVM+FCfB0rgJQsQCdPSrPBsrLEm92MDLEzuxxWAFnToZxO1NeuiQV3AXULkDxNnVbVzs0l/wTRNXAgcVns2UNMUhZuUhxpAbJUR0kJNQvQhQsktopW0GYrTpdoZswwXdekCaCQ8YuXXhKFMWsdskSVKqSeUGiomKp39y7xIX/7LXGxObjJs1ZYASSvKE5/nytXkurkANlHHkfG3V9SHBYDVK9ePXTv3h1DhgzBf//9h71792LEiBF44YUXUL48sRrcunULdevWxX///QcACA8Px+DBgzF27Fjs2LEDR44cwaBBg9C6dWu0atUKANC1a1fUr18fAwYMwIkTJ7Bp0yZ8+OGHePPNN+FfNHPMmjULf//9Ny5fvozTp09j9OjR2L59O958801H/bnFBrmF4nLSZZxMOIlxm8chr5D8yuqUqaP4WqcIIANz5Sl3TFjMNWaYvsCFsC6wHrV6CMu5hfrlq6alAceOSednuduofHmgUSPpusxMYNUq8bmaqGEtFdYIILXSK6wQYZcjI6X7sa4fiRWoYkXRPCSzLitBx+wuLjA1C9Dp0+SRxvdyAQSS5k6jkGfNIqbEQ4fICV++vOn+XbuKJ3qTJta9V6dOpFN8qVLi+bV8ua0jdxhUABUWmlp6qQWItfwoFU3n7i8pDs0LX7x4MerWrYtOnTqhR48eaNu2LRYwTQ3z8vJw4cIFZDKz9syZM/Hkk0+id+/eaNeuHWJiYvDXX38J2729vbFu3Tp4e3ujdevWeOmllzBw4EB88sknwj65ubkYN24cGjVqhMcffxwnTpzA1q1b0Yma0Dmq5ORLIzFvpt7EQ/OlXZQbRcmuqEUIAijfMQLIxweAn3IhzexC8wU2nQ2thn3+zfNY12+dsJ5Nh7eXZs3Iv3/+EdfJBZCSkUTe50tvARQermxVUrMA0WxkdhsVASZuMGp9vmc509CRFiBbjAJUACUnmxdQXABB2qOrb1/g0UdNm5my+PqSpmwnTohFM63FYBCtQGyYxerVxO3mYp8R636W/4Zv3yYlstjfmFKuwL593P3F4tDcmcjISCxZskR1e9WqVU36RgUEBGDu3LmYy6aFyKhSpQo2bNigun38+PEYP3689QPmSGr41IqshUtJlyTbvQ3eqF9W2dVIBVBWnhXlbq0gJASqAijHDSxAX34JvPsuAEOBUDE7MjASBiadRE8BREPali4FehQZmeQuMCUBlJkJtG1LbqoBdVGzaRNJ3f71V/F6oDWjqmpVMSGHoiaA6igYFENCSExMmtyzGRVFZnYNFiB3DYJetoy0NVOjxAug/HwxjfG550yDxNQICFCu/GkNb74JyIv10tigyEhtaZAOgv3tyX8XX39N/g0fLq5TEkBt2hArkQv/DLeCVwbkSGAF0JLepuK1QlgFyQWdJdDXsUHQEaUKgaeVK4kXerteANGenaFRycI6WhpA7bneqFmAmNJaOHNG9CJMmaIugPbvB55+mngiaFKk1oBipfYT7IWdFSUhCmFRNA5I1QJkhQByRBC0kmizBJuoNGiQ+n4lvgXGrFni8iuvOPe9X31VfduVK+rbnAB7Hm/dqrwPWwqPxoxTeAC0KVwAcSQ8U4dcxeuXrY8W5Vvgw8c+lGw318PK0TFAt1JvqW4z+rpeAFELR4FPMgBSUoAWPVzRZwXaVm6LaV2m6f6+rBFVbgGipQPWrgW++44sp6eTiRAgVgl2YmXzCGivIBatYkLJ8sQKIHkavBzVTLAqVcjj7t1m33/FCrEfpp4WoIMHST/NX3+1/rVaK1KXeAvQFlI9He3bAz17Ove9vb1Jx3glNIhuR2IwiL+bw4eV92F/FnIrEQ+ANoULII6E9x97H7/0+gU7XiYFEJuVaybZ/mmHT1VfSwXQ3ht7UVCof12ee5lM3Ee+9CqRY3R9DJAggHyTAUitPb3r98aeQXtQJaKKQ8cgtwDRCTMgAChKpJTg4yMVI59/Dvz4I1lWcvNotU5YEkCWLEC04jSNdVi/nriNbjYtChj6/XezfcH69FF+L3t55BHy1tYkGlHkpWrUKLEC6OefgQoVxPifqVNJTQBn8+yzxEckD7a+edO6brYOgP7+tDQqkAsgHgBtChdAHAmBvoEY+NBARAWTEgPRIWLV7cmPTxYsRErULl0bAFBoLMSmK5t0H1t8OlNB/F49YOVi4E4TAEBmnustQEIBQG8y87ii47t8fmbFjVJhPV9fqQUmJIRUKFYSJfLjmUMpTZwVT6wXVckCREUAzW558kkSk/TRni6iWUulM7w8QNnVVcIp3AJkhoIC4n5i+0Q2aOCasQQFAXv3Eh9SVpZ4J2A0Ep8wXXYBcgFkrkE0K5K4+0sZLoA4ZqFCCABalG+hGv8DADUja6Jfw34AgN9PqpiR7UAqgBoAp14ErpHMPhp07EqoBajQh1ijHF30kMLOxfLmp+zFX6mPl6+vGOPJ7l+6tPJ7aRVASneo7IWdrXqsJLYU+4EBSM/wAnr1Ik/efVfBRwbIO+24iwDiFiAzKOVsq6lwZxIQQNR6zZrk+dat5Hn//i4ZjlwAmRPV7G+Qu7+U4QKIY5boYNEC5ONl+UryfIPnAQB/nvlT6H+lF5LqyptmkAtFCukwefWBwgTqZKgAMvo6VwCxbN8ufc5e/AMCTPf38QGKSmwBEAWOUqkVQLsLzJIAYuuYKB2T7ksr21IqVKD/gURmT55s8lp5Fxt36aqutbVFiRRAtFw4xdZUdkfx/ffS53/8IU2VdxJyAWROVLMuMO7+UoYLII5Z2Iu4li7mver2gq+XLwqMBbiRqm8/tdjkonYXOyYDGVEkDuM+cbtdvH9R1/eyBSHNuyggWw8BlJVFKrtu26Ztf7lBhL34ywsO0u2s2KGWmaJC7Ir7a6FrV9N17IWdtVopGRXpRD94sFjZVjgGG4Dzyy8mr5WLL3exAGltrlmiBdCzzwKzZ5OcbneiY0fgs8+k68yUYnEU9NzQIoDY38GmTcQCtEn/yASPhgsgjlkMBgMmPz4Zvev1xuNVHtf0mnplSSOjm6kqDWlsJDalSAClkEDihQuB8v5EAF1OuoxCow3FWXREEEB++lmA5swhFhA2jV2OuUrQ7MXfYCCV/VnkMUD0rrFhQ+X30iqARo40Xcdaep54ggQUjxql/Ho1EZCZCWIdoAXrMjNNVJ88+NNdBJBWSpwAys0VBVCLFuTkqVrVpUNShO3YC5jGLDkB+huiv3OtLjCOMlwAcSwyqf0krHh+Bby9vC3vDKBiGLlDN5e2bg15BXkYs3EMtlwtSo9NrgqAXEDjThIxlFOQg9kHZuvyfrYiXGh1FEBs3yg1i7s5ASQXLGrPx40jnyftKal2Z6lVAJlza9HlgwelJV8svR4oygqrUAG4dIn03MjMNPH7cQHkITx4QJS9v7+Yem6mn6PLKVWKiO0pU8R106c7dQjy34U5CxDNoJw/nzxGRpIsT44IF0Ac3akQSmI09LIAfbzrY8w6OEtckUxET0AAJKJs7OaxyCuQtUl2InILULCvxqAPMyxcKC5b6v6g1CNIfvGXZ2fR7dOmEUFC44SU3GWAdUX6mjaVPrfmwm7WAgQQcxZ9g9hYyT7u6gLTSokRQN99Z+rbrVfPNWPRSnAwiSKmXeT37CHFd1avtq1DrpVYI4BoRugHH4g1vXj8jxQugDi6QwXQrTR9LECf75HdtqRKi7C83lwMEjmRcEKX97QFvV1gsbFSq09OjvJ+1AIkFz+A6cVf3kVazaKjJoCsCShev550rqBYI57MWoAoNBZIVvKWW4A8gMJC4IcfTNdT16a7Q/tTHjoEPP44iV3y8QHGjBH3OX+euGtpBVIdsEYAZWa6TSN7t4ULII7uUBeY3jFAlBrVfCXzzDdPfKO71ckWBAHkT0wQ9goguSVDnhIuR6lGm1ywyG9S1QSNUisLc/srUa4c6VNGsebCriaAJC4+mg12U/qdyz83d8kCA7TFzZYIATR/vmi5GzBAXO8parVyZeChh0zXz5oFXLhAlt94A1i5kvQXk/umbcScAGKTBShz5pBH7v5Shgsgju5UCNPPAiRJfQfwYqMXcekSMGOGuM7P2w+PVHiEvKdOcUe2IAigoPsAgDJBZew6ntxaoyaA6F2e0hwrv57QVloUNXHgrRLuZa2YYPtYWnNhV7tzlViAaFuM8+cl+7izBeiJJ0TviRpO8KS4njffFJdnzgQ+/ZQUH/QkaMVqg4GoC2qRrF8fOHIE2LFD3He2PvGJ5gSQUs/YSZO4+8scbjQ1cIoLelqA/o0jLctblG+Bjf03IiIgQjGd2NFWJy0YDOSfMZAIoNJBKtUENaJVAFG0CKB+/YCBA9W3W0JrKjeFnZSt6aAut2a99x4p5iZZ/+ij5PHoURItXpQS8+CB9LVqYs5VWHJLJCY6ZxwuIT2dxM2wlC4NfPih8v7uTFQU+TLz88kPqUIF0ry1sBDo0kW67/vvk/4sR4+SfZ97zqa3NCeAlDLCqFtcyT3O4RYgjgOg7qikrCTcy7AQuWuBxExyNagSXgWlg0qrZqKVCykHAIjPiFfc7iyMRgBBZMx6W4B27yZz7B9/KO+vxQXm4yM2QlXaroSfH7mhlRlaNMFOyhLrjQXkf0v37uRRIgIrVCB+NqNR0qlbni3naXEQloLdPZrJk4EePcTnixe7bCi6Qe8inn9eXPfgAeljtm6duK5WLaBvXyKEbEyflwsg9velZGGl65SKoHK4AOI4ALYJaP+/7CsZn5RF7LelAsz3EaDWFrq/6zACoWRys0cAZWSQm0aWzz4j8+aLL0qtKda4wABpmwstAqiwEGjWDKhTx/K+cgwG4LXXSLhE+/baX8cKoCeeUG+NgXJE+LL9L+RNuz3NpVSsBRDrFgoNJYKguBAYCAwaJD6vU4d0s5cXxSosVK//YAG5AKpenZROMhiUKwgYjTz+xxxcAHF0h+0XJtTusREqaCIDVdKSiqDbXS6AevcH/IgSKR1ouwvsnXeAnTvVt8utQ4CyBUgPAWSvgPjhB+D4cetigFiT/c8/mxFA1Md2546wylMFEE2AmjjRteNwKMz3hK+/dj//pL1Ei62D0In0KcTQoaaN+BYsEO9cCguBefPIj8QC8t+Qvz9w4AD5WB9XqFNLEwJ4/I8yXABxHEL3msRnQV1TtuJxAqiR6J+yxwK0erX57UoXdSULkJLAYQWQluuPK1xIrJiLjtZgASq6sBYUmFpQwsIcM0ZbUfs8ly0jfczY5rTFitRUUQD9+CMRBsWNsWOJSeaRR0Ql26ABMenevEkikn19gZQUMQvu229JxpiSgpEhP5d9fclvmOoudzvX3R0ugDgOYfLjkwGQDC170CqAqLXlfuZ9u95PTwzWRgwzsDfKSrAJM/SCaosFKM91dSPN8sIL5JG25FAVQGxjVJDrC3UP/vQTCZx299p6lNKlpW3Oih1Hj5LHKlVIkzc7fh9uS9mypDbQwYOmKZcVKpCoZeqrOnGCVJJ+5x3yPDWVmHOUzLtFyAWO/Pd97RqJ13v4YXE43P2lDs8C4zgEGgcUmxKL3bG70a5KO5uOQ9PgywaXNbsfFUgJGQmIS4lD5fDKNr2fbiQ5tqAbm2RiLgZIyQIUGEhCL+7dA6pVc8z47OXVV0l8Q7Nm5LmqAKJ/wLVrAET3V+nS5Bg2kZxMatO0aQNMmGDjQbTRsiW5VgLmi9oVC6hPt0ULlw7D5VSrRsTP+PHARVkT59atyePVq4o/TksCKDJS+u/HH4Gnn9Zx7MUMbgHiOIRSgeJsPujvQWb2VKfQWIjziST1qG6Zumb3rRhWETVKEdGx4uwKm95PV5b9pVftM4tYGwQNkAKF27aRRBV3xMuLNOCmWS6qAqh6dfJYJIBoBhhbgdpqxo4l2Tvvv+8Q/9+LL5LHBg2kRYLZprTFjj//BD7+mCzbmAJebKD1q+Tih6V1a8VzT16gVC2Gb+NG4P59Ln4s4abTH8fTCfcXf6k0ld1aTsSfQEZeBny9fAVxo4bBYMBTtZ8CACSkq3QNdTAFhUxgTlp5pwkg6vLR6gKzFnewTFABlJ8vqydEBVBsLJCfL1iAbBJAGRnA/v2kei+FVvXVkdGjgTVrgF27gCZNgF69yDp3FaO6MGKEuFzSBVBlBev0kCHSolkJCSY97gBTC1CxPmecAP/4OA7B30dMV2gS08Tq1++J3YNmC4j/o0FUA/h6W05Xig4hkYCuqgWUlc8okLxApxUfoyEDWl1gWvn3X+IFogVvXQmb/iuJWypfnmzMzwdu3hSKCMrDL8xy7Rq5tQ4JIX8w20ujXj3dO377+ABPPUXcdF5ewKpVpBhysSUzU4xMnz6dF6Xp2VPMCvvgA9JXbOZM4hY7eVLcb9Ys0qWYUfysAPL1LZ5hVM6ECyCOw1jamzSCyslX6eJphjc3iKXyHy7/sKbXxISQOyhXWYCy8hgBlC8VQCkppASKlmrIak1P1aACSG8L0KOPkmBrdwjZYNN/JW4wLy+galWyfO2a8NlZdY2dMcO0gRjL22+bDUzlWODQIfJYvry0WWhJpU4d4OxZ4NQpUtzr2WdJl/moKKBRI7Fg1uzZJEDa25uYCBctkgggd2rx4qlwAcRxGLQ9hZILLK8gD8fuHEOhUVkRpOaIFyRrBdCmK5tw6NYha4drN4IFKN8PMHpJBFD79iSmZf58y8eR97KyxD//kCab1ALEmsWLyyTJWrJM4oBq1iSPp08L2zRZvn77DViyRLlPQGioNNJ87lxrhsuhJCYSUyJAFDU3WRCqVBFTHOUouchmzwYGDUJ40jVhlTs1+fVUuADiOAxanVlJAHX9vSuaLWiGX47/ovhaVhg1im6k6f3YQOlWP7VCcnayFaO1H6ERaz4xP7CWHFrj7PffLR/HnDFCjZ49gevXyXKlSuL64hJY6+0t1iwyEUCPPUYet29HXg45b3wLmQ8/MZFYcRYtIs9PnyYX4oEDgf79SaoMi58fadmwebP4YX70kZ5/Tslg1izii6R9vmjvNo55GjRQ3RR2UCwsW1xublwJF0Ach0ELAabkpCCvQAzcyMnPwc7rOwEAx+KPKb72QTbpaNmzVk+0rNBS0/tVCa8Cf2/iKyk0FuK/W//ZOnSb6LWsF1kIIApGybCgJWgxJcW29//7b/LIFqOtVcu2Y7kjqplgtOLujh3IO0TOJ79fFpAOqkYjMGUKiT0ZNIhU3F240PTgFSoQU1pyMmnYOXYsWf/VV+QxLY201uZoIznZ1N3VzrZSGCWORx5R3RR244ywbE1zYY4yXEPaSUFBAfLctZqcjvj6+sLbyrL1pQJKwQADjDDiftZ9wUXFure8DKaKIDUnFem56QCAZf9bprmgoMFgwOzuszFsPan7fvTOUXSt0dWqMdvD3QxpDwYlAaTlT7HFAsTCBgyXL2/fsdwJPz8S52QigJo1I0HMKSnI23cYQHP4Ig/48ksSZ7Fhg7jvG28ofyiffCJ2XGV54QViyTh4kOwzcKDYs4KjTHKyaergK68ATZu6YjSeR5s24vKvvwJ16wqiKPTqCWGTs7JMizNcANmI0WhEfHw8kpOTXT0UpxEREYGYmBjNgsTbyxuRgZG4n3UfiZmJigJIKUCaupLC/cMR7Bds1Rhfb/E6ridfx9S9U3EnzUI5ZR3JzjdVO0rBzFosQPYKoL59STPqDh2KV8iFqgXI25tkax04gNxk0nLeF0U3Jaz4odBO3F9/TTp4p6aqx2MAwNq1Yl79unXAqFG2/xElAXkflx9/tKMqZQnEz4/4zHfsIEWjvL1JhehWreAde1XYzeR3wLEaLoBshIqfqKgoBAUF2dX2wN0xGo3IzMzE3aIiK+XKae/vVSaoDO5n3Ze0qGAFUGa+6W3MmXvEzFszsqZN4y0XSsbnzHT4qw/EiemhA0dxAsoWIC1GNHsFUFQUCXMpbqgKIIDUAzpwAHkgkaG+L/wPWPqeuP3SJeCPP6SdRocNI6nvlihblhTxmzQJWL+eCyBzPHhAPmfKli1A586uG4+n8tBD5B+FxqJR8c7RBS6AbKCgoEAQP6XZxkrFmMDAQADA3bt3ERUVpdkdViaoDC7cv4Cuv3dF5vuZ8PbylgggSep4EYdvHwYAtChvW/41tTT9eeZPLH5uMXy8HH+aX066DABoVq4ZSuURU7+jXGBlykCod6NE0VdV7KCp8Iqm/6KCiIIAimZcMKtXk0wxeWyFFvFDoQG8W7YAZ86YDVQt0fTuTSwXADFFcvGjD9HRJOqZl2PQFR4EbQM05ieIFrMqIdC/15qYJxoInVuQiw2XiDtCYgHKM72aXU++DgCoU7qOTeNkO9Dvur7LpmNYy5WkKwCAGqVqCDVoHBEE7e0N9Olj/vXF9bSkwd20lYSEHj0AiALIryzTM4CmyT/MlFOw1kTWsaPYeJWm9HGkxMeL4gcA+vVz3ViKG97e4vnH0Q0ugOygOLu9lLDl7w30Fc0RcSlxAEhWGEVSPbkI2gGeptFbS8uKYtaYrW04rIVagGpG1hQEkFIMkL0WID8/y4USi6sFiBZ6jItTsAK1bg3Mm4fcNh0AAL7+3qT/1DffiNaayEhg0yaS3m6tBcdgAJ54giw7oD1GsUBeMbs4ReC7A2xdKo4ucAHEcShsf6ybqTcBACnZogBiLUAFhQV4a8Nb2HKV1LooFWBbEyo/bz88WftJAEBarpVVBW3kygPRAkRdNbZagCwJIEstNoqrBYj9XBSNkMOGIa8ScYX5+oKYyt56S7pP1662X0jqFFkkZ88mdYV4Go6UU6ekz62IFeRooGNHAIC/gUc/6wUXQByHwhY0pNYYauEBpALo37h/8e2hb4XnbEd5awnxI/EdNJ3e0bAWIHsEUHq6adFhtnxKSRZAM2aIy2qhEFQYOaRKLhVAqanE2vHuuw54Ew8mvijpoGVL8tlUrOja8RQ3ilq+BEKh5w3HJngQNMehhPqLpYipNYZ1S7FB0HJ3la0WIAAI9SPvm5bjeAvQnINzRAtQZA3h4qt0kbYkgD74QFyeMYMUMixfHmjenKzTIoCKqwvsmWfEZUsCiK2FpBt160qf73JOfJnHkFDUg2/ePF7zxxFUqwYACDKmIxnhFnbmaIFbgDgO5eP2HwvLVADdzxJT4jPyMrDr+i6cSjglCY4GgPAA23/kVACdvHsSC48tVO05Zi1GoxFzDs7Bruu78N+t/1BQWICRG0cK28uHljcRQEaj+HpLAmiLWOke5coBTz4prSlXki1AAMyKSwDW9QKzlqILkEAJqgFmkfx8oKhMBmJiXDuW4kp0NBAQgCBw16tecAFUgrh37x5iYmLwxRdfCOv27dsHPz8/bNu2zSHvWTm8Mpb3WQ4A2HBpA7Lzs7H41GJh++2022j/S3s0nt8Yu+N2C+t9vXwRFRxl8/tSF9iKsysweM1g/HnmT5uPZTQakVtArqx/X/gbIzeORPtf2qPljy1R/7v6kn29DF5Cjx5qjWDr1lgKgg5m6j7Szs9sPy+DoeRagACx/5FLXGDy5ks3b5Ku3hyi3AsLSWmBsmVdPZriicEAVKnCBZCOcAGkB0YjkJHhmn+secECZcuWxcKFCzF58mQcPnwYaWlpGDBgAEaMGIFOtJ+SA6DWGACoNLOS6n6Lji8CAPSp3wd3xt2Bn7ftfgx57Z9+K/sp1hzSQs8lPeH/mT+O3TmGI7ePSLZdvH9RWJ7fk7R6l1+k2awtSxYgJQEUESGuy8qSHu+770yPQbPQiiPsZ6tkgHGoAJJjNJLaQteuWd63uEPdgX378i6djqRaNXTGVgDaiqpyzMMFkB5kZpI7H1f8szITpUePHhgyZAj69++PYcOGITg4GFOmTHHQB0MI8w8TlrWkpVcJr2JzCjzF28t0dth3Y5/VxzEajfjn8j8AgGYLmiEjL0Nxv9KBpfF6i9cBiBdfvSxA7PUkJwcYWeRx69kTGD5cmpVtMGjLNPNU6GcxYwZxDdIG7xSHC6CtW0nz1T17gMaNyU3IK6+U7IywlStJ3zXANE6Koy81auBTfISpj63H6RMFJGuCYzPFeKrkqDFt2jTk5+dj+fLlWLx4Mfxp2pKDoO4olrGtxqJSmLI1iBVMtjK8xXChMzyFrT+kFXmdonuZ9xT3iwiIEJbNWYAsdXBmixOHKXwMgYHAgAHAyZPAX3+RdWy5FSsMgh4J/Wyp5WvQIOl2KjYdEgQNEPGzdSvQti3w7LNk3e7dRLn++KOD3tTNYStTVq7sunGUBJo0QRCy8K7fTNT9YRxQujQvzGkH3FapB0FBrlPiNkS8XrlyBbdv30ZhYSGuX7+ORo0aOWBgIkG+pmNsEtMEu2J34UbqDZNt1jZAVaJUYClkf5iNM3fPoOE80uiS1iGyBjZlHwAu3b+kuB8rjOSBumlMIpqlItqsBUipU0NQELHysF9ZcY75kWPJu+JUF1jbttLnQ4aQWi1ly0oDt4ozv/wiNXHWsa16O0cjNLtu2zbyDwBGjwZ27nTViDwabgHSA4OBXLlc8c/K6sy5ubl46aWX0LdvX3z66ad47bXXhCanjqJGZA180v4Tybpn6j6Dp+s8bbJv2aCyiuttpUFUAwxvMRyAPgKItumQw9YbkgdBWyzgx0BfGxNDen7JYeOBKCUpFsCtBFCnTkJxOoEaNaT5+sUZo5G4/yg//SRt4MnRnwYNTH/wR47w1vA2wgVQCeODDz5ASkoKvvnmG7z77ruoXbs2Xn31VYe/70ePf4QKoWIvmzD/MLz/2Pv4ouMXkv0S3k6wuQu8GpXDiVn+bob1Qo/tYg8ACRkJwvLPz/wsLLNp9rZagIxG4F6RIWnsWOm26dOJUWH+fO1jL44oCaAE8StxvAuMxWAANmwwXb9jB3DxInDnDnD1qhMG4iQuXyb91AwG4M03gRUrxG1z5gBOmEdKPAEBpDgYS3o6d4PZCBdAJYidO3di1qxZ+O233xAWFgYvLy/89ttv2LNnD+bNm+fw96ep5BQfLx9MeGwCVj6/EgDwVO2nHNJfjTZktbYvWHZ+Np5Y/ITitu96fIdXmryiuM1WC9D48aRxOWB6AR87lmQ9NWtmcdjFGiUBtHatuEzjrRwc1ibi70/6jcn57DPirmjcWFS17sqXXxK33aFD5vcbPhw4fJgsf/cd8Pzz4rYRIxw3Po4U9uSmLscrV1wzFg+HC6ASRPv27ZGXl4e2TOxC1apVkZKSguHDhzv8/X98mgSJTusyTbL+uXrPYf/g/Vj4zEKHvK+tAuhkwknkFCh3Hm0S0wQA8O6jpB3C+23fF7bJg6BZC9D+/eoiaBrzsShZMIpzdpdWlAQQGwZHLUBOE0AA6TdmNJIKyJTffiOmqYwMEiTtrqxaBbz3HpCYCHTvTk7Op54i/dKomrx6FejcmQR/KxEZ6bzxcoBWrcTlNm3I4+XLrhmLh8OnVI7TeLrO03jw7gOMazPOZFuriq0EoaI3ZYNIYTZrBRDtXt+qYiu80eINyTba4uOzjp/h8JDD+LiDWPFangbPWoAKC4FHHzV9rxRZgppTXDgeiJIAYsMf6DXbJZ/fsGHK6vbrr4kQcjfOnAGee058npQEzJoFrFtHxM7vvxOLT40aYsAtAIwZIz3O5s1OGS6niK+/Jufa/v3kuwG4BchGuADiOBU2XdxZUGGllsKuxp7YPQCAqhFVTeoS0eKOPl4+aF6+uaTwojkLEEA8DVmymozyDvBcACljSQC5xALE4uMDLFggXXfwIKld4E7ExSn7U7//Xlx+7TUS88Myf75pgFqTJroPj2OG0FBibWzVCqhZFC/JLUA2wQUQp9hTIawCDDAgNSfVqkDo7w6TYjOVwiqZWKfYJq9y5EHQbJAuRZ60Ib+mcAGkjDkBZDS62AJEqV5dXA4v6me3dStQUOCa8Sixc6f0JGxISkWYtSSsXg28/rq08JSvb8lKQ3Q3qAVo715g4kTXjsUDcZgASkpKQv/+/REWFoaIiAgMHjwY6RZq5WRnZ+PNN99E6dKlERISgt69eyNBdvUYOXIkmjdvDn9/fzRRufM4efIkHvt/e/ce1cSd9gH8mwQCRuQqEHkFL10reF0uigi2r5UtWrXrSrtacY+Kl7VFq+K21bpezlpLq3W71WopbrXbt6LWamvV1iNVi9WCF1pcFY3a9cKKaCsCIsgt8/7xc5JJCJBAwmSY53NODpNJZvIQMfPk+d2GDYO7uzuCg4OxevVqe/1aRII81B6GkWVnis9YdUxlbSXq9CyDmTJwikkCpFKompys0bwT9IULDZ9j3lIiHFADUALUmKYSoPp640SQolWAALY8Bm/XLpYE3b8P7NkjXkzmcnPZTzc34MABy/MrCH3/vXF4v1IJLF/Ottevd1iIxArCmbdXrmxfow7bgMMSoKSkJJw/fx5ZWVnYt28fjh49ilmzZjV5zIIFC7B3717s3LkT2dnZKCoqwnhhG/UjycnJmDBhgsVzlJeX4+mnn0a3bt2Ql5eHNWvWYMWKFcgwL0sTWekXwL7hXvz1olXP5+f/cVG6oI9/H3T17Gp4rIdPjwZrjQkJK0B6PZu12Zzwy7elxU0pAbKsqQRIOOO2qO9fp07A5cusI/RTTwEvPeo/NmuW8yyZwSdAn34KJCQ0HMU1ZAiQmMiGHup0DSd9XL6cVYua+UwnDubhwfoD8XJyxItFijgHKCgo4ABwp06dMuz75ptvOIVCwd28edPiMaWlpZyrqyu3c+dOw74LFy5wALicnJwGz1++fDk3cODABvs3btzI+fj4cNXV1YZ9r732Gte7d2+bfoeysjIOAFdWVtbgsaqqKq6goICrqqqy6ZxSJ+Xfe9438zisAPfKwVesev6Z4jMcVoALWBPAcRzHVddVc1gBDivAjfp0VJPH7trFcQDHxcZy3PHjbNv8du2a8fnXrzd8/OBB234/4bHt2fDhDd+rv/2NPVZSYtxXUyNunCYePuS4bt1YYDt2iB0Ne3NUKhbP9etsn17PcUlJbJ/gM5hIQE0Nxz37rPGPv65O7IhE19T1W8ghFaCcnBx4e3sjKirKsC8+Ph5KpRInTpyweExeXh5qa2sRHx9v2BcaGoqQkBDk2JDV5uTk4IknnoBa8BUwISEBOp0O9+7da/S46upqlJeXm9xI+8GvO1ZYXohv//MtTvzX8t8hL+XrFADAgxo2eketUuOz5z7DgMABhpmlGyOsAPF9E7t0MX2OsAnMUh8hqgBZZm0FyKkWJHdzY7NGA2yCRLFdv87aCzt0AIIfrcenUACffMKGwz/3nLjxEdu4urLlMHiXLS/XQxpySAJUXFyMgIAAk30uLi7w9fVFcXFxo8eo1Wp4m7VFBwYGNnpMY+cJDAxscA7+scakpaXBy8vLcAsOtrxQJ5Emfjbo7ee243f/9zsM+WgI8ovzG33+sRvHAMBk9ffn+z6PM7PPYGzvsU2+lrAPED8CzHxZC2ETGCVA1msqARKOAHPAfJqt0/VRE+p/bV+Oxe74rLxnT9M3Sqlki2sS6Rk+3Nj37NgxcWOREJsSoEWLFkGhUDR5u3jRuj4Wzmbx4sUoKysz3AoLGy7SSaRr0P8MarAv+1p2s8cJl++wlnAYPN/VwnyuOGEFyFKxkRIgyywNpDKvADnle+dMCRBfIeCHUJP24Zln2M+ZM4H33xc3FomwqVC8cOFCTBUufmdBz549odVqGyywWVdXh5KSEmi1WovHabVa1NTUoLS01KQKdPv27UaPaew85iPH+PtNncfNzQ1uog4dIY7U3bt7g32nb522+NyHdcZeyQf/ZPskb3wT2NWrxg7Q5l+shQmQ+TxBgO0XcZXKuUZZO4qlDuPmCZBT/jcOYRVInD3LesaLOa03/0fZv794MRD7S04GVqxg23PnsvvCadJJAzb9L/T390doaGiTN7VajZiYGJSWliIvL89w7OHDh6HX6xEdHW3x3JGRkXB1dcUhwYyjOp0ON27cQExMjNUxxsTE4OjRo6gVXGGysrLQu3dv+Pj42PLrknZGAdN2EX6mZ3P8XEGuSleEdQ6z+XX4CpAwselkNm2QsAnMHgnQRjZlEZYsse04qWkqAWrThVBtNWwY4OnJJiBsbs0tRzt7lv0cMEDcOIh9de1qep9m6G6WQ76GhIWFYeTIkZg5cyZOnjyJ48ePY86cOZg4cSKCHk2idfPmTYSGhuLkyZMAAC8vL0yfPh2pqak4cuQI8vLyMG3aNMTExGCIYO2TK1euID8/H8XFxaiqqkJ+fj7y8/NR8+jTb9KkSVCr1Zg+fTrOnz+PHTt24L333kOq+UxzRHbcXExLA0X3iyw+77/lrJlC66Ft0eKsfAVIyHziQ2EFyNL0WLZexGfNAoqK2FQg7Zn5DNqARCpAGg0bEg+wOXXs7euv2RIJ/ERIltTXsz82fiCK+ariRNoUCrauGy+7+SZ+uXNYHXbr1q0IDQ3FiBEj8MwzzyAuLs5kLp7a2lrodDpUCubFePfddzFmzBgkJibiiSeegFarxe7du03OO2PGDISHh+PDDz/EpUuXEB4ejvDwcBQVsYuZl5cXDh48iKtXryIyMhILFy7EsmXLmp2DSA4yMjIQFBQEvV5vsv/3v/89kpOTRYqq7SyKZR8O4dpwAMDN8pvgzC4Yek6P2M1ssS6+47StLCVA5pULe1eAADbSzOk6/9qZpVYbSVSAAOMicPbupPrhh8Do0cCrr5qu0C7088+sJ76wFNmtm33jIOJLSwM++ohtW5qAjJhw2GBRX19fZGZmNvp49+7dG1x83N3dsWHDBmzYsKHR47777rtmX3vAgAH43hHfshrBceLNb6bRWH/Re/755zF37lwcOXIEIx4Nyy0pKcGBAwfw9ddfOzBK5/D6sNcxpOsQRHSJQMA7Aaiqq0JZdZnJ+mTXS68btocGD23R61iqQJgNirTYB8jPD7h7l2077UVcZBs3AuYfK5KoAAHGBOiHH9iHhj2y1WvXTCfC+/xz9mFk3vdj0yY2qaFQc7M/E2kaOJD9PHPGfn9n7RStBWYHlZVsQk4xbrYkXj4+Phg1apRJYvr555+jc+fOGD58uAPeGefiqnJFwm8S4N/RHx5qDwANV4gXzhT9ytBXWvQ6li7AfN9E3ujRbNFtwNgE1rGj8XFKgCzz8mITFwvxyaTTV4AiIljn519+AZr4kmc1jjMuTyF082bDfcePm96fMIEujO1V377s7+zuXdYuThpFCZDMJCUlYdeuXah+9HV569atmDhxIpRijkoRgV8HNizLPAG68CtbuOu5Ps/Bv6N/i85tfgFeuRIwm5oKADD20XRCfAWIEiDrmI9246tmTl8BcnMzLo46d27rz/fdd8Zmjl69gMcfZ9t8ApSXx+b6efNN4Px5tu+nn1jGuG1b61+fOCd3d6B3b7ZNzWBNktdVz0E0GvYtXoybraMcx44dC47jsH//fhQWFuL7779HUlKSY94YJ+anYQnQ3cq7Jvv5ClCoX2iDY6xlfgF2d2/6y3ZBAfv5P4IphygBapx5AvTDD2wySaevAAGsjwavpKR15xL2Jfr2W+Mf0KxZrM9PVBSbi2HJEuDePVYV6N2bDVOk6k/7FhfH+nyNGQNMmiR2NE6LEiA7UCjYt3cxbrZ+jrm7u2P8+PHYunUrtm3bht69eyMiIsIxb4wT4ytAd6uMCZCe0+Pry6wvVJi/7cPfeeYJUFMVibt3jQs4C1aOgUrV4pdv94QJED+zxb17EqgAAcCf/2zsfMxXZVqKn+9s+HA2z1Dfvuz+5cuWJzmMjGTLX5D2LyOD/afQ61m1Lz1d7IicEiVAMpSUlIT9+/dj8+bNsqz+AJYrQAd/Poib91nzQWxwbIvPbUsC9OujFjgfH2PrCGmaMAHiqz01NRKpAAFAv37s55gxltdBsda1a+znCy+wn++80/A5o0YZtwXrLBIZEI4IbO8ThLUQJUAy9NRTT8HX1xc6nQ6TZFoe1XZks4LzCQ8AwwKpiWGJ6Obd8iHClprAGsN3Yu/QgX1ZI80TJkD8lAO1tRKpAAHGBKi8HNBqgVu3WnYePgHiK0puboBOxyo9ADBtGvDpp8bnP/lky16HSFNmprGPhKUZRAklQHKkVCpRVFQEjuPQs2dPscMRRS8/Ngnc5RLjysn/vsM6DLam+gOwrhbCRTubuiDzCZBGQwmQterqjNuSrAD97/+a3t+50/Zz6PXGJrTu3Y37H38cOH2ajRDbvJktQvfxx2yCvN/9roUBE8niv31VVlJfIAsoASKy9LgfGzGj+1VnmI/qZjmrBllaN8xWwotwUxUgfs1CSoCsJ/kK0MiRrHOyB5uKwTAzsy3++lfjdkgzE3ZOmcI6X8tspCcBsGqVcZv6AjVA/yOILAV1Ykuy6O7qoPybEpN3T4burs7ksdYQXoT57VALA8u2b2c/KQGy3quvsp/jxxsTzdpaCVWAAFa14ZundDrbj9+717jdVIZN5G32bGMfMYD6ApmhBIjIUkfXjib3t57ditKHpQDsnwDx16czZ4Dp0y0/nxIg602cyAY67dhhrADV1EioAsTj++7k5Vn/j6/XA6+9Bpw7x+4fOOCY2Ej7IewLdO8eVYEEKAEisqRxtTyBkqvSFVoPbavPb6kCpFY3vvySRgM8+yzb1rb+5du93/yG9bMSNoFJqgIEmDZdvfee6WOFhezbOr+uE2/vXmD1arat0QCPlrQhpEn8tzCOoyqQACVARJY6qjta3D8udBxcVRZWM7WRpQQIMJ3tWUijYfMAXbgAXLrU6peXDWEnaMlVgPhJjAAgNZX1C+JNm8ZmcJ4xA7hoXJ4FhYXG7YAA0972hDRm1SpjH7CSEufoEJ2ezpqCRaxIUQJEZMndxXK/iT7+fexyfktNYICx36s5vkIdGmq6YDdpmqUKkGQSIIXCdCFT4egwYdITFgbExLChzMJ1vl580eEhknZi9mzT9ef4zodiWrIEuH5d1IoUJUBElpQKy3/6nTWd7XJ+WytANEFvywg7QfMVIMk0gQGmScyNGyyLq6trODdQbi6Qn29cQTcuDvjLX9osTNIOzJ5t/KbFceJWgSZNav1SMHZACRCRPZXCuO5EY32DbCWs9AgrQMIESLjcha1ruhFG2Alacn2AALZ8hbDyc/48W8Gb7xQdJliS5fRp1vlZqQS++oqGtRPbrV1r3BZrWHx6uulivMKh+m2M/gcR2evq2dWwrYB9FokUJjfCJS6ECZC/YLF5SoBaRtgEVlvLtiWVAKlUwOHDQOyjyTezs4EjR9h2r15skVMev4K8QmHaf4gQa5kPi1+4sG1fPz0dSEkx3tdoTJuB2xglQET2unTqYtjmJ0hsLeGoZmHS4+1t3A4IMG5TAtQywk7QfAIkuX7BCgUwbBjbXrAAmDqVbYeHA0FBwEsvmT6fhgmS1sjMNFYP23qG6CVLjB+OSqVpRUoElAAR2QvqFISjU49i09hNiA1p3TIYvEeTSwNg1zeecJFuYdJDCVDLCCtA/BIZrq0fxNf2xo1ruC8igv3s3990/65dDg+HtHMTJhi3t21rmyRI2O9HoWCdskWs/gCUABEZWxS7CF5uXlgdvxrDug3DjIgZdjt3Y/PaCVsuqqqM25QAtYywEzSfAEmuAgQAgwc33McnQLNmGfe98AIQHd02MZH2KzPTtCnM0UnQpEmm/X58fERPfgBKgGTnwIEDiIuLg7e3N/z8/DBmzBj8/PPPYoclirT4NNx99S4e833M7uduamLfmTNZU9if/2zcRwlQywg7QUu2CQxg34jN+2PwCZBSCRQUACtWABkZbR4aaacsJUGO6BRt3ulZoxG147OQFD8qnA7HcaisrRTltTWuGigU1nfcffDgAVJTUzFgwABUVFRg2bJl+MMf/oD8/HwoZTiqRKVUNf+kFmgqAcrIYNXfnBzjPhoG3zKWmsAkmQABbOLDmTOBsjL2C/n5GR8LCwOWLxcvNtI+ZWYCe/awvkAAm5bh6FG23x7MKz8aDfDggX3ObQdS/ahwKpW1lfBIa2SGOwerWFzR6KzGliQmJprc37x5M/z9/VFQUIB+/frZOzzZiogAfvih8cddXU07RFMFqGWEnaAl3QcIYL9M795iR0HkZu1aNjKL/9a2bRvwxBOtb6Iyr/w4Qadnc/L7yi9zly9fxgsvvICePXvC09MT3bt3BwDcuHFD3MDamTfeYC0ap083/hxKgFqPn2Ty4cN2UAEiRAz8LNHCloQXX2xdn6BJk0wn+XSSTs/m6KPCDjSuGlQsrhDttW0xduxYdOvWDZs2bUJQUBD0ej369euHGn4WOWIXXl7AO+80/RxhAmRDKyYR4KcYqKyUeB8gQsTEJybmlaA9e1jVxtrEJT0dePll439GgFV+nDD5ASgBsguFQmFTM5RY7t69C51Oh02bNmHYo3lHjh07JnJU8iWcLZryz5bhE6AHD9pBExghYuITlJdeMs7jUVnJKjnz5gHvvdd4EmMp8QFYaduWBKqNUQIkIz4+PvDz80NGRga6dOmCGzduYNGiRWKHJVtKJevy8fPPwKBBYkcjTXzToTABogoQIS3EJyoLFxo7RgPsG9qLL9q2AG9UFHDqlH3jszPqAyQjSqUS27dvR15eHvr164cFCxZgzZo1Yocla2fPAvfu0QrwLWWpAkQJECGtMHs2+w/1wQctK6eq1exYJ09+AKoAyU58fDwKCgpM9nHCaYtJm3J1pSab1hAmQHz1nd5PQuxg9mx2a6x5y5xa3XQzmROiBIgQIlnCTtBUASLEAfhEqB2iJjBCiGRRExghpKUoASKESJawEzQ1gRFCbEHflQghkiWsAPEruVAFiBBiDaoAEUIky1InaEqACCHWoASIECJZfALEcUBVFdumJjBCiDUoASKESFZHCxOwUwWIEGINSoAIIZKlUhkXROVRAkQIsQYlQIQQSdOYrQdMTWCEEGtQAkQIkTTzhIcqQIQQa1ACJDNTp06FQqHAbAsze6akpEChUGDq1KltHxghLXTnjul9SoAIIdagBEiGgoODsX37dlTxw2YAPHz4EJmZmQgJCRExMkJaj5rACCHWoARIhiIiIhAcHIzdu3cb9u3evRshISEIDw8XMTJCWk9Jn2qEECtQsdgOOI5DZWWlKK+t0WigUChsPi45ORlbtmxBUlISAGDz5s2YNm0avvvuOztHSIhjeXsDpaViR0EIkRpKgOygsrISHh4eorx2RUUFOlqaDKUZkydPxuLFi3H9+nUAwPHjx7F9+3ZKgIjk5OYCoaFiR0EIkRpKgGTK398fo0ePxscffwyO4zB69Gh07txZ7LAIsVnv3oCPD3DvntiREEKkhBIgO9BoNKioqBDttVsqOTkZc+bMAQBs2LDBXiER0ubq68WOgBAiNZQA2YFCoWhRM5TYRo4ciZqaGigUCiQkJIgdDiEtRgkQIcRWNF5CxlQqFS5cuICCggKoVCqxwyGkxRIT2c+BA8WNgxAiHVQBkjlPT0+xQyCk1davBwYPNiZChBDSHEqAZObjjz9u8vEvv/yyTeIgxJ48PYGUFLGjIIRIicOawEpKSpCUlARPT094e3tj+vTpzXYUfvjwIVJSUuDn5wcPDw8kJibi9u3bJs95+eWXERkZCTc3N/z2t79tcI5r165BoVA0uOXm5trz1yOEEEKIhDksAUpKSsL58+eRlZWFffv24ejRo5g1a1aTxyxYsAB79+7Fzp07kZ2djaKiIowfP77B85KTkzFhwoQmz/Xtt9/i1q1bhltkZGSrfh9CCCGEtB8OaQK7cOECDhw4gFOnTiEqKgoAsH79ejzzzDN45513EBQU1OCYsrIyfPTRR8jMzMRTTz0FANiyZQvCwsKQm5uLIUOGAADWrVsHAPjll1/w73//u9EY/Pz8oNVq7f2rEUIIIaQdcEgFKCcnB97e3obkBwDi4+OhVCpx4sQJi8fk5eWhtrYW8fHxhn2hoaEICQlBTk6OzTE8++yzCAgIQFxcHL766qtmn19dXY3y8nKTGyGEEELaJ4ckQMXFxQgICDDZ5+LiAl9fXxQXFzd6jFqthre3t8n+wMDARo+xxMPDA2vXrsXOnTuxf/9+xMXFYdy4cc0mQWlpafDy8jLcgoODrX5NQgghhEiLTQnQokWLLHYwFt4uXrzoqFit0rlzZ6SmpiI6OhqDBg3CW2+9hcmTJ2PNmjVNHrd48WKUlZUZboWFhc2+ll6vt1fYkiC335cQQkj7ZVMfoIULF2Lq1KlNPqdnz57QarW4c+eOyf66ujqUlJQ02i9Hq9WipqYGpaWlJlWg27dvt7ovT3R0NLKyspp8jpubG9zc3Kw6n1qthlKpRFFREfz9/aFWq1u0IrtUcByHmpoa/PLLL1AqlVCr1WKHRAghhLSKTQmQv78//P39m31eTEwMSktLkZeXZxh9dfjwYej1ekRHR1s8JjIyEq6urjh06BASH81mptPpcOPGDcTExNgSZgP5+fno0qVLq84hpFQq0aNHD9y6dQtFRUV2O6+z02g0CAkJgVJJE4gTQgiRNoeMAgsLC8PIkSMxc+ZMpKeno7a2FnPmzMHEiRMNI8Bu3ryJESNG4JNPPsHgwYPh5eWF6dOnIzU1Fb6+vvD09MTcuXMRExNjGAEGAFeuXEFFRQWKi4tRVVWF/Px8AECfPn2gVqvxr3/9C2q1GuHh4QCA3bt3Y/PmzfjnP/9p199RrVYjJCQEdXV1qJfBQkQqlQouLi7tutJFCCFEPhw2E/TWrVsxZ84cjBgxAkqlEomJiYYh7ABQW1sLnU6HyspKw753333X8Nzq6mokJCRg48aNJuedMWMGsrOzDff5ROfq1avo3r07AGDlypW4fv06XFxcEBoaih07duC5556z+++oUCjg6uoKV1dXu5+bEEIIIY6j4DiOEzsIZ1ReXg4vLy+UlZXRelmEEEKIRFh7/abOHIQQQgiRHUqACCGEECI7tBp8I/iWQZoRmhBCCJEO/rrdXA8fSoAacf/+fQCgGaEJIYQQCbp//z68vLwafZw6QTdCr9ejqKgInTp1suvQ7/LycgQHB6OwsJA6V7cCvY/2Qe+jfdD7aB/0PtqH3N9HjuNw//59BAUFNTlvHVWAGqFUKtG1a1eHnd/T01OWf5j2Ru+jfdD7aB/0PtoHvY/2Ief3sanKD486QRNCCCFEdigBIoQQQojsUALUxtzc3LB8+XKrF14lltH7aB/0PtoHvY/2Qe+jfdD7aB3qBE0IIYQQ2aEKECGEEEJkhxIgQgghhMgOJUCEEEIIkR1KgAghhBAiO5QAtaENGzage/fucHd3R3R0NE6ePCl2SJKTlpaGQYMGoVOnTggICMC4ceOg0+nEDkvS3nrrLSgUCsyfP1/sUCTp5s2bmDx5Mvz8/NChQwf0798fp0+fFjssSamvr8fSpUvRo0cPdOjQAY899hhWrlzZ7FpOcnf06FGMHTsWQUFBUCgU+PLLL00e5zgOy5YtQ5cuXdChQwfEx8fj8uXL4gTrhCgBaiM7duxAamoqli9fjh9//BEDBw5EQkIC7ty5I3ZokpKdnY2UlBTk5uYiKysLtbW1ePrpp/HgwQOxQ5OkU6dO4cMPP8SAAQPEDkWS7t27h9jYWLi6uuKbb75BQUEB1q5dCx8fH7FDk5S3334bH3zwAd5//31cuHABb7/9NlavXo3169eLHZpTe/DgAQYOHIgNGzZYfHz16tVYt24d0tPTceLECXTs2BEJCQl4+PBhG0fqpDjSJgYPHsylpKQY7tfX13NBQUFcWlqaiFFJ3507dzgAXHZ2ttihSM79+/e5Xr16cVlZWdyTTz7JzZs3T+yQJOe1117j4uLixA5D8kaPHs0lJyeb7Bs/fjyXlJQkUkTSA4D74osvDPf1ej2n1Wq5NWvWGPaVlpZybm5u3LZt20SI0PlQBagN1NTUIC8vD/Hx8YZ9SqUS8fHxyMnJETEy6SsrKwMA+Pr6ihyJ9KSkpGD06NEmf5fENl999RWioqLw/PPPIyAgAOHh4di0aZPYYUnO0KFDcejQIVy6dAkAcObMGRw7dgyjRo0SOTLpunr1KoqLi03+f3t5eSE6OpquO4/QYqht4Ndff0V9fT0CAwNN9gcGBuLixYsiRSV9er0e8+fPR2xsLPr16yd2OJKyfft2/Pjjjzh16pTYoUjaf/7zH3zwwQdITU3F66+/jlOnTuHll1+GWq3GlClTxA5PMhYtWoTy8nKEhoZCpVKhvr4eq1atQlJSktihSVZxcTEAWLzu8I/JHSVARLJSUlJw7tw5HDt2TOxQJKWwsBDz5s1DVlYW3N3dxQ5H0vR6PaKiovDmm28CAMLDw3Hu3Dmkp6dTAmSDzz77DFu3bkVmZib69u2L/Px8zJ8/H0FBQfQ+EoehJrA20LlzZ6hUKty+fdtk/+3bt6HVakWKStrmzJmDffv24ciRI+jatavY4UhKXl4e7ty5g4iICLi4uMDFxQXZ2dlYt24dXFxcUF9fL3aIktGlSxf06dPHZF9YWBhu3LghUkTS9Morr2DRokWYOHEi+vfvjz/96U9YsGAB0tLSxA5NsvhrC113GkcJUBtQq9WIjIzEoUOHDPv0ej0OHTqEmJgYESOTHo7jMGfOHHzxxRc4fPgwevToIXZIkjNixAicPXsW+fn5hltUVBSSkpKQn58PlUoldoiSERsb22AahkuXLqFbt24iRSRNlZWVUCpNL0cqlQp6vV6kiKSvR48e0Gq1Jted8vJynDhxgq47j1ATWBtJTU3FlClTEBUVhcGDB+Mf//gHHjx4gGnTpokdmqSkpKQgMzMTe/bsQadOnQxt2V5eXujQoYPI0UlDp06dGvSZ6tixI/z8/KgvlY0WLFiAoUOH4s0338Qf//hHnDx5EhkZGcjIyBA7NEkZO3YsVq1ahZCQEPTt2xc//fQT/v73vyM5OVns0JxaRUUFrly5Yrh/9epV5Ofnw9fXFyEhIZg/fz7eeOMN9OrVCz169MDSpUsRFBSEcePGiRe0MxF7GJqcrF+/ngsJCeHUajU3ePBgLjc3V+yQJAeAxduWLVvEDk3SaBh8y+3du5fr168f5+bmxoWGhnIZGRlihyQ55eXl3Lx587iQkBDO3d2d69mzJ7dkyRKuurpa7NCc2pEjRyx+Hk6ZMoXjODYUfunSpVxgYCDn5ubGjRgxgtPpdOIG7UQUHEdTbRJCCCFEXqgPECGEEEJkhxIgQgghhMgOJUCEEEIIkR1KgAghhBAiO5QAEUIIIUR2KAEihBBCiOxQAkQIIYQQ2aEEiBBCCCGyQwkQIYQQQmSHEiBCCCGEyA4lQIQQQgiRHUqACCGEECI7/w+caxltFBX5PQAAAABJRU5ErkJggg==" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "\n", + "from essm_jax.pytee_utils import pytree_unravel\n", + "from typing import NamedTuple\n", + "\n", + "\n", + "class State(NamedTuple):\n", + " \"\"\"\n", + " State for a single timestep of the Kalman filter.\n", + " \"\"\"\n", + " x: jax.Array # [n]\n", + " v: jax.Array # [n]\n", + " a: jax.Array # [n]\n", + " M: jax.Array # [n]\n", + "\n", + "\n", + "class Observation(NamedTuple):\n", + " \"\"\"\n", + " Observables for a single timestep of the Kalman filter.\n", + " \"\"\"\n", + " x: jax.Array # [n]\n", + "\n", + "\n", + "n = 1 # Single dimension\n", + "\n", + "\n", + "@dataclasses.dataclass(eq=False)\n", + "class ExcitedDampedHarmonicOscillator:\n", + " n: int = 1\n", + " dt: float = 0.01\n", + " sigma_f: float = 0.1\n", + " sigma_x: float = 0.1\n", + " zeta: float = 0.1\n", + " omega: float = 1.\n", + " m_bar: float = 1.\n", + "\n", + " def __post_init__(self):\n", + " self.c = 2 * self.zeta * self.omega\n", + " self.k = self.omega ** 2\n", + " self.M_bar = jnp.log(self.m_bar)\n", + "\n", + " example_state = State(\n", + " x=jnp.ones((n,)),\n", + " v=jnp.ones((n,)),\n", + " a=jnp.ones((n,)),\n", + " M=jnp.ones((n,))\n", + " )\n", + "\n", + " example_observation = Observation(\n", + " x=jnp.ones((n,))\n", + " )\n", + "\n", + " self.state_ravel_fn, self.state_unravel_fn = pytree_unravel(example_state)\n", + " self.obs_ravel_fn, self.obs_unravel_fn = pytree_unravel(example_observation)\n", + "\n", + " def batched_state_ravel_fn(self, state: State) -> jax.Array:\n", + " \"\"\"\n", + " Ravel a state with batched time axis into batched flat states.\n", + " \"\"\"\n", + " # This is needed because calling ravel is only defined on the per element basis.\n", + " return jax.vmap(self.state_ravel_fn)(state)\n", + "\n", + " def batched_state_unravel_fn(self, flat_state: jax.Array) -> State:\n", + " \"\"\"\n", + " Unravel a batched flat state into a state with batched time axis.\n", + " \"\"\"\n", + " # This is needed because calling unravel is only defined on the per element basis.\n", + " return jax.vmap(self.state_unravel_fn)(flat_state)\n", + "\n", + " def batched_observables_ravel_fn(self, observation: Observation) -> jax.Array:\n", + " \"\"\"\n", + " Ravel observables with batched time axis into batched flat observables.\n", + " \"\"\"\n", + " # This is needed because calling ravel is only defined on the per element basis.\n", + " return jax.vmap(self.obs_ravel_fn)(observation)\n", + "\n", + " def batched_observables_unravel_fn(self, flat_observables: jax.Array) -> Observation:\n", + " \"\"\"\n", + " Unravel a batched flat observables into observables with batched time axis.\n", + " \"\"\"\n", + " # This is needed because calling unravel is only defined on the per element basis.\n", + " return jax.vmap(self.obs_unravel_fn)(flat_observables)\n", + "\n", + " def transition_fn(self, z: jax.Array, t: jax.Array, t_next) -> (tfpd.MultivariateNormalLinearOperator):\n", + " dt = t_next - t\n", + " state = self.state_unravel_fn(z)\n", + " next_state_mean = State(\n", + " x=state.x + state.v * dt,\n", + " v=state.v + (-self.c * state.v - self.k * state.x) * dt,\n", + " a=(-self.c * state.v - self.k * state.x),\n", + " M=self.M_bar * jnp.ones_like(state.M)\n", + " )\n", + " next_state_scale = State(\n", + " x=jnp.zeros_like(state.x),\n", + " v=self.sigma_f * jnp.sqrt(dt) * jnp.ones_like(state.v),\n", + " a=jnp.zeros_like(state.a),\n", + " M=jnp.zeros_like(state.M)\n", + " )\n", + " return tfpd.MultivariateNormalDiag(self.state_ravel_fn(next_state_mean), self.state_ravel_fn(next_state_scale))\n", + "\n", + " def observation_fn(self, z: jax.Array, t: jax.Array) -> tfpd.MultivariateNormalLinearOperator:\n", + " state = self.state_unravel_fn(z)\n", + " obs_mean = Observation(\n", + " x=state.x,\n", + " )\n", + " obs_scale = Observation(\n", + " x=self.sigma_x * jnp.ones_like(state.x),\n", + " )\n", + " return tfpd.MultivariateNormalDiag(self.obs_ravel_fn(obs_mean), self.obs_ravel_fn(obs_scale))\n", + "\n", + " def create_initial_state_prior(self) -> tfpd.MultivariateNormalLinearOperator:\n", + " initial_state = State(\n", + " x=jnp.zeros((n,)),\n", + " v=jnp.zeros((n,)),\n", + " a=jnp.zeros((n,)),\n", + " M=self.M_bar * jnp.ones((n,))\n", + " )\n", + " initial_state_scale = State(\n", + " x=jnp.zeros((n,)),\n", + " v=jnp.zeros((n,)),\n", + " a=jnp.zeros((n,)),\n", + " M=jnp.zeros((n,))\n", + " )\n", + " return tfpd.MultivariateNormalDiag(self.state_ravel_fn(initial_state), self.state_ravel_fn(initial_state_scale))\n", + "\n", + " def build_essm(self):\n", + " return ExtendedStateSpaceModel(\n", + " transition_fn=self.transition_fn,\n", + " observation_fn=self.observation_fn,\n", + " initial_state_prior=self.create_initial_state_prior(),\n", + " more_data_than_params=False,\n", + " dt=self.dt\n", + " )\n", + "\n", + "\n", + "model = ExcitedDampedHarmonicOscillator(\n", + " n=1,\n", + " dt=0.01,\n", + " sigma_f=0.01,\n", + " sigma_x=0.001,\n", + " zeta=0.1,\n", + " omega=1.,\n", + " m_bar=1.\n", + ")\n", + "\n", + "essm = model.build_essm()\n", + "\n", + "samples = essm.sample(key=jax.random.PRNGKey(0), num_time=1000)\n", + "\n", + "obs = model.batched_observables_unravel_fn(samples.observation)\n", + "state = model.batched_state_unravel_fn(samples.latent)\n", + "plt.plot(samples.t, state.x, label='x', c='r')\n", + "plt.plot(samples.t, state.v, label='v', c='b')\n", + "plt.plot(samples.t, state.a, label='a', c='g')\n", + "plt.plot(samples.t, state.M, label='M', c='k')\n", + "plt.plot(samples.t, obs.x, label='observed', alpha=0.5)\n", + "\n", + "filter_result = essm.forward_filter(observations=samples.observation)\n", + "\n", + "future_samples = essm.forward_simulate(key=jax.random.PRNGKey(0), num_time=100, filter_result=filter_result)\n", + "future_latent = model.batched_state_unravel_fn(future_samples.latent)\n", + "plt.plot(future_samples.t, future_latent.x, c='r')\n", + "plt.plot(future_samples.t, future_latent.v, c='b')\n", + "plt.plot(future_samples.t, future_latent.a, c='g')\n", + "plt.plot(future_samples.t, future_latent.M, c='k')\n", + "plt.legend()\n", + "plt.show()\n", + "\n", + "filter_latent = model.batched_state_unravel_fn(filter_result.filtered_mean)\n", + "plt.plot(filter_result.t, filter_latent.x, label='x', c='r')\n", + "plt.plot(filter_result.t, filter_latent.v, label='v', c='b')\n", + "plt.plot(filter_result.t, filter_latent.a, label='a', c='g')\n", + "plt.plot(filter_result.t, filter_latent.M, label='M', c='k')\n", + "\n", + "filter_state = essm.create_filter_state(filter_result=filter_result)\n", + "for _ in range(100):\n", + " filter_state = essm.incremental_predict(filter_state=filter_state)\n", + " latent = model.state_unravel_fn(filter_state.filtered_mean)\n", + " plt.scatter(filter_state.t, latent.x, c='r', s=1)\n", + " plt.scatter(filter_state.t, latent.v, c='b', s=1)\n", + " plt.scatter(filter_state.t, latent.a, c='g', s=1)\n", + " plt.scatter(filter_state.t, latent.M, c='k', s=1)\n", + "\n", + "plt.legend()\n", + "plt.show()\n", + "\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-14T22:31:00.060073207Z", + "start_time": "2024-08-14T22:30:53.098775507Z" + } + }, + "id": "5f0b000d8cfe6d2b", + "execution_count": 4 + }, + { + "cell_type": "code", + "outputs": [], + "source": [], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2024-08-14T22:28:27.399534612Z" + } + }, + "id": "6626b28fc500ba6d", + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/online_filtering.ipynb b/docs/examples/online_filtering.ipynb new file mode 100644 index 0000000..0658db6 --- /dev/null +++ b/docs/examples/online_filtering.ipynb @@ -0,0 +1,196 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "ef3eaff3797489e6", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-14T22:35:52.089737670Z", + "start_time": "2024-08-14T22:35:51.100189045Z" + } + }, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "import tensorflow_probability.substrates.jax as tfp\n", + "import pylab as plt\n", + "\n", + "from essm_jax.essm import ExtendedStateSpaceModel\n", + "\n", + "tfpd = tfp.distributions\n" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Let's define a non-linear transition function that forces the state proportionally to its magnitude,\n", + "\n", + "$$p(z_{t+1} | z_t, t) = \\mathcal{N}[T(z_t, t), \\sigma_t^2]$$\n", + "\n", + "with,\n", + "\n", + "$$T(z_t, t) = z_t \\left(1 + |z_t| \\sin \\left(2 \\pi \\frac{t}{10}\\right)\\right)$$\n", + "\n", + "and \n", + "\n", + "$$\\sigma_t = 0.1$$\n", + "\n", + "Now for the observation function, let's define also a non-linear one, that takes the absolute value.\n", + "\n", + "$$p(x_t | z_t, t) = \\mathcal{N}[O(z_t, t), \\epsilon_t^2]$$\n", + "\n", + "with \n", + "\n", + "$$O(z_t, t) = |z_t|$$\n", + "\n", + "and noise that oscillates with time,\n", + "\n", + "$$\\epsilon_t = 0.01 + 0.1 \\cos\\left(2\\pi \\frac{t}{5}\\right)$$\n" + ], + "metadata": { + "collapsed": false + }, + "id": "716c02a70cfecdff" + }, + { + "cell_type": "code", + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" + ] + }, + { + "data": { + "text/plain": "
", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiIAAAGdCAYAAAAvwBgXAAAAP3RFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMS5wb3N0MSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8kixA/AAAACXBIWXMAAA9hAAAPYQGoP6dpAACJK0lEQVR4nO29d5gkd3ntf6rz5Nk4m7VJ0iohrXIECclGEgYEmAsYuJLgChuka2S4xpKxhX1tWHyNceCHSTYIbDAYkzFJrDIornJYaVcraVebZ8PkzvX7o/v7rW9VV+yu6q6eOZ/n2Qe029NdU9PT9dZ5z3teTdd1HYQQQgghHSDR6QMghBBCyNyFhQghhBBCOgYLEUIIIYR0DBYihBBCCOkYLEQIIYQQ0jFYiBBCCCGkY7AQIYQQQkjHYCFCCCGEkI6R6vQBuFGtVrFnzx4MDAxA07ROHw4hhBBCfKDrOiYmJrBs2TIkEu6aR6wLkT179mDlypWdPgxCCCGENMGuXbuwYsUK18fEuhAZGBgAUPtGBgcHO3w0hBBCCPHD+Pg4Vq5cKa/jbsS6EBHtmMHBQRYihBBCSJfhx1ZBsyohhBBCOgYLEUIIIYR0DBYihBBCCOkYLEQIIYQQ0jFYiBBCCCGkY7AQIYQQQkjHYCFCCCGEkI7BQoQQQgghHYOFCCGEEEI6BgsRQgghhHQMFiKEEEII6RgsRAghhBDSMViIEEIIceWp3WP4l3t2oFLVO30oZBYS6+27hBBCOs9f//czuH/HYRy/ZAAXHbuo04dDZhlURAghhLgyPlMGABycKHT4SMhshIUIIYQQV8rVKgBgfKbU4SMhsxEWIoQQQlwpV2rekPF8ucNHQmYjLEQIIYS4UqIiQiKEhQghhBBXKlIRYSFCwoeFCCGEEFdK9bFdYVolJExYiBBCCHGlXKm1ZiYKVERI+LAQIYQQ4oo0q1IRIRHAQoQQQogr0qxKjwiJABYihBBCXDEUERYiJHxYiBBCCHFE13WUq0aOiK5z3wwJFxYihBBCHFEX3VWqOqaLlQ4eDZmNsBAhhBDiSNmycZc+ERI2LEQIIYQ4UqqP7go4OUPChoUIIYQQR4RRVUBFhIQNCxFCCCGOiNFdASdnSNiwECGEEOIIFRESNSxECCGEONJQiNAjQkKGhQghhBBHymzNkIhhIUIIIcQRju+SqGEhQgghxBGO75KoYSFCCCHEEZpVSdSwECGEEOKI1SMykaciQsKFhQghhBBHSlRESMSwECGEEOJI4/guCxESLixECCGEOCJaM5lk7XIxztYMCRkWIoQQQhwRisj8vgyAmiKi67rblxASCBYihBBCHBGKiChEylUdM6VKJw+JzDJYiBBCCHFEmFUHcimkEhoAZomQcGEhQgghxBGhiKSTCQz2pAFwcoaECwsRQgghjghFJJXUMJhLAeDkDAmXSAuRTZs24ayzzsLAwAAWL16Mq666Cs8991yUL0kIISREhFk1laAiQqIh0kLkrrvuwvXXX4/7778ft912G0qlEn77t38bU1NTUb4sIYTMevaP5zE2HX1BUJGtGQ2DuXohQo8ICZFUlE/+85//3PTft956KxYvXowtW7bg1a9+dZQvTQghs5apQhmX/t1dWDyYxe0fuTjS1zJaMwkM9tRbM1RESIhEWohYGRsbAwDMnz/f9t8LhQIKhYL87/Hx8bYcFyGEdBMHJgqYLJSRPxT9GK00qyY09GWEIsJChIRH28yq1WoVN954Iy644AKcfPLJto/ZtGkThoaG5J+VK1e26/AIIaRryNdzPMpVPfJwMaGIJBOa4hFha4aER9sKkeuvvx5PPfUUvvWtbzk+5uabb8bY2Jj8s2vXrnYdHiGEdA15JVCsGnHIaVltzXBqhkRAW1ozN9xwA37yk5/g7rvvxooVKxwfl81mkc1m23FIhBDSteRLVfn/y9UqkolkZK9VVsyqAzlOzZDwibQQ0XUd//t//298//vfx5133ok1a9ZE+XKEEDInyJcNRaQSsSRSMo3v1i4ZE2zNkBCJtBC5/vrr8c1vfhM//OEPMTAwgH379gEAhoaG0NPTE+VLE0LIrKWgtGbKERci9uO7VERIeETqEfn85z+PsbExXHzxxVi6dKn88+1vfzvKlyWEkFmN2pqpVNqkiCRpViXREHlrhhBCSLjMtFERER6RVCJBRYREAnfNEEJIl6FOzUTtETEi3jVToBlvNElYsBAhhJAuwzo1EyWmZNW6IlKq6KZjIKQVWIgQQkiX0VZFRDGr9maSSCY0ABzhJeHBQoQQQroMdXw3co+I0prRNI2hZiR0WIgQQkiXUVCnZtplVk3WLhfG5AwLERIOLEQIIaTLUFsz5YjHd8Xzp5O1lowxOcMRXhIOLEQIIaTLaKdHpFQ1klUBmCZnCAkDFiKEENJltHNqplwRrRmrIsJChIQDCxFCCOkyOmNWrSsiOaarknBhIUIIIV3GTLF9HpFS1aKI9HBqhoQLCxFCCOky8uU2Ts04mVXpESEhwUKEEEK6DPP23Yg9Ig1mVU7NkHBhIUIIIV1Ge3fNmFszAzlOzZBwYSFCCCFdhnlqJupAM9GaoVmVRAMLEUII6TLUqZnIc0TqiojYMSNaMxM0q5r43iOvYNNPn+VW4iZgIUIIIV2GKVm1XWZVBpq58smfPosv3r0DLxyc7PShdB0sRAghpIvQdd3UmqlEblZ1CjQr8+6/jq7rODJdK8ymChWPRxMrLEQIIaSLKJTNhUfkOSLW8d16a6ZYqTYcy1xlplSRLTKek+CwECGEkC5C3bwLRO8RqVjGd/sySdTtIgw1qzOhGHcLZSoiQWEhQgghXcRMyXyhi9ojUrKM72qaZmSJ0CcCAJhQzkORikhgWIgQQkgXkbcUIpHniFjGdwHDJzLGUDMA5lFmtmaCw0KEEEK6iHy5fYqIruuy0BHjuwAnZ6ywNdMaLEQIIaSLyDd4RKK7Ay8pRlgxvguokzMsRABza8bq4SHesBAhhJAuwtqaKUU4NaPusREeEYDpqlYm2JppCRYihBDSRbTTI6IWOaZCRLRmqIgAsCgibM0EhoUIIYR0EdbWTJQeEbXIsW3N0CMCwKyIcGomOCxECCGki7DecUfpERGbdxMakDCZVY10VcLWTKuwECGEkC7C2pqJUhEpiTCzpPlSMZjj1IzKuKk1E6wQyZcqcz4qn4UIIYR0EQ1TM1GaVUWYmaKGAKoiwkIEsCgiJf8ekX1jeZz517/CR77zeBSH1TWwECGEkC6incmqwqzaUIjUPSITnJoBYDWr+ldEHn75MCYLZWx5+UgUh9U1sBAhhJAuop1TM2J8N21tzdQVkaPTxcheu5to1iOy8/A0AGCm2PqkzefvfAGf3bwt0NfkSxUcnio2vKfaDQsRQgjpIto5NSM2+6qjuwCweCALANg/Xpjz/gag+amZXSEVIvlSBf/vF1vxd7c9b1JnvPjF0/tw+l/dhvfe+lBLr98qLEQIIaSLEHevolsS6dSMZfOuYMlQDkCtTTRGn0jTOSJSEWlRkSiUqhD1YJCiRjy2J51s6fVbhYUIIYR0EeJC15epTa5Eq4iI1oxZEcmlk1jQlwEA7B3LR/b6QXhk5xH8f7dvk8fcLnRdb7k1U67qcstxMxQqRvFhVczcmBaFSIaFCCGEEJ+IC01ftlaItCNZ1Tq+CxiqyN6xmchePwibfvosPv3L5/HAi4fb+rr5UtVUDPrdNVOqVLHnqFHETbfQnlFfM4i6Ih7by0KEEEKIX0Rrpj/XBkWkaj++CwBLh3oAwHQx7SRClWh3q8jqyfDbmtl7NG8qIlsxjBYVNSXI84jWTG9dXesULEQIIaSLEBcaqYhEmiNib1YFgKV1RWRfTFoz4mLc7l0v1sV/flszoi0jaMWwqioiQQoRocLk6BEhhBDiF9Ga6c/WLh7R5ogIRaTxUrF0uFaI7IlJa0ZMqwTxSISBVRHxOzVjLURaac2YFJEAHpWZUq2IYmuGEEKIb/INZtXop2asZlUAWFZvzeyNSWtGFE3tzsSwhro1rYi0cNxqmmsziggLEUIIIb4REr7wiEQbaGY/vgsYZtV943EpRGrH2m5FZLJQK0QG6j8Pv62hXZZCpJMeEU7NEEII8Y244+6ve0TK7dg146KI7Dk6E4tQs1K5U4pIrTWzqL8W8uZ3aibM1oz6mn5fHzBUGOaIEEII8U2DWbUNyarWiHcAGBmqX3jLVRyd7nyomVAF8m02q4rWzMJ+43z4QRQi8+t5LK20ZlRFJMjzsDVDCCEkMHJ8Nxu9R6RUf+6kzfhuNpXEwv7aRTQOhlXhEQmiCISBmJpZOFA7F8VKFVWP4nBsuiTHjI8b6QcA5FtRRMqteUR6OL5LCCHELzLQrH4X2x5FpLEQAYwskU4bVsuVKsRpaPf4rmjNCEUEMCsUduw6Mi2/ZkFf7eumi81vMlYndYJ4ZPIMNCOEEBIEXdeNqZls9IFmbuO7gJKu2mHDaqmiBoO1e3zX3JoBvNszoi2zan6PzPCYaeG4TYVIgEJMFD/0iBBCCPFFsWIsN+tvh0ek6hxoBgDLRCFytLOtmWanRsJAKCLz+jLQ6qfJS5UxCpFeqUa0NL5bbu77564ZQgghgVDv9ptVREYnC76nXESRk3ZUROqtmQ6nq5Y6WojUVIXBXArZVO08eflU1EJEFAEzLbRmCmzNEEIIaQciuErTjLvYIIrI7Vv348y//hX++c4XfD2+5DK+CwDLhuOx+M5ciHSmNTOQSyGbqv1MvFozIkNk5fxepTXTXkWkVKnKllZvmmZVQgghPhAX2Z50Ui6iCzI189y+SQDA03vGfD3ebXwXUMyqnVZEyopHpENm1YFc2lBEmmnNFEPyiPgsRNTcklyms6UACxFCCOkSxF1zLp2UBtIgS++EcmCNJXd8vMv4LmAsvts7lu9oqFkxLopIuvYzcds3U65UsftITUFataBXGkVbaSk1M74rUlWTCQ0Zh0KzXbAQIYSQLkFcZHKphCwOgnhERFLqVMFfIeK2fRcARgZrhUixXMXhqaLv4wgb9cLf/vFdUYik5QXdrTWzdyyPclVHJpnAyEBOFiLtHt8VRW1vOglNs//5tgsWIoQQ0iXkVUWkXhwE8YiU6o+d9F2I1C5qTmbVTCohx1Y72Z5RPSLtDDTLlypSjfHrERH+kBXze5BIaIZZNSyPiM9CTBQ+uQ4bVQEWIoQQ0jWIFe9Zk0ekGUXEp6HRY3wXUA2r8ShE2jk1I9QQTQP6M0ZrpuByDKo/BDAyPELLEfGriMQk3h1gIUIIIV2DoYgkpEek7JHiqSKmJITB0ouKh1kVUH0i3pMzY9MlfHfLK74VGb90KkdEnMf+TAqJhKaYVZ1/Jg2FSCjju8b37FYEqcgMkQ6HmQEsRAghpGswPCJJJJPBFRGhHEwVK77MpcKsmnIwqwLG5MweHzHv/3LvDnzkO4/jmw+87OdwfWNKVvW5dC4MVKMqANmacTOrOhYirSy9U17P7/PMxCRDBGAhQgghXYPwP9QUkeAeEWE+rVR1XxK+YVb1VkT2+VBExLTIoZCNrSXlQlyp6qZWTZSIQqRfFiLeioiaIQIorZkWxnebyRGZiUmqKsBChBBCugZhRMylk6apGb+jsyUlc8RPe6TsQxER+2b2+PCIHK1vnA0ycuwH65I5r0CxsFAzRICaebf2+k14RNo8NWO0ZjobZgawECGEkK5B3O2qgWYA4FcUKSsFgJ9CpOQxvgsAy4ZrrZl9fgqR6ZoSEvaiPqsC0i6fiFNrxqkQGs+XcGS6VrwIRUTdNdNsFovJI1P29zxiaoatGUIIIb4Rd7tZRREB/KerqhdsP1kiXuO7gNqayaPqUWAIRSTs1onVk9GuQmTcoogYUzP2359oyyzoy8ilhWJ8tqo3Kjt+UV9P9/k8cdkzA7AQIYSQrmHGZmoG8O8TUU2dftJVvbbvArVQM02rXfy8vB9jdTWgHHJrpmR5vnalqzYqIu6tGas/BDBPrcwUmyugrIWHn+9ftGZys31q5u6778Yb3vAGLFu2DJqm4Qc/+EGUL0cIaSNfuvsFXPy3d+CxXUc7fShzBjXQzKyI+Luwq8qJP0XE26yaTiawqB5q5tae0XXdUEQC7MfxQ9xaM05TM1Z/CFA7f6LN1uzkjHVk188I7/RcyRGZmprCqaeeis997nNRvgwhpAP89Ml9eOnQNP7X1x6Sd3okWsSdbi5l9oj4NX8G9YiIwiXtYlYFjPbMHpfJmclCWSo34SsiVrNquwqRWmE1KFozHlMzdoUIoGaJhKOI+Clo4tSaidQue8UVV+CKK66I8iUIIR1CmN1GJ4t4760P4bsfPF9+IJNoKCitmURCg6bVPAF+FRH1gh3MrOp+z7p0qAePvzLmqogcnTZC1IJsDPaD9cLfrtaMOIdCEfGamtl5uL7szlqIpJOYyJdNG3GDYPWkBGnN9GQ4NWOiUChgfHzc9IcQEk/EB1kmlcC2A5O4/huPtC2/Ya6iju8CCJwlErQQ8TO+CwBLh70VkbEZoxCxejpapXOKiJNHxP734MB4rVATI88CoYg021Iq1L9/sbvOz/MwWdWBTZs2YWhoSP5ZuXJlpw+JEOKA+CD7f299FXozSdyzbRS3/PCpUNbB67r/bIy5RF4JNAOgZIn4KwBV5SSYR8Rfa2avS7qqWogECWHzQ6NHpM05IlkxNVMf33V4ffE705c1qxDGvpnghYiu69KTIhRJP4XITInju7bcfPPNGBsbk3927drV6UMihDggWjNnHDMPn33nRiQ04D8e3IWv/Pqllp/76q8+hCv/6d5Ae1TmAqpZFYCcnIlqakZc4FMu47uAEfPutzUTtnLWODXTaUXE/vWdsjuEItJMa0b1hwz21I7DT8w9k1UdyGazGBwcNP0hhMQPNSK8L5vCpSeM4E+vPAEA8C/37GjpucuVKu5+/iCe3TuOfeOd2+gaR8QFVkxnJANu4C0HzRGpiqV3rZtVj84Yo71uZlVd1/HU7rFA7ZXGHJH2FLDjshAxm1WdcjzE1uO+jL0i0kwBpbaBhnr8KyJszRBCuhpVQhZ3d288bRkAYP94viUlY1p57mbNe7MVcYEVd7GiQPCriKgFiy+PiF+zaj1ddf+4c6iZX7Pqjx7fg9/57L34x19t8zw+QefGd0WgmSVZ1aYQqlZ1Y9Fc1nzx721hakYtwkRImr/WzByZmpmcnMT27dvlf7/44ot47LHHMH/+fKxatSrKlyaERMh0/SKW0Iy7wIV9WaQSGspVHQcnC1KuD4r6YRz2uvhuR5pVU2aPiN9Wh3rRCtOsunggi4RWa5GMThWweCDX8Bi/ZtVX6ovxtrx8xPP4jOezFCJtMKsWy1WpRvgZ37Ur3gWi1dZM4S1eK5NKoDfjvxCZniutmYcffhgbN27Exo0bAQAf/vCHsXHjRtxyyy1RviwhXc+2/RP46q9fjO0UihGGlIJWt+onEhpGBo2472ZRWwbTBSoiKvliax6RctCldxXRmnG/VKSTCSwaqIWaORlWxZ4Z63FYEe/5Fw5OeR6foBOtGaGGAHbbdxvft+J3RtNqOTAqrZhVxfeeTSakidnP959Xfoc7TaRHcPHFF9P5TkgT/N+fPIN7to3imAW9eO2GkU4fTgNOqYwjg1nsPjrTUiEyTUXEEWFCFIVIcI9IsKkZURQkPRQRoGZY3T9ewN6xPE61GXg0tWZcFBHxb6OTBYznS76yaYTCkkkmUKxU2zK+K4yqfRkj5VbumrFRRIRRtSedRMJyPntbGN8V32s2nZAFjtfz6LouW6D0iBBCbBEX8tEJ990dncLJ/S/yEVoxmc6YPCIsRFTySqAZ0FqOyJQPtcmvWRVQl9/ZG1bNrRkXRURRS3b4VEWEOVR4NZzGZ8NkwmJUBdw9ItMuCkSulakZ0ZpJJuT4sJciUqxU5Xtm1rdmCCHNIXZyjCvyb5xw+lBdMlgf42yhEFHv1P3ctc8VdF1vGN+ViojPgDDz+K73e8uvWRWAbM0cnCzY/rtaiLgpOKWy8W87Dk56vi5gFDaiEGmHWdVqVAXcp2acincgpNZMOmm0ZjwUIdWHFQezKgsRQmKGruuynx7X1oSzIuK9/MwL9UNyilMzklJFh7h+5yzju814RKaKFc/WecnnrhmgZlgFgIMT9oWI79ZME4qIUYj4H19tlXFLhgigKiLOHhG7C79szbRiVk0mZIHq9f2Lgied1Dz9P+2g80dACDExXazIO9dJH6FTncDJcR+GWVWVp6mIGKh3ucKLIBJP/SSr6rpuUkTULBg7qlUdok4JpIg4FSIzfs2qiiIy6lMRqasohiLSPrOq2prJuEzNiFaYmyLSSmsmiEdk2mJ67jQsRAiJGUeU6YK4KiJCqbAGMy2pFyL7W2jNqDkifnwMcwVxcdGUkelkgKkZu8e4vb9Ur4ZXxDvg3prJlyqm4sDdrBpcESlYWzNtNKvatWbKVb0hS0eoiNZ4d8AoCJppzQizaiaZQE/G39TMjIs60wlYiBASM1QJeyKmhciMQ2tGRn2P55uemJumR8QWYYDMpZJyZDoVYGpGfYzotLgVImqxkPaIeAeARf21ItROEVH9IYC7WVU9zh2jU76KrFK5/a0ZW7Nq2jhPVp+IW5Kp8Fo1V4goiojP1oybcbYTsBAhJGaohUhcWzNSZrYkRC4erN0V50vVhouPX0ytGU7NSKwTM0Awj4h6YRzuzQBwL/TUQsTP+K5QREYniw3pqup7GvAwqyrHWSxXseeoc2y89Wvk1IyPXSutIlozg4oiklFaWNbJGTdFRCgZzSSrmjwiPlszMzEa3QVYiBASO9ReelxbM0Y8tPlDNZdOYl5v7Q6x2cmZmRI9InYYm3eNi0cgRUQpLIbrO0ncFt+prRk/47sL+mvFTaWqm9qLgBFmJlsXLq0Zq1rygo/JmUazajvHd43fgVQyIX8mToqIXTukldaM9IikklKR8W7N1DNN2JohhNjRHYqI8yjiEh+bWN1Qs0PoETGYKTUaDA1FxPvCKzwLCc24ePpRRJIJTbaC3EgnE5jfVytGrD4RMY6+sL+mmpRcjtdapPjxiQiD66DMEWlDa6bQaFYFlHTVBkXEbWqm3pppMeJdtmY8PDJux9IJWIgQEjOOdoFZ1c3stmSwtRFeNdadrRkDY/Ou8bGdCpAjUpLhZAkZSe5qVq342zOjsqjffnJGtOkW1ts3uu7cThLHKYzPfiZnrIFm7fWImFXBjEPMuzHybtOaCUERUQsRr4KGrRlCiCsms2qXBZoBraercnzXHmuYGRBsakYoIulkQk47uRUiFaVw8YvTCO9Y/T29qN6+AZxHeMVxHrdkAIA/RaRoNau2wSMickT6LZ4PmSViOYZpH+O7zSkiRoEqnsfLIzPjMH7fKViIEBIzjqitmUI5lvuaplxSIkdaHOE1je8y0Exi7JmxUUT8TJYIhSOp+VJE5OZdH/4QgVMhInxPojUDOKs44u83BChEOpusamnNpJ0UEWHwtjOrGopI0N93syIiPCJszRBCWmBMMatW9ebk2qhxa82InSN7m2zNzBQ5vmuHnSIiA818bGkWPopUIiHv4t3Or/p4vzgWItNmj0jtmJ1aM3VFZKRWiOwbz3u2KK1m1fZMzdi3Zpw8IrJ4t2mHqMpEUKNtQTGrBh3f7UlzfJcQYsMRy6hjHA2rUy6tmVbTVVWD6nSx0jAKOlcRBkx1hXwzUzPppCYLEbepGblnphmPiINZdYHSmnEyrIqiYkFfBgvrj3/RRRVRE2NFUVCp6q5ZJWFgjO9azar19ojl9UXx3pd1bs0AwW88TIqIHN91/97zJSoihHQlR6aKbVlCd9Qy+jgew0LEKdAMMDwizbZmrB/E0zFUhDqBuLiod89BPCJyb0wyIbMsXBWREFszwiMy3Jv2XNRnLNrTsHZhPwB3w6oaCa+qE1G2Z0qVqvx5+FdEhC+jsXhPJjRpcm22EMmqrZmye4tnmuO7ZK5RLFdj6XMIQqFcwaWfuQtX/uM9kd+hW8Of4jg546aILK1v4D0yXWrqYjBtmZSZjuH33wnsAs0CeUTKRmEx4McjUgnPrCo8IsM9GXnMToqFMa2TwNpFfQCAF1wUEfV5BrKGOhFlloiqUvYHnJrpc7j4G4bVYO931ayarT+HrttvADaOhVMzZA4xNlPC+Z/ajOu+/nCnD6UlRieLODxVxCtHZnBoquj9BU2i67qUsYfrwWBxbM24eUQGe1LyYtmMKmJd/BXHQqwT5OUFR1FEkv6TVUWxkk74m5opNzO+67BvRhTXQ71pWdg4FU/i7zMpTRYiO1xCzdRCJJNKSEUiSkVEtLR60smGQi3rsPjOaVGkwChEghVQouBQp2YAIO/yPNw1Q+YUz++fwOhkEXc+dzDynm2UqBL2bh+R080yUSjLi8qKeTVlYbIQrxFeXdcN451Nv1vTNJkBEdQnouu6/MAW179mNpLORlpNVg0+NSNaJAEUkbpH5Oh0yaQIiByR4Z60p8G2rJhkZWvGRRERrYmEVmtx5OQIa3Tvm3E5MdOoCDqP7wpFxN4g2qtMzgRBtIAyqQTSSU3+3riFmskcERYiZC5wuK4elKs6Xj403eGjaR71A9vP7otmEb30XDohJwzcDIWdoFCuyvXwTkuzms0SKVaqshBbUP/+qYjUmHHdNeMnWdUoLPxMzZSlp8S/IjLUk5aPPzRZ/92vVOV7eLg3I6dwSk5TM0rBJBSRF0enHFuiRSUfBYAywhrdjY/TxAxgjO8WlUJE13XpdbIr3oHmY94NRaS2DNHP5AxbM2ROcVhpY/jZGRFXptpUiIgdHfN6M/JiEbcLsXounD7ImlVE1EAnUYhZPSNzFdvx3UDbd8VyNGNqxq3tV2piaiaR0OTPTfhEVLP1YC4lCxXHZFWlsFg5vxfppIaZUgV7HYpacZzCm2EoElG2ZuwzRGqv3+gRyZe8i3eZJRLUI6IoIgCUQsRPa4bju2QOoBYi2w/MjkLklSPRFSKyl96TNgyFMVNExN1ULp1w3Mo60qQiIp47ndQw1CMKMbZmAOOCk0upikh9asZHxHtRaXn0+Shyy03kiACNhlUxBTaQTdWWwtULEafxXXVsOJ1MYNX8XgDOPhFRuGTiooiIQkh5fbWYdirem27NWL9/Hx6Z6RKnZsgcYvYoIsYv9VxXRMQHpVOvGzAUkaBmVXUfh/j+OTVTo2VFRGl5qFMzThNtzSSrAsBii2FVmK+H6ubrdMJ9A686ZgwAaxe5+0REC8RozfgL9WoFpwwRQJ2aUQsR7+I916RZVeTLiJaQn+9fvAbNqmROcMRUiHhHNccVdflalGbVMWVipr8+ijgRswuxUIfc7qaaTVdVo6f93LXPJYT50H77bpBAM0MRqerOykGpEtysCjQqImqGiHrMXmZVoxBxn5yRrZxU7Xn9hnq1wthM7T052GOniDS2ZsT72q14F0pJ0FZk0aKIZH14TUT7hx4RMidQR113HJjs2jyRdplVj0yJD+2MMdkQs9aMn9E/uW+myUKkJ5OU/WtOzdRoeWpGMZ/2ppPQ6jfmEw5TWXJJXgCPCNC4gVfNEAGMwqZkc8y6rivTOrXXXSdDzfwpIlmf+1Za4fBU7Xub35dp+De7qZkpHwFi4vcp6HGLFpAoQHo8WlMm4ywVETIXOKKkhE4Uyg1BR92C6hE5Ml2KzEApP7R70xiIqSLgFmYmEFMzByYKvu7WBTPKnWN/fbqA+2Zq2AWaGTkiwaZmEglN3p1POXhwrAWBXxo9IpbWjMv4rlpQiRaOoYjYFyLSrGptzURoVhU3WPP7sg3/Zjc1M+NDEWl1asb6/TuZddWpN3pEyJzgsCX8a3uX+kSsH9ZRqSLiQ3u4J91RReQXT+/D3/x8q+3I5LRLvLtgUX8WCa12YTk06b/4VO8ce32Ebs0lWvWIlCwKh9fkjOEpabI1M2kuRIZ70qZjthvfVbOGRAEkPCK7j86YpqqsX9PoEYmuNSM+1xbYKiKNHhE/7Uzxb0EVwKAeEfUcsjVD5gTiF3adj6jmOGO9K999tLk9Kl4ctTGrhu0RqVZ1T5XiL3/0ND5/5wt4cvdYw79N+1BEUsmEvCAFmZxRPSLSrMrWDAD71kygXTMWz4dYvuZU6BlJrM0pIgcmaj934XsaEoWITFZtLBTU4kQUIvP7MvJrdx1pzCKSioAc37WPWA+Tw1IRcWnNlGw8Ig4ZIoCxlTdoa6ZREXFvzYi2TCaZCFxkRkU8joLMSvKlivwFPHvNfADAC106wjtlacXsjmiE94giYxvpl+Emq/7ev9yP3/rMXY4f1MVyVWY2HJpqVDOmfcZDN5MlovpPej0ulHONQrmxNZPyWCCnIj0f9Qt8f33iw+n8iqIgGXR8t7/2cz84UYCu6yYDtvr6dsWT2q5JK68rpnzs2nQly/fldSEOI+H5kGsh4jw105N2MavKHBH/hYiu68bSO6GIpLwUkXiN7gIsREiEiLuGdFLDqSuGAXTvCK/IshALq4K0ZpymA+wQH9rzejMYzHmvag9KsVzF/TsOY8foFF50MP/tH8/LHrJ1AR+gRFW73N0BhmG1OUVEGd+NONDs8FQRP31yr6mnH0fEBSqn7JqRcek+PCLCHCpaGF4eHGvh4peFA7WLc75UxWShLFU+aVZ1SVYVKkwyoSGhKDFu7ZaG1oyNIiH45zu346RbfoEtLx8J9D2p6LoupwEX9DcWInZL7+TCO5ffmVw6eGumXNUh6rlssvb1WY/WVNxGdwEWIiRCRCEyrzeD9Yu9d0bEGfFhvX5kAID/QuSRnUdw8l/8Av9yzw5fjxfmXnV8dzLvnPUQFGGGBZxVHfV7sy1ESt53d4AxwhtEEVH9J4ZHJNrWzN/+4jl88BuP4MeP74n0dVolX3aemvHXmjG22gKGcdKp9Vdq0qyqFpEHJwqNOSIuZlXjGM2v2ePSthAFZINZ1eax971wCMVKFT98bHeg70llfKYsC6agiohbO7OZQDO1eBaKSI+H6XU6ZqO7AAsREiFqH3WdYjjrxshuUYgcVy+oXvFZiNz3wiHkS1XcvW3U87HVqlnGFq2ZclVvWKDVLGph4ZSHsmfM+HtxPCq+FZEm0lXV8V3x/FEHmr1S9x1si3HbsKTs4DHvmnHfZKvS2Jpx3zdTthQuQVAnZ8YazKrO47vWDBFBzmUkt2j5GrfWjHj/3+vj99EJ0a7sz6ZMm5AFQpEo2ozvuqkQbsWWE+rnQqNHxKEQidnCO4CFCAnAeL6EXYf9L65TU0Ln9WXk3UM3qiLig+T4JcEUETHC6GdyZCJfli2R4Z6MOeshpPaMqRBxVESMwsG2EPFYZy5oxiMi7xzTKWW8NNpCRHyPbj/TyUIZP3liT8eKaPWi0rwiYlY4vKZmKrKVE0wRAZQskUlDERnurf3+J/0oIpbXdBvJLYkcEeuuFZvHCkVwx+hU08GEbkZVwF4RMcZ3XVozTUzNiGInpbSyvMZ3/eQAtRsWIsQ37//6w3jt392Jlw/5KyTE9s359T7q+roq0o0+ETG+e2y9NbNvLO/rw1+MMIpz4YYo3PoySWRStayH/pBHWI8quS5Oqs5uU2um8bhlGJKHtNvMBt4ZpZcu0j+nIp6aEcXZ3jHnC9OX796BG775KH7/37YE8vyEhXp3nzXtmglgVrVEp3utEGg2WRUwFJH94wXDIyIj3p2PueSw3ybrkpZq3TUjFImCzWPHlEL8102qIm5G1dqxNnpExOdHj1trpokcEfEa6nvCy6w7U/Q+lnbDQoT4Ytfhady/4zBKFR2P7Trq62vEhXV+/U5o3eL6CG+MJXAnxIf12oV9SCU0lKu6HE90QyoiUwVPn4f1zhFA6Fki/hQRpRBxac30Zt0/yFRFxK/HZcqmNTPlsg8lDMSFco/LSPbz+ycAAPdsG8Xf/HxrZMfihBpmpmmGWhBEEbG2Pbwi9GXhEnB8FzAKkZdGp6SZ0jq+a7f0rlw1t48Ebu0GWYjUI97FRdmqiFSqumkT8D3bmytE3DJEaq/fWAjNlLzbmUJhzDehiGRMhYh7i8fweMXn8h+fIyGx5rZn9sv///Ihf+0Z653DOqmIdFdrplSpyl/4wVxa3un7GeEdrRcipYqO8Rn3YuKI5c4RgJIlEs4Ir8ms6qCI7FUuyLZmVb/ju/XzNF2s+M5CUWVjcaEsV3WZlRA26sVp37izyqVuXP7yPS/iB482b3ZshoLNnhlAUUR8TM0ULUZQt5FYoPnxXcAoRMTG7WwqIY/dMKs6KyJWj4ibAbPRI2J/IRaL6gS/2T5qG9jnhVdrxm7pnVREXFREuWumCY+I6lXxO77rZpxtNyxEiC/UQuQln62ZI46FSHcpIuoHdW82iWXDPQD8Lb9TI+1HbTI5VKzLwYBoFZGDEwXbDytVERl38Yi4xVUDtQ86cbHzu3PGcPSnTM/vFEPeKurFqVLVHVcQCEPrb504AgD4k+8+gadswt6iQoaZWcyRKZdMDitla6CZR9uv7ODX8IPwiAgDsPqeTrkYbJ1eM9j4rn1rQrz3c+kEejNJHJoq4tl940G+LQCNLWcr9kvvhCISbo5IwUYRyXoFmvn0eLUTFiLEkyNTRTz40mH5334VEeudgyhEdoxOBdo/0mnEB3UmlUA6mcDyeiHiJuUDtQ8UVQnw8okYiojSmgl538wRi8Jh3Y47ni+Zjtm2NRMgEGlpQJ+ImkCZTGhSko/KsGpVfPbY+ESmCmV53j79u6fikuMXoVCu4v1ffxijigm5UK5gPB+OcmXFbs8MEHBqxtL2MALzPJJVmylE6oqIOD8iQwRQsk9cds2kLSqM+L7tskGsS++cFBHxXp7fm8G5axcAAH7dRHtGLLxzbM0ou2ZES9HPxb9HmkyrvpUaO4+I1/jujE+PVzthIRIy5UoVd2w9EOnmx6AUypWmJEjB7VsPoFLVpePbr1nVWogsn9eDTCqBYrkaWTJpFIgPEVEULJeKiHtBNmqZlLH+txXrTg7AkM/DKkTGZszFkPXnINQQYQs4Ol1seO/4VUQAI9TMWvA4YW379EvDalRLBi2FiI3KJZSvwVwKQ71p/MM7NmLtwj7sGcvj8n+4B+dt2owT/vznOP7Pfo5X/cUvceuvXwz9OGdK9q2ZZqZmGsyqDmqbNXckCKIQEQwpikg66Vw8FT0VERePSNI6NWIudGTUfG8GF6xfCKDm+QmK28I7wGiTVHXje/TzO6MWKX4Nq814RGaoiMx+/vnOF3DtrQ/hr37yTKcPBUDtTXfFP9yDK//pnqaLEdGWeftZqwAAo5PFhn6rHdKsWi9EkgkNaxeKnTPd056ZtORmLPOpiBywyPxeI7zqnhmB9IhE0JoBGosp4Q9ZXf85VXVg0lIEBJF2lw2Jc+Wv8LRGYfd6bIhtFet48l6bn6loy6yY1wugZrr80v88AwPZFEYnC9g7ljddOB56qfnUTsF4vmTKochbVr0LkgGW3pUtHhFZ5DkUuS2N71oKEbW4Npbe2SgiDpM6bptpDbOqe46G+P0a6knhomNrhchDLx0OfNPobVY1jl0UQ34WRaptN7+FiOERaSxEnLKH2JqZ5VSqOr714E4AwHe2vBJo62hU/PKZfdgxOoWt+yZw2GYU04t8qYK7nj8IAHjL6cuxsN4X9WrPVKu6lLNVU1c3+kTEB7W4m1k+r66IeKg6Vr/BqEdr5uiMjUck674PJCjiZ7LC4XsQd/9rF/bLD/QxS/HiJ65asGpB7eK902f+zIzlA7vP42LZKtbxZLvWjDCqinMGAOsXD+CnH7oIX3/v2fjRDRfgno9ego+/4UQAwZeWWXn50BTO++RmXP2VB6W0L1szKfNHdiuKSJ/HUsVWxncX9GWgDPfIiRn1mO3MqjJ0LeHfI1Ism78vObViuRALv9NwTwbHLu7H4oEs8qUqHgkY9+43RwQwWknTBW+Dd0JpRfr1idgrIu6BZtIQztbM7OQ3L4xiT12CLpar+I96UdJJVHf//gB5DoJfbx/FTKmCZUM5nLRsEMcsqN0pexUi4/mS/HBU7/DXLe7iQkS2ZmrtBq+7/IOW8V67BXIqR6ajH98dq194T1o2CKAxS0R8T8uHc/LioaoGxXJVXqB6PSLeAWDl/Foh8sphb0VE13Ujo6Re5PTJkKdoChGrImLbmqkXIsuVQgSofW+vPm4RXrViGCvn92JB3aAZJAfCjm8/tAtTxQru23EIm589AED1iDQ/NWMNC1OnZuzGo8VzWuPW/ZBKJkyKgcms6rZ917IPR+BnfLcxWdWqiBhbgDVNw4WiPRPAJ6LrumeOiKZpMtOkUPeJiPe1m1kV8PZ3WCnaTc14tWbE7xinZmYn/7XlFQDAyvm1D6yv3/dyRxdpjU4WTNHiB8aDKzS/fLrWlvmtE0egaRqOqd/hek3OiF/WgWzKVK2vW1QrZLZ3UZaIXHhX/xARrZmJQtnVnCgUEfFBPjrhroiMyeVgikdEtmbCGt+tPc/Jy4YAOHtElg73SIOh2s5R79T8SLsr6xdvP4pIvlSVybLiQ9LIuoimNSO+t4X1IsLOy2IoIr2uzxX0ImJHparj+8rNw9//6nnouq7smbEqIrX/rvgKNDOHhYlzW9Xtj7lsSWINijingLm4Fs9nu/TOySMiRlJtPk9lIWJNVi1VTAWWdQvwhfX2TBDD6mShLD/T7RbeCdR01ULZiOf3+p0R73u/iojd1EzOJfwNMIr6HFszs4+xmRJ+/tQ+AMA/vP00LB7I4sBEAf/9ZOcWaf3k8T0myTaoIlKp6ti8tVaI/PZJSwAAq6Ui4l6IyNFdyy9rN2aJiF9csa20N5PCvPqHmVt7RqSqiu/ZryIyr89mfDeE1kShXJH94ZOW1xQR6wiyUPSWDfdIg6GaPTJdD2ZKJzXTh58Tq+qKyL7xvGfLQlU9xEVd7puJWBE5YamI7nfziPQ0/JuKLERaSIK9f8ch7B3LYyCXQl8miaf3jOOXz+yXEr81h0IoInZ7W6yULcFf6goBu/dXK2ZVwOwTUVszYiLGdmrGcdeMc9iXsfSublZVzKJqsWNdvicMq0/uHpOfV16ItkxtBNhZUVAnZ9T3g1c7RLZmfCsitceZCpGM8Rx2ShdbM7OY/35iLwrlKo4b6cfpq+bh6vNXAwD+9d4XI02FdOP7j9WKIPHm3h9QEXl05xGMThYxmEvh7DXzAUBRRNzvcA9NNRovAWBtXRE5PFWUv9RxZ9LiEQFUw6pLIVJXRDbUL3Je47uGmc44Z0I+D8OsKrweCQ3YsKRWiFij6tXWjFBmVEXEMJP6+xCb35eR7RWv3BXx3NlUQl5gvbIuWkV8bycurZ2P0clCw44OO4+IHT0Z9968H777SE1VfeOpy3DNBasBAP/wq23y4mFtzaQD5IgULfHpiYRmnF+b91cr47uAuRAxt2aciydr6JpAnlub/SlFS2smq6hG6uPV1gxQm+g6bqQfug7ct+OQr+/pkDSq2k/MCAyfSkVOfGVSCU+/TdAsETezqvrvKn4DCdsJC5GQ+M6WXQCAt52xEpqm4Z1nr0I2lcBTu8fxcEAzlBOP7TqKTT99Fj99cq9nBf/i6BQe33UUyYSGt56+AkCwnR+AMS3z2g2L5S95YEXE0kftzaTk+OuOLvGJWD0igDrC612InFC/yB10MS+XK1WZ8DnPJlk1jAuxGiE/MpiTUfVCKatUdbmgbtlwj61HZLrgr9ct0DRN+kS82jNGhojx3FGbVcU48+qFffLDXF3SN1OsyIvPimH31oy4+DTbmpkqlKWq+pbTV+C6i9aiP5vCs3vH8eMnxE2Fg0fER/KsXdvDmJyxa82Ep4iYc0Sc20llS1EhcEsLtXpEsqmEVHrUx6tmVUHQMd7Dk+7+EIHamvGz8E4gfFfBPSKNrRnAft+OHAVnITK72H5gAo/urF3037RxGYDaG/Utpy8HAHzl3tZzBXYcnMR7/vUBfPHuHfjgNx7B6X99G17/T/dg00+flXswVIRJ9aJjF+Kkuh/gQIBCRNd1/OLp2oeiaMsAhiKyf7zgKpe7GbrWhuQT0XUdLxycjHwJmfiQ7s82KiJ+CpEN9Y29E/my40ZMdQfGUEQ5IqI4HO5JI5nQsLRuuhXfw8GJAspVHcmEhsUDOXkXaypEAoSZCQzDqlchIlJVjec29s1E6xEZ7knbBtWJ8eaBbAqDPR5GwyaSMVV+/tQ+TBcrWLOwD6evGsZwbwbvrasiz++vR6U7eUR8BZo1tj3E+bVbISA9Jc0qIv32ikjaxWDr9JpZ12TV2teI9oSmaUYhoDxetBjV3y8xxvubF3wWIh5GVUFGef0pqUB4F++5phUR43cmndRkDpCdgsTtu12MrtdGc9/2hd9g87P7Tf/2X1tqF/1Ljl+ExQM5+ffXXrAGAPCLp/dhl8/xRTsm8iW8/9+2YCJfxrpFfTh2cU1OfHrPOL549w5c9blfmxbR6bqOHzxWO6arTluOkcH6JkwfS9oE2w9M4qVD08gkE3j1cYvk3w/3ZuQvstsdrpMiAhieiRdHW/OJ3PHcAVz6d3dh08+iXUImioBeZVzVK11V13WTR0RIzU7tKJG5MpBLmeRbOb4bQmvG2iOXqk699SAKkiWDOSQTmjQYqiOuQcLMBCvn+VNE7D4gjRyRaD0iQ71pWZip7bZdysSMumzOjh6Xi6UfvvdorS3zlo3L5Wu978K1shgFGiPekyKl1Nf4bqPa0J+rvRfsFRF7v4ZfnDwiculdgF0zvpbeKV+jtkYEVrMqABxfb1Ha5cfYccgjQ8R4fSPmXS6J9HHhF4vo/O6bKVYazaqaprlOzsjWjI+pt3bBQsQH+VIFN333Sdz0vSfx0EtH8L6vPYy//cVWVKo6ypUqvlfv6/7uGStMX3fcyAAuOnYhqjrwtd+81NRrV6s6/ujbj2H7gUksGczhP647F7d9+DV48E8vxT+8/TScccw8TBcruParD2L7gZoy8uiuo3j50DR6M0n89kkjMt0yiEfk83e9AAC4YP0CkxIAAKuFT2TU+cLiduewph6W1Woh8pvttb7u1ib2RQTBMKsqrRmZw2F/DsZmSvJDdfFgVjrsnXwiR232zACGWdXv0jg3rIFpy+utBlGA7K1naCyrX5CH3DwiAe6mVs33NzkjPyCV8yzO+XQLBlA3jipy/dJ6+NpeJUtkt8+JGcAoRIqVamCVbs/RGfzmhdr7+aqNy+XfD/Wm8b4L18j/biVZVU7BJNTWTO35Jm0UkVIL47uApRAxJau6KCKyWLJ4RNyW3pUbC6yczb4Vq0dEfd5iperrHIp4d+/WjJFlEsSTIQpvvxt4hYk5m7IWbvZFsa7r8hwy0KyL2H10Bv/ji/fh2w/vQkKDVAc+d8cLuPorD+KHj+3BgYkC5vdl8NoNIw1f/976h8i3H9rVlLz+9796Hr969gAyqQS++J4zsLheVCwezOGqjcvx9feejVNXDOHIdAnv+dcHsfvojGzLvO6kJejNpLB40Nj74OcD8v4dh/C9R3ZD04A/vPTYhn8/xodPRISnze9t/IUVqZ1+l+c58Vy9JeVlAm0VOb5ra1a1v5MSbZmhnjSyqaQ0tzn5ROxSVQHjQlwsVx3bOn6xRsiLYkqYMYUSIL43u9aMMN756XcLRKjZLo8sEfHcqptffHhHYVbVdV0aeId608bPVPGI+DWqAuYPdrsxUze+/+hu6Dpwzpr5spUleO+FazBYL0gbd80YioiXKd5WEXEZj251fHdxvRBJJjT0K787ctLHThGxjBgLZFqoXaCZTfFiVQTypYpsY6hFkXo+/fx+HXKYBrSiTs1MBdh265Yga0fRRg0CVHXO/DxqYcJCpEv49fZRvOGz9+KJV8Yw3JvGrdeeja+/92z84ztOQ086iXu3j+Ij33kcAPCm05bZjjO+5thFWLuoDxOFMr5bzxnxy0+f3IvP3r4dAPCpt5yCU1cONzymL5vCV689G+sW9WHvWB7v+dcH8OPHa8Y2cWe1oC+LZEKDrnunexbLVfzZD54CAPze2auwcdW8hses9jE549aaWbNAFCLTLe3A2bqvVohEPX3jZlbdP5G3zYoRhYj4MPariKh3a4BZhWnVJ2Ftzayw+FxEUSWUATuz6kyAfrdAtGZ2HZ52vVjatWYMRST8QiRfqsoP8uGeNJYNNbZm/I7uAua70iA+EV3Xpar6VouqCgCDuTT+6qqTccLSQbx2w2LTv6lqhdevkjXQDFByWuymZlo0q65d2I/XnTSC916wGgnlOFMu47t2xwgYF2g75cKaIwI0ZmmI93AyoclsHvVxgL+fmVe8u0BtzQTxZIgCwq8CWJDR/+afUdZhDNhuRD4OsBCxUK3q+NUz+/HOL92Pd/3LAzg8VcRJywbx4xsulGrIm05bjh/ecIHcmwI0tmUEiYSGa+qjvF/7zUu+L7zb9k/gI/9ZK3L+14Vr8JbT7Z8fqF3s/+1952DZUA47Dk7hyHQJC/uzuGBdbcNkMqFJ45hXlsi/3LsD2w9MYmF/Bh993Qbbx/hRROT4rs0v7PJ5PUgnNRTLVdtIbT8cnirKi/2R6WKkI9KiEFGLggV9GWRSCei6/TkVyoeQp0W4k1Ps/xEHRSSZ0AxVoEWfSENrxtJe2q2M7gKwDTSb8rEzw4poa0wUyg27blTs2j5RBpoJ82I6WTvHS+uFmeoXCKKIaJrmeCfqxhOvjOGFg1PIpRO44uQlto9502nL8bMPXSR/9wRJ5QLvla5qt9l2wGUqqdSiWTWR0PDF95yJj73+RNPfp118LY5TMy7KRake8Z6xbc3UHived4O5lMnrk0gYKah+VKzDHgvvBGprZsqm5ehE0BHwgoMi4jRlZDciHwfmfCGSL1WwbyyPZ/eO49/uewmXfeYu/K+vP4z7dhxCMlEbw/3uB85vkEuPGxnAD2+4AO8+dxU+ePE6OZlix1tPX4GBbAo7Rqdw17aDvo7pf//Ho5gpVXDB+gW46Qr7gkBl2XAPvv6+c+To5xtPXWYyPY4MCZ+IcyGy6/A0/mnzNgDAn155gknCVFm9sHYu3GLej7jcOSQTmgy6cvOZuKH6QkoV3TR1EjZ2ZtVEQpN30HaTMyLF1ihEaufBaQOvnZFOIBff2fTxg2D1oagjyLquKx4Rc2tGDTRrxnHfk0nK87DLwVMDGHdvagssykAzQ4XKQNM0I7pf9Ygc9e8RAYwiKkghIrJDXnfSEgzk7H/nnFDVCi+PgzSfpmwUEZtCxKkoaBU/ZlWrL8VNubAuvQOUKZt60TKmjK5byXrsZlE5FHR8t1Q19if5UCACJ6s6LEO088jU/jt+EzMAEB/bbBu5Z9tBfPS/nsCR6aKtw30gl8Lvnb0K//P81fLD2o6BXBp/fdUpnq/Xl03hf5y1Ev9674u49dcv4ZLjF7s+/m9+vhVb901gQV8Gf//203wvnVq/uB/fev95+K8tu/CBi9eb/m1kwFsR+csfP418qYpz187HmxXDnBVxV7ZnbAb5UqXBQJcvVeRdgJ0iAtQMqy8cnMKLo5MyajkIz+0zjywfnio2tDXCwk4RAWqKwkuHpm3TVaUi0i9aM0IR8TKrNp6v/lwKByYKISgi5vaPmBLJl6o4PFWUrRlRiIhCNF+qyp+zaA/1BNxTsWp+Lw5OFLDz8DRetWLY9jHiPNspIlFMzRjno/YaoiU1kS9jIl9COpmQqpsfRQQIHvM+U6xIT9dbXVRPJ8yKiHMhouu6Yj5Vp2ZcCpGqfVHQKmlpsLUb37UvfhKJWpJvsVxtUC6sgWZAo6fECAts/IzIpZOYyJd9FSJ+WzMZpTVjKCLeF39x3IGnZhySaK3qUdBAwnYxJxWRZELD3rG8LEKSCQ0L+jI4adkg/uINJ+L+my/FzVee4FqEBOXq81ZD04C7nj/omp9xx9YD+OqvXwIAfPptp5rGgf1w/JIBfOz1JzZU7F6TM798eh9+9ewBpJMa/vqqk11HFRf0ZdCfTUHXjR66imgzpBKaNNlZMSZnmlNEGguR6DYdiw8Sa4iX24p7cQETSoD44Br1GN8dtvmgHAgp1Mza/smmktLDsv3ApPyQFd/XQDYlL3QiDGqmFNysChhR726TM3bTBX1yfNffB7PTAjc7rHfJfdmUvFDtHctLNaQvk/Rd5AbdnvrjJ/ZgPF/Gyvk9cglbENQiwW3fTKWqyz0+aZtAM3uPSGvju07IpXe2gWbOaa5i87C1YCjZmFWFImFVROx+jn7HrmeKFVlgeppVU40R735G3oOuCbCLeAecF981M/XWDuakInLK8iH88PoLMK83g+G+NAayKc+MgFZZtaAXl24Ywa+e3Y+v3/cS/u+bTm54zMGJAv74v2q+kGvOX41LNrgrJ0GQWSIOiojI4rjuorVYv3jA9bnE8run94zjpdHphseLu/55fRnH89rq5MzWhkIknKVwVkqVqjSj9ls+SES4m90mYWsh4uURER+U6p4ZQX9IMe927Z/l83pwYKIg03/7MkkZ3KVptULyyHQJR2dKWDyYa/qDTCy/c5ucsWv7iOJvplRBpR625sSjO4/gnV++H+84axX+4o0neR6TSFVVi7+lQzmMzZSw5+iMfO+umNfr+/NBhpr5vKP9xv0vAwB+7+xjTIZOvyQSGjQN0HV3RUT9t5TN1MyUTetLbt9t0iPihBHx7hxOZqcC59JJjNsoFzLQzEYRsZpV7VqfongsePzMxK6odNJseLUjm1Y8IjZKnxO9AVt7dhHvgHNrRtxIxGnzLjBHFZGBXBqnrhzGqgW9GMylIy9CBNfWUxL/a8srDevHq1Ud/+c7j2N0sogNSwZ8+UKCIMZ+9080XgiPTBVlpscHLl7n6/lWL3AuJI64jO4KxORMM1ki1aou02TFnXZUiojaErBKqyK6/dm9jcm2ToWIk0fEUERsWjPSIxJOa0Z9DaH6PfTSYQC1toz6+2CEmtW+dipgxLtAeKzcgv2mZf/aeG61KPHyifx/t29HvlSV34sXdpNK6lj27gBGVYFbFLmVJ18Zw+OvjCGTTOB/nBm8LSPwkyVSUiZUVBVF/BytRa6u6/ICH7ap0Vh6Z+cRcc4uscvGqFR1+X2bWjMW9UQWIg6tGcA+hVRFzUbyumaoEe9GCGCA1oxPT1TRZvuu+jzWgnimWHt83BSROVmIdIrz1y3AcSP9mC5W8J2Hd8m/z5cq+PQvn8Ndzx9ENpXAP75jY4PvolVEa8Yu5n17/Y5++XCPb7OcUAPspHY/Mchr6jHvuw5PBw5/2nVkGtPFCjKpBE5fNQzAmNIJG9EOyaQSDRL1ictqhcj2g5MNFx7r1Iw6vmvXOnAKNAPCSVfNlwxZebjPrIgAwJaXaorIMks70gg1q51f446qudaMm1nVLoEym0rIi5Jbe+aFg5PYvPUAADQU+U5Yx5kBI8xt79hMoNFdQRBF5BsP1NSQK05ZIj1EzWBkiTj/HqkXffV97DQ1oxY16SbHd51wyxFx8ogA9tkYaoFlGt+VHhHz1IytR0TsCCq6fw4d8jkxA1iSVQPkiBjvH3+fiXYR74Bba6ZxjUIcYCHSRjRNwzXn1wLOvn7fy5jIl/Dlu3fgov93B/75zlqS6Z+9/gQcv8S9NdIMbq2ZF+qelXWL+30/32olC8SKn0JkZCCHXDqBclWXI5J+EW2Z9Yv6pdJzOKJQM7s9M4IlgznM602joig0QO3DUZwDYVYV56Jc1TE+01hQuJlVjX0zzbefnHIURJaIUFvEhVhgDTUT5yOotCsUkd1HZhwLT7u2j6YZ48t27QOBus/JbyEyZrMEbemQoYi8osS7+0XeiXpc1MbzJfywvh37Xecc4/v57Ui5KAwCccFOaGaFw8kMbG7lhKyIuCSrllxC1Oxi3otqIWIKNDOP5BpFZ/NTM+IzxsuoCijju+quGR9m1aCtGUdFJNWoHgFGgRy3qRkWIm3mzRuXY6gnjZ2Hp3HOJzfjEz99FgcnClg2lMMn33wK3n1uax9KTiypX7CPTJcanNTCPLt+kf9CRCgidlkifgqRREKTxUzQ9owwqm5YMiBfI6pQM5kkavMhommaVEWe2WOMEwuPTDKhSWNoLp2UBYA1XbVUqUrlZZ6NIiILkRYUEfWOUJWVrRdZYVQVDFtCzZpdmDUymEMmWSs8947Z+5TsWjOAuiHW/vs/PFWUI7BArc3gJ657zEaFWqbsmzEUEX+ju4D/qZnvP7IbM6UKjhvpx1mrG0MDg6CmqzphZIKYP/KdVgiYlIa2mlWdFRG7xXclZYLG3iPibVZtpjXjRUZpzTQXaObvd118ljt7RLrDrNqWQuRzn/scVq9ejVwuh3POOQcPPvhgO142lvRkknjH2SsB1N4Uxyzoxd+89RTc+ceX4PfOWRWZX2WoJy1/OQ5YJmdEa2bd4r6Gr3NCmE1fOTJj+tACjF9Yp9FdQbM7Z0QhcrxSiETVmpGpqg4KgMiPeWavUYgIf8jC/ozJgLhwwN6wKooETYNtaywMj4j0oFgKneWW1fbOrZm6ItJEoBlQu1iKosepPTPjEB/fKwsR+wvFN+5/GflSVW45BowpHzfstrEuU/bNGBkiAVozPgLNdF3Hv9dNqu8655iWf+eFEuBWfMkLvMV7oRZ5astQfa6wx3dTsjXjvH3XdmrGtjVjPF49j9btu2MuU2l+p2bkwjuPiRnT6zcb8e57asbJrGo/vjszV8d3v/3tb+PDH/4wPv7xj+ORRx7Bqaeeite97nU4cOBA1C8dW264ZD1+/9Vr8Q9vPw2bP/wavP2sVbbx8GGiaZpje0ZMfQRRRBYPZJFLJ1Cp6g05GuKi5yVhNjs5I8LMjl8yIF8jMkXEIUNEcOLSRkXk4GTt/KpLvwDjfFiLpjHlgmhnDOwPURGxfhA3KCLWQkSYVevH2EzEu8DLsGpklJg/JN2yRArlCr52X+2i/gevWScLJD/tmTFbj4jRmhGj7oEUER9r3B966Qi2HZhETzqJN5/unNfjFz8eEXnBtnzOiPd1VTerOKp/I3SzqlBEbAqnkkusfI9NbLnd/hygUeXwMzXj2ZqpG+J9tWbEJE7AiHejNdOqR8RpamaOtmY+85nP4LrrrsO1116LE088EV/4whfQ29uLr3zlK1G/dGwZyKVx85Un4KqNy32HlYXByEBjlki+VJG98CAeEU3THCdn1PFdN5qZnMmXKtKXsmHJoHyNqAoRES3uFM8sWjPP7h2X8f1yYsZiQDQMq2ZF5IhDkSDoDyFHxG2pnnlqxOIRCUkRAby38M44tGb6XDwiP3psD0YnC1gymMOVpyxtaCW5YVecjQzmoGmG96AnnbRtlznhZ2mZUEPedNoyDAZMUrVDXLRdFRGbMDOg9nMUQoL6/jJMo1roKm1KUXCsxu2SW46IjSJSsNm8C6htnLpZ1Udrxmt812+8O2CJeC8EMKsG3ODs5BFxUuakWXUuje8Wi0Vs2bIFl112mfGCiQQuu+wy3HfffQ2PLxQKGB8fN/0h4WGEmhmKyI6DU9D12p2Cn0pfxfCJmC8sQRWRIIXI9gOTqFR1DPWkMTKYbaMiYn/hXbuwD5lUAlPFirzAWkd3BWKE96DFWCtaZU6F24BL+qVf7CZEBGrrYcmQs1m1UtXlHVYzhYix/M7enDztUOT0ObRmdF3Hv9ZNqlefvxqZVAKDAQqRMZtJikwqYSogV8zrCXQh9mrNHJos4GdP7QXQuklV4MsjUra/wGuasRlXVdzKEY3uAuYpHOsxuy3ayykXd4GjIpIyFIFqVZetOrv3v9+Nt4cCeETU1kwQFUJVA72Op1LV5fmztmayDt+TXWhgHIi0EBkdHUWlUsHIyIjp70dGRrBv376Gx2/atAlDQ0Pyz8qVK6M8vDnHYtGamTAKke1KWybonc8xDoqI9Ii45IgAhkdk99EZ3y5x1R+iaZr8UJgpVQJtPPWLsfbe/g4ilUxIb4LwiRyQm3fNF/UFDqFmj+6sjc6KNo+VMMZ37TJEBCJLZNFAtkHiVTfwqh9qzbRm3NJV3YocJ7Pqr7cfwtZ9E+jNJPF7Z69qOF43ypWq9NxYJ5WWKu2pIP4QwFha5nQRuX/HYZQqOjYsGcApK5z3UwXBV46ISziZXWCevMCHPLprPQarYTXo1IzcM+OwrTdfqmCiUJabie3Hd+3bGFYOB/KIiEWVJfk9+UlWzaYSUqHyKkTUrd9+k1XFf885j0gQbr75ZoyNjck/u3bt8v4i4hsjS8S4EIqJmXUB/CECoYioika1qstWg9cv7MJ+IyreLehK5bn9xsQMULtICbf8oQhCzaRZ1SXAy+oTcVZEjCwRFZFqeqbD9EQYZlWjNdP4QSx8IlZ/CKAsvpsuScVC08zbUP0iPCJ2awHcihyn8d1/vXcHAOB/nLlS3ulKc61HIaIuSbSuIViutKeC+EMAb0VkPF+qv0Z46yOkIuIyviuj020KC7tzJvfMhDy6a31Oa7qqaAlZd6cA9kZOu4V36mPz5apUvnrSyYZCG2hs4zhx2OfCO/V41G3TfiZVTBucPUbAVSNq4/iufXE1J6dmFi5ciGQyif3795v+fv/+/ViypHHddTabxeDgoOkPCY8lNq0ZaVQN4A8RiImR32w/JCcMxvMleWdmZwxT0TRNbvL1257Zqigi4jmiHOF1yxERyBHeve6FyIJ6b1ktmPKlCp7eMwYAOPOY+bbPH+b4rt3PRHh9Vs1vvOgO9Yhk1aJhuksnm/INiEJkdLLYoG64FTl2ikipUsWvXzgEAHhnXQ0BjO/Pa2pGFGYDuVSDT2upMsIcJEME8J568FPYBiXpQxFxG4ud12v8jAXSNBqBh01tu1iLp7JHxDtgHrMtypaT/fhqoVRxNaoCSvFYdr7wF8oVeSPgL0ekXojUXzud1HwPJMgR3pL777tQRBKazbZij10zc6o1k8lkcMYZZ2Dz5s3y76rVKjZv3ozzzjsvypcmNiy2mZoRYWbNFCKnrhjCeWsXoFip4p9+tQ2AUQwMZFO2dx9W1iysva7fQuS5+sSMOqoZ5Qiv8GW4hRGdZMkSsaaqCoQiMqooIo/vOopSRcfigaxjG6Bf2bcSNIVWIEdVbdplbz59OW687Fj80WXHNvybvLDny1K6dzLuejHUk5Z339YRXnWs0FrkCIVkSrm4bz8wiWK5iv5sCscq712/rRk38+LSIVURabIQcbi7bjYi342US0CYoOSicKiql0DGpkfgEUnW9+MAaHg/u7WQ7CLenTwiqlnUbkzb/nmdFZEjU0YgoB+DsZiaEecxSCvTz+QVYHhlMqlEw++MMb5rPr/5uTo18+EPfxhf/vKX8bWvfQ3PPvssPvCBD2BqagrXXntt1C9NLFhbM5Wqjh31AqCZ1oymafg/rzseAPBfj7yCF0enfGeICNbU2zt+RniPThflxM9xI0YhIlpAUaSreo3vAsDxSwahacC+8TwOTRZcpmYa982ItsxZq+c7qgzqRatZw6q4yNi1ZgZzadx42XFYa/MeUD+8RQHbyofYyvn2y++mXcaCRZicqog8ubumIp20bNCU1SILkWn3QsQuzEyw3OQRaa414xTRPeWQldIKSR9TMyL4y05psO4TAtyXz4WBaBGVLMcsTbV2ZlW7ZFUxNeMS6OUWZub0vFaEijmvN+NrOaH1JizI74zfUDyn0V31OZwUkbBXiLRK5IXI29/+dnz605/GLbfcgtNOOw2PPfYYfv7znzcYWEn0iEJkolDGVKGMV45Mo1iuIptKBJagBWccMw+v3bAYlaqOv7/t+cCFSJDJGdGWse7EEdLykenoFBE3o1l/NiXbGw+9dET+sjspIhP5suzvbqkXImcc45yumUklpNTb7AZeN7OqG+lkQl4099Tbb61s7nQyrDpNzAD2UzNP1QuRU5abDZ+G38H9vWAX7y5ozawq+vvta82kfEzNyHFcm4uoKMbU3x85vRKBR0R9Xqsi4rbx100RcTOruu2ZUR9bcDGrSqOqz8816xRLoELEpyLiNLoLGMVV49K75nOAoqQtZtUbbrgBL7/8MgqFAh544AGcc8457XhZYqE/m5IXlf3jeWlUXbuov6UxvY/89nEAgB8/sQe/qfft/f7CBklXVaPdVaJszcjNmR4XDmFYvev5gwBqHzzWrxnMpeVF49BkEdWqjofrm2KdjKqCVkd4xYXZy7djh7hj3jMWhiJiH2rm1ru2CzQThcjJ1kKkfqyerZlpZ7l+zYLaSPbC/mzgkXbv1oy3whYUPx4Rt2kUu+wVaVaNoDWjPq+1eHLLEbG7yy86ju8aRYuXR8Tpoq0SJN4dsCtEArRmfCsi9vHugH27qVrV5fveKY6gU8RqaoZEj5ElUlAmZvxHu9tx0rIhvP5VS6HrwL/Vw5q8RncFohDZP17w3K9gNaoKZJZIh1ozgGFYvbteiFjVEKC2X0fdwrv94CTG82X0pJM4wWF0V9BKqFm+VJF3kc0UIuJivVcqIi0UIvPcCxE7N78onsX7o1ypSmNwQyEiL6ru58ktV2WoN43v/sH5+PbvnxvYlOt1EQmyAM0vgRSRoGbVCMZ31eNoNKs6H6f9+K67WTVfrsjvy26hJKAWLS6tGTEx42N0F2hUKZpRRJx2KwncFJGskqwqQuOe3TeOqWIFfZmkVHDjAguROYYwrB6YyLc0MWPlw791HBKacVfmZ9YeqH04iIvjS6PuI7zPKdHuKuLDoVNmVcBQRMT0kNUfIhCTM6NTBTz8Uq0tc9rKYc/FYq3EvAtpOpnQmroTFz+fMBQRp9aMkIztWmB9liLshYNTyJeq6MsksXah+QNVFCJeUzNGa8a+MDtlxVBTvim/rZloFBEXs6rDdAlgFGNHFI9I2UWZCAOhzFj3zTgt5wOUMduyWojYX4zFY3XdMIc7tWbsntfKoQDx7kBrHhFRJDy12z3Q080jonpAxOPu31FTX89aM7+tid5+iNfRkMhRR3i3tzAxY2Xdon689fQV8r/9KiKAd3tmPF/CX/zoaTy66ygANKgHRrpqdDkifhURgZ0iAsCkiDz8sr+2DAAM1EPNmskSkW0Zy+Zdv8hCJGSPSFW5g5+S0dN2ikjt9YRqYhhVhxqMg8aSPg+PiItZtRU8FREfnqOgGEvkfASa2bRa5tm0swyvRjSXCKG0OCWr2nlZ7FbbGx4Re0UEMEzWToWIn6V3QVsztWh847+DTJqdu3YBAOC+HYdcH+fqEVGKE+F9ub/+fOL54wQLkTmGaM3sGyu0FGZmxx9eeqy8gwrSW1/jkNCq6zr++4m9uOzv7sKtv3kJug68+9xVpnFNwNj9EGWOiJdHZPGA2U/gVIgsVCZn/BhVBa0oImL0sNmLrvgAD2NqZsW8HmRSCRTKVdMIr9tiMDE1IxQRJ38IYCgcU8WK7XZXgdv4biuIi1q5qtu+fhStGT9TM2WHFgbgYFaN2COStjGrVqq6TEC1K4DsTJxyasai3GSSRkKpVyHia2pmMphZVdM0k3ejN8CUyrlr50PTaiPqYgLPDqmI2JyrdFKD+NHlyxVUqzoefPFw/flZiJAOs7heiDyzdwzj+TI0zVAkWmXl/F589HUbsHpBLy48dqHvr7NOzuh6zcR5zVcfwvXffAQHJgpYs7AP//6+c/DXV53ScFc/v6/2ARN2a6ZYrkozXL/HHaymaSZVxKk1IyZnnts3gZcPTUPTgNN9FCIDsj3hvUPFytiMe4/cCxFqJu64WylEUsmE3PIsPD+A1/iuoYjouq4UIo2+mkHlYuPWnjHMqs2dEydyGeMj1e7CFkVrRl7UXc2q9hdswLxPSKhUboVLGCRtVBy1cHOLeFcTRZ3MqmohIEb+nc2q/qdm/Cy8E6gtkyBTUsO9GWxYUntv3++iihQrdbOqTcqxKaG1VMGz+8YxNlNCfzaFk5c1/t50GhYic4yRukfkkZePAqiZB8OcKb/u1Wtx5x9fYhsX7oQohJ7eM45/vnM7Lv27u/C7X7gPdz1/EJlkAjdedix+9qGLHIsb8eEwkS+73gULRicLuHfbqOfjVPNsn487WFMh4tiaqf395mdracPHjwz4CkhqSRHx2O7rhfUDvNXRPzH19JxtIeI8NVOp6pguVvB0PTjOOroL1C5womhzm5zxmqRolkwyIe9E7dozk1Emq7q898su3guhFOi6MR5einh8V5pVFV+LWkjZ5ojYtWaE98VlckScc6fRdfG4YqXqqCoFbc0A5mmWoMX7eT7aM6JwsovDB8wTXPfVpxnPWj0vdv4QgIXInEO0ZsSdRBj+kFYRhcize8fx/37+HHaMTqE3k8TvnrECP7/xItx42XGuxdJwT1p++B/xoYr8yX89gXf/6wP41TP7XR8nPsCyqYSvX151aZ1jIVL/IBO7Tvy0ZYDW9s0Y8e7N3f1bC5hWUxmPty1EnHNEVFn7yd1jmClV0JtJ2gawAfC1gdcr5KpZ3HaF6LpujINH4BFx377rrIhkU0l53kV7JvLxXdmaMY5ZbdPYHafdSKqTRwQw+yQA79aM9blVDgVYeCfItFCInLu2tu7BXRGpt2Yc9j6puSvCqBrHtgwAxCvVhETOiGUjbBwKkXWL+rGwP4PRySLOWj0PbztzJV5/ylLfd42JhIZ5vRkcmiri0FRRtp/sqO0oqakhm7fux2UnOgfrBY3jPsmHIrLQ0rLxY1QFWpyaaSFDxO7rmo14F4hCZOs+YyrAbXw3kdDQm0liuljBA/UP1BOXDjpm3wz3prH76Izj4jtd111377RKLp3EVLHSoIgUysYdtx+FzS++klVlYWF/0ZrXm8F00Thn5YjHd+3Mqmqbxu5nKwoG4b9JJxOOUzPq4wV2o9qAuWDJlyoNv+/lipFFEsSEb1ZEgv3OnLNmATQN2HFwCvvH8/IGUsVLEREFynSxjAdfjK9RFWAhMucQ47uCVjNEwqAnk8RPP3QRiuVq4Ehtwfy+WiHiZVh9Zs+4lHbv2TYKXdcdJ0kMGd3fRWPNwn4M96YxkS87fh8NhYjDojsrAy3kiBydaq01M2hVRFps5Yn+90uHppEvVZBLJ13Hd4FaMThdrOCB+geqnVFV4DXCO1WsyAtg0KRZPziFmqk/uzCTLX3liHi0WoZ6asWbUETcAtDCwM6samSdaLa/k6oqmi9VkE4mFI+I++M1zfgdspJI1BbSFctV28V36o6jwR7/PzfVIxJUERnqTePEpYN4es847t9xCG86bXnDY6Qi4rDTSxRYj+48ivF8Gf3ZlOlmKU6wNTPHyKWTprvAOCgiALB4INd0EQL4T1cVu10A4JUjMw15FirTxWCjlsmEhn9/3zn4t/ee7dhLVqXdkUHnRXdW+ltIVpWKSMCUUIH1Yt1qa2ZkMIvh3jQqVV1Obk27jO8CRqiZmDTyU4g4tWbE32dSiYa75jBwiuieLhiL/VpJMraSTPrYvlv/N6e753l95h09bgFoYZCy2TUj/B5OKkw2ZUzCiJsJp6V34vGCoZ60646YXMp5cka8N1MJzfH82R6v8t5qRkUUPhGn9kyhfqxOW33Fe/uu52ohi2fHMD9EEM+jIpGitmfCGt3tNMbiO/cskS317A7BvdudTavNTDicvHwI5693nhhSC5Qzj3FedGelP2tswQ3K0bDNqi22ZjRNw/Ejoj1T84lMuZhVAaM9JkYW7YyqAq/Fd2q8ezO5Kl44LRyLwqgK+PSIeCgiotg82naPiGI8ddkzA5gnYcS5NcZ3nQPQAO/3vtsGXtGi7c00boZ2o9nxXcG5shA5bPvvBamI2F/GRUG8ZeeR+vP5U187AQuROYhozyzszzRtYIwbonfr1pqpjQXXfikvrBcLv3YpRCYjWNmeSyfl3hi/RlXAKIa27Z/A737+N/jtv78L535yM87btBkPvWT/QSUwNu82qYg0TM207m8wJmdqPhG3HBHArErl0gnXlqKx+M5BEWmxMPPCKdRMbt4Nec+Hr2RVD8+HNV21HPX2XZuIdz8jw8Z6+9q5FS0kW7OqcvH3MiXbLdQTCEUk6Mi1qTXTxM/87LXzkdBqsQb76qnGKm6BZoDRmhFKWVz9IQALkTmJMD7NFjUEMKZR3Fozuw7P4MBEAemkhg9evA4A8JsXDjlK2lMBPSJ+WbeoH5qGQFkrYjX9dLGCh18+guf3T2LfeB57x/L4yH8+7rqnp1Wzak86aerBh1GIHF/3iQhFxJiacfKIGK954tJB1wvkUK+/1kwURlUAyDm0ZqJIVQX8ekTcI9vnWc5Z2cV7EQYyR0QpntyyTgSyyKtPJLl6RNTWjEcR7hZqZqx5CPZzy7RgVgVqSzJFC9KuPeMW8Q6YC7GBbMo01Rc3aFadg4jFY9Yttt2MaHkccYn2FpHqJy0bwtlr5qM/m8LR6RKe2TOOU1Y0Sv2TEV04vvQ/z8Deo3kcN+L//K9a0Iuvv/ds7Dk6g8GeNIZ60ujJJHH9Nx7BzsPT+Mwvn8ef/c6Jtl/rtQbdC03TMNSTwWi97RWG0dI6wuuliKgXATd/CODtETFSVaNRA3sctrkKiT/MMDNAmZpxi3j3UBtEa0aaVetFTZheFpW03fiux2QPoCgXQhERrRmXHBEgiCJi4xERymjAAlxtmQT9WsG5axfgiVfGcN8Lh3DVRrNh1UsRUT0qcfaHAFRE5iTvOe8Y/PHrjscHLl7f6UMJjfn1aZRDLht4hVFVhPoIqdLJJzIVUU9/8UAOp64cDvx1rz5uEd5x9ipcecpSXLB+IU5fNQ+ffPMpAICv/PpFPLrzSMPXzBQr8s5pXpNmVcCsHjT7oaoiCpEDEwUcmSp6ekTUZNuWC5EWCzMvnDwiojUTZrw7EGz7rtPFSPx8j05bFZGIzarq1IwPRcTqEXHNEVEuxJ4eEZuwNMGUh1rnfKzGz9nJhO2FW7BZwas1oxRicW7LACxE5iTz+zK4/pL1WDLknLfRbRiL75wLkS0vid0uNdPWhetrv5xOPhGRbRH2HWyYXLJhMd68cTmqOvAn331C3iUJRFsmldBaKiDUD/JmP1RV+rMpOTG0dd+EVER6HD7s1Yu3m1EV8GFWbbFV5YXcwNugiERT2BoekeYi3gEj7E6oRXJ8N2qzqnLMxYp7sQQ0ejmkR8RDEfH6WeccfmaA8TkQtEWrKhLNqqpnrp6HZELDzsPTcrO3oFhXhZzMqmo+CgsRQtrAfI9CZGymhOcP1NoAwiQqPBoPvnTYtTcc9oUjbP78d07Egr4Mnt8/ic/dsd30b2pwVysTIuoHeVgZGCJP5Ok9Y/Ii5FQsiWIwm0o0LD20ItoMToqIyBeJyqzqlCMip7A64BExCgunQDPz1uKot++mbULYyj6Kn5yl7eU6NWMZ33VDju+WGz8Hmi0gTVMzTapgA6pP5AWzKuKliPTU9x4N5FIN28HjBgsRMitYoHhEqjYfyI/sPAJdB1Yv6JWpp+sW9WNkMItiuSrzKVSiMquGzfy+DP7ijScBAP75zu2mxFLR8291OkqEmmVTidB8A8Kj9Oiuo/LvnNQWUfyc4GFUBQK0ZqJSRCyGSkEUm3cBI0ek7LZrxmt819KaEQVCOmJFpGQbaOb887W2vZyW3gHhTc0Y47vNtWaSAfNHrDjFvRelWdX+ucXxnr16fmRen7BgIUJmBeJCW9XtxzatbRmgZsK8cP0iALWUVStRmVWj4HdetRSXnTCCUkXH9d94BL+pt5vCGlUVKkOY6pDwiTxaLwLdPrBPWzmMVELD605a4vm84qIzU6o0tKqA6D0inopIRDkirQSaCePueL6ESlVXklXbN75b8pjsAdRNuWaPiFeyqnch4h1oFrS1KZSKoPkjVmSw2Yv2iohTIfI7r1qKK05eghteG38vYPw/YQnxQSaVwEAuhYl8GYenCg3JpiJrw7rb5cJjF+C7j7xi6xOJqqcfBZqm4RNvPhmP7jyCFw5O4ff+5QFcuH4h1tbzNlr1Q4iv72kimMkJoYjsqWck9KadP7DPW7cAT/7F63z5UwZyKWhabZvs2EypYe+PMb4b1dRMewPNkjZ7W6x4Bpr1Ght4x2dKkY/vpmzGd/1klzR6RJzNquoF2utn7RpoVmxufDerFCKtsHFV7TNr1+EZjOdLclt30WN895gFffj8u89o6bXbBRURMmswDKtmRaRUqeLxV44CAM60hIhdsK7mE3lqz1jD5t5uMKuqjAzm8LMbL8LV5x2DdFLDvdtH8fX7XgbQ+kVXXKjCyBARrF7YZ7qAeLUs/JpkEwlNfliPzTR6hsai9oh4RLyHMXWkkvYR8e7lEUknE/J9fnSmFPn4bsou0KzqXfxYPSKhmVUjHN9tVVEd6kljpB5CKVYiAEaom5NHpJvo/u+AkDqGYdUc8/50fdHdUE+6IcRt8WAOx430Q9cbR+SCLr2LA4sHcvjLN52M2z9yMX73jBUQ15GRQfttwH4R0naYhUg6mcA6xXga5iI4N5+IGvEeBd7JqtFMzZRdklX9KBzDMl21qHhKohrfbfS1iDt8XzkiPsyqoXlEmvy5iYj5MKbMjl1cUw+37zcKES+PSDfR/d8BIXXm99WzRCzKxsP1tswZx8yzXXx1QT3u3Zon0k2tGSsr5/fi0287Fb/8o1fj5is24NoL1rT0fOeuXYC1C/vwhlOXhXSENU5QQvXCbPs4FSKlSlWaRiMb3/XYvhvVrhlfiohLYSHOx9h0yUhijdqsqk7NVP17RPwsvVNzRFrxiEzJVQ+dUUQAYznptvrkH+A9NdNNdN8nLCEOyNaMJdRMTMRY/SGCC9cvxFd//RLufv4gdF2XPoWpCHbNtJv1iwewfnHrCbojgznc/n8ubv2ALByvFCJhqi1OhYj4b02rjUZGgRibtF7UDIm//R4RP20PsYvoyHRRFgjRm1UbA81cFRERPFY2T81kUo3fl/BOZFMJkzri/rzhBZoJ9fXYkdZXaYjn2HbAThHpHsXWie79hCXEwvz+xn0zuq7LRNUzlYkZlfPWLUBvJolXjszg8VfGcNrKYRTLVfkhF3buAzEwFSIhFnxDlnFUgfjvgWwqMv+DnJopOiki0SSruk7N+FgoJ5cFTpfkAr2ozapq8SSnZlzu8K3KRcm1NVP7Oz/Klz+PSLD35xnHzMOvb3otRgZaa4sCRmtmm01rZjYoIt3/HRBSZ77NBt6dh6dxsL7o7lU2+2SA2p3OZSeMAAB+9NgeADAtkesmj0i3IULNgOZWpTvhpIgcmKhN6Czob/3i4ISXRyT8XTONe1usyNRSl+JLZonMlDzNra3ialZ1OUZraq3bDp1j5vdB04yLuBuurZkWovmXD/eEoiqJEL/dR2dky3g2tWa6/zsgpI41XVXXdXz6l88DqOVQuMmzb6x7H37yxB5Uqrq8e82mErFeFtXtjAxmIzHCOhUiYupgXX2sOQqcIt7FnXWYyg8QniIiWjNHTWbViJfembbvinaQi0dE2Qmj67proNmqBb3Y/OHX4Avv8R5hdVVEYjA9N68vg4V1xfeFg5Om751mVUJihGjNiELke4/sxo8f34NkQsPNV57g+rWvPm4RhnrSODBRwAMvHpoV/pBuQNM02Z4JM3HUqRB5fn/N7HdsgM3HQemxmcCIstXna2rGR2qp2poxNuFG1ZoRS+/U1oz3pE5WUS7Uto6TKrB2Ub+vAsJQRBrPobgpCbNQbgZpWN0/KdUQgIoIIbFCXXz30ugUbvnhUwCAP7rsWJy+yt6oKsikErji5Fpq548f39OVo7vdykn1PRgivTUMnBbfPV/vsXvtq2kFtTWj67WLpdrqC337bpAcER9mVVNrJqrxXZtYej+TOqpyocbDtxKhbn1elVKlKr0YnU5Ylj6RA5OyqAWoiBASK9TWzIe+9SimihWcvWY+PnCxv4hj0Z756ZP7ZBBWpz985gLvf/VafPDidXjXuatCe85hG0VE13Vsqysix0WoiIj8CDUqXRS2mVTCVZVohiDJqmkXz8ewsvhO5o5EpojYmFV9LNoz4vOrpvj+Vk21shCxLL2bVgzHYReQQRGTM9sPTKBQCq8IiwP8lCWzhgX1HJFipYrHXxnDYC6Ff3j7ab6nI85ZuwCLB7I4MFHAz57cB6B7UlW7maVDPfjo5RtCfU671szoZBFHpkvQNDQE24WJmocyU6ogk0rIVl8U76cgHhE3RURdfCd+Z6JTRERrxkYR8bH0rlCqSFVA01pPgFW9JypCyUq1uLguDIwskUllbDnR0h6buND9pRQhdXoySdNF4FNvfRWWDff4/vpkQsPrX7UUAPDjJ2rTM/SIdCeDNoWIUENWze8NJe3SiXRSkxdGIfUb6Zzhv27SRl2wYqgNboWIkSNS9vH4VkjbTPr4SX9Vp1vUiZlWL8ZOUzOqV6zTF3zRmtl5eFq2HLOzQA0BWIiQWYZYcPaOs1biylOWBv560Z4Rd0ZURLoTO0VEGlVDCHhzQ9M0wydSl/anItzkbBeXrlKp6qhbVVzv6kU7ayJfltK/WyunFaQiYmrNeI8MGy2UqswQCUOpcPKIGD+3znvFFvZnMK83DV0Htu4bB2CYd7ud2fFdEFLnpis24JrzV+OWN5zY1NeftnIYK+cbKkqnnfKkOUSboVCuyovL8/XR3eNCSLr0ImfJEolyXYCXIqK2P9xaLWoMupg8i0oRMQy2yvhu2VuFMVoohlk1jKkR8fMqVXRTi6vZzbtRoGmaLKKf2l0rRDrdLgqL2fFdEFLnylOW4i/eeFLTC9Q0TcMbXmXsU2FrpjvpV5JThSrSDqOqQMS8G4VIPUMkgsJWKAhOHhFTIeLipUglExjI1d7vUY/vCqXFHGjmY9eMcl4LZe9Wjl/Ulq6qikS1MblZ1teL6Kf3jAEwjNHdDgsRQiy88TSjEGFrpjvRNA2D9Yvq2EwJuq4bo7ttUERklkjR7BGJ4v3kpYioF3uviR0xwiuIenxXLZJKfnbN1M+rrhsTLWFMIakjsGoh0uyemagQY+fP7KEiQsisZsOSQSnfUxHpXlSfyMGJAsZmSkhEPDEjsMa8R7V5FzAUAUdFpOp/usS6lyUyRSTZWDyV/eyaURa8jdeVrjAuxomEJls86uI7UezE5XNAtGYmlHHw2cDs+C4ICZmbrtiAV60YwuX1kDPSfQzJyPKSVEOOWdDnuYk1DKwekSgl/qSHWdXPWKxA9Yn4/ZpmSNm2ZryzS9JJDeKfJwqlUI8xJwoRVRGJWbChVc2bDWFmAHNECLHltRtG8NoNI50+DNICqiKy6/A0gGgTVVWMfTO1i2uUiohfj4ifcLLG1kw0iogonsytGe80VzGRNFWsYHymdk7TqXCOMZdOYjxfthQiwtsTj0vl4oEsBnIpTOSpiBBCSOxRC5FtcmImeqMq0NiaiXRqxqbNoRIkrt3amolqfFeoGKZkVR85IoChNk3kQ1ZEbEZ4RaBZXMyqtckZo5ieLYrI7PguCCHEwlCPYVbdJpfdtUcRyVnMqtJrEMnUjLtHxFh4560cDLdJEbEzq8r0V4/ixyhE6qpAaIVI4+K7OI3vCtQcHCoihBASY8QSvbHpogwza5ciYvWIRNmaUadmxJI9lSAekWGLR6TV6HQn0jbtJD/pr4AR4jUecnuix04RkdH88VBEAHMxnU3F57hagYUIIWRWIloz2w5MYjxfRkID1i7qa8trW1sz0xGO76qTLXaiiByL9aFuzOtrk1lVbt9tnJrxUjjE5Mx4yK2ZbNrs6wGMAjIuHhHA2DkDUBEhhJBYIwqRR3ceBQCsXtDXtjtIGWhWFIpI3fQYoSICGG0YFbmTxYffQ6hIQDjL5JyQrZmqTY6IR8EkjMBifDeMQDPAySMixnfjozwcq6h69IgQQkiMEYvvhCrRLn8I0CjzC7NqFBK/6qmw84mUAygiQ4pZNSqjqvrcum4cc9nHrhnA8HIIj0jo47vl+AaaAcCyoZz0GlERIYSQGGOdAGmXPwSwyREpRu8RAewnZ/wskxOo47tRGVWtzy2UEN9TMynz1Ex4ZlXzokJAzX+JTyGiaRrW19/L9IgQQkiMsYZzHdvGQkS0D4zWTPTbdwGgUnFWRNwSSwWqWTWqVNXacxvHIoqnss8x44apmZBUAaG0FMqNUzNxas0AwPEi+TkmY8WtEp8yjxBCQsRaiLRj665ANauWK1VpgIxCEUkkNGharc1RcvWIeBcWgz1p+VxRGVUBsyJStigiXgWQKETCNqvaTc1Emf/SCr//mnXIppJ46xkrOn0ooRCvs0sIISGhFiLJhIY1C9szMQOYL2rTyoUtqjvrVEJrWGEvCDI1k0xoGMylMTZTisyoCpiLDVEoCWXES+Gw5n1EGWg2VYxua3IrrFvUj7+66uROH0ZosDVDCJmV9GaS8oK3ekFvW/vpuYyhiIi76lRCi2xbqrFvxqY1Uw12wZ5X99ZEqYhomtYQxBZUERGEFfFuHd8tVaoo1ts0cfKIzEZYiBBCZiWapknDajuNqoCqiFRN8r6mRTQO67JvRka8+1Q4xLLAKM2q6vOLAsRv8JpQRAThJ6ua03ABoDdmHpHZBgsRQsisRYzwttOoCpgnMKZkOmd0d9Upl30zQZJVAcOwGqVZFTBGeMtWRcTn1Ix8ntDGd+vFY10FEZNO6aQ2a6ZT4goLEULIrGXxQBYAcOLSTikiRmsmSp+B274ZYyw2Pq0ZQE1XrULXdd85Ij2ZiAoRy/juVAxTVWcrPMOEkFnLx648EXdvO4jLThhp6+uqUzNR7pkRGPtm7KZm/JtVAWPxXfStmVoBUaroJiXHK0cka/GIhLZrJiPGd0UhEt2iQmKGhQghZNZyyoohnLJiqO2vmxMR76VKW7Io3DwifpUGgfDV+H18s6SV4kk12Xp6RFJWj0hIEe8p89RMHDfvzlbYmiGEkJARioiuA4enankXUU5eqBt4rYicjozP6ZJ2eUSS0qyqm/JPPD0i1qmZ0Md36x4RKiJtg4UIIYSEjHqxPDRZABCxWdXVIxJMEVlY99VYvRhhI82qFYsi4uURiagQyVqmZqYijOUnZniGCSEkZNLJBNLJWsjYaL0QiXIE1C1HJKhH5NINI7j2gtW48pSl4R2gDeqkjzjGhFZLinWjMUckZEXE4hGhWTV6eIYJISQCcukkSpUyRieLANpjVnXziPhVDnoySXz8DSeFd3AOpJTxXaNY8j7GxhyRsD0i5vHduO2ZmY2wNUMIIREgWgiyNRPhnbWhLrhMzUTs+QhKWhnfFUqOn3AyqyIS9tK7fJGKSLuJrBD5xCc+gfPPPx+9vb0YHh6O6mUIISSWCI+FUESinL5Iuk3NBAw0axfm8V3/7aOozKri5yVaM1IRoVk1ciJ7ZxaLRbztbW/DBz7wgaheghBCYotQREalWTX6QDPbqRm5ayZeikhKGd8NYqi1tmbCTlYVywNpVm0fkZ3hv/zLvwQA3HrrrVG9BCGExBYRvFUQi9PaEWhmY1YtlusX+ZgpIqKAKFd0RbXpnCKiPm8tEbc+vkuPSOTEqtQrFAooFAryv8fHxzt4NIQQ0jw9ljv3KHNEUi7JqrLtETOPiCieSpUqigEmexo8ImGN7ypeE3M0f6wuk7OSWJXImzZtwtDQkPyzcuXKTh8SIYQ0hTXvomNTM8IIGpKpMyzSyviuCF3zyhABGpNV0z6D2rxIJDR5jvLlqty+S0UkegK9M2+66SZomub6Z+vWrU0fzM0334yxsTH5Z9euXU0/FyGEdBJrIFi0Ee/OHhFjaiZehUhKDTQLMGKcqme0CMJSRACjyJkpGtH8VESiJ9AZ/shHPoJrrrnG9TFr165t+mCy2Syy2WzTX08IIXHB2kKINuLdeWomaKBZu7ALNPN7jLlULaMFCHcaqCeTxHi+jHypIiPeo0zEJTUCneFFixZh0aJFUR0LIYTMGtrZmlHbHFYMtSFehYidWdWvoTabTmKi7uEIs+WUkwZjY2tyL8d3Iyey34ydO3fi8OHD2LlzJyqVCh577DEAwPr169Hf3x/VyxJCSCywFiJR3llLj0jFLdAsbq2Zulm1WjVGjH0aatUR3jAVETVddZrju20jsjN8yy234Gtf+5r8740bNwIA7rjjDlx88cVRvSwhhMQC1SOS0BrzL8LENUck5oFm5YqOolRE/BUiapEXptKTUxbfTRVFsioVkaiJ7J156623Qtf1hj8sQgghcwHVI9KXSUHTomuNuHpEYtuaUSPeReiav0tSzlSIhHcZE9kvk4UyiiL/hWbVyIlXiUwIIbME9a49annfdWqm7H+hXDsRraJSVQ+s2qjqUqhTM/Wf2eGpovw7tmaiJ17vTEIImSWorZneiLMokkm37bvB/BftIqUoIqWAoWuiYEglNCRC/L7E+O6h+n6gdFKLXf7KbIRnmBBCIkC9a496BNSPRyR+iohIVg2uiGTrptKwfS+ieDw0VUv4ZoZIe4jXO5MQQmYJPRaPSJQYyao2UzNxXXonzKrVauAcEVEwhP09iakZoYhw8257YCFCCCERYDKrRtya6capmbQSSx9k+y5gtFDCbpsIFUt4RHrpD2kL8XpnEkLILKGdZlU5NWOzfTe+yap1s2pF2TXjN1k1HU1rxmpWpVG1PbAQIYSQCDCZVSNuzbjvmgmmNrSLtMms2tzUTNiFiBjfPTTF1kw7idc7kxBCZgmqItIf9dSMLEQaPSJCbQhzzDUMjGRVQxEJGmgWVWtmbKYEgGbVdhGvdyYhhMwSch3IEXELNItraybo9l3AUC5Cn5pp2A9ERaQdsBAhhJAIUFszkU/NyDZH93hE0soxG/twgnlEMmFPzVgKESoi7YGFCCGEREAnklWtikilqkOv/1U6Zh6RpJKsahRLnfWIWPcBRd1SIzXi9c4khJBZQjvHd8VF3WpWLSnbeOOriFSNEWOfikhPVFMzKSoinYCFCCGEREAyYcSDR92aSTtEvKuFSdxyRFJK8SQme9I+zacXrl+I01cN421nrgj1mKytGXpE2gPLPUIIiYhcKoFiudqGHBH7qZmyoojErhBRFZGAu2YWD+bwvQ9eEPoxZS2tGSoi7SFe70xCCJlFDPdmAADz+zKRvo6TR6RYL0Q0zShW4oJszTSxfTcqqIh0BpZ7hBASEZ9488nYuncCx430R/o6Th4Rw3sRv3tO0ZopqVMzHfaxNIzvUhFpCzzLhBASERcduwgXHbso8tdxUkSMzbvxUkMAc2tGFCKdLpgaFRFeIttB/MpkQgghgZAeEUuOiLF5N34f9emkoeKUYxK6Zh3f7WXEe1uI37uTEEJIILwUEb/L5NqJKJ5KqiLSaY9IiopIJ2AhQgghXY7T1IyRWBq/j3rRhilX9NgUTI3JqlRE2kH83p2EEEICkXLIEYmLCdSOlDI1I/fhdLhgylpyTGhWbQ8sRAghpMtxnJqp/3fcNu8C6vhuNfD23ahIJDRTMcLWTHuI37uTEEJIIFJOZtWYXODtSCmtmbh4RACjPZNOGsm4JFp4lgkhpMtx9ojEo+VhhyiOSsquGb/JqlEiJmeYqto+4vfuJIQQEgjnqRmhNHT+Am9FHd+VY8YxUCCEItJHo2rb6PxPnRBCSEsYioi1NSPyOeL3Ua8WT6VyfBJgxQhvL/0hbaPzP3VCCCEtIVovjdt346uIqMVRvlyp/13nj1O0ZmhUbR8sRAghpMtRR2FV4rJMzg7VDzJdrBUicSiY2JppP/F7dxJCCAmE1/bdOJhArajqR7Ecn+A1UYjQrNo+Ov9TJ4QQ0hLGrhnz1Ew5xh4ROz9IvFozVETaRfzenYQQQgLh5RGJY6BZIqHBKtTE4TipiLSfzv/UCSGEtETSwSNiTM10Xmmww6rUxEG5EVMz/VRE2kbnf+qEEEJawitHJA7eCzvSFkkkDgXT8nk9tf8d7unwkcwdqD0RQkiXo+aI6LoOTTNSS4F4TKPYUVNAKvK/45Aj8v5Xr8XGVcM4e838Th/KnIGFCCGEdDnqVExVB0TdEffWjLVAisNx5tJJXHTsok4fxpyi8+UnIYSQlkgqhYi6b8YINIvnR33S2pqJ4ZgxiZ54vjsJIYT4RvWAqD6ROAeaAebjTic12VIic4t4vjsJIYT4RlUWRDsGiHegGWBuzcTVUEuihz95QgjpctRCw04RicNYrB3qccXBH0I6QzzfnYQQQnyTSGgQXQ1bj0hMFRG1gIpr+4hED3/yhBAyC7DLEhFtmnQqnh/1avER1/YRiZ54vjsJIYQEwtg3o7Zm4u0RUdsxVETmLvzJE0LILMBu30wp5lMzacvUDJmbxPPdSQghJBApm30zIlk1rkZQ9bjiaqgl0cOfPCGEzALsPCKiKImrIqKOHce1fUSiJ57vTkIIIYEw9s0YUzNx3zWjFkhxLZZI9PAnTwghswA7j4jMEYlpWJiqgsS1fUSiJ57vTkIIIYFQN/AKukoRiWmxRKKHP3lCCJkF2OaIVGOuiKjju6l4FkskeuL57iSEEBIItxyRuAaaqQVSXIslEj38yRNCyCwgaTc1I3JEYjqRkjYFmsXzGEn0sBAhhJBZgJEj0jg1E9eMjhS37xKwECGEkFlBsn4hV1szpWrMA83U1kxMj5FEDwsRQgiZBaRspmaM1kw8P+q5fZcALEQIIWRWYOcRMbbvxlNtSCW5a4awECGEkFlByiZZVfz/uPov0tw1Q8BChBBCZgW2ikg53oFmaoEU18keEj0sRAghZBZg5xGRgWYxVRu4fZcALEQIIWRWIC7k5hyReCsi5tZMPI+RRE9khchLL72E973vfVizZg16enqwbt06fPzjH0exWIzqJQkhZM5iVUTu2HoAVb329/3ZVCcPzRFza4b3xXOVyN6dW7duRbVaxRe/+EWsX78eTz31FK677jpMTU3h05/+dFQvSwghcxLpEalUMV0s489+8BQA4NoLVqM3E89CxJysykJkrhLZu/Pyyy/H5ZdfLv977dq1eO655/D5z3+ehQghhISMqoj8/W3PY/fRGSwf7sEf/dZxHT4yZ5IMNCOIsBCxY2xsDPPnz3f890KhgEKhIP97fHy8HYdFCCFdj7ioP/HKGP77yb0AgL++6uTYqiGAZfsuC5E5S9u0sO3bt+Ozn/0sfv/3f9/xMZs2bcLQ0JD8s3LlynYdHiGEdDVCEfnR43tQqer4nVctxSUbFnf4qNxJc9cMQROFyE033QRN01z/bN261fQ1u3fvxuWXX463ve1tuO666xyf++abb8bY2Jj8s2vXruDfESGEzEGSykV9MJfCLW84sYNH4w+TWZWKyJwlsGb3kY98BNdcc43rY9auXSv//549e3DJJZfg/PPPx5e+9CXXr8tms8hms0EPiRBC5jzq3pabrjgBiwdyHTwafzBZlQBNFCKLFi3CokWLfD129+7duOSSS3DGGWfgq1/9KhKU3gghJBL66iO6Zx4zD+84qzva2mZFhNeHuUpkLqbdu3fj4osvxjHHHINPf/rTOHjwoPy3JUuWRPWyhBAyJ3n3ucdAA/A/z1uNRJfEpdOsSoAIC5HbbrsN27dvx/bt27FixQrTv+m67vBVhBBCmmH5cA8+evmGTh9GIFQVhGbVuUtkP/lrrrkGuq7b/iGEEEKSCUa8E+6aIYQQ0iHSbM0QsBAhhBDSIdR2DFszcxf+5AkhhHSEFHfNELAQIYQQ0iHU4oOtmbkLCxFCCCEdIZVgoBlhIUIIIaRDmMd3qYjMVViIEEII6Qj0iBCAhQghhJAOYZqaoUdkzsJChBBCSEdQ2zEZKiJzFv7kCSGEdIRUksmqhIUIIYSQDsFdMwRgIUIIIaRDpJMJ9GaSSCc19GWTnT4c0iEi275LCCGEuJFMaPjie85AsVxFb4aXo7kKf/KEEEI6xkXHLur0IZAOw9YMIYQQQjoGCxFCCCGEdAwWIoQQQgjpGCxECCGEENIxWIgQQgghpGOwECGEEEJIx2AhQgghhJCOwUKEEEIIIR2DhQghhBBCOgYLEUIIIYR0DBYihBBCCOkYLEQIIYQQ0jFYiBBCCCGkY8R6+66u6wCA8fHxDh8JIYQQQvwirtviOu5GrAuRiYkJAMDKlSs7fCSEEEIICcrExASGhoZcH6PpfsqVDlGtVrFnzx4MDAxA07Smn2d8fBwrV67Erl27MDg4GOIREis81+2D57p98Fy3F57v9hHVudZ1HRMTE1i2bBkSCXcXSKwVkUQigRUrVoT2fIODg3xTtwme6/bBc90+eK7bC893+4jiXHspIQKaVQkhhBDSMViIEEIIIaRjzIlCJJvN4uMf/ziy2WynD2XWw3PdPniu2wfPdXvh+W4fcTjXsTarEkIIIWR2MycUEUIIIYTEExYihBBCCOkYLEQIIYQQ0jFYiBBCCCGkY8yJQuRzn/scVq9ejVwuh3POOQcPPvhgpw+p69m0aRPOOussDAwMYPHixbjqqqvw3HPPmR6Tz+dx/fXXY8GCBejv78db3/pW7N+/v0NHPDv41Kc+BU3TcOONN8q/43kOl927d+Pd7343FixYgJ6eHpxyyil4+OGH5b/ruo5bbrkFS5cuRU9PDy677DJs27atg0fcnVQqFfz5n/851qxZg56eHqxbtw5/9Vd/ZdpNwnPdHHfffTfe8IY3YNmyZdA0DT/4wQ9M/+7nvB4+fBjvete7MDg4iOHhYbzvfe/D5ORkNAesz3K+9a1v6ZlMRv/KV76iP/300/p1112nDw8P6/v37+/0oXU1r3vd6/SvfvWr+lNPPaU/9thj+pVXXqmvWrVKn5yclI/5gz/4A33lypX65s2b9Ycfflg/99xz9fPPP7+DR93dPPjgg/rq1av1V73qVfqHPvQh+fc8z+Fx+PBh/ZhjjtGvueYa/YEHHtB37Nih/+IXv9C3b98uH/OpT31KHxoa0n/wgx/ojz/+uP7GN75RX7NmjT4zM9PBI+8+PvGJT+gLFizQf/KTn+gvvvii/p3vfEfv7+/X//Ef/1E+hue6OX7605/qH/vYx/Tvfe97OgD9+9//vunf/ZzXyy+/XD/11FP1+++/X7/nnnv09evX6+985zsjOd5ZX4icffbZ+vXXXy//u1Kp6MuWLdM3bdrUwaOafRw4cEAHoN911126ruv60aNH9XQ6rX/nO9+Rj3n22Wd1APp9993XqcPsWiYmJvRjjz1Wv+222/TXvOY1shDheQ6XP/mTP9EvvPBCx3+vVqv6kiVL9L/927+Vf3f06FE9m83q//Ef/9GOQ5w1vP71r9ff+973mv7uLW95i/6ud71L13We67CwFiJ+zuszzzyjA9Afeugh+Zif/exnuqZp+u7du0M/xlndmikWi9iyZQsuu+wy+XeJRAKXXXYZ7rvvvg4e2exjbGwMADB//nwAwJYtW1AqlUznfsOGDVi1ahXPfRNcf/31eP3rX286nwDPc9j86Ec/wplnnom3ve1tWLx4MTZu3Igvf/nL8t9ffPFF7Nu3z3S+h4aGcM455/B8B+T888/H5s2b8fzzzwMAHn/8cdx777244oorAPBcR4Wf83rfffdheHgYZ555pnzMZZddhkQigQceeCD0Y4r10rtWGR0dRaVSwcjIiOnvR0ZGsHXr1g4d1eyjWq3ixhtvxAUXXICTTz4ZALBv3z5kMhkMDw+bHjsyMoJ9+/Z14Ci7l29961t45JFH8NBDDzX8G89zuOzYsQOf//zn8eEPfxh/+qd/ioceegh/+Id/iEwmg6uvvlqeU7vPFJ7vYNx0000YHx/Hhg0bkEwmUalU8IlPfALvete7AIDnOiL8nNd9+/Zh8eLFpn9PpVKYP39+JOd+VhcipD1cf/31eOqpp3Dvvfd2+lBmHbt27cKHPvQh3Hbbbcjlcp0+nFlPtVrFmWeeiU9+8pMAgI0bN+Kpp57CF77wBVx99dUdPrrZxX/+53/iG9/4Br75zW/ipJNOwmOPPYYbb7wRy5Yt47meY8zq1szChQuRTCYbJgj279+PJUuWdOioZhc33HADfvKTn+COO+7AihUr5N8vWbIExWIRR48eNT2e5z4YW7ZswYEDB3D66acjlUohlUrhrrvuwj/90z8hlUphZGSE5zlEli5dihNPPNH0dyeccAJ27twJAPKc8jOldf74j/8YN910E97xjnfglFNOwXve8x780R/9ETZt2gSA5zoq/JzXJUuW4MCBA6Z/L5fLOHz4cCTnflYXIplMBmeccQY2b94s/65arWLz5s0477zzOnhk3Y+u67jhhhvw/e9/H7fffjvWrFlj+vczzjgD6XTadO6fe+457Ny5k+c+AJdeeimefPJJPPbYY/LPmWeeiXe9613y//M8h8cFF1zQMIb+/PPP45hjjgEArFmzBkuWLDGd7/HxcTzwwAM83wGZnp5GImG+BCWTSVSrVQA811Hh57yed955OHr0KLZs2SIfc/vtt6NareKcc84J/6BCt7/GjG9961t6NpvVb731Vv2ZZ57R3//+9+vDw8P6vn37On1oXc0HPvABfWhoSL/zzjv1vXv3yj/T09PyMX/wB3+gr1q1Sr/99tv1hx9+WD/vvPP08847r4NHPTtQp2Z0nec5TB588EE9lUrpn/jEJ/Rt27bp3/jGN/Te3l793//93+VjPvWpT+nDw8P6D3/4Q/2JJ57Q3/SmN3GktAmuvvpqffny5XJ893vf+56+cOFC/aMf/ah8DM91c0xMTOiPPvqo/uijj+oA9M985jP6o48+qr/88su6rvs7r5dffrm+ceNG/YEHHtDvvfde/dhjj+X4bit89rOf1VetWqVnMhn97LPP1u+///5OH1LXA8D2z1e/+lX5mJmZGf2DH/ygPm/ePL23t1d/85vfrO/du7dzBz1LsBYiPM/h8uMf/1g/+eST9Ww2q2/YsEH/0pe+ZPr3arWq//mf/7k+MjKiZ7NZ/dJLL9Wfe+65Dh1t9zI+Pq5/6EMf0letWqXncjl97dq1+sc+9jG9UCjIx/BcN8cdd9xh+/l89dVX67ru77weOnRIf+c736n39/frg4OD+rXXXqtPTExEcryarisxdoQQQgghbWRWe0QIIYQQEm9YiBBCCCGkY7AQIYQQQkjHYCFCCCGEkI7BQoQQQgghHYOFCCGEEEI6BgsRQgghhHQMFiKEEEII6RgsRAghhBDSMViIEEIIIaRjsBAhhBBCSMdgIUIIIYSQjvH/A88Jwld5pWR1AAAAAElFTkSuQmCC" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "\n", + "\n", + "def transition_fn(z, t, t_next):\n", + " mean = z * (1 + jnp.abs(z) * jnp.sin(2 * jnp.pi * t / 10))\n", + " scale = 0.01 * jnp.ones(np.size(z))\n", + " return tfpd.MultivariateNormalDiag(mean, scale)\n", + "\n", + "\n", + "def observation_fn(z, t):\n", + " mean = jnp.abs(z)\n", + " scale = 0.001 * jnp.ones(np.size(z)) + 0.01 * (2 * jnp.pi * t / 5.)\n", + " return tfpd.MultivariateNormalDiag(mean, scale)\n", + "\n", + "\n", + "n = 1\n", + "\n", + "initial_state_prior = tfpd.MultivariateNormalDiag(jnp.zeros(n), jnp.ones(n))\n", + "\n", + "essm = ExtendedStateSpaceModel(\n", + " transition_fn=transition_fn,\n", + " observation_fn=observation_fn,\n", + " initial_state_prior=initial_state_prior,\n", + " more_data_than_params=False, # if observation is bigger than latent we can speed it up.\n", + " dt=1.\n", + ")\n", + "samples = essm.sample(jax.random.PRNGKey(0), 100)\n", + "\n", + "plt.plot(samples.t, samples.observation)\n", + "plt.show()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-14T22:35:52.766401446Z", + "start_time": "2024-08-14T22:35:52.094140163Z" + } + }, + "id": "5f0b000d8cfe6d2b", + "execution_count": 2 + }, + { + "cell_type": "code", + "outputs": [ + { + "data": { + "text/plain": "
", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjUAAAGdCAYAAADqsoKGAAAAP3RFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMS5wb3N0MSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8kixA/AAAACXBIWXMAAA9hAAAPYQGoP6dpAACOj0lEQVR4nO3deXxU1d348c+dmcxM9p0sJBBQFJBVEESNglJxqdUGXFAfKCrWPuLD8utTpYtgtYW6Qi2PtrYureCG0SpaWkBQVAQFQUF2wxayACH7MsnM/f1xc29mspFJZkv4vl+veb3IzJk7J5fJ3O+c8z3fo6iqqiKEEEII0c2Zgt0BIYQQQghfkKBGCCGEED2CBDVCCCGE6BEkqBFCCCFEjyBBjRBCCCF6BAlqhBBCCNEjSFAjhBBCiB5BghohhBBC9AiWYHcgUFwuF8ePHyc6OhpFUYLdHSGEEEJ0gKqqVFRUkJ6ejsnU/ljMWRPUHD9+nMzMzGB3QwghhBCdcPToUTIyMtptc9YENdHR0YB2UmJiYoLcGyGEEEJ0RHl5OZmZmcZ1vD1nTVCjTznFxMRIUCOEEEJ0Mx1JHZFEYSGEEEL0CBLUCCGEEKJHkKBGCCGEED3CWZNTI4QQwv+cTif19fXB7oboRsxmMxaLxSflViSoEUII4ROVlZUcO3YMVVWD3RXRzURERJCWlobVau3ScSSoEUII0WVOp5Njx44RERFBcnKyFDkVHaKqKg6HgxMnTpCXl8eAAQPOWGCvPRLUCCGE6LL6+npUVSU5OZnw8PBgd0d0I+Hh4YSFhXH48GEcDgd2u73Tx5JEYSGEED4jIzSiM7oyOuNORmpCmNPpZOPGjRQUFJCWlkZ2djZmsznY3RJCCCFCkgQ1ISo3N5fZs2dz7Ngx476MjAyWLl1KTk5OEHsmhBBChCaZfgpBubm5TJkyxSOgAcjPz2fKlCnk5uYGqWdCCNGzqKrKvffeS0JCAoqisH37dsaPH8+cOXOMNllZWSxZsiRofWyNoii8++67bT7e/Hc4W0hQE2KcTiezZ89udUmkft+cOXNwOp2B7poQQvQ4q1ev5uWXX2bVqlUUFBQwZMgQcnNzefTRR9t8zpkCiu5ow4YNKIpCaWmpz4+9cOFCRowY4fPjtkaCmhCzceNGY4TGZI8mZuwUzFGJxuOqqnL06FE2btwYrC4KIUSPcfDgQdLS0rjkkktITU3FYrGQkJDQoR2hu0qKFPqeBDUhpqCgwPh39KgfEj/+J8SM+XG77YQQItSoqkq1oyEot44W//vJT37CAw88wJEjR1AUhaysLKD9qRu9zY9//GOP5wD885//5MILL8Rut9O/f38eeeQRGhoajMcVReG5557jRz/6EZGRkfzud7/r0PP279/P5Zdfjt1uZ/DgwaxZs6bj/xGN/vGPfzB69Giio6NJTU3l9ttvp7i4GIBDhw4xYcIEAOLj41EUhZ/85CcAuFwuFi1aRL9+/QgPD2f48OGsXLnSOK4+wrNu3TpGjx5NREQEl1xyCXv37gXg5Zdf5pFHHmHHjh0oioKiKLz88ste97+jJFE4xKSlpRn/DkvqA4A5Mr7ddkIIEWpq6p0MfvjfQXnt7347iQjrmS9vS5cu5ZxzzuEvf/kLX375ZYdWl3755Zf06tWLl156iWuuucZ4zsaNG5k2bRp//OMfyc7O5uDBg9x7770ALFiwwHj+woULWbx4MUuWLMFisZzxeS6Xi5ycHFJSUti8eTNlZWWdypWpr6/n0Ucf5fzzz6e4uJh58+bxk5/8hA8//JDMzEzefvttJk+ezN69e4mJiTFqDS1atIhXX32V559/ngEDBvDJJ59w5513kpyczBVXXGEc/1e/+hVPPfUUycnJ3Hfffdx111189tln3HrrrezcuZPVq1ezdu1aAGJjY73uf0dJUBNisrOzycjIID8/H0ucFriYwpuGQRVFISMjg+zs7GB1UQgheoTY2Fiio6Mxm82kpqZ26DnJyckAxMXFeTznkUce4aGHHmL69OkA9O/fn0cffZRf/OIXHkHN7bffzowZM4yf77rrrnaft3btWvbs2cO///1v0tPTAfj973/Ptdde69Xvetdddxn/7t+/P3/84x+56KKLqKysJCoqioSEBAB69epFXFwcAHV1dfz+979n7dq1jBs3znjup59+yp///GePoOZ3v/ud8fNDDz3E9ddfT21tLeHh4URFRWGxWDp8jrtCgpoQYzabWbp0KVOmTCEsQXsDm+xRQFNRqyVLlki9GiFESAsPM/PdbycF7bUDbceOHXz22WfGlBJoCz9qa2uprq4mIiICgNGjR3v1vN27d5OZmWkENIARYHhj69atLFy4kB07dnD69GlcLhcAR44cYfDgwa0+58CBA1RXV/ODH/zA436Hw8HIkSM97hs2bJjxb30mobi4mD59+njd166QoCYE5eTk8OKKlSzcbgO0hGHQ6tQsWbJE6tQIIUKeoigdmgLqKSorK3nkkUda/Xx2L/sfGRnZqed1RVVVFZMmTWLSpEksX76c5ORkjhw5wqRJk3A4HG0+r7KyEoAPPviA3r17ezxms9k8fg4LCzP+rX8B1wOnQDp73nHdzAUXj4ftmwCITkpl/fr1UlFYCCFCQFhYWIuyGhdeeCF79+7l3HPP9epYZ3reoEGDOHr0qFFZHuCLL77w6jX27NnDqVOnWLx4MZmZmQB89dVXHm303bHdf6/Bgwdjs9k4cuSIx1STt6xWa8DKkEhQE6LyTlYZ/651KmRffgVmk+ypIoQQwZaVlcW6deu49NJLsdlsxMfH8/DDD/PDH/6QPn36MGXKFEwmEzt27GDnzp089thjbR7rTM+bOHEi5513HtOnT+eJJ56gvLycX/3qV171t0+fPlitVp599lnuu+8+du7c2aIOT9++fVEUhVWrVnHdddcRHh5OdHQ0P//5z5k7dy4ul4vLLruMsrIyPvvsM2JiYow8oI6cr7y8PLZv305GRgbR0dEtRnp8RZZ0h6hDbkENQEWt1DMQQohQ8NRTT7FmzRoyMzON3JJJkyaxatUq/vOf/3DRRRdx8cUX88wzz9C3b992j3Wm55lMJt555x1qamoYM2YM99xzj0f+TUckJyfz8ssv89ZbbzF48GAWL17Mk08+6dGmd+/eRrJzSkoKs2bNAuDRRx/lN7/5DYsWLWLQoEFcc801fPDBB/Tr16/Drz958mSuueYaJkyYQHJyMq+99ppX/feGonZ0QX83V15eTmxsLGVlZcTExAS7O2d0//JtfPBtUy2aDT8fT1ZSZDvPEEKI4KmtrSUvL49+/fr5LBdEnD3ae/94c/2WkZoQlddspKa0RkZqhBBCiPZIUBOCVFXl0CktqImwaonBpdVtZ6gLIYQQQoKakHSiso5qhxOTAheka0NtZTJSI4QQQrRLgpoQdOhkNQDpceEkR2sZ4qXVEtQIIYQQ7ZGgJgTpK5/6JUUSG67VDpCgRgjRHZwla0+Ej/nqfSNBTQjS82n6JkYQF6FVaSytkZwaIUTo0guDtlehVoi2VFdrMxTulYk7Q4rvhSA9qMlKjMTp0qJXyakRQoQyi8VCREQEJ06cICwsDJNJvjOLM1NVlerqaoqLi4mLi+ty1XwJakJQXmNOTb+kSE5W1gFQJtNPQogQpigKaWlp5OXlcfjw4WB3R3QzzXc97ywJakKMqqocNqafIql3aiM1UqdGCBHqrFYrAwYMkCko4ZWwsDCf7WsoQU2IOVHRtJy7T0KEMVIjdWqEEN2ByWSSisIiaGTSM8TolYR7x4djtZiIDdeSpiSnRgghhGifBDUh5vApLZ8mK1Hb58lY/VRdL0slhRBCiHZIUBNi8txWPgHENdapaXCpVDmcQeuXEEIIEeokqAkxepKwviO3PcyE1aL9N0lejRBCCNE2CWpCTNNy7ghAWyYZJ3k1QgghxBlJUBNCmi/n1ul5NVKrRgghhGibBDUhxH05d2Z8hHG/nlcjtWqEEEKItklQE0L05dwZ8RFGHg1ATHjTCighhBBCtK5TQc2yZcvIysrCbrczduxYtmzZ0mbbF154gezsbOLj44mPj2fixIke7evr63nwwQcZOnQokZGRpKenM23aNI4fP+5xnKysLBRF8bgtXry4M90PSU6nk9WfbwMgRqnD6Wxa6SSbWgohhBBn5nVQ88YbbzBv3jwWLFjAtm3bGD58OJMmTaK4uLjV9hs2bGDq1KmsX7+eTZs2kZmZydVXX01+fj6g7cy5bds2fvOb37Bt2zZyc3PZu3cvP/rRj1oc67e//S0FBQXG7YEHHvC2+yEpNzeXrKwslrywHIDP//NPsrKyyM3NBWhKFJaRGiGEEKJNXm+T8PTTTzNz5kxmzJgBwPPPP88HH3zAiy++yEMPPdSi/fLlyz1+/utf/8rbb7/NunXrmDZtGrGxsaxZs8ajzZ/+9CfGjBnDkSNH6NOnj3F/dHS0Tza8CiW5ublMmTIFVVVJGp0OQMPp4+Tn5zNlyhRWrlxJXNxQQKafhBBCiPZ4NVLjcDjYunUrEydObDqAycTEiRPZtGlTh45RXV1NfX09CQkJbbYpKyvTljLHxXncv3jxYhITExk5ciRPPPEEDQ0NbR6jrq6O8vJyj1uocTqdzJ4926gUHBbfFNTo982ZM4cYuxZ7yvSTEEII0TavRmpOnjyJ0+kkJSXF4/6UlBT27NnToWM8+OCDpKenewRG7mpra3nwwQeZOnUqMTExxv3/8z//w4UXXkhCQgKff/458+fPp6CggKeffrrV4yxatIhHHnmkg79ZcGzcuJFjx44ZP1vi0gCoP63lE6mqytGjRyk4fBCQkRohhBCiPQHdpXvx4sW8/vrrbNiwodVdXOvr67nllltQVZXnnnvO47F58+YZ/x42bBhWq5Wf/vSnLFq0CJvN1uJY8+fP93hOeXk5mZmZPvxtuq6goMD4tzk6CZPVjupsoKHMMz+prqIEiJTie0IIIUQ7vApqkpKSMJvNFBUVedxfVFR0xlyXJ598ksWLF7N27VqGDRvW4nE9oDl8+DAfffSRxyhNa8aOHUtDQwOHDh3i/PPPb/G4zWZrNdgJJWlpaca/w5K03KH608fB5TmtlpmSCN/XSlAjhBBCtMOrnBqr1cqoUaNYt26dcZ/L5WLdunWMGzeuzec9/vjjPProo6xevZrRo0e3eFwPaPbv38/atWtJTEw8Y1+2b9+OyWSiV69e3vwKISU7O5uMjAwURSEsqS8A9SePGI8rikJmZibjLxkDyPSTEEII0R6vp5/mzZvH9OnTGT16NGPGjGHJkiVUVVUZq6GmTZtG7969WbRoEQB/+MMfePjhh1mxYgVZWVkUFhYCEBUVRVRUFPX19UyZMoVt27axatUqnE6n0SYhIQGr1cqmTZvYvHkzEyZMIDo6mk2bNjF37lzuvPNO4uPjfXUuAs5sNrN06VKmTJmCVR+paQxqFEUBYMmSJSREaVN1NfVOauud2MPMwemwEEIIEcK8DmpuvfVWTpw4wcMPP0xhYSEjRoxg9erVRvLwkSNHMJmaBoCee+45HA4HU6ZM8TjOggULWLhwIfn5+bz33nsAjBgxwqPN+vXrGT9+PDabjddff52FCxdSV1dHv379mDt3rkfOTHeVk5PDypUrmfuhlhxcf/IwABkZGSxZsoScnBxcLhWTAi4VymvqJagRQgghWqGo+trhHq68vJzY2FjKysrOmK8TaKqqMmTBv6lyOJlzfiUj+qWQnZ2N2dwUvIz87X84XV3Pf+Zeznkp0UHsrRBCCBE43ly/A7r6SbTueFktVQ4nYWaF+6fdTJi5ZapTXISV09X1klcjhBBCtEE2tAwB+4oqAOiXFNlqQAMQa2xqKQX4hBBCiNZIUBMC9jcGNQPamVbSN7WUZd1CCCFE6ySoCQH7iioBOK9X20GNPlIjQY0QQgjROglqQoA+UnNeSlSbbeKM6ScJaoQQQojWSFATZC6Xyv5ibaSmvemn2AgrIJtaCiGEEG2RoCbI8ktrqHY4sZpNZCVGtNlORmqEEEKI9klQE2T7i7Wpp/7JkVjaWPkEkigshBBCnIkENUGmJwm3N/UETUGNjNQIIYQQrZOgJsj0GjXn9Wo7SRggNlxyaoQQQoj2SFATZPs7OFITKzk1QgghRLskqAkil0vlQOPKp/aWc0PT9FNFbQNO11mxXZcQQgjhFQlqgujY6Rpq6p1YLSb6JLS98gmaRmpA26lbCCGEEJ4kqAkiPZ+mf1L7K58Awswmomza/qOlEtQIIYQQLUhQE0T7ivVKwu3n0+hkU0shhBCibRLUBJGeJHymfBqdsaxbRmqEEEKIFizB7sDZyOl0snHjRjbvKQfMnJMc2aHnGQX4ZAWUEEII0YKM1ARYbm4uWVlZTLjyKo6VNwBw7603kJube8bnyvSTEEII0TYJagIoNzeXKVOmcOzYMSyxKZjCbLjq68jfu4MpU6acMbBpKsAnIzVCCCFEcxLUBIjT6WT27NmoqlZjJiypDwANJcdQXU4A5syZg9PpbPMYsv+TEEII0TYJagJk48aNHDt2zPhZD2ocJ48AoKoqR48eZePGjW0eQ9+pW3JqhBBCiJYkqAmQgoICj5/DEtIBaDh1rN127mT1kxBCCNE2CWoCJC0tzeNnkz0GAGd1abvt3Bk5NZIoLIQQQrQgQU2AZGdnk5GRgaIoAJjCtdo0rlqtVo2iKGRmZpKdnd3mMWLsZgCOFp1iw4YN7ebfCCGEEGcbCWoCxGw2s3TpUkALYEz2xqCmpsIIdJYsWYLZbG71+bm5udz64xsAKCypYMKECWRlZXVoKbgQQghxNpCgJoBycnJYuXIlvXv3xmTXtkZw1laQkZHBypUrycnJafV5+lLwgsMHADCFa8/Nz8/v0FJwIYQQ4mwgQU2A5eTkkJeXhz06AYDXXv4reXl5bQY07kvBXTWNU1UmM4o13Fgefqal4EIIIcTZQIKaIHC4oEGLR7jmysvbnHICz6XgakMdaoOWJKyP9HRkKbgQQghxNpCgJghKG+vMhJkVIqxtBzTQcom3q64KAJMtot12QgghxNlGgpog0IOa2HCrkSTcluZLvF11NQCYrBHtthNCCCHONhLUBEFpjTaFFBt+5k3Smy8Fbz5S05Gl4EIIIcTZQIKaINC3OYiLsJ6xbfOl4KqjWvu3LaJDS8GFEEKIs4UENUGgb3Og7+V0Ju5LwV11WlBjskaccSm4EEIIcTaRoCYIjJyaiI4FNaAFNocOHWLieG2a6Wez57a7FFwIIYQ420hQEwRlxkjNmaef3JnNZs7t2xuAlN59ZcpJCCGEcCNBTRCUNSYKx3kxUqOLsmnJxRW1DT7tkxBCCNHddSqoWbZsGVlZWdjtdsaOHcuWLVvabPvCCy+QnZ1NfHw88fHxTJw4sUV7VVV5+OGHSUtLIzw8nIkTJ7J//36PNiUlJdxxxx3ExMQQFxfH3XffTWVlZWe6H3SlRqJwJ4IauxbUVNZJUCOEEEK48zqoeeONN5g3bx4LFixg27ZtDB8+nEmTJlFcXNxq+w0bNjB16lTWr1/Ppk2byMzM5OqrryY/P99o8/jjj/PHP/6R559/ns2bNxMZGcmkSZOora012txxxx3s2rWLNWvWsGrVKj755BPuvffeTvzKwddUp8b7oCbarj2nUkZqhBBCCE+ql8aMGaPef//9xs9Op1NNT09XFy1a1KHnNzQ0qNHR0eorr7yiqqqqulwuNTU1VX3iiSeMNqWlparNZlNfe+01VVVV9bvvvlMB9csvvzTa/Otf/1IVRVHz8/M79LplZWUqoJaVlXWovT9ds+QTte+Dq9QNe4u9fu47246pfR9cpd7+wiY/9EwIIYQILd5cv70aqXE4HGzdupWJEyca95lMJiZOnMimTZs6dIzq6mrq6+tJSNA2dMzLy6OwsNDjmLGxsYwdO9Y45qZNm4iLi2P06NFGm4kTJ2Iymdi8eXOrr1NXV0d5ebnHLVSUVTfm1HRipEbPqZGRGiGEEMKTV0HNyZMncTqdpKSkeNyfkpJCYWFhh47x4IMPkp6ebgQx+vPaO2ZhYSG9evXyeNxisZCQkNDm6y5atIjY2FjjlpmZ2aH+BYJRp6YLOTUVklMjhBBCeAjo6qfFixfz+uuv884772C32/36WvPnz6esrMy4HT161K+v11F1DU6qHU6gczk1MlIjhBBCtO7Mmw+5SUpKwmw2U1RU5HF/UVERqamp7T73ySefZPHixaxdu5Zhw4YZ9+vPKyoq8tiUsaioiBEjRhhtmiciNzQ0UFJS0ubr2mw2bDZbh3+3QNFr1ChKU9KvN6Jl9ZMQQgjRKq9GaqxWK6NGjWLdunXGfS6Xi3Xr1jFu3Lg2n/f444/z6KOPsnr1ao+8GIB+/fqRmprqcczy8nI2b95sHHPcuHGUlpaydetWo81HH32Ey+Vi7Nix3vwKQafv+xRjD8Nsan+H7tboIzXVDidOl+rTvgkhhBDdmVcjNQDz5s1j+vTpjB49mjFjxrBkyRKqqqqYMWMGANOmTaN3794sWrQIgD/84Q88/PDDrFixgqysLCMHJioqiqioKBRFYc6cOTz22GMMGDCAfv368Zvf/Ib09HRuuukmAAYNGsQ111zDzJkzef7556mvr2fWrFncdtttpKen++hUBEZZF/JpoCmnBrTRms5MYQkhhBA9kddBza233sqJEyd4+OGHKSwsZMSIEaxevdpI9D1y5AgmU9MA0HPPPYfD4WDKlCkex1mwYAELFy4E4Be/+AVVVVXce++9lJaWctlll7F69WqPvJvly5cza9YsrrrqKkwmE5MnT+aPf/xjZ37noDIK73UyGLFZzFgtJhwNLipq6yWoEUIIIRopqqqeFXMY5eXlxMbGUlZWRkxMTND6sXLrMX7+1g4uPy+Zv981plPHGPXoGk5VOVg9J5uBqcH7XYQQQgh/8+b6LXs/BVhpF2rU6IytEmQFlBBCCGGQoCbAuppTA26bWsoKKCGEEMLgdU6N6Jqu5tRA8GrVOJ1ONm7cSEFBAWlpaWRnZ2M2mwPaByGEEKItEtQEmF5NODbC2uljBKNWTW5uLrNnz+bYsWPGfRkZGSxdupScnJyA9UMIIYRoi0w/BZgvcmoCvVN3bm4uU6ZM8QhoAPLz85kyZQq5ubkB6YcQQgjRHglqAsynOTW19T7pU3ucTiezZ8+mtUVy+n1z5szB6XT6vS9CCCFEeySoCTA9qOlKfZlAbmq5ceNGjxGaxGtnEz/xp8bPqqpy9OhRNm7c6Pe+CCGEEO2RoCbAjERhH4zUBGL6qaCgwPi3KSKOqGE/IGbUDZijEttsJ4QQQgSDBDUB5HSplNfqIzXdI1HYfZNRc0RT0SNr+nltthNCCCGCQYKaAKqorUdPTenS9JMtcEFNdnY2GRkZKIqCyR5t3G9L04IaRVHIzMwkOzvb730RQggh2iNBTQDpU0+RVm3/ps5qShT2f1BjNptZunSp9m+3kRpb2vkoirbL+JIlS6RejRBCiKCToCaASo2VT52fegK3Jd0BqlOTk5PDypUrSUzrY9xnTT2XjMxMVq5cKXVqhBBChAQpvhdAeo2aru6sHR2EvZ9ycnIojB/K4//eB4DJFsGHm3YwOD0uYH0QQggh2iMjNQHkixo1ENg6Ne7KmgVRO/MrAvr6QgghRHskqAkgXyznhqY6NVUOJ05Xy6J4/lJapfXfatbeNl8fLQ3YawshhBBnIkFNAOlBTVeWc0PTSA1AlSNwU1CnG6fPxvRLAGCHBDVCCCFCiAQ1AeSr6SebxUSYWVt5FMi8Gj0oG39+MgB7iyqoccj2CEIIIUKDBDUBVFrT9c0sQasNE8haNTp9pGZQWgwpMTacLpWdx8sC9vpCCCFEeySoCaCy6q7v+6TTl3UHolaN7rRb/4dnxAEyBSWEECJ0SFATQKU+mn6CwFYVBm3jyrLGkab4SCsj+sQBkiwshBAidEhQE0BNdWq6ligMbjt1B2hZd5XDSb1TW2kVHxHGCBmpEUIIEWIkqAkgXyUKA0QHcKdugNNVWkBmtZgIDzMzNCMWRYFjp2s4WVkXkD4IIYQQ7ZGgJkBUVfVZnRpoGqkJ1PST3vf4iDAURSHaHsa5yVGAjNYIIYQIDRLUBEiVw0lDY6G8OF9MPwVwU0toWvkU77Zv1YjMOAC2S1AjhBAiBEhQEyB6Po3VYsIe1vXTHuiRGj2ocR9lGt4Y1Kzfkcdrr73Ghg0bcDqlbo0QQojgkKAmQIypp3Bt+qarYvSdugM0UtPU/6aRmpL92wDYkV/G7bffwYQJE8jKyiI3NzcgfRJCCCHcSVATIOU+TBKGwC/pNnJqIrX+5+bmMmfGLbjq6zDbo7DEpwGQn5/PlClTJLARQggRcBLUBIhRo8YH+TTgllMT8OknK06nk9mzZ6M6G3AUHQDAln4+oCVEA8yZM0emooQQQgSUBDUBYmxm6auRmgDXqSk1EoXD2LhxI8eOHQPAcXwfAJFDrjTaqqrK0aNH2bhxY0D6JoQQQoAENQHjq32fdAGvU2MsR7dSUFBg3F+x/V+46usIzxpJ1PBJHs9xbyeEEEL4mwQ1AeLLfZ8gGHVqmpZ0p6WlGfc3nD5O6SevaI9deQ+W2BTjsaKiIpmCEkIIETAS1ASILwvvgVuicIBHauIjwsjOziYjI8NYxVXx1fvUHvkWkzWcxOvmANr9c+fOldVQQgghAkaCmgDRp59iI3yTKKzv0l3paMDVWNTPn9wThc1mM0uXLgVoDGxUTn24BJejBnufoUSPvsF4nqyGEkIIESgS1ASIe50aX4hunH5SVaiu9+8UT4PTZVQu1keacnJyWLlyJb1799balBVx+qO/aW0un44lIaOxf7IaSgghRGBIUBMgvtzMEsBmMWExadM8/p6C0vsOnkFZTk4Ohw4d4plnntH6sWM1NXnbMIXZSLp+Lvo0lKyGEkIIEQgS1ASA0+mk6HQlAAe/+8YnIxaKogRsWbeeTxNtt2Axe75lzGYzKSlNycGn/rUUV30ttvTzCUvM8Ggrq6GEEEL4U6eCmmXLlpGVlYXdbmfs2LFs2bKlzba7du1i8uTJZGVloSgKS5YsadFGf6z57f777zfajB8/vsXj9913X2e6H1C5ublkZWVxqqIGgJnT7/BZ8mygCvCVtrKZpTv31VDOilPUnzwKgCWhd5vthBBCCF/zOqh54403mDdvHgsWLGDbtm0MHz6cSZMmUVxc3Gr76upq+vfvz+LFi0lNTW21zZdffklBQYFxW7NmDQA333yzR7uZM2d6tHv88ce97X5A5ebmMmXKFI4VFGGy2gFw1lT4LHk2UCug3Fc+tab5aqiG0/kAhMVrQY2iKGRmZpKdne3XfgohhDi7eR3UPP3008ycOZMZM2YwePBgnn/+eSIiInjxxRdbbX/RRRfxxBNPcNttt2Gz2Vptk5ycTGpqqnFbtWoV55xzDldccYVHu4iICI92MTEx3nY/YIytBFQVsz0KANXlRHVU+yx5NjpAtWrcVz61pvlqqPrTxwGwJKQbgc6SJUswm81+7acQQoizm1dBjcPhYOvWrUycOLHpACYTEydOZNOmTT7pkMPh4NVXX+Wuu+5qsZv18uXLSUpKYsiQIcyfP5/q6uo2j1NXV0d5ebnHLZDctxIwhWvBl6u20njcF8mz0QHaqbvsDCM14LkaqqFEC2rC4tPJyMhg5cqV5OTk+LWPQgghhMWbxidPnsTpdHokhgKkpKSwZ88en3To3XffpbS0lJ/85Cce999+++307duX9PR0vvnmGx588EH27t3b5hTOokWLeOSRR3zSp85wT4o1RycC4KwsabedtwKVU3OmkRpdTk4ON954I6+8v4HfflFL6oBhbP1HnozQCCGECAivgppA+Nvf/sa1115Lenq6x/333nuv8e+hQ4eSlpbGVVddxcGDBznnnHNaHGf+/PnMmzfP+Lm8vJzMzEz/dbwZ96RYix7UVJxqt523jK0SApRT05Hl6GazmclXX85vv/gPpXUqNQ0qURLTCCGECACvpp+SkpIwm80UFRV53F9UVNRmErA3Dh8+zNq1a7nnnnvO2Hbs2LEAHDhwoNXHbTYbMTExHrdAck+eNUdpQU1DZVNQ44vkWWNTyzr/Luk+0+qn5mIjwkiI1NoeOlnlt34JIYQQ7rwKaqxWK6NGjWLdunXGfS6Xi3Xr1jFu3Lgud+all16iV69eXH/99Wdsu337diB0lwm7J8+ao5OAppEaXyXPGtNPfh+p0aefOl44sF9SJAB5EtQIIYQIEK9XP82bN48XXniBV155hd27d/Ozn/2MqqoqZsyYAcC0adOYP3++0d7hcLB9+3a2b9+Ow+EgPz+f7du3txhhcblcvPTSS0yfPh2LxXNW7ODBgzz66KNs3bqVQ4cO8d577zFt2jQuv/xyhg0b1pnfOyD05NmoZG0qzVlxEsBnybNG8T2/16nRE4U7vm9VVqIW1MhIjRBCiEDxOqfm1ltv5cSJEzz88MMUFhYyYsQIVq9ebSQPHzlyBJOpKVY6fvw4I0eONH5+8sknefLJJ7niiivYsGGDcf/atWs5cuQId911V4vXtFqtrF27liVLllBVVUVmZiaTJ0/m17/+tbfdD7icnBxeOPIJuwsr+H//fQ/jz59Pdna2T5JnA1enxrvpJ4D+yTJSI4QQIrA6lSg8a9YsZs2a1epj7oEKaNWC9bos7bn66qvbbJeZmcnHH3/sdT9DRVFFHQC3/mgSg9J8l9tjLOkO0EiNN9NP+khN3ikJaoQQQgSG7P3kZ7X1TkqqtJGO1Bi7T48dHYDVTzUOJ3UNLkByaoQQQoQ2CWr8rLhcG6WxWUw+26FbZ0w/+XGkRp96spgU4/U6IispAtBGeU43BnVCCCGEP0lQ42eF5bUApMbaW1RI7qpA7NLtXnjPm/5HWC3GyJRMQQkhhAgECWr8TA9qUnw89QTudWoaOpS31BmlHdgioS36aI2sgBJCCBEIEtT4WVFZ40iNH4IafaTGpUK1o/MbY7anMyufdP2StI08Ja9GCCFEIEhQ42cFZU3TT74WHmbGbNKmhPyVV+PNFgnN9WscqZGgRgghRCBIUONnReX+G6lRFMXvVYXLZKRGCCFENyFBjZ+5Jwr7g79XQPlipObQySq/5fwIIYQQOglq/KywzH+JwuD/WjXuq5+8lZkQgUmBKoeTE40FCIUQQgh/kaDGj1wutWn6ye8jNf5Z1t2V1U82i5ne8eGATEEJIYTwPwlq/OhUlYMGl4qiQK9om19eQ18BVR6CIzUgeTVCCCECR4IaP9JHaZKibISZ/XOqI63axpiffL6FDRs24HT6dml3V0ZqAPolNq6AkgJ8Qggh/EyCGj8q9GONGoDc3Fzez30LgFffWMmECRPIysoiNzfXZ69h1KmJ7OxITeMeUCckqBFCCOFfEtT4UYEfqwnn5uYyZcoUKkqKAVCsWu5Kfn4+U6ZM8Ulg43KplNV0fvUTQFZjUHNIRmqEEEL4mQQ1fqRXE07zcZKw0+lk9uzZqKqK6qgBwGTTpnn0pdNz5szp8lRUeW09+krsuPDOjdT0b8yp+f5EJctXvOaXKTIhhBACJKjxK3/VqNm4cSPHjh0DwFVXDYDJGmE8rqoqR48eZePGjV16Hb1GTaTVjNXSubfK5g2rUZ0NNLhg+s9m+2WKTAghhAAJavyqyE/TTwUFBca/XXWVAJjCo9tt1xldXfmUm5vLrTdPoaFU60dYQm/At1NkQgghhE6CGj8q8FOicFpamvFvZ1UpAObI+HbbecvpdPLJF1sBsLjqvJ4ycp8iqy85rh0nPh3w7RSZEEIIoZOgxo+K/LSZZXZ2NhkZGSiKgrPqNADmiDjjcUVRyMzMJDs7u1PHz83NJSsri1//dhEAe3Z85fWUkfsUWf3pfKBppAZ8N0UmhBBC6CSo8ZOqugYqGvdj8nVQYzabWbp0KQCu6lIATBExoJhQFG3X7iVLlmA2m70+tr6q6tixY5jCtSRfV02F11NG7lNfDfpITVzLkaOuTpEJIYQQOglq/ERPEo6yWYytDHwpJyeHlStXkhYfjaq6UExmTOExZGRksHLlSnJycrw+pvuUEYApPEa7v6bC6ykj96mvhnJt2bklJrnddkIIIURXSFDjJ00bWfpnewTQAptDed8Ta9P+G196bSV5eXmdCmjAc8oIwGzXko9dtRWAd1NG7lNkRlAT28t4vKtTZEIIIURzEtT4SaGf8mmaM5vNpCdowUfWwGGdmnLSNZ8KMkcnAhh5O221a6tfxhRZxUkATLZIFFtkl6fIhBBCiNZIUOMnRo2amHC/v1Zy42aZJyrqunSc5lNBltgUABrKitpt1xZ9iiy9VxLO6jLtmDHJXZoiE0IIIdoiQY2fFBmF9/w3/aRLitJe42Rl14Ia9ykjaJou0oOazkwZ5eTkcOjQIfqnakvO//CnF7o0RSaEEEK0RYIaP/FXjZrW+Gqkxn3KyBwejcmm7dvkLCvu0pSR2WxmYKaWJJyYea5MOQkhhPALCWr8xF/VhFuTFKVV/O3qSA24rao6dwgAzsrTqA2OLk8Zpcc1brh5uqbLfRRCCCFa4/u1xgJoShROiw1gTo0PghrQAhv7uWP57xXbyeoVzT/Wryc7O7tLIyy9G4OaY6US1AghhPAPCWr8oN7pMgKMlEDm1FQ4fHbM/FKt/yMG9GH8+JFdPl5GvBbUHJegRgghhJ/I9JMfnKioQ1XBYlJIivR/UOPrkRqAY6e13b/1YKSrZPpJCCGEv0lQ42NOp5MPN3wOQIwVVNXl99fUR2pOVzuod/rm9Y41Bh++Cmr06afiijrqGmQTSyGEEL4nQY0P6RtBzn7wYQCOH/zO640gOyM+worZpKCqUFLlmymopqAmwifHS4i0Yg/T3m56vpEQQgjhSxLU+Ij7RpBGJd6KU15vBNkZZpNCQqS2Aqqry7qhcTuExumnTB+N1CiKIlNQQggh/EqCGh9ovhGkOTpJu7/ylNcbQXZWcpTv8mpOV9dT7dD6qgciviAroIQQQviTBDU+0HwjSEtU00gNeLcRZGclResroLoe1OhJwr2ibdjDfFcoT1ZACSGE8CcJanyg+QaPlsQMoOWeSR3ZCLKzfDlSc7RECzoyE3yTT6NLj5XpJyGEEP4jQY0PuG/waLJFYk3pD0Ddse/abOdrSdGNVYV9UKvG18u5db0bj5cvIzVCCCH8oFNBzbJly8jKysJutzN27Fi2bNnSZttdu3YxefJksrKyUBSFJUuWtGizcOFCFEXxuA0cONCjTW1tLffffz+JiYlERUUxefJkioqKWhwrGNw3grRlXoCimKg/dRRn1WmgcxtBesuXIzW+Xs6t03NqZPpJCCGEP3gd1LzxxhvMmzePBQsWsG3bNoYPH86kSZMoLi5utX11dTX9+/dn8eLFpKamtnncCy64gIKCAuP26aefejw+d+5c3n//fd566y0+/vhjjh8/HjI7PbtvBGnvMwyA2iM7Abq0EaQ3kn2YU9O08snH009GUFOLy6X69NhCCCGE10HN008/zcyZM5kxYwaDBw/m+eefJyIighdffLHV9hdddBFPPPEEt912GzZb29V1LRYLqampxi0pKcl4rKysjL/97W88/fTTXHnllYwaNYqXXnqJzz//nC+++MLbX8Ev9I0go8+5EIDao98CdHkjyI7yz0iNb4Oa1Fg7JgUcThcnq3xX/VgIIYQAL4Mah8PB1q1bmThxYtMBTCYmTpzIpk2butSR/fv3k56eTv/+/bnjjjs4cuSI8djWrVupr6/3eN2BAwfSp0+fNl+3rq6O8vJyj5u/XXnNDzEl9gHgmYd+xvr168nLywvIiJKx+qmLQY2qqn7LqQkzm0ht3LU81JKFnU4nGzZs4LXXXmPDhg1+XX4vhBDCP7wKak6ePInT6SQlJcXj/pSUFAoLCzvdibFjx/Lyyy+zevVqnnvuOfLy8sjOzqaiogKAwsJCrFYrcXFxHX7dRYsWERsba9wyMzM73b+O2pJXgqrCOcmR3DvtNsaPH+/XKSd3+khNaXU9jobOb5VwqspBbb0LRYG0OLuvumcwCvCFUF6NXgl6woQJ3H777UyYMCEglaCFEEL4Vkisfrr22mu5+eabGTZsGJMmTeLDDz+ktLSUN998s9PHnD9/PmVlZcbt6NGjPuxx6774vgSAi/sn+v21mosND8Ni0vJ3TnVhaudoiTZKkxpjx2bxfUDWO8Rq1bhXgnYXiErQQgghfMuroCYpKQmz2dxi1VFRUVG7ScDeiouL47zzzuPAgQMApKam4nA4KC0t7fDr2mw2YmJiPG7+9sX3WrG9YAQ1JpNibGzZla0S/LXySdc7hLZKaF4JWglryvkKVCVoIYQQvuNVUGO1Whk1ahTr1q0z7nO5XKxbt45x48b5rFOVlZUcPHjQqOsyatQowsLCPF537969HDlyxKev2xWl1Q52F2p5O2P7JwSlD0atmi7k1fgrSVgXStNP7pWgo0ZeT+bsN4gdd6vxeCAqQQshhPAdi7dPmDdvHtOnT2f06NGMGTOGJUuWUFVVxYwZMwCYNm0avXv3ZtGiRYCWXPzdd98Z/87Pz2f79u1ERUVx7rnnAvDzn/+cG264gb59+3L8+HEWLFiA2Wxm6tSpAMTGxnL33Xczb948EhISiImJ4YEHHmDcuHFcfPHFPjkRXeWeT9Mr2ve5KB2R7IORGl9vZNlcUwG+4O3U7XQ62bhxI2+//TYAUSOuJfHqnwFg6zMUNr3h0d6flaCFEEL4jtdBza233sqJEyd4+OGHKSwsZMSIEaxevdpIHj5y5AgmU9MA0PHjxxk5cqTx85NPPsmTTz7JFVdcwYYNGwA4duwYU6dO5dSpUyQnJ3PZZZfxxRdfkJycbDzvmWeewWQyMXnyZOrq6pg0aRL/93//19nf2+f0fJpx5wR+6kmnTz+drOx8VWF/j9Q0TT9V++X4Z5Kbm8vs2bObRmiGXU3ipPuNx82R8S2e489K0EIIIXzH66AGYNasWcyaNavVx/RARZeVlWXkJ7Tl9ddfP+Nr2u12li1bxrJlyzrcz0DaFMR8Gp1egK9rOTX+Wc6t04Oa8toGKmrribaH+eV1WqMnBevvx8ghV5JwjfY+rt6/mYgBYzFHxhntFUUhIyPDr5WghRBC+E5IrH7qzpxOJ6v+s549BWUAXNQ3Lmh9MYKaTubUuFyqMVLj680sdZE2C3ERWiBzPIBTUM2TgiMHjyfxujkoionyre9z6l9aRWhzRCyYLAGrBC2EEMJ3JKjpAr2+yS2zfomKguPkEUYPOT9oy4C7uvrpZGUdjgYXJkWr/usvxm7dpYGbgnJPCramnkvidbNRFBMV2z7g9No/46qpQHXWA2COjAtYJWghhBC+I0FNJ7nXN7H3GQpA3ZFvg1rfJLmLVYWPNo7SpMWGE2b231sjGMnCerKvYo0g6UcPopjDqNr7GSVrnm9soeKsKgXgL/94PWCVoIUQQviOBDWd0Hwqw56pBTW1R78Nan2Tro7U6Pk0vf2UT6MLdK0ap9Np1FZKvGYWYfFpNJQWcupffwSa8r30XdX7DBgiU05CCNENSVDTCe5TGSZ7FGG9sgCoPaJtYhms+ib6SE1FbQO19d4HVEY+jZ9WPunSYrV+fr5jj9/3WdKnCOfOnUvU8GuIHHQ5qrOBE+89jlpXZbRTFAWrSxs58sWmoEIIIQJPgppOcK9bYsscgqKYcJw8jKu6rM12gRBjt2BtnDbydgrK6XSyZddBABrKivwWaOTm5vLY/P8HwJad+/26z5L7FGFYchYJE+8F4PTHr+Ao2Ge005OCx424AIDicglqhBCiO5KgphPc65bUnzhM6Sf/oPLrf7XbLhAURXHLq+l4rRp9NGP1xi0A/O2Pj/sl0NCDjOK83QBYYnsB/tlnyX2KUAmzk3zjgygWK9UHv6Tiy3c92upJwWOGng/AicrgFQYUQgjReRLUdEJ2djYZGRkoikJDaQFlm96gYtsq43FFUcjMzAxKfZOkKG2rhI7m1biPZlhitX20GsqKfB5ouAcZDeUnADBHJYDJ4pc8JPcpwpjRNxKWmElDxSlOffAM7nk0zzzzjJEU7Is6P0IIIYJHgppOMJvNLF2q1TXRpy50wa5v4s0KKM+EZwVLrFbBuaGsyOeBhnuQ4aouxVVfh6KYsERrxQp9nYfkPvVn7z8KgLLPVuCqKfdol5KSYvw/SVAjhBDdmwQ1nZSTk8PKlSvp3bu3x/3Brm/izQoo90DDHJWAYg5DdTlxVmjVkX0ZaDTPL3LqozUxvdpt11n61J9iDceWrk0r1eR93WY76HrxQiGEEMHVqW0ShCYnJ4cbb7yRjRs3UlBQQFpaGtnZ2UFdDuzNSI17ABGW1AeAhtJCUF1ttuus5vlFDRUnCEvMwBKTTF077TpLnyI8ZU9HMZmpP12As7zYeLy1LRD0jUhPVNRpuTjNRuGEEEKENglqushsNjN+/Phgd8PgzUiNewBhTdV2THcUHmi3XWfpQUZ+fr6WV1OmBRiWGG3Ky9f7LOlThPf832oAag9vNx5ra4pQP3e19S4q6hqICeC+VEIIIbpOpp96GG9GatwTno2gpqgpqPFlwnPzPCRj+im2l9/ykHJycjgv+wYAag/vMO5va4ow3Gom2qbF+ZJXI4QQ3Y8ENT1MQoR2Uf7++MkzFrZzDzRsjUFNXeNIjT8CDfc8JH0FlCUm2W95SMUVtRTWaG/x15f+lhUrVrB+/fp2t0CQZGEhhOi+JKjpQXJzc7n5hmsA7aLckcJ2OTk5vPLaSiyxKQA4CrUCfP4KNHJycjh06BCPL5wPQP8ho/y2z9Kmg1rC8+C0GG64egJTp05l/Pjx7QZpSRLUCCFEtyVBTQ+h15vJP6gVtjPZIlAstg7Vm+k7UpteSolQWP7yX884mtFVZrOZ68dfDECpw4TJ5J+34WcHTgJw2YCkDj+nlwQ1QgjRbUlQ0wO415tRHTW46rWKuObIuA7Vm/k2X9ve4aJzUzs0muELqbHaSqOaeienq+t9fnxVVfnsgDZSc8k5iR1+nj79VCxBjRBCdDsS1PQA7vVmAFxVpQCYozpW2G5nY1AztHesfzvqxmYxGwHE8VLf79Z9+FQ1+aU1hJkVxvRL6PDzJKdGCCG6LwlqeoDmdWQcJ48AYE07t912um+DENQApMeFA5Dvh6Dm08appwv7xBNh7XjlguQoKcAnhBDdlQQ1PUDzOjJ1x3YBYMu4oN12AKerHBw7rQUVFwQ4qMloDGr8MVLz+UEtqLn03I7n04CM1AghRHcmQU0P4F5vBqDu2HcA2HsPBtqvN7PzuDZK0zcxgtjwwBabS4/T8mryT/suqHE6nXy0fj3rvzsOwLj+8V49X4IaIYToviSo6QGaF7arK9yP2uDAHBVPWHw60Ha9GX3qaUiAR2mgafrpeJlvgprc3FyysrK4dupMapwmXHXV5FwxyqudxvWtEkqq6nC61DO0FkIIEUokqOkhPDbYdDZQV7AfgNShl7Zbb0ZPEh4WxKAmv7S2y8fSl7QfO3YMe98RANQe+Zb8Y0fPuKTdXUKkFZMCLhVOSV6NEEJ0KxLU9CB6Ybv169czaZSWJHzrrF+2W28mWEnCAL31oKaL00/uS9oB7FnDAW1rhI4saXdnNikkRsmybiGE6I4kqOlh9A02p193KQBbD5e22ba02sHRkuAkCUNTUHOyso7a+jMHHG3xWNJutmDL0HKJ9E0sz7SkvTlZASWEEN2TBDU91Kg+Wm2W709Wtbm55c78ciA4ScIAcRFhhIdpeT6FZZ2fgnJfqm5LH4QpzI6z8jT1jUvbW2vXHkkWFkKI7kmCmh4qNiKM81OiAfjq0OlW2wQzSRi0pGZ9BVRXlnW7L1UPd5t6aq9de2SrBCGE6J4kqOnBRmdpy5m/OlTS6uPf5pcCwcmn0fWOjwDgWBeCGvcl7XqScE3j1BO0v6S9NTJSI4QQ3ZMENT3YRVnaFNSXhz1HapxOJxs2bGDTXq2Wy+C06ID3TdfbByM1+pJ2xRqBNW0AALWHtJEavXZPW0vaWyNBjRBCdE8S1PRg+kjNrvwyqh0NQFMtl6uuvYHTDu2//45rs72q5eJL6bG+qSqck5PDgmWvopjM1Jccx1lxAoCMjIx2l7S3RoIaIYTonjq+KY7odnrHhZMWa6egrJbtR0sp3PEJU6ZMQVVV7H213JP60wUU5O1nypQpXl/8fcEowOeDWjWu5AGw/xBXD+/D1StWkJaWRnZ2ttc7jsvqJyGE6J5kpKYHUxSF0Y1TUFu+P+VRy8WaqtWxcRQd8LqWiy/1jvfdppb6fk83Xz6MqVOnMn78eK8DGoBeMdqUmIzUCCFE9yJBTQ93UeMUVO7GHUYtF8ViI+K8cQA4Cg4A3tdy8ZXebjt168FVZxSX17KvqBJFgXH9E7vUJ336qbKuwZi2E0IIEfokqOnhyg9+DcChCgUUE4o1gl63PIItfSCu+lqqD3zh0b6jtVx8JSXGjqKAo8HFqSpHp4/z+cFTAFyQHkN8pLVLfYq0mo36OTJaI4QQ3YcENT1Ybm4u/zP9Zly1lZhsEdj7Didl6u+xZw7BVVdF8Ru/oaEk3+M5Ha3l4itWi8moC9OVZOHPDmhTT5eek9TlPimKIsnCQgjRDUlQ00MZ+yG5nNTl7wGg1+SHsaWei7OqlMIV86nL322097aWiy91dQ8oVVWNoOaSc7se1ICsgBJCiO6oU0HNsmXLyMrKwm63M3bsWLZs2dJm2127djF58mSysrJQFIUlS5a0aLNo0SIuuugioqOj6dWrFzfddBN79+71aDN+/HgURfG43XfffZ3p/lnBfT+k2mO7AFAsYTSUn6BwxYPUF39vtO1MLRdfSo/rWrLwoVPVHC+rxWo2GTlEXSUroIQQovvxOqh54403mDdvHgsWLGDbtm0MHz6cSZMmUVxc3Gr76upq+vfvz+LFi0lNTW21zccff8z999/PF198wZo1a6ivr+fqq6+mqqrKo93MmTMpKCgwbo8//ri33T9ruOfG1Hz/Farqor4kn8Llv2gx5dSZWi6+1LuTy7r1IoJLX/sXACP7xBJh9U2Vgl4xMlIjhBDdjddXgKeffpqZM2cyY8YMAJ5//nk++OADXnzxRR566KEW7S+66CIuuugigFYfB1i9erXHzy+//DK9evVi69atXH755cb9ERERbQZGwpN7bkx9cR7H//JTnJWnUBs8k3GfeeYZHnjggaCM0OiaRmqqO/yc3NxcZs+ezbFjx0i6aT6R51/KpndfITe50CfBmTFSI0GNEEJ0G16N1DgcDrZu3crEiRObDmAyMXHiRDZt2uSzTpWVaRstJiQkeNy/fPlykpKSGDJkCPPnz6e6uu2LYF1dHeXl5R63s4n7fkgADaUFHgGNnkMT7IAGvB+pyc3NZcqUKY3Tawr2PkMBKP5WKy7oi+rIek5NsQQ1QgjRbXgV1Jw8eRKn00lKSorH/SkpKRQWFvqkQy6Xizlz5nDppZcyZMgQ4/7bb7+dV199lfXr1zN//nz+8Y9/cOedd7Z5nEWLFhEbG2vcMjMzfdK/7kLfDwmacmZ0wc6haa6pqvCZc2qMBGi9iGBKf8zhMbjqqqk7vg/wTRHBxMgwAPYeLmDDhg0BL0oohBDCeyG3+un+++9n586dvP766x7333vvvUyaNImhQ4dyxx138Pe//5133nmHgwcPtnqc+fPnU1ZWZtyOHj0aiO6HlJycHFauXEnv3r097g92Dk1z+kjNqSoHtfXtBw/uCdCAsSt37dGdoLp8UkQwNzeXu26fAsDh4tNMmDCBrKysoO2PJYQQomO8yqlJSkrCbDZTVFTkcX9RUZFPcl1mzZrFqlWr+OSTT8jIyGi37dixYwE4cOAA55xzTovHbTYbNputy33q7nJycrjxxhvZuHEjBQUFnd4PyZ9iwi1EWs1UOZzkl9ZwTnJUm22bFwe0Z2l7WNUe3t5uu47Sp7ZMkQlkAOaIOEAhPz8/aPtjCSGE6BivRmqsViujRo1i3bp1xn0ul4t169Yxbty4TndCVVVmzZrFO++8w0cffUS/fv3O+Jzt27cDgS8W1x2ZzWbGjx/fpf2Q/ElRFGMKavk7H7Y73ePx/20Ow5YxGIDaQ9vbbtdB7lNbzupSrW9mC6bw6KDujyWEEKJjvJ5+mjdvHi+88AKvvPIKu3fv5mc/+xlVVVXGaqhp06Yxf/58o73D4WD79u1s374dh8NBfn4+27dv58CBA0ab+++/n1dffZUVK1YQHR1NYWEhhYWF1NRoORYHDx7k0UcfZevWrRw6dIj33nuPadOmcfnllzNs2LCungMRZLm5uez7Wtuu4Yllf213usc9AdrWeyCmMDsNlSXUnzwCdK2IoMfUlsuJs1pLWDdHarVvgrU/lhBCiI7xOqi59dZbefLJJ3n44YcZMWIE27dvZ/Xq1Uby8JEjRzyG/o8fP87IkSMZOXIkBQUFPPnkk4wcOZJ77rnHaPPcc89RVlbG+PHjSUtLM25vvPEGoI0QrV27lquvvpqBAwfy//7f/2Py5Mm8//77Xf39RZDp0z2VxVrOkzkmGcCY7mke2LgnQIdnjQSg9vAOoOsJ0M2nrJyVJdprRie2204IIURo6FSlslmzZjFr1qxWH9uwYYPHz1lZWWfcfflMj2dmZvLxxx971UcR+tynexrKteKNlphegPaeUBSFOXPmcOONN3oEKXoC9NwPtSKC+tRTRkYGS5Ys6XTOS/Mpq4byE1h79cPSGGi11U4IIURoCLnVT+Ls4T7d4yw/ATQFNdD+dM9V196AObk/AE/8/B7Wr19PXl5el5J4m9f2cRqBlhbUBHN/LCGEEGcmQY0IGvdpnPpT2vRTWEp/QGmzne6L70/hUqF/ciQ/m36bTxKgm9f2aWgMtMwxvUKuto8QQoiWJKgRQeM+jeM4cQiXoxazPYqwxIw22+n0Xbkv89Gu3Dr32j4NxuhRcsjV9hFCCNGSBDUiaDyme1xOHAVaRWBb70FA+9M9nzYGNZf6OKgBLbA5dOgQS373MAAZ5w/r8tSWEEII/5OgRgRN8+meuvzdANgyBrU73XO8tIbvT1RhUuDi/p4rk3zZtx9eeQkApXWAIn8qQggR6uSTWgSV+3RPrR7UpA9qd7pHn3oalhFHbHiY3/rWK9qOxaTQ4FJlt24hhOgGJKgRQadP97z13OMAhCVmsHXn3janez4zpp78M0qjM5sUUmPtAOSXtr0jvBBCiNAgQY0ICWazmR9ePYFze2n7Pu04Vt6ijdPpZP369azbqS0DH9c/we/90rdvyC+t9ftrCSGE6BoJakRIGdVH25Jg65HTHvfn5uaSlZXFpFt+QkWDCVd9LVMnjvX7ztn6DuLHS2v8+jpCCCG6ToIaEVJG9W0Mag43BTX6VgrHjh3D3ncEAHXHviP/6OFWt1LwpfQ4bfpJghohhAh9EtSIkHJhY1Cz42gpjgaXx1YKAPasEYC2NUIgds5Ol5EaIYToNiSoESGlf1IkcRFh1DW4+K6g3HPnbJMZe+YQAGoa93vy987Z+vTTsdMS1AghRKiToEaEFJNJ4cI+TVNQ7lskRJx/KSZbBM7qMuqL8zye56+dsyWnRgghug8JakTI0fNqth0+3bRFgslM3GV3AlCx9X3Ac2d3f+2cndYY1JTXNlBRW++X1xBCCOEbEtSIkKMHNV8dLuGyyy4jIyOD6GFXE5aQjrOqlPIv3zXa+nvn7CibxSjwV1Amy7qFECKUSVAjQs7wjDjMJoWi8jqKKut54umlxFx6GwBln7+OWq8FF4HaOVufgsqXvBohhAhpEtSIkBNuNTM4LRqAZ1esYmtFFJaoRNTKk1TsWG20C9TO2U0F+CSoEUKIUGYJdgeEaC43N5cd6z7DNPBK/vqvzYSfcxFmexRTh0Rzzdo1FBQUkJaWRnZ2tl9HaHS9pVaNEEJ0CzJSI0KKXmjv1N4vAYi6YAJmexSOE4d4/P5bKCkpYerUqYwfPz4gAQ1IrRohhOguJKgRIcO90F5d447dutKPXwHV5ddCe20Jleknp9PJhg0beO2119iwYUPAz4MQQoQ6CWpEyHAvtOesOElD+QkAao/toubgl34vtNeW3vH6SE3wVj/pe19NmDCB22+/nQkTJpCVleX3va+EEKI7kaBGhIzmBfQqd67DVVvJ6Y/+1m47f9NXPxWW19LgdAX0tcFz7yt3+fn5ft/7SgghuhMJakTIaF5Ar2zjqxxdOhVHwb522/lbcpSNMLOC06VSXFEX0NduvveVu0DsfSWEEN2JBDUiZGRnZ5ORkWHUn9E0Xcz9XWivLSaTQmpscFZAuU/JWVPOIXP260SP/pHxeLCm5IQQIhRJUCNChtlsZunSpQDNApvAFdprS+8gJQt77H01+ApM9ihiL5mKYrG22U4IIc5WEtSIkJKTk8PKlSvp3bu3x/2BKrTXlmCtgHKfarOlnQeAOTyaiIGXtdlOCCHOVlJ8T4ScnJwcbrzxRjZu3BjwQnttCdZu3fqUXP7xAqwp5xr3R4+8jqqdH6EoChkZGQGfkhNCiFAkQY0ISWazmfHjxwe7G4amAnyBXdatT8lNve//YbLacTlqUcxmbOkDsaaeS33RwaBNyQkhRKiRoEaIDkgP4qaWOTk5zD5h4s08cBTsxVlVSuTgK0i9dArP3H5R0KbkhBAi1EhQI0QHBGv6SWfpdQ7kHSFn/GgyTKX85SBEDLqCidddFZT+9HROpzOkpj+FEB0jQY0QHZDeuKllRV0D5bX1xNjDAvr6O46WAnDDJUO5ZkgqG5Z8wr6iSv7wxgaGWE/KhdeHcnNzmT17tkexw4yMDJYuXSqjYgEgAaXoCln9JEQHRFgtxEdogUygR2tq653sLawAYFhmHIqiMMR+GoCXNh6QbRN8SKo3B5dsByK6SoIaITooPVabglrxz38HdEPJXcfLaXCpJEXZSI+1k5uby9K5d+By1GBN6oMtcwggF96ual692Zp2HiZ7FCDVmwNBAkrhCxLUCNEBubm57NzyCQDP/PnlgH6D1KeeRmTG4nK5mD17Nq66aqq+2wBoy7tBLrxd5V69OfzcsaRNe5rE6+cZj0v1Zv9pHlCaoxKIGj4JTGZ5XwuvSFAjxBno3yArio4AYIlJBgL3DfKbY6UADMuI87jwVnz9IQAR512CKSIOkAtvV7hXZY656CYAwvuNRLGGt9lO+Ib7+xog+ce/JPGaB4i56MeAvK9Fx0lQI0Q73L9BNpSfAJqCmkB9g9xxrAyA4ZlxHhfU+uI86o7vQzFbiDj/Uo/nyIXXe3pV5rCkvtj7DAVAMYdh7zOs1XbCd9zfr+EDLsaWPhBoHIVUTK22CzVOp5MNGzbw2muvBXR6WnjqVFCzbNkysrKysNvtjB07li1btrTZdteuXUyePJmsrCwURWHJkiWdOmZtbS33338/iYmJREVFMXnyZIqKijrTfSE6zP0bpLO8GABLbKrxuL+/QZZV15N3sgqAYb1jW1xQq/ZoU2KRIbptQnf6oNerN0dfeL3H/eH9RwHB21D1bGC8XxUTcZf/l3G/JbYX4eeMbtkuyJq/r1euXCkJziHC66DmjTfeYN68eSxYsIBt27YxfPhwJk2aRHFxcavtq6ur6d+/P4sXLyY1NbXVNh055ty5c3n//fd56623+Pjjjzl+/LgsrxR+5/7N0HHyMABhyVke3x6bt/Olb/JLAeibGEF8pLXFTubVez4DwJZ5AebI+JC68Ha3lSxms5k/PL2UyAvGA1C2RetneP9RQd9QtafT39dRF4zHmtQXZ00FFdv/BUD0hT8M+ff1zTffLAnOIcLroObpp59m5syZzJgxg8GDB/P8888TERHBiy++2Gr7iy66iCeeeILbbrsNm83WqWOWlZXxt7/9jaeffporr7ySUaNG8dJLL/H555/zxRdfePsrCNFh7t8MG0qO46qrxmS1E5aY0WY7X9KThIdlxAEtdzJ3Vpyg7vgeFMVExPmXAKFx4e2uK1nqe1+IyRqBWlZA2afLURscWGJTyBg8KqgbqvZkel2anCk3E3vp7QCUf7GS8i9WoqouwvtdiCU+LaTf162RBOfg8CqocTgcbN26lYkTJzYdwGRi4sSJbNq0qVMd6Mgxt27dSn19vUebgQMH0qdPnzZft66ujvLyco+bEN7yGBlRXTiKDgJgTdU2l/T3N0gjnyYj1riv+U7mVY2jNfHDrgqJC2/zlSym8BjCEjOB0P6gV1WVf3yhjcYtvONKPvrPagbEaR+RD//f60E/rz2R+6jHS5/sxxKXirOyhIptq2goK6Lm+60ATP31/wX9/Dd/X+tMEbFEXjCBpBt+Tu/7/07SDT83HpME58DzKqg5efIkTqeTlJQUj/tTUlIoLCzsVAc6cszCwkKsVitxcXEdft1FixYRGxtr3DIzMzvVP3F2az4yYgQ1Kef6fUpCVVW2G8u54zwey8nJ4dChQ6xfv57F998CgCn1PC77wXU+74e3mq9k6TVlIWkznsWacg4Qeh/0en7E7/76FgeKK4m0mpkyOpPx48dz2xVawvDH+08GuZc9j/uohxJmI3bcrQCUfvYaakMdc+bM4Ve3aLliX5fZ+fe69UHNzWr+vrZlDiH1v54mY9Y/SPrh/yNy8HgsUQlEDh5v1DfShUqCc3fKceusHrv6af78+ZSVlRm3o0ePBrtLoptyHxmpK9gPgDVtABkZGX4bGXE6nbzz7/WcqKjDpMDAlKgWbfSdzP97+m2MyIxDVWH1zs59ufAl9w9wc1QCtvTzUMwWokff2Gq7YH7Quo8ULP1wOwAVOz9izYfvAzD+fG2l2+a8EmocPe8CECzNRz2iR/0Ic1Q89acLqPzmPyiKwttvv81/3zSeBJtKeW0DU+YtCmpuVvPAJP7Ke7T3tmLCUXSQsk1vGisk9ZFcXSgkOHe3HLfO8iqoSUpKwmw2t1h1VFRU1GYSsC+OmZqaisPhoLS0tMOva7PZiImJ8bgJ0Vn6yMjfnngYgOjMQew/cNAvAY3+4XPn//wagNrC7xl03jntfvhcP1T70Pzgm+B/I3T/ANerHQNEDsrGFBnn0S6YH7TuIwXm6EQiBlwMQNHGN428n3OSo+gdF46jwcUXeaf83qezhfuoh2KLJHbsZADKPl0OLqcxmrdo0e85+O+XAIga2bQqLRi5We7va0tCBrbUc1GdDeT/+R4KXp5N6Sd/py5/N6CN5ELorJjrrjluneFVUGO1Whk1ahTr1q0z7nO5XKxbt45x48Z1qgMdOeaoUaMICwvzaLN3716OHDnS6dcVwltms5lbrp1AlM1CvQu+P+X7PaDcP3xsaecBUFe4/4wfPtcO1YL7LYdKKK6o9Xm/vOGeh6TXewGt5kv08GuMD/qTJ08G7YO2+UhB1PBrUExmao98a6xymzNnDi6Xi8vP00ZrPt57wm/9Odu4j3rY+wzFZI+ivuQ4Vd997NFu6dKlVO5Yg6u+DlvquVjTzweCk5vl/r6OHHw5ADWHvqahtGl01FF4ANBGckNlxVzz97q973DCevUDQjvHrbO8nn6aN28eL7zwAq+88gq7d+/mZz/7GVVVVcyYMQOAadOmMX/+fKO9w+Fg+/btbN++HYfDQX5+Ptu3b+fAgQMdPmZsbCx333038+bNY/369WzdupUZM2Ywbtw4Lr744q6eAyE6zGRSGNJbG/X7Nr/Mp8duuffQAAAcBfvO+OGTER/B8MYpqGff2RjUOXP3PCR7phbU6BerqJHXgsnCU089xdy5c43fSwmzEZbUFwjMB61HfoRiImrY1QBUfP2B0Qc970efgvp4nwQ1vuI+6qHnWtUd2wV4JuGWlJTgqq2gurEeU7TbaE2gc7Pc39eRg64AoLpZEFZXqE1P21LO8ev0tDc8tv8YMI6U235H6h2PY4nV8lhDLcetq7wOam699VaefPJJHn74YUaMGMH27dtZvXq1keh75MgRjyj8+PHjjBw5kpEjR1JQUMCTTz7JyJEjueeeezp8TIBnnnmGH/7wh0yePJnLL7+c1NTUHjVkJroPfXn1t8d8G9S4f/iYwmOwNX4rrSvYB5z5w6cPWjLrC//6Kuhz5jk5Oby4YiVhiRmoqouSdX+hobIES1Qiv3zuTZKTk5umH8JspNz+B9LvXmZURvb3B637Z5Qt8wIs0Yk4ayqo3vdFi3aXnJOIxaSQd7KKP738eo9NsAwk91EPPajRk/BBm7ZJSEgwfq7Ypm0JEjkwG1O4ZypBIJNwc3JyeOaVtwlL6I2rvpbqA5sByMzM5M033yT3r1rQY4lLZduuvUEPaKDp/Jhjkkm8bjYAJms4Cdc80Gq77q5TicKzZs3i8OHD1NXVsXnzZsaOHWs8tmHDBl5++WXj56ysLFRVbXHbsGFDh48JYLfbWbZsGSUlJVRVVZGbm9vpPB4humJob2159Tc+Hqlx/1CJu3waJms4juI86osPtdlOl5uby/O/+ikAtj5DjNyVYM6ZJwzSRlGzYi28+tfnuHmE9iVlV32yx1Ry4rWzsTUmVsZfda/HXkv++qB1HykwvnXv+xxcDS3arfnwfeqP7wFg/h9f7bEJloHkPurRPKjRp21mz55ttHcU7qeuYD+KJYyooVd5HCvQSbgV8dqXjYszI1n+8t9Yv349eXl53HzzzVz3gwn0S4oE4LuCyoD2qzk9Af+7774Dk5mkH/0Csz0KR3EervpawrNGaJuGNgqFZGZf6LGrn4TwFz2o2V1QTr3T5bPj6h8q1rTziBquTYeUrHme5kPyzT989GmrhrIi6o7v1QrxnacV4gvmnPnm77XE2quG9mXq1Kn88tYrMCsqXx8p5fG/vg5AzMU3EznoclRnPQ0Vp7BEJxJ76VTjGP76oDVGCsxhxuiQ+1RC87yf0j2fA01bJvTEBMtAy8nJ4aXXVmKJTkRVXTiK8wCMaZtf/epXHtWzKxsrDEcNvxZQgpKE63KprGpMxr/7ByOYOnUq48eP98iZGdL4+eDr6WlvuCfgP/bYY8Rddif23oNw1VZyIvcxSj/5BwDxE+7GHJ1MQkICTqezR4xASlAjhJf6JkYQbbfgaHCxr6jCJ8fUP1ASEhNJ+MHPUBQTld+ua8wz0LT1Ie4+bVW991MAYsbkYI5OBJqmcp599tmAfmh90RjUXNxfm0bYuOYDynduACD6whsIP+ciY5+fkjXPc2r1H7W+j76RsKS+JCcnk5+f75fpHn2kwN5vJObwaBoqTlF7dCfQNFLgnvdT8/02AG1zS3NYj0ywDIa+I7Q6NGlRZpa/8qIx6pGTk9OiRlTV7k9w1VURlpCOva+2yWigk3C3HCqhsLyWaLvFyLVqbpge1Ph4erqjmq90sve7kNhxNwNwavWzNJQVUbH1fWrzd2OyRZB4zSxKSkqYOHEiWVlZvPXWW926lo0ENUJ4SVEUhmX47oNL/1Y1ceJEHJljsKUNwFVbyekNL3m8JrT+Ie4+RVP57ToayooIi0slZepizNFNH7xz584N2LRJcUUtB09UoSgwpl+CMZpU/tV7AEQOupykG/4XRTFRse0DKnf8m9rvt1K993MUk5mEH9zHiRMnuPPOO/023ZOTk8Okmb8CGoNBVRt100cK3PN+6k/k0VBxCpPVjj1jMNDzEiyDYddxrdL7mHPTWh31cK8RpdbXUrVrPQC9LskJaBKuPpXz9EotYfmaC1KwWVoPpoI5UtN8sYE5Mp6k6+cBULHtA6r3atXHUV2c+nApaoOD8P6jiByiTekdO3aMW265pVvXspGgRohOGOKjvBr3b1Wm8BjiLp8GQOnGV3FVlxrt2ltJ4T5F46opp3DFQ9SfLiAsPo3U2xdhjullPB6oaZPN35cAMDA1hrgIqzGa5Cjcr+1VZQnDZIug9si3lKx7wXheyboXcDlqsfcZSuQFE/za72pHA3sqrAA8MetWVqxY4TFS0Dyfp/bQ1wDYs0Z43N9TEiyDYddx7e9HX1HYGvfq2b+6bTwAYVmjuWxiYKpnG1M5V/2ATflauYTX//DzNt+LFzT+LvmlNZyqrAtIH3XNqx7HTbgLc2QcjuI8Sj76q3F/dHQ0DSXHKN24HICEq2YSfu4YFGtEi2N2t6lWCWqE6IRhveMA2NmFoKb5t6r48T/BHB6No+ggFV9rqz0SEhJYu3atcaFtTfOdu53lJyh67SHqS45jiUsl9Y7FWOK0pPpATZtszvOcenK/8Jd/+U8AGsqKOfHPxR7JuRHUUva5lm8TP+EuFFuk3/q9dncxNfVO+iREcNeNV7YYKWiez1N7eAeg1flw11MSLINhZ742UnNBemy77fTq2fPuupVRfeNpcKm8+ZX/q8S7f+kIzxqBOTyGhsoS8r/e0OaFPsYeRv/GZOFAj9a4/50p1ggiG/PFTq1+Fpz1xmMVFdq0efmX71B3fB8mexS9Jj9M5uzXSJ32DHET7sLaWCeru021SlAjRCfo00+7C8qpa+jcH7r7typr+vlGrZRT/3nOmAopKSnBbDa3mzfQPPcAwFlxSgtsTh3FEtOLlNsXY4rQ+hyIaZMvGkdqLu6v5fW4X/ir92ykeOVvKXz157iqPT/0KyoqKP/yXepPHcUcGU/cZXcYj/m63+9tPw7Aj4anG+fNXfNgUQ9qrKnnYrJFhky12O6qrKaeIyXVAFyQ3vGK77eP6QPAS5/sZ/kK/+V9NP/SETl4PADVez5FdWmv19aFXh/J7cqXns7wWNU38FIUixXHycM4GstCtKC6OPHOY1R8/SH1JfkoJjO2tAHEjskhZeoiY+SmO021SlAjRCdkxIcTGx5GvVNlX2Hnlm66f6uKufAGACq/XYujcflwa+3a0nznbgBnZQmFK+ZrIzbRSUQO9Lz4+mva5ERFHQeKtXMyJksbqWkeINQc3IKzssR4jkddElcDJWv/DEDU8KtRwuw+7bfT6eSDNetZv0erBHv90JRW2zUPFp2VJThOHkFRTMZoTbCrxXZn3zXm02TEhxMXYe3w8+rztqDWVXGqVuWe3zztt7wPj60cLDbCB2hlRqq+2wC0f6HXv/R8E+BkYY+qxxdcCUDVzvXG44qikJzsmeDsrCyh5D//x/EXfsqxZdM58d4TNFScxBRmw+62zQl0j6lWCWqE6ARFURjaOHf+t3fXdOrbovGtymwh/NwxAFQ0Lltttd0Z6LkHzzzzjHGfq7qUqp1aTRhbY4Krt8f1htPp5KUPtA/5zGgTMXbtgt/aaJKutboktYe2U3+6AFOYnfBzLvJZv/X8iNt+vginquAozuPqsUPbvCA2Dxb10ZqkCy4NiWqx3ZmeT+PNKE1ubi633zKFim/WABA94lrAP3kf7hfwyAvGY7KGU3+6oMWoR2sX+mCN1Oh/Z+boZGOLEj0I0//Oli1b5vEFw52z8hTVuz+m5sAWAOxZ3W+qVYIaITohNzeXT959FYB/rPq4U98W9W9V4f0uxGSLoKHiJI7jTR+YnZneMJvNPPDAA57TJo3Lwm0ZF3T6uB2hBwx/ePFtAHatf9fjnLQ2mgRt1yWp3qMtT48YeJlP+u2eHxHRWHCvavcnZ7wguieq3j9ZWyVyziXXSUDTRfoFf8gZ8ml07tNBFdtXAxB+zmjM0Ul+yfswLuCKiZjGDTcrtr7fdjs3F6THoChwvKyWkwFOFs7JyWHm7/4CQO3hb3BWaJXG9b+zm2++uc0vGDojf6zPMKNdd5lqlaBGCC/pF8eT+7XaJdbGarjeflvUv1VFnNdY/G3v5+iF9rqyGV7zURFHwT5UZz2W6EQjYdjX0yYeG3E2fkOsPfpti3PiHiA0X23Uoi7JHm3EJ7z/aEyNVYY722/3C6IpMs74Flu9+5MOXRD1RNWfT78JkwLfn6yioMz3G5qeTfTl3Pqoxpm4Twc1lByj9vA3KCazsUrO13kf+peOyPMvJSw+HWdNOZXf/Nt4vL0LfbQ9zKgsHOhkYVVV2eeIA+C/rx/d4u8M2v6Coas98i0A1l79MDdWJ+8uU60S1AjhBfeLo7Ejb3JWpwuy/fBHN5E8Qpv7NmpI0P4S7o7wqO/R4KCusa9pwy7z+bSJR8AQEYu1cWPKuqO7Wj0neoBwprok9cXfU19yHFOYjd5jrulSv90viFFDf4BiMlOXv4eGsiKg4xfE2PAwY++vzw6c6lRfhLac/uAJLe+qo9NPzad5qg9qUyS21AHttusss9nMkiVLibl4CqCN0qj12qhLR750BKsI3878cg4UV2KzmJg7pfW/M2j9C8Zbb71FRkYGrppyHEXfA5A+4opuNdUqQY0QXnC/ODrLT+CsLkMxW7D26gd4/21x0/enqHEqJEVZ+fDvz7b6raqz3D+0rh7RH4DbHvi1zz+c3M+JPgLiKM7DVaN9E/f2nLj3+wcDtdVT18z8ZZf6rV/oTLbIpqmEbavabNeeS8/V+vTZgZOd7s/ZbndBBS4VkqNt9Iqxn/kJtJzm0feKCkvp3267rkgemo015RzUhjoqtja9XzrypUMP1lZv3hWQyrx6gcBFr2s5dBMH9SLaHtbuc5p/wZgyZYrxtzfhAm0U5465v+02AQ1IUCOEV5pf9OqO7wWaLuZttWvL6p1au0kXpHLlhAltfqvqLP1Da9q12hTX1sOnfXJcd+6/q74sveb7re22OxO93/97m5bD8vG+E1TWNZzhWW3TL3QxY3K0Tf1OHKJq9ydttmvPpeckAVpQo49ECe98pxfd8yJJuPkKuvrGvaLC4lJRfLjEXg8OXnvtNRa9+xUAd11+Huv+9V6Hv3Tk5uby2zn3AFqBTn9X5jUKBF55FRsPa8vk3/vjrzv1evrf3l3Xa58ZX+SVnOEZoUWCGiG80PyiV3PwSwDCzx3bbrvWNDhd/HuXNv1x3VD/rioY1TcegIMnqnxe5VT/XS0JWtKzqrqMzQdba+eNwWkx9EuKpK7BxbrdRZ3uY3Z2NhnnDiJ69I0A2oZ+atNmpN5cEC/sG4/NYqK4os6YQglV7hfoUNrHp6NF99w1z7ly1VbSUFYMgK1xtKareR/uG0H+5P8tZHeJC9XlJKNqX5tTpq0dY8qUKeTv/AJVdWGJTsIUGee3yrzu+Wz2fiMxR8bjrColf+vaLr3emH4JmE0KeSerOF7affLHJKgRwgut1VsBsPUeiCki1quL45ZDJZRUOYiPCGNsvwS/9js+0sp5KVEAfOXj0Rr9nMSM+iEANfs3G7kq0LWVE4qicH1jwPfBN97nSugX9TfffJOht/4vJquduuN7qDmw2eM1oOMXRHuYmYsa6+98uj90p6DcL9Chto/ProIzb4/QmuYJro5iLe8j+byRXc77aL4RZOxYLZematcGZt7RseDAPb9Mra+l/pR2LFvKuX5ZodW8QGBUY9J01e5PUJ0NXXq9aHsYQxvzgj4/2H3yxySoEcILLQqyVZyirvAAimIiorGeypkujvqF9o+5Wo7JDwalYDH7/09xdOOF+EsfDyebzWYWP7XU2BTPPVelK6u4dPoo1kd7injp1Y6POLhf1P/rvtnsqtNGq1xfv+vRrjNJ2Zc05tX8c/PekBsFgZYXaF2w9/FxOp2s/Wg9uxunnwY2BtrecM+5+uFlIwG4eebcLgU0zYMDS0IG4eddDED55pVAx4KD5nsvGYsJeg8EfL9Cy6NAoDWc8AFan/WNP7v6epeco73PPz8YusF7cxLUCOGl5t8W9W/9CUPHn/Hi2HShvZJPD2tTF689+WBALjJ6dd8v/ZBXU5s+EpM1HLX0uFHjArq+igtg96a1qOWFNLhg1u//3KERh+YX9bhLb0cxh1F7aDvFOz/lkUce6VJSdn3jktetR8u5/Q7/7STuDT1YXr58Offdd1+r+T7B3MdHf+9fN/UenKqCs7aS7AsHdynv49arLwG0xOOuaB6MxIz5MYpionrfJupPHe1wcNBiE9Qj3wAQ3m9Uu+06y/044VkjMYXZqS/Jx1G43yevd+m5Wv7YpoOnuk3+mAQ1QnSC+7fF38zQ8jQi+o/iuhtubPM5HrVceg/CEpWAq7aS/K0fBeTb8+gsbaRiV34Z1Y7OJ90253Kp/H3TIQAemz6x1Ro0nZWbm8vNN0+hbOcGACIbC/G1N+LQ2rfuyCHasvnTn/wdRVH461//yi233NKppOzc3FzmzrgZZ20lJltkp+sU+ZL7qNSdd97JiRMnAK14WtpdfyJ+4k+NtsHYx8f9vW9NOQcAR+HBLp+zwWna9NX+4gocDa4ztG6b50VfIaJxxKP8q/faaddSy5w7LdHYljYAc2R8m+06y/04tsZCeTV5287Yr44a1Tceq9lEQVkth05Vd66TASZBjRCd1LRz8C2kx9qpqXe2OUzb/EIbMbCx4N6BLaiNu+f6+9tzRnwE6bF2Glwq24+Udvl4+sjAIy+s5PCpamLsFiaPyuxwQmVHjq+fs+rdeiG+USjW8HZHHJp/647LvhPFZKZ63yatEGEXLupGn1zOFrt2B2sUpNWpJsVEbPad9LrtMazJWcSMuoGI8y7xeJ5+gfZ3MnHz974R1BQf7PI5y4gPJ9puod6psr+486M17hf9sKQ+mCNicTlqqcv/rs12rWmec+eqLqWusUp4+DmjfV6Z1/317H21FZh1jaOI0PVKwPYwMxf2jQPg5Q8/D8mp1uYkqBGiixRFYeJgbVPENd+1vkKn+YU24rxxAFTv1bYCCNS3ZyOv5lDXpqDcRwae/0jbgLN0279Yveq9Mzyz49zPWf3Jw9SfOopisRoX57bOmfu3aXNUAhHna+1LN77aZrvO9EkPasL7XWg8HuhRkOYBA4A5OpGUqb8n7pLbUBQTjhOHAIj/wX2YbJFGu7S0tIAkEzd/71tTtJpO9Y3F3bpyzhRFMUZr9A0yO8MjOGgsz1CX/x007sbd0eCgtT3O9MUE4R3MufOG/nruRS9rj+70eP2uvl5cnbbC7P9y14VcwnlrJKgRwgcmDtKCmrW7i3G5Ws49u19ArSnnYInphctRQ03e122284eL+ulBTeeThd1HBtyXcRdsfMOn0y/Nz0Xlzo8ArdYMKG22c/82HT5gHIpioi5/D/UnD7fZrjN9qv1+K6rqwt5nKGFJfdrtu780DxjsfYaR9pM/Ys8cgquumhPvPU7BK3OoP3UUS1QCcRPuAiAhIYENGzYEJJm4+bkIS8gAwNHs/6Oz52xwY62b7wo6H9S4ByN6UKNvFeBtcNAi504v+9DvQl5/0/eVeXNycvjlMy8BnkUvfZHPlpuby8uP/xLQ94HSzkWwE87bI0GNED4wtn8CUTYLJyrq+KaVvV48LrSNNW1q876Gxqmn1tr5w0WNeTVf5p3k1RXeDyU3HxmIvrBxGfeBLTSUFgK+m35pfi4qvv4QV1011uS+hJ97UZvt3L91R56vTfNVuW1B0ZUheffXaigratyvC2Ial/+21Sd/cQ8EFGs4ST/6X8wRsdQVHqDg5dlU7/4EnA2c+tezAEQPn4S9zzBKSkp45JFHjP9HW+ZQwhq/6ft6Gs39XJhskZgjtGXCDafbDka9ode66cpIDWjBwZtvrSS8r5abogc1nQkO3HPuXnr6UeJsCkqYndRh/tkQ0pmkTeldf9EAn+Wz6X/rdcf34XLUYI6IJSzZP+8RX5KgRggfsFnMXHFeMgB/+9eWFnPP7hfa8HPHAE1710DgdsHd+dla1Loq6pxw9/972OuhZI+RAbOFqMYEXH33Yl9OvzTPT1DrqqjY9gEAseNuafOcuQ/J2zK1ncn1fbW6OiTfvE/lX7wFQOTgK7DEpgR8N2P3QCBm7GTMkfHUl+RT+Or/0lDaFDTU5X9nnLuEa2ahWKyAthlrytRFpN6+iJSpv0f/Ju6v/0dLgjZ60VBxCrW+Fuj6e9+Yfioo7/IKnaGX/gDFHo3VDH974uEuBQd6zt3tt0/lmmHa6NRHe4q71L+2bPpeqyNz8+XDfVaV3PhbdzVQd2wX0JQ/BsFJOO8ICWqE8JHYqiMAvP3FvhZzz/qF1hyVgC31XFTVZQxL+2ru+0xyc3O55eYp1BzVPqBsGYMB74aS3UcG7BkXYLJF4Kw8Te3hb9ps11mt5SeUf/VPXPV12NIHYusztM1zlpOTw9wnX9I2rizYh7Ncu5h0dUi+xQ7oRQepyduGYjIbe0oFcjdjPWCwRCcSc9FNAJz++GWPEcCkpCQSEhI4/fErNFScJCw+nfir7iXxh/+PtOlLjOkWc0QslnjP0RJf/z+GJaQD0HD6OOCb9/65vaIIMytU1DZw7HTXKt9+0RgcjO2fxJ23+27LkisH9gK0oMbXS6OLymv5/kQVigJj+yX67Lju//c1hxrzx84Z0267UCBBjRA+kJuby+NzpqG6nFh79cMSq+XYuAcMOTk5zH78rwA4ju/DVa1NU/li7vtM3KeN9G9dtgxtFMOboWT3kQF7Y4JsTd5WQG2zXVc0z09wVZdS9e0aAC6b+SgJCQltrsgoDNP6cOcVF/h8o1D3PpVtehOAqGE/4MUVgd3NWA8YYi+7A1OYndpju6jZtwnQAgZFUXjggQcoKSlBdVRT8p//AyB6xDVG9dnKb9fhOKkF5PrGrDpf/z8m9NGK0NU3BjW+eO9bLSYG9IoGYFcXp6D0oObi/r4LDkCr92I1mzhSUs3BE1U+Pbbe58FpMcRGtL+BpTfc/+/1kU57nyGYIuPabBcKJKgRoov0gMFZU0Fd4yiInjfTPGAosWsXwtvHD/XphfZM3KeN9KAmvO9wlDCb0c+ODCV7TKP11wqKudfF8Mf0i3t+wooVK3h+zs2YUNlfYWbS1HtbXZFxusphlHa//0eX+nyjUPc+vbj4l5wbZ0Ixh7GlLCpgy171pdiHS+uJGvYDAE6vf8l4XA8YBgwYYNxXc2CLkXBde/gbCl6ezakPn6EufzfQFNT46//xxjtnAnDLNVf49L1/gQ+ShV0u1W9BTaTNwsWN1Xk/2tP5Pcxao/d5nI/77P637iwvpi5/D4rJTOT5Wq2oQE+1dpQENUJ0kXvAUN1YXVgv3gVNAcPaDZ/w2QGtjs09113s8wtte9yHiOvy91J/+jgmexSRg8e32a41TdNoiViTs1BVl5bwjH+n0fT8hKlTp2KuLaVip1YGPvbim4027qNia74rwulSGZQWQ1ZSZFuH9Umfbr99Klcka5uErtpTxp0zZgZuV+YJE1i8eg+g4Dq8lV/9dGqLYLn5N+lTHzxD/vN3U/T6L3EUHQTcdrzu1d+v/496AbebJl7q0/e+sQLqeMsk/Y7aV1zB6ep6wsPMDMvo+EabHXVV4xTU25v2+TTw3dQYvI87x7dBTfOp1qo92q72EYOyAzZl3hkS1AjRRe6BQPW+TaguJ/a+w7D1HuTR7tP9J6lrcNE7LpyBqdEB7aPnhU01kkajR93QTrvW5eTkMPeJFwBwFOzDVasVPQvkNFpZY4JuxPmXYEnUkjDdR8U++Fab3rhuSKrf+qLLzc1lwcwcHMV5mGwRxoqwgOzK3Hc44f1HozrrKfz3n1m4cCE2m80jYGie3Ayqx4ajAI7GmjHWXv389v+oqiqHTmpTL/19HGj6olbNF43BweiseML8sBeb47A2ornnVD13+CjwLSir4dCpakxKU7kGX3Kfaq3e85lWwiDjAjLOG+r3v/XOkqBGiC5yDwSc5cVUfqPlfOg1QXSHG7QP3qsG9XK7wARG8wtb5bdrcTlqsCZnYesz1Ouh5MqoTABuyQ7ONFr9qaNUN+aOxF58i/G4qqrkF5ewcZ+2TcC1Q/073++eq6QHWtGjf4QSZgvArswKceNnAFDx9b+MPJXmr9dawrVO//nnM6cCYIlJ5utde/3y/3iqykFFXQOKApkJET499qDGkZrjZbWcrnJ06hhffK/VbvL11BNogehP75iC4+RhFJPZKNjY1cBXH6UZ0juWGLvv8mnc6VOta99fSf8obSuKX/357ZAMaECCGiG6rHnAUPbpclyOGuy9BxFx/qWNAUMf9pRp35yvaizUF0jNL2xqXRVVejG7UT8COj6U3OB0sXG/No02fdJFQZtGMxJ0h1xJ/FX3gqJ9nIWfOxYXCmppPt98+h+/9sdj6nHPp9SfLsAcEUvkBdpSd3/uyhwxKBtb6rm46qoo+/z1dl+veXKzLiMjg7fffpvHFvyazIRwAPYW+zaRVZfXOErTOy4ce5hv3ysx9jAy47X+L1v+rtdTOy6XyuY8/+TTuAeiNQcaC/E1VhfubOBrbF66RjveWD+M0rjTp1rv+sEIAD74ttCvr9cVEtQI0UXNAwZn1WnKt7wDQNwV01EVM1dOmU5xRR2RVjMX9/fvB1Bbml/YKratArQg4C/LOz6UvP1oKRW1DcSGhzE8I85f3W2V+6iYo3A/pzdoibExo39Er1sewWSLJKKx4F75ro/9XvXUIwdJdVHx9YcARF4wvu12Pnk9hdhxtwJQvuUdo4pse6/XPOG6+eiaL6Zw2qMHNf38kOOUm5vLkR3aCp3H//qaV1M7TqeT5R+s53R1PVYzXJAW5dO+uQeiRnXh/qPApAV23ga+7vlUm/O00aXnfzsvINV9rx2Sitmk8M2xMmMqMdRIUCOEDzQPGMq35NJQWUJYfDrRI6/l3S+1hMzKA1/ywXv/DGo/9QvbK8/+gQsSTSgmE1vLO75q5+PGqZ3sAUmYTcGdRivf/DbFub/D5aghPGskqdOeNob2q/ZoFzl/Vj1tnoNUvfsTI+/AHNOrzXZdfb3wARdjTe6Lq66K8sbCh+31S+eecN18dG1QY1Czu6DzG0O2Rw9qshJ9G9ToOUZlh7XNJ629+gOtT+0037xz5cqVZGVlcf/CZwAoO7CNAef092mA4Jmkvxtn5WltNG9gdpvt2uKeT2WO6YUlLhXV5eT4jo0B2bYgKcrGJY0JyR98G1r1aXQS1AjhI+4Bw+z776Ps0+UAxF5yGxGNyyBLdn4S9D1T3C9sFyc0rtr5roQ7ps3o0DdcPagZf36vNtv4S6sbBu7fROGrv6ChrJiwhN4oljDqTx2l/uRhv1c9bR5kOStPUXtYK68fOfiKTi97bWvnbP314i7RRmkqtr6PWtf0jbkry2wHpXV9WXR7DvlhpMZ9asdIdk7RgprmUzutbd558803c+zYMWxu+z35OsHbI8BUXZRv1TZ9jbnYu601mm9RohdNdBTsx1VX7fG7+tMNw7QCiq995ttVXL4iQY0QPmQ2m8nOzmblypVUfrMGx8kjmCNisSb1QVVdVDcOP4fCnim5ubksuHcK9aWFmMOjjSmT9j7UT1bW8c0xbdns5QOSAtldQ2v5IfUn8ij4+1xqG2vwVO3a4PEcf1U9bS3IqvpOe219uby3y17b2znbbDbz00eexZp6Li5HDeVfNe2K3tVltvr004HiChwNLq+ffyb+mH5yn9pxFGtBTVhiJtZ0rcifHtTedtttTJ48ucXmnRoFe+YQQAtqfJ3g3TzwbdrDLIvw/qM7HIi2tnmp3mcI3LYFjrwvUZ0NHKtUmf7AgyG3a7cENUL4mPHho7oo3dBUDE2vIhwKe6YY3/pcTiO3Rl+K3NaHutPp5C/vaX3uE20iMdI/qy06Qh8Ve+aZZ4z7XNVlFL32Kwpf/Tllm1d6tPdn1dPmQVb1vs9RG+qxJvdl6d+9WyXiPr3g7tixY0yePJk5c+eyvkgrmMiBjR65NF1dip0RH0603UK9U+XgicpOHaMtLpfKoVO+D2rcg1VnxUlqD3+DYjKTctvvCB8wznhs5cqVLZ9stmDrPZDYy27HHBGLy1GLo3A/4NsAobUk/Yrt/wKaRms6Eoh6bF5qiyTiPK0WVu3h7W2287Xc3Fym3TbFKLgZMUgLxEJp124JaoTwMY89Uw5+aeyLpC9Bbq1doLl/66v8Zg0uRy3WXv2MJFv9Q33hwoUeuQdPr9A+jL/9zxtB/3ZmNpt54IEHPGuwuBqoy98DLi0YC1TVU/epx+UvvcDo3nYAKhMHdvgYzacXWvPn3I84WOZCddaz4LbL2kz67QxFURiU6p9k4aKKWmrrXVhMChmNq5R8oXmwWvz2b6k++CWmMBvJP55vBOo6S0IGseNuJeX2xfSZ8yapdz5J3KXacvbaI98Y7xudr/5GWyTpf/VP1IZ67JlDeOKljgW+HpuXXvhDTLZIHCcOt9h3zV8BvPv7s3q3VogvcuDlgIKqqqiqyn333cfy5cuDOiUlQY0QPtb8Q+XEu4s4+eFSyr/6Z7vtAsn9w1qtq6Lym38DkPSjXxA14lrjsccee8wt9yDfSMKtydsaEt/OOlKDJVBVT91zle6ZqO1m/P7247hcHdvA0D3QtGUOofd/v0Jyzq+NPboAYhtzaSq3r2bmnbdSUlLi0yX1emXe3T7Oq8lr3O+oT0IEFh8Wtmuxk3t9LSfefpSK7f9CUUwk/OA+4q+6l9hLbiPtrj/Re+bzxF3+X9gzh6BYrDirSqne+zklH/2VUx8uaXF8X/6Nuge+//jLn7giSwvudqsZreZPtfW7mqzhRI/WyjBoZQ2095e/A/jmldNd9XWEJWaQcttjxl53J06c4M477wzqlFSn3l3Lli0jKysLu93O2LFj2bJlS7vt33rrLQYOHIjdbmfo0KF8+OGHHo/rG681vz3xxBNGm6ysrBaPL168uDPdF8Kvmn/QumortE0YXQ1AaOyZ0vzD+vT6l6j8dg2KyUzipPsbCwd6BgnWlP6YI+Nw1VVTl7/HL8XlOqO9GizBqno6YWAvou0WjpfVsuVQSYeeYwSaJguJk2ZhiU4kYsDFpN7xB1KnPU3sJbdh7zsc1VlP2WbtYuHrcz8oTat0vbvQx0FN49STr7esaDWoVV2U/HsZpz9+BdCW+8dl36lt6+FsoObgV5xa/Sz5L/yUY3+6kxPv/p6KL9/1mMrz19+oe+C74NZLUVBZu6eYq2+ebuRP9e3bl9/+9rctghz9d40aPglzRCz1p49TvWejx+/uzwDe44uQo4aS1c/ictRi7zuctLuWNVYnb/rMCNaXHq+DmjfeeIN58+axYMECtm3bxvDhw5k0aRLFxcWttv/888+ZOnUqd999N19//TU33XQTN910Ezt37jTaFBQUeNxefPFFFEVh8uTJHsf67W9/69HugQce8Lb7QvhdKI0etKVF6XxXA6c+XErpJ/8AIHZMDkk3PURYUl+iL/whyTm/IWXqIgBqD+8wArRQyA+CM9dgCTR7mJlrG7doeO7Drzq0SkQPNKNH/ZCwxAycVaep2P4v1AYHtrTziMu+E4DKnR/hrDjhl3M/OE3b8+i74+XtToN565CflnND20Ft+RdvcfL9J2moOEnNwa84+eESjv3pTopXLqRyx79pKMlv9XiB+hvdsfE/VO3VpqRjxjS9T/Pz81mwYEGrG7Ved8ON9Jl0t/H7oWoJ3YEI4Jt/Ear6bgMFL82i9vA3mKx2Eib+lJTbF2OJ11ZHBetLj6J6+c4dO3YsF110EX/6058AcLlcZGZm8sADD/DQQw+1aH/rrbdSVVXFqlWrjPsuvvhiRowYwfPPP9/qa9x0001UVFSwbt06476srCzmzJnDnDlzvOmuoby8nNjYWMrKyoiJienUMYTwRm5uLrNnz/ZI+szMzGTJkiUhUWJcT0oFPC5gEYOuIOm6OSiWlonAzqrTnHjvceoaV1zoVqxYwdSpU/3b4W5m8Uvv8PxeK86aCo4t+y9wNpCRkcHSpUtb/f93Op1knT8U5YaFmGyRnPrXUiq/WYMpIpbokdcTfeH1oJgofGWOx95Nvjz3tfVOLljwb5wulS/mX0VqrN0nx73nlS9Zu7uYR28awn9d3Ncnx2zO6XSyceNG1q1bx2OPPdbp4wTib9TpdJKVlUWxM4K0aU+jOhvI//NMnBUnWrRVFC1n5ZFHHuFU/GD+mR9OWqydR8eaOFFUSFpaGtnZ2X7/kqT3OT8/v1nAqxA18lrix8/AZA2noayI/L/c65GftH79esaPH9/p1/bm+u3VSI3D4WDr1q1MnDix6QAmExMnTmTTpk2tPmfTpk0e7QEmTZrUZvuioiI++OAD7r777haPLV68mMTEREaOHMkTTzxBQ0NDm32tq6ujvLzc4yZEIIXa6EFzbX3Drd79MUWv/wpndRlqQz01h3ZwesNLHH/pfzj2p2ktAhoIbn5QKMrNzeWX99xMQ8UpzOHRhPcbBbRfEO7NN99k0K2/0BJACw9Q+a32pc5VXUbZZys4tmwa+c/9pMVmlL489/Yws7HZ5HcFnd/xujljObcfRmp0+tTOwoULm23g2bbMzEzefPPNgP+N6vkpjoJ91B7egWK2kH7XsyTd8HMiBl2OYms6T3oAseCR3/L2d9p1LH/ty5SXng7oFiVtj0CrVH79Icf/dj81h77m9PoX/ZZw3SGqF/Lz81VA/fzzzz3u/9///V91zJgxrT4nLCxMXbFihcd9y5YtU3v16tVq+z/84Q9qfHy8WlNT43H/U089pa5fv17dsWOH+txzz6lxcXHq3Llz2+zrggULVLQMKo9bWVlZR35VIc4aDQ0N6vr169Vf//rXHn8risWmKhZbq39HRhtFUTMzM9WGhoZg/xoho6GhQc3IyFABNX7C3WrfB1epSTc+2Oo5e/vtt422Yb36q31+8Z7a98FVavLgi9s97/489//z2ja174Or1D99tN8nx6tvcKrn/vIDte+Dq9Rjp6t9cswzefvtt1VFUVRFUVo9d3PmzFHXr18ftPftihUrjL6E9eqv9v7vV9S+D64ybn1+/q7a69bH1Kjh16im8BgVUCOHXKX2fXCVmjHrH6opzKYqiqK+/fbbAe+7+3u2o7f169d36TXLyso6fP0OudVPL774InfccQd2u+ew57x58xg/fjzDhg3jvvvu46mnnuLZZ5+lrq6u1ePMnz+fsrIy43b06NFAdF+Ibqetb7hqQx1qQ+t/XxA6+UGhxn2ViFGIb2A20aNvBJrykGbOnOlRkyZh4r0oiomq3R9z4rsveOSRR4zp9kDmZp2fou19tPqLnT5Zmnu8tJZ6p4rNYiItxjfTWWfS1ihkZmYmb7/9Ns8880zARjha4z66Vl/8PfnPzdDqK216S9vJ22whPGsEidfMImPWP+h1y2+NpeflW97FVa/9XQYjSd99BPrVV18lOTm5zVGxYCyK8CqoSUpKwmw2U1TkOfxZVFREampqq89JTU3tcPuNGzeyd+9e7rnnnjP2ZezYsTQ0NHDo0KFWH7fZbMTExHjchBBtay/BuTXBXF0UytyH2h1FByn/8l0AEq6aSfxVM43dxF966SVjaiFiYDb2zCG46ms5vf4lFEXhr3/9K08++SRvv/12wFZ25ebmsujBWQBs+77IJ0tzjZVPiZGYArhXWChP/7ZI1Fdd1OXvofSTVyj42/3k/3kmpze8RF3hARSTmfB+F2KJS8VZU0HFdm31sBrEJH39i9Add9xh5MaGyqIIr4Iaq9XKqFGjPBJ4XS4X69atY9y4ca0+Z9y4cR7tAdasWdNq+7/97W+MGjWK4cOHn7Ev27dvx2Qy0atX4PefEaKnau8bbjByD7qjFsvlP/qrlmcAxIy+kaQbH0SxWEExYe87nMTr5pJ47WwAyr9YibPipMcFK1AXZz1x/PgurUSHJSEdJczW5aW5eY3VibOSInzW145qbwPPYDrTF4iG0gLKN79N4StzyP/LvZz+5O/UHt5ByZrnUB01Hm2DWcQTQq+kgsXbJ8ybN4/p06czevRoxowZw5IlS6iqqmLGjBkATJs2jd69e7Nokbb8c/bs2VxxxRU89dRTXH/99bz++ut89dVX/OUvf/E4bnl5OW+99RZPPfVUi9fctGkTmzdvZsKECURHR7Np0ybmzp3LnXfeSXx8fGd+byFEG3JycrjxxhvZuHEjBQUFAVtd0VPo38LdV4mUb8mloeIkSdfNJfL8SwmLT8cUHo0lumn/rLrjeyjf8o7HsfQLln5x9hf3arFqdam2k3RUPGFJfXEU7ENRFObMmcONN97o9fvg0Clts0Vf16jp7vRgoPkKyeYaTh+nfNOblG96s9XHQyFJP5Q+M7wOam699VZOnDjBww8/TGFhISNGjGD16tWkpGgVBY8cOYLJ1DQAdMkll7BixQp+/etf88tf/pIBAwbw7rvvMmTIEI/jvv7666iq2urSRJvNxuuvv87ChQupq6ujX79+zJ07l3nz5nnbfSFEB/j7ItqT6d/Cp0yZYizHBaje/QlFlSUk5/waa69+ADhrK6nevZGqXR9Rl7+7xbECdcFqvlmi40Qe4VHxWHv1x1Gwz2PkyJv3hdPp5Kt9Wj6j49QxnM7zJDh20zwY2L9/Py+88EK7QY5OURQyMjKCWsTTXah8Znhdp6a7kjo1QohAaq1OEYAlMYPoEddSd3SXtmu7s77Fc/ULVl5eXkCCgNdee43bb7/d+Dku+7+IveRWqnZt4OSqJ43726uJo9eJ0b+pnzx5krlz5+K6bgFh8WkULn+QZMrarNMjNO7ncf/+/SxcuBDwrCWlT1mdLTlt3ly/vR6pEUIIcWb6t/Bnn32WuXPnGvc3nDrG6XUvtPm8YCRYNh8Rqvl+K7GX3Iq9/ygtsbmxcm1bI0dtBXCYLPSJ1fIe60/nk19dxpQpU86ai3FnNB/xGDJkSItzm5GRETJFPEONjNQIIYQftV2JtXXBqDrdoo+KiYwHXsUcHkPh8gdx5H/X5siRnmDc2u9mScig98zncdVVc3TJLUDgR6F6guajYGdbjpuM1AghRIhoK8cGPEvgDxgwIGgXrJZ9dFFz8CuihlxJxLljcOR/1+rIkXuCsTtryjlEnH8ZEQMvBaD+9HHjsc7m55zNQiVfpTuQoEYIIfysrZUuoTSN0LyPNQe3EDXkSqIHXsJfH/hhq31snmAcNeJaYsZOJiyuqQ6Zq76Oyu2rWzw32EuRRc8kQY0QQgRAKC17bYt7H/OOFfDYLnDGpnHh5eNbtHU6nR41yGIvuc3YSdxVX0vNwa+o3vsZNQe/RK2vbfH8UFiKLHoeyakRQgjRqttf+ILPD57iNz8czN2X9TPub54YHHvpVOIuuwOA0k+XU74lF7W+9S02JKdGeMtvu3QLIYQ4e1w5UFu5tG5301Y3emJwU0BzuxHQnF7/EmWfvdZuQAOyX5jwHwlqhBBCtGriIK2o6pa8Espr61skBsdedjtxl2n1bU6vf5HyLW+3ezzZL0z4m+TUCCGEaFVWUiTnJEdy8EQVn+w7QVTJPmOEJnr0jcRdqgU0JR/9jYovPbd4yMzM5KmnniI5OTlkc4hEzyNBjRBCiDZdNSiFgye+Z/mGb0ja/wEA9qyRxE+4C9CmnJoHNL/+9a9ZuHChBDAi4CSoEUII0aaw4j2AlU8PnubYsv/DEpem7TRuMlP5zX9anXK66qqrJKARQSE5NUIIIVqVm5vLg3ffjLOmAnN4DPZ+I0me/BvM9ijq8vdw6j//59FeURQyMzNDZpNFcfaRoEYIIUQLRlKwy0nt91sBSL7xIaxJfWioOMWJd34HzgajvaxsEqFAghohhBAtuFcLrj64BQCTNRy1wcGJd36Hs+q0R3tZ2SRCgeTUCCGEaMF9G4Pa77eiNtSjWMI4tfpPOAr2GY/NmjWLyZMny8omERIkqBFCCNGC+zYGrroqilcuRLFGULN/k0e7yZMny2aLImRIUCOEEKKF7OxsMjIyyM/PR1VVag/v8Hhc3+5AkoJFKJGcGiGEEC2YzWaWLl0KNCUB6yQpWIQqCWqEEEK0Kicnh5UrV9K7d2+P+yUpWIQq2aVbCCFEu5xOJxs3bpTtDkRQeHP9lpwaIYQQ7TKbzZIMLLoFmX4SQgghRI8gQY0QQgghegQJaoQQQgjRI0hQI4QQQogeQYIaIYQQQvQIEtQIIYQQokeQoEYIIYQQPYIENUIIIYToESSoEUIIIUSPcNZUFNZ3gygvLw9yT4QQQgjRUfp1uyO7Op01QU1FRQUAmZmZQe6JEEIIIbxVUVFBbGxsu23Omg0tXS4Xx48fJzo6GkVROn2c8vJyMjMzOXr0qGyM6WdyrgNHznXgyLkOLDnfgeOvc62qKhUVFaSnp2MytZ81c9aM1JhMJjIyMnx2vJiYGPkDCRA514Ej5zpw5FwHlpzvwPHHuT7TCI1OEoWFEEII0SNIUCOEEEKIHkGCGi/ZbDYWLFiAzWYLdld6PDnXgSPnOnDkXAeWnO/ACYVzfdYkCgshhBCiZ5ORGiGEEEL0CBLUCCGEEKJHkKBGCCGEED2CBDVCCCGE6BEkqPHSsmXLyMrKwm63M3bsWLZs2RLsLnV7ixYt4qKLLiI6OppevXpx0003sXfvXo82tbW13H///SQmJhIVFcXkyZMpKioKUo97hsWLF6MoCnPmzDHuk/PsW/n5+dx5550kJiYSHh7O0KFD+eqrr4zHVVXl4YcfJi0tjfDwcCZOnMj+/fuD2OPuyel08pvf/IZ+/foRHh7OOeecw6OPPuqxV5Cc68755JNPuOGGG0hPT0dRFN59912PxztyXktKSrjjjjuIiYkhLi6Ou+++m8rKSv90WBUd9vrrr6tWq1V98cUX1V27dqkzZ85U4+Li1KKiomB3rVubNGmS+tJLL6k7d+5Ut2/frl533XVqnz591MrKSqPNfffdp2ZmZqrr1q1Tv/rqK/Xiiy9WL7nkkiD2unvbsmWLmpWVpQ4bNkydPXu2cb+cZ98pKSlR+/btq/7kJz9RN2/erH7//ffqv//9b/XAgQNGm8WLF6uxsbHqu+++q+7YsUP90Y9+pPbr10+tqakJYs+7n9/97ndqYmKiumrVKjUvL09966231KioKHXp0qVGGznXnfPhhx+qv/rVr9Tc3FwVUN955x2PxztyXq+55hp1+PDh6hdffKFu3LhRPffcc9WpU6f6pb8S1HhhzJgx6v3332/87HQ61fT0dHXRokVB7FXPU1xcrALqxx9/rKqqqpaWlqphYWHqW2+9ZbTZvXu3CqibNm0KVje7rYqKCnXAgAHqmjVr1CuuuMIIauQ8+9aDDz6oXnbZZW0+7nK51NTUVPWJJ54w7istLVVtNpv62muvBaKLPcb111+v3nXXXR735eTkqHfccYeqqnKufaV5UNOR8/rdd9+pgPrll18abf71r3+piqKo+fn5Pu+jTD91kMPhYOvWrUycONG4z2QyMXHiRDZt2hTEnvU8ZWVlACQkJACwdetW6uvrPc79wIED6dOnj5z7Trj//vu5/vrrPc4nyHn2tffee4/Ro0dz880306tXL0aOHMkLL7xgPJ6Xl0dhYaHH+Y6NjWXs2LFyvr10ySWXsG7dOvbt2wfAjh07+PTTT7n22msBOdf+0pHzumnTJuLi4hg9erTRZuLEiZhMJjZv3uzzPp01G1p21cmTJ3E6naSkpHjcn5KSwp49e4LUq57H5XIxZ84cLr30UoYMGQJAYWEhVquVuLg4j7YpKSkUFhYGoZfd1+uvv862bdv48ssvWzwm59m3vv/+e5577jnmzZvHL3/5S7788kv+53/+B6vVyvTp041z2tpnipxv7zz00EOUl5czcOBAzGYzTqeT3/3ud9xxxx0Acq79pCPntbCwkF69enk8brFYSEhI8Mu5l6BGhJT777+fnTt38umnnwa7Kz3O0aNHmT17NmvWrMFutwe7Oz2ey+Vi9OjR/P73vwdg5MiR7Ny5k+eff57p06cHuXc9y5tvvsny5ctZsWIFF1xwAdu3b2fOnDmkp6fLuT7LyPRTByUlJWE2m1usBCkqKiI1NTVIvepZZs2axapVq1i/fj0ZGRnG/ampqTgcDkpLSz3ay7n3ztatWykuLubCCy/EYrFgsVj4+OOP+eMf/4jFYiElJUXOsw+lpaUxePBgj/sGDRrEkSNHAIxzKp8pXfe///u/PPTQQ9x2220MHTqU//qv/2Lu3LksWrQIkHPtLx05r6mpqRQXF3s83tDQQElJiV/OvQQ1HWS1Whk1ahTr1q0z7nO5XKxbt45x48YFsWfdn6qqzJo1i3feeYePPvqIfv36eTw+atQowsLCPM793r17OXLkiJx7L1x11VV8++23bN++3biNHj2aO+64w/i3nGffufTSS1uUJti3bx99+/YFoF+/fqSmpnqc7/LycjZv3izn20vV1dWYTJ6XM7PZjMvlAuRc+0tHzuu4ceMoLS1l69atRpuPPvoIl8vF2LFjfd8pn6ce92Cvv/66arPZ1Jdffln97rvv1HvvvVeNi4tTCwsLg921bu1nP/uZGhsbq27YsEEtKCgwbtXV1Uab++67T+3Tp4/60UcfqV999ZU6btw4ddy4cUHsdc/gvvpJVeU8+9KWLVtUi8Wi/u53v1P379+vLl++XI2IiFBfffVVo83ixYvVuLg49Z///Kf6zTffqDfeeKMsM+6E6dOnq7179zaWdOfm5qpJSUnqL37xC6ONnOvOqaioUL/++mv166+/VgH16aefVr/++mv18OHDqqp27Lxec8016siRI9XNmzern376qTpgwABZ0h0qnn32WbVPnz6q1WpVx4wZo37xxRfB7lK3B7R6e+mll4w2NTU16n//93+r8fHxakREhPrjH/9YLSgoCF6ne4jmQY2cZ996//331SFDhqg2m00dOHCg+pe//MXjcZfLpf7mN79RU1JSVJvNpl511VXq3r17g9Tb7qu8vFydPXu22qdPH9Vut6v9+/dXf/WrX6l1dXVGGznXnbN+/fpWP5+nT5+uqmrHzuupU6fUqVOnqlFRUWpMTIw6Y8YMtaKiwi/9VVTVreSiEEIIIUQ3JTk1QgghhOgRJKgRQgghRI8gQY0QQgghegQJaoQQQgjRI0hQI4QQQogeQYIaIYQQQvQIEtQIIYQQokeQoEYIIYQQPYIENUIIIYToESSoEUIIIUSPIEGNEEIIIXoECWqEEEII0SP8f/Jd51hKJolQAAAAAElFTkSuQmCC" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "\n", + "filter_result = essm.forward_filter(samples.observation)\n", + "\n", + "filter_state = essm.create_initial_filter_state()\n", + "\n", + "\n", + "for i in range(100):\n", + " filter_state = essm.incremental_predict(filter_state)\n", + " # incorperate new data\n", + " filter_state, _ = essm.incremental_update(filter_state, samples.observation[i])\n", + " \n", + " plt.scatter(filter_state.t, filter_state.filtered_mean, c='black')\n", + " \n", + "\n", + "plt.plot(filter_result.t, filter_result.filtered_mean[:, 0], label='filtered latent')\n", + "plt.legend()\n", + "plt.show()" + ], + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2024-08-14T22:35:58.070899470Z", + "start_time": "2024-08-14T22:35:52.776945243Z" + } + }, + "id": "initial_id", + "execution_count": 3 + }, + { + "cell_type": "code", + "outputs": [], + "source": [], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-14T22:35:58.071206840Z", + "start_time": "2024-08-14T22:35:58.066994695Z" + } + }, + "id": "6626b28fc500ba6d", + "execution_count": 3 + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/essm_jax/essm.py b/essm_jax/essm.py index 71fb768..93ba669 100644 --- a/essm_jax/essm.py +++ b/essm_jax/essm.py @@ -87,19 +87,19 @@ class ExtendedStateSpaceModel: Args: transition_fn: A function that computes the state transition distribution - p(z[t] | z[t-1], t). Must return a MultivariateNormalLinearOperator. - Call signature is `transition_fn(z[t-1], t)`, where z[t-1] is the previous state. + p(z(t'), t' | z(t), t). Must return a MultivariateNormalLinearOperator. + Call signature is `transition_fn(z(t), t, t')`, where z(t) is the previous state. observation_fn: A function that computes the observation distribution - p(x[t] | z[t], t). Must return a MultivariateNormalLinearOperator. - Call signature is `observation_fn(z[t], t)`, where z[t] is the current state. - Note: t in [1, num_time] is the observation time index, with t=0 being the initial state. - initial_state_prior: A distribution over the initial state p(z[0]). + p(x(t) | z(t), t). Must return a MultivariateNormalLinearOperator. + Call signature is `observation_fn(z(t), t)`, where z(t) is the current state. + Note: t is the observation time, with t=t0 being the initial state. + initial_state_prior: A distribution over the initial state p(z(t0)). Must be a MultivariateNormalLinearOperator. more_data_than_params: If True, the observation function has more outputs than inputs. materialise_jacobians: If True, the Jacobians are materialised as dense matrices. - dt: The time step size, default is 1. + dt: The time step size, default is 1. t[i] = t0 + (i+1) * dt """ - transition_fn: Callable[[jax.Array, jax.Array], tfpd.MultivariateNormalLinearOperator] + transition_fn: Callable[[jax.Array, jax.Array, jax.Array], tfpd.MultivariateNormalLinearOperator] observation_fn: Callable[[jax.Array, jax.Array], tfpd.MultivariateNormalLinearOperator] initial_state_prior: tfpd.MultivariateNormalLinearOperator more_data_than_params: bool = False @@ -119,9 +119,9 @@ def __post_init__(self): self.latent_size = np.size(_initial_state_prior_mean) self.latent_shape = np.shape(_initial_state_prior_mean) - def get_transition_jacobian(self, t: jax.Array) -> JVPLinearOp: + def get_transition_jacobian(self, t: jax.Array, t_next: jax.Array) -> JVPLinearOp: def _transition_fn(z): - return self.transition_fn(z, t).mean() + return self.transition_fn(z, t, t_next).mean() return JVPLinearOp(_transition_fn, more_outputs_than_inputs=False) @@ -134,15 +134,15 @@ def _observation_fn(z): more_data_than_params = self.latent_size < observation_size return JVPLinearOp(_observation_fn, more_outputs_than_inputs=more_data_than_params) - def transition_matrix(self, z, t): - Fop = self.get_transition_jacobian(t) + def transition_matrix(self, z, t, t_next): + Fop = self.get_transition_jacobian(t, t_next) return Fop(z).to_dense() def observation_matrix(self, z, t): Hop = self.get_observation_jacobian(t) return Hop(z).to_dense() - def sample(self, key, num_time: int, t0: Union[jax.Array, int] = 0) -> SampleResult: + def sample(self, key, num_time: int, t0: Union[jax.Array, float] = 0.) -> SampleResult: """ Sample from the model. @@ -162,19 +162,22 @@ def sample(self, key, num_time: int, t0: Union[jax.Array, int] = 0) -> SampleRes # observation: S -> O def _sample_latents_op(latent, y): - (key, t) = y + (key, t, t_next) = y new_latent_key, obs_key = jax.random.split(key, 2) - transition_dist = self.transition_fn(latent, t) + transition_dist = self.transition_fn(latent, t, t_next) new_latent = transition_dist.sample(seed=new_latent_key) observation_dist = self.observation_fn(new_latent, t) new_observation = observation_dist.sample(seed=obs_key) - return new_latent, SampleResult(t=t, latent=new_latent, observation=new_observation) + return new_latent, SampleResult(t=t_next, latent=new_latent, observation=new_observation) - # Sample at t0 + # Sample at t0 forming initial state init = self.initial_state_prior.sample(seed=init_key) + t_from = jnp.arange(0, num_time) * self.dt + t0 + t_to = t_from + self.dt xs = ( jax.random.split(latent_key, num_time), - jnp.arange(1, num_time + 1) * self.dt + t0 + t_from, + t_to ) _, samples = lax.scan( _sample_latents_op, @@ -239,7 +242,8 @@ def forward_simulate(self, key: jax.Array, num_time: int, observation_fn=self.observation_fn, initial_state_prior=initial_state_prior, more_data_than_params=self.more_data_than_params, - materialise_jacobians=self.materialise_jacobians + materialise_jacobians=self.materialise_jacobians, + dt=self.dt ) return new_essm.sample(key=key, num_time=num_time, t0=t0) @@ -327,25 +331,42 @@ def incremental_predict(self, filter_state: IncrementalFilterState) -> Increment """ # Predict step, compute p(z[t+1] | x[:t]) - t = filter_state.t + jnp.asarray(self.dt, filter_state.t.dtype) + t_next = filter_state.t + jnp.asarray(self.dt, filter_state.t.dtype) - Fop = self.get_transition_jacobian(t=t) + Fop = self.get_transition_jacobian(t=filter_state.t, t_next=t_next) F = Fop(filter_state.filtered_mean) if self.materialise_jacobians: F = F.to_dense() - predicted_dist = self.transition_fn(filter_state.filtered_mean, t) + predicted_dist = self.transition_fn(filter_state.filtered_mean, filter_state.t, t_next) predicted_mean = predicted_dist.mean() # [latent_size] Q = predicted_dist.covariance() # [latent_size, latent_size] predicted_cov = F @ filter_state.filtered_cov @ F.T + Q # [latent_size, latent_size] return IncrementalFilterState( - t=t, + t=t_next, log_cumulative_marginal_likelihood=filter_state.log_cumulative_marginal_likelihood, filtered_mean=predicted_mean, filtered_cov=predicted_cov ) + def create_filter_state(self, filter_result: FilterResult) -> IncrementalFilterState: + """ + Create an incremental filter state from a filter result. + + Args: + filter_result: the filter result + + Returns: + the incremental filter state + """ + return IncrementalFilterState( + t=filter_result.t[-1], + log_cumulative_marginal_likelihood=filter_result.log_cumulative_marginal_likelihood[-1], + filtered_mean=filter_result.filtered_mean[-1], + filtered_cov=filter_result.filtered_cov[-1] + ) + def create_initial_filter_state(self, t0: Union[jax.Array, float] = 0.) -> IncrementalFilterState: """ Create an incremental filter at the initial time. @@ -358,20 +379,11 @@ def create_initial_filter_state(self, t0: Union[jax.Array, float] = 0.) -> Incre """ # Push forward initial state to create p(z[1] | z[0]) t0 = jnp.asarray(t0, jnp.float32) - t1 = t0 + jnp.asarray(self.dt, t0.dtype) - init_predict_dist = self.transition_fn(self.initial_state_prior.mean(), t1) - - init_Fop = self.get_transition_jacobian(t1) - init_F = init_Fop(self.initial_state_prior.mean()) - if self.materialise_jacobians: - init_F = init_F.to_dense() - init_predicted_mean = init_predict_dist.mean() - init_predicted_cov = init_F @ self.initial_state_prior.covariance() @ init_F.T + init_predict_dist.covariance() return IncrementalFilterState( - t=t1, + t=t0, log_cumulative_marginal_likelihood=jnp.asarray(0.), - filtered_mean=init_predicted_mean, - filtered_cov=init_predicted_cov + filtered_mean=self.initial_state_prior.mean(), + filtered_cov=self.initial_state_prior.covariance() ) def forward_filter(self, observations: jax.Array, mask: Optional[jax.Array] = None, @@ -433,6 +445,8 @@ def _filter_op(filter_state: IncrementalFilterState, y: YType) -> Tuple[Incremen ) filter_state = self.create_initial_filter_state(t0=t0) + filter_state = self.incremental_predict(filter_state) # Advance to first update time + if mask is None: _mask = jnp.zeros(num_time, dtype=jnp.bool_) # dummy variable (we skip the mask select) else: @@ -503,33 +517,40 @@ class Carry(NamedTuple): smoothed_cov=filter_result.predicted_cov[-1] # [latent_size, latent_size] covariance of p(z[T] | x[:T]) ) - def _smooth_op(carry: Carry, y: FilterResult) -> Tuple[Carry, SmoothingResult]: + class XType(NamedTuple): + t: jax.Array # time t + filtered_mean: jax.Array # [latent_size] mean of p(z[t] | x[:t]) + filtered_cov: jax.Array # [latent_size, latent_size] covariance of p(z[t] | x[:t]) + predicted_mean: jax.Array # [latent_size] mean of p(z[t+1] | x[:t]) + predicted_cov: jax.Array # [latent_size, latent_size] covariance of p(z[t+1] | x[:t]) + + def _smooth_op(carry: Carry, x: XType) -> Tuple[Carry, SmoothingResult]: """ A single step of the backward equations. """ - Fop = self.get_transition_jacobian(t=y.t) - F = Fop(y.filtered_mean) + Fop = self.get_transition_jacobian(t=x.t - self.dt, t_next=x.t) + F = Fop(x.filtered_mean) if self.materialise_jacobians: F = F.to_dense() - # Compute the RTS smoother gain + # Compute the C # J = y.filtered_cov @ F.T @ jnp.linalg.inv(y.predicted_cov) - predicted_cov_chol = lax.linalg.cholesky(y.predicted_cov, + predicted_cov_chol = lax.linalg.cholesky(_efficient_add_scalar_diag(x.predicted_cov, 1e-6), symmetrize_input=False) # Possibly need to add a small diagonal jitter - tmp_F_P = F @ y.filtered_cov - J = hpsd_solve(y.predicted_cov, tmp_F_P, cholesky_matrix=predicted_cov_chol).T + tmp_F_P = F @ x.filtered_cov + J = hpsd_solve(x.predicted_cov, tmp_F_P, cholesky_matrix=predicted_cov_chol).T # Update the state estimate - smoothed_mean = y.filtered_mean + J @ (carry.smoothed_mean - y.predicted_mean) + smoothed_mean = x.filtered_mean + J @ (carry.smoothed_mean - x.predicted_mean) # Update the state covariance - smoothed_cov = y.filtered_cov + J @ (carry.smoothed_cov - y.predicted_cov) @ J.T + smoothed_cov = x.filtered_cov + J @ (carry.smoothed_cov - x.predicted_cov) @ J.T # Push-forward the smoothed distribution to the observation space - observation_dist = self.observation_fn(smoothed_mean, y.t) + observation_dist = self.observation_fn(smoothed_mean, x.t) R = observation_dist.covariance() - Hop = self.get_observation_jacobian(t=y.t, observation_size=y.observation_mean.size) + Hop = self.get_observation_jacobian(t=x.t, observation_size=np.shape(filter_result.observation_mean)[1]) H = Hop(smoothed_mean) if self.materialise_jacobians: H = H.to_dense() @@ -540,49 +561,46 @@ def _smooth_op(carry: Carry, y: FilterResult) -> Tuple[Carry, SmoothingResult]: smoothed_mean=smoothed_mean, smoothed_cov=smoothed_cov ), SmoothingResult( - t=y.t, + t=x.t, smoothed_mean=smoothed_mean, smoothed_cov=smoothed_cov, smoothed_obs_mean=smoothed_obs_mean, smoothed_obs_cov=smoothed_obs_cov ) + xs = XType( + t=filter_result.t, + filtered_mean=filter_result.filtered_mean, + filtered_cov=filter_result.filtered_cov, + predicted_mean=filter_result.predicted_mean, + predicted_cov=filter_result.predicted_cov + ) + if include_prior: + # prepend the initial state + t0 = xs.t[0] - self.dt + init_filter_state = self.create_initial_filter_state(t0=t0) + init_predict_state = self.incremental_predict(init_filter_state) + xs = XType( + t=jnp.concatenate([jnp.asarray(t0)[None], xs.t], axis=0), + filtered_mean=jnp.concatenate([init_filter_state.filtered_mean[None], xs.filtered_mean], axis=0), + filtered_cov=jnp.concatenate([init_filter_state.filtered_cov[None], xs.filtered_cov], axis=0), + predicted_mean=jnp.concatenate([init_predict_state.filtered_mean[None], xs.predicted_mean], axis=0), + predicted_cov=jnp.concatenate([init_predict_state.filtered_cov[None], xs.predicted_cov], axis=0) + ) + final_carry, smooth_results = lax.scan( f=_smooth_op, init=init_carry, - xs=filter_result, + xs=xs, reverse=True ) if include_prior: - # Transition prior to compute predicted p(z1 | z0) - t1 = filter_result.t[0] - init_predict_dist = self.transition_fn(self.initial_state_prior.mean(), t1) - init_Fop = self.get_transition_jacobian(t=t1) - init_F = init_Fop(self.initial_state_prior.mean()) - if self.materialise_jacobians: - init_F = init_F.to_dense() - init_predicted_mean = init_predict_dist.mean() - init_predicted_cov = init_F @ self.initial_state_prior.covariance() @ init_F.T + init_predict_dist.covariance() - - t0 = t1 - jnp.asarray(self.dt, t1.dtype) - y = FilterResult( - t=t0, - log_cumulative_marginal_likelihood=jnp.asarray(0.), - filtered_mean=self.initial_state_prior.mean(), - filtered_cov=self.initial_state_prior.covariance(), - predicted_mean=init_predicted_mean, - predicted_cov=init_predicted_cov, - observation_mean=filter_result.observation_mean[0], # dummy unused - observation_cov=filter_result.observation_cov[0] # dummy unused - ) - final_initial_prior_carry, _ = _smooth_op( - carry=final_carry, - y=y - ) + # Trim + smooth_results = jax.tree.map(lambda x: x[1:], smooth_results) smoothed_prior = InitialPrior( - mean=final_initial_prior_carry.smoothed_mean, - covariance=final_initial_prior_carry.smoothed_cov + mean=final_carry.smoothed_mean, + covariance=final_carry.smoothed_cov ) return smooth_results, smoothed_prior diff --git a/essm_jax/tests/test_essm.py b/essm_jax/tests/test_essm.py index da7f0cf..f55b72e 100644 --- a/essm_jax/tests/test_essm.py +++ b/essm_jax/tests/test_essm.py @@ -18,7 +18,7 @@ def test_extended_state_space_model(): num_time = 10 - def transition_fn(z, t): + def transition_fn(z, t, t_next): mean = 2 * z cov = jnp.eye(np.size(z)) return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov)) @@ -137,7 +137,7 @@ def _compare(key): def test_jvp_essm(): - def transition_fn(z, t): + def transition_fn(z, t, t_next): mean = jnp.sin(2 * z) cov = jnp.eye(np.size(z)) return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov)) @@ -193,7 +193,7 @@ def observation_fn(z, t): def test_speed_test_jvp_essm(): - def transition_fn(z, t): + def transition_fn(z, t, t_next): mean = jnp.sin(2 * z + t) cov = jnp.eye(np.size(z)) return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov)) @@ -245,14 +245,14 @@ def observation_fn(z, t): def test_essm_forward_simulation(): - def transition_fn(z, t): + def transition_fn(z, t, t_next): mean = z + jnp.sin(2 * jnp.pi * t / 10 * z) cov = 0.1 * jnp.eye(np.size(z)) return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov)) def observation_fn(z, t): mean = z - cov = t * 0.01 * jnp.eye(np.size(z)) + cov = 0.01 * jnp.eye(np.size(z)) return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov)) n = 1 @@ -273,17 +273,19 @@ def observation_fn(z, t): # Suppose we only observe every 3rd observation mask = jnp.arange(T) % 3 != 0 - # Marginal likelihood, p(x[:]) = prod_t p(x[t] | x[:t-1]) - log_prob = essm.log_prob(samples.observation, mask=mask) - print(log_prob) - # Filtered latent distribution, p(z[t] | x[:t]) filter_result = essm.forward_filter(samples.observation, mask=mask) + assert np.all(np.isfinite(filter_result.log_cumulative_marginal_likelihood)) + assert np.all(np.isfinite(filter_result.filtered_mean)) + + # Marginal likelihood, p(x[:]) = prod_t p(x[t] | x[:t-1]) + log_prob = essm.log_prob(samples.observation, mask=mask) + assert log_prob == filter_result.log_cumulative_marginal_likelihood[-1] # Smoothed latent distribution, p(z[t] | x[:]), i.e. past latents given all future observations # Including new estimate for prior state p(z[0]) smooth_result, posterior_prior = essm.backward_smooth(filter_result, include_prior=True) - print(smooth_result) + assert np.all(np.isfinite(smooth_result.smoothed_mean)) # Forward simulate the model forward_samples = essm.forward_simulate( @@ -322,7 +324,7 @@ def test__efficienct_add_scalar_diag(): def test_incremental_filtering(): - def transition_fn(z, t): + def transition_fn(z, t, t_next): mean = z + z * jnp.sin(2 * jnp.pi * t / 10) cov = 0.1 * jnp.eye(np.size(z)) return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov)) @@ -349,26 +351,14 @@ def observation_fn(z, t): filter_state = essm.create_initial_filter_state() - import pylab as plt - for i in range(100): + filter_state = essm.incremental_predict(filter_state) filter_state, _ = essm.incremental_update(filter_state, samples.observation[i]) - plt.scatter(filter_state.t, filter_state.filtered_mean, c='black') assert filter_state.t == filter_result.t[i] - filter_state = essm.incremental_predict(filter_state) - # print(i, np.abs( - # filter_state.log_cumulative_marginal_likelihood - filter_result.log_cumulative_marginal_likelihood[i])) - # print(i, np.max(np.abs(filter_state.filtered_mean - filter_result.predicted_mean[i]))) - # print(i, np.max(np.abs(filter_state.filtered_cov - filter_result.predicted_cov[i]))) - # print(i, np.max(np.abs(filter_state.filtered_cov))) np.testing.assert_allclose(filter_state.log_cumulative_marginal_likelihood, filter_result.log_cumulative_marginal_likelihood[i], atol=1e-5) - np.testing.assert_allclose(filter_state.filtered_mean, filter_result.predicted_mean[i], atol=1e-5) - np.testing.assert_allclose(filter_state.filtered_cov, filter_result.predicted_cov[i], atol=1e-5) - - plt.plot(filter_result.t, filter_result.filtered_mean[:, 0], label='filtered latent') - plt.legend() - plt.show() + np.testing.assert_allclose(filter_state.filtered_mean, filter_result.filtered_mean[i], atol=1e-5) + np.testing.assert_allclose(filter_state.filtered_cov, filter_result.filtered_cov[i], atol=1e-5) @pytest.mark.parametrize('use_sparse', [False, True]) @@ -386,7 +376,7 @@ def test_performance_sparse(use_sparse: bool): else: m = jnp.asarray(m) - def transition_fn(z, t): + def transition_fn(z, t, t_next): if use_sparse: mean = matvec_sparse(m, z) else: From d4948f66886b0d657c4c6ec663deea6b13d50da9 Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Thu, 15 Aug 2024 01:36:57 +0300 Subject: [PATCH 6/8] * bump to 1.0.1 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 83220de..bf6d268 100755 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ long_description = fh.read() setup(name='essm_jax', - version='1.0.0', + version='1.0.1', description='Extended State Spapce Model in JAX', long_description=long_description, long_description_content_type="text/markdown", From 0c51ac825aca8f07f6e6bef591bf539655110273 Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Thu, 15 Aug 2024 01:39:47 +0300 Subject: [PATCH 7/8] * fix test --- essm_jax/tests/test_essm.py | 2 +- essm_jax/tests/test_jvp_op.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/essm_jax/tests/test_essm.py b/essm_jax/tests/test_essm.py index f55b72e..e9ae0b8 100644 --- a/essm_jax/tests/test_essm.py +++ b/essm_jax/tests/test_essm.py @@ -3,9 +3,9 @@ import jax import pytest +jax.config.update('jax_enable_x64', True) from essm_jax.sparse import create_sparse_rep, matvec_sparse -jax.config.update('jax_enable_x64', True) import numpy as np import tensorflow_probability.substrates.jax as tfp from jax import numpy as jnp diff --git a/essm_jax/tests/test_jvp_op.py b/essm_jax/tests/test_jvp_op.py index fe3aaf6..5022f45 100644 --- a/essm_jax/tests/test_jvp_op.py +++ b/essm_jax/tests/test_jvp_op.py @@ -1,4 +1,7 @@ +import jax import jax.numpy as jnp + +jax.config.update('jax_enable_x64', True) import numpy as np import pytest From 0aa3bd4d154ea6805067049eca5dd55a365939c2 Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Thu, 15 Aug 2024 01:43:29 +0300 Subject: [PATCH 8/8] * fix test --- essm_jax/tests/test_jvp_op.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/essm_jax/tests/test_jvp_op.py b/essm_jax/tests/test_jvp_op.py index 5022f45..1799632 100644 --- a/essm_jax/tests/test_jvp_op.py +++ b/essm_jax/tests/test_jvp_op.py @@ -16,7 +16,7 @@ def test_jvp_linear_op(): def fn(x): return jnp.asarray([jnp.sum(jnp.sin(x) ** i) for i in range(m)]) - x = jnp.arange(n).astype(jnp.float32) + x = jnp.arange(n).astype(float) jvp_op = JVPLinearOp(fn) jvp_op = jvp_op(x) @@ -73,8 +73,8 @@ def test_multiple_primals(init_primals: bool): def fn(x, y): return jnp.stack([x * y, y, -y], axis=-1) # [n, 3] - x = jnp.arange(n).astype(jnp.float32) - y = jnp.arange(n).astype(jnp.float32) + x = jnp.arange(n).astype(float) + y = jnp.arange(n).astype(float) if init_primals: jvp_op = JVPLinearOp(fn, primals=(x, y)) else: