diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 97c072d0e..c51610318 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -4,6 +4,7 @@ import logging import random from collections.abc import AsyncGenerator +from dataclasses import dataclass from typing import cast from google import genai @@ -32,6 +33,21 @@ def filter(self, record): logging.getLogger("google_genai.types").addFilter(SuppressNonTextPartsWarning()) +@dataclass +class ChunkView: + """流式响应 chunk 的结构化视图对象 + + 提供对 Gemini API 流式响应的统一访问接口, + 避免在多处重复进行防御性检查。 + """ + candidate: types.Candidate | None + parts: list[types.Part] | None + reasoning_text: str | None + visible_text: str | None + has_function_call: bool + finish_reason: types.FinishReason | None + + @register_provider_adapter( "googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器", @@ -398,15 +414,85 @@ def append_or_extend( return gemini_contents - def _extract_reasoning_content(self, candidate: types.Candidate) -> str: - """Extract reasoning content from candidate parts""" - if not candidate.content or not candidate.content.parts: - return "" + def _split_chunk_content( + self, chunk: types.GenerateContentResponse + ) -> ChunkView: + """ + 从流式响应 chunk 中提取结构化视图。 + + 添加防御性检查,安全访问可能不存在的属性。 + + Args: + chunk: Gemini API 返回的流式响应 chunk + + Returns: + ChunkView: 包含 candidate, parts, reasoning_text, visible_text, + has_function_call, finish_reason 的结构化视图 + """ + # 防御性检查:candidates 是否存在且非空 + if not chunk.candidates: + return ChunkView( + candidate=None, + parts=None, + reasoning_text=None, + visible_text=None, + has_function_call=False, + finish_reason=None, + ) - thought_buf: list[str] = [ - (p.text or "") for p in candidate.content.parts if p.thought - ] - return "".join(thought_buf).strip() + candidate = chunk.candidates[0] + + # 防御性检查:使用 getattr 安全访问 content + content = getattr(candidate, "content", None) + if content is None: + return ChunkView( + candidate=candidate, + parts=None, + reasoning_text=None, + visible_text=None, + has_function_call=False, + finish_reason=getattr(candidate, "finish_reason", None), + ) + + # 防御性检查:使用 getattr 安全访问 parts + parts = getattr(content, "parts", None) + + reasoning_text: str | None = None + visible_text: str | None = None + has_function_call = False + + if parts: + text_parts: list[str] = [] + reasoning_parts: list[str] = [] + + for part in parts: + # 安全访问 thought, text, function_call 属性 + is_thought = getattr(part, "thought", False) + part_text = getattr(part, "text", None) + + if getattr(part, "function_call", None): + has_function_call = True + + if is_thought and part_text: + reasoning_parts.append(part_text) + elif not is_thought and part_text: + text_parts.append(part_text) + + reasoning_text = "".join(reasoning_parts) if reasoning_parts else None + visible_text = "".join(text_parts) if text_parts else None + else: + # 回退:当 parts 为空但 chunk.text 存在时 + chunk_text = getattr(chunk, "text", None) + visible_text = chunk_text or None + + return ChunkView( + candidate=candidate, + parts=parts, + reasoning_text=reasoning_text, + visible_text=visible_text, + has_function_call=has_function_call, + finish_reason=getattr(candidate, "finish_reason", None), + ) def _extract_usage( self, usage_metadata: types.GenerateContentResponseUsageMetadata @@ -451,7 +537,12 @@ def _process_content_parts( raise Exception("API 返回的 candidate.content.parts 为空。") # 提取 reasoning content - reasoning = self._extract_reasoning_content(candidate) + thought_buf: list[str] = [ + (p.text or "") + for p in result_parts + if getattr(p, "thought", False) + ] + reasoning = "".join(thought_buf).strip() if reasoning: llm_response.reasoning_content = reasoning @@ -467,6 +558,9 @@ def _process_content_parts( ): chain.append(Comp.Plain("这是图片")) for part in result_parts: + # 跳过思考内容(thought=True),只处理实际输出 + if getattr(part, "thought", False): + continue if part.text: chain.append(Comp.Plain(part.text)) @@ -635,20 +729,25 @@ async def _query_stream( async for chunk in result: llm_response = LLMResponse("assistant", is_chunk=True) - if not chunk.candidates: + # 使用辅助函数进行防御性检查和内容提取 + chunk_view = self._split_chunk_content(chunk) + + # 如果 candidate 为空,跳过 + if chunk_view.candidate is None: logger.warning(f"收到的 chunk 中 candidates 为空: {chunk}") continue - if not chunk.candidates[0].content: + + # 如果 parts 为空(content 为空),跳过 + if chunk_view.parts is None: logger.warning(f"收到的 chunk 中 content 为空: {chunk}") continue - if chunk.candidates[0].content.parts and any( - part.function_call for part in chunk.candidates[0].content.parts - ): + # 检查是否包含函数调用 + if chunk_view.has_function_call: llm_response = LLMResponse("assistant", is_chunk=False) llm_response.raw_completion = chunk llm_response.result_chain = self._process_content_parts( - chunk.candidates[0], + chunk_view.candidate, llm_response, ) llm_response.id = chunk.response_id @@ -657,28 +756,33 @@ async def _query_stream( yield llm_response return - _f = False - - # 提取 reasoning content - reasoning = self._extract_reasoning_content(chunk.candidates[0]) - if reasoning: - _f = True - accumulated_reasoning += reasoning - llm_response.reasoning_content = reasoning - if chunk.text: - _f = True - accumulated_text += chunk.text - llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)]) - if _f: + has_content = False + + # 处理思维链内容 + if chunk_view.reasoning_text: + has_content = True + accumulated_reasoning += chunk_view.reasoning_text + llm_response.reasoning_content = chunk_view.reasoning_text + + # 处理实际输出内容 + if chunk_view.visible_text: + has_content = True + accumulated_text += chunk_view.visible_text + llm_response.result_chain = MessageChain( + chain=[Comp.Plain(chunk_view.visible_text)] + ) + + if has_content: yield llm_response - if chunk.candidates[0].finish_reason: + # 检查是否为最终 chunk + if chunk_view.finish_reason: # Process the final chunk for potential tool calls or other content - if chunk.candidates[0].content.parts: + if chunk_view.parts: final_response = LLMResponse("assistant", is_chunk=False) final_response.raw_completion = chunk final_response.result_chain = self._process_content_parts( - chunk.candidates[0], + chunk_view.candidate, final_response, ) final_response.id = chunk.response_id