Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 69 additions & 3 deletions nilai-api/src/nilai_api/handlers/nilrag.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import numpy as np
import time
import sys

import nilql
import nilrag
Expand All @@ -20,6 +22,11 @@
"sentence-transformers/all-MiniLM-L6-v2", device="cpu"
) # FIXME: Use a GPU model and move to a separate container

def get_size_in_MB(obj):
return sys.getsizeof(obj) / (1024 * 1024)

def get_size_in_KB(obj):
return sys.getsizeof(obj) / 1024

def generate_embeddings_huggingface(
chunks_or_query: Union[str, list],
Expand Down Expand Up @@ -73,43 +80,64 @@ def handle_nilrag(req: ChatRequest):
nilDB = nilrag.NilDB(nodes)

# Initialize secret keys
start_time = time.time()
num_parties = len(nilDB.nodes)
additive_key = nilql.secret_key({"nodes": [{}] * num_parties}, {"sum": True})
xor_key = nilql.secret_key({"nodes": [{}] * num_parties}, {"store": True})
end_time = time.time()
secret_keys_initialization_time = round(end_time - start_time, 2)

# Step 2: Secret share query
logger.debug("Secret sharing query and sending to NilDB...")
# 2.1 Extract the user query
query = None
start_time = time.time()
for message in req.messages:
if message.role == "user":
query = message.content
break

if query is None:
raise HTTPException(status_code=400, detail="No user query found")
end_time = time.time()
extract_user_query_time = round(end_time - start_time, 2)

# 2.2 Generate query embeddings: one string query is assumed.
start_time = time.time()
query_embedding = generate_embeddings_huggingface([query])[0]
nilql_query_embedding = encrypt_float_list(additive_key, query_embedding)
end_time = time.time()
embedding_generation_time = round(end_time - start_time, 2)
query_size = round(get_size_in_KB(nilql_query_embedding),2)

# Step 3: Ask NilDB to compute the differences
logger.debug("Requesting computation from NilDB...")
start_time = time.time()
difference_shares = nilDB.diff_query_execute(nilql_query_embedding)
end_time = time.time()
asking_nilDB_time = round(end_time - start_time, 2)
difference_shares_size = round(get_size_in_KB(difference_shares),2)

# Step 4: Compute distances and sort
logger.debug("Compute distances and sort...")
# 4.1 Group difference shares by ID
start_time = time.time()
difference_shares_by_id = group_shares_by_id(
difference_shares, # type: ignore
lambda share: share["difference"],
)
end_time = time.time()
group_shares_by_id_time = round(end_time - start_time, 2)
# 4.2 Transpose the lists for each _id
start_time = time.time()
difference_shares_by_id = {
id: np.array(differences).T.tolist()
for id, differences in difference_shares_by_id.items()
}
end_time = time.time()
transpose_lists_time = round(end_time - start_time, 2)
# 4.3 Decrypt and compute distances
start_time = time.time()
reconstructed = [
{
"_id": id,
Expand All @@ -119,36 +147,55 @@ def handle_nilrag(req: ChatRequest):
}
for id, difference_shares in difference_shares_by_id.items()
]
end_time = time.time()
decryption_time = round(end_time - start_time, 2)

# 4.4 Sort id list based on the corresponding distances
start_time = time.time()
sorted_ids = sorted(reconstructed, key=lambda x: x["distances"])
end_time = time.time()
sort_id_list_time = round(end_time - start_time, 2)

# Step 5: Query the top k
logger.debug("Query top k chunks...")
top_k = 2
top_k_ids = [item["_id"] for item in sorted_ids[:top_k]]

# 5.1 Query top k
start_time = time.time()
chunk_shares = nilDB.chunk_query_execute(top_k_ids)

end_time = time.time()
query_top_chunks_time = round(end_time - start_time, 2)
chunks_shares_size = round(get_size_in_KB(chunk_shares), 2)
# 5.2 Group chunk shares by ID
start_time = time.time()
chunk_shares_by_id = group_shares_by_id(
chunk_shares, # type: ignore
lambda share: share["chunk"],
)
end_time = time.time()
group_chunks_time = round(end_time - start_time, 2)

# 5.3 Decrypt chunks
start_time = time.time()
top_results = [
{"_id": id, "distances": nilql.decrypt(xor_key, chunk_shares)}
for id, chunk_shares in chunk_shares_by_id.items()
]
end_time = time.time()
decrypt_chunks_time = round(end_time - start_time, 2)

# Step 6: Format top results
start_time = time.time()
formatted_results = "\n".join(
f"- {str(result['distances'])}" for result in top_results
)
relevant_context = f"\n\nRelevant Context:\n{formatted_results}"
end_time = time.time()
format_results_time = round(end_time - start_time, 2)

# Step 7: Update system message
start_time = time.time()
for message in req.messages:
if message.role == "system":
if message.content is None:
Expand All @@ -163,11 +210,30 @@ def handle_nilrag(req: ChatRequest):
else:
# If no system message exists, add one
req.messages.insert(0, Message(role="system", content=relevant_context))

end_time = time.time()
update_system_message_time = round(end_time - start_time, 2)
logger.debug(f"System message updated with relevant context:\n {req.messages}")
return {
"secret_keys_initialization_seconds": secret_keys_initialization_time,
"extract_user_query_seconds": extract_user_query_time,
"embedding_generation_seconds": embedding_generation_time,
"asking_nilDB_seconds": asking_nilDB_time,
"group_shares_by_id_seconds": group_shares_by_id_time,
"transpose_lists_seconds": transpose_lists_time,
"decryption_seconds": decryption_time,
"sort_id_list_seconds": sort_id_list_time,
"query_top_chunks_seconds": query_top_chunks_time,
"group_chunks_seconds": group_chunks_time,
"decrypt_chunks_seconds": decrypt_chunks_time,
"format_results_seconds": format_results_time,
"update_system_message_seconds": update_system_message_time,
"query_size_kbs": query_size,
"difference_shares_size_kbs": difference_shares_size,
"chunks_shares_size_kbs": chunks_shares_size,
}

except Exception as e:
logger.error("An error occurred within nilrag: %s", str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
)
7 changes: 5 additions & 2 deletions nilai-api/src/nilai_api/routers/private.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,10 @@ async def chat_completion(
f"Chat completion request for model {model_name} from user {user.userid} on url: {model_url}"
)

nilrag_metrics = {}
if req.nilrag:
handle_nilrag(req)
nilrag_metrics = handle_nilrag(req)
logger.info(f"nilRag metrics: {nilrag_metrics}")

if req.stream:
client = AsyncOpenAI(base_url=model_url, api_key="<not-needed>")
Expand Down Expand Up @@ -261,6 +263,7 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]:
model_response = SignedChatCompletion(
**response.model_dump(),
signature="",
metrics=nilrag_metrics,
)
if model_response.usage is None:
raise HTTPException(
Expand All @@ -286,4 +289,4 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]:
signature = sign_message(state.private_key, response_json)
model_response.signature = b64encode(signature).decode()

return model_response
return model_response
3 changes: 2 additions & 1 deletion packages/nilai-common/src/nilai_common/api_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class ChatRequest(BaseModel):

class SignedChatCompletion(ChatCompletion):
signature: str
metrics: Optional[dict] = {}


class AttestationResponse(BaseModel):
Expand Down Expand Up @@ -69,4 +70,4 @@ class ModelEndpoint(BaseModel):

class HealthCheckResponse(BaseModel):
status: str
uptime: str
uptime: str
Loading