-
Notifications
You must be signed in to change notification settings - Fork 19
/
model.py
26 lines (21 loc) · 918 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from tensorflow.keras.layers import Dense, Flatten, GlobalAveragePooling2D, BatchNormalization, Dropout,AveragePooling2D
import tensorflow as tf
from tensorflow.keras.applications import DenseNet201
from keras.models import Model
from keras.models import Sequential
from keras.regularizers import *
from tensorflow import keras
from tensorflow.keras import layers
def download_model():
model = Sequential()
conv_base = DenseNet201(input_shape=(224,224,3), include_top=False, pooling='max',weights='imagenet')
model.add(conv_base)
model.add(BatchNormalization())
model.add(Dense(2048, activation='relu', kernel_regularizer=l1_l2(0.01)))
model.add(BatchNormalization())
model.add(Dense(8, activation='softmax'))
train_layers = [layer for layer in conv_base.layers[::-1][:5]]
for layer in conv_base.layers:
if layer in train_layers:
layer.trainable = True
model.save("model/model.h5")