Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add spacy NER, fix some UI. #425

Merged
merged 7 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/router_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .signals.concept_labels import ConceptLabelsSignal
from .signals.concept_scorer import ConceptScoreSignal
from .signals.default_signals import register_default_signals
from .signals.semantic_similarity import SemanticSimilaritySignal
from .signals.signal import (
Signal,
TextEmbeddingModelSignal,
Expand Down Expand Up @@ -202,7 +203,8 @@ class ListFilter(BaseModel):
Filter = Union[BinaryFilter, UnaryFilter, ListFilter]

AllSignalTypes = Union[ConceptScoreSignal, ConceptLabelsSignal, SubstringSignal,
TextEmbeddingModelSignal, TextEmbeddingSignal, TextSignal, Signal]
SemanticSimilaritySignal, TextEmbeddingModelSignal, TextEmbeddingSignal,
TextSignal, Signal]


# We override the `Column` class so we can add explicitly all signal types for better OpenAPI spec.
Expand Down
2 changes: 2 additions & 0 deletions src/signals/default_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from ..embeddings.cohere import Cohere
from ..embeddings.sbert import SBERT
from .concept_scorer import ConceptScoreSignal
from .ner import SpacyNER
from .pii import PIISignal
from .signal import register_signal
from .text_statistics import TextStatisticsSignal
Expand All @@ -15,6 +16,7 @@ def register_default_signals() -> None:
# Text.
register_signal(PIISignal)
register_signal(TextStatisticsSignal)
register_signal(SpacyNER)

# Embeddings.
register_signal(Cohere)
Expand Down
50 changes: 50 additions & 0 deletions src/signals/ner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Compute named entity recognition with SpaCy."""
from typing import Iterable, Optional

import spacy
from pydantic import Field as PydanticField
from typing_extensions import override

from ..data.dataset_utils import lilac_span
from ..schema import Field, Item, RichData, SignalInputType, field
from .signal import TextSignal


class SpacyNER(TextSignal):
"""Named entity recognition with SpaCy

