@@ -52,8 +52,8 @@ Libraries:
52
52
53
53
54
54
55
- ``` python
56
- ! pip install -- quiet datasets tqdm cohere pymongo
55
+ ``` sh
56
+ pip install --quiet datasets tqdm cohere pymongo
57
57
```
58
58
59
59
@@ -183,11 +183,11 @@ def combine_attributes(row):
183
183
combined = f " { row[' company' ]} { row[' sector' ]} "
184
184
185
185
# Add reports information
186
- for report in row[' reports' ]:
186
+ for report in row[" reports" ]:
187
187
combined += f " { report[' year' ]} { report[' title' ]} { report[' author' ]} { report[' content' ]} "
188
188
189
189
# Add recent news information
190
- for news in row[' recent_news' ]:
190
+ for news in row[" recent_news" ]:
191
191
combined += f " { news[' headline' ]} { news[' summary' ]} "
192
192
193
193
return combined.strip()
@@ -196,15 +196,15 @@ def combine_attributes(row):
196
196
197
197
``` python
198
198
# Add the new column 'combined_attributes'
199
- dataset_df[' combined_attributes' ] = dataset_df.apply(
199
+ dataset_df[" combined_attributes" ] = dataset_df.apply(
200
200
combine_attributes, axis = 1
201
201
)
202
202
```
203
203
204
204
205
205
``` python
206
206
# Display the first few rows of the updated dataframe
207
- dataset_df[[' company' , ' ticker' , ' combined_attributes' ]].head()
207
+ dataset_df[[" company" , " ticker" , " combined_attributes" ]].head()
208
208
```
209
209
210
210
<div >
@@ -270,7 +270,7 @@ def get_embedding(
270
270
texts = [text],
271
271
model = model,
272
272
input_type = input_type, # Used for embeddings of search queries run against a vector DB to find relevant documents
273
- embedding_types = [' float' ],
273
+ embedding_types = [" float" ],
274
274
)
275
275
276
276
return response.embeddings.float[0 ]
@@ -279,7 +279,7 @@ def get_embedding(
279
279
# Apply the embedding function with a progress bar
280
280
tqdm.pandas(desc = " Generating embeddings" )
281
281
dataset_df[" embedding" ] = dataset_df[
282
- ' combined_attributes'
282
+ " combined_attributes"
283
283
].progress_apply(get_embedding)
284
284
285
285
print (f " We just computed { len (dataset_df[' embedding' ])} embeddings. " )
@@ -421,8 +421,8 @@ def get_mongo_client(mongo_uri):
421
421
)
422
422
423
423
# Validate the connection
424
- ping_result = client.admin.command(' ping' )
425
- if ping_result.get(' ok ' ) == 1.0 :
424
+ ping_result = client.admin.command(" ping" )
425
+ if ping_result.get(" ok " ) == 1.0 :
426
426
# Connection successful
427
427
print (" Connection to MongoDB successful" )
428
428
return client
@@ -478,7 +478,7 @@ MongoDB's Document model and its compatibility with Python dictionaries offer se
478
478
![ ] ( ../../assets/images/rag-cohere-mongodb-4.png )
479
479
480
480
``` python
481
- documents = dataset_df.to_dict(' records' )
481
+ documents = dataset_df.to_dict(" records" )
482
482
collection.insert_many(documents)
483
483
484
484
print (" Data ingestion into MongoDB completed" )
@@ -592,13 +592,13 @@ def rerank_documents(query: str, documents, top_n: int = 3):
592
592
original_doc = documents[result.index]
593
593
top_documents_after_rerank.append(
594
594
{
595
- ' company' : original_doc[' company' ],
596
- ' combined_attributes' : original_doc[
597
- ' combined_attributes'
595
+ " company" : original_doc[" company" ],
596
+ " combined_attributes" : original_doc[
597
+ " combined_attributes"
598
598
],
599
- ' reports' : original_doc[' reports' ],
600
- ' vector_search_score' : original_doc[' score' ],
601
- ' relevance_score' : result.relevance_score,
599
+ " reports" : original_doc[" reports" ],
600
+ " vector_search_score" : original_doc[" score" ],
601
+ " relevance_score" : result.relevance_score,
602
602
}
603
603
)
604
604
@@ -724,9 +724,9 @@ pd.DataFrame(reranked_documents).head()
724
724
def format_documents_for_chat (documents ):
725
725
return [
726
726
{
727
- " company" : doc[' company' ],
727
+ " company" : doc[" company" ],
728
728
# "reports": doc['reports'],
729
- " combined_attributes" : doc[' combined_attributes' ],
729
+ " combined_attributes" : doc[" combined_attributes" ],
730
730
}
731
731
for doc in documents
732
732
]
@@ -825,7 +825,7 @@ class CohereChat:
825
825
# Use the connection string from history_params
826
826
self .client = pymongo.MongoClient(
827
827
self .history_params.get(
828
- ' connection_string' , ' mongodb://localhost:27017/'
828
+ " connection_string" , " mongodb://localhost:27017/"
829
829
)
830
830
)
831
831
@@ -838,34 +838,34 @@ class CohereChat:
838
838
# Use the history_collection from history_params, or default to "chat_history"
839
839
self .history_collection = self .db[
840
840
self .history_params.get(
841
- ' history_collection' , ' chat_history'
841
+ " history_collection" , " chat_history"
842
842
)
843
843
]
844
844
845
845
# Use the session_id from history_params, or default to "default_session"
846
846
self .session_id = self .history_params.get(
847
- ' session_id' , ' default_session'
847
+ " session_id" , " default_session"
848
848
)
849
849
850
850
def add_to_history (self , message : str , prefix : str = " " ):
851
851
self .history_collection.insert_one(
852
852
{
853
- ' session_id' : self .session_id,
854
- ' message' : message,
855
- ' prefix' : prefix,
853
+ " session_id" : self .session_id,
854
+ " message" : message,
855
+ " prefix" : prefix,
856
856
}
857
857
)
858
858
859
859
def get_chat_history (self ) -> List[Dict[str , str ]]:
860
860
history = self .history_collection.find(
861
- {' session_id' : self .session_id}
862
- ).sort(' _id' , 1 )
861
+ {" session_id" : self .session_id}
862
+ ).sort(" _id" , 1 )
863
863
return [
864
864
{
865
865
" role" : (
866
- " user" if item[' prefix' ] == " USER" else " chatbot"
866
+ " user" if item[" prefix" ] == " USER" else " chatbot"
867
867
),
868
- " message" : item[' message' ],
868
+ " message" : item[" message" ],
869
869
}
870
870
for item in history
871
871
]
@@ -875,11 +875,11 @@ class CohereChat:
875
875
) -> List[Dict]:
876
876
rerank_docs = [
877
877
{
878
- ' company' : doc[' company' ],
879
- ' combined_attributes' : doc[' combined_attributes' ],
878
+ " company" : doc[" company" ],
879
+ " combined_attributes" : doc[" combined_attributes" ],
880
880
}
881
881
for doc in documents
882
- if doc[' combined_attributes' ].strip()
882
+ if doc[" combined_attributes" ].strip()
883
883
]
884
884
885
885
if not rerank_docs:
@@ -897,11 +897,11 @@ class CohereChat:
897
897
898
898
top_documents_after_rerank = [
899
899
{
900
- ' company' : rerank_docs[result.index][' company' ],
901
- ' combined_attributes' : rerank_docs[result.index][
902
- ' combined_attributes'
900
+ " company" : rerank_docs[result.index][" company" ],
901
+ " combined_attributes" : rerank_docs[result.index][
902
+ " combined_attributes"
903
903
],
904
- ' relevance_score' : result.relevance_score,
904
+ " relevance_score" : result.relevance_score,
905
905
}
906
906
for result in response.results
907
907
]
@@ -925,8 +925,8 @@ class CohereChat:
925
925
) -> List[Dict]:
926
926
return [
927
927
{
928
- " company" : doc[' company' ],
929
- " combined_attributes" : doc[' combined_attributes' ],
928
+ " company" : doc[" company" ],
929
+ " combined_attributes" : doc[" combined_attributes" ],
930
930
}
931
931
for doc in documents
932
932
]
@@ -972,8 +972,8 @@ class CohereChat:
972
972
973
973
def show_history (self ):
974
974
history = self .history_collection.find(
975
- {' session_id' : self .session_id}
976
- ).sort(' _id' , 1 )
975
+ {" session_id" : self .session_id}
976
+ ).sort(" _id" , 1 )
977
977
for item in history:
978
978
print (f " { item[' prefix' ]} : { item[' message' ]} " )
979
979
print (" -------------------------" )
@@ -988,9 +988,9 @@ chat = CohereChat(
988
988
database = DB_NAME ,
989
989
main_collection = COLLECTION_NAME ,
990
990
history_params = {
991
- ' connection_string' : MONGO_URI ,
992
- ' history_collection' : " chat_history" ,
993
- ' session_id' : 2 ,
991
+ " connection_string" : MONGO_URI ,
992
+ " history_collection" : " chat_history" ,
993
+ " session_id" : 2 ,
994
994
},
995
995
)
996
996
0 commit comments