Skip to content

try-alex-custom-embedding-with-no-bnorm #138

Open
@david-thrower

Description

@david-thrower

Kind of issue: Enhancement

Additional context: Alex developed a custom enbedding:

class CustomEmbedding(tf.keras.layers.Layer):
   
    def __init__(self, input_dim, output_dim, **kwargs):
        super(CustomEmbedding, self).__init__(**kwargs)
        self.input_dim = input_dim
        self.output_dim = output_dim

    def build(self, input_shape):
        self.embeddings = self.add_weight(
            shape=(self.input_dim, self.output_dim),
            initializer='uniform',
            trainable=True,
            name='embeddings'
        )
        self.scaling = self.add_weight(
            shape=input_shape[-1],
            initializer='ones',
            trainable=True,
            name='scaling'
        )
        super(CustomEmbedding, self).build(input_shape)

    def gaussian_one_hot(self, x, depth, sigma=0.1):
        y = tf.range(depth, dtype=tf.float32)
        x = tf.expand_dims(x, -1)  # Add an extra dimension for broadcasting
        vec = tf.exp(-tf.square(y - x) / (2 * tf.square(sigma)))
        vec = vec / tf.reduce_sum(vec, axis=-1, keepdims=True)
        return vec

    def call(self, inputs):
        scaled   = tf.nn.softmax(inputs) * tf.exp(self.scaling)
        #tf.print(scaled)
        retrieve = self.gaussian_one_hot(scaled, self.input_dim)
        embedded = tf.einsum('bik,kj->bij', retrieve, self.embeddings)
        return embedded

    def compute_output_shape(self, input_shape):
        return input_shape+(self.output_dim,)

I want to try this without the BatchNormalization layer and see if it will fix the issue with the gradients.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions