@@ -44,3 +44,56 @@ def get_allocator(cls, device: torch_device) -> CUDAPluggableAllocator:
4444 so_path , "mc_nvlink_malloc" , "mc_nvlink_free"
4545 )
4646 return cls ._instances [device ]
47+
48+
49+ class BarexAllocator :
50+ _instances : Dict [torch_device , CUDAPluggableAllocator ] = {}
51+ _lock : Final = threading .Lock ()
52+
53+ @classmethod
54+ def _get_so_path (cls ) -> str :
55+ """Dynamically locate libaccl_barex.so for barex memory allocation"""
56+ # Check common system paths for libaccl_barex.so
57+ possible_paths = [
58+ "/usr/lib/libaccl_barex.so" , # Ubuntu [deb]
59+ "/usr/lib64/libaccl_barex.so" , # AliOS [rpm]
60+ ]
61+
62+ for path in possible_paths :
63+ if os .path .exists (path ):
64+ return path
65+
66+ # Try to locate in mooncake package installation
67+ try :
68+ # Attempt to locate package resource
69+ with resources .path ("mooncake" , "libaccl_barex.so" ) as so_path :
70+ if so_path .exists ():
71+ return str (so_path )
72+ except (ImportError , FileNotFoundError , TypeError ):
73+ pass
74+
75+ # Fallback strategy: check in package location via import metadata
76+ try :
77+ import mooncake
78+
79+ base_path = os .path .dirname (os .path .abspath (mooncake .__file__ ))
80+ so_path = os .path .join (base_path , "libaccl_barex.so" )
81+ if os .path .exists (so_path ):
82+ return so_path
83+ except (ImportError , FileNotFoundError , TypeError ):
84+ pass
85+
86+ raise ImportError (
87+ "BarexAllocator requires libaccl_barex.so to be installed. "
88+ "Please install the barex library or ensure it's in the system path."
89+ )
90+
91+ @classmethod
92+ def get_allocator (cls , device : torch_device ) -> CUDAPluggableAllocator :
93+ with cls ._lock :
94+ if device not in cls ._instances :
95+ so_path = cls ._get_so_path ()
96+ cls ._instances [device ] = CUDAPluggableAllocator (
97+ so_path , "u2mm_alloc_wrapper" , "u2mm_free_wrapper"
98+ )
99+ return cls ._instances [device ]
0 commit comments