Skip to content

Commit 4d01530

Browse files
authored
feat: introduce barex allocator (#932)
1 parent f89c207 commit 4d01530

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

mooncake-integration/allocator.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,56 @@ def get_allocator(cls, device: torch_device) -> CUDAPluggableAllocator:
4444
so_path, "mc_nvlink_malloc", "mc_nvlink_free"
4545
)
4646
return cls._instances[device]
47+
48+
49+
class BarexAllocator:
50+
_instances: Dict[torch_device, CUDAPluggableAllocator] = {}
51+
_lock: Final = threading.Lock()
52+
53+
@classmethod
54+
def _get_so_path(cls) -> str:
55+
"""Dynamically locate libaccl_barex.so for barex memory allocation"""
56+
# Check common system paths for libaccl_barex.so
57+
possible_paths = [
58+
"/usr/lib/libaccl_barex.so", # Ubuntu [deb]
59+
"/usr/lib64/libaccl_barex.so", # AliOS [rpm]
60+
]
61+
62+
for path in possible_paths:
63+
if os.path.exists(path):
64+
return path
65+
66+
# Try to locate in mooncake package installation
67+
try:
68+
# Attempt to locate package resource
69+
with resources.path("mooncake", "libaccl_barex.so") as so_path:
70+
if so_path.exists():
71+
return str(so_path)
72+
except (ImportError, FileNotFoundError, TypeError):
73+
pass
74+
75+
# Fallback strategy: check in package location via import metadata
76+
try:
77+
import mooncake
78+
79+
base_path = os.path.dirname(os.path.abspath(mooncake.__file__))
80+
so_path = os.path.join(base_path, "libaccl_barex.so")
81+
if os.path.exists(so_path):
82+
return so_path
83+
except (ImportError, FileNotFoundError, TypeError):
84+
pass
85+
86+
raise ImportError(
87+
"BarexAllocator requires libaccl_barex.so to be installed. "
88+
"Please install the barex library or ensure it's in the system path."
89+
)
90+
91+
@classmethod
92+
def get_allocator(cls, device: torch_device) -> CUDAPluggableAllocator:
93+
with cls._lock:
94+
if device not in cls._instances:
95+
so_path = cls._get_so_path()
96+
cls._instances[device] = CUDAPluggableAllocator(
97+
so_path, "u2mm_alloc_wrapper", "u2mm_free_wrapper"
98+
)
99+
return cls._instances[device]

0 commit comments

Comments
 (0)