-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathembeddings_model.py
82 lines (63 loc) · 2.7 KB
/
embeddings_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader, PyPDFDirectoryLoader
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
from langchain.schema import Document
from langchain.vectorstores import FAISS
import boto3
import pullpdfs
import os
import json
from sagemaker.jumpstart.model import JumpStartModel
from typing import Dict, List
import config
import streamlit as st
embedding_model_endpoint_name = config.aws_sagemaker_embeddings_model_endpoint
class CustomEmbeddingsContentHandler(EmbeddingsContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:
input_str = json.dumps({"text_inputs": inputs, **model_kwargs})
return input_str.encode("utf-8")
def transform_output(self, output: bytes) -> List[List[float]]:
response_json = json.loads(output.read().decode("utf-8"))
return response_json["embedding"]
@st.cache_data
def makeChunks(filenames, metadata, data_root):
print("Splitting retrieved documents into chunks.")
documents = []
for idx, file in enumerate(filenames):
loader = PyPDFLoader(data_root + file)
document = loader.load()
for document_fragment in document:
document_fragment.metadata = metadata[idx]
documents += document
text_splitter = RecursiveCharacterTextSplitter(
chunk_size = 512,
chunk_overlap = 100,
)
docs = text_splitter.split_documents(documents)
print("\nDocumens split into chunks.")
print(f'# of document pages {len(documents)}')
print(f'# of document chunks: {len(docs)}')
return docs
@st.cache_resource
def build():
print("Building embedding model.")
filenames, metadata, data_root = pullpdfs.pull()
docs = makeChunks(filenames, metadata, data_root)
embeddings_content_handler = CustomEmbeddingsContentHandler()
embeddings = SagemakerEndpointEmbeddings(
endpoint_name=embedding_model_endpoint_name,
region_name=boto3.Session().region_name,
content_handler=embeddings_content_handler
)
db = FAISS.from_documents(docs, embeddings)
print("Embedding model completed and vector DB completed")
return db
def test_db(db):
print("Testing vector DB functionality.")
query = "Why is Amazon successful?"
results_with_scores = db.similarity_search_with_score(query)
for doc, score in results_with_scores:
print(f"Content: {doc.page_content}\nMetadata: {doc.metadata}\nScore: {score}\n\n")