Skip to content

Commit

Permalink
Dont use einsum for one replica
Browse files Browse the repository at this point in the history
  • Loading branch information
APJansen committed Feb 7, 2024
1 parent 22aa095 commit 5fdf9ed
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions n3fit/src/n3fit/backends/keras_backend/multi_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
replica_seeds: List[int],
kernel_initializer: Initializer,
replica_input: bool = True,
**kwargs
**kwargs,
):
super().__init__(**kwargs)
self.replicas = len(replica_seeds)
Expand All @@ -76,6 +76,8 @@ def __init__(
)
self.replica_input = replica_input

self.matmul = None

def build(self, input_shape):
input_dim = input_shape[-1]
self.kernel = self.add_weight(
Expand All @@ -98,6 +100,18 @@ def build(self, input_shape):
self.input_spec.axes = {-1: input_dim}
self.built = True

if self.replicas == 1:
if self.replica_input:
self.matmul = lambda inputs: tf.tensordot(inputs, self.kernel[0], [[-1], [0]])
else:
# Manually add replica dimension
self.matmul = lambda inputs: tf.expand_dims(
tf.tensordot(inputs, self.kernel[0], [[-1], [0]]), axis=1
)
else:
einrule = f"b{'r' if self.replica_input else ''}nf,rfg->brng"
self.matmul = lambda inputs: tf.einsum(einrule, inputs, self.kernel)

def call(self, inputs):
"""
Compute output of shape (batch_size, replicas, gridsize, units).
Expand All @@ -110,9 +124,7 @@ def call(self, inputs):
if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
inputs = tf.cast(inputs, dtype=self._compute_dtype_object)

input_axes = 'brnf' if self.replica_input else 'bnf'
einrule = input_axes + ',rfg->brng'
outputs = tf.einsum(einrule, inputs, self.kernel)
outputs = self.matmul(inputs)

# Reshape the output back to the original ndim of the input.
if not tf.executing_eagerly():
Expand Down

0 comments on commit 5fdf9ed

Please sign in to comment.