Skip to content

Commit 7d3c77c

Browse files
authored
Fix re-compilation bugs (#1541)
The primary bug this is trying to fix is to get rid of weird re-compilation behaviors with our default compilation. E.g. create a `GemmaCausalLM`, generate some text without specifying a sampler, `compile()` again without a sampler, generation will have switched from `"greedy"` -> `"top_k"`. Create a `BertClassifier`, `fit()`, `compile()` again without specifying an optimizer, optimizer will have switch from `"adam"` to `"rmsprop"`. The way I am trying to fix this is by leaning a little more heavily on the `"auto"` style option we introduced for `jit_compile`. KerasNLP tasks will by default use `loss="auto"` and `optimizer="auto"`, which resolve to a default for a given task. Since we override compile with these in the signature, recompilation will not silently change behavior.
1 parent 9ac3335 commit 7d3c77c

28 files changed

+211
-316
lines changed

keras_nlp/models/albert/albert_classifier.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -186,17 +186,6 @@ def __init__(
186186
self.activation = keras.activations.get(activation)
187187
self.dropout = dropout
188188

189-
# === Default compilation ===
190-
logit_output = self.activation == keras.activations.linear
191-
self.compile(
192-
loss=keras.losses.SparseCategoricalCrossentropy(
193-
from_logits=logit_output
194-
),
195-
optimizer=keras.optimizers.Adam(5e-5),
196-
metrics=[keras.metrics.SparseCategoricalAccuracy()],
197-
jit_compile=True,
198-
)
199-
200189
def get_config(self):
201190
config = super().get_config()
202191
config.update(

keras_nlp/models/albert/albert_masked_lm.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,3 @@ def __init__(self, backbone, preprocessor=None, **kwargs):
124124
outputs=outputs,
125125
**kwargs,
126126
)
127-
128-
# === Default compilation ===
129-
self.compile(
130-
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
131-
optimizer=keras.optimizers.Adam(5e-5),
132-
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
133-
jit_compile=True,
134-
)

keras_nlp/models/bart/bart_seq_2_seq_lm.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515

1616
from keras_nlp.api_export import keras_nlp_export
17-
from keras_nlp.backend import keras
1817
from keras_nlp.backend import ops
1918
from keras_nlp.models.bart.bart_backbone import BartBackbone
2019
from keras_nlp.models.bart.bart_seq_2_seq_lm_preprocessor import (
@@ -200,14 +199,6 @@ def __init__(
200199
**kwargs,
201200
)
202201

203-
# === Default compilation ===
204-
self.compile(
205-
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
206-
optimizer=keras.optimizers.Adam(2e-5),
207-
metrics=[keras.metrics.SparseCategoricalAccuracy()],
208-
jit_compile=True,
209-
)
210-
211202
def call_decoder_with_cache(
212203
self,
213204
encoder_hidden_states,
@@ -460,7 +451,7 @@ def repeat_tensor(x):
460451
cache,
461452
)
462453

463-
decoder_token_ids = self._sampler(
454+
decoder_token_ids = self.sampler(
464455
next=next,
465456
prompt=decoder_token_ids,
466457
cache=self_attention_cache,

keras_nlp/models/bert/bert_classifier.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -170,17 +170,6 @@ def __init__(
170170
self.activation = keras.activations.get(activation)
171171
self.dropout = dropout
172172

173-
# === Default compilation ===
174-
logit_output = self.activation == keras.activations.linear
175-
self.compile(
176-
loss=keras.losses.SparseCategoricalCrossentropy(
177-
from_logits=logit_output
178-
),
179-
optimizer=keras.optimizers.Adam(5e-5),
180-
metrics=[keras.metrics.SparseCategoricalAccuracy()],
181-
jit_compile=True,
182-
)
183-
184173
def get_config(self):
185174
config = super().get_config()
186175
config.update(

keras_nlp/models/bert/bert_masked_lm.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,3 @@ def __init__(
128128
outputs=outputs,
129129
**kwargs,
130130
)
131-
132-
# === Default compilation ===
133-
self.backbone = backbone
134-
self.preprocessor = preprocessor
135-
self.compile(
136-
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
137-
optimizer=keras.optimizers.Adam(5e-5),
138-
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
139-
jit_compile=True,
140-
)

keras_nlp/models/bloom/bloom_causal_lm.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515

1616
from keras_nlp.api_export import keras_nlp_export
17-
from keras_nlp.backend import keras
1817
from keras_nlp.backend import ops
1918
from keras_nlp.models.bloom.bloom_backbone import BloomBackbone
2019
from keras_nlp.models.bloom.bloom_causal_lm_preprocessor import (
@@ -167,15 +166,6 @@ def __init__(
167166
**kwargs,
168167
)
169168

170-
# === Default compilation ===
171-
self.compile(
172-
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
173-
optimizer=keras.optimizers.Adam(2e-5),
174-
metrics=[keras.metrics.SparseCategoricalAccuracy()],
175-
sampler="greedy",
176-
jit_compile=True,
177-
)
178-
179169
def call_with_cache(
180170
self,
181171
token_ids,
@@ -273,7 +263,7 @@ def next(prompt, cache, index):
273263
cache,
274264
)
275265

276-
token_ids = self._sampler(
266+
token_ids = self.sampler(
277267
next=next,
278268
prompt=token_ids,
279269
cache=cache,

keras_nlp/models/causal_lm.py

Lines changed: 64 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,23 +68,74 @@ class CausalLM(Task):
6868
```
6969
"""
7070

71+
def __init__(self, *args, **kwargs):
72+
super().__init__(*args, **kwargs)
73+
# Default compilation.
74+
self.compile()
75+
7176
def compile(
7277
self,
73-
*args,
74-
run_eagerly=False,
75-
jit_compile=True,
78+
optimizer="auto",
79+
loss="auto",
80+
*,
81+
weighted_metrics="auto",
7682
sampler="top_k",
7783
**kwargs,
7884
):
79-
xla_compatible = True
85+
"""Configures the `CausalLM` task for training and generation.
86+
87+
The `CausalLM` task extends the default compilation signature of
88+
`keras.Model.compile` with defaults for `optimizer`, `loss`, and
89+
`weighted_metrics`. To override these defaults, pass any value
90+
to these arguments during compilation.
91+
92+
The `CausalLM` task adds a new `sampler` to `compile`, which can be used
93+
to control the sampling strategy used with the `generate` function.
94+
95+
Note that because training inputs include padded tokens which are
96+
excluded from the loss, it is almost always a good idea to compile with
97+
`weighted_metrics` and not `metrics`.
98+
99+
Args:
100+
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
101+
instance. Defaults to `"auto"`, which uses the default optimizer
102+
for the given model and task. See `keras.Model.compile` and
103+
`keras.optimizers` for more info on possible `optimizer` values.
104+
loss: `"auto"', a loss name, or a `keras.losses.Loss` instance.
105+
Defaults to `"auto"`, where a
106+
`keras.losses.SparseCategoricalCrossentropy` loss will be
107+
applied for the token classification `CausalLM` task. See
108+
`keras.Model.compile` and `keras.losses` for more info on
109+
possible `loss` values.
110+
weighted_metrics: `"auto"`, or a list of metrics to be evaluated by
111+
the model during training and testing. Defaults to `"auto"`,
112+
where a `keras.metrics.SparseCategoricalAccuracy` will be
113+
applied to track the accuracy of the model at guessing masked
114+
token values. See `keras.Model.compile` and `keras.metrics` for
115+
more info on possible `weighted_metrics` values.
116+
sampler: A sampler name, or a `keras_nlp.samplers.Sampler` instance.
117+
Configures the sampling method used during `generate()` calls.
118+
See `keras_nlp.samplers` for a full list of built-in sampling
119+
strategies.
120+
**kwargs: See `keras.Model.compile` for a full list of arguments
121+
supported by the compile method.
122+
"""
123+
if optimizer == "auto":
124+
optimizer = keras.optimizers.Adam(2e-5)
125+
if loss == "auto":
126+
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
127+
if weighted_metrics == "auto":
128+
weighted_metrics = [keras.metrics.SparseCategoricalAccuracy()]
129+
# Keras 2 does not jit_compile by default.
130+
if not config.keras_3():
131+
kwargs["jit_compile"] = True
80132
super().compile(
81-
*args,
82-
run_eagerly=run_eagerly,
83-
# Only `jit_compile` if not eager and in a compatible environment.
84-
jit_compile=jit_compile and xla_compatible and not run_eagerly,
133+
optimizer=optimizer,
134+
loss=loss,
135+
weighted_metrics=weighted_metrics,
85136
**kwargs,
86137
)
87-
self._sampler = get_sampler(sampler)
138+
self.sampler = get_sampler(sampler)
88139
# Clear the compiled generate function.
89140
self.generate_function = None
90141

@@ -127,7 +178,7 @@ def compiled_generate_function(inputs, stop_token_ids, state):
127178
non_trainable_variables,
128179
) = state
129180
mapping = itertools.chain(
130-
zip(self._sampler.variables, sampler_variables),
181+
zip(self.sampler.variables, sampler_variables),
131182
zip(self.trainable_variables, trainable_variables),
132183
zip(self.non_trainable_variables, non_trainable_variables),
133184
)
@@ -137,7 +188,7 @@ def compiled_generate_function(inputs, stop_token_ids, state):
137188

138189
# Get updated sampler variables from the stateless scope.
139190
sampler_variables = []
140-
for v in self._sampler.variables:
191+
for v in self.sampler.variables:
141192
new_v = scope.get_current_value(v)
142193
sampler_variables.append(new_v if new_v is not None else v)
143194
return outputs, sampler_variables
@@ -151,7 +202,7 @@ def wrapped_generate_function(
151202

152203
# Create an explicit tuple of all variable state.
153204
state = (
154-
self._sampler.variables,
205+
self.sampler.variables,
155206
# Use the explicit variable.value to preserve the
156207
# sharding spec of distribution.
157208
[v.value for v in self.trainable_variables],
@@ -165,7 +216,7 @@ def wrapped_generate_function(
165216
)
166217
# Only assign the sampler variables (random seeds), as other
167218
# model variables should never be updated in generation.
168-
for ref_v, v in zip(self._sampler.variables, sampler_variables):
219+
for ref_v, v in zip(self.sampler.variables, sampler_variables):
169220
ref_v.assign(v)
170221
return outputs
171222

keras_nlp/models/classifier.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from keras_nlp.api_export import keras_nlp_export
15+
from keras_nlp.backend import config
16+
from keras_nlp.backend import keras
1517
from keras_nlp.models.task import Task
1618

1719

@@ -49,3 +51,62 @@ class Classifier(Task):
4951
classifier.predict(["What an amazing movie!", "A total waste of my time."])
5052
```
5153
"""
54+
55+
def __init__(self, *args, **kwargs):
56+
super().__init__(*args, **kwargs)
57+
# Default compilation.
58+
self.compile()
59+
60+
def compile(
61+
self,
62+
optimizer="auto",
63+
loss="auto",
64+
*,
65+
metrics="auto",
66+
**kwargs,
67+
):
68+
"""Configures the `Classifier` task for training.
69+
70+
The `Classifier` task extends the default compilation signature of
71+
`keras.Model.compile` with defaults for `optimizer`, `loss`, and
72+
`metrics`. To override these defaults, pass any value
73+
to these arguments during compilation.
74+
75+
Args:
76+
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
77+
instance. Defaults to `"auto"`, which uses the default optimizer
78+
for the given model and task. See `keras.Model.compile` and
79+
`keras.optimizers` for more info on possible `optimizer` values.
80+
loss: `"auto"', a loss name, or a `keras.losses.Loss` instance.
81+
Defaults to `"auto"`, where a
82+
`keras.losses.SparseCategoricalCrossentropy` loss will be
83+
applied for the classification task. See
84+
`keras.Model.compile` and `keras.losses` for more info on
85+
possible `loss` values.
86+
metrics: `"auto"`, or a list of metrics to be evaluated by
87+
the model during training and testing. Defaults to `"auto"`,
88+
where a `keras.metrics.SparseCategoricalAccuracy` will be
89+
applied to track the accuracy of the model during training.
90+
See `keras.Model.compile` and `keras.metrics` for
91+
more info on possible `metrics` values.
92+
**kwargs: See `keras.Model.compile` for a full list of arguments
93+
supported by the compile method.
94+
"""
95+
if optimizer == "auto":
96+
optimizer = keras.optimizers.Adam(5e-5)
97+
if loss == "auto":
98+
activation = getattr(self, "activation", None)
99+
activation = keras.activations.get(activation)
100+
from_logits = activation != keras.activations.softmax
101+
loss = keras.losses.SparseCategoricalCrossentropy(from_logits)
102+
if metrics == "auto":
103+
metrics = [keras.metrics.SparseCategoricalAccuracy()]
104+
# Keras 2 does not jit_compile by default.
105+
if not config.keras_3():
106+
kwargs["jit_compile"] = True
107+
super().compile(
108+
optimizer=optimizer,
109+
loss=loss,
110+
metrics=metrics,
111+
**kwargs,
112+
)

keras_nlp/models/deberta_v3/deberta_v3_classifier.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -212,17 +212,6 @@ def __init__(
212212
self.hidden_dim = hidden_dim
213213
self.dropout = dropout
214214

215-
# === Default compilation ===
216-
logit_output = self.activation == keras.activations.linear
217-
self.compile(
218-
loss=keras.losses.SparseCategoricalCrossentropy(
219-
from_logits=logit_output
220-
),
221-
optimizer=keras.optimizers.Adam(5e-5),
222-
metrics=[keras.metrics.SparseCategoricalAccuracy()],
223-
jit_compile=True,
224-
)
225-
226215
def get_config(self):
227216
config = super().get_config()
228217
config.update(

keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,3 @@ def __init__(
130130
outputs=outputs,
131131
**kwargs,
132132
)
133-
134-
# === Default compilation ===
135-
self.compile(
136-
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
137-
optimizer=keras.optimizers.Adam(5e-5),
138-
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
139-
jit_compile=True,
140-
)

keras_nlp/models/distil_bert/distil_bert_classifier.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -192,17 +192,6 @@ def __init__(
192192
self.hidden_dim = hidden_dim
193193
self.dropout = dropout
194194

195-
# === Default compilation ===
196-
logit_output = self.activation == keras.activations.linear
197-
self.compile(
198-
loss=keras.losses.SparseCategoricalCrossentropy(
199-
from_logits=logit_output
200-
),
201-
optimizer=keras.optimizers.Adam(5e-5),
202-
metrics=[keras.metrics.SparseCategoricalAccuracy()],
203-
jit_compile=True,
204-
)
205-
206195
def get_config(self):
207196
config = super().get_config()
208197
config.update(

keras_nlp/models/distil_bert/distil_bert_masked_lm.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,3 @@ def __init__(
132132
outputs=outputs,
133133
**kwargs,
134134
)
135-
136-
# === Default compilation ===
137-
self.compile(
138-
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
139-
optimizer=keras.optimizers.Adam(5e-5),
140-
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
141-
jit_compile=True,
142-
)

keras_nlp/models/f_net/f_net_classifier.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -141,17 +141,6 @@ def __init__(
141141
self.activation = keras.activations.get(activation)
142142
self.dropout = dropout
143143

144-
# === Default compilation ===
145-
logit_output = self.activation == keras.activations.linear
146-
self.compile(
147-
loss=keras.losses.SparseCategoricalCrossentropy(
148-
from_logits=logit_output
149-
),
150-
optimizer=keras.optimizers.Adam(5e-5),
151-
metrics=[keras.metrics.SparseCategoricalAccuracy()],
152-
jit_compile=True,
153-
)
154-
155144
def get_config(self):
156145
config = super().get_config()
157146
config.update(

0 commit comments

Comments
 (0)