Skip to content

Commit f313c78

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 06cac3f commit f313c78

File tree

4 files changed

+16
-10
lines changed

4 files changed

+16
-10
lines changed

tests/jax/test_distributed_softmax.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,11 @@ def impl_test_softmax(
138138
@pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [8, 8, 1024, 1024]])
139139
@pytest.mark.parametrize(
140140
"softmax_fusion",
141-
[SoftmaxFusion.SCALED, SoftmaxFusion.SCALED_MASKED, SoftmaxFusion.SCALED_UPPER_TRIANG_MASKED],
141+
[
142+
SoftmaxFusion.SCALED,
143+
SoftmaxFusion.SCALED_MASKED,
144+
SoftmaxFusion.SCALED_UPPER_TRIANG_MASKED,
145+
],
142146
)
143147
@pytest.mark.parametrize("scale_factor", [1.0, 3.0])
144148
@pytest.mark.parametrize("dtype", DTYPES)

tests/jax/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def __call__(
267267
# softmax_offset shape: [1, h, 1, 1], attn_weights shape: [b, h, q, k]
268268
extra_col = jnp.broadcast_to(
269269
softmax_offset,
270-
(attn_weights.shape[0], softmax_offset.shape[1], attn_weights.shape[2], 1)
270+
(attn_weights.shape[0], softmax_offset.shape[1], attn_weights.shape[2], 1),
271271
)
272272
attn_weights = jnp.concatenate([attn_weights, extra_col], axis=-1)
273273

@@ -1043,14 +1043,14 @@ def __call__(self, qlen, klen, bidirectional=True):
10431043

10441044
def convert_softmax_type_str_to_enum(softmax_type: str) -> AttnSoftmaxType:
10451045
"""Convert softmax_type string to AttnSoftmaxType enum.
1046-
1046+
10471047
Args:
10481048
softmax_type: String representation of softmax type.
10491049
One of "vanilla", "off_by_one", "learnable".
1050-
1050+
10511051
Returns:
10521052
AttnSoftmaxType enum value.
1053-
1053+
10541054
Raises:
10551055
ValueError: If softmax_type is not a valid option.
10561056
"""
@@ -1063,7 +1063,7 @@ def convert_softmax_type_str_to_enum(softmax_type: str) -> AttnSoftmaxType:
10631063
else:
10641064
raise ValueError(
10651065
f"Unknown softmax_type: {softmax_type}. "
1066-
f"Valid options: 'vanilla', 'off_by_one', 'learnable'"
1066+
"Valid options: 'vanilla', 'off_by_one', 'learnable'"
10671067
)
10681068

10691069

transformer_engine/jax/flax/module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,8 @@ def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp
213213
outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor)
214214
else:
215215
raise ValueError(
216-
f"Unsupported softmax fusion: {self.softmax_fusion}. softmax_fusion must be [SCALED,"
217-
" SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
216+
f"Unsupported softmax fusion: {self.softmax_fusion}. softmax_fusion must be"
217+
" [SCALED, SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
218218
)
219219
assert input_dtype == outputs.dtype
220220
return outputs

transformer_engine/jax/flax/transformer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,13 +236,15 @@ def apply_swa_mask(original_mask: Array) -> Array:
236236
# softmax_offset shape: [1, h, 1, 1], attn_weights shape: [b, h, q, k]
237237
extra_col = jnp.broadcast_to(
238238
softmax_offset,
239-
(attn_weights.shape[0], softmax_offset.shape[1], attn_weights.shape[2], 1)
239+
(attn_weights.shape[0], softmax_offset.shape[1], attn_weights.shape[2], 1),
240240
)
241241
attn_weights = jnp.concatenate([attn_weights, extra_col], axis=-1)
242242

243243
# Pad mask if present to match new shape
244244
if mask is not None:
245-
mask = jnp.pad(mask, ((0, 0), (0, 0), (0, 0), (0, 1)), mode='constant', constant_values=0)
245+
mask = jnp.pad(
246+
mask, ((0, 0), (0, 0), (0, 0), (0, 1)), mode="constant", constant_values=0
247+
)
246248

247249
def convert_to_softmax_type(attn_mask_type, mask):
248250
"""Convert the attn_mask_type to SoftmaxFusion"""

0 commit comments

Comments
 (0)