Skip to content

Commit bf5644e

Browse files
Gsk/ch21380/upgrade ml model schema (#196)
* minor naming changes * added load_v2 in base and all framework model classes * Needed changes to modela base class, and Tensorflow-Keras Class for schema, save, load and Tensorboard. * moved tensorboard file serialize to base class * Write Tensorboard files became a static method. * Needed changes to PyTorch Class for schema, save, load and Tensorboard. * changed opened array naming in _write_array * Needed changes to SkLearn Class for schema, save, load and Tensorboard. * changed attribute dtype to numpy.uint8 * change on how all pickled objects are loaded * fix tensorboard key, value loop * pytorch model unit tests and fixes in save and load * tensorflow model unit tests and fixes in save and load * Removed old Tensorboard Class, removed tensorboard file properties. * removed groups leftovers * removed custom layer model leftovers from unit tests * added model and optimizer weight getters to model base class * added model and optimizer weight getters to Tensorflow-Keras model class * added model and optimizer weight getters to PyTorch model class * updated model notebooks * fix mypy getters return error * added abstract methods decorator in model base class preview, get_optimizer_weights and get_weights methods * removed abstractmethod decorator from getters. * moved _write_array method to _base.py, updated Tensoflow Keras model class accordingly. * Updated Sklearn and PyTorch model classes with new write array functionality. * Added mode='w' in write metadata * PR review changes * Added tiledb-ml version check * One liner version check. * Changed version check function and unit tests. * Correct TileDB-ML version in file properties. * changed version check back to schema characteristics, removed unnecessary mocks from unit tests. * DRYfy get_weights/get_optimizer_weights Co-authored-by: George Sakkis <george.sakkis@tiledb.com>
1 parent ddb447c commit bf5644e

13 files changed

+670
-839
lines changed

examples/models/pytorch_tiledb_models_example.ipynb

Lines changed: 74 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@
5353
"outputs": [
5454
{
5555
"data": {
56-
"text/plain": [
57-
"<torch._C.Generator at 0x15dae6650>"
58-
]
56+
"text/plain": "<torch._C.Generator at 0x1218c0b70>"
5957
},
6058
"execution_count": 2,
6159
"metadata": {},
@@ -105,18 +103,11 @@
105103
]
106104
},
107105
{
108-
"data": {
109-
"application/vnd.jupyter.widget-view+json": {
110-
"model_id": "1834a028d80340888143ba2c4d99a1b0",
111-
"version_major": 2,
112-
"version_minor": 0
113-
},
114-
"text/plain": [
115-
" 0%| | 0/9912422 [00:00<?, ?it/s]"
116-
]
117-
},
118-
"metadata": {},
119-
"output_type": "display_data"
106+
"name": "stderr",
107+
"output_type": "stream",
108+
"text": [
109+
"100.0%\n"
110+
]
120111
},
121112
{
122113
"name": "stdout",
@@ -129,18 +120,11 @@
129120
]
130121
},
131122
{
132-
"data": {
133-
"application/vnd.jupyter.widget-view+json": {
134-
"model_id": "55ca1ec83c0a4526ab2f201e1615c490",
135-
"version_major": 2,
136-
"version_minor": 0
137-
},
138-
"text/plain": [
139-
" 0%| | 0/28881 [00:00<?, ?it/s]"
140-
]
141-
},
142-
"metadata": {},
143-
"output_type": "display_data"
123+
"name": "stderr",
124+
"output_type": "stream",
125+
"text": [
126+
"100.0%\n"
127+
]
144128
},
145129
{
146130
"name": "stdout",
@@ -153,18 +137,11 @@
153137
]
154138
},
155139
{
156-
"data": {
157-
"application/vnd.jupyter.widget-view+json": {
158-
"model_id": "2284fcddf2b243ad8989c7326457b9ed",
159-
"version_major": 2,
160-
"version_minor": 0
161-
},
162-
"text/plain": [
163-
" 0%| | 0/1648877 [00:00<?, ?it/s]"
164-
]
165-
},
166-
"metadata": {},
167-
"output_type": "display_data"
140+
"name": "stderr",
141+
"output_type": "stream",
142+
"text": [
143+
"100.0%\n"
144+
]
168145
},
169146
{
170147
"name": "stdout",
@@ -177,18 +154,11 @@
177154
]
178155
},
179156
{
180-
"data": {
181-
"application/vnd.jupyter.widget-view+json": {
182-
"model_id": "86116abe57bb42c09ed7de5fbd982b5e",
183-
"version_major": 2,
184-
"version_minor": 0
185-
},
186-
"text/plain": [
187-
" 0%| | 0/4542 [00:00<?, ?it/s]"
188-
]
189-
},
190-
"metadata": {},
191-
"output_type": "display_data"
157+
"name": "stderr",
158+
"output_type": "stream",
159+
"text": [
160+
"100.0%"
161+
]
192162
},
193163
{
194164
"name": "stdout",
@@ -197,6 +167,13 @@
197167
"Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw\n",
198168
"\n"
199169
]
170+
},
171+
{
172+
"name": "stderr",
173+
"output_type": "stream",
174+
"text": [
175+
"\n"
176+
]
200177
}
201178
],
202179
"source": [
@@ -283,7 +260,7 @@
283260
}
284261
},
285262
"source": [
286-
"We continue with the training loop and we iterate over all training data once per epoch. Loading the individual batches\n",
263+
"We continue with the training loop, and we iterate over all training data once per epoch. Loading the individual batches\n",
287264
"is handled by the DataLoader. We need to set the gradients to zero using optimizer.zero_grad() since PyTorch by default\n",
288265
"accumulates gradients. We then produce the output of the network (forward pass) and compute a negative log-likelihodd\n",
289266
"loss between the output and the ground truth label. The backward() call we now collect a new set of gradients which we\n",
@@ -299,6 +276,14 @@
299276
}
300277
},
301278
"outputs": [
279+
{
280+
"name": "stderr",
281+
"output_type": "stream",
282+
"text": [
283+
"2022-12-07 17:00:23.979857: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
284+
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
285+
]
286+
},
302287
{
303288
"name": "stdout",
304289
"output_type": "stream",
@@ -388,8 +373,8 @@
388373
"cell_type": "markdown",
389374
"metadata": {},
390375
"source": [
391-
"We can now save the trained model as a TileDB array. In case we want to train the model further in a later time, we can also save\n",
392-
"the optimizer in our TileDB array. In case we will use our model only for inference, we don't have to save the optimizer and we\n",
376+
"We can now save the trained model as a TileDB array. In case we want to train the model further in a later time, we can also save\n",
377+
"the optimizer in our TileDB array. In case we will use our model only for inference, we don't have to save the optimizer, and we\n",
393378
"only keep the model. We first declare a PytTorchTileDB object and initialize it with the corresponding TileDB uri, model and optimizer,\n",
394379
"and then save the model as a TileDB array. Finally, we can save any kind of metadata (in any structure, i.e., list, tuple or dictionary)\n",
395380
"by passing a dictionary to the meta attribute."
@@ -443,17 +428,20 @@
443428
" '../data/pytorch-mnist-1/__schema',\n",
444429
" '../data/pytorch-mnist-1/__fragments']\n",
445430
"Key: TILEDB_ML_MODEL_ML_FRAMEWORK, Value: PYTORCH\n",
446-
"Key: TILEDB_ML_MODEL_ML_FRAMEWORK_VERSION, Value: 1.10.2\n",
431+
"Key: TILEDB_ML_MODEL_ML_FRAMEWORK_VERSION, Value: 1.12.0\n",
447432
"Key: TILEDB_ML_MODEL_PREVIEW, Value: Net(\n",
448433
" (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))\n",
449434
" (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))\n",
450435
" (conv2_drop): Dropout2d(p=0.5, inplace=False)\n",
451436
" (fc1): Linear(in_features=320, out_features=50, bias=True)\n",
452437
" (fc2): Linear(in_features=50, out_features=10, bias=True)\n",
453438
")\n",
454-
"Key: TILEDB_ML_MODEL_PYTHON_VERSION, Value: 3.9.13\n",
439+
"Key: TILEDB_ML_MODEL_PYTHON_VERSION, Value: 3.9.9\n",
455440
"Key: TILEDB_ML_MODEL_STAGE, Value: STAGING\n",
456441
"Key: epochs, Value: 1\n",
442+
"Key: model_state_dict_size, Value: 90053\n",
443+
"Key: optimizer_state_dict_size, Value: 90064\n",
444+
"Key: tensorboard_size, Value: 22674\n",
457445
"Key: train_loss, Value: (2.358812093734741, 2.285137891769409, 2.3066349029541016, 2.2708795070648193, 2.2367401123046875, 2.24334716796875, 2.1832549571990967, 2.1485116481781006, 2.1049115657806396, 2.0044069290161133, 1.8622523546218872, 1.8843708038330078, 1.7973158359527588, 1.6879109144210815, 1.508046269416809, 1.764279842376709, 1.4700727462768555, 1.3514467477798462, 1.2905819416046143, 1.0177571773529053, 1.042162299156189, 1.0987662076950073, 1.2285516262054443, 1.1495932340621948, 0.8452475070953369, 0.9741130471229553, 0.8569056987762451, 0.9234588146209717, 1.0218565464019775, 0.8069543242454529, 0.8789511919021606, 0.8185049891471863, 0.8055434226989746, 0.8231522440910339, 0.8543609976768494, 0.7746452689170837, 0.718348503112793, 0.5433375239372253, 0.7593768239021301, 0.65492182970047, 0.6999298930168152, 0.8053513765335083, 0.790733814239502, 0.7599329948425293, 0.540409505367279, 0.6412327885627747, 0.6593738198280334)\n"
458446
]
459447
}
@@ -494,18 +482,21 @@
494482
"output_type": "stream",
495483
"text": [
496484
"Key: TILEDB_ML_MODEL_ML_FRAMEWORK, Value: PYTORCH\n",
497-
"Key: TILEDB_ML_MODEL_ML_FRAMEWORK_VERSION, Value: 1.10.2\n",
485+
"Key: TILEDB_ML_MODEL_ML_FRAMEWORK_VERSION, Value: 1.12.0\n",
498486
"Key: TILEDB_ML_MODEL_PREVIEW, Value: Net(\n",
499487
" (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))\n",
500488
" (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))\n",
501489
" (conv2_drop): Dropout2d(p=0.5, inplace=False)\n",
502490
" (fc1): Linear(in_features=320, out_features=50, bias=True)\n",
503491
" (fc2): Linear(in_features=50, out_features=10, bias=True)\n",
504492
")\n",
505-
"Key: TILEDB_ML_MODEL_PYTHON_VERSION, Value: 3.9.13\n",
493+
"Key: TILEDB_ML_MODEL_PYTHON_VERSION, Value: 3.9.9\n",
506494
"Key: TILEDB_ML_MODEL_STAGE, Value: STAGING\n",
507495
"Key: epochs, Value: 1\n",
496+
"Key: model_state_dict_size, Value: 90053\n",
508497
"Key: new_meta, Value: [\"Any kind of info\"]\n",
498+
"Key: optimizer_state_dict_size, Value: 90064\n",
499+
"Key: tensorboard_size, Value: 22674\n",
509500
"Key: train_loss, Value: (2.358812093734741, 2.285137891769409, 2.3066349029541016, 2.2708795070648193, 2.2367401123046875, 2.24334716796875, 2.1832549571990967, 2.1485116481781006, 2.1049115657806396, 2.0044069290161133, 1.8622523546218872, 1.8843708038330078, 1.7973158359527588, 1.6879109144210815, 1.508046269416809, 1.764279842376709, 1.4700727462768555, 1.3514467477798462, 1.2905819416046143, 1.0177571773529053, 1.042162299156189, 1.0987662076950073, 1.2285516262054443, 1.1495932340621948, 0.8452475070953369, 0.9741130471229553, 0.8569056987762451, 0.9234588146209717, 1.0218565464019775, 0.8069543242454529, 0.8789511919021606, 0.8185049891471863, 0.8055434226989746, 0.8231522440910339, 0.8543609976768494, 0.7746452689170837, 0.718348503112793, 0.5433375239372253, 0.7593768239021301, 0.65492182970047, 0.6999298930168152, 0.8053513765335083, 0.790733814239502, 0.7599329948425293, 0.540409505367279, 0.6412327885627747, 0.6593738198280334)\n"
510501
]
511502
}
@@ -524,49 +515,6 @@
524515
" print(\"Key: {}, Value: {}\".format(key, value))"
525516
]
526517
},
527-
{
528-
"cell_type": "markdown",
529-
"metadata": {},
530-
"source": [
531-
"For the case of PyTorch models, internally, we save model's state_dict and optimizer's state_dict,\n",
532-
"as [variable sized attributes)](https://docs.tiledb.com/main/how-to/arrays/writing-arrays/var-length-attributes)\n",
533-
"(pickled), i.e., we can open the TileDB and get only the state_dict of the model or optimizer,\n",
534-
"without bringing the whole model in memory. For example, we can load model's and optimizer's state_dict\n",
535-
"for model `pytorch-mnist-1` as follows."
536-
]
537-
},
538-
{
539-
"cell_type": "code",
540-
"execution_count": 10,
541-
"metadata": {
542-
"pycharm": {
543-
"name": "#%%\n"
544-
}
545-
},
546-
"outputs": [
547-
{
548-
"name": "stdout",
549-
"output_type": "stream",
550-
"text": [
551-
"Type: <class 'collections.OrderedDict'> , Keys: odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])\n",
552-
"Type: <class 'dict'>, Keys: dict_keys(['state', 'param_groups'])\n"
553-
]
554-
}
555-
],
556-
"source": [
557-
"# First open arrays\n",
558-
"model_array_1 = tiledb.open(uri)[:]\n",
559-
"\n",
560-
"# Load model state_dict\n",
561-
"model_1_state_dict = pickle.loads(model_array_1['model_state_dict'].item(0))\n",
562-
"\n",
563-
"# Load optimizer state_dict\n",
564-
"optimizer_1_state_dict = pickle.loads(model_array_1['optimizer_state_dict'].item(0))\n",
565-
"\n",
566-
"print(f'Type: {type(model_1_state_dict)} , Keys: {model_1_state_dict.keys()}')\n",
567-
"print(f'Type: {type(optimizer_1_state_dict)}, Keys: {optimizer_1_state_dict.keys()}')"
568-
]
569-
},
570518
{
571519
"cell_type": "markdown",
572520
"metadata": {
@@ -581,7 +529,7 @@
581529
},
582530
{
583531
"cell_type": "code",
584-
"execution_count": 11,
532+
"execution_count": 10,
585533
"metadata": {
586534
"pycharm": {
587535
"name": "#%%\n"
@@ -592,10 +540,7 @@
592540
"# Place holder for the loaded model\n",
593541
"network = Net()\n",
594542
"optimizer = optim.SGD(network.parameters(), lr=learning_rate, momentum=momentum)\n",
595-
"\n",
596-
"# Load returns possible extra attributes, other than model's and optimizer's state dicts. In case there were\n",
597-
"# no extra attributes it will return an empty dict\n",
598-
"_ = tiledb_model_1.load(model=network, optimizer=optimizer)"
543+
"tiledb_model_1.load(model=network, optimizer=optimizer)"
599544
]
600545
},
601546
{
@@ -607,7 +552,7 @@
607552
},
608553
{
609554
"cell_type": "code",
610-
"execution_count": 12,
555+
"execution_count": 11,
611556
"metadata": {
612557
"pycharm": {
613558
"name": "#%%\n"
@@ -724,16 +669,16 @@
724669
"number of fragments: 2\n",
725670
"\n",
726671
"===== FRAGMENT NUMBER 0 =====\n",
727-
"fragment uri: file:///Users/konstantinostsitsimpikos/tileroot/TileDB-ML/examples/data/pytorch-mnist-1/__fragments/__1660811273615_1660811273615_23699d36dbc744809486d88176c2920f_13\n",
728-
"timestamp range: (1660811273615, 1660811273615)\n",
672+
"fragment uri: file:///Users/george/PycharmProjects/TileDB-ML/examples/data/pytorch-mnist-1/__fragments/__1670425246498_1670425246498_5ca20757611a43009e22606647ee9b22_16\n",
673+
"timestamp range: (1670425246498, 1670425246498)\n",
729674
"number of unconsolidated metadata: 2\n",
730-
"version: 13\n",
675+
"version: 16\n",
731676
"\n",
732677
"===== FRAGMENT NUMBER 1 =====\n",
733-
"fragment uri: file:///Users/konstantinostsitsimpikos/tileroot/TileDB-ML/examples/data/pytorch-mnist-1/__fragments/__1660811314379_1660811314379_0309938da153404e88a7a64ff044fc20_13\n",
734-
"timestamp range: (1660811314379, 1660811314379)\n",
678+
"fragment uri: file:///Users/george/PycharmProjects/TileDB-ML/examples/data/pytorch-mnist-1/__fragments/__1670425278236_1670425278236_8e60255a3abe4173b21458369995c20c_16\n",
679+
"timestamp range: (1670425278236, 1670425278236)\n",
735680
"number of unconsolidated metadata: 2\n",
736-
"version: 13\n"
681+
"version: 16\n"
737682
]
738683
}
739684
],
@@ -789,7 +734,7 @@
789734
},
790735
{
791736
"cell_type": "code",
792-
"execution_count": 13,
737+
"execution_count": 12,
793738
"metadata": {
794739
"pycharm": {
795740
"name": "#%%\n"
@@ -826,7 +771,7 @@
826771
},
827772
{
828773
"cell_type": "code",
829-
"execution_count": 14,
774+
"execution_count": 13,
830775
"metadata": {
831776
"pycharm": {
832777
"name": "#%%\n"
@@ -915,7 +860,7 @@
915860
},
916861
{
917862
"cell_type": "code",
918-
"execution_count": 15,
863+
"execution_count": 14,
919864
"metadata": {
920865
"pycharm": {
921866
"name": "#%%\n"
@@ -924,11 +869,9 @@
924869
"outputs": [
925870
{
926871
"data": {
927-
"text/plain": [
928-
"'../data/tiledb-pytorch-mnist/pytorch-mnist-2'"
929-
]
872+
"text/plain": "'../data/tiledb-pytorch-mnist/pytorch-mnist-2'"
930873
},
931-
"execution_count": 15,
874+
"execution_count": 14,
932875
"metadata": {},
933876
"output_type": "execute_result"
934877
}
@@ -949,13 +892,22 @@
949892
},
950893
{
951894
"cell_type": "code",
952-
"execution_count": null,
895+
"execution_count": 15,
953896
"metadata": {
954897
"pycharm": {
955898
"name": "#%%\n"
956899
}
957900
},
958-
"outputs": [],
901+
"outputs": [
902+
{
903+
"name": "stdout",
904+
"output_type": "stream",
905+
"text": [
906+
"file:///Users/george/PycharmProjects/TileDB-ML/examples/data/tiledb-pytorch-mnist/pytorch-mnist-1 array\n",
907+
"file:///Users/george/PycharmProjects/TileDB-ML/examples/data/tiledb-pytorch-mnist/pytorch-mnist-2 array\n"
908+
]
909+
}
910+
],
959911
"source": [
960912
"tiledb.ls(group, lambda obj_path, obj_type: print(obj_path, obj_type))"
961913
]

0 commit comments

Comments
 (0)