Skip to content

Commit

Permalink
simplified FeatureMapOptions
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer committed Oct 4, 2023
1 parent ea4e273 commit 0298f5f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
3 changes: 1 addition & 2 deletions apax/bal/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def kernel_selection(
"max_dist": selection.max_dist_selection,
}[selection_method]

base_feature_config = feature_maps.FeatureMapOptions.model_validate(base_fm_options)
base_feature_map = base_feature_config.base_feature_map
base_feature_map = feature_maps.FeatureMapOptions(base_fm_options)

config, params = restore_parameters(model_dir)

Expand Down
9 changes: 4 additions & 5 deletions apax/bal/feature_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import jax
import jax.numpy as jnp
from flax.traverse_util import flatten_dict, unflatten_dict
from pydantic import BaseModel, Field
from pydantic import BaseModel, TypeAdapter


def extract_feature_params(params: dict, layer_name: str) -> Tuple[dict, dict]:
Expand Down Expand Up @@ -76,7 +76,6 @@ def apply(self, model):
return model.apply


class FeatureMapOptions(BaseModel, extra="forbid"):
base_feature_map: Union[LastLayerGradientFeatures, IdentityFeatures] = Field(
LastLayerGradientFeatures(name="ll_grad"), discriminator="name"
)
FeatureMapOptions = TypeAdapter(
Union[LastLayerGradientFeatures, IdentityFeatures]
).validate_python

0 comments on commit 0298f5f

Please sign in to comment.