Skip to content

Commit 05fdbb6

Browse files
authored
variable contentfile chunk sizes (#1980)
* adding a setting to be able to override the chunk size when embedding * adding collection name override param * updating spec * adding test * added test. switched back to tiktokenizer with chunk size param * fixing test * fix test * catching exception * fixing tests
1 parent 40bef0e commit 05fdbb6

File tree

9 files changed

+126
-21
lines changed

9 files changed

+126
-21
lines changed

frontends/api/src/generated/v0/api.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10901,6 +10901,7 @@ export const VectorContentFilesSearchApiAxiosParamCreator = function (
1090110901
/**
1090210902
* Vector Search for content
1090310903
* @summary Content File Vector Search
10904+
* @param {string} [collection_name] Manually specify the name of the Qdrant collection to query
1090410905
* @param {Array<string>} [content_feature_type] The feature type of the content file. Possible options are at api/v1/course_features/
1090510906
* @param {Array<string>} [course_number] Course number of the content file
1090610907
* @param {Array<string>} [file_extension] The extension of the content file.
@@ -10917,6 +10918,7 @@ export const VectorContentFilesSearchApiAxiosParamCreator = function (
1091710918
* @throws {RequiredError}
1091810919
*/
1091910920
vectorContentFilesSearchRetrieve: async (
10921+
collection_name?: string,
1092010922
content_feature_type?: Array<string>,
1092110923
course_number?: Array<string>,
1092210924
file_extension?: Array<string>,
@@ -10947,6 +10949,10 @@ export const VectorContentFilesSearchApiAxiosParamCreator = function (
1094710949
const localVarHeaderParameter = {} as any
1094810950
const localVarQueryParameter = {} as any
1094910951

10952+
if (collection_name !== undefined) {
10953+
localVarQueryParameter["collection_name"] = collection_name
10954+
}
10955+
1095010956
if (content_feature_type) {
1095110957
localVarQueryParameter["content_feature_type"] = content_feature_type
1095210958
}
@@ -11025,6 +11031,7 @@ export const VectorContentFilesSearchApiFp = function (
1102511031
/**
1102611032
* Vector Search for content
1102711033
* @summary Content File Vector Search
11034+
* @param {string} [collection_name] Manually specify the name of the Qdrant collection to query
1102811035
* @param {Array<string>} [content_feature_type] The feature type of the content file. Possible options are at api/v1/course_features/
1102911036
* @param {Array<string>} [course_number] Course number of the content file
1103011037
* @param {Array<string>} [file_extension] The extension of the content file.
@@ -11041,6 +11048,7 @@ export const VectorContentFilesSearchApiFp = function (
1104111048
* @throws {RequiredError}
1104211049
*/
1104311050
async vectorContentFilesSearchRetrieve(
11051+
collection_name?: string,
1104411052
content_feature_type?: Array<string>,
1104511053
course_number?: Array<string>,
1104611054
file_extension?: Array<string>,
@@ -11062,6 +11070,7 @@ export const VectorContentFilesSearchApiFp = function (
1106211070
> {
1106311071
const localVarAxiosArgs =
1106411072
await localVarAxiosParamCreator.vectorContentFilesSearchRetrieve(
11073+
collection_name,
1106511074
content_feature_type,
1106611075
course_number,
1106711076
file_extension,
@@ -11116,6 +11125,7 @@ export const VectorContentFilesSearchApiFactory = function (
1111611125
): AxiosPromise<ContentFileVectorSearchResponse> {
1111711126
return localVarFp
1111811127
.vectorContentFilesSearchRetrieve(
11128+
requestParameters.collection_name,
1111911129
requestParameters.content_feature_type,
1112011130
requestParameters.course_number,
1112111131
requestParameters.file_extension,
@@ -11141,6 +11151,13 @@ export const VectorContentFilesSearchApiFactory = function (
1114111151
* @interface VectorContentFilesSearchApiVectorContentFilesSearchRetrieveRequest
1114211152
*/
1114311153
export interface VectorContentFilesSearchApiVectorContentFilesSearchRetrieveRequest {
11154+
/**
11155+
* Manually specify the name of the Qdrant collection to query
11156+
* @type {string}
11157+
* @memberof VectorContentFilesSearchApiVectorContentFilesSearchRetrieve
11158+
*/
11159+
readonly collection_name?: string
11160+
1114411161
/**
1114511162
* The feature type of the content file. Possible options are at api/v1/course_features/
1114611163
* @type {Array<string>}
@@ -11247,6 +11264,7 @@ export class VectorContentFilesSearchApi extends BaseAPI {
1124711264
) {
1124811265
return VectorContentFilesSearchApiFp(this.configuration)
1124911266
.vectorContentFilesSearchRetrieve(
11267+
requestParameters.collection_name,
1125011268
requestParameters.content_feature_type,
1125111269
requestParameters.course_number,
1125211270
requestParameters.file_extension,

main/settings.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,3 +848,10 @@ def get_all_config_keys():
848848
AI_BUDGET_DURATION = get_string(name="AI_BUDGET_DURATION", default="60m")
849849
AI_MAX_BUDGET = get_float(name="AI_MAX_BUDGET", default=0.05)
850850
AI_ANON_LIMIT_MULTIPLIER = get_float(name="AI_ANON_LIMIT_MULTIPLIER", default=10.0)
851+
CONTENT_FILE_EMBEDDING_CHUNK_SIZE_OVERRIDE = get_int(
852+
name="CONTENT_FILE_EMBEDDING_CHUNK_SIZE", default=None
853+
)
854+
CONTENT_FILE_EMBEDDING_CHUNK_OVERLAP = get_int(
855+
name="CONTENT_FILE_EMBEDDING_CHUNK_OVERLAP",
856+
default=200, # default that the tokenizer uses
857+
)

openapi/specs/v0.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,12 @@ paths:
831831
description: Vector Search for content
832832
summary: Content File Vector Search
833833
parameters:
834+
- in: query
835+
name: collection_name
836+
schema:
837+
type: string
838+
minLength: 1
839+
description: Manually specify the name of the Qdrant collection to query
834840
- in: query
835841
name: content_feature_type
836842
schema:

vector_search/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import pytest
3+
from langchain.text_splitter import RecursiveCharacterTextSplitter
34
from qdrant_client.http.models.models import CountResult
45

56
from vector_search.encoders.base import BaseEncoder
@@ -30,11 +31,14 @@ def _use_dummy_encoder(settings):
3031
def _use_test_qdrant_settings(settings, mocker):
3132
settings.QDRANT_HOST = "https://test"
3233
settings.QDRANT_BASE_COLLECTION_NAME = "test"
34+
settings.CONTENT_FILE_EMBEDDING_CHUNK_OVERLAP = 0
3335
mock_qdrant = mocker.patch("qdrant_client.QdrantClient")
3436
mock_qdrant.scroll.return_value = [
3537
[],
3638
None,
3739
]
40+
get_text_splitter_patch = mocker.patch("vector_search.utils._get_text_splitter")
41+
get_text_splitter_patch.return_value = RecursiveCharacterTextSplitter()
3842
mock_qdrant.count.return_value = CountResult(count=10)
3943
mocker.patch(
4044
"vector_search.utils.qdrant_client",

vector_search/serializers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,10 @@ class ContentFileVectorSearchRequestSerializer(serializers.Serializer):
229229
"The readable_id value of the parent learning resource for the content file"
230230
),
231231
)
232+
collection_name = serializers.CharField(
233+
required=False,
234+
help_text=("Manually specify the name of the Qdrant collection to query"),
235+
)
232236

233237

234238
class ContentFileVectorSearchResponseSerializer(SearchResponseSerializer):

vector_search/utils.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import logging
12
import uuid
23

34
from django.conf import settings
4-
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
5+
from langchain.text_splitter import TokenTextSplitter
56
from qdrant_client import QdrantClient, models
67

78
from learning_resources.models import LearningResource
@@ -21,6 +22,8 @@
2122
)
2223
from vector_search.encoders.utils import dense_encoder
2324

25+
logger = logging.getLogger(__name__)
26+
2427

2528
def qdrant_client():
2629
return QdrantClient(
@@ -178,17 +181,17 @@ def _get_text_splitter(encoder):
178181
"""
179182
Get the text splitter to use based on the encoder
180183
"""
184+
chunk_params = {
185+
"chunk_overlap": settings.CONTENT_FILE_EMBEDDING_CHUNK_OVERLAP,
186+
}
181187
if hasattr(encoder, "token_encoding_name") and encoder.token_encoding_name:
182-
# leverage tiktoken to ensure we stay within token limits
183-
return TokenTextSplitter(encoding_name=encoder.token_encoding_name)
184-
else:
185-
# default for use with fastembed
186-
return RecursiveCharacterTextSplitter(
187-
chunk_size=512,
188-
chunk_overlap=50,
189-
add_start_index=True,
190-
separators=["\n\n", "\n", ".", " ", ""],
191-
)
188+
chunk_params["encoding_name"] = encoder.token_encoding_name
189+
190+
if settings.CONTENT_FILE_EMBEDDING_CHUNK_SIZE_OVERRIDE:
191+
chunk_params["chunk_size"] = settings.CONTENT_FILE_EMBEDDING_CHUNK_SIZE_OVERRIDE
192+
193+
# leverage tiktoken to ensure we stay within token limits
194+
return TokenTextSplitter(**chunk_params)
192195

193196

194197
def _process_content_embeddings(serialized_content):
@@ -254,10 +257,16 @@ def _process_content_embeddings(serialized_content):
254257
metadata.extend(split_metadatas)
255258
ids.extend(split_ids)
256259
if len(resource_points) > 0:
257-
client.update_vectors(
258-
collection_name=RESOURCES_COLLECTION_NAME,
259-
points=resource_points,
260-
)
260+
try:
261+
# sometimes we can't update the multi-vector if max size is exceeded
262+
263+
client.update_vectors(
264+
collection_name=RESOURCES_COLLECTION_NAME,
265+
points=resource_points,
266+
)
267+
except Exception as e: # noqa: BLE001
268+
msg = f"Exceeded multi-vector max size: {e}"
269+
logger.warning(msg)
261270
return points_generator(ids, metadata, embeddings, vector_name)
262271

263272

vector_search/utils_test.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from langchain.text_splitter import RecursiveCharacterTextSplitter
2+
from django.conf import settings
33
from qdrant_client import models
44
from qdrant_client.models import PointStruct
55

@@ -234,11 +234,30 @@ def test_get_text_splitter(mocker):
234234
"""
235235
Test that the correct splitter is returned based on encoder
236236
"""
237+
settings.CONTENT_FILE_EMBEDDING_CHUNK_SIZE_OVERRIDE = None
237238
encoder = dense_encoder()
238239
encoder.token_encoding_name = None
239240
mocked_splitter = mocker.patch("vector_search.utils.TokenTextSplitter")
240-
splitter = _get_text_splitter(encoder)
241-
assert isinstance(splitter, RecursiveCharacterTextSplitter)
241+
_get_text_splitter(encoder)
242+
assert "encoding_name" not in mocked_splitter.mock_calls[0].kwargs
242243
encoder.token_encoding_name = "cl100k_base" # noqa: S105
243-
splitter = _get_text_splitter(encoder)
244-
mocked_splitter.assert_called()
244+
_get_text_splitter(encoder)
245+
assert "encoding_name" in mocked_splitter.mock_calls[1].kwargs
246+
247+
248+
def test_text_splitter_chunk_size_override(mocker):
249+
"""
250+
Test that we always use the recursive splitter if chunk size is overriden
251+
"""
252+
chunk_size = 100
253+
settings.CONTENT_FILE_EMBEDDING_CHUNK_SIZE_OVERRIDE = chunk_size
254+
settings.CONTENT_FILE_EMBEDDING_CHUNK_OVERLAP = chunk_size / 10
255+
encoder = dense_encoder()
256+
mocked_splitter = mocker.patch("vector_search.utils.TokenTextSplitter")
257+
encoder.token_encoding_name = "cl100k_base" # noqa: S105
258+
_get_text_splitter(encoder)
259+
assert mocked_splitter.mock_calls[0].kwargs["chunk_size"] == 100
260+
mocked_splitter = mocker.patch("vector_search.utils.TokenTextSplitter")
261+
settings.CONTENT_FILE_EMBEDDING_CHUNK_SIZE_OVERRIDE = None
262+
_get_text_splitter(encoder)
263+
assert "chunk_size" not in mocked_splitter.mock_calls[0].kwargs

vector_search/views.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,19 @@ def get(self, request):
114114
query_text = request_data.data.get("q", "")
115115
limit = request_data.data.get("limit", 10)
116116
offset = request_data.data.get("offset", 0)
117+
collection_name_override = request_data.data.get("collection_name")
118+
collection_name = CONTENT_FILES_COLLECTION_NAME
119+
if collection_name_override:
120+
collection_name = (
121+
f"{settings.QDRANT_BASE_COLLECTION_NAME}.{collection_name_override}"
122+
)
123+
117124
response = vector_search(
118125
query_text,
119126
limit=limit,
120127
offset=offset,
121128
params=request_data.data,
122-
search_collection=CONTENT_FILES_COLLECTION_NAME,
129+
search_collection=collection_name,
123130
)
124131
if request_data.data.get("dev_mode"):
125132
return Response(response)

vector_search/views_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,34 @@ def test_content_file_vector_search_filters_empty_query(mocker, client):
188188
),
189189
]
190190
)
191+
192+
193+
def test_content_file_vector_search_filters_custom_collection(mocker, client):
194+
"""Test content file vector search uses custom collection if specified"""
195+
196+
mock_qdrant = mocker.patch("qdrant_client.QdrantClient")
197+
custom_collection_name = "foo_bar_collection"
198+
mock_qdrant.scroll.return_value = [[]]
199+
mocker.patch(
200+
"vector_search.utils.qdrant_client",
201+
return_value=mock_qdrant,
202+
)
203+
mock_qdrant.count.return_value = CountResult(count=10)
204+
# omit the q param
205+
params = {
206+
"offered_by": ["ocw"],
207+
"platform": ["edx"],
208+
"key": ["testfilename.pdf"],
209+
"course_number": ["test"],
210+
"content_feature_type": ["test_feature"],
211+
"run_readable_id": ["test_run_id"],
212+
"resource_readable_id": ["test_resource_id_1", "test_resource_id_2"],
213+
"collection_name": custom_collection_name,
214+
}
215+
216+
client.get(reverse("vector_search:v0:vector_content_files_search"), data=params)
217+
assert (
218+
mock_qdrant.scroll.mock_calls[0]
219+
.kwargs["collection_name"]
220+
.endswith(custom_collection_name)
221+
)

0 commit comments

Comments
 (0)