From abe1ab5cd9f426f8e6348dd5434f2193f3d932eb Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 4 Dec 2025 14:51:36 -0800 Subject: [PATCH] add TE BF16 dot + unit tests Signed-off-by: Phuong Nguyen --- src/MaxText/configs/types.py | 172 +++++++++++++++++++------ src/MaxText/layers/quantizations.py | 5 +- tests/integration_tests/train_tests.py | 17 +++ 3 files changed, 151 insertions(+), 43 deletions(-) diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index 86eb77919d..167b42f45a 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -78,6 +78,11 @@ class QuantizationType(str, Enum): FP8_NANO_V2 = "fp8_nanoo" FP8_GPU = "fp8_gpu" FP8_FULL = "fp8_full" + TE = "te_noscaling" + TE_FP8_DS = "te_fp8_delayedscaling" + TE_FP8_CS = "te_fp8_currentscaling" + TE_MXFP8 = "te_mxfp8" + TE_NVFP4 = "te_nvfp4" class KvQuantAxis(str, Enum): @@ -290,13 +295,16 @@ class EmergencyCheckpointing(BaseModel): local_checkpoint_directory: PathStr = Field("", description="Local directory for emergency checkpoints.") local_checkpoint_period: NonNegativeInt = Field(0, description="Frequency (in steps) for local emergency checkpoints.") multi_tier_checkpointing_backup_interval_minutes: NonNegativeInt = Field( - 0, description="Interval in minutes to back up local checkpoints to persistent storage." + 0, + description="Interval in minutes to back up local checkpoints to persistent storage.", ) mtc_data_parallelism: int = Field( - 0, description="Number of identical pipelines in the job for multi-tier checkpointing. 0 defaults to num_slices." + 0, + description="Number of identical pipelines in the job for multi-tier checkpointing. 0 defaults to num_slices.", ) enable_emergency_checkpoint: bool = Field( - False, description="Legacy flag for enabling emergency checkpointing. Prefer `enable_multi_tier_checkpointing`." + False, + description="Legacy flag for enabling emergency checkpointing. Prefer `enable_multi_tier_checkpointing`.", ) use_replicator_service: bool = Field( False, @@ -371,12 +379,14 @@ class ModelArchitecture(BaseModel): head_dim: int = Field(128, description="Dimension of each attention head.") mlp_activations: list[str] = Field(["silu", "linear"], description="Activation functions in the MLP layer.") mlp_activations_limit: float = Field( - -1.0, description="Upper bound to clip the MLP activation values. -1.0 means no clipping." + -1.0, + description="Upper bound to clip the MLP activation values. -1.0 means no clipping.", ) normalization_layer_epsilon: float = Field(1.0e-05, description="Epsilon value for normalization layers.") fused_qkv: bool = Field(False, description="If supported, fuse the Q, K, and V projections.") attention_bias: bool = Field( - False, description="If True, adds a learnable bias to the query, key, and value projections." + False, + description="If True, adds a learnable bias to the query, key, and value projections.", ) fused_mlp: bool = Field(False, description="If supported, fuse the MLP layers.") @@ -435,7 +445,10 @@ class Attention(BaseModel): use_post_attn_norm: bool = Field(False, description="Apply LayerNorm after the attention block.") use_post_ffw_norm: bool = Field(False, description="Apply LayerNorm after the feed-forward block.") use_ragged_attention: bool = Field(False, description="Whether to use ragged attention kernels.") - use_tokamax_gmm: bool = Field(False, description="Whether to use the Tokamax library for GMM kernel implementation.") + use_tokamax_gmm: bool = Field( + False, + description="Whether to use the Tokamax library for GMM kernel implementation.", + ) ragged_block_size: int = Field(256, description="Block size for ragged attention.") enable_padding_causal_mask: bool = Field(True, description="Temporary flag for TE padding.") use_tokamax_splash: bool = Field(False, description="Whether to use tokamax splash attention.") @@ -489,7 +502,8 @@ class SplashAttention(BaseModel): sa_k_layout: str = Field("HEAD_DIM_MINOR", description="Layout for K in splash attention.") sa_v_layout: str = Field("HEAD_DIM_MINOR", description="Layout for V in splash attention.") use_max_logit_estimate: int = Field( - -1, description="-1 means no estimate, any > 0 value will be used as max logit estimate" + -1, + description="-1 means no estimate, any > 0 value will be used as max logit estimate", ) cost_estimate_flops_fwd: int = Field( -1, @@ -502,7 +516,8 @@ class SplashAttention(BaseModel): "to overlap for communication (backward)", ) dq_reduction_steps: int = Field( - 0, description="the number of reduction steps. For now, only 3 or all the kv steps are supported." + 0, + description="the number of reduction steps. For now, only 3 or all the kv steps are supported.", ) @@ -527,15 +542,18 @@ class MoEGeneral(BaseModel): load_balance_loss_weight: NonNegativeFloat = Field(0.01, description="Weight for the load balancing auxiliary loss.") use_custom_sort_vjp: bool = Field(True, description="Whether to use a custom sort VJP for sparse matmul ops.") use_ring_of_experts: bool = Field( - False, description="Whether to use Ring of Experts for sparse matmul expert parallelism." + False, + description="Whether to use Ring of Experts for sparse matmul expert parallelism.", ) use_random_routing: bool = Field(False, description="Whether to use random routing for debugging.") interleave_moe_layer_step: int = Field(1, description="Frequency of MoE layers, e.g., 2 means every 2nd layer is MoE.") expert_shard_attention_option: Literal["fsdp", "context"] = Field( - "fsdp", description="How the expert axis is used to shard attention weights and activations." + "fsdp", + description="How the expert axis is used to shard attention weights and activations.", ) moe_fsdp_use_two_stage_all_gather: bool = Field( - False, description="Use two separate All-Gather calls for MoE weights sharded on both FSDP and FSDP-transpose." + False, + description="Use two separate All-Gather calls for MoE weights sharded on both FSDP and FSDP-transpose.", ) fsdp_shard_on_exp: bool = Field( False, @@ -543,10 +561,12 @@ class MoEGeneral(BaseModel): "num_experts % fsdp_parallelism != 0.", ) norm_topk_prob: bool = Field( - False, description="Enable top-k probability normalization for router weights (Qwen3-specific)." + False, + description="Enable top-k probability normalization for router weights (Qwen3-specific).", ) float32_weight_sum: bool = Field( - True, description="Whether to use full fp32 precision for weight_sum during final unpermute in MoE." + True, + description="Whether to use full fp32 precision for weight_sum during final unpermute in MoE.", ) @@ -555,22 +575,40 @@ class MoEKernels(BaseModel): megablox: bool = Field(True, description="Whether to use Megablox kernels for MoE.") sparse_matmul: bool = Field(True, description="Whether to use sparse matmul kernels for MoE.") - wi_tile_fwd_batch_seq: int = Field(512, description="forward pass tiling dimension for batch/sequence in GMM for wi.") + wi_tile_fwd_batch_seq: int = Field( + 512, + description="forward pass tiling dimension for batch/sequence in GMM for wi.", + ) wi_tile_fwd_embed_dim: int = Field(1024, description="forward pass tiling dimension for embedding in GMM for wi.") wi_tile_fwd_mlp_dim: int = Field(1024, description="forward pass tiling dimension for MLP in GMM for wi.") - wi_tile_dlhs_batch_seq: int = Field(512, description="bwd pass dlhs tiling dimension for batch/sequence in GMM for wi.") + wi_tile_dlhs_batch_seq: int = Field( + 512, + description="bwd pass dlhs tiling dimension for batch/sequence in GMM for wi.", + ) wi_tile_dlhs_embed_dim: int = Field(1024, description="bwd pass dlhs tiling dimension for embedding in GMM for wi.") wi_tile_dlhs_mlp_dim: int = Field(1024, description="bwd pass dlhs tiling dimension for MLP in GMM for wi.") - wi_tile_drhs_batch_seq: int = Field(512, description="bwd pass drhs tiling dimension for batch/sequence in GMM for wi.") + wi_tile_drhs_batch_seq: int = Field( + 512, + description="bwd pass drhs tiling dimension for batch/sequence in GMM for wi.", + ) wi_tile_drhs_embed_dim: int = Field(1024, description="bwd pass drhs tiling dimension for embedding in GMM for wi.") wi_tile_drhs_mlp_dim: int = Field(1024, description="bwd pass drhs tiling dimension for MLP in GMM for wi.") - wo_tile_fwd_batch_seq: int = Field(512, description="forward pass tiling dimension for batch/sequence in GMM for wo.") + wo_tile_fwd_batch_seq: int = Field( + 512, + description="forward pass tiling dimension for batch/sequence in GMM for wo.", + ) wo_tile_fwd_embed_dim: int = Field(1024, description="forward pass tiling dimension for embedding in GMM for wo.") wo_tile_fwd_mlp_dim: int = Field(1024, description="forward pass tiling dimension for MLP in GMM for wo.") - wo_tile_dlhs_batch_seq: int = Field(512, description="bwd pass dlhs tiling dimension for batch/sequence in GMM for wo.") + wo_tile_dlhs_batch_seq: int = Field( + 512, + description="bwd pass dlhs tiling dimension for batch/sequence in GMM for wo.", + ) wo_tile_dlhs_embed_dim: int = Field(1024, description="bwd pass dlhs tiling dimension for embedding in GMM for wo.") wo_tile_dlhs_mlp_dim: int = Field(1024, description="bwd pass dlhs tiling dimension for MLP in GMM for wo.") - wo_tile_drhs_batch_seq: int = Field(512, description="bwd pass drhs tiling dimension for batch/sequence in GMM for wo.") + wo_tile_drhs_batch_seq: int = Field( + 512, + description="bwd pass drhs tiling dimension for batch/sequence in GMM for wo.", + ) wo_tile_drhs_embed_dim: int = Field(1024, description="bwd pass drhs tiling dimension for embedding in GMM for wo.") wo_tile_drhs_mlp_dim: int = Field(1024, description="bwd pass drhs tiling dimension for MLP in GMM for wo.") @@ -602,7 +640,10 @@ class Qwen3Next(BaseModel): gdn_value_head_dim: int = Field(128, description="Head dimension for the value in the Gated Delta Net.") gdn_num_key_heads: int = Field(16, description="Number of key/query heads in the Gated Delta Net.") gdn_num_value_heads: int = Field(32, description="Number of value heads in the Gated Delta Net.") - gdn_chunk_size: int = Field(64, description="Chunk size for the parallel scan algorithm in the Gated Delta Net.") + gdn_chunk_size: int = Field( + 64, + description="Chunk size for the parallel scan algorithm in the Gated Delta Net.", + ) use_qk_norm_in_gdn: bool = Field( True, description="Whether to apply L2 normalization to query and key tensors inside the Gated Delta Rule kernel.", @@ -638,7 +679,8 @@ class HardwareAndMesh(BaseModel): param_scan_axis: int = Field(1, description="Axis to scan over for parameters.") context_parallel_load_balance: bool = Field(True, description="Whether to use load balancing for context parallelism.") context_parallel_strategy: str = Field( - "all_gather", description="Strategy for context parallelism ('all_gather' or 'ring')." + "all_gather", + description="Strategy for context parallelism ('all_gather' or 'ring').", ) custom_mesh: str = Field("", description="Available options: ['hybrid_ring_64x4', 'hybrid_ring_32x8']") allow_split_physical_axes: bool = Field(False, description="Allow splitting physical axes for device mesh creation.") @@ -653,7 +695,8 @@ class LayoutAndSharding(BaseModel): logical_axis_rules: Any = Field([], description="Rules for mapping logical axes to physical mesh axes.") data_sharding: Any = Field([], description="Sharding for input data.") input_data_sharding_logical_axes: list[str] = Field( - ["activation_embed_and_logits_batch", "activation_norm_length"], description="Logical axes for sharding input data." + ["activation_embed_and_logits_batch", "activation_norm_length"], + description="Logical axes for sharding input data.", ) sharding_tolerance: float = Field( 0.02, @@ -778,7 +821,10 @@ class Tokenizer(BaseModel): tokenize_eval_data: bool = Field(True, description="If False, assumes the evaluation dataset is pre-tokenized.") add_bos: bool = Field(True, description="Whether to add a beginning-of-sentence token.") add_eos: bool = Field(True, description="Whether to add an end-of-sentence token.") - use_truncation: bool = Field(True, description="If False, use chunking for long sequences instead of truncation.") + use_truncation: bool = Field( + True, + description="If False, use chunking for long sequences instead of truncation.", + ) num_vocab_tiling: int = Field( 1, description="Enables memory-saving optimization by tiling cross-entropy loss computation. >1 to enable.", @@ -804,16 +850,19 @@ class DatasetGeneral(BaseModel): description="Whether to pack multiple short examples into a single sequence.", ) max_segments_per_seq: int = Field( - 32, description="Maximum number of segments that can be packed into a single sequence." + 32, + description="Maximum number of segments that can be packed into a single sequence.", ) num_epoch: int = Field(1, description="Number of epochs to train for.") expansion_factor_real_data: float = Field(-1.0, description="Factor for partial data loading on hosts.") reuse_example_batch: int = Field(0, description="For performance testing, repeatedly uses the same batch.") generate_padding_batch_train: bool = Field( - False, description="Whether to generate a padding batch for training to ensure divisibility." + False, + description="Whether to generate a padding batch for training to ensure divisibility.", ) generate_padding_batch_eval: bool = Field( - False, description="Whether to generate a padding batch for evaluation to ensure divisibility." + False, + description="Whether to generate a padding batch for evaluation to ensure divisibility.", ) enable_rampup_batch_size: bool = Field(False, description="Enable rampup batch size.") per_device_batch_size_start: float = Field(4.0, description="Start per device batch size for rampup.") @@ -854,11 +903,13 @@ class GrainDataset(BaseModel): grain_file_type: str = Field("arrayrecord", description="File type for Grain data.") grain_worker_count: int = Field(1, description="Number of workers for Grain data loading.") grain_per_worker_buffer_size: int = Field( - 1, description="Buffer size for each worker for Grain data loading during training." + 1, + description="Buffer size for each worker for Grain data loading during training.", ) grain_worker_count_eval: int = Field(1, description="Number of workers for Grain eval data loading.") grain_per_worker_buffer_size_eval: int = Field( - 1, description="Buffer size for each worker for Grain data loading during evaluation." + 1, + description="Buffer size for each worker for Grain data loading during evaluation.", ) grain_ram_budget_mb: int = Field(1024, description="RAM budget (MB) for auto-tuning worker count.") grain_num_threads: int = Field(16, description="Number of threads for Grain ReadOptions during training.") @@ -868,7 +919,8 @@ class GrainDataset(BaseModel): 500, description="Prefetch buffer size for Grain ReadOptions during evaluation." ) grain_data_source_max_workers: int = Field( - 16, description="Max workers for ThreadPoolExecutor when mixing multiple Grain data sources." + 16, + description="Max workers for ThreadPoolExecutor when mixing multiple Grain data sources.", ) @@ -929,7 +981,9 @@ class Optimizer(BaseModel): ) warmup_steps_fraction: float = Field(0.1, ge=0.0, le=1.0, description="Fraction of total steps for LR warmup.") learning_rate_schedule_steps: int = Field( - -1, ge=-1, description="Total steps for the LR schedule. -1 defaults to `steps`." + -1, + ge=-1, + description="Total steps for the LR schedule. -1 defaults to `steps`.", ) @@ -999,7 +1053,8 @@ class YarnRope(BaseModel): rope_interleave: bool = Field(True, description="Whether RoPE sin/cos are interleaved vs concatenated.") rope_truncate: bool = Field(True, description="Whether to floor/ceil the correction range for YaRN.") rope_attention_scaling: bool = Field( - False, description="Scale the rotary embedding output. Used by some models like gpt-oss." + False, + description="Scale the rotary embedding output. Used by some models like gpt-oss.", ) @@ -1237,7 +1292,10 @@ class VisionTower(BaseModel): ) tile_size_for_vit: int = Field(336, description="Tile size for the Vision Transformer.") patch_size_for_vit: int = Field(14, description="Patch size for the Vision Transformer.") - conv_stride_for_vit: int = Field(14, description="Convolutional stride for the Vision Transformer's patch embedding.") + conv_stride_for_vit: int = Field( + 14, + description="Convolutional stride for the Vision Transformer's patch embedding.", + ) num_hidden_layers_for_vit: int = Field(34, description="Number of hidden layers in the Vision Transformer.") rope_theta_for_vit: int = Field(10000, description="RoPE theta value for the Vision Transformer.") vision_output_dim_for_vit: int = Field(4096, description="Final output dimension of the vision-to-language projection.") @@ -1307,12 +1365,17 @@ class RLEvaluation(BaseModel): eval_sampling_strategy: str = Field("greedy", description="Sampling strategy for evaluation.") generation_configs: dict[str, Any] = Field( - default_factory=dict, description="Configurations for different generation strategies." + default_factory=dict, + description="Configurations for different generation strategies.", ) num_eval_passes: int = Field(1, description="Number of generation passes during evaluation.") - eval_corr_lst: bool = Field(False, description="If True, only include correct responses in the list during evaluation.") + eval_corr_lst: bool = Field( + False, + description="If True, only include correct responses in the list during evaluation.", + ) eval_make_lst: bool = Field( - False, description="If True, return a list of (question, answer, responses) during evaluation." + False, + description="If True, return a list of (question, answer, responses) during evaluation.", ) @@ -1603,7 +1666,11 @@ def set_derived_and_validate_values(self) -> "MaxTextConfig": # To work around SDK bug b/454725283, remove the trailing back slash from the managed_mldiagnostics_dir. self.managed_mldiagnostics_dir = os.path.join(output_dir, "managed-mldiagnostics") else: - self.checkpoint_dir, self.metrics_dir, self.tensorboard_dir = None, None, None + self.checkpoint_dir, self.metrics_dir, self.tensorboard_dir = ( + None, + None, + None, + ) # B. RESOLVE TOKENIZER PATH # If the tokenizer path is a relative name without a directory, resolve it against the assets root. @@ -1714,7 +1781,10 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de self.global_batch_size_to_eval_on, self.micro_batch_size_to_eval_on, ) = calculate_global_batch_sizes( - self.eval_per_device_batch_size, self.expansion_factor_real_data, self.num_target_devices, 1 + self.eval_per_device_batch_size, + self.expansion_factor_real_data, + self.num_target_devices, + 1, ) # Calculate ramp-up batch size parameters if enabled. @@ -1795,7 +1865,8 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de num_stages = int(self.ici_pipeline_parallelism * self.dcn_pipeline_parallelism) if self.num_pipeline_repeats == -1: num_pipeline_repeats, remainder = divmod( - self.pipeline_parallel_layers, num_stages * self.num_layers_per_pipeline_stage + self.pipeline_parallel_layers, + num_stages * self.num_layers_per_pipeline_stage, ) assert not remainder, ( f"The number of layers per stage ({self.num_layers_per_pipeline_stage}) times the number of stages " @@ -1843,7 +1914,10 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de stage_idx = self.mesh_axes.index("stage") data_idx = self.mesh_axes.index("data") if stage_idx > data_idx: # Ensure 'stage' comes before 'data' for correct sharding logic. - self.mesh_axes[stage_idx], self.mesh_axes[data_idx] = self.mesh_axes[data_idx], self.mesh_axes[stage_idx] + self.mesh_axes[stage_idx], self.mesh_axes[data_idx] = ( + self.mesh_axes[data_idx], + self.mesh_axes[stage_idx], + ) # Adjust data_sharding to also prioritize 'stage'. if ( @@ -1856,7 +1930,12 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de self.data_sharding[0].insert(0, "stage") # Add sharding for FP8 amax history when using pipeline parallelism. - if self.quantization and "fp8" in self.quantization: + if self.quantization and self.quantization in ( + "fp8", + "nanoo_fp8", + "fp8_gpu", + "te_fp8_delayedscaling", + ): self.logical_axis_rules.append(["aqt_amax_history", ("stage",)]) self.model_fsdp_ag_once = self.pipeline_fsdp_ag_once # Backward compatibility alias @@ -1890,7 +1969,10 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de raise ValueError("`sliding_window_size` must be an integer > 0 for 'local_sliding' attention.") if self.quantize_kvcache and not self.kv_quant_axis: raise ValueError("`kv_quant_axis` cannot be empty when quantize_kvcache is True.") - if self.quantization in ("fp8", "nanoo_fp8", "fp8_gpu") and self.gradient_accumulation_steps > 1: + if ( + self.quantization in ("fp8", "nanoo_fp8", "fp8_gpu", "te_fp8_delayedscaling") + and self.gradient_accumulation_steps > 1 + ): raise ValueError("FP8 quantization is not compatible with gradient accumulation.") if self.num_experts > 1: is_fully_moe = ( @@ -1911,7 +1993,13 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de if self.decoder_block == DecoderBlockType.GPT_OSS and not self.sparse_matmul and self.capacity_factor != -1: raise ValueError("GPT-OSS MoE only supports dropless (capacity_factor=-1) with dense matmul.") if self.use_multimodal: - valid_mm_models = ("gemma3-4b", "gemma3-12b", "gemma3-27b", "llama4-17b-16e", "llama4-17b-128e") + valid_mm_models = ( + "gemma3-4b", + "gemma3-12b", + "gemma3-27b", + "llama4-17b-16e", + "llama4-17b-128e", + ) if self.model_name not in valid_mm_models and self.model_name != "default": raise ValueError(f"Multimodal is only supported for {valid_mm_models}, not {self.model_name}") if self.use_sft: diff --git a/src/MaxText/layers/quantizations.py b/src/MaxText/layers/quantizations.py index d0f9353b6c..eafdd99ed2 100644 --- a/src/MaxText/layers/quantizations.py +++ b/src/MaxText/layers/quantizations.py @@ -748,6 +748,7 @@ def _get_recipe(recipe_name: str): from transformer_engine.common import recipe # pylint: disable=import-outside-toplevel # pytype: disable=import-error RECIPES = { + "te_noscaling": None, "te_fp8_delayedscaling": recipe.DelayedScaling, "te_fp8_currentscaling": recipe.Float8CurrentScaling, "te_mxfp8": recipe.MXFP8BlockScaling, @@ -755,7 +756,9 @@ def _get_recipe(recipe_name: str): } if recipe_name not in RECIPES: raise ValueError(f"Invalid TransformerEngine recipe: {recipe_name}") - return RECIPES[recipe_name]() + + te_recipe = RECIPES[recipe_name] + return te_recipe() if te_recipe is not None else None def get_block_size(self): """Get the block size for quantization for recipes that require blocks. diff --git a/tests/integration_tests/train_tests.py b/tests/integration_tests/train_tests.py index d29338e502..f09376759f 100644 --- a/tests/integration_tests/train_tests.py +++ b/tests/integration_tests/train_tests.py @@ -109,6 +109,18 @@ class TrainTests(unittest.TestCase): "enable_goodput_recording=False", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}", ], + "te_noscaling": [ # tests base config with te_noscaling i.e. BF16 + None, + os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + "base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + "dataset_path=gs://maxtext-dataset", + "quantization=te_noscaling", + "steps=2", + "enable_checkpointing=False", + "enable_goodput_recording=False", + rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}", + ], "te_fp8_delayedscaling": [ # tests base config with te_fp8_delayedscaling None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), @@ -234,6 +246,11 @@ def test_gpu_fp8(self): def test_gpu_nanoo_fp8(self): train_main(TrainTests.CONFIGS["nanoo_fp8"] + ["attention=dot_product"]) + @pytest.mark.integration_test + @pytest.mark.gpu_only + def test_gpu_te_noscaling(self): + train_main(TrainTests.CONFIGS["te_noscaling"] + ["attention=cudnn_flash_te"]) + @pytest.mark.skip(reason="No runner with GPU arch >= 89 is available") @pytest.mark.integration_test @pytest.mark.gpu_only