For details see: [spacy.io/models](https://spacy.io/models).
""" # noqa: D415, D400
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can remove noqa if you end the first line with a period

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

name = 'spacy_ner'
display_name = 'Named Entity Recognition'

model: Optional[str] = PydanticField(
title='SpaCy package name or model path.', default='en_core_web_sm')

input_type = SignalInputType.TEXT
compute_type = SignalInputType.TEXT

_nlp: spacy.language.Language

@override
def setup(self) -> None:
self._nlp = spacy.load(
'en_core_web_sm',
# Disable everything except the NER component. See: https://spacy.io/models
disable=['tok2vec', 'tagger', 'parser', 'attribute_ruler', 'lemmatizer'])

@override
def fields(self) -> Field:
return field(fields=[field('string_span', fields={'label': 'string'})])

@override
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
text_data = (row if isinstance(row, str) else '' for row in data)

for doc in self._nlp.pipe(text_data):
result = [lilac_span(ent.start_char, ent.end_char, {'label': ent.label_}) for ent in doc.ents]

if result:
yield result
else:
yield None
40 changes: 40 additions & 0 deletions src/signals/ner_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Test the Spacy NER signal."""

from ..schema import field
from .ner import SpacyNER
from .splitters.text_splitter_test_utils import text_to_expected_spans


def test_spacy_ner_fields() -> None:
signal = SpacyNER()
signal.setup()
assert signal.fields() == field(fields=[field('string_span', fields={'label': 'string'})])


def test_ner() -> None:
signal = SpacyNER()
signal.setup()

text = ('Net income was $9.4 million compared to the prior year of $2.7 million.'
'Revenue exceeded twelve billion dollars, with a loss of $1b.')
emails = list(signal.compute([text]))

expected_spans = text_to_expected_spans(text, [
('$9.4 million', {
'label': 'MONEY'
}),
('the prior year', {
'label': 'DATE'
}),
('$2.7 million', {
'label': 'MONEY'
}),
('twelve billion dollars', {
'label': 'MONEY'
}),
('1b', {
'label': 'MONEY'
}),
])

assert emails == [expected_spans]
14 changes: 11 additions & 3 deletions src/signals/splitters/text_splitter_test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Utilities for testing text splitters."""

from typing import Optional
from typing import Optional, Union

from ...data.dataset_utils import lilac_span
from ...schema import TEXT_SPAN_END_FEATURE, TEXT_SPAN_START_FEATURE, VALUE_KEY, Item
Expand All @@ -16,14 +16,22 @@ def spans_to_text(text: str, spans: Optional[list[Item]]) -> list[str]:
]


def text_to_expected_spans(text: str, splits: list[str]) -> list[Item]:
def text_to_expected_spans(text: str, splits: Union[list[str], list[tuple[str,
Item]]]) -> list[Item]:
"""Convert text and a list of splits to a list of expected spans."""
start_offset = 0
expected_spans: list[Item] = []
for split in splits:
item: Item
if isinstance(split, str):
split, item = split, {}
elif isinstance(split, tuple):
split, item = split
else:
raise ValueError('Split should be a string or a tuple of (string, item dict).')
start = text.find(split, start_offset)
end = start + len(split)
expected_spans.append(lilac_span(start=start, end=end))
expected_spans.append(lilac_span(start=start, end=end, metadata=item))
start_offset = end

return expected_spans
9 changes: 5 additions & 4 deletions src/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def progress(it: Iterable[TProgress],
task_step_id: Optional[TaskStepId],
estimated_len: Optional[int],
step_description: Optional[str] = None,
emit_every_frac: float = .01) -> Iterable[TProgress]:
emit_every_s: float = 1.) -> Iterable[TProgress]:
"""An iterable wrapper that emits progress and yields the original iterable."""
if not task_step_id:
yield from it
Expand All @@ -223,15 +223,15 @@ def progress(it: Iterable[TProgress],

estimated_len = max(1, estimated_len) if estimated_len else None

emit_every = max(1, int(estimated_len * emit_every_frac)) if estimated_len else None

task_info: TaskInfo = get_worker().state.tasks[task_id].annotations['task_info']

it_idx = 0
start_time = time.time()
last_emit = time.time() - emit_every_s
with tqdm(it, desc=task_info.name, total=estimated_len) as tq:
for t in tq:
if estimated_len and emit_every and it_idx % emit_every == 0:
cur_time = time.time()
if estimated_len and cur_time - last_emit > emit_every_s:
it_per_sec = tq.format_dict['rate'] or 0.0
set_worker_task_progress(
task_step_id=task_step_id,
Expand All @@ -240,6 +240,7 @@ def progress(it: Iterable[TProgress],
it_per_sec=it_per_sec or 0.0,
estimated_total_sec=((estimated_len) / it_per_sec if it_per_sec else 0),
estimated_len=estimated_len)
last_emit = cur_time
yield t
it_idx += 1

Expand Down
59 changes: 59 additions & 0 deletions web/blueprint/src/lib/components/datasetView/SpanClick.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import type {DatasetState} from '$lib/stores/datasetStore';
import type {DatasetViewStore} from '$lib/stores/datasetViewStore';
import type {SignalInfoWithTypedSchema} from '$lilac';
import type {SvelteComponent} from 'svelte';
import type {Readable} from 'svelte/store';
import type {SpanDetails} from './StringSpanDetails.svelte';
import StringSpanDetails from './StringSpanDetails.svelte';

export interface SpanClickInfo {
details: () => SpanDetails;
datasetViewStore: DatasetViewStore;
datasetStore: Readable<DatasetState>;
embeddings: SignalInfoWithTypedSchema[];
addConceptLabel: (
conceptName: string,
conceptNamespace: string,
text: string,
label: boolean
) => void;
}

export function spanClick(element: HTMLSpanElement, clickInfo: SpanClickInfo) {
let spanDetailsComponent: SvelteComponent | undefined;
let curClickInfo = clickInfo;
element.addEventListener('click', e => showClickDetails(e));
function showClickDetails(e: MouseEvent) {
spanDetailsComponent = new StringSpanDetails({
props: {
details: curClickInfo.details(),
clickPosition: {x: e.clientX, y: e.clientY},
datasetViewStore: curClickInfo.datasetViewStore,
datasetStore: curClickInfo.datasetStore,
embeddings: curClickInfo.embeddings,
addConceptLabel: curClickInfo.addConceptLabel
},
target: document.body
});
spanDetailsComponent.$on('close', destroyClickInfo);
spanDetailsComponent.$on('click', destroyClickInfo);
}

function destroyClickInfo() {
spanDetailsComponent?.$destroy();
spanDetailsComponent = undefined;
}

return {
update(clickInfo: SpanClickInfo) {
curClickInfo = clickInfo;

spanDetailsComponent?.$set({
details: curClickInfo.details()
});
},
destroy() {
destroyClickInfo();
}
};
}
10 changes: 10 additions & 0 deletions web/blueprint/src/lib/components/datasetView/SpanHover.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import {deserializePath, pathIsMatching} from '$lilac';
import type {SvelteComponent} from 'svelte';
import SpanHoverTooltip, {type SpanHoverNamedValue} from './SpanHoverTooltip.svelte';

export interface SpanHoverInfo {
namedValues: SpanHoverNamedValue[];
spansHovered: string[];
isHovered: boolean;
itemScrollContainer: HTMLDivElement | null;
}
Expand All @@ -15,6 +17,14 @@ export function spanHover(element: HTMLSpanElement, spanHoverInfo: SpanHoverInfo
if (!curSpanHoverInfo.isHovered) {
return;
}
curSpanHoverInfo.namedValues = spanHoverInfo.namedValues.filter(namedValue =>
curSpanHoverInfo.spansHovered.some(path =>
pathIsMatching(deserializePath(namedValue.spanPath), deserializePath(path))
)
);
if (curSpanHoverInfo.namedValues.length === 0) {
return;
}
if (curSpanHoverInfo.itemScrollContainer != null) {
curSpanHoverInfo.itemScrollContainer.addEventListener('scroll', itemScrollListener);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
export interface SpanHoverNamedValue {
name: string;
value: DataTypeCasted;
spanPath: string;
isConcept?: boolean;
isKeywordSearch?: boolean;
isSemanticSearch?: boolean;
isNonNumericMetadata?: boolean;
}
</script>

Expand Down Expand Up @@ -37,10 +40,11 @@
<div class="named-value-name table-cell max-w-xs truncate pr-2">{namedValue.name}</div>
<div class="table-cell rounded text-right">
<span
style:background-color={namedValue.isConcept && typeof namedValue.value === 'number'
style:background-color={(namedValue.isConcept || namedValue.isSemanticSearch) &&
typeof namedValue.value === 'number'
? colorFromScore(namedValue.value)
: ''}
class:font-bold={namedValue.isKeywordSearch}
class:font-bold={namedValue.isKeywordSearch || namedValue.isNonNumericMetadata}
class="px-1"
>
{typeof namedValue.value === 'number'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,49 +14,50 @@
<script lang="ts">
import {fade} from 'svelte/transition';

import {editConceptMutation} from '$lib/queries/conceptQueries';
import {queryEmbeddings} from '$lib/queries/signalQueries';
import {getDatasetContext} from '$lib/stores/datasetStore';
import {getDatasetViewContext} from '$lib/stores/datasetViewStore';
import type {DatasetState} from '$lib/stores/datasetStore';
import type {DatasetViewStore} from '$lib/stores/datasetViewStore';
import {getComputedEmbeddings, getSearchEmbedding, getSearchPath} from '$lib/view_utils';
import {serializePath} from '$lilac';
import {serializePath, type SignalInfoWithTypedSchema} from '$lilac';
import {Button} from 'carbon-components-svelte';
import ThumbsDownFilled from 'carbon-icons-svelte/lib/ThumbsDownFilled.svelte';
import ThumbsUpFilled from 'carbon-icons-svelte/lib/ThumbsUpFilled.svelte';
import {createEventDispatcher} from 'svelte';
import type {Readable} from 'svelte/store';
import {clickOutside} from '../common/clickOutside';
import EmbeddingBadge from './EmbeddingBadge.svelte';

export let details: SpanDetails;
// The coordinates of the click so we can position the popup next to the cursor.
export let clickPosition: {x: number; y: number} | undefined;

let datasetViewStore = getDatasetViewContext();
let datasetStore = getDatasetContext();
export let datasetViewStore: DatasetViewStore;
export let datasetStore: Readable<DatasetState>;

export let embeddings: SignalInfoWithTypedSchema[];

// We cant create mutations from this component since it is hoisted so we pass the function in.
export let addConceptLabel: (
conceptName: string,
conceptNamespace: string,
text: string,
label: boolean
) => void;

$: searchPath = getSearchPath($datasetViewStore, $datasetStore);

const conceptEdit = editConceptMutation();
const dispatch = createEventDispatcher();
function addLabel(label: boolean) {
if (!details.conceptName || !details.conceptNamespace)
throw Error('Label could not be added, no active concept.');
$conceptEdit.mutate([
details.conceptNamespace,
details.conceptName,
{insert: [{text: details.text, label}]}
]);
addConceptLabel(details.conceptNamespace, details.conceptName, details.text, label);
dispatch('click');
}

// Get the embeddings.
const embeddings = queryEmbeddings();

$: searchEmbedding = getSearchEmbedding(
$datasetViewStore,
$datasetStore,
searchPath,
($embeddings.data || []).map(e => e.name)
embeddings.map(e => e.name)
);

$: computedEmbeddings = getComputedEmbeddings($datasetStore, searchPath);
Expand Down
Loading
Loading