Skip to content

Commit

Permalink
Avoid R2RException for Server Errors (#1560)
Browse files Browse the repository at this point in the history
* Clean up exceptions

* Bump version
  • Loading branch information
NolanTrem authored Nov 6, 2024
1 parent 1467be8 commit 22c0e26
Show file tree
Hide file tree
Showing 17 changed files with 107 additions and 77 deletions.
19 changes: 14 additions & 5 deletions py/core/main/api/ingestion_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,16 @@
from uuid import UUID

import yaml
from fastapi import Body, Depends, File, Form, Path, Query, UploadFile
from fastapi import (
Body,
Depends,
File,
Form,
Path,
Query,
UploadFile,
HTTPException,
)
from pydantic import Json

from core.base import R2RException, RawChunk, Workflow, generate_document_id
Expand Down Expand Up @@ -484,9 +493,9 @@ async def update_document_metadata_app(
}

except Exception as e:
raise R2RException(
raise HTTPException(
status_code=500,
message=f"Error updating document metadata: {str(e)}",
detail=f"Error updating document metadata: {str(e)}",
)

@self.router.put(
Expand Down Expand Up @@ -548,8 +557,8 @@ async def update_chunk_app(
}

except Exception as e:
raise R2RException(
status_code=500, message=f"Error updating chunk: {str(e)}"
raise HTTPException(
status_code=500, detail=f"Error updating chunk: {str(e)}"
)

create_vector_index_extras = self.openapi_extras.get(
Expand Down
13 changes: 7 additions & 6 deletions py/core/main/orchestration/hatchet/ingestion_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import uuid
from typing import TYPE_CHECKING
from uuid import UUID
from fastapi import HTTPException

from hatchet_sdk import ConcurrencyLimitStrategy, Context
from litellm import AuthenticationError
Expand Down Expand Up @@ -246,9 +247,9 @@ async def parse(self, context: Context) -> dict:
message="Authentication error: Invalid API key or credentials.",
)
except Exception as e:
raise R2RException(
raise HTTPException(
status_code=500,
message=f"Error during ingestion: {str(e)}",
detail=f"Error during ingestion: {str(e)}",
)

@orchestration_provider.failure()
Expand Down Expand Up @@ -605,9 +606,9 @@ async def update_chunk(self, context: Context) -> dict:
}

except Exception as e:
raise R2RException(
raise HTTPException(
status_code=500,
message=f"Error during chunk update: {str(e)}",
detail=f"Error during chunk update: {str(e)}",
)

@orchestration_provider.failure()
Expand Down Expand Up @@ -692,9 +693,9 @@ async def update_document_metadata(self, context: Context) -> dict:
}

except Exception as e:
raise R2RException(
raise HTTPException(
status_code=500,
message=f"Error during document metadata update: {str(e)}",
detail=f"Error during document metadata update: {str(e)}",
)

@orchestration_provider.failure()
Expand Down
25 changes: 13 additions & 12 deletions py/core/main/orchestration/simple/ingestion_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from litellm import AuthenticationError

from fastapi import HTTPException
from core.base import DocumentExtraction, R2RException, increment_version
from core.utils import (
generate_default_user_collection_id,
Expand Down Expand Up @@ -123,8 +124,8 @@ async def ingest_files(input_data):
await service.update_document_status(
document_info, status=IngestionStatus.FAILED
)
raise R2RException(
status_code=500, message=f"Error during ingestion: {str(e)}"
raise HTTPException(
status_code=500, detail=f"Error during ingestion: {str(e)}"
)

async def update_files(input_data):
Expand Down Expand Up @@ -302,9 +303,9 @@ async def ingest_chunks(input_data):
await service.update_document_status(
document_info, status=IngestionStatus.FAILED
)
raise R2RException(
raise HTTPException(
status_code=500,
message=f"Error during chunk ingestion: {str(e)}",
detail=f"Error during chunk ingestion: {str(e)}",
)

async def update_chunk(input_data):
Expand Down Expand Up @@ -335,9 +336,9 @@ async def update_chunk(input_data):
)

except Exception as e:
raise R2RException(
raise HTTPException(
status_code=500,
message=f"Error during chunk update: {str(e)}",
detail=f"Error during chunk update: {str(e)}",
)

async def create_vector_index(input_data):
Expand All @@ -354,9 +355,9 @@ async def create_vector_index(input_data):
await service.providers.database.create_index(**parsed_data)

except Exception as e:
raise R2RException(
raise HTTPException(
status_code=500,
message=f"Error during vector index creation: {str(e)}",
detail=f"Error during vector index creation: {str(e)}",
)

async def delete_vector_index(input_data):
Expand All @@ -374,9 +375,9 @@ async def delete_vector_index(input_data):
return {"status": "Vector index deleted successfully."}

except Exception as e:
raise R2RException(
raise HTTPException(
status_code=500,
message=f"Error during vector index deletion: {str(e)}",
detail=f"Error during vector index deletion: {str(e)}",
)

async def update_document_metadata(input_data):
Expand Down Expand Up @@ -406,9 +407,9 @@ async def update_document_metadata(input_data):
}

except Exception as e:
raise R2RException(
raise HTTPException(
status_code=500,
message=f"Error during document metadata update: {str(e)}",
detail=f"Error during document metadata update: {str(e)}",
)

return {
Expand Down
5 changes: 3 additions & 2 deletions py/core/main/services/ingestion_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime
from typing import Any, AsyncGenerator, Optional, Sequence, Union
from uuid import UUID
from fastapi import HTTPException

from core.base import (
Document,
Expand Down Expand Up @@ -135,8 +136,8 @@ async def ingest_file_ingress(
logger.error(f"R2RException in ingest_file_ingress: {str(e)}")
raise
except Exception as e:
raise R2RException(
status_code=500, message=f"Error during ingestion: {str(e)}"
raise HTTPException(
status_code=500, detail=f"Error during ingestion: {str(e)}"
)

def _create_document_info_from_file(
Expand Down
5 changes: 3 additions & 2 deletions py/core/main/services/kg_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
from typing import AsyncGenerator, Optional
from uuid import UUID
from fastapi import HTTPException

from core.base import KGExtractionStatus, RunManager
from core.base.abstractions import (
Expand Down Expand Up @@ -512,9 +513,9 @@ async def tune_prompt(
results.append(result)

if not results:
raise R2RException(
message="No results generated from prompt tuning",
raise HTTPException(
status_code=500,
detail="No results generated from prompt tuning",
)

return results[0]
18 changes: 10 additions & 8 deletions py/core/main/services/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
from typing import Optional
from uuid import UUID
from fastapi import HTTPException

from core import R2RStreamingRAGAgent
from core.base import (
Expand Down Expand Up @@ -203,12 +204,12 @@ async def rag(
except Exception as e:
logger.error(f"Pipeline error: {str(e)}")
if "NoneType" in str(e):
raise R2RException(
raise HTTPException(
status_code=502,
message="Remote server not reachable or returned an invalid response",
detail="Remote server not reachable or returned an invalid response",
) from e
raise R2RException(
status_code=500, message="Internal Server Error"
raise HTTPException(
status_code=500, detail="Internal Server Error"
) from e

async def stream_rag_response(
Expand Down Expand Up @@ -398,12 +399,13 @@ async def stream_response():
except Exception as e:
logger.error(f"Pipeline error: {str(e)}")
if "NoneType" in str(e):
raise R2RException(
raise HTTPException(
status_code=502,
message="Server not reachable or returned an invalid response",
detail="Server not reachable or returned an invalid response",
)
raise R2RException(
status_code=500, message="Internal Server Error"
raise HTTPException(
status_code=500,
detail="Internal Server Error",
)


Expand Down
9 changes: 5 additions & 4 deletions py/core/pipes/kg/deduplication.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
import logging
from typing import Any, Optional, Union
from typing import Any, Union
from uuid import UUID
from fastapi import HTTPException

from core.base import AsyncState, R2RException
from core.base import AsyncState
from core.base.abstractions import Entity, KGEntityDeduplicationType
from core.base.pipes import AsyncPipe
from core.providers import (
Expand Down Expand Up @@ -133,9 +134,9 @@ async def kg_named_entity_deduplication(
logger.error(
f"KGEntityDeduplicationPipe: Error in entity deduplication: {str(e)}"
)
raise R2RException(
message=f"KGEntityDeduplicationPipe: Error deduplicating entities: {str(e)}",
raise HTTPException(
status_code=500,
detail=f"KGEntityDeduplicationPipe: Error deduplicating entities: {str(e)}",
)

async def kg_description_entity_deduplication(
Expand Down
14 changes: 8 additions & 6 deletions py/core/pipes/kg/prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
"""

import logging
from typing import Any, Optional
from typing import Any
from uuid import UUID
from fastapi import HTTPException

from core.base import (
AsyncState,
Expand Down Expand Up @@ -76,14 +77,15 @@ async def _run_logic(
)

if not tuned_prompt:
raise R2RException(
message="Failed to generate tuned prompt", status_code=500
raise HTTPException(
status_code=500,
detail="Failed to generate tuned prompt",
)

yield {"tuned_prompt": tuned_prompt.choices[0].message.content}

except Exception as e:
logger.error(f"Error in prompt tuning: {str(e)}")
raise R2RException(
message=f"Error tuning prompt: {str(e)}", status_code=500
raise HTTPException(
status_code=500,
detail=f"Error tuning prompt: {str(e)}",
)
29 changes: 16 additions & 13 deletions py/core/providers/auth/r2r_auth.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
from datetime import datetime, timedelta, timezone
from typing import Dict
from fastapi import HTTPException

import jwt
from fastapi import Depends
Expand Down Expand Up @@ -192,7 +192,7 @@ async def verify_email(
)
return {"message": "Email verified successfully"}

async def login(self, email: str, password: str) -> Dict[str, Token]:
async def login(self, email: str, password: str) -> dict[str, Token]:
logger = logging.getLogger()
logger.debug(f"Attempting login for email: {email}")

Expand All @@ -209,8 +209,9 @@ async def login(self, email: str, password: str) -> Dict[str, Token]:
logger.error(
f"Invalid hashed_password type: {type(user.hashed_password)}"
)
raise R2RException(
status_code=500, message="Invalid password hash in database"
raise HTTPException(
status_code=500,
detail="Invalid password hash in database",
)

try:
Expand All @@ -219,8 +220,9 @@ async def login(self, email: str, password: str) -> Dict[str, Token]:
)
except Exception as e:
logger.error(f"Error during password verification: {str(e)}")
raise R2RException(
status_code=500, message="Error during password verification"
raise HTTPException(
status_code=500,
detail="Error during password verification",
) from e

if not password_verified:
Expand All @@ -242,7 +244,7 @@ async def login(self, email: str, password: str) -> Dict[str, Token]:

async def refresh_access_token(
self, refresh_token: str
) -> Dict[str, Token]:
) -> dict[str, Token]:
token_data = await self.decode_token(refresh_token)
if token_data.token_type != "refresh":
raise R2RException(
Expand All @@ -267,13 +269,14 @@ async def refresh_access_token(

async def change_password(
self, user: UserResponse, current_password: str, new_password: str
) -> Dict[str, str]:
) -> dict[str, str]:
if not isinstance(user.hashed_password, str):
logger.error(
f"Invalid hashed_password type: {type(user.hashed_password)}"
)
raise R2RException(
status_code=500, message="Invalid password hash in database"
raise HTTPException(
status_code=500,
detail="Invalid password hash in database",
)

if not self.crypto_provider.verify_password(
Expand All @@ -291,7 +294,7 @@ async def change_password(
)
return {"message": "Password changed successfully"}

async def request_password_reset(self, email: str) -> Dict[str, str]:
async def request_password_reset(self, email: str) -> dict[str, str]:
user = await self.database_provider.get_user_by_email(email)
if not user:
# To prevent email enumeration, always return a success message
Expand All @@ -312,7 +315,7 @@ async def request_password_reset(self, email: str) -> Dict[str, str]:

async def confirm_password_reset(
self, reset_token: str, new_password: str
) -> Dict[str, str]:
) -> dict[str, str]:
user_id = await self.database_provider.get_user_id_by_reset_token(
reset_token
)
Expand All @@ -330,7 +333,7 @@ async def confirm_password_reset(
await self.database_provider.remove_reset_token(user_id)
return {"message": "Password reset successfully"}

async def logout(self, token: str) -> Dict[str, str]:
async def logout(self, token: str) -> dict[str, str]:
# Add the token to a blacklist
await self.database_provider.blacklist_token(token)
return {"message": "Logged out successfully"}
Expand Down
Loading

0 comments on commit 22c0e26

Please sign in to comment.