Skip to content

Commit

Permalink
Change axis in weight constraint
Browse files Browse the repository at this point in the history
  • Loading branch information
APJansen committed Mar 5, 2024
1 parent 8c95d73 commit 067d7ad
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions n3fit/src/n3fit/backends/keras_backend/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
"""

import tensorflow as tf
from tensorflow.keras.constraints import MinMaxNorm
from tensorflow.keras import backend as K
from tensorflow.keras.constraints import MinMaxNorm


class MinMaxWeight(MinMaxNorm):
Expand All @@ -14,15 +14,12 @@ class MinMaxWeight(MinMaxNorm):
"""

def __init__(self, min_value, max_value, **kwargs):
super(MinMaxWeight, self).__init__(
min_value=min_value, max_value=max_value, **kwargs
)
super().__init__(min_value=min_value, max_value=max_value, axis=1, **kwargs)

@tf.function
def __call__(self, w):
norms = K.sum(w, axis=self.axis, keepdims=True)
desired = (
self.rate * K.clip(norms, self.min_value, self.max_value)
+ (1 - self.rate) * norms
self.rate * K.clip(norms, self.min_value, self.max_value) + (1 - self.rate) * norms
)
return w * desired / (K.epsilon() + norms)

0 comments on commit 067d7ad

Please sign in to comment.