From 0298f5f0048fe3bd91bfe82b14d3ea14a6f91209 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 4 Oct 2023 15:35:45 +0200 Subject: [PATCH] simplified FeatureMapOptions --- apax/bal/api.py | 3 +-- apax/bal/feature_maps.py | 9 ++++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/apax/bal/api.py b/apax/bal/api.py index c62d267e..09d07c7d 100644 --- a/apax/bal/api.py +++ b/apax/bal/api.py @@ -77,8 +77,7 @@ def kernel_selection( "max_dist": selection.max_dist_selection, }[selection_method] - base_feature_config = feature_maps.FeatureMapOptions.model_validate(base_fm_options) - base_feature_map = base_feature_config.base_feature_map + base_feature_map = feature_maps.FeatureMapOptions(base_fm_options) config, params = restore_parameters(model_dir) diff --git a/apax/bal/feature_maps.py b/apax/bal/feature_maps.py index 8cf29005..ce8ba2b5 100644 --- a/apax/bal/feature_maps.py +++ b/apax/bal/feature_maps.py @@ -3,7 +3,7 @@ import jax import jax.numpy as jnp from flax.traverse_util import flatten_dict, unflatten_dict -from pydantic import BaseModel, Field +from pydantic import BaseModel, TypeAdapter def extract_feature_params(params: dict, layer_name: str) -> Tuple[dict, dict]: @@ -76,7 +76,6 @@ def apply(self, model): return model.apply -class FeatureMapOptions(BaseModel, extra="forbid"): - base_feature_map: Union[LastLayerGradientFeatures, IdentityFeatures] = Field( - LastLayerGradientFeatures(name="ll_grad"), discriminator="name" - ) +FeatureMapOptions = TypeAdapter( + Union[LastLayerGradientFeatures, IdentityFeatures] +).validate_python