Skip to content

Commit 14f172c

Browse files
committed
code
1 parent 55495e7 commit 14f172c

File tree

3 files changed

+243
-41
lines changed

3 files changed

+243
-41
lines changed

core/bin/qdrant_migrator.rs

Lines changed: 199 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,15 @@ use clap::{Parser, Subcommand};
55
use dust::{
66
data_sources::qdrant::{QdrantClients, QdrantCluster, QdrantDataSourceConfig},
77
project,
8+
run::Credentials,
89
stores::postgres,
910
stores::store,
1011
utils,
1112
};
13+
use qdrant_client::{
14+
prelude::Payload,
15+
qdrant::{self, PointId, ScrollPoints},
16+
};
1217

1318
#[derive(Debug, Subcommand)]
1419
enum Commands {
@@ -19,14 +24,16 @@ enum Commands {
1924
data_source_id: String,
2025
},
2126
#[command(arg_required_else_help = true)]
22-
#[command(about = "Set `shadow_write_cluster` (!!! creates collection on `shadow_write_cluster`)", long_about = None)]
27+
#[command(about = "Set `shadow_write_cluster` \
28+
(!!! creates collection on `shadow_write_cluster`)", long_about = None)]
2329
SetShadowWrite {
2430
project_id: i64,
2531
data_source_id: String,
2632
cluster: String,
2733
},
2834
#[command(arg_required_else_help = true)]
29-
#[command(about = "Clear `shadow_write_cluster` (!!! deletes collection from `shadow_write_cluster`)", long_about = None)]
35+
#[command(about = "Clear `shadow_write_cluster` \
36+
(!!! deletes collection from `shadow_write_cluster`)", long_about = None)]
3037
ClearShadowWrite {
3138
project_id: i64,
3239
data_source_id: String,
@@ -38,7 +45,8 @@ enum Commands {
3845
data_source_id: String,
3946
},
4047
#[command(arg_required_else_help = true)]
41-
#[command(about = "Switch `shadow_write_cluster` and `cluster` (!!! moves read traffic to `shadow_write_cluster`)", long_about = None)]
48+
#[command(about = "Switch `shadow_write_cluster` and `cluster` \
49+
(!!! moves read traffic to `shadow_write_cluster`)", long_about = None)]
4250
CommitShadowWrite {
4351
project_id: i64,
4452
data_source_id: String,
@@ -102,6 +110,53 @@ fn main() -> Result<()> {
102110
}
103111
));
104112

113+
let qdrant_client = qdrant_clients.main_client(&ds.config().qdrant_config);
114+
match qdrant_client
115+
.collection_info(ds.qdrant_collection())
116+
.await?
117+
.result
118+
{
119+
Some(info) => {
120+
utils::info(&format!(
121+
"[MAIN] Qdrant collection: cluster={} collection={} status={} \
122+
points_count={}",
123+
qdrant_clients
124+
.main_cluster(&ds.config().qdrant_config)
125+
.to_string(),
126+
ds.qdrant_collection(),
127+
info.status.to_string(),
128+
info.points_count,
129+
));
130+
}
131+
None => Err(anyhow!("Qdrant collection not found"))?,
132+
}
133+
134+
match qdrant_clients.shadow_write_cluster(&ds.config().qdrant_config) {
135+
Some(shadow_write_cluster) => {
136+
let shadow_write_qdrant_client = qdrant_clients
137+
.shadow_write_client(&ds.config().qdrant_config)
138+
.unwrap();
139+
match shadow_write_qdrant_client
140+
.collection_info(ds.qdrant_collection())
141+
.await?
142+
.result
143+
{
144+
Some(info) => {
145+
utils::info(&format!(
146+
"[SHADOW] Qdrant collection: cluster={} collection={} status={}\
147+
points_count={}",
148+
shadow_write_cluster.to_string(),
149+
ds.qdrant_collection(),
150+
info.status.to_string(),
151+
info.points_count,
152+
));
153+
}
154+
None => Err(anyhow!("Qdrant collection not found"))?,
155+
}
156+
}
157+
None => (),
158+
};
159+
105160
Ok::<(), anyhow::Error>(())
106161
}
107162
Commands::SetShadowWrite {
@@ -128,11 +183,36 @@ fn main() -> Result<()> {
128183
}),
129184
};
130185

131-
// TODO(spolu): Create collection on shadow_write_cluster
186+
// Create collection on shadow_write_cluster.
187+
let shadow_write_qdrant_client =
188+
match qdrant_clients.shadow_write_client(&config.qdrant_config) {
189+
Some(client) => client,
190+
None => unreachable!(),
191+
};
192+
193+
// We send a fake credentials here since this is not really used for OpenAI to get
194+
// the embeedding size (which is what happens here). May need to be revisited in
195+
// future.
196+
let mut credentials = Credentials::new();
197+
credentials.insert("OPENAI_API_KEY".to_string(), "foo".to_string());
198+
199+
ds.create_qdrant_collection(credentials, shadow_write_qdrant_client.clone())
200+
.await?;
201+
202+
utils::done(&format!(
203+
"Created qdrant shadow_write_cluster collection: \
204+
collection={} shadow_write_cluster={}",
205+
ds.qdrant_collection(),
206+
match qdrant_clients.shadow_write_cluster(&config.qdrant_config) {
207+
Some(cluster) => cluster.to_string(),
208+
None => "none".to_string(),
209+
}
210+
));
132211

212+
// Add shadow_write_cluster to config.
133213
ds.update_config(store, &config).await?;
134214

135-
utils::info(&format!(
215+
utils::done(&format!(
136216
"Updated data source: collection={} cluster={} shadow_write_cluster={}",
137217
ds.qdrant_collection(),
138218
qdrant_clients
@@ -150,15 +230,62 @@ fn main() -> Result<()> {
150230
project_id,
151231
data_source_id,
152232
} => {
233+
// This is the most dangerous command of all as it is the only one to actually
234+
// delete data in an unrecoverable way.
153235
let project = project::Project::new_from_id(project_id);
154236
let mut ds = match store.load_data_source(&project, &data_source_id).await? {
155237
Some(ds) => ds,
156238
None => Err(anyhow!("Data source not found"))?,
157239
};
158240

159-
let mut config = ds.config().clone();
241+
// Delete collection on shadow_write_cluster.
242+
let shadow_write_qdrant_client =
243+
match qdrant_clients.shadow_write_client(&ds.config().qdrant_config) {
244+
Some(client) => client,
245+
None => Err(anyhow!("No shadow write cluster to clear"))?,
246+
};
247+
248+
match shadow_write_qdrant_client
249+
.collection_info(ds.qdrant_collection())
250+
.await?
251+
.result
252+
{
253+
Some(info) => {
254+
// confirm
255+
match utils::confirm(&format!(
256+
"[DANGER] Are you sure you want to delete this qdrant \
257+
shadow_write_cluster collection? \
258+
(this is definitive) shadow_write_cluster={} points_count={}",
259+
match qdrant_clients.shadow_write_cluster(&ds.config().qdrant_config) {
260+
Some(cluster) => cluster.to_string(),
261+
None => "none".to_string(),
262+
}
263+
.to_string(),
264+
info.points_count,
265+
))? {
266+
true => (),
267+
false => Err(anyhow!("Aborted"))?,
268+
}
269+
}
270+
None => Err(anyhow!("Qdrant collection not found"))?,
271+
};
272+
273+
shadow_write_qdrant_client
274+
.delete_collection(ds.qdrant_collection())
275+
.await?;
276+
277+
utils::done(&format!(
278+
"Deleted qdrant shadow_write_cluster collection: \
279+
collection={} shadow_write_cluster={}",
280+
ds.qdrant_collection(),
281+
match qdrant_clients.shadow_write_cluster(&ds.config().qdrant_config) {
282+
Some(cluster) => cluster.to_string(),
283+
None => "none".to_string(),
284+
}
285+
));
160286

161-
// TODO(spolu): delete collection from shadow_write_cluster
287+
// Remove shadow_write_cluster from config.
288+
let mut config = ds.config().clone();
162289

163290
config.qdrant_config = match config.qdrant_config {
164291
Some(c) => Some(QdrantDataSourceConfig {
@@ -173,7 +300,7 @@ fn main() -> Result<()> {
173300

174301
ds.update_config(store, &config).await?;
175302

176-
utils::info(&format!(
303+
utils::done(&format!(
177304
"Updated data source: collection={} cluster={} shadow_write_cluster={}",
178305
ds.qdrant_collection(),
179306
qdrant_clients
@@ -190,7 +317,69 @@ fn main() -> Result<()> {
190317
Commands::MigrateShadowWrite {
191318
project_id,
192319
data_source_id,
193-
} => Ok::<(), anyhow::Error>(()),
320+
} => {
321+
// This is the most dangerous command of all as it is the only one to actually
322+
// delete data in an unrecoverable way.
323+
let project = project::Project::new_from_id(project_id);
324+
let ds = match store.load_data_source(&project, &data_source_id).await? {
325+
Some(ds) => ds,
326+
None => Err(anyhow!("Data source not found"))?,
327+
};
328+
329+
let qdrant_client = qdrant_clients.main_client(&ds.config().qdrant_config);
330+
331+
// Delete collection on shadow_write_cluster.
332+
let shadow_write_qdrant_client =
333+
match qdrant_clients.shadow_write_client(&ds.config().qdrant_config) {
334+
Some(client) => client,
335+
None => Err(anyhow!("No shadow write cluster to migrate to"))?,
336+
};
337+
338+
let mut page_offset: Option<PointId> = None;
339+
let mut total: usize = 0;
340+
loop {
341+
let scroll_results = qdrant_client
342+
.scroll(&ScrollPoints {
343+
collection_name: ds.qdrant_collection(),
344+
with_vectors: Some(true.into()),
345+
with_payload: Some(true.into()),
346+
limit: Some(256),
347+
offset: page_offset,
348+
..Default::default()
349+
})
350+
.await?;
351+
352+
let count = scroll_results.result.len();
353+
354+
let points = scroll_results
355+
.result
356+
.into_iter()
357+
.map(|r| {
358+
qdrant::PointStruct::new(
359+
r.id.unwrap(),
360+
r.vectors.unwrap(),
361+
Payload::new_from_hashmap(r.payload),
362+
)
363+
})
364+
.collect::<Vec<_>>();
365+
366+
shadow_write_qdrant_client
367+
.upsert_points(ds.qdrant_collection(), points, None)
368+
.await?;
369+
370+
total += count;
371+
utils::info(&format!("Migrated points: count={} total={}", count, total));
372+
373+
page_offset = scroll_results.next_page_offset;
374+
if page_offset.is_none() {
375+
break;
376+
}
377+
}
378+
379+
utils::info(&format!("Done migrating: total={}", total));
380+
381+
Ok::<(), anyhow::Error>(())
382+
}
194383
Commands::CommitShadowWrite {
195384
project_id,
196385
data_source_id,
@@ -207,7 +396,7 @@ fn main() -> Result<()> {
207396
Some(c) => match c.shadow_write_cluster {
208397
Some(cluster) => Some(QdrantDataSourceConfig {
209398
cluster: cluster,
210-
shadow_write_cluster: None,
399+
shadow_write_cluster: Some(c.cluster),
211400
}),
212401
None => Err(anyhow!("No shadow write cluster to commit"))?,
213402
},

core/src/data_sources/data_source.rs

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -312,39 +312,14 @@ impl DataSource {
312312
Ok(())
313313
}
314314

315-
pub async fn setup(
315+
pub async fn create_qdrant_collection(
316316
&self,
317317
credentials: Credentials,
318-
qdrant_clients: QdrantClients,
318+
qdrant_client: Arc<QdrantClient>,
319319
) -> Result<()> {
320-
let qdrant_client = qdrant_clients.main_client(&self.config.qdrant_config);
321-
322320
let mut embedder = provider(self.config.provider_id).embedder(self.config.model_id.clone());
323321
embedder.initialize(credentials).await?;
324322

325-
// GCP store created data to test GCP.
326-
let bucket = match std::env::var("DUST_DATA_SOURCES_BUCKET") {
327-
Ok(bucket) => bucket,
328-
Err(_) => Err(anyhow!("DUST_DATA_SOURCES_BUCKET is not set"))?,
329-
};
330-
331-
let bucket_path = format!("{}/{}", self.project.project_id(), self.internal_id);
332-
let data_source_created_path = format!("{}/created.txt", bucket_path);
333-
334-
Object::create(
335-
&bucket,
336-
format!("{}", self.created).as_bytes().to_vec(),
337-
&data_source_created_path,
338-
"application/text",
339-
)
340-
.await?;
341-
342-
utils::done(&format!(
343-
"Created GCP bucket for data_source `{}`",
344-
self.data_source_id
345-
));
346-
347-
// Qdrant create collection.
348323
qdrant_client
349324
.create_collection(&qdrant::CreateCollection {
350325
collection_name: self.qdrant_collection(),
@@ -381,6 +356,41 @@ impl DataSource {
381356
..Default::default()
382357
})
383358
.await?;
359+
Ok(())
360+
}
361+
362+
pub async fn setup(
363+
&self,
364+
credentials: Credentials,
365+
qdrant_clients: QdrantClients,
366+
) -> Result<()> {
367+
let qdrant_client = qdrant_clients.main_client(&self.config.qdrant_config);
368+
369+
// GCP store created data to test GCP.
370+
let bucket = match std::env::var("DUST_DATA_SOURCES_BUCKET") {
371+
Ok(bucket) => bucket,
372+
Err(_) => Err(anyhow!("DUST_DATA_SOURCES_BUCKET is not set"))?,
373+
};
374+
375+
let bucket_path = format!("{}/{}", self.project.project_id(), self.internal_id);
376+
let data_source_created_path = format!("{}/created.txt", bucket_path);
377+
378+
Object::create(
379+
&bucket,
380+
format!("{}", self.created).as_bytes().to_vec(),
381+
&data_source_created_path,
382+
"application/text",
383+
)
384+
.await?;
385+
386+
utils::done(&format!(
387+
"Created GCP bucket for data_source `{}`",
388+
self.data_source_id
389+
));
390+
391+
// Qdrant create collection.
392+
self.create_qdrant_collection(credentials, qdrant_client.clone())
393+
.await?;
384394

385395
let _ = qdrant_client
386396
.create_field_index(

0 commit comments

Comments
 (0)