Skip to content

Commit

Permalink
Update phishing_email_detection_gpt2.py
Browse files Browse the repository at this point in the history
  • Loading branch information
david-thrower authored Dec 8, 2023
1 parent d092dc6 commit 01a3df0
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion phishing_email_detection_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,20 @@ def from_config(cls, config):
#
return cls(max_seq_length=config['max_seq_length'])




class CastToFloat32(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(CastToFloat32, self).__init__(**kwargs)

def call(self, inputs):
return tf.cast(inputs, tf.float32)

def get_config(self):
return super(CastToFloat32, self).get_config()


# GPT2 configurables

max_seq_length = 250
Expand Down Expand Up @@ -147,8 +161,9 @@ def from_config(cls, config):
# I think concatenating the embedded and
# un-embedded tokens may emulate a wide and deep model.
# Worth a try.
float_tokens = CastToFloat32()(tokens)
concatenated_inputs =\
tf.keras.layers.Concatenate(axis=1)([flattened, tokens])
tf.keras.layers.Concatenate(axis=1)([flattened, float_tokens])

tokenized_embedded_model=\
tf.keras.Model(
Expand Down

0 comments on commit 01a3df0

Please sign in to comment.