diff --git a/.gitignore b/.gitignore
index 3432f0519d..d7db964722 100644
--- a/.gitignore
+++ b/.gitignore
@@ -73,3 +73,6 @@ docs/reference
docs/changelog.md
docs/contributing.md
.vercel
+
+# Work development
+dev/*
diff --git a/changelog_unreleased.md b/changelog_unreleased.md
new file mode 100644
index 0000000000..84b530771b
--- /dev/null
+++ b/changelog_unreleased.md
@@ -0,0 +1,9 @@
+### Added
+- New `edsnlp.external_information_qualifier` qualifies spans in a document based on external information and a defined distance to these contextual/external elements as in Distant Supervision.
+- New `eds.contextual_qualifier` pipeline component to qualify spans based on contextual information.
+- Add the fixture `edsnlp_blank_nlp` for the test.
+
+### Fixed
+- Correct the contributing documentation. Delete `$ pre-commit run --all-files`recommendation.
+- Fix the the `Obj Class` in the doc template `class.html`.
+- Fix the `get_pipe_meta` function.
diff --git a/contributing.md b/contributing.md
index 026b954029..213d9f8e7c 100644
--- a/contributing.md
+++ b/contributing.md
@@ -43,16 +43,7 @@ $ pre-commit install
The pre-commit hooks defined in the [configuration](https://github.com/aphp/edsnlp/blob/master/.pre-commit-config.yaml) will automatically run when you commit your changes, letting you know if something went wrong.
-The hooks only run on staged changes. To force-run it on all files, run:
-
-
-
-```console
-$ pre-commit run --all-files
----> 100%
-color:green All good !
-```
-
+The hooks only run on staged changes.
## Proposing a merge request
diff --git a/docs/advanced-tutorials/distant_annotation.md b/docs/advanced-tutorials/distant_annotation.md
new file mode 100644
index 0000000000..0b4876810a
--- /dev/null
+++ b/docs/advanced-tutorials/distant_annotation.md
@@ -0,0 +1,163 @@
+# External Information & Context qualifiers
+
+This tutorial shows the use of two pipes to qualify spans or entities by using the `ContextualQualifier` and the `ExternalInformationQualifier`
+
+### Import dependencies
+```python
+import datetime
+
+import pandas as pd
+
+import edsnlp
+from edsnlp.pipes.qualifiers.contextual.contextual import (
+ ClassPatternsContext,
+ ContextualQualifier,
+)
+from edsnlp.pipes.qualifiers.external_information.external_information import (
+ ExternalInformation,
+ ExternalInformationQualifier,
+)
+from edsnlp.utils.collections import get_deep_attr
+```
+
+### Data
+Lets start creating a toy example
+```python
+# Create context dates
+# The elements under this attribute should be a list of dicts with keys value and class
+context_dates = [
+ {
+ "value": datetime.datetime(2024, 2, 15),
+ "class": "Magnetic resonance imaging (procedure)",
+ },
+ {"value": datetime.datetime(2024, 2, 17), "class": "Biopsy (procedure)"},
+ {"value": datetime.datetime(2024, 2, 17), "class": "Colonoscopy (procedure)"},
+]
+
+# Texy
+text = """
+RCP du 18/12/2024 : DUPONT Jean
+
+Homme de 68 ans adressé en consultation d’oncologie pour prise en charge d’une tumeur du colon.
+Antécédents : HTA, diabète de type 2, dyslipidémie, tabagisme actif (30 PA), alcoolisme chronique (60 g/jour).
+
+Examen clinique : patient en bon état général, poids 80 kg, taille 1m75.
+
+
+HISTOIRE DE LA MALADIE :
+Lors du PET-CT (14/02/2024), des dépôts pathologiques ont été observés qui coïncidaient avec les résultats du scanner.
+Le 15/02/2024, une IRM a été réalisée pour évaluer l’extension de la tumeur.
+Une colonoscopie a été réalisée le 17/02/2024 avec une biopsie d'adénopathie sous-carinale.
+Une deuxième a été biopsié le 18/02/2024. Les résultats de la biopsie ont confirmé un adénocarcinome du colon.
+Il a été opéré le 20/02/2024. L’examen anatomopath ologique de la pièce opératoire a confirmé un adénocarcinome du colon stade IV avec métastases hépatiques et pulmonaires.
+Trois mois après la fin du traitement de chimiothérapie (abril 2024), le patient a signalé une aggravation progressive des symptômes
+
+CONCLUSION : Adénocarcinome du colon stade IV avec métastases hépatiques et pulmonaires.
+"""
+
+
+# Create a toy dataframe
+df = pd.DataFrame.from_records(
+ [
+ {
+ "person_id": 1,
+ "note_id": 1,
+ "note_text": text,
+ "context_dates": context_dates,
+ }
+ ]
+)
+df
+```
+
+### Define the nlp pipeline
+```python
+import edsnlp.pipes as eds
+
+nlp = edsnlp.blank("eds")
+
+nlp.add_pipe(eds.sentences())
+nlp.add_pipe(eds.normalizer())
+nlp.add_pipe(eds.dates())
+
+
+nlp.add_pipe(
+ ContextualQualifier(
+ span_getter="dates",
+ patterns={
+ "lf1": {
+ "Magnetic resonance imaging (procedure)": ClassPatternsContext(
+ **{
+ "terms": {"irm": ["IRM", "imagerie par résonance magnétique"]},
+ "regex": None,
+ "context_words": 0,
+ "context_sents": 1,
+ "attr": "TEXT",
+ }
+ )
+ },
+ "lf2": {
+ "Biopsy (procedure)": {
+ "regex": {"biopsy": ["biopsie", "biopsié"]},
+ "context_words": (10, 10),
+ "context_sents": 0,
+ "attr": "TEXT",
+ }
+ },
+ "lf3": {
+ "Surgical procedure (procedure)": {
+ "regex": {"chirurgie": ["chirurgie", "exerese", "opere"]},
+ "context_words": 0,
+ "context_sents": (2, 2),
+ "attr": "NORM",
+ },
+ },
+ },
+ )
+)
+
+nlp.add_pipe(
+ ExternalInformationQualifier(
+ nlp=nlp,
+ span_getter="dates",
+ external_information={
+ "lf4": ExternalInformation(
+ doc_attr="_.context_dates",
+ span_attribute="_.date.to_datetime()",
+ threshold=datetime.timedelta(days=0),
+ )
+ },
+ )
+)
+```
+
+### Apply the pipeline to texts
+```python
+doc_iterator = edsnlp.data.from_pandas(
+ df, converter="omop", doc_attributes=["context_dates"]
+)
+
+docs = list(nlp.pipe(doc_iterator))
+```
+
+### Lets inspect the results
+```python
+doc = docs[0]
+dates = doc.spans["dates"]
+
+for date in dates:
+ for attr in ["lf1", "lf2", "lf3", "lf4"]:
+ value = get_deep_attr(date, "_." + attr)
+
+ if value:
+ print(date.start, date.end, date, attr, value)
+```
+
+```python
+# Out : 120 125 15/02/2024 lf1 Magnetic resonance imaging (procedure)
+# Out : 120 125 15/02/2024 lf4 ['Magnetic resonance imaging (procedure)']
+# Out : 147 152 17/02/2024 lf2 Biopsy (procedure)
+# Out : 147 152 17/02/2024 lf4 ['Biopsy (procedure)', 'Colonoscopy (procedure)']
+# Out : 168 173 18/02/2024 lf2 Biopsy (procedure)
+# Out : 192 197 20/02/2024 lf3 Surgical procedure (procedure)
+```
diff --git a/docs/pipes/qualifiers/contextual.md b/docs/pipes/qualifiers/contextual.md
new file mode 100644
index 0000000000..b3c4fe319f
--- /dev/null
+++ b/docs/pipes/qualifiers/contextual.md
@@ -0,0 +1,8 @@
+# Contextual {: #edsnlp.pipes.qualifiers.contextual.factory.create_component }
+
+::: edsnlp.pipes.qualifiers.contextual.factory.create_component
+ options:
+ heading_level: 2
+ show_bases: true
+ show_source: true
+ only_class_level: true
diff --git a/docs/pipes/qualifiers/external_information.md b/docs/pipes/qualifiers/external_information.md
new file mode 100644
index 0000000000..0e9aa8c862
--- /dev/null
+++ b/docs/pipes/qualifiers/external_information.md
@@ -0,0 +1,8 @@
+# External Information {: #edsnlp.pipes.qualifiers.external_information.factory.create_component }
+
+::: edsnlp.pipes.qualifiers.external_information.factory.create_component
+ options:
+ heading_level: 2
+ show_bases: true
+ show_source: true
+ only_class_level: true
diff --git a/docs/tutorials/training-ner.md b/docs/tutorials/training-ner.md
index 50ddc56146..b6b6e38744 100644
--- a/docs/tutorials/training-ner.md
+++ b/docs/tutorials/training-ner.md
@@ -115,7 +115,7 @@ Visit the [`edsnlp.train` documentation][edsnlp.training.trainer.train] for a li
# 🎛️ OPTIMIZER
optimizer:
- "@core": optimizer
+ "@core": optimizer !draft # (2)!
optim: adamw
groups:
# Assign parameters starting with transformer (ie the parameters of the transformer component)
@@ -133,7 +133,6 @@ Visit the [`edsnlp.train` documentation][edsnlp.training.trainer.train] for a li
"warmup_rate": 0.1
"start_value": 3e-4
"max_value": 3e-4
- module: ${ nlp }
total_steps: ${ train.max_steps }
# 📚 DATA
@@ -216,6 +215,14 @@ Visit the [`edsnlp.train` documentation][edsnlp.training.trainer.train] for a li
1. Why do we use `'@core': pipeline` here ? Because we need the reference used in `optimizer.module = ${ nlp }` to be the actual Pipeline and not its keyword arguments : when confit sees `'@core': pipeline`, it will instantiate the `Pipeline` class with the arguments provided in the dict.
In fact, you could also use `'@core': eds.pipeline` in every config when you define a pipeline, but sometimes it's more convenient to let Confit infer that the type of the nlp argument based on the function when it's type hinted. Not specifying `'@core': pipeline` is also more aligned with `spacy`'s pipeline config API. However, in general, explicit is better than implicit, so feel free to use explicitly write `'@core': eds.pipeline` when you define a pipeline.
+ 1. What does "draft" mean here ? We'll let the train function pass the nlp object
+ to the optimizer after it has been been `post_init`'ed : `post_init` is the operation that
+ looks at some data, finds how many label the model must learn, and updates the model weights
+ to have as many heads as there are labels observed in the train data. This function will be
+ called by `train`, so the optimizer should be defined *after*, when the model parameter
+ tensors are final. To do that, instead of instantiating the optimizer right now, we create
+ a "Draft", which will be instantiated inside the `train` function, once all the required
+ parameters are set.
To train the model, you can use the following command:
@@ -277,9 +284,8 @@ Visit the [`edsnlp.train` documentation][edsnlp.training.trainer.train] for a li
# 🎛️ OPTIMIZER
max_steps = 2000
- optimizer = ScheduledOptimizer(
+ optimizer = ScheduledOptimizer.draft( # (1)!
optim=torch.optim.Adam,
- module=nlp,
total_steps=max_steps,
groups=[
{
@@ -333,6 +339,15 @@ Visit the [`edsnlp.train` documentation][edsnlp.training.trainer.train] for a li
)
```
+ 1. Wait, what's does "draft" mean here ? We'll let the train function pass the nlp object
+ to the optimizer after it has been been `post_init`'ed : `post_init` is the operation that
+ looks at some data, finds how many label the model must learn, and updates the model weights
+ to have as many heads as there are labels observed in the train data. This function will be
+ called by `train`, so the optimizer should be defined *after*, when the model parameter
+ tensors are final. To do that, instead of instantiating the optimizer right now, we create
+ a "Draft", which will be instantiated inside the `train` function, once all the required
+ parameters are set.
+
or use the config file:
```{ .python .no-check }
diff --git a/docs/tutorials/training-span-classifier.md b/docs/tutorials/training-span-classifier.md
index 83aa88b171..154866be91 100644
--- a/docs/tutorials/training-span-classifier.md
+++ b/docs/tutorials/training-span-classifier.md
@@ -184,13 +184,14 @@ Visit the [`edsnlp.train` documentation][edsnlp.training.trainer.train] for a li
```
1. Put entities extracted by `eds.dates` in `doc.ents`, instead of `doc.spans['dates']`.
- 2. Wait, what's does "draft" mean here ? The rationale is this: we don't want to
- instantiate the optimizer now, because the nlp object hasn't been `post_init`'ed
- yet : `post_init` is the operation that looks at some data, finds how many labels the model must learn,
- and updates the model weights to have as many heads as there are labels. This function will
- be called by `train`, so the optimizer should be defined *after*, when the model parameter tensors are
- final. To do that, instead of instantiating the optimizer, we create a "Draft", which will be
- instantiated inside the `train` function, once all the required parameters are set.
+ 2. What does "draft" mean here ? We'll let the train function pass the nlp object
+ to the optimizer after it has been been `post_init`'ed : `post_init` is the operation that
+ looks at some data, finds how many label the model must learn, and updates the model weights
+ to have as many heads as there are labels observed in the train data. This function will be
+ called by `train`, so the optimizer should be defined *after*, when the model parameter
+ tensors are final. To do that, instead of instantiating the optimizer right now, we create
+ a "Draft", which will be instantiated inside the `train` function, once all the required
+ parameters are set.
And train the model:
@@ -309,13 +310,14 @@ Visit the [`edsnlp.train` documentation][edsnlp.training.trainer.train] for a li
```
1. Put entities extracted by `eds.dates` in `doc.ents`, instead of `doc.spans['dates']`.
- 2. Wait, what's does "draft" mean here ? The rationale is this: we don't want to
- instantiate the optimizer now, because the nlp object hasn't been `post_init`'ed
- yet : `post_init` is the operation that looks at some data, finds how many label the model must learn,
- and updates the model weights to have as many heads as there are labels. This function will
- be called by `train`, so the optimizer should be defined *after*, when the model parameter tensors are
- final. To do that, instead of instantiating the optimizer, we create a "Draft", which will be
- instantiated inside the `train` function, once all the required parameters are set.
+ 2. What does "draft" mean here ? We'll let the train function pass the nlp object
+ to the optimizer after it has been been `post_init`'ed : `post_init` is the operation that
+ looks at some data, finds how many label the model must learn, and updates the model weights
+ to have as many heads as there are labels observed in the train data. This function will be
+ called by `train`, so the optimizer should be defined *after*, when the model parameter
+ tensors are final. To do that, instead of instantiating the optimizer right now, we create
+ a "Draft", which will be instantiated inside the `train` function, once all the required
+ parameters are set.
!!! note "Upstream annotations at training vs inference time"
diff --git a/edsnlp/core/pipeline.py b/edsnlp/core/pipeline.py
index c4ff93ddc6..a93ef14e0e 100644
--- a/edsnlp/core/pipeline.py
+++ b/edsnlp/core/pipeline.py
@@ -338,7 +338,7 @@ def get_pipe_meta(self, name: str) -> FactoryMeta:
Dict[str, Any]
"""
pipe = self.get_pipe(name)
- return PIPE_META.get(pipe, {})
+ return PIPE_META.get(pipe, FactoryMeta([], [], False, {}))
def make_doc(self, text: str) -> Doc:
"""
diff --git a/edsnlp/metrics/span_attribute.py b/edsnlp/metrics/span_attribute.py
index d701813e03..cbb74f20dc 100644
--- a/edsnlp/metrics/span_attribute.py
+++ b/edsnlp/metrics/span_attribute.py
@@ -121,6 +121,8 @@ def span_attribute_metric(
continue
getter_key = attr if attr.startswith("_.") else f"_.{attr}"
value = BINDING_GETTERS[getter_key](span)
+ if isinstance(value, dict):
+ value = max(value, key=value.get)
if (value or include_falsy) and default_values[attr] != value:
labels[micro_key][1].add((eg_idx, beg, end, attr, value))
labels[attr][1].add((eg_idx, beg, end, attr, value))
diff --git a/edsnlp/pipes/__init__.py b/edsnlp/pipes/__init__.py
index aea3f0f088..e095b6bfad 100644
--- a/edsnlp/pipes/__init__.py
+++ b/edsnlp/pipes/__init__.py
@@ -74,6 +74,8 @@
from .qualifiers.negation.factory import create_component as negation
from .qualifiers.reported_speech.factory import create_component as reported_speech
from .qualifiers.reported_speech.factory import create_component as rspeech
+ from .qualifiers.contextual.factory import create_component as contextual_qualifier
+ from .qualifiers.external_information.factory import create_component as external_information_qualifier
from .trainable.ner_crf.factory import create_component as ner_crf
from .trainable.biaffine_dep_parser.factory import create_component as biaffine_dep_parser
from .trainable.extractive_qa.factory import create_component as extractive_qa
diff --git a/edsnlp/pipes/qualifiers/contextual/__init__.py b/edsnlp/pipes/qualifiers/contextual/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/edsnlp/pipes/qualifiers/contextual/contextual.py b/edsnlp/pipes/qualifiers/contextual/contextual.py
new file mode 100644
index 0000000000..c9568c4c2b
--- /dev/null
+++ b/edsnlp/pipes/qualifiers/contextual/contextual.py
@@ -0,0 +1,228 @@
+import re
+from dataclasses import dataclass
+from itertools import chain
+from typing import Dict, List, Optional, Tuple, Union
+
+from pydantic import NonNegativeInt
+from spacy.tokens import Doc, Span
+
+from edsnlp.core import PipelineProtocol
+from edsnlp.matchers.utils import Patterns
+from edsnlp.pipes.base import (
+ BaseSpanAttributeClassifierComponent,
+ SpanGetterArg,
+)
+from edsnlp.pipes.core.matcher.matcher import GenericMatcher
+from edsnlp.utils.span_getters import (
+ get_spans,
+ make_span_context_getter,
+ validate_span_getter,
+)
+
+
+@dataclass
+class ClassPatternsContext:
+ """
+ A data class to hold pattern matching context.
+
+ Parameters
+ ----------
+ terms : Optional[Patterns]
+ Terms to match.
+ regex : Optional[Patterns]
+ Regular expressions to match.
+ context_words : Union[NonNegativeInt, Tuple[NonNegativeInt, NonNegativeInt]]
+ Number of words to consider as context.
+ context_sents : Union[NonNegativeInt, Tuple[NonNegativeInt, NonNegativeInt]]
+ Number of sentences to consider as context.
+ attr : str
+ Attribute to match on.
+ regex_flags : Union[re.RegexFlag, int]
+ Flags for regular expressions.
+ ignore_excluded : bool
+ Whether to ignore excluded tokens.
+ ignore_space_tokens : bool
+ Whether to ignore space tokens.
+ """
+
+ terms: Optional[Patterns] = None
+ regex: Optional[Patterns] = None
+ context_words: Union[NonNegativeInt, Tuple[NonNegativeInt, NonNegativeInt]] = 0
+ context_sents: Union[NonNegativeInt, Tuple[NonNegativeInt, NonNegativeInt]] = 1
+ attr: str = "TEXT"
+ regex_flags: Union[re.RegexFlag, int] = 0
+ ignore_excluded: bool = False
+ ignore_space_tokens: bool = False
+
+
+@dataclass
+class ClassMatcherContext:
+ """
+ A data class to hold matcher context configuration.
+
+ Parameters
+ ----------
+ name : str
+ The name of the matcher.
+ value : Union[str, bool, int]
+ The value to set when a match is found.
+ matcher : GenericMatcher
+ The matcher object.
+ context_words : Union[NonNegativeInt, Tuple[NonNegativeInt, NonNegativeInt]]
+ Number of words to consider as context.
+ context_sents : Union[NonNegativeInt, Tuple[NonNegativeInt, NonNegativeInt]]
+ Number of sentences to consider as context.
+ """
+
+ name: str
+ value: Union[str, bool, int]
+ matcher: GenericMatcher
+ context_words: Union[NonNegativeInt, Tuple[NonNegativeInt, NonNegativeInt]]
+ context_sents: Union[NonNegativeInt, Tuple[NonNegativeInt, NonNegativeInt]]
+
+
+class ContextualQualifier(BaseSpanAttributeClassifierComponent):
+ """
+ The `eds.contextual_qualifier` pipeline component
+ qualifies spans based on contextual information.
+
+ Parameters
+ ----------
+ nlp : PipelineProtocol
+ The spaCy pipeline object.
+ name : Optional[str], default="contextual_qualifier"
+ The name of the component.
+ span_getter : SpanGetterArg
+ The function or callable to get spans from the document.
+ patterns : Dict[str, Dict[Union[str, int], ClassPatternsContext]]
+ A dictionary of patterns to match in the text. Each pattern dictionary should
+ follow the structure of the `ClassPatternsContext` data class.
+
+ ??? note "`ClassPatternsContext`"
+ ::: edsnlp.pipes.qualifiers.contextual.contextual.ClassPatternsContext
+ options:
+ heading_level: 1
+ only_parameters: "no-header"
+ skip_parameters: []
+ show_source: false
+ show_toc: false
+
+ """
+
+ def __init__(
+ self,
+ nlp: PipelineProtocol,
+ name: Optional[str] = "contextual_qualifier",
+ *,
+ span_getter: SpanGetterArg,
+ patterns: Dict[str, Dict[Union[str, int], ClassPatternsContext]],
+ ):
+ """
+ Initialize the ContextualQualifier.
+
+ Parameters
+ ----------
+ nlp : PipelineProtocol
+ The NLP pipeline object.
+ name : Optional[str], default="contextual_qualifier"
+ The name of the qualifier.
+ span_getter : SpanGetterArg
+ The span getter argument to identify spans to qualify.
+ patterns : Dict[str, Dict[Union[str, int], ClassPatternsContext]]
+ A dictionary of patterns to match in the text.
+ """
+ self.span_getter = span_getter
+ self.named_matchers = list() # Will contain all the named matchers
+
+ for pattern_name, named_dict in patterns.items():
+ for value, class_patterns in named_dict.items():
+ if isinstance(class_patterns, dict):
+ class_patterns = ClassPatternsContext(**class_patterns)
+
+ name_value_str = str(pattern_name) + "_" + str(value)
+ name_value_matcher = GenericMatcher(
+ nlp=nlp,
+ terms=class_patterns.terms,
+ regex=class_patterns.regex,
+ attr=class_patterns.attr,
+ ignore_excluded=class_patterns.ignore_excluded,
+ ignore_space_tokens=class_patterns.ignore_space_tokens,
+ span_setter=name_value_str,
+ )
+
+ self.named_matchers.append(
+ ClassMatcherContext(
+ name=pattern_name,
+ value=value,
+ matcher=name_value_matcher,
+ context_words=class_patterns.context_words,
+ context_sents=class_patterns.context_sents,
+ )
+ )
+
+ self.set_extensions()
+
+ super().__init__(
+ nlp=nlp,
+ name=name,
+ span_getter=validate_span_getter(span_getter),
+ )
+
+ def set_extensions(self) -> None:
+ """
+ Sets custom extensions on the Span object for each context name.
+ """
+ for named_matcher in self.named_matchers:
+ if not Span.has_extension(named_matcher.name):
+ Span.set_extension(named_matcher.name, default=None)
+
+ def get_matches(self, context: Span) -> List[Span]:
+ """
+ Extracts matches from the context span.
+
+ Parameters
+ ----------
+ context : Span
+ The span context to look for a match.
+
+ Returns
+ -------
+ List[Span]
+ List of detected spans.
+ """
+ match_iterator = (
+ *self.phrase_matcher(context, as_spans=True),
+ *self.regex_matcher(context, as_spans=True),
+ )
+
+ matches = chain.from_iterable(match_iterator)
+
+ return list(matches)
+
+ def __call__(self, doc: Doc) -> Doc:
+ """
+ Processes the document, qualifying spans based on contextual information.
+
+ Parameters
+ ----------
+ doc : Doc
+ The spaCy document to process.
+
+ Returns
+ -------
+ Doc
+ The processed document with qualified spans.
+ """
+ for matcher in self.named_matchers:
+ span_context_getter = make_span_context_getter(
+ context_words=matcher.context_words, context_sents=matcher.context_sents
+ )
+ for ent in get_spans(doc, self.span_getter):
+ context = span_context_getter(ent)
+
+ matches = matcher.matcher.process(context)
+
+ if len(matches) > 0:
+ ent._.set(matcher.name, matcher.value)
+
+ return doc
diff --git a/edsnlp/pipes/qualifiers/contextual/factory.py b/edsnlp/pipes/qualifiers/contextual/factory.py
new file mode 100644
index 0000000000..c5b865bd24
--- /dev/null
+++ b/edsnlp/pipes/qualifiers/contextual/factory.py
@@ -0,0 +1,7 @@
+from edsnlp.core import registry
+from edsnlp.pipes.qualifiers.contextual.contextual import ContextualQualifier
+
+create_component = registry.factory.register(
+ "eds.contextual_qualifier",
+ assigns=["doc.spans"],
+)(ContextualQualifier)
diff --git a/edsnlp/pipes/qualifiers/external_information/__init__.py b/edsnlp/pipes/qualifiers/external_information/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/edsnlp/pipes/qualifiers/external_information/external_information.py b/edsnlp/pipes/qualifiers/external_information/external_information.py
new file mode 100644
index 0000000000..8ad7c77fd3
--- /dev/null
+++ b/edsnlp/pipes/qualifiers/external_information/external_information.py
@@ -0,0 +1,319 @@
+import datetime as dt
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+import pandas as pd
+from spacy.tokens import Doc, Span
+
+from edsnlp.core import PipelineProtocol
+from edsnlp.pipes.base import (
+ BaseSpanAttributeClassifierComponent,
+ SpanGetterArg,
+)
+from edsnlp.utils.bindings import make_binding_getter
+from edsnlp.utils.span_getters import (
+ get_spans,
+)
+
+
+@dataclass
+class ExternalInformation:
+ """
+ Parameters
+ ----------
+ doc_attr: str
+ The elements under this attribute should be
+ a list of dicts with keys `value` and `class`
+ (List[Dict[str, Any]]).
+
+ ### Example:
+ ```python
+ import datetime
+
+ doc_attr = "_.context_dates"
+ context_dates = [
+ {"value": datetime.datetime(2024, 2, 15), "class": "irm"},
+ {"value": datetime.datetime(2024, 2, 7), "class": "biopsy"},
+ ]
+ ```
+
+ span_attribute = "_.date.to_datetime()"
+
+ threshold = datetime.timedelta(days=0)
+
+ reduce: str = "all", the way to aggregate the matches
+ one of ["all", "one_only", "closest"]
+
+ comparison_type: str = "similarity", the way to compare the values.
+ One of ["similarity", "exact_match"]
+ """
+
+ doc_attr: str
+ span_attribute: str
+ threshold: Union[float, dt.timedelta]
+ reduce: str = "all" # "one_only" , "closest" # TODO: implement
+ comparison_type: str = "similarity" # "exact_match" # TODO: implement
+
+
+class ExternalInformationQualifier(BaseSpanAttributeClassifierComponent):
+ """
+ The `eds.external_information_qualifier` pipeline component qualifies spans
+ in a document based on external information and a defined distance to these
+ contextual/external elements as in Distant Supervision (http://deepdive.stanford.edu/distant_supervision).
+
+ Parameters
+ ----------
+ nlp : PipelineProtocol
+ The spaCy pipeline object.
+ name : Optional[str], default="distant_qualifier"
+ The name of the component.
+ span_getter : SpanGetterArg
+ The function or callable to get spans from the document.
+ external_information : Dict[str, ExternalInformation]
+ A dictionary where keys are the names of the attributes to set on spans,
+ and values are ExternalInformation objects defining the context and comparison
+ settings.
+
+ ??? note "`ExternalInformation`"
+ ::: edsnlp.pipes.qualifiers.external_information.external_information.ExternalInformation
+ options:
+ heading_level: 1
+ only_parameters: "no-header"
+ skip_parameters: []
+ show_source: false
+ show_toc: false
+
+ Methods
+ -------
+ set_extensions()
+ Sets custom extensions on the Span object for each context name.
+ annotate(ent, name, value)
+ Annotates a span with a given value.
+ distance(spans, ctx_spans_values)
+ Computes the distance between spans and context values.
+ threshold(distances, threshold)
+ Applies a threshold to the computed distances.
+ mask_to_dict(idx_x, idx_y)
+ Converts mask indices to a dictionary.
+ reduce(mask, reduce_mode)
+ Reduces the mask based on the specified mode.
+ annotate(labels, filtered_spans, ctx_classes, name)
+ Annotates spans with labels based on the reduced mask.
+ __call__(doc: Doc) -> Doc
+ Processes the document, qualifying spans based on their distance to context
+ elements.
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ nlp: PipelineProtocol,
+ name: Optional[str] = "external_information_qualifier",
+ *,
+ span_getter: SpanGetterArg,
+ external_information: Dict[str, ExternalInformation],
+ ):
+ """
+ Initializes the ExternalInformationQualifier component.
+
+ Parameters
+ ----------
+ nlp : PipelineProtocol
+ The spaCy pipeline object.
+ name : Optional[str], default="external_information_qualifier"
+ The name of the component.
+ span_getter : SpanGetterArg
+ The function or callable to get spans from the document.
+ external_information : Dict[str, ExternalInformationQualifier]
+ A dictionary where keys are the names of the attributes to set on spans,
+ and values are ExternalInformationQualifier objects defining the context and
+ comparison settings.
+ """
+ for key, context in external_information.items():
+ if isinstance(context, dict):
+ external_information[key] = ExternalInformation(**context)
+ self.distant_context = external_information
+
+ super().__init__(nlp, name, span_getter=span_getter)
+
+ def set_extensions(self) -> None:
+ """
+ Sets custom extensions on the Span object for each context name.
+ """
+ for name in self.distant_context.keys():
+ if not Span.has_extension(name):
+ Span.set_extension(name, default=None)
+
+ def distance(self, spans, ctx_spans_values):
+ """
+ Computes the distance between spans and context values.
+
+ Parameters
+ ----------
+ spans : List
+ The list of span attributes.
+ ctx_spans_values : List
+ The list of context values.
+
+ Returns
+ -------
+ np.ndarray
+ The computed distances.
+ """
+ doc_elements = np.array(spans) # shape: N
+ ctx_elements = np.array(ctx_spans_values) # shape: M
+ distances = doc_elements[:, None] - ctx_elements[None, :] # shape: N x M
+
+ return distances
+
+ def threshold(self, distances: np.ndarray, threshold: Union[float, dt.timedelta]):
+ """
+ Applies a threshold to the computed distances.
+
+ Parameters
+ ----------
+ distances : np.ndarray
+ The computed distances.
+ threshold : Union[float, dt.timedelta]
+ The threshold value.
+
+ Returns
+ -------
+ np.ndarray
+ A mask indicating which distances are within the threshold.
+ """
+ mask = np.abs(distances) <= threshold
+ return mask
+
+ def mask_to_dict(self, idx_x: np.ndarray, idx_y: np.ndarray):
+ """
+ Converts mask indices to a dictionary.
+
+ Parameters
+ ----------
+ idx_x : np.ndarray
+ The indices of the spans.
+ idx_y : np.ndarray
+ The indices of the context values.
+
+ Returns
+ -------
+ Dict[int, List[int]]
+ A dictionary mapping span indices to context value indices.
+ """
+ result = {}
+ for x, y in zip(idx_x, idx_y):
+ if x not in result:
+ result[x] = []
+ result[x].append(y)
+ return result
+
+ def reduce(self, mask: np.ndarray, reduce_mode: str):
+ """
+ Reduces the mask based on the specified mode.
+
+ Parameters
+ ----------
+ mask : np.ndarray
+ The mask indicating which distances are within the threshold.
+ reduce_mode : str
+ The mode to use for reducing the mask.
+ One of ["all", "one_only", "closest"].
+
+ Returns
+ -------
+ Dict[int, List[int]]
+ A dictionary mapping span indices to context value indices.
+ """
+ if reduce_mode == "all":
+ idx_x, idx_y = np.nonzero(mask)
+
+ result = self.mask_to_dict(idx_x, idx_y)
+ return result
+ else:
+ raise NotImplementedError
+
+ def annotate(
+ self,
+ labels: Dict[int, List[int]],
+ filtered_spans: List[Span],
+ ctx_classes: List[Union[str, int]],
+ name: str,
+ ):
+ """
+ Annotates spans with labels based on the reduced mask.
+
+ Parameters
+ ----------
+ labels : Dict[int, List[int]]
+ A dictionary mapping span indices to context value indices.
+ filtered_spans : List[Span]
+ The list of filtered spans.
+ ctx_classes : List[Union[str, int]]
+ The list of context classes.
+ name : str
+ The name of the attribute to set.
+ """
+ for key, values in labels.items():
+ span = filtered_spans[key]
+ label_names = [ctx_classes[j] for j in values]
+ span._.set(name, label_names)
+
+ def __call__(self, doc: Doc) -> Doc:
+ """
+ Processes the document, qualifying spans based on their distance
+ to context elements.
+
+ Parameters
+ ----------
+ doc : Doc
+ The spaCy document to process.
+
+ Returns
+ -------
+ Doc
+ The processed document with qualified spans.
+ """
+ for name, context in self.distant_context.items():
+ # Get spans to qualify and their attributes
+ doc_spans = list(get_spans(doc, self.span_getter))
+ binding_getter_span_attr = make_binding_getter(context.span_attribute)
+
+ filtered_spans = [
+ span
+ for span in doc_spans
+ if not pd.isna(binding_getter_span_attr(span))
+ ]
+
+ filtered_spans_attr = [
+ binding_getter_span_attr(span) for span in filtered_spans
+ ]
+
+ # Get context to annotate distantly
+ binding_getter_doc_attr = make_binding_getter(context.doc_attr)
+ context_doc: Optional[List[Dict[str, Any]]] = binding_getter_doc_attr(doc)
+ if isinstance(context_doc, list):
+ ctx_values = [i.get("value") for i in context_doc] # values to look for
+ ctx_classes = [i.get("class") for i in context_doc] # classes to assign
+ if len(ctx_values) > 0:
+ assert isinstance(ctx_values[0], (dt.datetime, dt.date)), (
+ "Values should be datetime objects. Future: add support for"
+ " other types"
+ )
+ else:
+ ctx_values = []
+ ctx_classes = []
+
+ # Compute distance
+ if context.comparison_type == "similarity":
+ distances = self.distance(filtered_spans_attr, ctx_values)
+ mask = self.threshold(distances, context.threshold)
+
+ labels = self.reduce(mask, context.reduce)
+ else:
+ raise NotImplementedError
+
+ # Qualify / Annotate
+ self.annotate(labels, filtered_spans, ctx_classes, name)
+
+ return doc
diff --git a/edsnlp/pipes/qualifiers/external_information/factory.py b/edsnlp/pipes/qualifiers/external_information/factory.py
new file mode 100644
index 0000000000..72ff409ee2
--- /dev/null
+++ b/edsnlp/pipes/qualifiers/external_information/factory.py
@@ -0,0 +1,9 @@
+from edsnlp.core import registry
+from edsnlp.pipes.qualifiers.external_information.external_information import (
+ ExternalInformationQualifier,
+)
+
+create_component = registry.factory.register(
+ "eds.external_information_qualifier",
+ assigns=["doc.spans"],
+)(ExternalInformationQualifier)
diff --git a/edsnlp/pipes/trainable/span_classifier/span_classifier.py b/edsnlp/pipes/trainable/span_classifier/span_classifier.py
index 3a4b0a4b6f..a597cf9fca 100644
--- a/edsnlp/pipes/trainable/span_classifier/span_classifier.py
+++ b/edsnlp/pipes/trainable/span_classifier/span_classifier.py
@@ -343,20 +343,36 @@ def set_extensions(self):
if not Span.has_extension(qlf):
Span.set_extension(qlf, default=None)
- def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]):
+ def post_init(
+ self,
+ gold_data: Iterable[Doc],
+ exclude: Set[str],
+ ):
super().post_init(gold_data, exclude=exclude)
bindings = [
(qlf, labels, dict.fromkeys(vals)) for qlf, labels, vals in self.bindings
]
+ self.binding_target_shape = dict.fromkeys(qlf for qlf, _, _ in bindings)
+
for doc in gold_data:
spans = list(get_spans(doc, self.span_getter))
for span in spans:
- for attr, labels, values in bindings:
+ for attr, labels, values in bindings: # FIXME
+ binding_has_softlabels = False
if labels is True or span.label_ in labels:
value = BINDING_GETTERS[attr](span)
if value is not None or self.keep_none:
- values[value] = None
+ if isinstance(value, dict):
+ binding_has_softlabels = True
+ for k in value.keys():
+ values[k] = None
+ else:
+ values[value] = None
+
+ self.binding_target_shape[attr] = (
+ len(values) if binding_has_softlabels else 1
+ )
bindings = [
(attr, labels, sorted(values, key=str)) for attr, labels, values in bindings
@@ -368,7 +384,7 @@ def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]):
f"Attribute {attr} for labels {labels} should have at "
f"least 2 values but found {len(values)}: {values}."
)
-
+ self.exist_soft_labels = max(self.binding_target_shape.values()) > 1
self.update_bindings(bindings)
def update_bindings(self, bindings: List[Tuple[str, SpanFilter, List[Any]]]):
@@ -454,18 +470,77 @@ def preprocess(self, doc: Doc, **kwargs) -> Dict[str, Any]:
def preprocess_supervised(self, doc: Doc) -> Dict[str, Any]:
preps = self.preprocess(doc)
- return {
- **preps,
- "targets": [
- [
- values_to_idx.get(BINDING_GETTERS[qlf](span), -100)
- if labels is True or span.label_ in labels
- else -100
- for qlf, labels, values_to_idx in self.bindings_to_idx
- ]
- for span in preps["$spans"]
- ],
- }
+ targets = []
+ for span in preps["$spans"]:
+ span_targets = []
+ for qlf, labels, values_to_idx in self.bindings_to_idx:
+ if labels is True or span.label_ in labels:
+ value = BINDING_GETTERS[qlf](span)
+ if isinstance(value, dict):
+ # Probabilities dict: convert to vector in
+ # order of values_to_idx
+ prob_vec = [
+ float(value.get(val, -100)) for val in values_to_idx
+ ]
+
+ span_targets.append(prob_vec)
+ else:
+ idx = values_to_idx.get(value, -100)
+ if self.exist_soft_labels:
+ if idx != -100:
+ target = F.one_hot(
+ torch.tensor(idx),
+ num_classes=len(values_to_idx),
+ ).tolist()
+ else:
+ target = [idx] * len(values_to_idx)
+ else:
+ target = idx
+ span_targets.append(target)
+ else:
+ if self.exist_soft_labels:
+ ignore_value = [-100] * len(values_to_idx)
+ span_targets.append(ignore_value)
+ else:
+ span_targets.append(-100)
+ targets.append(span_targets)
+ return {**preps, "targets": targets}
+
+ # def transform_label(self, value, values_to_idx, target_shape):
+ # if isinstance(value, dict):
+ # target = [float(value.get(val, -100)) for val in values_to_idx]
+ # else:
+ # idx = values_to_idx.get(value, -100)
+ # if target_shape > 1:
+ # target = F.one_hot(
+ # torch.tensor(idx),
+ # num_classes=len(values_to_idx),
+ # ).tolist()
+ # else:
+ # target = idx
+ # return target
+
+ # def preprocess_supervised(self, doc: Doc) -> Dict[str, Any]:
+ # preps = self.preprocess(doc)
+ # targets = []
+
+ # for qlf, labels, values_to_idx in self.bindings_to_idx:
+ # qlf_target_shape = self.binding_target_shape[qlf]
+ # qlf_targets = []
+ # for span in preps["$spans"]:
+ # if labels is True or span.label_ in labels:
+ # value = BINDING_GETTERS[qlf](span)
+ # target = self.transform_label(
+ # value, values_to_idx, qlf_target_shape
+ # )
+ # else:
+ # if qlf_target_shape > 1:
+ # target = [-100] * len(values_to_idx)
+ # else:
+ # target = -100
+ # qlf_targets.append(target)
+ # targets.append(qlf_targets)
+ # return {**preps, "targets": targets}
def collate(self, batch: Dict[str, Sequence[Any]]) -> SpanClassifierBatchInput:
collated: SpanClassifierBatchInput = {
@@ -474,15 +549,22 @@ def collate(self, batch: Dict[str, Sequence[Any]]) -> SpanClassifierBatchInput:
if "targets" in batch:
targets = ft.as_folded_tensor(
batch["targets"],
- dtype=torch.long,
+ dtype=torch.float if self.exist_soft_labels else torch.long,
full_names=("sample", "span", "group"),
data_dims=("span", "group"),
).as_tensor()
- collated["targets"] = targets.view(len(targets), len(self.bindings))
+ if self.exist_soft_labels:
+ collated["targets"] = targets
+ else:
+ collated["targets"] = targets.view(len(targets), len(self.bindings))
+
return collated
# noinspection SpellCheckingInspection
- def forward(self, batch: SpanClassifierBatchInput) -> BatchOutput:
+ def forward(
+ self,
+ batch: SpanClassifierBatchInput,
+ ) -> BatchOutput:
"""
Apply the span classifier module to the document embeddings and given spans to:
- compute the loss
@@ -514,10 +596,14 @@ def forward(self, batch: SpanClassifierBatchInput) -> BatchOutput:
# - `negated=False` and `negated=True`
for group_idx, bindings_indexer in enumerate(self.bindings_indexers):
if "targets" in batch:
+ if self.exist_soft_labels:
+ mask = torch.all(batch["targets"][:, group_idx] != -100, axis=1)
+ else:
+ mask = batch["targets"][:, group_idx] != -100
losses.append(
F.cross_entropy(
- binding_scores[:, bindings_indexer],
- batch["targets"][:, group_idx],
+ binding_scores[mask, bindings_indexer],
+ batch["targets"][mask, group_idx],
reduction="sum",
weight=torch.tensor(self.label_weights, dtype=torch.float)[
bindings_indexer
diff --git a/mkdocs.yml b/mkdocs.yml
index 8cc5cbaaf8..1db5551aa7 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -56,6 +56,7 @@ nav:
- tutorials/training-ner.md
- tutorials/training-span-classifier.md
- tutorials/tuning.md
+ - advanced-tutorials/distant_annotation.md
- Pipes:
- Overview: pipes/index.md
- Core Pipelines:
@@ -73,6 +74,8 @@ nav:
- pipes/qualifiers/hypothesis.md
- pipes/qualifiers/reported-speech.md
- pipes/qualifiers/history.md
+ - pipes/qualifiers/contextual.md
+ - pipes/qualifiers/external_information.md
- Miscellaneous:
- pipes/misc/index.md
- pipes/misc/dates.md
@@ -179,6 +182,16 @@ extra:
extra_css:
- assets/stylesheets/extra.css
+ - assets/stylesheets/cards.css
+ - assets/termynal/termynal.css
+
+extra_javascript:
+ - https://cdn.jsdelivr.net/npm/vega@5
+ - https://cdn.jsdelivr.net/npm/vega-lite@5
+ - https://cdn.jsdelivr.net/npm/vega-embed@6
+ - assets/termynal/termynal.js
+ - https://polyfill.io/v3/polyfill.min.js?features=es6
+ - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
watch:
- contributing.md
@@ -256,9 +269,8 @@ markdown_extensions:
slugify: !!python/object/apply:pymdownx.slugs.slugify
kwds:
case: lower
- #- pymdownx.arithmatex:
- # generic: true
- - markdown_grid_tables
+ - pymdownx.arithmatex:
+ generic: true
- footnotes
- md_in_html
- attr_list
diff --git a/pyproject.toml b/pyproject.toml
index 1ccd9df4bf..76e4f2b0eb 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -178,6 +178,8 @@ where = ["."]
"eds.hypothesis" = "edsnlp.pipes.qualifiers.hypothesis.factory:create_component"
"eds.negation" = "edsnlp.pipes.qualifiers.negation.factory:create_component"
"eds.reported_speech" = "edsnlp.pipes.qualifiers.reported_speech.factory:create_component"
+"eds.contextual_qualifier" = "edsnlp.pipes.qualifiers.contextual.factory:create_component"
+"eds.external_information_qualifier" = "edsnlp.pipes.qualifiers.external_information.factory:create_component"
# Misc
"eds.consultation_dates" = "edsnlp.pipes.misc.consultation_dates.factory:create_component"
diff --git a/tests/conftest.py b/tests/conftest.py
index 29e9944042..a9ac031dda 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -37,6 +37,9 @@ def pytest_collection_modifyitems(items):
items[:] = first_tests + last_tests
+EDS_SENTENCES_PIPE = "eds.sentences"
+
+
@fixture(scope="session", params=["eds", "fr"])
def lang(request):
return request.param
@@ -58,13 +61,23 @@ def blank_nlp(lang):
model = spacy.blank("eds")
else:
model = edsnlp.blank("fr")
- model.add_pipe("eds.sentences")
+ model.add_pipe(EDS_SENTENCES_PIPE)
+ return model
+
+
+@fixture
+def edsnlp_blank_nlp(lang):
+ if lang == "eds":
+ model = edsnlp.blank("eds")
+ else:
+ model = edsnlp.blank("fr")
+ model.add_pipe(EDS_SENTENCES_PIPE)
return model
def make_ml_pipeline():
nlp = edsnlp.blank("eds")
- nlp.add_pipe("eds.sentences", name="sentences")
+ nlp.add_pipe(EDS_SENTENCES_PIPE, name="sentences")
nlp.add_pipe(
"eds.transformer",
name="transformer",
diff --git a/tests/pipelines/qualifiers/test_contextual.py b/tests/pipelines/qualifiers/test_contextual.py
new file mode 100644
index 0000000000..2b79d63d3e
--- /dev/null
+++ b/tests/pipelines/qualifiers/test_contextual.py
@@ -0,0 +1,68 @@
+from edsnlp.utils.examples import Entity, parse_example
+
+text = """
+RCP du 18/12/2024 : DUPONT Jean
+
+Homme de 68 ans adressé en consultation d’oncologie pour prise en charge d’une tumeur du colon.
+Antécédents : HTA, diabète de type 2, dyslipidémie, tabagisme actif (30 PA), alcoolisme chronique (60 g/jour).
+
+Examen clinique : patient en bon état général, poids 80 kg, taille 1m75.
+
+
+HISTOIRE DE LA MALADIE :
+Lors du PET-CT (14/02/2024), des dépôts pathologiques ont été observés qui coïncidaient avec les résultats du scanner.
+Le 15/02/2024, une IRM a été réalisée pour évaluer l’extension de la tumeur.
+Une colonoscopie a été réalisée le 17/02/2024 avec une biopsie d'adénopathie sous-carinale.
+Une deuxième a été biopsié le 18/02/2024. Les résultats de la biopsie ont confirmé un adénocarcinome du colon.
+Il a été opéré le 20/02/2024. L’examen anatomopathologique de la pièce opératoire a confirmé un adénocarcinome du colon stade IV avec métastases hépatiques et pulmonaires.
+Trois mois après la fin du traitement de chimiothérapie (avril 2024), le patient a signalé une aggravation progressive des symptômes
+
+CONCLUSION : Adénocarcinome du colon stade IV avec métastases hépatiques et pulmonaires.
+""" # noqa: E501
+
+examples = [
+ text,
+]
+
+
+def test_contextual_qualifier(edsnlp_blank_nlp):
+ edsnlp_blank_nlp.add_pipe("eds.dates")
+
+ edsnlp_blank_nlp.add_pipe(
+ "eds.contextual_qualifier",
+ config=dict(
+ span_getter="dates",
+ patterns={
+ "lf1": {
+ "Magnetic resonance imaging (procedure)": {
+ "terms": {"irm": ["IRM", "imagerie par résonance magnétique"]},
+ "regex": None,
+ "context_words": 0,
+ "context_sents": 1,
+ "attr": "TEXT",
+ }
+ },
+ "lf2": {
+ "Biopsy (procedure)": {
+ "regex": {"biopsy": ["biopsie", "biopsié"]},
+ "context_words": (10, 10),
+ "context_sents": 0,
+ "attr": "TEXT",
+ }
+ },
+ },
+ ),
+ )
+
+ for text, entities in map(parse_example, examples):
+ doc = edsnlp_blank_nlp(text)
+
+ dates = doc.spans["dates"]
+
+ assert len(dates) == len(entities)
+
+ for ent, entity in zip(dates, entities):
+ entity: Entity
+ assert ent.text == text[entity.start_char : entity.end_char]
+ assert ent._.lf1 == entity.modifiers_dict.get("lf1")
+ assert ent._.lf2 == entity.modifiers_dict.get("lf2")
diff --git a/tests/pipelines/qualifiers/test_external_information.py b/tests/pipelines/qualifiers/test_external_information.py
new file mode 100644
index 0000000000..f99572d8e7
--- /dev/null
+++ b/tests/pipelines/qualifiers/test_external_information.py
@@ -0,0 +1,94 @@
+import datetime
+
+import pandas as pd
+
+import edsnlp
+from edsnlp.utils.examples import Entity, parse_example
+
+text = """
+RCP du 18/12/2024 : DUPONT Jean
+
+Homme de 68 ans adressé en consultation d’oncologie pour prise en charge d’une tumeur du colon.
+Antécédents : HTA, diabète de type 2, dyslipidémie, tabagisme actif (30 PA), alcoolisme chronique (60 g/jour).
+
+Examen clinique : patient en bon état général, poids 80 kg, taille 1m75.
+
+
+HISTOIRE DE LA MALADIE :
+Lors du PET-CT (14/02/2024), des dépôts pathologiques ont été observés qui coïncidaient avec les résultats du scanner.
+Le 15/02/2024, une IRM a été réalisée pour évaluer l’extension de la tumeur.
+Une colonoscopie a été réalisée le 17/02/2024 avec une biopsie d'adénopathie sous-carinale.
+Une deuxième a été biopsié le 18/02/2024. Les résultats de la biopsie ont confirmé un adénocarcinome du colon.
+Il a été opéré le 20/02/2024. L’examen anatomopathologique de la pièce opératoire a confirmé un adénocarcinome du colon stade IV avec métastases hépatiques et pulmonaires.
+Trois mois après la fin du traitement de chimiothérapie (avril 2024), le patient a signalé une aggravation progressive des symptômes
+
+CONCLUSION : Adénocarcinome du colon stade IV avec métastases hépatiques et pulmonaires.
+""" # noqa: E501
+
+# Create context dates
+# The elements under this attribute should be a list of dicts with keys value and class
+context_dates = [
+ {
+ "value": datetime.datetime(2024, 2, 15),
+ "class": "Magnetic resonance imaging (procedure)",
+ },
+ {"value": datetime.datetime(2024, 2, 17), "class": "Biopsy (procedure)"},
+ {"value": datetime.datetime(2024, 2, 17), "class": "Colonoscopy (procedure)"},
+]
+
+examples = [
+ (text, context_dates),
+]
+
+
+def test_external_information_qualifier(edsnlp_blank_nlp):
+ edsnlp_blank_nlp.add_pipe("eds.dates")
+
+ edsnlp_blank_nlp.add_pipe(
+ "eds.external_information_qualifier",
+ config=dict(
+ span_getter="dates",
+ external_information={
+ "lf1": dict(
+ doc_attr="_.context_dates",
+ span_attribute="_.date.to_datetime()",
+ threshold=datetime.timedelta(days=0),
+ )
+ },
+ ),
+ )
+
+ texts = [text for text, _ in examples]
+ context_dates = [context_dates for _, context_dates in examples]
+ for (text, entities), context in zip(map(parse_example, texts), context_dates):
+ # Create a dataframe
+ df = pd.DataFrame.from_records(
+ [
+ {
+ "person_id": 1,
+ "note_id": 1,
+ "note_text": text,
+ "context_dates": context,
+ }
+ ]
+ )
+
+ doc_iterator = edsnlp.data.from_pandas(
+ df, converter="omop", doc_attributes=["context_dates"]
+ )
+
+ docs = list(edsnlp_blank_nlp.pipe(doc_iterator))
+
+ doc = docs[0]
+
+ dates = doc.spans["dates"]
+
+ assert len(dates) == len(entities)
+
+ for ent, entity in zip(dates, entities):
+ entity: Entity
+ assert ent.text == text[entity.start_char : entity.end_char]
+ for e in entity.modifiers:
+ value = e.value
+ if value:
+ assert value in ent._.lf1
diff --git a/tests/pipelines/trainable/test_span_qualifier.py b/tests/pipelines/trainable/test_span_qualifier.py
index 66a75abc65..2a02f1ce56 100644
--- a/tests/pipelines/trainable/test_span_qualifier.py
+++ b/tests/pipelines/trainable/test_span_qualifier.py
@@ -27,7 +27,7 @@ def gold():
Span(doc1, 0, 1, "event"), # criteria = "si"
Span(doc1, 3, 4, "criteria"),
]
- doc1.spans["sc"][0]._.test_negated = False
+ doc1.spans["sc"][0]._.test_negated = {False: 0.6, True: 0.4}
doc1.spans["sc"][1]._.test_negated = True
doc1.spans["sc"][2]._.test_negated = False
doc1.spans["sc"][1]._.event_type = "stop"
diff --git a/tests/training/dataset2.jsonl b/tests/training/dataset2.jsonl
new file mode 100644
index 0000000000..01f313722b
--- /dev/null
+++ b/tests/training/dataset2.jsonl
@@ -0,0 +1,2 @@
+{"note_id": "1", "note_text": "Pas de cancer chez le patient ou sa famille.\nOn trouve un nodule superieur centimétrique droit évocateur de fibroanédome.", "entities": [{"start": 7, "end": 13, "label": "sosy", "negation": {"1":0.7,"0":0.3}}, {"start": 58, "end": 64, "label": "sosy", "negation": {"1":0.2,"0":0.8},"speciality":1}, {"start": 75, "end": 88, "label": "measure", "unit": "cm","speciality":0}, {"start": 108, "end": 120, "label": "sosy", "negation": {"1":0.2,"0":0.8},"speciality":1}]}
+{"note_id": "2", "note_text": "La patiente a un gros rhume, sans fièvre ou douleur thoracique. Elle fait 30 kg.", "entities": [{"start": 22, "end": 27, "label": "sosy", "negation": {"1":0.2,"0":0.8}}, {"start": 34, "end": 40, "label": "sosy", "negation": {"1":0.7,"0":0.3}}, {"start": 44, "end": 62, "label": "sosy", "negation": {"1":0.7,"0":0.3},"speciality":0}, {"start": 74, "end": 79, "label": "measure", "unit": "kg","speciality":1}]}
diff --git a/tests/training/qlf_config.yml b/tests/training/qlf_config.yml
index afd6da65e2..491c17f4a2 100644
--- a/tests/training/qlf_config.yml
+++ b/tests/training/qlf_config.yml
@@ -16,7 +16,8 @@ nlp:
qualifier:
'@factory': eds.span_classifier
- attributes: { "_.negation": [ "sosy" ] }
+ attributes: { "_.negation": [ "sosy" ], "speciality":True }
+ # attributes: ["_.negation"]
span_getter: ["ents", "gold_spans"]
context_getter: { '@misc': eds.span_context_getter, "context_words": 30, "context_sents": 1 }
@@ -63,7 +64,7 @@ scorer:
train_data:
data:
"@readers": json
- path: ./dataset.jsonl
+ path: ./dataset2.jsonl
converter:
- '@factory': 'myproject.custom_dict2doc'
span_setter : 'gold_spans'
@@ -80,17 +81,17 @@ train_data:
# regex: '\\n{2,}'
shuffle: dataset
batch_size: 4 docs
- pipe_names: [ "qualifier" ]
+ pipe_names: []
sub_batch_size: 10 words
val_data:
"@readers": json
- path: ./dataset.jsonl
+ path: ./dataset2.jsonl
converter:
- '@factory': myproject.custom_dict2doc
span_setter : 'gold_spans'
span_attributes : ['negation']
- bool_attributes : ['negation'] # default standoff to doc converter
+ bool_attributes : [] # default standoff to doc converter
# 🚀 TRAIN SCRIPT OPTIONS
train: