Skip to content

Commit 41d4705

Browse files
authored
Merge pull request #19 from BrainDriveAI/feature/personas
Feature: Persona System Integration
2 parents 0edfdb4 + 4fa67b2 commit 41d4705

29 files changed

+2982
-26
lines changed

backend/app/api/v1/api.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from fastapi import APIRouter
2-
from app.api.v1.endpoints import auth, settings, ollama, ai_providers, ai_provider_settings, navigation_routes, components, conversations, tags
2+
from app.api.v1.endpoints import auth, settings, ollama, ai_providers, ai_provider_settings, navigation_routes, components, conversations, tags, personas
33
from app.routers import plugins
44
from app.routes.pages import router as pages_router
55

@@ -13,6 +13,7 @@
1313
api_router.include_router(components.router, prefix="/components", tags=["components"])
1414
api_router.include_router(conversations.router, tags=["conversations"])
1515
api_router.include_router(tags.router, tags=["tags"])
16+
api_router.include_router(personas.router, tags=["personas"])
1617
# Include the plugins router (which already includes the lifecycle router)
1718
api_router.include_router(plugins.router, tags=["plugins"])
18-
api_router.include_router(pages_router)
19+
api_router.include_router(pages_router)

backend/app/api/v1/endpoints/ai_providers.py

Lines changed: 137 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,32 @@ async def chat_completion(request: ChatCompletionRequest, db: AsyncSession = Dep
393393
logger.debug(f"Messages: {request.messages}")
394394
logger.debug(f"Params: {request.params}")
395395

396+
# Validate persona data if provided
397+
if request.persona_id or request.persona_system_prompt or request.persona_model_settings:
398+
logger.info(f"Persona data provided - persona_id: {request.persona_id}")
399+
400+
# Basic validation: if persona_id is provided, persona_system_prompt should also be provided
401+
if request.persona_id and not request.persona_system_prompt:
402+
logger.error(f"Invalid persona data: persona_id provided but persona_system_prompt is missing")
403+
raise HTTPException(
404+
status_code=400,
405+
detail="Invalid persona data: persona_system_prompt is required when persona_id is provided"
406+
)
407+
408+
# Validate persona model settings if provided
409+
if request.persona_model_settings:
410+
try:
411+
# Import here to avoid circular imports
412+
from app.schemas.persona import ModelSettings
413+
ModelSettings(**request.persona_model_settings)
414+
logger.debug(f"Persona model settings validated successfully")
415+
except Exception as validation_error:
416+
logger.error(f"Invalid persona model settings: {validation_error}")
417+
raise HTTPException(
418+
status_code=400,
419+
detail=f"Invalid persona model settings: {str(validation_error)}"
420+
)
421+
396422
try:
397423
# Get provider instance using the helper function
398424
logger.info("Getting provider instance from request")
@@ -403,8 +429,39 @@ async def chat_completion(request: ChatCompletionRequest, db: AsyncSession = Dep
403429
current_messages = [message.model_dump() for message in request.messages]
404430
print(f"Current messages: {current_messages}")
405431

406-
# Initialize combined_messages with current messages
407-
combined_messages = current_messages.copy()
432+
# Handle persona system prompt injection
433+
messages_with_persona = current_messages.copy()
434+
435+
# If persona is provided, inject system prompt at the beginning
436+
if request.persona_system_prompt:
437+
logger.info(f"Applying persona system prompt for persona_id: {request.persona_id}")
438+
system_message = {
439+
"role": "system",
440+
"content": request.persona_system_prompt
441+
}
442+
# Insert system message at the beginning, but after any existing system messages
443+
system_messages = [msg for msg in messages_with_persona if msg.get("role") == "system"]
444+
non_system_messages = [msg for msg in messages_with_persona if msg.get("role") != "system"]
445+
446+
# If there are existing system messages, replace the first one with persona system prompt
447+
# Otherwise, add persona system prompt as the first message
448+
if system_messages:
449+
messages_with_persona = [system_message] + system_messages[1:] + non_system_messages
450+
else:
451+
messages_with_persona = [system_message] + non_system_messages
452+
453+
logger.debug(f"Messages after persona system prompt injection: {len(messages_with_persona)} messages")
454+
455+
# Apply persona model settings to params
456+
enhanced_params = request.params.copy() if request.params else {}
457+
if request.persona_model_settings:
458+
logger.info(f"Applying persona model settings: {request.persona_model_settings}")
459+
# Merge persona settings with request params (persona takes precedence)
460+
enhanced_params.update(request.persona_model_settings)
461+
logger.debug(f"Enhanced params with persona settings: {enhanced_params}")
462+
463+
# Initialize combined_messages with persona-enhanced messages
464+
combined_messages = messages_with_persona.copy()
408465

409466
# Get or create a conversation
410467
from app.models.conversation import Conversation
@@ -446,6 +503,13 @@ async def chat_completion(request: ChatCompletionRequest, db: AsyncSession = Dep
446503
print(f"Conversation owner: {conversation.user_id}, Request user: {user_id}, Original request user_id: {request.user_id}")
447504
raise HTTPException(status_code=403, detail="Not authorized to access this conversation")
448505

506+
# Update conversation with persona_id if provided and different from current
507+
if request.persona_id and conversation.persona_id != request.persona_id:
508+
logger.info(f"Updating conversation {conversation_id} with persona_id: {request.persona_id}")
509+
conversation.persona_id = request.persona_id
510+
await db.commit()
511+
await db.refresh(conversation)
512+
449513
# Get previous messages for this conversation
450514
print(f"Retrieving previous messages for conversation {conversation_id}")
451515
previous_messages = await conversation.get_messages(db)
@@ -500,12 +564,32 @@ async def chat_completion(request: ChatCompletionRequest, db: AsyncSession = Dep
500564
page_id=request.page_id, # NEW FIELD - ID of the page this conversation belongs to
501565
model=request.model,
502566
server=provider_instance.server_name,
503-
conversation_type=request.conversation_type or "chat" # New field with default
567+
conversation_type=request.conversation_type or "chat", # New field with default
568+
persona_id=request.persona_id # Store persona_id when creating conversation
504569
)
505570
db.add(conversation)
506571
await db.commit()
507572
await db.refresh(conversation)
508573
print(f"Created new conversation with ID: {conversation.id}")
574+
575+
# If persona has a sample greeting, add it as the first assistant message
576+
if request.persona_sample_greeting:
577+
logger.info(f"Adding persona sample greeting for persona_id: {request.persona_id}")
578+
greeting_message = Message(
579+
id=str(uuid.uuid4()),
580+
conversation_id=conversation.id,
581+
sender="llm",
582+
message=request.persona_sample_greeting,
583+
message_metadata={
584+
"persona_id": request.persona_id,
585+
"persona_greeting": True,
586+
"model": request.model,
587+
"temperature": 0.0 # Greeting is static, not generated
588+
}
589+
)
590+
db.add(greeting_message)
591+
await db.commit()
592+
print(f"Added persona sample greeting: {request.persona_sample_greeting[:50]}...")
509593

510594
# Store user messages in the database
511595
for msg in request.messages:
@@ -536,7 +620,7 @@ async def stream_generator():
536620
async for chunk in provider_instance.chat_completion_stream(
537621
combined_messages,
538622
request.model,
539-
request.params
623+
enhanced_params
540624
):
541625
# Extract content from the chunk
542626
content = ""
@@ -559,18 +643,29 @@ async def stream_generator():
559643
elapsed_time = time.time() - start_time
560644
tokens_per_second = token_count / elapsed_time if elapsed_time > 0 else 0
561645

562-
# Store the LLM response in the database
646+
# Store the LLM response in the database with persona metadata
647+
message_metadata = {
648+
"token_count": token_count,
649+
"tokens_per_second": round(tokens_per_second, 1),
650+
"model": request.model,
651+
"temperature": enhanced_params.get("temperature", 0.7),
652+
"streaming": True
653+
}
654+
655+
# Add persona metadata if persona was used
656+
if request.persona_id:
657+
message_metadata.update({
658+
"persona_id": request.persona_id,
659+
"persona_applied": bool(request.persona_system_prompt),
660+
"persona_model_settings_applied": bool(request.persona_model_settings)
661+
})
662+
563663
db_message = Message(
564664
id=str(uuid.uuid4()),
565665
conversation_id=conversation.id,
566666
sender="llm",
567667
message=full_response,
568-
message_metadata={
569-
"token_count": token_count,
570-
"tokens_per_second": round(tokens_per_second, 1),
571-
"model": request.model,
572-
"temperature": request.params.get("temperature", 0.7) if request.params else 0.7
573-
}
668+
message_metadata=message_metadata
574669
)
575670
db.add(db_message)
576671

@@ -582,9 +677,17 @@ async def stream_generator():
582677
yield "data: [DONE]\n\n"
583678
except Exception as stream_error:
584679
print(f"Error in stream_generator: {stream_error}")
680+
logger.error(f"Streaming error with persona_id {request.persona_id}: {stream_error}")
681+
682+
# Enhanced error message for persona-related errors
683+
error_message = f"Streaming error: {str(stream_error)}"
684+
if request.persona_id:
685+
error_message += f" (Persona ID: {request.persona_id})"
686+
585687
error_json = json.dumps({
586688
"error": True,
587-
"message": f"Streaming error: {str(stream_error)}"
689+
"message": error_message,
690+
"persona_id": request.persona_id if request.persona_id else None
588691
})
589692
yield f"data: {error_json}\n\n"
590693
yield "data: [DONE]\n\n"
@@ -613,7 +716,7 @@ async def stream_generator():
613716
result = await provider_instance.chat_completion(
614717
combined_messages,
615718
request.model,
616-
request.params
719+
enhanced_params
617720
)
618721
elapsed_time = time.time() - start_time
619722

@@ -631,18 +734,29 @@ async def stream_generator():
631734
token_count = len(response_content.split()) * 1.3 # Rough estimate: words * 1.3
632735
tokens_per_second = token_count / elapsed_time if elapsed_time > 0 else 0
633736

634-
# Store the LLM response in the database
737+
# Store the LLM response in the database with persona metadata
738+
message_metadata = {
739+
"token_count": int(token_count),
740+
"tokens_per_second": round(tokens_per_second, 1),
741+
"model": request.model,
742+
"temperature": enhanced_params.get("temperature", 0.7),
743+
"streaming": False
744+
}
745+
746+
# Add persona metadata if persona was used
747+
if request.persona_id:
748+
message_metadata.update({
749+
"persona_id": request.persona_id,
750+
"persona_applied": bool(request.persona_system_prompt),
751+
"persona_model_settings_applied": bool(request.persona_model_settings)
752+
})
753+
635754
db_message = Message(
636755
id=str(uuid.uuid4()),
637756
conversation_id=conversation.id,
638757
sender="llm",
639758
message=response_content,
640-
message_metadata={
641-
"token_count": int(token_count),
642-
"tokens_per_second": round(tokens_per_second, 1),
643-
"model": request.model,
644-
"temperature": request.params.get("temperature", 0.7) if request.params else 0.7
645-
}
759+
message_metadata=message_metadata
646760
)
647761
db.add(db_message)
648762

@@ -660,10 +774,12 @@ async def stream_generator():
660774
import traceback
661775
logger.error(traceback.format_exc())
662776

663-
# Provide a more detailed error message
777+
# Provide a more detailed error message with persona context
664778
error_message = str(inner_e)
665779
if "provider_instance" in locals():
666780
error_message += f" Provider: {request.provider}, Server: {request.server_id}"
781+
if request.persona_id:
782+
error_message += f" Persona: {request.persona_id}"
667783

668784
raise HTTPException(
669785
status_code=500,

0 commit comments

Comments
 (0)