Skip to content

Commit 37fe816

Browse files
Pass costum activation function
1 parent 98c1619 commit 37fe816

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

deeplc/deeplc.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@
5050

5151
from tensorflow.keras.models import load_model
5252

53+
# "Costum" activation function
54+
lrelu = lambda x: tf.keras.activations.relu(x, alpha=0.1, max_value=20.0)
55+
5356
try: from tensorflow.compat.v1.keras.backend import set_session
5457
except ImportError: from tensorflow.keras.backend import set_session
5558
try: from tensorflow.compat.v1.keras.backend import clear_session
@@ -410,20 +413,23 @@ def make_preds_core(self,
410413
if isinstance(self.model, dict):
411414
ret_preds = []
412415
for m_group_name,m_name in self.model.items():
413-
mod = load_model(m_name)
416+
mod = load_model(m_name,
417+
custom_objects = {'<lambda>': lrelu})
414418
uncal_preds = mod.predict(
415419
[X, X_sum, X_global, X_hc], batch_size=5120).flatten() / correction_factor
416420

417421
p = list(self.calibration_core(uncal_preds,self.calibrate_dict[m_name],self.calibrate_min[m_name],self.calibrate_max[m_name]))
418422
ret_preds.append(p)
419423
ret_preds = np.array([sum(a)/len(a) for a in zip(*ret_preds)])
420424
elif not mod_name:
421-
mod = load_model(self.model)
425+
mod = load_model(self.model,
426+
custom_objects = {'<lambda>': lrelu})
422427
uncal_preds = mod.predict(
423428
[X, X_sum, X_global, X_hc], batch_size=5120).flatten() / correction_factor
424429
ret_preds = self.calibration_core(uncal_preds,self.calibrate_dict,self.calibrate_min,self.calibrate_max)
425430
else:
426-
mod = load_model(mod_name)
431+
mod = load_model(mod_name,
432+
custom_objects = {'<lambda>': lrelu})
427433
uncal_preds = mod.predict(
428434
[X, X_sum, X_global, X_hc], batch_size=5120).flatten() / correction_factor
429435
ret_preds = self.calibration_core(uncal_preds,self.calibrate_dict,self.calibrate_min,self.calibrate_max)
@@ -442,14 +448,16 @@ def make_preds_core(self,
442448
if isinstance(self.model, dict):
443449
ret_preds = []
444450
for m_group_name,m_name in self.model.items():
445-
mod = load_model(m_name)
451+
mod = load_model(m_name,
452+
custom_objects = {'<lambda>': lrelu})
446453
p = mod.predict(
447454
[X, X_sum, X_global, X_hc], batch_size=5120).flatten() / correction_factor
448455
ret_preds.append(p)
449456
ret_preds = np.array([sum(a)/len(a) for a in zip(*ret_preds)])
450457
elif isinstance(self.model, list):
451458
mod_name = self.model[0]
452-
mod = load_model(mod_name)
459+
mod = load_model(mod_name,
460+
custom_objects = {'<lambda>': lrelu})
453461
ret_preds = mod.predict([X,
454462
X_sum,
455463
X_global,
@@ -458,7 +466,8 @@ def make_preds_core(self,
458466
verbose=cnn_verbose).flatten() / correction_factor
459467
elif isinstance(self.model, str):
460468
mod_name = self.model
461-
mod = load_model(mod_name)
469+
mod = load_model(mod_name,
470+
custom_objects = {'<lambda>': lrelu})
462471
ret_preds = mod.predict([X,
463472
X_sum,
464473
X_global,
@@ -469,7 +478,8 @@ def make_preds_core(self,
469478
logging.critical('No CNN model defined.')
470479
exit(1)
471480
else:
472-
mod = load_model(mod_name)
481+
mod = load_model(mod_name,
482+
custom_objects = {'<lambda>': lrelu})
473483
ret_preds = mod.predict([X,
474484
X_sum,
475485
X_global,

0 commit comments

Comments
 (0)