@@ -393,6 +393,32 @@ async def chat_completion(request: ChatCompletionRequest, db: AsyncSession = Dep
393393 logger .debug (f"Messages: { request .messages } " )
394394 logger .debug (f"Params: { request .params } " )
395395
396+ # Validate persona data if provided
397+ if request .persona_id or request .persona_system_prompt or request .persona_model_settings :
398+ logger .info (f"Persona data provided - persona_id: { request .persona_id } " )
399+
400+ # Basic validation: if persona_id is provided, persona_system_prompt should also be provided
401+ if request .persona_id and not request .persona_system_prompt :
402+ logger .error (f"Invalid persona data: persona_id provided but persona_system_prompt is missing" )
403+ raise HTTPException (
404+ status_code = 400 ,
405+ detail = "Invalid persona data: persona_system_prompt is required when persona_id is provided"
406+ )
407+
408+ # Validate persona model settings if provided
409+ if request .persona_model_settings :
410+ try :
411+ # Import here to avoid circular imports
412+ from app .schemas .persona import ModelSettings
413+ ModelSettings (** request .persona_model_settings )
414+ logger .debug (f"Persona model settings validated successfully" )
415+ except Exception as validation_error :
416+ logger .error (f"Invalid persona model settings: { validation_error } " )
417+ raise HTTPException (
418+ status_code = 400 ,
419+ detail = f"Invalid persona model settings: { str (validation_error )} "
420+ )
421+
396422 try :
397423 # Get provider instance using the helper function
398424 logger .info ("Getting provider instance from request" )
@@ -403,8 +429,39 @@ async def chat_completion(request: ChatCompletionRequest, db: AsyncSession = Dep
403429 current_messages = [message .model_dump () for message in request .messages ]
404430 print (f"Current messages: { current_messages } " )
405431
406- # Initialize combined_messages with current messages
407- combined_messages = current_messages .copy ()
432+ # Handle persona system prompt injection
433+ messages_with_persona = current_messages .copy ()
434+
435+ # If persona is provided, inject system prompt at the beginning
436+ if request .persona_system_prompt :
437+ logger .info (f"Applying persona system prompt for persona_id: { request .persona_id } " )
438+ system_message = {
439+ "role" : "system" ,
440+ "content" : request .persona_system_prompt
441+ }
442+ # Insert system message at the beginning, but after any existing system messages
443+ system_messages = [msg for msg in messages_with_persona if msg .get ("role" ) == "system" ]
444+ non_system_messages = [msg for msg in messages_with_persona if msg .get ("role" ) != "system" ]
445+
446+ # If there are existing system messages, replace the first one with persona system prompt
447+ # Otherwise, add persona system prompt as the first message
448+ if system_messages :
449+ messages_with_persona = [system_message ] + system_messages [1 :] + non_system_messages
450+ else :
451+ messages_with_persona = [system_message ] + non_system_messages
452+
453+ logger .debug (f"Messages after persona system prompt injection: { len (messages_with_persona )} messages" )
454+
455+ # Apply persona model settings to params
456+ enhanced_params = request .params .copy () if request .params else {}
457+ if request .persona_model_settings :
458+ logger .info (f"Applying persona model settings: { request .persona_model_settings } " )
459+ # Merge persona settings with request params (persona takes precedence)
460+ enhanced_params .update (request .persona_model_settings )
461+ logger .debug (f"Enhanced params with persona settings: { enhanced_params } " )
462+
463+ # Initialize combined_messages with persona-enhanced messages
464+ combined_messages = messages_with_persona .copy ()
408465
409466 # Get or create a conversation
410467 from app .models .conversation import Conversation
@@ -446,6 +503,13 @@ async def chat_completion(request: ChatCompletionRequest, db: AsyncSession = Dep
446503 print (f"Conversation owner: { conversation .user_id } , Request user: { user_id } , Original request user_id: { request .user_id } " )
447504 raise HTTPException (status_code = 403 , detail = "Not authorized to access this conversation" )
448505
506+ # Update conversation with persona_id if provided and different from current
507+ if request .persona_id and conversation .persona_id != request .persona_id :
508+ logger .info (f"Updating conversation { conversation_id } with persona_id: { request .persona_id } " )
509+ conversation .persona_id = request .persona_id
510+ await db .commit ()
511+ await db .refresh (conversation )
512+
449513 # Get previous messages for this conversation
450514 print (f"Retrieving previous messages for conversation { conversation_id } " )
451515 previous_messages = await conversation .get_messages (db )
@@ -500,12 +564,32 @@ async def chat_completion(request: ChatCompletionRequest, db: AsyncSession = Dep
500564 page_id = request .page_id , # NEW FIELD - ID of the page this conversation belongs to
501565 model = request .model ,
502566 server = provider_instance .server_name ,
503- conversation_type = request .conversation_type or "chat" # New field with default
567+ conversation_type = request .conversation_type or "chat" , # New field with default
568+ persona_id = request .persona_id # Store persona_id when creating conversation
504569 )
505570 db .add (conversation )
506571 await db .commit ()
507572 await db .refresh (conversation )
508573 print (f"Created new conversation with ID: { conversation .id } " )
574+
575+ # If persona has a sample greeting, add it as the first assistant message
576+ if request .persona_sample_greeting :
577+ logger .info (f"Adding persona sample greeting for persona_id: { request .persona_id } " )
578+ greeting_message = Message (
579+ id = str (uuid .uuid4 ()),
580+ conversation_id = conversation .id ,
581+ sender = "llm" ,
582+ message = request .persona_sample_greeting ,
583+ message_metadata = {
584+ "persona_id" : request .persona_id ,
585+ "persona_greeting" : True ,
586+ "model" : request .model ,
587+ "temperature" : 0.0 # Greeting is static, not generated
588+ }
589+ )
590+ db .add (greeting_message )
591+ await db .commit ()
592+ print (f"Added persona sample greeting: { request .persona_sample_greeting [:50 ]} ..." )
509593
510594 # Store user messages in the database
511595 for msg in request .messages :
@@ -536,7 +620,7 @@ async def stream_generator():
536620 async for chunk in provider_instance .chat_completion_stream (
537621 combined_messages ,
538622 request .model ,
539- request . params
623+ enhanced_params
540624 ):
541625 # Extract content from the chunk
542626 content = ""
@@ -559,18 +643,29 @@ async def stream_generator():
559643 elapsed_time = time .time () - start_time
560644 tokens_per_second = token_count / elapsed_time if elapsed_time > 0 else 0
561645
562- # Store the LLM response in the database
646+ # Store the LLM response in the database with persona metadata
647+ message_metadata = {
648+ "token_count" : token_count ,
649+ "tokens_per_second" : round (tokens_per_second , 1 ),
650+ "model" : request .model ,
651+ "temperature" : enhanced_params .get ("temperature" , 0.7 ),
652+ "streaming" : True
653+ }
654+
655+ # Add persona metadata if persona was used
656+ if request .persona_id :
657+ message_metadata .update ({
658+ "persona_id" : request .persona_id ,
659+ "persona_applied" : bool (request .persona_system_prompt ),
660+ "persona_model_settings_applied" : bool (request .persona_model_settings )
661+ })
662+
563663 db_message = Message (
564664 id = str (uuid .uuid4 ()),
565665 conversation_id = conversation .id ,
566666 sender = "llm" ,
567667 message = full_response ,
568- message_metadata = {
569- "token_count" : token_count ,
570- "tokens_per_second" : round (tokens_per_second , 1 ),
571- "model" : request .model ,
572- "temperature" : request .params .get ("temperature" , 0.7 ) if request .params else 0.7
573- }
668+ message_metadata = message_metadata
574669 )
575670 db .add (db_message )
576671
@@ -582,9 +677,17 @@ async def stream_generator():
582677 yield "data: [DONE]\n \n "
583678 except Exception as stream_error :
584679 print (f"Error in stream_generator: { stream_error } " )
680+ logger .error (f"Streaming error with persona_id { request .persona_id } : { stream_error } " )
681+
682+ # Enhanced error message for persona-related errors
683+ error_message = f"Streaming error: { str (stream_error )} "
684+ if request .persona_id :
685+ error_message += f" (Persona ID: { request .persona_id } )"
686+
585687 error_json = json .dumps ({
586688 "error" : True ,
587- "message" : f"Streaming error: { str (stream_error )} "
689+ "message" : error_message ,
690+ "persona_id" : request .persona_id if request .persona_id else None
588691 })
589692 yield f"data: { error_json } \n \n "
590693 yield "data: [DONE]\n \n "
@@ -613,7 +716,7 @@ async def stream_generator():
613716 result = await provider_instance .chat_completion (
614717 combined_messages ,
615718 request .model ,
616- request . params
719+ enhanced_params
617720 )
618721 elapsed_time = time .time () - start_time
619722
@@ -631,18 +734,29 @@ async def stream_generator():
631734 token_count = len (response_content .split ()) * 1.3 # Rough estimate: words * 1.3
632735 tokens_per_second = token_count / elapsed_time if elapsed_time > 0 else 0
633736
634- # Store the LLM response in the database
737+ # Store the LLM response in the database with persona metadata
738+ message_metadata = {
739+ "token_count" : int (token_count ),
740+ "tokens_per_second" : round (tokens_per_second , 1 ),
741+ "model" : request .model ,
742+ "temperature" : enhanced_params .get ("temperature" , 0.7 ),
743+ "streaming" : False
744+ }
745+
746+ # Add persona metadata if persona was used
747+ if request .persona_id :
748+ message_metadata .update ({
749+ "persona_id" : request .persona_id ,
750+ "persona_applied" : bool (request .persona_system_prompt ),
751+ "persona_model_settings_applied" : bool (request .persona_model_settings )
752+ })
753+
635754 db_message = Message (
636755 id = str (uuid .uuid4 ()),
637756 conversation_id = conversation .id ,
638757 sender = "llm" ,
639758 message = response_content ,
640- message_metadata = {
641- "token_count" : int (token_count ),
642- "tokens_per_second" : round (tokens_per_second , 1 ),
643- "model" : request .model ,
644- "temperature" : request .params .get ("temperature" , 0.7 ) if request .params else 0.7
645- }
759+ message_metadata = message_metadata
646760 )
647761 db .add (db_message )
648762
@@ -660,10 +774,12 @@ async def stream_generator():
660774 import traceback
661775 logger .error (traceback .format_exc ())
662776
663- # Provide a more detailed error message
777+ # Provide a more detailed error message with persona context
664778 error_message = str (inner_e )
665779 if "provider_instance" in locals ():
666780 error_message += f" Provider: { request .provider } , Server: { request .server_id } "
781+ if request .persona_id :
782+ error_message += f" Persona: { request .persona_id } "
667783
668784 raise HTTPException (
669785 status_code = 500 ,
0 commit comments