Skip to content

Commit

Permalink
Add spacy NER, fix some UI.
Browse files Browse the repository at this point in the history
  • Loading branch information
nsthorat committed Jul 6, 2023
1 parent ee2a654 commit acfbe6d
Show file tree
Hide file tree
Showing 15 changed files with 327 additions and 93 deletions.
5 changes: 4 additions & 1 deletion src/router_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from fastapi.responses import ORJSONResponse
from pydantic import BaseModel, validator

from .signals.semantic_similarity import SemanticSimilaritySignal

from .config import data_path
from .data.dataset import BinaryOp
from .data.dataset import Column as DBColumn
Expand Down Expand Up @@ -202,7 +204,8 @@ class ListFilter(BaseModel):
Filter = Union[BinaryFilter, UnaryFilter, ListFilter]

AllSignalTypes = Union[ConceptScoreSignal, ConceptLabelsSignal, SubstringSignal,
TextEmbeddingModelSignal, TextEmbeddingSignal, TextSignal, Signal]
SemanticSimilaritySignal, TextEmbeddingModelSignal, TextEmbeddingSignal,
TextSignal, Signal]


# We override the `Column` class so we can add explicitly all signal types for better OpenAPI spec.
Expand Down
2 changes: 2 additions & 0 deletions src/signals/default_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from ..embeddings.cohere import Cohere
from ..embeddings.sbert import SBERT
from .concept_scorer import ConceptScoreSignal
from .ner import SpacyNER
from .pii import PIISignal
from .signal import register_signal
from .text_statistics import TextStatisticsSignal
Expand All @@ -15,6 +16,7 @@ def register_default_signals() -> None:
# Text.
register_signal(PIISignal)
register_signal(TextStatisticsSignal)
register_signal(SpacyNER)

