Skip to content

Reverse sort matmul benchmarks by memory use to avoid fragmentation#4449

Merged
jacobhinkle merged 7 commits intomainfrom
jh/sort_benchmark_problems
May 20, 2025
Merged

Reverse sort matmul benchmarks by memory use to avoid fragmentation#4449
jacobhinkle merged 7 commits intomainfrom
jh/sort_benchmark_problems

Conversation

@jacobhinkle
Copy link
Copy Markdown
Collaborator

@jacobhinkle jacobhinkle commented May 14, 2025

Before this PR memory use on H200 grows slowly to reach 126 GiB out of 128 GiB capacity.

After this PR memory use on H200 never goes above about 66 GiB.

Also, we previously had a 20 GiB cutoff for included benchmarks but this changes that to 90% of GPU capacity.

@jacobhinkle jacobhinkle requested a review from Priya2698 May 14, 2025 13:01
@github-actions
Copy link
Copy Markdown

github-actions bot commented May 14, 2025

Review updated until commit be28037

Description

  • Reverse sort matmul benchmarks by memory use

  • Update OOM cutoff to 90% of GPU capacity

  • Introduce maybe_skip_oom_case utility

  • Add docstrings and utility functions


Changes walkthrough 📝

Relevant files
Enhancement
test_matmul.py
Enhance matmul benchmark memory management                             

benchmarks/python/test_matmul.py

  • Import functools for comparison utilities
  • Add row_mem function to compute memory usage
  • Add three_way_cmp and mem_cmp functions for sorting
  • Reverse sort matmul problems by memory use
  • Introduce maybe_skip_oom_case to skip OOM cases
  • Update OOM cutoff to 90% of GPU capacity
  • Use maybe_skip_oom_case in test functions
  • +42/-5   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Memory Calculation

    The memory calculation in row_mem and maybe_skip_oom_case assumes half-precision (FP16) data types. Ensure that this assumption is correct for all use cases and that the memory calculation is accurate.

            def row_mem(row):
                """Compute gmem bytes used in a half-precision GEMM
    
                This computes only the space required for operands and output,
                ignoring intermediates like split-K buffers and assuming no bias
                terms.
                """
                m, n, k, _ = row
                return ((m + n) * k + m * n) * 2
    
            def three_way_cmp(a, b) -> int:
                """Perform a three-way comparison like the Python2 cmp() function
    
                This returns 0 if a == b, -1 if a < b, and 1 if a > b
                """
                return int(a > b) - int(a < b)
    
            def mem_cmp(row1, row2):
                """Compare two rows based on memory use"""
                return three_way_cmp(row_mem(row1), row_mem(row2))
    
            # Reverse sort by expected memory use to avoid fragmentation
            rows.sort(key=functools.cmp_to_key(mem_cmp), reverse=True)
    
            return rows
    
    
    def maybe_skip_oom_case(m: int, n: int, k: int):
        expected_mem = (m * k + n * k + m * n) * 2  # operands plus output
        expected_mem *= 2  # account for multiple runs/deferred frees
    
        _, total = torch.cuda.mem_get_info()
        max_mem = total * 0.9
        if expected_mem > max_mem:
            pytest.skip(
                f"Case takes more than {max_mem / (2 ** 30): .2f} GiB. Skipping to avoid OOM"
            )
    Three-Way Comparison

    The three_way_cmp function is a reimplementation of the Python 2 cmp function. Consider using Python's built-in comparison operators or functools.cmp_to_key directly for simplicity and readability.

    def three_way_cmp(a, b) -> int:
        """Perform a three-way comparison like the Python2 cmp() function
    
        This returns 0 if a == b, -1 if a < b, and 1 if a > b
        """
        return int(a > b) - int(a < b)
    Memory Threshold

    The memory threshold is set to 90% of GPU capacity. Verify that this threshold is appropriate for all target hardware and that it balances performance and resource utilization effectively.

    max_mem = total * 0.9
    if expected_mem > max_mem:

    @jacobhinkle
    Copy link
    Copy Markdown
    Collaborator Author

    !build

    @Priya2698
    Copy link
    Copy Markdown
    Collaborator

    I believe you switched to custom memory checks for matmul since retry_on_oom resulted in errors pertaining to benchmark fixture being used twice?

    @jacobhinkle jacobhinkle requested a review from Priya2698 May 19, 2025 16:18
    Comment on lines +39 to +43
    if a < b:
    return -1
    elif a > b:
    return 1
    return 0
    Copy link
    Copy Markdown
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I think this condition can be re-written as return (a > b) - (a < b).

    Copy link
    Copy Markdown
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Clever!

    return rows


    def maybe_skip_oom_case(m: int, n: int, k: int):
    Copy link
    Copy Markdown
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Should we also clear allocated memory here?

    Copy link
    Copy Markdown
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Since we reverse sort the benchmarks by required memory, we shouldn't need to clear the allocated memory, since subsequent tests will fit into the already-allocated memory pool.

    Copy link
    Copy Markdown
    Collaborator

    @Priya2698 Priya2698 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM

    @jacobhinkle
    Copy link
    Copy Markdown
    Collaborator Author

    !build

    @jacobhinkle jacobhinkle merged commit c1d8a3c into main May 20, 2025
    15 of 16 checks passed
    @jacobhinkle jacobhinkle deleted the jh/sort_benchmark_problems branch May 20, 2025 17:59
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants