Skip to content

Commit

Permalink
Add a flag to allow quantization along the embed_size dimension
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 597700612
  • Loading branch information
The praxis Authors committed Jan 12, 2024
1 parent 7b4e27e commit 545e00a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
4 changes: 2 additions & 2 deletions praxis/layers/quantization/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,15 @@ def setup(self) -> None:
tensor_split_dims_mapping=w_sharding,
)
out_bias_dims = sorted(out.index(d) for d in (set(out) - set(x)))
# Fan-out dims must be at the end of `out`.
assert all(d >= len(out) - len(out_bias_dims) for d in out_bias_dims)
bias_shape = [self.w_shape[w.index(out[d])] for d in out_bias_dims]
self.set_up_weights(
weight_name='w',
weight_params=pc,
scale_shape=bias_shape,
)
if self.use_bias:
# Fan-out dims must be at the end of `out`.
assert all(d >= len(out) - len(out_bias_dims) for d in out_bias_dims)
if w_sharding is not None:
b_sharding = [w_sharding[w.index(out[d])] for d in out_bias_dims]
else:
Expand Down
7 changes: 6 additions & 1 deletion praxis/layers/quantization/embedding_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ class SharedEmbeddingSoftmax(embedding_softmax.SharedEmbeddingSoftmax):

quantization: QuantizationParams = instance_field(QuantizationParams)

fq_reverse_contract_embed_dim: bool = False

def setup(self) -> None:
if self.feed_forward_tpl is not None:
wp = self.weight_split_dims_mapping
Expand Down Expand Up @@ -283,7 +285,10 @@ def emb_lookup(self, ids: JTensor) -> JTensor:
'w', use_symmetric=self.quantization.weight_params.use_symmetric
)
else:
eqn = 'xy,zy->xz'
if self.fq_reverse_contract_embed_dim:
eqn = 'xy,yz->xz'
else:
eqn = 'xy,zy->xz'
emb_var = linear_layer.theta.w
if self.quantization.quantization_type == QuantizationType.AQT:
contract_dims = quantized_operations.eqn_to_weight_contract_dims(eqn)
Expand Down

0 comments on commit 545e00a

Please sign in to comment.