# Embeddings.
register_signal(Cohere)
Expand Down
60 changes: 60 additions & 0 deletions src/signals/ner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Compute text statistics for a document."""
import re
from typing import Iterable, Optional

import spacy
from pydantic import Field as PydanticField
from typing_extensions import override

from ..data.dataset_utils import lilac_span
from ..schema import Field, Item, RichData, SignalInputType, field
from .signal import TextSignal

EMAILS_KEY = 'emails'
NUM_EMAILS_KEY = 'num_emails'

# This regex is a fully RFC 5322 regex for email addresses.
# https://uibakery.io/regex-library/email-regex-python
EMAIL_REGEX = re.compile(
"(?:[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*|\"(?:[\\x01-\\x08\\x0b\\x0c\\x0e-\\x1f\\x21\\x23-\\x5b\\x5d-\\x7f]|\\\\[\\x01-\\x09\\x0b\\x0c\\x0e-\\x7f])*\")@(?:(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?|\\[(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?|[a-z0-9-]*[a-z0-9]:(?:[\\x01-\\x08\\x0b\\x0c\\x0e-\\x1f\\x21-\\x5a\\x53-\\x7f]|\\\\[\\x01-\\x09\\x0b\\x0c\\x0e-\\x7f])+)\\])",
re.IGNORECASE)


class SpacyNER(TextSignal):
"""Named entity recognition with spacy
For details see: [spacy.io/models](https://spacy.io/models).
""" # noqa: D415, D400
name = 'spacy_ner'
display_name = 'Spacy Named Entity Recognition'

model: Optional[str] = PydanticField(
title='SpaCy package name or model path.', default='en_core_web_sm', description='')

input_type = SignalInputType.TEXT
compute_type = SignalInputType.TEXT

_nlp: spacy.language.Language

@override
def setup(self) -> None:
self._nlp = spacy.load(
'en_core_web_sm',
# Disable everything except the NER component. See: https://spacy.io/models
disable=['tok2vec', 'tagger', 'parser', 'attribute_ruler', 'lemmatizer'])

@override
def fields(self) -> Field:
return field(fields=[field('string_span', fields={'label': 'string'})])

@override
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
text_data = (row if isinstance(row, str) else '' for row in data)

for doc in self._nlp.pipe(text_data):
result = [lilac_span(ent.start_char, ent.end_char, {'label': ent.label_}) for ent in doc.ents]

if result:
yield result
else:
yield None
40 changes: 40 additions & 0 deletions src/signals/ner_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Test the Spacy NER signal."""

from ..schema import field
from .ner import SpacyNER
from .splitters.text_splitter_test_utils import text_to_expected_spans


def test_spacy_ner_fields() -> None:
signal = SpacyNER()
signal.setup()
assert signal.fields() == field(fields=[field('string_span', fields={'label': 'string'})])


def test_ner() -> None:
signal = SpacyNER()
signal.setup()

text = ('Net income was $9.4 million compared to the prior year of $2.7 million.'
'Revenue exceeded twelve billion dollars, with a loss of $1b.')
emails = list(signal.compute([text]))

expected_spans = text_to_expected_spans(text, [
('$9.4 million', {
'label': 'MONEY'
}),
('the prior year', {
'label': 'DATE'
}),
('$2.7 million', {
'label': 'MONEY'
}),
('twelve billion dollars', {
'label': 'MONEY'
}),
('1b', {
'label': 'MONEY'
}),
])

assert emails == [expected_spans]
12 changes: 9 additions & 3 deletions src/signals/splitters/text_splitter_test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Utilities for testing text splitters."""

from typing import Optional
from typing import Optional, Union

from ...data.dataset_utils import lilac_span
from ...schema import TEXT_SPAN_END_FEATURE, TEXT_SPAN_START_FEATURE, VALUE_KEY, Item
Expand All @@ -16,14 +16,20 @@ def spans_to_text(text: str, spans: Optional[list[Item]]) -> list[str]:
]


def text_to_expected_spans(text: str, splits: list[str]) -> list[Item]:
def text_to_expected_spans(text: str, splits: list[Union[str, tuple[str, Item]]]) -> list[Item]:
"""Convert text and a list of splits to a list of expected spans."""
start_offset = 0
expected_spans: list[Item] = []
for split in splits:
if isinstance(split, str):
split, item = split, None
elif isinstance(split, tuple):
split, item = split
else:
raise ValueError('Split should be a string or a tuple of (string, item dict).')
start = text.find(split, start_offset)
end = start + len(split)
expected_spans.append(lilac_span(start=start, end=end))
expected_spans.append(lilac_span(start=start, end=end, metadata=item))
start_offset = end

return expected_spans
9 changes: 5 additions & 4 deletions src/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def progress(it: Iterable[TProgress],
task_step_id: Optional[TaskStepId],
estimated_len: Optional[int],
step_description: Optional[str] = None,
emit_every_frac: float = .01) -> Iterable[TProgress]:
emit_every_s: float = 1.) -> Iterable[TProgress]:
"""An iterable wrapper that emits progress and yields the original iterable."""
if not task_step_id:
yield from it
Expand All @@ -223,15 +223,15 @@ def progress(it: Iterable[TProgress],

estimated_len = max(1, estimated_len) if estimated_len else None

emit_every = max(1, int(estimated_len * emit_every_frac)) if estimated_len else None

task_info: TaskInfo = get_worker().state.tasks[task_id].annotations['task_info']

it_idx = 0
start_time = time.time()
last_emit = time.time() - emit_every_s
with tqdm(it, desc=task_info.name, total=estimated_len) as tq:
for t in tq:
if estimated_len and emit_every and it_idx % emit_every == 0:
cur_time = time.time()
if estimated_len and cur_time - last_emit > emit_every_s:
it_per_sec = tq.format_dict['rate'] or 0.0
set_worker_task_progress(
task_step_id=task_step_id,
Expand All @@ -240,6 +240,7 @@ def progress(it: Iterable[TProgress],
it_per_sec=it_per_sec or 0.0,
estimated_total_sec=((estimated_len) / it_per_sec if it_per_sec else 0),
estimated_len=estimated_len)
last_emit = cur_time
yield t
it_idx += 1

Expand Down
59 changes: 59 additions & 0 deletions web/blueprint/src/lib/components/datasetView/SpanClick.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import type {DatasetState} from '$lib/stores/datasetStore';
import type {DatasetViewStore} from '$lib/stores/datasetViewStore';
import type {SignalInfoWithTypedSchema} from '$lilac';
import type {SvelteComponent} from 'svelte';
import type {Readable} from 'svelte/store';
import type {SpanDetails} from './StringSpanDetails.svelte';
import StringSpanDetails from './StringSpanDetails.svelte';

export interface SpanClickInfo {
details: () => SpanDetails;
datasetViewStore: DatasetViewStore;
datasetStore: Readable<DatasetState>;
embeddings: SignalInfoWithTypedSchema[];
addConceptLabel: (
conceptName: string,
conceptNamespace: string,
text: string,
label: boolean
) => void;
}

export function spanClick(element: HTMLSpanElement, clickInfo: SpanClickInfo) {
let spanDetailsComponent: SvelteComponent | undefined;
let curClickInfo = clickInfo;
element.addEventListener('click', e => showClickDetails(e));
function showClickDetails(e: MouseEvent) {
spanDetailsComponent = new StringSpanDetails({
props: {
details: curClickInfo.details(),
clickPosition: {x: e.clientX, y: e.clientY},
datasetViewStore: curClickInfo.datasetViewStore,
datasetStore: curClickInfo.datasetStore,
embeddings: curClickInfo.embeddings,
addConceptLabel: curClickInfo.addConceptLabel
},
target: document.body
});
spanDetailsComponent.$on('close', destroyClickInfo);
spanDetailsComponent.$on('click', destroyClickInfo);
}

function destroyClickInfo() {
spanDetailsComponent?.$destroy();
spanDetailsComponent = undefined;
}

return {
update(clickInfo: SpanClickInfo) {
curClickInfo = clickInfo;

spanDetailsComponent?.$set({
details: curClickInfo.details()
});
},
destroy() {
destroyClickInfo();
}
};
}
10 changes: 10 additions & 0 deletions web/blueprint/src/lib/components/datasetView/SpanHover.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import {deserializePath, pathIsMatching} from '$lilac';
import type {SvelteComponent} from 'svelte';
import SpanHoverTooltip, {type SpanHoverNamedValue} from './SpanHoverTooltip.svelte';

export interface SpanHoverInfo {
namedValues: SpanHoverNamedValue[];
spansHovered: string[];
isHovered: boolean;
itemScrollContainer: HTMLDivElement | null;
}
Expand All @@ -15,6 +17,14 @@ export function spanHover(element: HTMLSpanElement, spanHoverInfo: SpanHoverInfo
if (!curSpanHoverInfo.isHovered) {
return;
}
curSpanHoverInfo.namedValues = spanHoverInfo.namedValues.filter(namedValue =>
curSpanHoverInfo.spansHovered.some(path =>
pathIsMatching(deserializePath(namedValue.spanPath), deserializePath(path))
)
);
if (curSpanHoverInfo.namedValues.length === 0) {
return;
}
if (curSpanHoverInfo.itemScrollContainer != null) {
curSpanHoverInfo.itemScrollContainer.addEventListener('scroll', itemScrollListener);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
export interface SpanHoverNamedValue {
name: string;
value: DataTypeCasted;
spanPath: string;
isConcept?: boolean;
isKeywordSearch?: boolean;
isSemanticSearch?: boolean;
isNonNumericMetadata?: boolean;
}
</script>

Expand Down Expand Up @@ -37,10 +40,11 @@
<div class="named-value-name table-cell max-w-xs truncate pr-2">{namedValue.name}</div>
<div class="table-cell rounded text-right">
<span
style:background-color={namedValue.isConcept && typeof namedValue.value === 'number'
style:background-color={(namedValue.isConcept || namedValue.isSemanticSearch) &&
typeof namedValue.value === 'number'
? colorFromScore(namedValue.value)
: ''}
class:font-bold={namedValue.isKeywordSearch}
class:font-bold={namedValue.isKeywordSearch || namedValue.isNonNumericMetadata}
class="px-1"
>
{typeof namedValue.value === 'number'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,56 +14,59 @@
<script lang="ts">
import {fade} from 'svelte/transition';
import {editConceptMutation} from '$lib/queries/conceptQueries';
import {queryEmbeddings} from '$lib/queries/signalQueries';
import {getDatasetContext} from '$lib/stores/datasetStore';
import {getDatasetViewContext} from '$lib/stores/datasetViewStore';
import type {DatasetState} from '$lib/stores/datasetStore';
import type {DatasetViewStore} from '$lib/stores/datasetViewStore';
import {getComputedEmbeddings, getSearchEmbedding, getSearchPath} from '$lib/view_utils';
import {serializePath} from '$lilac';
import {serializePath, type SignalInfoWithTypedSchema} from '$lilac';
import {Button} from 'carbon-components-svelte';
import ThumbsDownFilled from 'carbon-icons-svelte/lib/ThumbsDownFilled.svelte';
import ThumbsUpFilled from 'carbon-icons-svelte/lib/ThumbsUpFilled.svelte';
import {createEventDispatcher} from 'svelte';
import type {Readable} from 'svelte/store';
import {clickOutside} from '../common/clickOutside';
import EmbeddingBadge from './EmbeddingBadge.svelte';
export let details: SpanDetails;
// The coordinates of the click so we can position the popup next to the cursor.
export let clickPosition: {x: number; y: number} | undefined;
let datasetViewStore = getDatasetViewContext();
let datasetStore = getDatasetContext();
export let datasetViewStore: DatasetViewStore;
export let datasetStore: Readable<DatasetState>;
export let embeddings: SignalInfoWithTypedSchema[];
$: console.log(details);
// We cant create mutations from this component since it is hoisted so we pass the function in.
export let addConceptLabel: (
conceptName: string,
conceptNamespace: string,
text: string,
label: boolean
) => void;
$: searchPath = getSearchPath($datasetViewStore, $datasetStore);
const conceptEdit = editConceptMutation();
const dispatch = createEventDispatcher();
function addLabel(label: boolean) {
if (!details.conceptName || !details.conceptNamespace)
throw Error('Label could not be added, no active concept.');
$conceptEdit.mutate([
details.conceptNamespace,
details.conceptName,
{insert: [{text: details.text, label}]}
]);
addConceptLabel(details.conceptNamespace, details.conceptName, details.text, label);
dispatch('click');
}
// Get the embeddings.
const embeddings = queryEmbeddings();
$: searchEmbedding = getSearchEmbedding(
$datasetViewStore,
$datasetStore,
searchPath,
($embeddings.data || []).map(e => e.name)
embeddings.map(e => e.name)
);
$: computedEmbeddings = getComputedEmbeddings($datasetStore, searchPath);
const findSimilar = (embedding: string) => {
if (searchPath == null || searchEmbedding == null) return;
console.log('finding similar', details);
datasetViewStore.addSearch({
path: [serializePath(searchPath)],
query: {
Expand Down
Loading

0 comments on commit acfbe6d

Please sign in to comment.