From 54bc5128cf0b954702607a0eb6a24a08df5d545d Mon Sep 17 00:00:00 2001 From: rbrugaro Date: Mon, 29 Jul 2024 14:56:50 -0700 Subject: [PATCH 01/15] set num threads and memory binding for better OOB performance --- docker/Dockerfile.intel | 9 +++++++-- optimum/intel/ipex/modeling_base.py | 28 +++++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/docker/Dockerfile.intel b/docker/Dockerfile.intel index 60fd51b424..0d4e5dc5da 100644 --- a/docker/Dockerfile.intel +++ b/docker/Dockerfile.intel @@ -10,6 +10,8 @@ ARG BASE_IMAGE=ubuntu:22.04 FROM ${BASE_IMAGE} +ENV http_proxy=http://proxy-chain.intel.com:912 +ENV https_proxy=http://proxy-chain.intel.com:912 RUN --mount=type=cache,id=apt-dev,target=/var/cache/apt \ sh -c "apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ @@ -27,6 +29,8 @@ RUN --mount=type=cache,id=apt-dev,target=/var/cache/apt \ libpng-dev \ python3 \ python3-pip \ + python3-dev \ + libnuma-dev \ && rm -rf /var/lib/apt/lists/*" RUN /usr/sbin/update-ccache-symlinks RUN mkdir /opt/ccache && ccache --set-config=cache_dir=/opt/ccache @@ -43,7 +47,8 @@ RUN python3 -m pip install --no-cache-dir \ torchaudio==${TORCHAUDIO_VERSION} \ -f https://download.pytorch.org/whl/torch_stable.html && \ python3 -m pip install intel-extension-for-pytorch==$IPEX_VERSION && \ - python3 -m pip install oneccl_bind_pt --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ + python3 -m pip install oneccl_bind_pt --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ && \ + python3 -m pip install --no-cache-dir numa ARG OMP_NUM_THREADS=1 ENV OMP_NUM_THREADS=${OMP_NUM_THREADS} @@ -51,4 +56,4 @@ ARG KMP_BLOCKTIME=1 ENV KMP_BLOCKTIME=${KMP_BLOCKTIME} ARG KMP_HW_SUBSET=1T ENV KMP_HW_SUBSET=${KMP_HW_SUBSET} -ENV LD_PRELOAD="/usr/local/lib/libiomp5.so /usr/lib/x86_64-linux-gnu/libtcmalloc.so" \ No newline at end of file +ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc.so" diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 18f38cd666..bd8bccc809 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -60,7 +60,6 @@ from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, recursive_to_device - logger = logging.getLogger(__name__) @@ -129,6 +128,21 @@ def ipex_jit_trace(model, task, use_cache): return trace_model +def get_int_from_env(env_keys, default): + """Returns the first positive env value found in the `env_keys` list or the default.""" + for e in env_keys: + val = int(os.environ.get(e, -1)) + if val >= 0: + return val + return default + +def get_number_of_sockets(): + sockets = set() + with open('/proc/cpuinfo') as f: + for line in f: + if line.startswith('physical id'): + sockets.add(line.strip().split()[-1]) + return len(sockets) class IPEXModel(OptimizedModel): auto_model_class = AutoModel @@ -153,6 +167,18 @@ def __init__( else: self._device = torch.device("cpu") + import numa + import psutil + n_sockets=get_number_of_sockets() + num_cpu_threads_per_process = int(psutil.cpu_count(logical=False) / n_sockets) + os.environ["OMP_NUM_THREADS"]=str(num_cpu_threads_per_process) + torch.set_num_threads(num_cpu_threads_per_process) + numa.set_affinity(0,range(num_cpu_threads_per_process)) + numa.set_membind([0]) + print("affinity", numa.get_affinity(0)) + print("membind", numa.get_membind()) + + # CPU only support jit model for now. if export: if isinstance(model, torch.jit.RecursiveScriptModule): From 5d19b461c13ee583aad90d3243796697fa5d1307 Mon Sep 17 00:00:00 2001 From: rbrugaro Date: Mon, 29 Jul 2024 14:59:33 -0700 Subject: [PATCH 02/15] clean env var --- docker/Dockerfile.intel | 2 -- 1 file changed, 2 deletions(-) diff --git a/docker/Dockerfile.intel b/docker/Dockerfile.intel index 0d4e5dc5da..fc9058673a 100644 --- a/docker/Dockerfile.intel +++ b/docker/Dockerfile.intel @@ -10,8 +10,6 @@ ARG BASE_IMAGE=ubuntu:22.04 FROM ${BASE_IMAGE} -ENV http_proxy=http://proxy-chain.intel.com:912 -ENV https_proxy=http://proxy-chain.intel.com:912 RUN --mount=type=cache,id=apt-dev,target=/var/cache/apt \ sh -c "apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ From 9387717a44c31c600ee4aea7888b86829783232a Mon Sep 17 00:00:00 2001 From: rbrugaro Date: Thu, 1 Aug 2024 13:54:20 -0700 Subject: [PATCH 03/15] added core and memory binding util for improved performance --- docker/Dockerfile.intel | 2 - optimum/intel/ipex/modeling_base.py | 29 +------------- optimum/intel/utils/modeling_utils.py | 58 +++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 30 deletions(-) diff --git a/docker/Dockerfile.intel b/docker/Dockerfile.intel index fc9058673a..a7f1dc978f 100644 --- a/docker/Dockerfile.intel +++ b/docker/Dockerfile.intel @@ -48,8 +48,6 @@ RUN python3 -m pip install --no-cache-dir \ python3 -m pip install oneccl_bind_pt --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ && \ python3 -m pip install --no-cache-dir numa -ARG OMP_NUM_THREADS=1 -ENV OMP_NUM_THREADS=${OMP_NUM_THREADS} ARG KMP_BLOCKTIME=1 ENV KMP_BLOCKTIME=${KMP_BLOCKTIME} ARG KMP_HW_SUBSET=1T diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index bd8bccc809..02dc02d141 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -60,6 +60,7 @@ from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, recursive_to_device + logger = logging.getLogger(__name__) @@ -128,22 +129,6 @@ def ipex_jit_trace(model, task, use_cache): return trace_model -def get_int_from_env(env_keys, default): - """Returns the first positive env value found in the `env_keys` list or the default.""" - for e in env_keys: - val = int(os.environ.get(e, -1)) - if val >= 0: - return val - return default - -def get_number_of_sockets(): - sockets = set() - with open('/proc/cpuinfo') as f: - for line in f: - if line.startswith('physical id'): - sockets.add(line.strip().split()[-1]) - return len(sockets) - class IPEXModel(OptimizedModel): auto_model_class = AutoModel export_feature = "feature-extraction" @@ -167,18 +152,6 @@ def __init__( else: self._device = torch.device("cpu") - import numa - import psutil - n_sockets=get_number_of_sockets() - num_cpu_threads_per_process = int(psutil.cpu_count(logical=False) / n_sockets) - os.environ["OMP_NUM_THREADS"]=str(num_cpu_threads_per_process) - torch.set_num_threads(num_cpu_threads_per_process) - numa.set_affinity(0,range(num_cpu_threads_per_process)) - numa.set_membind([0]) - print("affinity", numa.get_affinity(0)) - print("membind", numa.get_membind()) - - # CPU only support jit model for now. if export: if isinstance(model, torch.jit.RecursiveScriptModule): diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index 9b68266d16..ad19755669 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -110,3 +110,61 @@ def _find_files_matching_pattern( files = [Path(p) for p in repo_files if re.match(pattern, str(p)) and str(p.parent) == subfolder] return files + +def get_number_of_sockets(): + """linux only""" + try: + sockets = set() + with open('/proc/cpuinfo') as f: + for line in f: + if line.startswith('physical id'): + sockets.add(line.strip().split()[-1]) + return len(sockets) + except Exception as e: + print(f"Error retrieving number of sockets: {e}") + +def bind_cores_for_best_perf(): + """ + In a multi-socker system binds CPU cores to single socket and numa node for better OOB performance. + + System configuration is equivalent than running the following command when launching the script: + numactl -C '0-'${PHYSICAL_CORES_PER_SOCKET} --membind 0 python script.py + + Returns: + None + """ + + import importlib.util + import platform + system = platform.system() + if system == "Linux": + if importlib.util.find_spec("numa") is not None: + import numa + import psutil + import os + + nodes = numa.get_max_node() + 1 + n_sockets = get_number_of_sockets() + if n_sockets != nodes: + print(f'Warning: number of sockets {n_sockets} does not match number of NUMA nodes {nodes}.') + print('Newer CPUs enable sub-numa cluster (SNC) but LLMs may show improved performance with SNC disabled in BIOS.') + if os.getenv("OMP_NUM_THREADS") is None: + # set OMP_NUM_THREADS to number of physical cores per socket + num_cpu_threads_per_process = int(psutil.cpu_count(logical=True) / n_sockets) + os.environ['OMP_NUM_THREADS'] = str(num_cpu_threads_per_process) + print(f"OMP_NUM_THREADS/MKL_NUM_THREADS unset, we set it at {num_cpu_threads_per_process} to improve oob performance.") + else: + #do not override if OMP_NUM_THREADS already set + num_cpu_threads_per_process = int(os.getenv("OMP_NUM_THREADS")) + torch.set_num_threads(num_cpu_threads_per_process) + + # Bind the current process to the specified range of CPU cores + numa.set_affinity(0, range(num_cpu_threads_per_process)) + # Check if the current memory binding policy includes all NUMA nodes + if len(numa.get_membind()) == nodes: + # Bind the process's memory allocation to the first NUMA node + numa.set_membind([0]) + else: + print("numa module not found, skipping binding cores") + else: + print("OS not supported, skipping binding cores") From 191f77218eed21f83a8de179de8f383ea76eb0b2 Mon Sep 17 00:00:00 2001 From: rbrugaro Date: Mon, 5 Aug 2024 10:17:36 -0700 Subject: [PATCH 04/15] add example usage in docstring --- optimum/intel/utils/modeling_utils.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index ad19755669..b65f5642a1 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -127,11 +127,28 @@ def bind_cores_for_best_perf(): """ In a multi-socker system binds CPU cores to single socket and numa node for better OOB performance. - System configuration is equivalent than running the following command when launching the script: + System configuration is equivalent to running the following command when launching the script: numactl -C '0-'${PHYSICAL_CORES_PER_SOCKET} --membind 0 python script.py + Example: + .. code-block:: python + + from optimum.intel.ipex import IPEXModelForCausalLM + from optimum.intel.utils.modeling_utils import bind_cores_for_best_perf + + bind_cores_for_best_perf() + model = IPEXModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.bfloat16, export=True) + tokenizer = AutoTokenizer.from_pretrained("gpt2") + input_sentence = ["tell me a story about a trip to the moon"] + model_inputs = tokenizer(input_sentence, return_tensors="pt") + generation_kwargs = dict(max_new_tokens=500) + generated_ids = model.generate(**model_inputs, **generation_kwargs) + Returns: None + + Note: + For distributed and multi-rank applications rely on vLLM, Ray,... to set proper system configuration. """ import importlib.util @@ -142,7 +159,6 @@ def bind_cores_for_best_perf(): import numa import psutil import os - nodes = numa.get_max_node() + 1 n_sockets = get_number_of_sockets() if n_sockets != nodes: From 2773fe62e05fc5ab823a267dda54cb37dfe9c58b Mon Sep 17 00:00:00 2001 From: rbrugaro Date: Wed, 7 Aug 2024 11:20:20 -0700 Subject: [PATCH 05/15] change utlity for best oob to support world_size and rank >=1 --- optimum/intel/utils/modeling_utils.py | 80 +++++++++++++++------------ 1 file changed, 44 insertions(+), 36 deletions(-) diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index b65f5642a1..487323b2a8 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -18,6 +18,7 @@ import torch from huggingface_hub import HfApi, HfFolder +import os MULTI_QUERY_ATTN_MODELS = {"falcon", "gpt_bigcode"} @@ -111,24 +112,18 @@ def _find_files_matching_pattern( return files -def get_number_of_sockets(): - """linux only""" - try: - sockets = set() - with open('/proc/cpuinfo') as f: - for line in f: - if line.startswith('physical id'): - sockets.add(line.strip().split()[-1]) - return len(sockets) - except Exception as e: - print(f"Error retrieving number of sockets: {e}") +def get_int_from_env(env_keys, default): + """Returns the first positive env value found in the `env_keys` list or the default.""" + for e in env_keys: + val = int(os.environ.get(e, -1)) + if val >= 0: + return val + return default def bind_cores_for_best_perf(): - """ - In a multi-socker system binds CPU cores to single socket and numa node for better OOB performance. - - System configuration is equivalent to running the following command when launching the script: - numactl -C '0-'${PHYSICAL_CORES_PER_SOCKET} --membind 0 python script.py + """ + Set number of threads per rank, numa cpu affinity and numa memory binding if not already set for better OOB performance. + Works for wold_size >= 1 and rank >= 1 Example: .. code-block:: python @@ -147,8 +142,6 @@ def bind_cores_for_best_perf(): Returns: None - Note: - For distributed and multi-rank applications rely on vLLM, Ray,... to set proper system configuration. """ import importlib.util @@ -158,28 +151,43 @@ def bind_cores_for_best_perf(): if importlib.util.find_spec("numa") is not None: import numa import psutil - import os + import math + + world_size= get_int_from_env( + ["WORLD_SIZE", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE"], 1 + ) + rank_id= get_int_from_env( + ["LOCAL_RANK", "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK"], 0 + ) nodes = numa.get_max_node() + 1 - n_sockets = get_number_of_sockets() - if n_sockets != nodes: - print(f'Warning: number of sockets {n_sockets} does not match number of NUMA nodes {nodes}.') - print('Newer CPUs enable sub-numa cluster (SNC) but LLMs may show improved performance with SNC disabled in BIOS.') + rank_per_node = math.ceil(world_size / nodes) + num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes) + node_id = int(rank_id / rank_per_node) + rank_offset_per_node = rank_id % rank_per_node if os.getenv("OMP_NUM_THREADS") is None: - # set OMP_NUM_THREADS to number of physical cores per socket - num_cpu_threads_per_process = int(psutil.cpu_count(logical=True) / n_sockets) - os.environ['OMP_NUM_THREADS'] = str(num_cpu_threads_per_process) - print(f"OMP_NUM_THREADS/MKL_NUM_THREADS unset, we set it at {num_cpu_threads_per_process} to improve oob performance.") + # set OMP_NUM_THREADS to num of physical cores per socket + num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1) + print("setting OMP_NUM_THREADS to", num_cpus_per_rank) else: - #do not override if OMP_NUM_THREADS already set - num_cpu_threads_per_process = int(os.getenv("OMP_NUM_THREADS")) - torch.set_num_threads(num_cpu_threads_per_process) - - # Bind the current process to the specified range of CPU cores - numa.set_affinity(0, range(num_cpu_threads_per_process)) - # Check if the current memory binding policy includes all NUMA nodes + num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS")) + print("OMP_NUM_THREADS already set to ", num_cpus_per_rank) if len(numa.get_membind()) == nodes: - # Bind the process's memory allocation to the first NUMA node - numa.set_membind([0]) + # if numa memory binding is not set, set it to the node where the rank is running + numa.set_membind([node_id]) + + torch.set_num_threads(num_cpus_per_rank) + + + if len(numa.get_affinity(0)) == psutil.cpu_count(logical=True): + #if numa affinity is unset (default value is set to all logical cores) set it to the physical cores assigned to the rank + cpu_start = num_cpus_per_rank * rank_offset_per_node + numa.set_affinity( + 0, + list(numa.node_to_cpus(node_id))[ + cpu_start : cpu_start + num_cpus_per_rank + ], + ) + print(f"affinity={numa.get_affinity(0)}, membind = {numa.get_membind()}") else: print("numa module not found, skipping binding cores") else: From fa5526aa92c5ad8283534e6e312d66dc7231fba4 Mon Sep 17 00:00:00 2001 From: rbrugaro Date: Mon, 19 Aug 2024 22:17:59 -0700 Subject: [PATCH 06/15] fix style --- optimum/intel/ipex/modeling_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index c1e2e3a73e..67e707d594 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -129,6 +129,7 @@ def ipex_jit_trace(model, task, use_cache): return trace_model + class IPEXModel(OptimizedModel): auto_model_class = AutoModel export_feature = "feature-extraction" From 09ac1cedc7099fc379b57b9c63c4a344a2f9fc7d Mon Sep 17 00:00:00 2001 From: rbrugaro Date: Mon, 19 Aug 2024 22:42:08 -0700 Subject: [PATCH 07/15] fix node_id value to account for rank_id starts at zero --- optimum/intel/utils/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index 487323b2a8..89d435a5a1 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -162,7 +162,7 @@ def bind_cores_for_best_perf(): nodes = numa.get_max_node() + 1 rank_per_node = math.ceil(world_size / nodes) num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes) - node_id = int(rank_id / rank_per_node) + node_id = int((rank_id+1) / rank_per_node) rank_offset_per_node = rank_id % rank_per_node if os.getenv("OMP_NUM_THREADS") is None: # set OMP_NUM_THREADS to num of physical cores per socket From 8b476f80615bb83fdc40471937a3f898fad0a2c8 Mon Sep 17 00:00:00 2001 From: rbrugaro Date: Mon, 19 Aug 2024 23:57:56 -0700 Subject: [PATCH 08/15] numa node assignment calculated from local size not from world size --- optimum/intel/utils/modeling_utils.py | 41 ++++++++++++++------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index 89d435a5a1..95e4795de7 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import re from pathlib import Path from typing import List, Optional, Union import torch from huggingface_hub import HfApi, HfFolder -import os MULTI_QUERY_ATTN_MODELS = {"falcon", "gpt_bigcode"} @@ -112,6 +112,7 @@ def _find_files_matching_pattern( return files + def get_int_from_env(env_keys, default): """Returns the first positive env value found in the `env_keys` list or the default.""" for e in env_keys: @@ -120,8 +121,9 @@ def get_int_from_env(env_keys, default): return val return default + def bind_cores_for_best_perf(): - """ + """ Set number of threads per rank, numa cpu affinity and numa memory binding if not already set for better OOB performance. Works for wold_size >= 1 and rank >= 1 @@ -138,31 +140,33 @@ def bind_cores_for_best_perf(): model_inputs = tokenizer(input_sentence, return_tensors="pt") generation_kwargs = dict(max_new_tokens=500) generated_ids = model.generate(**model_inputs, **generation_kwargs) - + Returns: None - + """ - + import importlib.util import platform + system = platform.system() if system == "Linux": if importlib.util.find_spec("numa") is not None: + import math + import numa import psutil - import math - world_size= get_int_from_env( - ["WORLD_SIZE", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE"], 1 - ) - rank_id= get_int_from_env( - ["LOCAL_RANK", "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK"], 0 - ) + local_size = get_int_from_env( + ["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1 + ) + rank_id = get_int_from_env( + ["LOCAL_RANK", "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK"], 0 + ) nodes = numa.get_max_node() + 1 - rank_per_node = math.ceil(world_size / nodes) + rank_per_node = math.ceil(local_size / nodes) num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes) - node_id = int((rank_id+1) / rank_per_node) + node_id = int(rank_id / rank_per_node) rank_offset_per_node = rank_id % rank_per_node if os.getenv("OMP_NUM_THREADS") is None: # set OMP_NUM_THREADS to num of physical cores per socket @@ -174,18 +178,15 @@ def bind_cores_for_best_perf(): if len(numa.get_membind()) == nodes: # if numa memory binding is not set, set it to the node where the rank is running numa.set_membind([node_id]) - - torch.set_num_threads(num_cpus_per_rank) + torch.set_num_threads(num_cpus_per_rank) if len(numa.get_affinity(0)) == psutil.cpu_count(logical=True): - #if numa affinity is unset (default value is set to all logical cores) set it to the physical cores assigned to the rank + # if numa affinity is unset (default value is set to all logical cores) set it to the physical cores assigned to the rank cpu_start = num_cpus_per_rank * rank_offset_per_node numa.set_affinity( 0, - list(numa.node_to_cpus(node_id))[ - cpu_start : cpu_start + num_cpus_per_rank - ], + list(numa.node_to_cpus(node_id))[cpu_start : cpu_start + num_cpus_per_rank], ) print(f"affinity={numa.get_affinity(0)}, membind = {numa.get_membind()}") else: From 85ccfc4734a9fa471bf3f5e5771590a6ed866a33 Mon Sep 17 00:00:00 2001 From: rbrugaro Date: Wed, 21 Aug 2024 15:31:48 -0700 Subject: [PATCH 09/15] reorg imports, moved checks to import_utils, remove prints for logger --- optimum/intel/utils/__init__.py | 1 + optimum/intel/utils/import_utils.py | 13 ++++++++++++ optimum/intel/utils/modeling_utils.py | 29 ++++++++++++++------------- 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/optimum/intel/utils/__init__.py b/optimum/intel/utils/__init__.py index d77588f896..50cdfa143e 100644 --- a/optimum/intel/utils/__init__.py +++ b/optimum/intel/utils/__init__.py @@ -22,6 +22,7 @@ is_neural_compressor_available, is_neural_compressor_version, is_nncf_available, + is_numa_available, is_openvino_available, is_torch_version, is_transformers_available, diff --git a/optimum/intel/utils/import_utils.py b/optimum/intel/utils/import_utils.py index 6be0aac47a..5d0e3e2471 100644 --- a/optimum/intel/utils/import_utils.py +++ b/optimum/intel/utils/import_utils.py @@ -150,6 +150,15 @@ except importlib_metadata.PackageNotFoundError: _accelerate_available = False +_numa_available = importlib.util.find_spec("numa") is not None +_numa_version = "N/A" + +if _numa_available: + try: + _numa_version = importlib_metadata.version("numa") + except importlib_metadata.PackageNotFoundError: + _numa_available = False + def is_transformers_available(): return _transformers_available @@ -272,6 +281,10 @@ def is_accelerate_available(): return _accelerate_available +def is_numa_available(): + return _numa_available + + # This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319 def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str): """ diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index 95e4795de7..0eb92ae052 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -12,17 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import math import os +import platform import re from pathlib import Path from typing import List, Optional, Union +import psutil import torch from huggingface_hub import HfApi, HfFolder +from .import_utils import is_numa_available + MULTI_QUERY_ATTN_MODELS = {"falcon", "gpt_bigcode"} +logger = logging.getLogger(__name__) + def get_model_device(model: torch.nn.Module) -> torch.device: """ @@ -145,17 +153,10 @@ def bind_cores_for_best_perf(): None """ - - import importlib.util - import platform - system = platform.system() if system == "Linux": - if importlib.util.find_spec("numa") is not None: - import math - + if is_numa_available(): import numa - import psutil local_size = get_int_from_env( ["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1 @@ -169,12 +170,11 @@ def bind_cores_for_best_perf(): node_id = int(rank_id / rank_per_node) rank_offset_per_node = rank_id % rank_per_node if os.getenv("OMP_NUM_THREADS") is None: - # set OMP_NUM_THREADS to num of physical cores per socket num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1) - print("setting OMP_NUM_THREADS to", num_cpus_per_rank) + logger.info(f"Setting OMP_NUM_THREADS to {num_cpus_per_rank} for better performance") else: num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS")) - print("OMP_NUM_THREADS already set to ", num_cpus_per_rank) + logger.info(f"OMP_NUM_THREADS already set to {num_cpus_per_rank}") if len(numa.get_membind()) == nodes: # if numa memory binding is not set, set it to the node where the rank is running numa.set_membind([node_id]) @@ -188,8 +188,9 @@ def bind_cores_for_best_perf(): 0, list(numa.node_to_cpus(node_id))[cpu_start : cpu_start + num_cpus_per_rank], ) - print(f"affinity={numa.get_affinity(0)}, membind = {numa.get_membind()}") + logger.info(f"affinity={numa.get_affinity(0)}, membind = {numa.get_membind()}") else: - print("numa module not found, skipping binding cores") + logger.warning("numa module not found, skipping binding cores") + else: - print("OS not supported, skipping binding cores") + logger.error("bind_cores_for_best_perf: OS not supported, skipping binding cores") From 54073599068d75c6b2e5d751af1e5f593e926b7e Mon Sep 17 00:00:00 2001 From: rbrugaro Date: Fri, 23 Aug 2024 12:12:58 -0700 Subject: [PATCH 10/15] raise Errors with missing pkg and unsupported OS --- optimum/intel/utils/modeling_utils.py | 76 +++++++++++++-------------- 1 file changed, 36 insertions(+), 40 deletions(-) diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index 0eb92ae052..9370bc5ee9 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -153,44 +153,40 @@ def bind_cores_for_best_perf(): None """ - system = platform.system() - if system == "Linux": - if is_numa_available(): - import numa - - local_size = get_int_from_env( - ["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1 - ) - rank_id = get_int_from_env( - ["LOCAL_RANK", "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK"], 0 - ) - nodes = numa.get_max_node() + 1 - rank_per_node = math.ceil(local_size / nodes) - num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes) - node_id = int(rank_id / rank_per_node) - rank_offset_per_node = rank_id % rank_per_node - if os.getenv("OMP_NUM_THREADS") is None: - num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1) - logger.info(f"Setting OMP_NUM_THREADS to {num_cpus_per_rank} for better performance") - else: - num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS")) - logger.info(f"OMP_NUM_THREADS already set to {num_cpus_per_rank}") - if len(numa.get_membind()) == nodes: - # if numa memory binding is not set, set it to the node where the rank is running - numa.set_membind([node_id]) - - torch.set_num_threads(num_cpus_per_rank) - - if len(numa.get_affinity(0)) == psutil.cpu_count(logical=True): - # if numa affinity is unset (default value is set to all logical cores) set it to the physical cores assigned to the rank - cpu_start = num_cpus_per_rank * rank_offset_per_node - numa.set_affinity( - 0, - list(numa.node_to_cpus(node_id))[cpu_start : cpu_start + num_cpus_per_rank], - ) - logger.info(f"affinity={numa.get_affinity(0)}, membind = {numa.get_membind()}") - else: - logger.warning("numa module not found, skipping binding cores") - + if platform.system() != "Linux": + logger.error("bind_cores_for_best_perf: OS not supported, this function can only be run on Linux systems.") + raise OSError("bind_cores_for_best_perf: OS not supported, this function can only be run on Linux systems.") + if not is_numa_available(): + logger.error("'numa' module not found") + raise ImportError("'numa' module not found, install with 'pip install numa'") + import numa + + local_size = get_int_from_env(["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1) + rank_id = get_int_from_env( + ["LOCAL_RANK", "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK"], 0 + ) + nodes = numa.get_max_node() + 1 + rank_per_node = math.ceil(local_size / nodes) + num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes) + node_id = int(rank_id / rank_per_node) + rank_offset_per_node = rank_id % rank_per_node + if os.getenv("OMP_NUM_THREADS") is None: + num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1) + logger.info(f"Setting OMP_NUM_THREADS to {num_cpus_per_rank} for better performance") else: - logger.error("bind_cores_for_best_perf: OS not supported, skipping binding cores") + num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS")) + logger.info(f"OMP_NUM_THREADS already set to {num_cpus_per_rank}") + if len(numa.get_membind()) == nodes: + # if numa memory binding is not set, set it to the node where the rank is running + numa.set_membind([node_id]) + + torch.set_num_threads(num_cpus_per_rank) + + if len(numa.get_affinity(0)) == psutil.cpu_count(logical=True): + # if numa affinity is unset (default value is set to all logical cores) set it to the physical cores assigned to the rank + cpu_start = num_cpus_per_rank * rank_offset_per_node + numa.set_affinity( + 0, + list(numa.node_to_cpus(node_id))[cpu_start : cpu_start + num_cpus_per_rank], + ) + logger.info(f"affinity={numa.get_affinity(0)}, membind = {numa.get_membind()}") From 95b65fe8695990123cabdaa9d02e062c01f0a51f Mon Sep 17 00:00:00 2001 From: rbrugaro Date: Mon, 26 Aug 2024 10:41:06 -0700 Subject: [PATCH 11/15] added missng env var to list --- optimum/intel/utils/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index 9370bc5ee9..2a3c5e15a3 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -161,7 +161,7 @@ def bind_cores_for_best_perf(): raise ImportError("'numa' module not found, install with 'pip install numa'") import numa - local_size = get_int_from_env(["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1) + local_size = get_int_from_env(["LOCAL_WORLD_SIZE≈", "MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1) rank_id = get_int_from_env( ["LOCAL_RANK", "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK"], 0 ) From 7fb3cb5bc111aeb2a7e4b7060bb01c5b82d93d92 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Mon, 26 Aug 2024 20:19:54 +0200 Subject: [PATCH 12/15] Update optimum/intel/utils/modeling_utils.py --- optimum/intel/utils/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index 2a3c5e15a3..a9c7774dde 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -161,7 +161,7 @@ def bind_cores_for_best_perf(): raise ImportError("'numa' module not found, install with 'pip install numa'") import numa - local_size = get_int_from_env(["LOCAL_WORLD_SIZE≈", "MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1) + local_size = get_int_from_env(["LOCAL_WORLD_SIZE", "MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1) rank_id = get_int_from_env( ["LOCAL_RANK", "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK"], 0 ) From 3613f6942e5385e5fdd44fc480b867b748e1b288 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Mon, 26 Aug 2024 20:25:49 +0200 Subject: [PATCH 13/15] Update optimum/intel/utils/import_utils.py --- optimum/intel/utils/import_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optimum/intel/utils/import_utils.py b/optimum/intel/utils/import_utils.py index 5d0e3e2471..36fa5c147a 100644 --- a/optimum/intel/utils/import_utils.py +++ b/optimum/intel/utils/import_utils.py @@ -151,7 +151,6 @@ _accelerate_available = False _numa_available = importlib.util.find_spec("numa") is not None -_numa_version = "N/A" if _numa_available: try: From c825876bb9436dc6fe80a054c1676e486f83c7b9 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Mon, 26 Aug 2024 20:26:22 +0200 Subject: [PATCH 14/15] Update optimum/intel/utils/import_utils.py --- optimum/intel/utils/import_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/intel/utils/import_utils.py b/optimum/intel/utils/import_utils.py index 36fa5c147a..032280e940 100644 --- a/optimum/intel/utils/import_utils.py +++ b/optimum/intel/utils/import_utils.py @@ -154,7 +154,7 @@ if _numa_available: try: - _numa_version = importlib_metadata.version("numa") + importlib_metadata.version("numa") except importlib_metadata.PackageNotFoundError: _numa_available = False From e658ffcb3d0f9ffdd86ec5e15738f22caa2da7b4 Mon Sep 17 00:00:00 2001 From: rbrugaro Date: Mon, 26 Aug 2024 12:44:15 -0700 Subject: [PATCH 15/15] fix style quality error --- optimum/intel/utils/modeling_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index 76ac6b0139..1d2f7b03c5 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -186,7 +186,9 @@ def bind_cores_for_best_perf(): raise ImportError("'numa' module not found, install with 'pip install numa'") import numa - local_size = get_int_from_env(["LOCAL_WORLD_SIZE", "MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1) + local_size = get_int_from_env( + ["LOCAL_WORLD_SIZE", "MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1 + ) rank_id = get_int_from_env( ["LOCAL_RANK", "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK"], 0 )