|
19 | 19 | from absl.testing import parameterized
|
20 | 20 |
|
21 | 21 | from keras_nlp import upload_preset
|
22 |
| -from keras_nlp.models.albert.albert_classifier import AlbertClassifier |
23 |
| -from keras_nlp.models.backbone import Backbone |
24 |
| -from keras_nlp.models.bert.bert_classifier import BertClassifier |
25 |
| -from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier |
26 |
| -from keras_nlp.models.task import Task |
| 22 | +from keras_nlp.models import AlbertClassifier |
| 23 | +from keras_nlp.models import Backbone |
| 24 | +from keras_nlp.models import BertBackbone |
| 25 | +from keras_nlp.models import BertClassifier |
| 26 | +from keras_nlp.models import BertTokenizer |
| 27 | +from keras_nlp.models import RobertaClassifier |
| 28 | +from keras_nlp.models import Task |
27 | 29 | from keras_nlp.tests.test_case import TestCase
|
| 30 | +from keras_nlp.utils.preset_utils import CONFIG_FILE |
| 31 | +from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE |
28 | 32 | from keras_nlp.utils.preset_utils import check_preset_class
|
29 | 33 | from keras_nlp.utils.preset_utils import load_from_preset
|
30 | 34 | from keras_nlp.utils.preset_utils import save_to_preset
|
@@ -116,4 +120,51 @@ def test_upload_empty_preset(self):
|
116 | 120 | with self.assertRaises(FileNotFoundError):
|
117 | 121 | upload_preset(uri, empty_preset)
|
118 | 122 |
|
119 |
| - # TODO: add more test to cover various invalid scenarios such as invalid json, missing files, etc. |
| 123 | + @parameterized.parameters( |
| 124 | + (TOKENIZER_CONFIG_FILE), (CONFIG_FILE), ("model.weights.h5") |
| 125 | + ) |
| 126 | + @pytest.mark.keras_3_only |
| 127 | + @pytest.mark.large |
| 128 | + def test_upload_with_missing_file(self, missing_file): |
| 129 | + # Load a model from Kaggle to use as a test model. |
| 130 | + preset = "bert_tiny_en_uncased" |
| 131 | + backbone = BertBackbone.from_preset(preset) |
| 132 | + tokenizer = BertTokenizer.from_preset(preset) |
| 133 | + |
| 134 | + # Save the model on a local directory. |
| 135 | + temp_dir = self.get_temp_dir() |
| 136 | + local_preset_dir = os.path.join(temp_dir, "bert_preset") |
| 137 | + backbone.save_to_preset(local_preset_dir) |
| 138 | + tokenizer.save_to_preset(local_preset_dir) |
| 139 | + |
| 140 | + # Delete the file that is supposed to be missing. |
| 141 | + missing_path = os.path.join(local_preset_dir, missing_file) |
| 142 | + os.remove(missing_path) |
| 143 | + |
| 144 | + # Verify error handling. |
| 145 | + with self.assertRaisesRegex(FileNotFoundError, "is missing"): |
| 146 | + upload_preset("kaggle://test/test/test", local_preset_dir) |
| 147 | + |
| 148 | + @parameterized.parameters((TOKENIZER_CONFIG_FILE), (CONFIG_FILE)) |
| 149 | + @pytest.mark.keras_3_only |
| 150 | + @pytest.mark.large |
| 151 | + def test_upload_with_invalid_json(self, json_file): |
| 152 | + # Load a model from Kaggle to use as a test model. |
| 153 | + preset = "bert_tiny_en_uncased" |
| 154 | + backbone = BertBackbone.from_preset(preset) |
| 155 | + tokenizer = BertTokenizer.from_preset(preset) |
| 156 | + |
| 157 | + # Save the model on a local directory. |
| 158 | + temp_dir = self.get_temp_dir() |
| 159 | + local_preset_dir = os.path.join(temp_dir, "bert_preset") |
| 160 | + backbone.save_to_preset(local_preset_dir) |
| 161 | + tokenizer.save_to_preset(local_preset_dir) |
| 162 | + |
| 163 | + # Re-write json file content to an invalid format. |
| 164 | + json_path = os.path.join(local_preset_dir, json_file) |
| 165 | + with open(json_path, "w") as file: |
| 166 | + file.write("Invalid!") |
| 167 | + |
| 168 | + # Verify error handling. |
| 169 | + with self.assertRaisesRegex(ValueError, "is an invalid json"): |
| 170 | + upload_preset("kaggle://test/test/test", local_preset_dir) |
0 commit comments