diff --git a/changelog.md b/changelog.md
index 92b294b1a..486544735 100644
--- a/changelog.md
+++ b/changelog.md
@@ -1,5 +1,13 @@
# Changelog
+## Unreleased
+
+### Fixed
+
+- Since `fork` hangs when HDFS has been used in the main process, we now auto detect if the currently running program has interacted with HDFS before auto-picking a process starting method.
+- We now account for pipe selection (ie `enable`, `disable` and `exclude`) when loading a model from huggingface hub.
+- We do not instantiate pipes in `exclude` anymore when loading a model (before they were instantiated but not added to the pipeline).
+
## v0.18.0 (2025-09-02)
📢 EDS-NLP will drop support for Python 3.7, 3.8 and 3.9 support in the next major release (v0.19.0), in October 2025. Please upgrade to Python 3.10 or later.
diff --git a/docs/assets/overrides/main.html b/docs/assets/overrides/main.html
index 586ec84ab..8e558138b 100644
--- a/docs/assets/overrides/main.html
+++ b/docs/assets/overrides/main.html
@@ -8,5 +8,5 @@
{% block announce %}
- Check out the new span classifier training tutorial !
+ Check out the new span classifier training tutorial and the Slurm tutorial !
{% endblock %}
diff --git a/docs/index.md b/docs/index.md
index c1869d29b..de1cefa0d 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -69,7 +69,7 @@ To learn more about EDS-NLP, we have prepared a series of tutorials that should
--8<-- "docs/tutorials/index.md:classic-tutorials"
-We also provide tutorials on how to train deep-learning models with EDS-NLP. These tutorials cover the training API, hyperparameter tuning, and more.
+**Deep-learning tutorials**: we also provide tutorials on how to train deep-learning models with EDS-NLP. These tutorials cover the training API, hyperparameter tuning, and more.
--8<-- "docs/tutorials/index.md:deep-learning-tutorials"
diff --git a/docs/tutorials/hpc.md b/docs/tutorials/hpc.md
new file mode 100644
index 000000000..f02ba5c17
--- /dev/null
+++ b/docs/tutorials/hpc.md
@@ -0,0 +1,303 @@
+# Running an existing model on HPC (e.g. Slurm)
+
+This tutorial shows how to run an existing deep-learning based EDS-NLP model (for example the
+public pseudonymisation model [eds-pseudo-public](https://eds-pseudo-public.streamlit.app) efficiently
+on a cluster. In a Clinical Data Warehouse like AP-HP's, most research projects might want to:
+
+1. first fetch a corpus of documents with PySpark. Depending on your computing setup, this might run on a specific cluster like Hadoop/YARN.
+2. run the NLP model on these notes. This is often best done on a GPU cluster, for instance one managed by Slurm.
+
+## Python inference script
+
+Let's start by the Python NLP inference script. We’ll write an inference script that:
+
+- loads an existing model, e.g. `AP-HP/dummy-ner` which annotates entities of the [DEFT 2020 dataset](https://hal.science/hal-03095262/document) on documents.
+- reads notes from a Parquet dataset (e.g. exported from Spark)
+- applies the model on these notes
+- writes entities back to a new Parquet dataset (e.g. to be re-imported in Spark)
+
+```python { title="inference.py" }
+import logging
+import os
+from datetime import datetime
+from typing import Union
+
+import confit
+import pyarrow.fs
+
+import edsnlp
+
+
+def make_fs(path: str, endpoint: str = None):
+ """
+ This function can be used to define the filesystem explicitly.
+ Otherwise, it's automatically created from the path
+ (ex: "s3://", "hdfs://", ...) using default parameters.
+ """
+
+ # For instance, if you have a s3 volume (S3 is not necessarily AWS !)
+ # you can use the S3 filesystem and provide credentials as env vars.
+ if path.startswith("s3://"):
+ return pyarrow.fs.S3FileSystem(
+ access_key=os.getenv("S3_ACCESS_KEY"),
+ secret_key=os.getenv("S3_SECRET_KEY"),
+ endpoint_override=os.getenv("S3_ENDPOINT"),
+ )
+ return None
+
+
+app = confit.Cli() #(1)!
+
+
+@app.command("inference")
+def main(
+ *,
+ input_path: str,
+ output_path: str,
+ model_name: str = "AP-HP/dummy-ner",
+ batch_size: str = "32 docs",
+ show_progress: bool = False,
+ output_file_size: Union[int, str] = 10_000,
+):
+ """
+ Run inference on a corpus of notes stored in Parquet format.
+
+ Parameters
+ ----------
+ input_path : str
+ Input Parquet path (e.g. s3://bucket/notes/ or hdfs path)
+ output_path : str
+ Output Parquet path (e.g. s3://bucket/note_nlp/ or hdfs path)
+ model_name : str
+ Model to load: local path, installed model package or EDS-NLP
+ compatible Hub repo (e.g. 'AP-HP/eds-pseudo-public')
+ batch_size : str
+ Batch size expression (e.g. '32 docs', '8000 words')
+ show_progress : bool
+ Show progress bars
+ output_file_size : Union[int, str]
+ Size per Parquet file (e.g. '1000 docs', '40000 words')
+ in the output dataset
+ """
+
+ logging.info("Model loading started")
+ nlp = edsnlp.load(model_name)
+ # Do anything to the model here
+ print(nlp)
+ logging.info("Model loading done")
+
+ input_fs = make_fs(input_path)
+ output_fs = make_fs(output_path)
+
+ print(f"Job started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
+
+ # Read OMOP-like parquet (note_id, person_id, note_text, ...)
+ docs = edsnlp.data.read_parquet(
+ path=input_path,
+ converter="omop",
+ filesystem=input_fs,
+ read_in_worker=True,
+ doc_attributes=["note_id", "person_id"], #(2)!
+ )
+
+ # Apply the model lazily
+ docs = docs.map_pipeline(nlp)
+
+ # Configure multiprocessing with automatic resource detection
+ docs = docs.set_processing(
+ backend="multiprocessing",
+ batch_size=batch_size,
+ show_progress=show_progress,
+ # You can set num_cpu_workers and num_gpu_workers here,
+ # otherwise they are auto-detected
+ )
+
+ # Write entities to parquet, with a fallback when no entity
+ # Feel free to change the output format here
+ def doc_to_rows(doc):
+ rows = [
+ dict(
+ note_id=getattr(doc._, "note_id", None),
+ person_id=getattr(doc._, "person_id", None),
+ offset_begin=ent.start_char,
+ offset_end=ent.end_char,
+ label=ent.label_,
+ snippet=ent.text,
+ date=getattr(ent._, 'date'),
+ # You can add other ent attributes here
+ # like ent._.certainty, ent._.family, etc.
+ nlp_system=model_name,
+ )
+ for ent in doc.ents
+ ]
+ return rows or [
+ dict(
+ note_id=getattr(doc._, "note_id", None),
+ person_id=getattr(doc._, "person_id", None),
+ offset_begin=0,
+ offset_end=0,
+ label="EMPTY",
+ snippet="",
+ date=None,
+ # You can add other ent attributes here
+ nlp_system=model_name,
+ )
+ ]
+
+ # We declare here where we want to write the output
+ # All writers trigger the execution by default (unless execute=False)
+ docs.write_parquet(
+ path=output_path,
+ overwrite=True,
+ batch_size=output_file_size,
+ converter=doc_to_rows,
+ filesystem=output_fs,
+ )
+
+ print(f"Job done: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
+
+
+if __name__ == "__main__":
+ app()
+```
+
+1. We use [confit](https://github.com/aphp/confit) to create a CLI application and enforce parameter types.
+
+!!! tip "Converters and schemas"
+
+ - If your input is not OMOP-like (ie with `note_text` and `note_id` columns),
+ provide your own reader converter instead of `converter="omop"` (see the [Converters](/data/converters) page).
+ - See the tutorial [Processing multiple texts](/tutorials/multiple-texts) for more about batching expressions (`batch_size`) and the `backend` options.
+
+## Accessing computation resources on Slurm
+
+Slurm is a workload manager for HPC clusters. You request resources (CPUs, memory, GPUs, time) and submit jobs with scripts. Key
+Below is a Slurm script that activates your environment, shows GPU info, and runs the inference script.
+
+```sbatch { title="job.sh" }
+# Name your job clearly to find it in queues and reports.
+#SBATCH --job-name=nlp
+# Walltime limit. Increase if you hit the time limit.
+#SBATCH --time 1:00:00
+# For instance here, we request one V100 GPU. Adapt to what your cluster
+# provides, and contact your admin if unsure.
+#SBATCH --gres=gpu:v100:1
+#SBATCH --partition gpuV100
+# Single-node job with 4 CPU cores for rule-based pipes, preprocessing,
+# collation, postprocessing.
+#SBATCH --nodes=1
+#SBATCH --cpus-per-task=10
+# RAM (!= GPU VRAM !) per node (in MB), adjust if you hit OOM errors.
+#SBATCH --mem=50000
+# Container config (if your Slurm allows this). Adapt to your cluster
+# setup and contact your admin if unsure.
+#SBATCH --container-image /scratch/images/sparkhadoop.sqsh --container-mounts=/export/home/$USER:/export/home/$USER,/export/home/share:/export/home/share,/data/scratch/$USER:/data/scratch/$USER --container-mount-home --container-writable --container-workdir=/
+# Stdout/stderr file patterns with `%j` expanded to the job ID.
+# You can put these in a logs/ directory if you prefer, but MAKE SURE
+# THAT THIS DIRECTORY EXISTS BEFORE SUBMITTING !
+#SBATCH --output=slurm-%j-stdout.log
+#SBATCH --error=slurm-%j-stderr.log
+
+set -euo pipefail
+# Setup the env. Simple setup for AP-HP cluster below
+# Refer to your HPC documentation for your own setup.
+/etc/start.sh
+export HADOOP_HOME=/usr/local/hadoop
+export CLASSPATH=`$HADOOP_HOME/bin/hdfs classpath --glob`
+export ARROW_LIBHDFS_DIR=/usr/local/hadoop/usr/lib/
+source "$HOME/.user_conda/miniconda/etc/profile.d/conda.sh"
+
+# Activate your environment(s), e.g. conda/venv/uv or a mix of these
+conda activate your-conda-env
+source path/to/your/project/.venv/bin/activate
+
+# You can install packages here. Doing this here can be useful to
+# ensure installed versions match the deployment env, for instance
+# glibc, CUDA versions, etc. Otherwise, install in your env beforehand.
+pip install "edsnlp[ml]" "pyarrow<17"
+
+# Check available GPUs
+nvidia-smi
+
+cd path/to/your/project
+python inference.py \
+ --model_name "AP-HP/dummy-ner" \
+ --input_path "hdfs:///user/USERNAME/notes/" \
+ --output_path "hdfs:///user/USERNAME/nlp_results/" \
+ --batch_size "10000 words" \
+ --output_file_size "1000 docs" \
+ --show_progress
+```
+
+## Run and monitor the job
+
+1. Launch the job and store the job id in a JOB_ID variable:
+ ```bash { data-md-color-scheme="slate" }
+ JOB_ID=$(sbatch job.sh | awk '{print $4}') && echo "Job: $JOB_ID"
+ ```
+ ```
+ Job ID: 123456
+ ```
+
+2. See the current running and pending jobs.with `squeue`
+ ```bash { data-md-color-scheme="slate" }
+ squeue
+ ```
+ ```
+ JOBID PARTITION NAME USER ST TIME NODES NODELIST(REASON)
+ 123456 gpuV100 nlp USERNAME R 0:10 1 gpu-node-01
+ ```
+- Cancel the job if needed with:
+ ```bash { data-md-color-scheme="slate" }
+ scancel $JOB_ID
+ ```
+
+- Follow the logs in real time with. See the above #SBATCH directive comment to put them in a directory if needed.
+ ```bash { data-md-color-scheme="slate" }
+ tail -f -n+0 slurm-$JOB_ID-std*.log
+ ```
+
+## Fetching data with PySpark
+
+The above job requires a Parquet dataset as input. You can use PySpark to extract notes from your CDW and write them to Parquet.
+In theory, you could run end-to-end with Spark using
+```python { .no-check }
+docs = edsnlp.data.from_spark(...)
+```
+However, this interleaves Spark’s distributed CPU scheduling with the GPU-based inference, often mobilizing many CPUs in an uncoordinated way while documents stream through both PySpark and the GPU workers.
+
+A more robust pattern is to decouple document selection from inference. In a Spark-enabled notebook or a Spark-submit job:
+
+1. Extract your input corpus with Spark, write to Parquet (HDFS or S3)
+ ```python { .no-check }
+ from pyspark.sql import SparkSession
+
+ spark = SparkSession.builder.getOrCreate()
+
+ note = spark.sql("""
+ SELECT note_id, person_id, note_text
+ FROM your_database.note
+ WHERE note_datetime >= '2024-01-01' and note_text IS NOT NULL
+ LIMIT 10000
+ """)
+
+ note.write.mode("overwrite").parquet("hdfs:///user/USERNAME/notes/")
+ ```
+
+2. Run the Slurm GPU inference on that Parquet dataset, as in sections above (point your `--input_path` to the Parquet location and `--output_path` to a destination Parquet directory).
+
+3. Load the inference results back into Spark if needed (aggregation, joins, etc.)
+ ```python { .no-check }
+ note_nlp = spark.read.parquet("hdfs:///user/USERNAME/nlp_results/")
+ note_nlp.createOrReplaceTempView("note_nlp")
+
+ # Example: count entities per label
+ spark.sql("""
+ SELECT label, COUNT(*) AS n
+ FROM note_nlp
+ GROUP BY label
+ ORDER BY n DESC
+ """).show()
+ ```
+
+This approach keeps GPU inference scheduling independent of Spark, avoids excessive CPU pressure, and is easier to monitor and reason about.
diff --git a/docs/tutorials/index.md b/docs/tutorials/index.md
index 9f563e8fc..3818be564 100644
--- a/docs/tutorials/index.md
+++ b/docs/tutorials/index.md
@@ -46,6 +46,14 @@ We provide step-by-step guides to get you started. We cover the following use-ca
---
Improve the inference speed of your pipeline
+=== card {: href=/tutorials/hpc }
+
+ :fontawesome-solid-microchip:
+ **Running on HPC (eg. Slurm)**
+
+ ---
+ Use an existing model at scale with an High-Performance Computing (HPC) job scheduler like Slurm.
+
=== card {: href=/tutorials/reason }
:fontawesome-regular-hospital:
@@ -85,6 +93,8 @@ We provide step-by-step guides to get you started. We cover the following use-ca
---
Quickly visualize the results of your pipeline as annotations or tables.
+
+
### Deep learning tutorials
We also provide tutorials on how to train deep-learning models with EDS-NLP. These tutorials cover the training API, hyperparameter tuning, and more.
diff --git a/edsnlp/core/pipeline.py b/edsnlp/core/pipeline.py
index c4ff93ddc..6ed7106e8 100644
--- a/edsnlp/core/pipeline.py
+++ b/edsnlp/core/pipeline.py
@@ -558,7 +558,8 @@ def from_config(
config["nlp"]["components"] = Reference("components")
config = config["nlp"]
- config = Config(config).resolve(root=root_config, registry=registry)
+ with DraftPipe.disable_auto_instantiation():
+ config = Config(config).resolve(root=root_config, registry=registry)
if isinstance(config, Pipeline): # pragma: no cover
return config
config = dict(config)
@@ -590,6 +591,7 @@ def _add_pipes(
enable: Container[str],
disable: Container[str],
):
+ components = {n: comp for n, comp in components.items() if n not in exclude}
try:
components = DraftPipe.instantiate(components, nlp=self)
except ConfitValidationError as e:
@@ -1225,6 +1227,7 @@ def load(
auto_update=auto_update,
install_dependencies=install_dependencies,
**kwargs,
+ **pipe_selection,
)
except (
ImportError,
@@ -1237,7 +1240,7 @@ def load(
if not isinstance(model, Config):
raise ValueError(error) from base_exc
- return Pipeline.from_config(model)
+ return Pipeline.from_config(model, **pipe_selection)
def load_from_huggingface(
diff --git a/edsnlp/core/registries.py b/edsnlp/core/registries.py
index e6c22ad0d..7866149bd 100644
--- a/edsnlp/core/registries.py
+++ b/edsnlp/core/registries.py
@@ -1,5 +1,6 @@
import inspect
import types
+from contextlib import contextmanager
from dataclasses import dataclass
from functools import wraps
from typing import (
@@ -71,12 +72,26 @@ class FactoryMeta:
T = TypeVar("T")
+_AUTO_INSTANTIATE_COMPLETE_PIPE_DRAFT = True
+
+
class DraftPipe(Draft[T]):
def __init__(self, func, kwargs):
super().__init__(func, kwargs)
self.instantiated = None
self.error = None
+ @staticmethod
+ @contextmanager
+ def disable_auto_instantiation():
+ global _AUTO_INSTANTIATE_COMPLETE_PIPE_DRAFT
+ old_value = _AUTO_INSTANTIATE_COMPLETE_PIPE_DRAFT
+ _AUTO_INSTANTIATE_COMPLETE_PIPE_DRAFT = False
+ try:
+ yield
+ finally:
+ _AUTO_INSTANTIATE_COMPLETE_PIPE_DRAFT = old_value
+
def maybe_nlp(self) -> Union["DraftPipe", Any]:
"""
If the factory requires an nlp argument and the user has explicitly
@@ -92,15 +107,20 @@ def maybe_nlp(self) -> Union["DraftPipe", Any]:
sig = inspect.signature(self._func)
if (
- not (
- "nlp" in sig.parameters
- and (
- sig.parameters["nlp"].default is sig.empty
- or sig.parameters["nlp"].annotation in (Pipeline, PipelineProtocol)
+ _AUTO_INSTANTIATE_COMPLETE_PIPE_DRAFT
+ and (
+ not (
+ "nlp" in sig.parameters
+ and (
+ sig.parameters["nlp"].default is sig.empty
+ or sig.parameters["nlp"].annotation
+ in (Pipeline, PipelineProtocol)
+ )
)
+ or "nlp" in self._kwargs
)
- or "nlp" in self._kwargs
- ) and not self.search_nested_drafts(self._kwargs):
+ and not self.search_nested_drafts(self._kwargs)
+ ):
return self._func(**self._kwargs)
return self
diff --git a/edsnlp/core/torch_component.py b/edsnlp/core/torch_component.py
index e78b4fc6b..7f0ee1292 100644
--- a/edsnlp/core/torch_component.py
+++ b/edsnlp/core/torch_component.py
@@ -79,7 +79,9 @@ def cached(key, store_key=False):
def wrapper(fn):
@wraps(fn)
def wrapped(self: "TorchComponent", *args, **kwargs):
- if self._current_cache_id is None or len(args) == 0:
+ # Got an error once in the CI where self._current_cache_id (!= None) was not
+ # in _caches. This should not happen, but just in case, we check here.
+ if _caches.get(self._current_cache_id) is None or len(args) == 0:
return fn(self, *args, **kwargs)
cache_key = (
fn.__name__,
diff --git a/edsnlp/processing/multiprocessing.py b/edsnlp/processing/multiprocessing.py
index c193f173f..71f0f7a41 100644
--- a/edsnlp/processing/multiprocessing.py
+++ b/edsnlp/processing/multiprocessing.py
@@ -7,6 +7,7 @@
import multiprocessing
import multiprocessing.reduction
import os
+import re
import sys
import tempfile
import threading
@@ -18,8 +19,8 @@
from contextlib import nullcontext
from itertools import tee
from multiprocessing.connection import wait
+from multiprocessing.queues import Empty
from typing import (
- TYPE_CHECKING,
Dict,
Iterable,
List,
@@ -39,20 +40,13 @@
decompress_dict,
)
-doc_size_fns = {
- "words": len,
-}
-
-if TYPE_CHECKING:
- import torch
-
-
-# Singleton is important since the STOP objects may be passed to
-# other processes, i.e. pickled, unpickled, while they should
-# always be the same object.
+WORKER_STOP_CHECK_INTERVAL = 5.0 # seconds
class StopType:
+ # Singleton is important since the STOP objects may be passed to
+ # other processes, i.e. pickled, unpickled, while they should
+ # always be the same object.
instance = None
def __repr__(self):
@@ -218,6 +212,63 @@ def cpu_count(): # pragma: no cover
return max(1, min(os_cpu_count, cpu_count_affinity, cpu_count_cgroup))
+def hdfs_or_jvm_loaded() -> bool:
+ libs = ("libhdfs", "libhdfs3", "libhadoop", "libjvm")
+
+ def loaded_libs():
+ if sys.platform.startswith("linux"):
+ try:
+ with open("/proc/self/maps", "r") as f:
+ return re.findall(r"/[^\s]+\.so[^\s]*", f.read())
+ except Exception:
+ return []
+ try:
+ import psutil
+
+ return [m.path for m in psutil.Process().memory_maps() if m.path]
+ except Exception:
+ return []
+
+ try:
+ loaded = loaded_libs()
+ except Exception:
+ return False
+ return any(any(k in os.path.basename(p).lower() for k in libs) for p in loaded)
+
+
+def get_multiprocessing_context(has_torch_pipes, process_start_method):
+ methods = multiprocessing.get_all_start_methods()
+ default_method = "fork" if "fork" in methods else "spawn"
+ has_hdfs = hdfs_or_jvm_loaded()
+
+ # Base choice, ie torch based pipes prefer spawn
+ method = process_start_method or (
+ "spawn" if has_torch_pipes or has_hdfs else default_method
+ )
+
+ # Warn for torch + fork
+ if has_torch_pipes and method == "fork":
+ warnings.warn(
+ "Using fork start method with GPU workers may lead to deadlocks. "
+ "Consider using process_start_method='spawn' instead."
+ )
+ method = "spawn"
+
+ # Avoid fork if HDFS/JVM already loaded
+ if has_hdfs and method == "fork":
+ safe = "forkserver" if "forkserver" in methods else "spawn"
+ warnings.warn(
+ "Using fork start method with HDFS may lead to deadlocks. "
+ f"Consider using process_start_method='{safe}' instead."
+ )
+ method = safe
+
+ if default_method != method:
+ logging.info(f"Switching process start method to {method}")
+
+ return multiprocessing.get_context(method)
+
+
# Should we check if the multiprocessing module of edsnlp
# is responsible for this child process before replacing the pickler ?
if (
@@ -394,7 +445,7 @@ def run_stage_thread(self, stage):
except BaseException as e:
if self.stop: # pragma: no cover
return
- print(f"Error in {self.uid}:\n{traceback.format_exc()}", flush=True)
+ print(f"Error in {self.uid}/{stage}:\n{traceback.format_exc()}", flush=True)
self.main_control_queue.put(e)
finally:
try:
@@ -427,7 +478,12 @@ def run(self):
self.main_control_queue.put("READY")
while not self.stop:
- notification = self.worker_control_queue.get()
+ try:
+ notification = self.worker_control_queue.get(
+ timeout=WORKER_STOP_CHECK_INTERVAL
+ )
+ except Empty: # pragma: no cover
+ continue
if notification is STOP and not self.stop:
self.stop = True
self.on_stop()
@@ -444,16 +500,6 @@ def run(self):
print(f"Error in {self.uid}:\n{traceback.format_exc()}", flush=True)
self.main_control_queue.put(e)
finally:
- # print(f"Waiting time for {self.uid}", flush=True)
- # print(
- # "\n"
- # + "\n".join(
- # f"Waiting time for {self.uid}/{k}: {v:.2f}"
- # for k, v in self.waiting_times.items()
- # )
- # + "\n",
- # flush=True,
- # )
for thread in threads:
thread.join()
for stage in self.stages_to_run:
@@ -517,6 +563,16 @@ def iter_tasks(self, stage, stop_mode=False):
task_idx = 0
for item in iter(self.stream.reader.read_records()):
if self.stop: # pragma: no cover
+ # I don't know why, sometimes when a KeyboardInterrupt is captured
+ # by the main process, and converted to a stop=True in the workers,
+ # this line is correctly reached, but the StopSignal exception is
+ # not bubbled up most of the time to `self.send_results(items)` in
+ # `CPUWorker.process_items` and therefore not caught in
+ # `self.process_items(stage)` in `Worker.run_stage_thread`.
+ # This makes the worker hang at the end instead of stopping.
+ # Simply wrapping this iter_tasks() generator in a try/except
+ # next(iterator) doesn't catch the exception when raised here, as if
+ # the exception bubbling mechanism was broken in these cases.
raise StopSignal()
if isinstance(item, StreamSentinel):
yield item
@@ -529,7 +585,7 @@ def iter_tasks(self, stage, stop_mode=False):
task_idx += 1
else:
- task_idx = -1
+ last_task_idx = -1
deterministic = self.stream.deterministic
schedule = self.stage_schedule[stage]
@@ -538,8 +594,8 @@ def iter_tasks(self, stage, stop_mode=False):
raise StopSignal()
if stage > 0:
- task_idx = (task_idx + 1) % len(schedule)
- while schedule[task_idx] is None:
+ task_idx = (last_task_idx + 1) % len(schedule)
+ while schedule[task_idx] is None: # pragma: no cover
task_idx = (task_idx + 1) % len(schedule)
if deterministic:
@@ -563,9 +619,15 @@ def iter_tasks(self, stage, stop_mode=False):
name = f"from-{prod}_to-stage-{stage}_of-{self.uid}"
queue = self.data_queues[name]
t = time.time()
- item = queue.get()
+ try:
+ item = queue.get(timeout=WORKER_STOP_CHECK_INTERVAL)
+ except Empty: # pragma: no cover
+ continue
+ else:
+ if stage > 0:
+ last_task_idx = task_idx
+ self.waiting_times["get-" + name] += time.time() - t
- self.waiting_times["get-" + name] += time.time() - t
if item is STOP:
schedule[:] = [s if s != prod else None for s in schedule]
self.num_producers_alive[stage] -= 1
@@ -751,18 +813,18 @@ def process_items(self, stage):
del batch, item
def iter_tasks(self, stage, stop_mode=False):
- offset = -1
queues = [
key
for key, q in self.data_queues.items()
if key.endswith(f"-{stage}_of-{self.uid}")
]
# Get items from the previous stage
+ last_offset = -1
while self.num_producers_alive[stage] > 0:
if self.stop and not stop_mode: # pragma: no cover
raise StopSignal()
- offset = (offset + 1) % len(queues)
+ offset = (last_offset + 1) % len(queues)
while queues[offset] is None: # pragma: no cover
offset = (offset + 1) % len(queues)
@@ -773,9 +835,13 @@ def iter_tasks(self, stage, stop_mode=False):
name = names[conns.index(ready)]
queue = self.data_queues[name]
t = time.time()
- item = queue.get()
-
- self.waiting_times["get-" + name] += time.time() - t
+ try:
+ item = queue.get(timeout=WORKER_STOP_CHECK_INTERVAL)
+ except Empty: # pragma: no cover
+ continue
+ else:
+ last_offset = offset
+ self.waiting_times["get-" + name] += time.time() - t
if item is STOP:
queues[:] = [s if s != name else None for s in queues]
@@ -826,7 +892,7 @@ def __init__(self, stream):
self.stream = stream
self.stages = stream._make_stages(split_torch_pipes=num_gpu_workers > 0)
self.has_torch_pipes = has_torch_pipes
- mp = self.get_multiprocessing_context(
+ mp = get_multiprocessing_context(
has_torch_pipes=has_torch_pipes,
process_start_method=stream.process_start_method,
)
@@ -867,7 +933,7 @@ def __init__(self, stream):
for cpu in self.cpu_worker_names:
name = f"from-main_to-stage-0_of-{cpu}"
if not share_queues:
- queue = mp.SimpleQueue()
+ queue = mp.Queue()
self.data_queues[name] = queue
self.input_queue_names.append(name)
@@ -951,7 +1017,7 @@ def __init__(self, stream):
self.stopped = False
self.num_alive_workers = num_cpu_workers
self.workers_status = [True] * num_cpu_workers
- self.current_worker_idx = -1
+ self.current_worker_idx = 0
self.error = None
self.dequeue_notifications_thread = None
self.queue_feeder_threads: List[threading.Thread] = []
@@ -1080,18 +1146,11 @@ def iter_outputs(self, stop_mode=False):
)
missing_sentinels = len(self.cpu_worker_names) if requires_sentinel else 0
buffer = []
+ last_heartbeat = {name: time.time() for name in self.cpu_worker_names}
while self.num_alive_workers > 0:
if self.stopped and not stop_mode: # pragma: no cover
raise StopSignal()
- self.current_worker_idx = (self.current_worker_idx + 1) % len(
- self.cpu_worker_names
- )
- while not self.workers_status[self.current_worker_idx]:
- self.current_worker_idx = (self.current_worker_idx + 1) % len(
- self.cpu_worker_names
- )
-
if deterministic:
worker_idx = self.current_worker_idx
else:
@@ -1112,7 +1171,20 @@ def iter_outputs(self, stop_mode=False):
name = f"from-{self.cpu_worker_names[worker_idx]}_to-main"
queue = self.data_queues[name]
- out = queue.get()
+ try:
+ out = queue.get(timeout=10)
+ except Empty: # pragma: no cover
+ continue
+
+ last_heartbeat[self.cpu_worker_names[worker_idx]] = time.time()
+ self.current_worker_idx = (self.current_worker_idx + 1) % len(
+ self.cpu_worker_names
+ )
+ while not self.workers_status[self.current_worker_idx]:
+ self.current_worker_idx = (self.current_worker_idx + 1) % len(
+ self.cpu_worker_names
+ )
+
if out is STOP:
self.num_alive_workers -= 1
self.workers_status[worker_idx] = False
@@ -1131,8 +1203,6 @@ def iter_outputs(self, stop_mode=False):
else:
yield out
yield from buffer
- if self.error:
- raise self.error
def feed_queue(self, queue, items):
"""
@@ -1225,6 +1295,7 @@ def teardown(self, garbage_collected=False):
self.revert_environ()
self.revert_pickler()
+ self.error = None
for _ in self.iter_outputs(stop_mode=True):
pass
@@ -1328,26 +1399,6 @@ def adjust_num_workers(stream: Stream):
has_torch_pipes,
)
- @staticmethod
- def get_multiprocessing_context(has_torch_pipes, process_start_method):
- if has_torch_pipes:
- if process_start_method == "fork":
- warnings.warn(
- "Using fork start method with GPU workers may lead to deadlocks. "
- "Consider using process_start_method='spawn' instead."
- )
-
- process_start_method = process_start_method or "spawn"
-
- default_method = (
- "fork" if "fork" in multiprocessing.get_all_start_methods() else "spawn"
- )
- if process_start_method is not None and default_method != process_start_method:
- logging.info(f"Switching process start method to {process_start_method}")
- process_start_method = process_start_method or default_method
-
- return multiprocessing.get_context(process_start_method)
-
@staticmethod
def setup_environ(disable_implicit_parallelism):
old_environ = {
diff --git a/edsnlp/tune.py b/edsnlp/tune.py
index c7c46d681..0f56e0194 100644
--- a/edsnlp/tune.py
+++ b/edsnlp/tune.py
@@ -595,7 +595,7 @@ def compute_remaining_n_trials_possible(
remaining_gpu_time, compute_time_per_trial(study, ema=True)
)
return n_trials
- except ValueError:
+ except ValueError: # pragma: no cover
return 0
diff --git a/mkdocs.yml b/mkdocs.yml
index 8cc5cbaaf..92da01032 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -51,6 +51,7 @@ nav:
- tutorials/endlines.md
- tutorials/aggregating-results.md
- tutorials/multiple-texts.md
+ - tutorials/hpc.md
- advanced-tutorials/fastapi.md
- tutorials/make-a-training-script.md
- tutorials/training-ner.md