|
11 | 11 | _VectorIndexConfigHNSW,
|
12 | 12 | _VectorIndexConfigFlat,
|
13 | 13 | Vectorizers,
|
| 14 | + ReferenceProperty, |
14 | 15 | )
|
15 | 16 | from weaviate.collections.classes.data import DataObject
|
16 | 17 | from weaviate.collections.classes.grpc import _MultiTargetVectorJoin
|
17 | 18 | from weaviate.exceptions import WeaviateInvalidInputError
|
| 19 | +from weaviate.types import INCLUDE_VECTOR |
18 | 20 |
|
19 | 21 |
|
20 | 22 | def test_create_named_vectors_throws_error_in_old_version(
|
@@ -756,3 +758,46 @@ def test_deprecated_syntax(collection_factory: CollectionFactory):
|
756 | 758 | return_metadata=wvc.query.MetadataQuery.full(),
|
757 | 759 | )
|
758 | 760 | assert "Providing lists of lists has been deprecated" in str(e)
|
| 761 | + |
| 762 | + |
| 763 | +@pytest.mark.parametrize( |
| 764 | + "include_vector, expected", |
| 765 | + [ |
| 766 | + (False, {}), |
| 767 | + (["bringYourOwn1"], {"bringYourOwn1": [0, 1, 2]}), |
| 768 | + # TODO: to be uncommented when https://github.com/weaviate/weaviate/issues/6279 is resolved |
| 769 | + # (True, {"bringYourOwn1": [0, 1, 2], "bringYourOwn2": [3, 4, 5]}) |
| 770 | + ], |
| 771 | +) |
| 772 | +def test_include_vector_on_references( |
| 773 | + collection_factory: CollectionFactory, include_vector: INCLUDE_VECTOR, expected: dict |
| 774 | +) -> None: |
| 775 | + """Test include vector on reference""" |
| 776 | + dummy = collection_factory() |
| 777 | + if dummy._connection._weaviate_version.is_lower_than(1, 24, 0): |
| 778 | + pytest.skip("Named vectorizers are only supported in Weaviate v1.24.0 and higher.") |
| 779 | + |
| 780 | + ref_collection = collection_factory( |
| 781 | + name="Target", |
| 782 | + vectorizer_config=[ |
| 783 | + wvc.config.Configure.NamedVectors.none(name="bringYourOwn1"), |
| 784 | + wvc.config.Configure.NamedVectors.none(name="bringYourOwn2"), |
| 785 | + ], |
| 786 | + ) |
| 787 | + |
| 788 | + TO_UUID = ref_collection.data.insert( |
| 789 | + properties={}, vector={"bringYourOwn1": [0, 1, 2], "bringYourOwn2": [3, 4, 5]} |
| 790 | + ) |
| 791 | + |
| 792 | + collection = collection_factory( |
| 793 | + name="Source", |
| 794 | + references=[ReferenceProperty(name="hasRef", target_collection=ref_collection.name)], |
| 795 | + ) |
| 796 | + |
| 797 | + collection.data.insert({}, references={"hasRef": TO_UUID}) |
| 798 | + |
| 799 | + objs = collection.query.fetch_objects( |
| 800 | + return_references=wvc.query.QueryReference(link_on="hasRef", include_vector=include_vector) |
| 801 | + ).objects |
| 802 | + |
| 803 | + assert objs[0].references["hasRef"].objects[0].vector == expected |
0 commit comments