Skip to content

Commit

Permalink
Added new vector index api
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed Jun 11, 2024
1 parent f80c389 commit d14b470
Show file tree
Hide file tree
Showing 25 changed files with 575 additions and 164 deletions.
6 changes: 4 additions & 2 deletions aligned/compiler/feature_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,8 +1376,10 @@ def indexed(
if self.indexes is None:
self.indexes = []

if not self.embedding_size:
assert embedding_size, 'An embedding size is needed in order to create a vector index'
if not embedding_size:
embedding_size = self.embedding_size

assert embedding_size, 'An embedding size is needed in order to create a vector index'

self.indexes.append(
VectorIndexFactory(
Expand Down
4 changes: 2 additions & 2 deletions aligned/data_source/batch_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _deserialize(cls, value: dict) -> BatchDataSource:
return data_class.from_dict(value)

def all_columns(self, limit: int | None = None) -> RetrivalJob:
return self.all(RequestResult.empty(), limit=limit)
return self.all(RequestResult(set(), set(), None), limit=limit)

def all(self, result: RequestResult, limit: int | None = None) -> RetrivalJob:
return self.all_data(
Expand All @@ -228,7 +228,7 @@ def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:

return FileFullJob(self, request=request, limit=limit)

raise NotImplementedError()
raise NotImplementedError(type(self))

def all_between_dates(
self,
Expand Down
47 changes: 46 additions & 1 deletion aligned/exposed_model/interface.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import polars as pl
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Callable, Coroutine
from dataclasses import dataclass
from aligned.retrival_job import RetrivalJob
from aligned.schemas.codable import Codable
Expand Down Expand Up @@ -93,6 +93,18 @@ def _deserialize(cls, value: dict) -> ExposedModel:
data_class = PredictorFactory.shared().supported_predictors[name_type]
return data_class.from_dict(value)

@staticmethod
def polars_predictor(
callable: Callable[[pl.DataFrame, ModelFeatureStore], Coroutine[None, None, pl.DataFrame]]
) -> 'ExposedModel':
import dill

async def function_wrapper(values: RetrivalJob, store: ModelFeatureStore) -> pl.DataFrame:
features = await store.features_for(values).to_polars()
return await callable(features, store)

return DillPredictor(function=dill.dumps(function_wrapper))

@staticmethod
def ollama_generate(
endpoint: str,
Expand Down Expand Up @@ -160,6 +172,39 @@ def mlflow_server(
)


@dataclass
class DillPredictor(ExposedModel):

function: bytes

model_type: str = 'dill_predictor'

@property
def exposed_at_url(self) -> str | None:
return None

@property
def as_markdown(self) -> str:
return 'A function stored in a dill file.'

async def needed_features(self, store: ModelFeatureStore) -> list[FeatureReference]:
default = store.model.features.default_version
return store.feature_references_for(store.selected_version or default)

async def needed_entities(self, store: ModelFeatureStore) -> set[Feature]:
return store.request().request_result.entities

async def run_polars(self, values: RetrivalJob, store: ModelFeatureStore) -> pl.DataFrame:
import dill
import inspect

function = dill.loads(self.function)
if inspect.iscoroutinefunction(function):
return await function(values, store)
else:
return function(values, store)


@dataclass
class EnitityPredictor(ExposedModel):

Expand Down
6 changes: 3 additions & 3 deletions aligned/feature_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ async def freshness_for(


class WritableFeatureSource:
async def insert(self, job: RetrivalJob, requests: list[RetrivalRequest]) -> None:
async def insert(self, job: RetrivalJob, request: RetrivalRequest) -> None:
raise NotImplementedError(f'Append is not implemented for {type(self)}.')

async def upsert(self, job: RetrivalJob, requests: list[RetrivalRequest]) -> None:
async def upsert(self, job: RetrivalJob, request: RetrivalRequest) -> None:
raise NotImplementedError(f'Upsert write is not implemented for {type(self)}.')

async def overwrite(self, job: RetrivalJob, requests: list[RetrivalRequest]) -> None:
async def overwrite(self, job: RetrivalJob, request: RetrivalRequest) -> None:
raise NotImplementedError(f'Overwrite write is not implemented for {type(self)}.')


Expand Down
90 changes: 83 additions & 7 deletions aligned/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from aligned.schemas.model import EventTrigger
from aligned.schemas.model import Model as ModelSchema
from aligned.schemas.repo_definition import EnricherReference, RepoDefinition, RepoMetadata
from aligned.sources.vector_index import VectorIndex

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -104,6 +105,7 @@ class ContractStore:
feature_views: dict[str, CompiledFeatureView]
combined_feature_views: dict[str, CompiledCombinedFeatureView]
models: dict[str, ModelSchema]
vector_indexes: dict[str, ModelSchema]

@property
def all_models(self) -> list[str]:
Expand All @@ -115,11 +117,13 @@ def __init__(
combined_feature_views: dict[str, CompiledCombinedFeatureView],
models: dict[str, ModelSchema],
feature_source: FeatureSource,
vector_indexes: dict[str, ModelSchema] | None = None,
) -> None:
self.feature_source = feature_source
self.combined_feature_views = combined_feature_views
self.feature_views = feature_views
self.models = models
self.vector_indexes = vector_indexes or {}

@staticmethod
def empty() -> ContractStore:
Expand Down Expand Up @@ -458,12 +462,20 @@ def model(self, name: str) -> ModelFeatureStore:
"""
Selects a model for easy of use.
```python
entities = {"trip_id": [1, 2, 3, ...]}
preds = await store.model("my_model").predict_over(entities).to_polars()
```
Returns:
ModelFeatureStore: A new store that containes the selected model
"""
model = self.models[name]
return ModelFeatureStore(model, self)

def vector_index(self, name: str) -> VectorIndexStore:
return VectorIndexStore(self, self.vector_indexes[name])

def event_triggers_for(self, feature_view: str) -> set[EventTrigger]:
triggers = self.feature_views[feature_view].event_triggers or set()
for model in self.models.values():
Expand Down Expand Up @@ -607,7 +619,7 @@ def feature_view(self, view: str) -> FeatureViewStore:
feature_view = self.feature_views[view]
return FeatureViewStore(self, feature_view, self.event_triggers_for(view))

def add_view(self, view: CompiledFeatureView) -> None:
def add_view(self, view: CompiledFeatureView | FeatureView | FeatureViewWrapper) -> None:
"""
Compiles and adds the feature view to the store
Expand All @@ -625,7 +637,7 @@ class MyFeatureView:
Args:
view (CompiledFeatureView): The feature view to add
"""
self.add_compiled_view(view)
self.add_feature_view(view)

def add_compiled_view(self, view: CompiledFeatureView) -> None:
"""
Expand Down Expand Up @@ -654,11 +666,13 @@ class MyFeatureView:
view.materialized_source or view.source
)

def add_feature_view(self, feature_view: FeatureView | FeatureViewWrapper) -> None:
def add_feature_view(self, feature_view: FeatureView | FeatureViewWrapper | CompiledFeatureView) -> None:
if isinstance(feature_view, FeatureViewWrapper):
self.add_compiled_view(feature_view.compile())
else:
elif isinstance(feature_view, FeatureView):
self.add_compiled_view(feature_view.compile_instance())
else:
self.add_compiled_view(feature_view)

def add_combined_feature_view(self, feature_view: CombinedFeatureView) -> None:
compiled_view = type(feature_view).compile()
Expand Down Expand Up @@ -691,6 +705,10 @@ def add_compiled_model(self, model: ModelSchema) -> None:

source = PredictModelSource(self.model(model.name))

if isinstance(model.predictions_view.source, VectorIndex):
index_name = model.predictions_view.source.vector_index_name() or model.name
self.vector_indexes[index_name] = model

if isinstance(self.feature_source, BatchFeatureSource) and source is not None:
self.feature_source.sources[FeatureLocation.model(model.name).identifier] = source

Expand Down Expand Up @@ -863,7 +881,7 @@ async def insert_into(
values = RetrivalJob.from_convertable(values, write_request)

if isinstance(source, WritableFeatureSource):
await source.insert(values, [write_request])
await source.insert(values, write_request)
elif isinstance(source, DataFileReference):
import polars as pl

Expand Down Expand Up @@ -908,7 +926,7 @@ async def upsert_into(
values = RetrivalJob.from_convertable(values, write_request)

if isinstance(source, WritableFeatureSource):
await source.upsert(values, [write_request])
await source.upsert(values, write_request)
elif isinstance(source, DataFileReference):
new_df = (await values.to_lazy_polars()).select(write_request.all_returned_columns)
entities = list(write_request.entity_names)
Expand Down Expand Up @@ -942,7 +960,7 @@ async def overwrite(
values = RetrivalJob.from_convertable(values, write_request)

if isinstance(source, WritableFeatureSource):
await source.overwrite(values, [write_request])
await source.overwrite(values, write_request)
elif isinstance(source, DataFileReference):
df = (await values.to_lazy_polars()).select(write_request.all_returned_columns)
await source.write_polars(df)
Expand Down Expand Up @@ -1781,3 +1799,61 @@ async def freshness(self) -> datetime | None:
location = FeatureLocation.feature_view(view.name)

return (await self.source.freshness_for({location: view.event_timestamp}))[location]


class VectorIndexStore:

store: ContractStore
model: ModelSchema

def __init__(self, store: ContractStore, model: ModelSchema):
if model.predictions_view.source is None:
raise ValueError(f"An output source on the model {model.name} is needed")

if not isinstance(model.predictions_view.source, VectorIndex):
message = (
f"An output source on the model {model.name} needs to be of type VectorIndex,"
f"got {type(model.predictions_view.source)}"
)
raise ValueError(message)

self.store = store
self.model = model

def nearest_n_to(
self, entities: RetrivalJob | ConvertableToRetrivalJob, number_of_records: int
) -> RetrivalJob:
source = self.model.predictions_view.source
assert isinstance(source, VectorIndex)

embeddings = self.model.predictions_view.embeddings()
n_embeddings = len(embeddings)

if n_embeddings == 0:
raise ValueError(f"Need at least one embedding to search. Got {n_embeddings}")
if n_embeddings > 1:
raise ValueError('Got more than one embedding, it is therefore unclear which to use.')

embedding = embeddings[0]
response = self.model.predictions_view.request(self.model.name)

def contains_embedding() -> bool:
if isinstance(entities, RetrivalJob):
return embedding.name in entities.loaded_columns
elif isinstance(entities, dict):
return embedding.name in entities
elif isinstance(entities, (pl.DataFrame, pd.DataFrame, pl.LazyFrame)):
return embedding.name in entities.columns
raise ValueError('Unable to determine if the entities contains the embedding')

if self.model.exposed_model and not contains_embedding():
model_store = self.store.model(self.model.name)
features: RetrivalJob = model_store.predict_over(entities)
else:
# Assumes that we can lookup the embeddings from the source
feature_ref = FeatureReference(
embedding.name, FeatureLocation.model(self.model.name), dtype=embedding.dtype
)
features: RetrivalJob = self.store.features_for(entities, features=[feature_ref.identifier])

return source.nearest_n_to(features, number_of_records, response)
2 changes: 1 addition & 1 deletion aligned/redis/tests/test_redis_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ async def test_write_job(mocker, retrival_request: RetrivalRequest) -> None: #
config = RedisConfig.localhost()
source = RedisSource(config)

await source.insert(insert_facts, [retrival_request])
await source.insert(insert_facts, retrival_request)

job = FactualRedisJob(RedisConfig.localhost(), requests=[retrival_request], facts=facts)
data = await job.to_lazy_polars()
Expand Down
33 changes: 22 additions & 11 deletions aligned/request/retrival_request.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections import defaultdict
from dataclasses import dataclass, field

import pyarrow as pa

from aligned.schemas.codable import Codable
from aligned.schemas.derivied_feature import AggregatedFeature, AggregateOver, DerivedFeature
from aligned.schemas.feature import EventTimestamp, Feature, FeatureLocation
Expand Down Expand Up @@ -29,13 +31,12 @@ class RetrivalRequest(Codable):
derived_features: set[DerivedFeature]
aggregated_features: set[AggregatedFeature] = field(default_factory=set)
event_timestamp_request: EventTimestampRequest | None = field(default=None)
features_to_include: set[str] = field(default_factory=set)

@property
def event_timestamp(self) -> EventTimestamp | None:
return self.event_timestamp_request.event_timestamp if self.event_timestamp_request else None

features_to_include: set[str] = field(default_factory=set)

def __init__(
self,
name: str,
Expand Down Expand Up @@ -76,19 +77,19 @@ def filter_features(self, feature_names: set[str]) -> 'RetrivalRequest':
)

@property
def all_returned_columns(self) -> list[str]:

result = self.entity_names
def all_returned_features(self) -> list[Feature]:
result = self.entities

if self.event_timestamp and (
all(agg.aggregate_over.window is not None for agg in self.aggregated_features)
or len(self.aggregated_features) == 0
):
result = result.union({self.event_timestamp.name})
result = result.union({self.event_timestamp.as_feature()})

if self.aggregated_features:
agg_names = [feat.name for feat in self.aggregated_features]
derived_after_aggs_name: set[str] = set()
agg_features = [feat.derived_feature for feat in self.aggregated_features]
agg_names = list(self.aggregated_features)
derived_after_aggs: set[Feature] = set()
derived_features = {der.name: der for der in self.derived_features}

def is_dependent_on_agg_feature(feature: DerivedFeature) -> bool:
Expand All @@ -104,11 +105,15 @@ def is_dependent_on_agg_feature(feature: DerivedFeature) -> bool:

for feat in self.derived_features:
if is_dependent_on_agg_feature(feat):
derived_after_aggs_name.add(feat.name)
derived_after_aggs.add(feat)

return agg_names + list(derived_after_aggs_name) + list(result)
return agg_features + list(derived_after_aggs) + list(result)

return list(result.union(self.all_feature_names))
return list(result.union(self.all_features))

@property
def all_returned_columns(self) -> list[str]:
return [feature.name for feature in self.all_returned_features]

@property
def returned_features(self) -> set[Feature]:
Expand Down Expand Up @@ -182,6 +187,12 @@ def derived_features_order(self) -> list[set[DerivedFeature]]:

return feature_orders

def pyarrow_schema(self) -> pa.Schema:
from aligned.schemas.vector_storage import pyarrow_schema

sorted_features = sorted(self.all_returned_features, key=lambda feature: feature.name)
return pyarrow_schema(sorted_features)

def aggregate_over(self) -> dict[AggregateOver, set[AggregatedFeature]]:
features = defaultdict(set)
for feature in self.aggregated_features:
Expand Down
Loading

0 comments on commit d14b470

Please sign in to comment.