Skip to content

Commit da734ee

Browse files
Add Tests for Kaggle Upload Validation (#1524)
* Add Kaggle upload validation tests. * Use bert_tiny as test model.
1 parent 9ee8041 commit da734ee

File tree

2 files changed

+58
-7
lines changed

2 files changed

+58
-7
lines changed

keras_nlp/utils/preset_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def _validate_backbone(preset):
223223
weights_path = os.path.join(preset, config["weights"])
224224
if not os.path.exists(weights_path):
225225
raise FileNotFoundError(
226-
f"The weights file doesn't exist in preset directory `{preset}`."
226+
f"The weights file is missing from the preset directory `{preset}`."
227227
)
228228
else:
229229
raise ValueError(

keras_nlp/utils/preset_utils_test.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,16 @@
1919
from absl.testing import parameterized
2020

2121
from keras_nlp import upload_preset
22-
from keras_nlp.models.albert.albert_classifier import AlbertClassifier
23-
from keras_nlp.models.backbone import Backbone
24-
from keras_nlp.models.bert.bert_classifier import BertClassifier
25-
from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier
26-
from keras_nlp.models.task import Task
22+
from keras_nlp.models import AlbertClassifier
23+
from keras_nlp.models import Backbone
24+
from keras_nlp.models import BertBackbone
25+
from keras_nlp.models import BertClassifier
26+
from keras_nlp.models import BertTokenizer
27+
from keras_nlp.models import RobertaClassifier
28+
from keras_nlp.models import Task
2729
from keras_nlp.tests.test_case import TestCase
30+
from keras_nlp.utils.preset_utils import CONFIG_FILE
31+
from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE
2832
from keras_nlp.utils.preset_utils import check_preset_class
2933
from keras_nlp.utils.preset_utils import load_from_preset
3034
from keras_nlp.utils.preset_utils import save_to_preset
@@ -116,4 +120,51 @@ def test_upload_empty_preset(self):
116120
with self.assertRaises(FileNotFoundError):
117121
upload_preset(uri, empty_preset)
118122

119-
# TODO: add more test to cover various invalid scenarios such as invalid json, missing files, etc.
123+
@parameterized.parameters(
124+
(TOKENIZER_CONFIG_FILE), (CONFIG_FILE), ("model.weights.h5")
125+
)
126+
@pytest.mark.keras_3_only
127+
@pytest.mark.large
128+
def test_upload_with_missing_file(self, missing_file):
129+
# Load a model from Kaggle to use as a test model.
130+
preset = "bert_tiny_en_uncased"
131+
backbone = BertBackbone.from_preset(preset)
132+
tokenizer = BertTokenizer.from_preset(preset)
133+
134+
# Save the model on a local directory.
135+
temp_dir = self.get_temp_dir()
136+
local_preset_dir = os.path.join(temp_dir, "bert_preset")
137+
backbone.save_to_preset(local_preset_dir)
138+
tokenizer.save_to_preset(local_preset_dir)
139+
140+
# Delete the file that is supposed to be missing.
141+
missing_path = os.path.join(local_preset_dir, missing_file)
142+
os.remove(missing_path)
143+
144+
# Verify error handling.
145+
with self.assertRaisesRegex(FileNotFoundError, "is missing"):
146+
upload_preset("kaggle://test/test/test", local_preset_dir)
147+
148+
@parameterized.parameters((TOKENIZER_CONFIG_FILE), (CONFIG_FILE))
149+
@pytest.mark.keras_3_only
150+
@pytest.mark.large
151+
def test_upload_with_invalid_json(self, json_file):
152+
# Load a model from Kaggle to use as a test model.
153+
preset = "bert_tiny_en_uncased"
154+
backbone = BertBackbone.from_preset(preset)
155+
tokenizer = BertTokenizer.from_preset(preset)
156+
157+
# Save the model on a local directory.
158+
temp_dir = self.get_temp_dir()
159+
local_preset_dir = os.path.join(temp_dir, "bert_preset")
160+
backbone.save_to_preset(local_preset_dir)
161+
tokenizer.save_to_preset(local_preset_dir)
162+
163+
# Re-write json file content to an invalid format.
164+
json_path = os.path.join(local_preset_dir, json_file)
165+
with open(json_path, "w") as file:
166+
file.write("Invalid!")
167+
168+
# Verify error handling.
169+
with self.assertRaisesRegex(ValueError, "is an invalid json"):
170+
upload_preset("kaggle://test/test/test", local_preset_dir)

0 commit comments

Comments
 (0)