diff --git a/docs/en/learn/llm-connections.mdx b/docs/en/learn/llm-connections.mdx
index daedc21a27..574c1ec714 100644
--- a/docs/en/learn/llm-connections.mdx
+++ b/docs/en/learn/llm-connections.mdx
@@ -14,7 +14,64 @@ CrewAI uses LiteLLM to connect to a wide variety of Language Models (LLMs). This
You can easily configure your agents to use a different model or provider as described in this guide.
-## Supported Providers
+## Native NVIDIA Provider
+
+CrewAI includes native support for NVIDIA NIM (NVIDIA Inference Microservices), providing direct access to 180+ high-performance models including Qwen, LLaMA, DeepSeek R1, and Mistral.
+
+
+**Auto-Detection**: Models with "/" in the name (e.g., `qwen/qwen3-next-80b-a3b-instruct`) automatically use the NVIDIA native provider. No configuration needed beyond setting `NVIDIA_API_KEY`.
+
+
+### Quick Start
+
+
+
+ Visit [NVIDIA Build](https://build.nvidia.com/) to get a free API key (format: `nvapi-...`)
+
+
+ ```bash
+ export NVIDIA_API_KEY="nvapi-your-key-here"
+ ```
+
+
+ ```python
+ from crewai import Agent, LLM
+
+ # "/" in model name triggers NVIDIA provider automatically
+ llm = LLM(model="qwen/qwen3-next-80b-a3b-instruct", temperature=0.7)
+
+ agent = Agent(
+ role="Research Analyst",
+ goal="Analyze data and provide insights",
+ backstory="Expert in data analysis",
+ llm=llm
+ )
+ ```
+
+
+
+### Key Features
+
+- **180+ Models**: Chat, code, reasoning, vision, and safety models
+- **Auto-Detection**: Automatic routing for models with "/" in name
+- **Streaming Support**: Real-time response streaming
+- **Vision Models**: Llama 3.2 Vision (11B/90B), Phi-4 Vision
+- **Reasoning Models**: DeepSeek R1, QwQ-32B with chain-of-thought
+- **Built-in Security**: Input validation and resource management
+
+### Popular Models
+
+| Category | Model | Best For |
+|----------|-------|----------|
+| **Chat** | `qwen/qwen3-next-80b-a3b-instruct` | General conversation & analysis |
+| **Chat** | `meta/llama-3.1-70b-instruct` | High-quality responses |
+| **Code** | `qwen/qwen2.5-coder-32b-instruct` | Code generation & debugging |
+| **Reasoning** | `deepseek-ai/deepseek-r1` | Complex problem solving |
+| **Vision** | `meta/llama-3.2-90b-vision-instruct` | Image analysis & understanding |
+
+See [NVIDIA Build](https://build.nvidia.com/) for the complete model catalog.
+
+## LiteLLM Providers
LiteLLM supports a wide range of providers, including but not limited to:
@@ -36,7 +93,6 @@ LiteLLM supports a wide range of providers, including but not limited to:
- Groq
- SambaNova
- Nebius AI Studio
-- [NVIDIA NIMs](https://docs.api.nvidia.com/nim/reference/models-1)
- And many more!
For a complete and up-to-date list of supported providers, please refer to the [LiteLLM Providers documentation](https://docs.litellm.ai/docs/providers).
diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py
index 8bc1fe6486..57c9e4e618 100644
--- a/lib/crewai/src/crewai/llm.py
+++ b/lib/crewai/src/crewai/llm.py
@@ -9,6 +9,7 @@
import os
import sys
import threading
+import time
from typing import (
TYPE_CHECKING,
Any,
@@ -24,6 +25,12 @@
from pydantic import BaseModel, Field
from typing_extensions import Self
+# Cache for NVIDIA model list to avoid repeated API calls
+_nvidia_models_cache: set[str] | None = None
+_nvidia_cache_timestamp: float | None = None
+_NVIDIA_CACHE_TTL = 3600 # 1 hour cache expiration
+_nvidia_cache_lock = threading.Lock()
+
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.llm_events import (
LLMCallCompletedEvent,
@@ -316,6 +323,7 @@ def writable(self) -> bool:
"gemini",
"bedrock",
"aws",
+ "nvidia",
]
@@ -339,6 +347,75 @@ class AccumulatedToolArgs(BaseModel):
function: FunctionArgs = Field(default_factory=FunctionArgs)
+def _get_nvidia_models() -> set[str]:
+ """Fetch and cache the list of models available from NVIDIA NIM API.
+
+ Returns:
+ Set of model IDs available in NVIDIA's catalog
+ """
+ global _nvidia_models_cache, _nvidia_cache_timestamp
+
+ # Check if cache exists and hasn't expired
+ if _nvidia_models_cache is not None and _nvidia_cache_timestamp is not None:
+ if time.time() - _nvidia_cache_timestamp < _NVIDIA_CACHE_TTL:
+ return _nvidia_models_cache
+ # Cache expired - will refresh below
+
+ # Thread-safe cache initialization
+ with _nvidia_cache_lock:
+ # Double-check after acquiring lock (with TTL check)
+ if _nvidia_models_cache is not None and _nvidia_cache_timestamp is not None:
+ if time.time() - _nvidia_cache_timestamp < _NVIDIA_CACHE_TTL:
+ return _nvidia_models_cache
+ # Cache expired - proceed with refresh
+
+ # Accept both NVIDIA_API_KEY (build.nvidia.com) and NVIDIA_NIM_API_KEY (cloud endpoints)
+ api_key = os.getenv("NVIDIA_API_KEY") or os.getenv("NVIDIA_NIM_API_KEY")
+ if not api_key:
+ _nvidia_models_cache = set()
+ _nvidia_cache_timestamp = time.time()
+ return _nvidia_models_cache
+
+ try:
+ # Use httpx instead of requests for better security and async support
+ # All HTTP logic inside lock to prevent race conditions
+ with httpx.Client(timeout=5.0) as client:
+ response = client.get(
+ "https://integrate.api.nvidia.com/v1/models",
+ headers={"Authorization": f"Bearer {api_key}"},
+ )
+
+ if response.status_code == 200:
+ models = response.json().get("data", [])
+ # Dedupe model IDs (NVIDIA API has some duplicates)
+ _nvidia_models_cache = set([m["id"] for m in models])
+ _nvidia_cache_timestamp = time.time()
+ else:
+ logging.warning(
+ f"NVIDIA API returned status {response.status_code}"
+ )
+ _nvidia_models_cache = set()
+ _nvidia_cache_timestamp = time.time()
+ except httpx.TimeoutException:
+ logging.warning("NVIDIA API request timed out")
+ _nvidia_models_cache = set()
+ _nvidia_cache_timestamp = time.time()
+ except httpx.HTTPError as e:
+ # Sanitize error message to avoid leaking API keys
+ error_msg = str(e).replace(api_key, "***")
+ logging.warning(f"NVIDIA API request failed: {error_msg}")
+ _nvidia_models_cache = set()
+ _nvidia_cache_timestamp = time.time()
+ except Exception as e:
+ # Catch-all for unexpected errors, with API key sanitization
+ error_msg = str(e).replace(api_key, "***") if api_key else str(e)
+ logging.warning(f"Failed to fetch NVIDIA models: {error_msg}")
+ _nvidia_models_cache = set()
+ _nvidia_cache_timestamp = time.time()
+
+ return _nvidia_models_cache
+
+
class LLM(BaseLLM):
completion_cost: float | None = None
@@ -363,32 +440,75 @@ def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM:
use_native = True
model_string = model
elif "/" in model:
- prefix, _, model_part = model.partition("/")
-
- provider_mapping = {
- "openai": "openai",
- "anthropic": "anthropic",
- "claude": "anthropic",
- "azure": "azure",
- "azure_openai": "azure",
- "google": "gemini",
- "gemini": "gemini",
- "bedrock": "bedrock",
- "aws": "bedrock",
- }
-
- canonical_provider = provider_mapping.get(prefix.lower())
-
- if canonical_provider and cls._validate_model_in_constants(
- model_part, canonical_provider
- ):
- provider = canonical_provider
- use_native = True
- model_string = model_part
+ # If NVIDIA API key is set, check if model is in NVIDIA's catalog FIRST
+ # This is the most accurate way: route to NVIDIA if they have it
+ # Accept both NVIDIA_API_KEY and NVIDIA_NIM_API_KEY
+ if os.getenv("NVIDIA_API_KEY") or os.getenv("NVIDIA_NIM_API_KEY"):
+ nvidia_models = _get_nvidia_models()
+
+ if model in nvidia_models:
+ # Model is in NVIDIA's catalog - use NVIDIA
+ provider = "nvidia"
+ use_native = True
+ model_string = model
+ else:
+ # Model NOT in NVIDIA catalog - fall back to standard routing
+ prefix, _, model_part = model.partition("/")
+
+ provider_mapping = {
+ "openai": "openai",
+ "anthropic": "anthropic",
+ "claude": "anthropic",
+ "azure": "azure",
+ "azure_openai": "azure",
+ "google": "gemini",
+ "gemini": "gemini",
+ "bedrock": "bedrock",
+ "aws": "bedrock",
+ }
+
+ canonical_provider = provider_mapping.get(prefix.lower())
+
+ if canonical_provider and cls._validate_model_in_constants(
+ model_part, canonical_provider
+ ):
+ provider = canonical_provider
+ use_native = True
+ model_string = model_part
+ else:
+ # Not in NVIDIA and not recognized - try litellm
+ provider = prefix
+ use_native = False
+ model_string = model_part
else:
- provider = prefix
- use_native = False
- model_string = model_part
+ prefix, _, model_part = model.partition("/")
+
+ provider_mapping = {
+ "openai": "openai",
+ "anthropic": "anthropic",
+ "claude": "anthropic",
+ "azure": "azure",
+ "azure_openai": "azure",
+ "google": "gemini",
+ "gemini": "gemini",
+ "bedrock": "bedrock",
+ "aws": "bedrock",
+ }
+
+ canonical_provider = provider_mapping.get(prefix.lower())
+
+ if canonical_provider and cls._validate_model_in_constants(
+ model_part, canonical_provider
+ ):
+ provider = canonical_provider
+ use_native = True
+ model_string = model_part
+ else:
+ # Unknown provider - fall back to LiteLLM
+ # (NVIDIA models are handled by catalog check above when API key is set)
+ provider = prefix
+ use_native = False
+ model_string = model_part
else:
provider = cls._infer_provider_from_model(model)
use_native = True
@@ -446,10 +566,9 @@ def _matches_provider_pattern(cls, model: str, provider: str) -> bool:
)
if provider == "gemini" or provider == "google":
- return any(
- model_lower.startswith(prefix)
- for prefix in ["gemini-", "gemma-", "learnlm-"]
- )
+ # Only match Gemini-specific models, not open models like Gemma
+ # Gemma can be hosted on NVIDIA/other providers
+ return model_lower.startswith("gemini-") or model_lower.startswith("learnlm-")
if provider == "bedrock":
return "." in model_lower
@@ -460,6 +579,9 @@ def _matches_provider_pattern(cls, model: str, provider: str) -> bool:
for prefix in ["gpt-", "gpt-35-", "o1", "o3", "o4", "azure-"]
)
+ # NVIDIA routing is handled by dynamic catalog check in __new__
+ # No static pattern matching needed - always use catalog lookup
+
return False
@classmethod
@@ -559,6 +681,11 @@ def _get_native_provider(cls, provider: str) -> type | None:
return BedrockCompletion
+ if provider == "nvidia":
+ from crewai.llms.providers.nvidia.completion import NvidiaCompletion
+
+ return NvidiaCompletion
+
return None
def __init__(
diff --git a/lib/crewai/src/crewai/llms/constants.py b/lib/crewai/src/crewai/llms/constants.py
index 9552efada0..b0a36fa584 100644
--- a/lib/crewai/src/crewai/llms/constants.py
+++ b/lib/crewai/src/crewai/llms/constants.py
@@ -566,3 +566,12 @@
"qwen.qwen3-coder-30b-a3b-v1:0",
"twelvelabs.pegasus-1-2-v1:0",
]
+
+
+# NVIDIA models (Jan 2026) - pattern matching handles all models with "/" format
+NVIDIA_MODELS = [
+ "qwen/qwen3-next-80b-a3b-instruct", # Latest Qwen3-Next, excellent tool calling (Jan 2026)
+ "qwen/qwen2.5-7b-instruct", # Efficient general-purpose model
+ "deepseek-ai/deepseek-r1-distill-qwen-14b", # Reasoning with Qwen base (Jan 2026)
+ "nvidia/cosmos-reason2-8b", # Vision + reasoning model (Jan 2026)
+]
diff --git a/lib/crewai/src/crewai/llms/providers/nvidia/__init__.py b/lib/crewai/src/crewai/llms/providers/nvidia/__init__.py
new file mode 100644
index 0000000000..69401db179
--- /dev/null
+++ b/lib/crewai/src/crewai/llms/providers/nvidia/__init__.py
@@ -0,0 +1,5 @@
+"""NVIDIA provider for CrewAI."""
+
+from crewai.llms.providers.nvidia.completion import NvidiaCompletion
+
+__all__ = ["NvidiaCompletion"]
diff --git a/lib/crewai/src/crewai/llms/providers/nvidia/completion.py b/lib/crewai/src/crewai/llms/providers/nvidia/completion.py
new file mode 100644
index 0000000000..ac2005cdbe
--- /dev/null
+++ b/lib/crewai/src/crewai/llms/providers/nvidia/completion.py
@@ -0,0 +1,1486 @@
+from __future__ import annotations
+
+from collections.abc import AsyncIterator
+import json
+import logging
+import os
+import re
+import time
+from typing import TYPE_CHECKING, Any, Self
+
+import httpx
+from openai import APIConnectionError, AsyncOpenAI, NotFoundError, OpenAI, Stream
+from openai.lib.streaming.chat import AsyncChatCompletionStream, ChatCompletionStream
+from openai.types.chat import ChatCompletion, ChatCompletionChunk
+from openai.types.chat.chat_completion import Choice
+from openai.types.chat.chat_completion_chunk import ChoiceDelta
+from pydantic import BaseModel
+
+from crewai.events.types.llm_events import LLMCallType
+from crewai.llms.base_llm import BaseLLM
+from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
+from crewai.utilities.agent_utils import is_context_length_exceeded
+from crewai.utilities.exceptions.context_window_exceeding_exception import (
+ LLMContextLengthExceededError,
+)
+from crewai.utilities.pydantic_schema_utils import generate_model_description
+from crewai.utilities.types import LLMMessage
+
+
+if TYPE_CHECKING:
+ from crewai.agent.core import Agent
+ from crewai.llms.hooks.base import BaseInterceptor
+ from crewai.task import Task
+ from crewai.tools.base_tool import BaseTool
+
+# Cache TTL for model metadata (1 hour)
+METADATA_CACHE_TTL = 3600
+
+# Valid model name pattern: alphanumeric, hyphens, underscores, slashes, dots
+# Examples: meta/llama-3.1-70b-instruct, nvidia/nemo-retriever-embedding-v1
+MODEL_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_.\-/]+$")
+
+
+class NvidiaCompletion(BaseLLM):
+ """NVIDIA native completion implementation.
+
+ This class provides direct integration with NVIDIA using the OpenAI-compatible API,
+ offering native structured outputs, function calling, and streaming support.
+
+ NVIDIA uses the OpenAI Python SDK since their API is OpenAI-compatible.
+ Default base URL: https://integrate.api.nvidia.com/v1
+ """
+
+ def __init__(
+ self,
+ model: str = "meta/llama-3.1-70b-instruct",
+ api_key: str | None = None,
+ base_url: str | None = None,
+ timeout: float | None = None,
+ max_retries: int = 2,
+ default_headers: dict[str, str] | None = None,
+ default_query: dict[str, Any] | None = None,
+ client_params: dict[str, Any] | None = None,
+ temperature: float | None = None,
+ top_p: float | None = None,
+ frequency_penalty: float | None = None,
+ presence_penalty: float | None = None,
+ max_tokens: int | None = None,
+ seed: int | None = None,
+ stream: bool = False,
+ response_format: dict[str, Any] | type[BaseModel] | None = None,
+ logprobs: bool | None = None,
+ top_logprobs: int | None = None,
+ provider: str | None = None,
+ interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize NVIDIA chat completion client.
+
+ Args:
+ model: NVIDIA model name (e.g., 'meta/llama-3.1-70b-instruct')
+ api_key: NVIDIA API key (defaults to NVIDIA_API_KEY env var)
+ base_url: NVIDIA base URL (defaults to https://integrate.api.nvidia.com/v1)
+ timeout: Request timeout in seconds
+ max_retries: Maximum number of retries
+ temperature: Sampling temperature (0-1)
+ top_p: Nucleus sampling parameter
+ frequency_penalty: Frequency penalty
+ presence_penalty: Presence penalty
+ max_tokens: Maximum tokens in response
+ seed: Random seed for reproducibility
+ stream: Enable streaming responses
+ response_format: Response format configuration
+ logprobs: Include log probabilities
+ top_logprobs: Number of top log probabilities
+ provider: Provider name (defaults to 'nvidia')
+ interceptor: HTTP interceptor for transport-level modifications
+ **kwargs: Additional parameters
+ """
+
+ if provider is None:
+ provider = kwargs.pop("provider", "nvidia")
+
+ # Validate model name to prevent injection attacks
+ if not MODEL_NAME_PATTERN.match(model):
+ raise ValueError(
+ f"Invalid model name: '{model}'. Model names must only contain "
+ "alphanumeric characters, hyphens, underscores, slashes, and dots."
+ )
+
+ self.interceptor = interceptor
+ # Client configuration attributes
+ self.max_retries = max_retries
+ self.default_headers = default_headers
+ self.default_query = default_query
+ self.client_params = client_params
+ self.timeout = timeout
+ self.base_url = base_url
+ self.api_base = kwargs.pop("api_base", None)
+
+ super().__init__(
+ model=model,
+ temperature=temperature,
+ api_key=api_key or os.getenv("NVIDIA_API_KEY") or os.getenv("NVIDIA_NIM_API_KEY"),
+ base_url=base_url,
+ timeout=timeout,
+ provider=provider,
+ **kwargs,
+ )
+
+ # Initialize clients without requiring API key (deferred to actual API calls)
+ client_config = self._get_client_params(require_key=False)
+ if self.interceptor:
+ transport = HTTPTransport(interceptor=self.interceptor)
+ http_client = httpx.Client(transport=transport)
+ client_config["http_client"] = http_client
+
+ self.client = OpenAI(**client_config)
+
+ async_client_config = self._get_client_params(require_key=False)
+ if self.interceptor:
+ async_transport = AsyncHTTPTransport(interceptor=self.interceptor)
+ async_http_client = httpx.AsyncClient(transport=async_transport)
+ async_client_config["http_client"] = async_http_client
+
+ self.async_client = AsyncOpenAI(**async_client_config)
+
+ # Completion parameters
+ self.top_p = top_p
+ self.frequency_penalty = frequency_penalty
+ self.presence_penalty = presence_penalty
+ self.max_tokens = max_tokens
+ self.seed = seed
+ self.stream = stream
+ self.response_format = response_format
+ self.logprobs = logprobs
+ self.top_logprobs = top_logprobs
+
+ # Detect model capabilities
+ model_lower = model.lower()
+ self.is_vision_model = any(indicator in model_lower for indicator in [
+ "-vl", # Vision-Language suffix
+ "deplot", # Google chart understanding
+ "fuyu", # Adept multimodal
+ "kosmos", # Microsoft multimodal grounding
+ "mistral-large-3", # Mistral Large 3 vision support
+ "multimodal", # Explicit multimodal designation
+ "nemotron-nano", # NVIDIA Nemotron Nano series
+ "neva", # NVIDIA vision-language
+ "nvclip", # NVIDIA CLIP
+ "paligemma", # Google vision-language
+ "streampetr", # NVIDIA perception model
+ "vila", # NVIDIA Visual Language Assistant
+ "vision", # Contains 'vision' in name
+ "vlm", # Vision-Language Model designation
+ ])
+ self.supports_tools = self._check_tool_support(model)
+
+ # Cache for model metadata from API with TTL
+ self._model_metadata: dict[str, Any] | None = None
+ self._model_metadata_timestamp: float = 0.0
+
+ def close(self) -> None:
+ """Close HTTP clients to release resources.
+
+ This method should be called when the NvidiaCompletion instance is no longer needed
+ to properly clean up HTTP connections and release resources.
+
+ Usage:
+ completion = NvidiaCompletion(model="meta/llama-3.1-70b-instruct")
+ try:
+ # Use completion
+ result = completion.call(messages)
+ finally:
+ completion.close()
+ """
+ try:
+ if hasattr(self, "client") and self.client:
+ self.client.close()
+ except Exception as e:
+ logging.debug(f"Error closing sync client: {e}")
+
+ try:
+ if hasattr(self, "async_client") and self.async_client:
+ # AsyncOpenAI client needs to be awaited, but __del__ is sync
+ # So we just close the underlying HTTP client if it exists
+ if hasattr(self.async_client, "_client"):
+ self.async_client._client.close()
+ except Exception as e:
+ logging.debug(f"Error closing async client: {e}")
+
+ def __del__(self) -> None:
+ """Destructor to ensure HTTP clients are closed."""
+ self.close()
+
+ def __enter__(self) -> Self:
+ """Enter context manager."""
+ return self
+
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
+ """Exit context manager and close clients."""
+ self.close()
+
+ def _get_client_params(self, require_key: bool = True) -> dict[str, Any]:
+ """Get NVIDIA client parameters.
+
+ Args:
+ require_key: If True, raises error when API key is missing.
+ If False, returns params with None API key (for non-API operations).
+ """
+
+ if self.api_key is None:
+ self.api_key = os.getenv("NVIDIA_API_KEY")
+ if self.api_key is None and require_key:
+ raise ValueError(
+ "NVIDIA_API_KEY is required. Get your API key from https://build.nvidia.com/"
+ )
+
+ # Default to NVIDIA's integrated API endpoint
+ default_base_url = "https://integrate.api.nvidia.com/v1"
+
+ base_params = {
+ "api_key": self.api_key or "placeholder", # Placeholder for initialization
+ "base_url": self.base_url
+ or self.api_base
+ or os.getenv("NVIDIA_BASE_URL")
+ or default_base_url,
+ "timeout": self.timeout,
+ "max_retries": self.max_retries,
+ "default_headers": self.default_headers,
+ "default_query": self.default_query,
+ }
+
+ client_params = {k: v for k, v in base_params.items() if v is not None}
+
+ if self.client_params:
+ client_params.update(self.client_params)
+
+ return client_params
+
+ def _check_tool_support(self, model: str) -> bool:
+ """Check if the model supports tool/function calling.
+
+ Most modern NVIDIA models support tools. If a model doesn't support tools,
+ the API will return an appropriate error which we handle gracefully.
+ """
+ return True # Default to True, let API handle unsupported models
+
+ def _fetch_model_metadata(self) -> dict[str, Any] | None:
+ """Fetch model metadata from NVIDIA API with TTL-based caching.
+
+ Queries the /v1/models endpoint to get model information including max_model_len.
+ Results are cached with a TTL to prevent cache poisoning and stale data.
+
+ Returns:
+ Model metadata dict with fields like max_model_len, or None if fetch fails
+ """
+ current_time = time.time()
+
+ # Check if cache is valid (exists and not expired)
+ if self._model_metadata is not None:
+ cache_age = current_time - self._model_metadata_timestamp
+ if cache_age < METADATA_CACHE_TTL:
+ return self._model_metadata
+ else:
+ logging.debug(
+ f"Model metadata cache expired ({cache_age:.1f}s > {METADATA_CACHE_TTL}s), refreshing..."
+ )
+
+ try:
+ # Query /v1/models endpoint to get model list
+ models = self.client.models.list()
+
+ # Find our specific model in the list
+ for model_obj in models.data:
+ if model_obj.id == self.model:
+ # Convert to dict
+ if hasattr(model_obj, "model_dump"):
+ metadata = model_obj.model_dump()
+ else:
+ metadata = model_obj.__dict__
+
+ # Validate metadata structure before caching
+ if not isinstance(metadata, dict):
+ logging.warning(
+ f"Invalid metadata type for {self.model}: {type(metadata)}"
+ )
+ self._model_metadata = {}
+ self._model_metadata_timestamp = current_time
+ return None
+
+ # Cache validated metadata with timestamp
+ self._model_metadata = metadata
+ self._model_metadata_timestamp = current_time
+
+ logging.debug(
+ f"Fetched and cached metadata for {self.model}: {self._model_metadata}"
+ )
+ return self._model_metadata
+
+ # Model not found in list - cache empty dict to avoid repeated lookups
+ logging.debug(f"Model {self.model} not found in /v1/models response")
+ self._model_metadata = {}
+ self._model_metadata_timestamp = current_time
+ return None
+
+ except Exception as e:
+ # API call failed - cache empty dict and return None
+ logging.debug(f"Failed to fetch model metadata: {e}")
+ self._model_metadata = {}
+ self._model_metadata_timestamp = current_time
+ return None
+
+ def call(
+ self,
+ messages: str | list[LLMMessage],
+ tools: list[dict[str, BaseTool]] | None = None,
+ callbacks: list[Any] | None = None,
+ available_functions: dict[str, Any] | None = None,
+ from_task: Task | None = None,
+ from_agent: Agent | None = None,
+ response_model: type[BaseModel] | None = None,
+ ) -> str | Any:
+ """Call NVIDIA NIM chat completion API.
+
+ Args:
+ messages: Input messages for the chat completion
+ tools: list of tool/function definitions
+ callbacks: Callback functions (not used in native implementation)
+ available_functions: Available functions for tool calling
+ from_task: Task that initiated the call
+ from_agent: Agent that initiated the call
+ response_model: Response model for structured output.
+
+ Returns:
+ Chat completion response or tool call result
+ """
+ # Validate API key before making actual API call
+ if not self.api_key and not (os.getenv("NVIDIA_API_KEY") or os.getenv("NVIDIA_NIM_API_KEY")):
+ raise ValueError(
+ "NVIDIA_API_KEY or NVIDIA_NIM_API_KEY is required for API calls. Get your API key from https://build.nvidia.com/"
+ )
+
+ try:
+ self._emit_call_started_event(
+ messages=messages,
+ tools=tools,
+ callbacks=callbacks,
+ available_functions=available_functions,
+ from_task=from_task,
+ from_agent=from_agent,
+ )
+
+ formatted_messages = self._format_messages(messages)
+
+ if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
+ raise ValueError("LLM call blocked by before_llm_call hook")
+
+ completion_params = self._prepare_completion_params(
+ messages=formatted_messages, tools=tools
+ )
+
+ if self.stream:
+ return self._handle_streaming_completion(
+ params=completion_params,
+ available_functions=available_functions,
+ from_task=from_task,
+ from_agent=from_agent,
+ response_model=response_model,
+ )
+
+ return self._handle_completion(
+ params=completion_params,
+ available_functions=available_functions,
+ from_task=from_task,
+ from_agent=from_agent,
+ response_model=response_model,
+ )
+
+ except Exception as e:
+ error_msg = f"NVIDIA NIM API call failed: {e!s}"
+ logging.error(error_msg)
+ self._emit_call_failed_event(
+ error=error_msg, from_task=from_task, from_agent=from_agent
+ )
+ raise
+
+ async def acall(
+ self,
+ messages: str | list[LLMMessage],
+ tools: list[dict[str, BaseTool]] | None = None,
+ callbacks: list[Any] | None = None,
+ available_functions: dict[str, Any] | None = None,
+ from_task: Task | None = None,
+ from_agent: Agent | None = None,
+ response_model: type[BaseModel] | None = None,
+ ) -> str | Any:
+ """Async call to NVIDIA NIM chat completion API.
+
+ Args:
+ messages: Input messages for the chat completion
+ tools: list of tool/function definitions
+ callbacks: Callback functions (not used in native implementation)
+ available_functions: Available functions for tool calling
+ from_task: Task that initiated the call
+ from_agent: Agent that initiated the call
+ response_model: Response model for structured output.
+
+ Returns:
+ Chat completion response or tool call result
+ """
+ # Validate API key before making actual API call
+ if not self.api_key and not (os.getenv("NVIDIA_API_KEY") or os.getenv("NVIDIA_NIM_API_KEY")):
+ raise ValueError(
+ "NVIDIA_API_KEY or NVIDIA_NIM_API_KEY is required for API calls. Get your API key from https://build.nvidia.com/"
+ )
+
+ try:
+ self._emit_call_started_event(
+ messages=messages,
+ tools=tools,
+ callbacks=callbacks,
+ available_functions=available_functions,
+ from_task=from_task,
+ from_agent=from_agent,
+ )
+
+ formatted_messages = self._format_messages(messages)
+
+ if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
+ raise ValueError("LLM call blocked by before_llm_call hook")
+
+ completion_params = self._prepare_completion_params(
+ messages=formatted_messages, tools=tools
+ )
+
+ if self.stream:
+ return await self._ahandle_streaming_completion(
+ params=completion_params,
+ available_functions=available_functions,
+ from_task=from_task,
+ from_agent=from_agent,
+ response_model=response_model,
+ )
+
+ return await self._ahandle_completion(
+ params=completion_params,
+ available_functions=available_functions,
+ from_task=from_task,
+ from_agent=from_agent,
+ response_model=response_model,
+ )
+
+ except Exception as e:
+ error_msg = f"NVIDIA NIM API call failed: {e!s}"
+ logging.error(error_msg)
+ self._emit_call_failed_event(
+ error=error_msg, from_task=from_task, from_agent=from_agent
+ )
+ raise
+
+ def _prepare_completion_params(
+ self, messages: list[LLMMessage], tools: list[dict[str, BaseTool]] | None = None
+ ) -> dict[str, Any]:
+ """Prepare parameters for NVIDIA NIM chat completion."""
+ params: dict[str, Any] = {
+ "model": self.model,
+ "messages": messages,
+ }
+ if self.stream:
+ params["stream"] = self.stream
+ # Note: stream_options not supported by all NVIDIA models
+
+ params.update(self.additional_params)
+
+ if self.temperature is not None:
+ params["temperature"] = self.temperature
+ if self.top_p is not None:
+ params["top_p"] = self.top_p
+ if self.frequency_penalty is not None:
+ params["frequency_penalty"] = self.frequency_penalty
+ if self.presence_penalty is not None:
+ params["presence_penalty"] = self.presence_penalty
+ if self.max_tokens is not None:
+ params["max_tokens"] = self.max_tokens
+ elif self._is_reasoning_model():
+ # Reasoning models (DeepSeek R1, V3, etc.) request entire context window
+ # when max_tokens is not specified, causing API errors. Set sensible default.
+ params["max_tokens"] = 4096
+ if self.seed is not None:
+ params["seed"] = self.seed
+ if self.logprobs is not None:
+ params["logprobs"] = self.logprobs
+ if self.top_logprobs is not None:
+ params["top_logprobs"] = self.top_logprobs
+
+ if self.response_format is not None:
+ if isinstance(self.response_format, type) and issubclass(
+ self.response_format, BaseModel
+ ):
+ params["response_format"] = generate_model_description(
+ self.response_format
+ )
+ elif isinstance(self.response_format, dict):
+ params["response_format"] = self.response_format
+
+ if tools and self.supports_tools:
+ params["tools"] = self._convert_tools_for_interference(tools)
+ params["tool_choice"] = "auto"
+
+ # Filter out CrewAI-specific parameters that shouldn't go to the API
+ crewai_specific_params = {
+ "callbacks",
+ "available_functions",
+ "from_task",
+ "from_agent",
+ "provider",
+ "api_key",
+ "base_url",
+ "api_base",
+ "timeout",
+ }
+
+ return {k: v for k, v in params.items() if k not in crewai_specific_params}
+
+ def _is_reasoning_model(self) -> bool:
+ """Detect if the current model is a reasoning model (DeepSeek R1, V3, etc.).
+
+ Reasoning models have special behavior where they request the entire context window
+ when max_tokens is not specified, which can cause API errors.
+
+ Returns:
+ True if the model is a reasoning model, False otherwise.
+ """
+ model_lower = self.model.lower()
+ reasoning_patterns = [
+ "deepseek-r1",
+ "deepseek-v3",
+ "deepseek-ai/deepseek-r1",
+ "deepseek-ai/deepseek-v3",
+ "gpt-oss", # OpenAI GPT-OSS models exhibit similar behavior
+ ]
+ return any(pattern in model_lower for pattern in reasoning_patterns)
+
+ def _convert_tools_for_interference(
+ self, tools: list[dict[str, BaseTool]]
+ ) -> list[dict[str, Any]]:
+ """Convert CrewAI tool format to OpenAI function calling format (NVIDIA NIM compatible)."""
+ from crewai.llms.providers.utils.common import safe_tool_conversion
+
+ nvidia_tools = []
+
+ for tool in tools:
+ name, description, parameters = safe_tool_conversion(tool, "NVIDIA NIM")
+
+ nvidia_tool = {
+ "type": "function",
+ "function": {
+ "name": name,
+ "description": description,
+ },
+ }
+
+ if parameters:
+ if isinstance(parameters, dict):
+ nvidia_tool["function"]["parameters"] = parameters # type: ignore
+ else:
+ nvidia_tool["function"]["parameters"] = dict(parameters)
+
+ nvidia_tools.append(nvidia_tool)
+ return nvidia_tools
+
+ def _handle_completion(
+ self,
+ params: dict[str, Any],
+ available_functions: dict[str, Any] | None = None,
+ from_task: Any | None = None,
+ from_agent: Any | None = None,
+ response_model: type[BaseModel] | None = None,
+ ) -> str | Any:
+ """Handle non-streaming chat completion."""
+ try:
+ if response_model:
+ parse_params = {
+ k: v for k, v in params.items() if k != "response_format"
+ }
+ parsed_response = self.client.beta.chat.completions.parse(
+ **parse_params,
+ response_format=response_model,
+ )
+ math_reasoning = parsed_response.choices[0].message
+
+ if math_reasoning.refusal:
+ pass
+
+ usage = self._extract_token_usage(parsed_response)
+ self._track_token_usage_internal(usage)
+
+ parsed_object = parsed_response.choices[0].message.parsed
+ if parsed_object:
+ structured_json = parsed_object.model_dump_json()
+ self._emit_call_completed_event(
+ response=structured_json,
+ call_type=LLMCallType.LLM_CALL,
+ from_task=from_task,
+ from_agent=from_agent,
+ messages=params["messages"],
+ )
+ return structured_json
+ else:
+ raise ValueError(f"Structured output parsing returned None for response_model {response_model.__name__}")
+
+ response: ChatCompletion = self.client.chat.completions.create(**params)
+
+ usage = self._extract_token_usage(response)
+
+ self._track_token_usage_internal(usage)
+
+ choice: Choice = response.choices[0]
+ message = choice.message
+
+ if message.tool_calls and available_functions:
+ tool_call = message.tool_calls[0]
+ function_name = tool_call.function.name
+
+ try:
+ function_args = json.loads(tool_call.function.arguments)
+ except json.JSONDecodeError as e:
+ logging.error(f"Failed to parse tool arguments: {e}")
+ function_args = {}
+
+ result = self._handle_tool_execution(
+ function_name=function_name,
+ function_args=function_args,
+ available_functions=available_functions,
+ from_task=from_task,
+ from_agent=from_agent,
+ )
+
+ if result is not None:
+ return result
+
+ # Return final answer (message.content) first, fallback to reasoning_content
+ # For reasoning models like DeepSeek R1, message.content is the answer
+ content = message.content or getattr(message, 'reasoning_content', None) or ""
+ content = self._apply_stop_words(content)
+
+ if self.response_format and isinstance(self.response_format, type):
+ try:
+ structured_result = self._validate_structured_output(
+ content, self.response_format
+ )
+ self._emit_call_completed_event(
+ response=structured_result,
+ call_type=LLMCallType.LLM_CALL,
+ from_task=from_task,
+ from_agent=from_agent,
+ messages=params["messages"],
+ )
+ return structured_result
+ except ValueError as e:
+ logging.warning(f"Structured output validation failed: {e}")
+
+ self._emit_call_completed_event(
+ response=content,
+ call_type=LLMCallType.LLM_CALL,
+ from_task=from_task,
+ from_agent=from_agent,
+ messages=params["messages"],
+ )
+
+ if usage.get("total_tokens", 0) > 0:
+ logging.info(f"NVIDIA NIM API usage: {usage}")
+
+ content = self._invoke_after_llm_call_hooks(
+ params["messages"], content, from_agent
+ )
+ except NotFoundError as e:
+ error_msg = f"Model {self.model} not found: {e}"
+ logging.error(error_msg)
+ self._emit_call_failed_event(
+ error=error_msg, from_task=from_task, from_agent=from_agent
+ )
+ raise ValueError(error_msg) from e
+ except APIConnectionError as e:
+ error_msg = f"Failed to connect to NVIDIA NIM API: {e}"
+ logging.error(error_msg)
+ self._emit_call_failed_event(
+ error=error_msg, from_task=from_task, from_agent=from_agent
+ )
+ raise ConnectionError(error_msg) from e
+ except Exception as e:
+ # Handle context length exceeded and other errors
+ if is_context_length_exceeded(e):
+ logging.error(f"Context window exceeded: {e}")
+ raise LLMContextLengthExceededError(str(e)) from e
+
+ error_msg = f"NVIDIA NIM API call failed: {e!s}"
+ logging.error(error_msg)
+ self._emit_call_failed_event(
+ error=error_msg, from_task=from_task, from_agent=from_agent
+ )
+ raise e from e
+
+ return content
+
+ def _handle_streaming_completion(
+ self,
+ params: dict[str, Any],
+ available_functions: dict[str, Any] | None = None,
+ from_task: Any | None = None,
+ from_agent: Any | None = None,
+ response_model: type[BaseModel] | None = None,
+ ) -> str:
+ """Handle streaming chat completion."""
+ full_response = ""
+ tool_calls: dict[int, dict[str, Any]] = {}
+
+ if response_model:
+ parse_params = {
+ k: v
+ for k, v in params.items()
+ if k not in ("response_format", "stream")
+ }
+
+ stream: ChatCompletionStream[BaseModel]
+ with self.client.beta.chat.completions.stream(
+ **parse_params, response_format=response_model
+ ) as stream:
+ for chunk in stream:
+ if chunk.type == "content.delta":
+ delta_content = chunk.delta
+ if delta_content:
+ self._emit_stream_chunk_event(
+ chunk=delta_content,
+ from_task=from_task,
+ from_agent=from_agent,
+ )
+
+ final_completion = stream.get_final_completion()
+ if final_completion:
+ usage = self._extract_token_usage(final_completion)
+ self._track_token_usage_internal(usage)
+ if final_completion.choices:
+ parsed_result = final_completion.choices[0].message.parsed
+ if parsed_result:
+ structured_json = parsed_result.model_dump_json()
+ self._emit_call_completed_event(
+ response=structured_json,
+ call_type=LLMCallType.LLM_CALL,
+ from_task=from_task,
+ from_agent=from_agent,
+ messages=params["messages"],
+ )
+ return structured_json
+
+ logging.error("Failed to get parsed result from stream")
+ return ""
+
+ completion_stream: Stream[ChatCompletionChunk] = (
+ self.client.chat.completions.create(**params)
+ )
+
+ usage_data = {"total_tokens": 0}
+
+ for completion_chunk in completion_stream:
+ if hasattr(completion_chunk, "usage") and completion_chunk.usage:
+ usage_data = self._extract_token_usage(completion_chunk)
+ continue
+
+ if not completion_chunk.choices:
+ continue
+
+ choice = completion_chunk.choices[0]
+ chunk_delta: ChoiceDelta = choice.delta
+
+ if chunk_delta.content:
+ full_response += chunk_delta.content
+ self._emit_stream_chunk_event(
+ chunk=chunk_delta.content,
+ from_task=from_task,
+ from_agent=from_agent,
+ )
+
+ if chunk_delta.tool_calls:
+ for tool_call in chunk_delta.tool_calls:
+ tool_index = tool_call.index if tool_call.index is not None else 0
+ if tool_index not in tool_calls:
+ tool_calls[tool_index] = {
+ "id": tool_call.id,
+ "name": "",
+ "arguments": "",
+ "index": tool_index,
+ }
+ elif tool_call.id and not tool_calls[tool_index]["id"]:
+ tool_calls[tool_index]["id"] = tool_call.id
+
+ if tool_call.function and tool_call.function.name:
+ tool_calls[tool_index]["name"] = tool_call.function.name
+ if tool_call.function and tool_call.function.arguments:
+ tool_calls[tool_index]["arguments"] += (
+ tool_call.function.arguments
+ )
+
+ self._emit_stream_chunk_event(
+ chunk=tool_call.function.arguments
+ if tool_call.function and tool_call.function.arguments
+ else "",
+ from_task=from_task,
+ from_agent=from_agent,
+ tool_call={
+ "id": tool_calls[tool_index]["id"],
+ "function": {
+ "name": tool_calls[tool_index]["name"],
+ "arguments": tool_calls[tool_index]["arguments"],
+ },
+ "type": "function",
+ "index": tool_calls[tool_index]["index"],
+ },
+ call_type=LLMCallType.TOOL_CALL,
+ )
+
+ self._track_token_usage_internal(usage_data)
+
+ if tool_calls and available_functions:
+ for call_data in tool_calls.values():
+ function_name = call_data["name"]
+ arguments = call_data["arguments"]
+
+ # Skip if function name is empty or arguments are empty
+ if not function_name or not arguments:
+ continue
+
+ # Check if function exists in available functions
+ if function_name not in available_functions:
+ logging.warning(
+ f"Function '{function_name}' not found in available functions"
+ )
+ continue
+
+ try:
+ function_args = json.loads(arguments)
+ except json.JSONDecodeError as e:
+ logging.error(f"Failed to parse streamed tool arguments: {e}")
+ continue
+
+ result = self._handle_tool_execution(
+ function_name=function_name,
+ function_args=function_args,
+ available_functions=available_functions,
+ from_task=from_task,
+ from_agent=from_agent,
+ )
+
+ if result is not None:
+ return result
+
+ full_response = self._apply_stop_words(full_response)
+
+ self._emit_call_completed_event(
+ response=full_response,
+ call_type=LLMCallType.LLM_CALL,
+ from_task=from_task,
+ from_agent=from_agent,
+ messages=params["messages"],
+ )
+
+ return self._invoke_after_llm_call_hooks(
+ params["messages"], full_response, from_agent
+ )
+
+ async def _ahandle_completion(
+ self,
+ params: dict[str, Any],
+ available_functions: dict[str, Any] | None = None,
+ from_task: Any | None = None,
+ from_agent: Any | None = None,
+ response_model: type[BaseModel] | None = None,
+ ) -> str | Any:
+ """Handle non-streaming async chat completion."""
+ try:
+ if response_model:
+ parse_params = {
+ k: v for k, v in params.items() if k != "response_format"
+ }
+ parsed_response = await self.async_client.beta.chat.completions.parse(
+ **parse_params,
+ response_format=response_model,
+ )
+ math_reasoning = parsed_response.choices[0].message
+
+ if math_reasoning.refusal:
+ pass
+
+ usage = self._extract_token_usage(parsed_response)
+ self._track_token_usage_internal(usage)
+
+ parsed_object = parsed_response.choices[0].message.parsed
+ if parsed_object:
+ structured_json = parsed_object.model_dump_json()
+ self._emit_call_completed_event(
+ response=structured_json,
+ call_type=LLMCallType.LLM_CALL,
+ from_task=from_task,
+ from_agent=from_agent,
+ messages=params["messages"],
+ )
+ return structured_json
+ else:
+ raise ValueError(f"Structured output parsing returned None for response_model {response_model.__name__}")
+
+ response: ChatCompletion = await self.async_client.chat.completions.create(
+ **params
+ )
+
+ usage = self._extract_token_usage(response)
+
+ self._track_token_usage_internal(usage)
+
+ choice: Choice = response.choices[0]
+ message = choice.message
+
+ if message.tool_calls and available_functions:
+ tool_call = message.tool_calls[0]
+ function_name = tool_call.function.name
+
+ try:
+ function_args = json.loads(tool_call.function.arguments)
+ except json.JSONDecodeError as e:
+ logging.error(f"Failed to parse tool arguments: {e}")
+ function_args = {}
+
+ result = self._handle_tool_execution(
+ function_name=function_name,
+ function_args=function_args,
+ available_functions=available_functions,
+ from_task=from_task,
+ from_agent=from_agent,
+ )
+
+ if result is not None:
+ return result
+
+ # Return final answer (message.content) first, fallback to reasoning_content
+ # For reasoning models like DeepSeek R1, message.content is the answer
+ content = message.content or getattr(message, 'reasoning_content', None) or ""
+ content = self._apply_stop_words(content)
+
+ if self.response_format and isinstance(self.response_format, type):
+ try:
+ structured_result = self._validate_structured_output(
+ content, self.response_format
+ )
+ self._emit_call_completed_event(
+ response=structured_result,
+ call_type=LLMCallType.LLM_CALL,
+ from_task=from_task,
+ from_agent=from_agent,
+ messages=params["messages"],
+ )
+ return structured_result
+ except ValueError as e:
+ logging.warning(f"Structured output validation failed: {e}")
+
+ self._emit_call_completed_event(
+ response=content,
+ call_type=LLMCallType.LLM_CALL,
+ from_task=from_task,
+ from_agent=from_agent,
+ messages=params["messages"],
+ )
+
+ if usage.get("total_tokens", 0) > 0:
+ logging.info(f"NVIDIA NIM API usage: {usage}")
+
+ content = self._invoke_after_llm_call_hooks(
+ params["messages"], content, from_agent
+ )
+ except NotFoundError as e:
+ error_msg = f"Model {self.model} not found: {e}"
+ logging.error(error_msg)
+ self._emit_call_failed_event(
+ error=error_msg, from_task=from_task, from_agent=from_agent
+ )
+ raise ValueError(error_msg) from e
+ except APIConnectionError as e:
+ error_msg = f"Failed to connect to NVIDIA NIM API: {e}"
+ logging.error(error_msg)
+ self._emit_call_failed_event(
+ error=error_msg, from_task=from_task, from_agent=from_agent
+ )
+ raise ConnectionError(error_msg) from e
+ except Exception as e:
+ if is_context_length_exceeded(e):
+ logging.error(f"Context window exceeded: {e}")
+ raise LLMContextLengthExceededError(str(e)) from e
+
+ error_msg = f"NVIDIA NIM API call failed: {e!s}"
+ logging.error(error_msg)
+ self._emit_call_failed_event(
+ error=error_msg, from_task=from_task, from_agent=from_agent
+ )
+ raise e from e
+
+ return content
+
+ async def _ahandle_streaming_completion(
+ self,
+ params: dict[str, Any],
+ available_functions: dict[str, Any] | None = None,
+ from_task: Any | None = None,
+ from_agent: Any | None = None,
+ response_model: type[BaseModel] | None = None,
+ ) -> str:
+ """Handle async streaming chat completion."""
+ full_response = ""
+ tool_calls: dict[int, dict[str, Any]] = {}
+
+ if response_model:
+ parse_params = {
+ k: v
+ for k, v in params.items()
+ if k not in ("response_format", "stream")
+ }
+
+ stream: AsyncChatCompletionStream[BaseModel]
+ async with self.async_client.beta.chat.completions.stream(
+ **parse_params, response_format=response_model
+ ) as stream:
+ async for chunk in stream:
+ if chunk.type == "content.delta":
+ delta_content = chunk.delta
+ if delta_content:
+ self._emit_stream_chunk_event(
+ chunk=delta_content,
+ from_task=from_task,
+ from_agent=from_agent,
+ )
+
+ final_completion = await stream.get_final_completion()
+ if final_completion:
+ usage = self._extract_token_usage(final_completion)
+ self._track_token_usage_internal(usage)
+ if final_completion.choices:
+ parsed_result = final_completion.choices[0].message.parsed
+ if parsed_result:
+ structured_json = parsed_result.model_dump_json()
+ self._emit_call_completed_event(
+ response=structured_json,
+ call_type=LLMCallType.LLM_CALL,
+ from_task=from_task,
+ from_agent=from_agent,
+ messages=params["messages"],
+ )
+ return structured_json
+
+ logging.error("Failed to get parsed result from stream")
+ return ""
+
+ stream: AsyncIterator[
+ ChatCompletionChunk
+ ] = await self.async_client.chat.completions.create(**params)
+
+ usage_data = {"total_tokens": 0}
+
+ async for chunk in stream:
+ if hasattr(chunk, "usage") and chunk.usage:
+ usage_data = self._extract_token_usage(chunk)
+ continue
+
+ if not chunk.choices:
+ continue
+
+ choice = chunk.choices[0]
+ chunk_delta: ChoiceDelta = choice.delta
+
+ if chunk_delta.content:
+ full_response += chunk_delta.content
+ self._emit_stream_chunk_event(
+ chunk=chunk_delta.content,
+ from_task=from_task,
+ from_agent=from_agent,
+ )
+
+ if chunk_delta.tool_calls:
+ for tool_call in chunk_delta.tool_calls:
+ tool_index = tool_call.index if tool_call.index is not None else 0
+ if tool_index not in tool_calls:
+ tool_calls[tool_index] = {
+ "id": tool_call.id,
+ "name": "",
+ "arguments": "",
+ "index": tool_index,
+ }
+ elif tool_call.id and not tool_calls[tool_index]["id"]:
+ tool_calls[tool_index]["id"] = tool_call.id
+
+ if tool_call.function and tool_call.function.name:
+ tool_calls[tool_index]["name"] = tool_call.function.name
+ if tool_call.function and tool_call.function.arguments:
+ tool_calls[tool_index]["arguments"] += (
+ tool_call.function.arguments
+ )
+
+ self._emit_stream_chunk_event(
+ chunk=tool_call.function.arguments
+ if tool_call.function and tool_call.function.arguments
+ else "",
+ from_task=from_task,
+ from_agent=from_agent,
+ tool_call={
+ "id": tool_calls[tool_index]["id"],
+ "function": {
+ "name": tool_calls[tool_index]["name"],
+ "arguments": tool_calls[tool_index]["arguments"],
+ },
+ "type": "function",
+ "index": tool_calls[tool_index]["index"],
+ },
+ call_type=LLMCallType.TOOL_CALL,
+ )
+
+ self._track_token_usage_internal(usage_data)
+
+ if tool_calls and available_functions:
+ for call_data in tool_calls.values():
+ function_name = call_data["name"]
+ arguments = call_data["arguments"]
+
+ if not function_name or not arguments:
+ continue
+
+ if function_name not in available_functions:
+ logging.warning(
+ f"Function '{function_name}' not found in available functions"
+ )
+ continue
+
+ try:
+ function_args = json.loads(arguments)
+ except json.JSONDecodeError as e:
+ logging.error(f"Failed to parse streamed tool arguments: {e}")
+ continue
+
+ result = self._handle_tool_execution(
+ function_name=function_name,
+ function_args=function_args,
+ available_functions=available_functions,
+ from_task=from_task,
+ from_agent=from_agent,
+ )
+
+ if result is not None:
+ return result
+
+ full_response = self._apply_stop_words(full_response)
+
+ self._emit_call_completed_event(
+ response=full_response,
+ call_type=LLMCallType.LLM_CALL,
+ from_task=from_task,
+ from_agent=from_agent,
+ messages=params["messages"],
+ )
+
+ return self._invoke_after_llm_call_hooks(
+ params["messages"], full_response, from_agent
+ )
+
+ async def astream(
+ self,
+ messages: str | list[LLMMessage],
+ tools: list[dict[str, BaseTool]] | None = None,
+ callbacks: list[Any] | None = None,
+ available_functions: dict[str, Any] | None = None,
+ from_task: Task | None = None,
+ from_agent: Agent | None = None,
+ response_model: type[BaseModel] | None = None,
+ ) -> AsyncIterator[str]:
+ """Stream responses from NVIDIA NIM chat completion API.
+
+ This method provides an async generator that yields text chunks as they
+ are received from the NVIDIA API, enabling real-time streaming responses.
+
+ Args:
+ messages: Input messages for the chat completion
+ tools: list of tool/function definitions
+ callbacks: Callback functions (not used in native implementation)
+ available_functions: Available functions for tool calling
+ from_task: Task that initiated the call
+ from_agent: Agent that initiated the call
+ response_model: Response model for structured output
+
+ Yields:
+ Text chunks as they are received from the API
+
+ Raises:
+ ValueError: If API key is missing or if LLM call is blocked by hook
+ NotFoundError: If the model is not found
+ APIConnectionError: If connection to NVIDIA API fails
+ LLMContextLengthExceededError: If context window is exceeded
+ """
+ # Validate API key before making actual API call
+ if not self.api_key and not (os.getenv("NVIDIA_API_KEY") or os.getenv("NVIDIA_NIM_API_KEY")):
+ raise ValueError(
+ "NVIDIA_API_KEY or NVIDIA_NIM_API_KEY is required for API calls. Get your API key from https://build.nvidia.com/"
+ )
+
+ try:
+ self._emit_call_started_event(
+ messages=messages,
+ tools=tools,
+ callbacks=callbacks,
+ available_functions=available_functions,
+ from_task=from_task,
+ from_agent=from_agent,
+ )
+
+ formatted_messages = self._format_messages(messages)
+
+ if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
+ raise ValueError("LLM call blocked by before_llm_call hook")
+
+ completion_params = self._prepare_completion_params(
+ messages=formatted_messages, tools=tools
+ )
+
+ # Force streaming mode for this method
+ completion_params["stream"] = True
+
+ # Handle structured output with response_model
+ if response_model:
+ parse_params = {
+ k: v
+ for k, v in completion_params.items()
+ if k not in ("response_format", "stream")
+ }
+
+ stream: AsyncChatCompletionStream[BaseModel]
+ async with self.async_client.beta.chat.completions.stream(
+ **parse_params, response_format=response_model
+ ) as stream:
+ async for chunk in stream:
+ if chunk.type == "content.delta":
+ delta_content = chunk.delta
+ if delta_content:
+ self._emit_stream_chunk_event(
+ chunk=delta_content,
+ from_task=from_task,
+ from_agent=from_agent,
+ )
+ yield delta_content
+
+ final_completion = await stream.get_final_completion()
+ if final_completion:
+ usage = self._extract_token_usage(final_completion)
+ self._track_token_usage_internal(usage)
+ if final_completion.choices:
+ parsed_result = final_completion.choices[0].message.parsed
+ if parsed_result:
+ structured_json = parsed_result.model_dump_json()
+ self._emit_call_completed_event(
+ response=structured_json,
+ call_type=LLMCallType.LLM_CALL,
+ from_task=from_task,
+ from_agent=from_agent,
+ messages=completion_params["messages"],
+ )
+
+ return
+
+ # Standard streaming without response_model
+ stream: AsyncIterator[
+ ChatCompletionChunk
+ ] = await self.async_client.chat.completions.create(**completion_params)
+
+ full_response = ""
+ tool_calls: dict[int, dict[str, Any]] = {}
+ usage_data = {"total_tokens": 0}
+
+ async for chunk in stream:
+ if hasattr(chunk, "usage") and chunk.usage:
+ usage_data = self._extract_token_usage(chunk)
+ continue
+
+ if not chunk.choices:
+ continue
+
+ choice = chunk.choices[0]
+ chunk_delta: ChoiceDelta = choice.delta
+
+ if chunk_delta.content:
+ full_response += chunk_delta.content
+ self._emit_stream_chunk_event(
+ chunk=chunk_delta.content,
+ from_task=from_task,
+ from_agent=from_agent,
+ )
+ yield chunk_delta.content
+
+ if chunk_delta.tool_calls:
+ for tool_call in chunk_delta.tool_calls:
+ tool_index = tool_call.index if tool_call.index is not None else 0
+ if tool_index not in tool_calls:
+ tool_calls[tool_index] = {
+ "id": tool_call.id,
+ "name": "",
+ "arguments": "",
+ "index": tool_index,
+ }
+ elif tool_call.id and not tool_calls[tool_index]["id"]:
+ tool_calls[tool_index]["id"] = tool_call.id
+
+ if tool_call.function and tool_call.function.name:
+ tool_calls[tool_index]["name"] = tool_call.function.name
+ if tool_call.function and tool_call.function.arguments:
+ tool_calls[tool_index]["arguments"] += (
+ tool_call.function.arguments
+ )
+
+ self._emit_stream_chunk_event(
+ chunk=tool_call.function.arguments
+ if tool_call.function and tool_call.function.arguments
+ else "",
+ from_task=from_task,
+ from_agent=from_agent,
+ tool_call={
+ "id": tool_calls[tool_index]["id"],
+ "function": {
+ "name": tool_calls[tool_index]["name"],
+ "arguments": tool_calls[tool_index]["arguments"],
+ },
+ "type": "function",
+ "index": tool_calls[tool_index]["index"],
+ },
+ call_type=LLMCallType.TOOL_CALL,
+ )
+
+ self._track_token_usage_internal(usage_data)
+
+ # Handle tool calls if present
+ if tool_calls and available_functions:
+ for call_data in tool_calls.values():
+ function_name = call_data["name"]
+ arguments = call_data["arguments"]
+
+ if not function_name or not arguments:
+ continue
+
+ if function_name not in available_functions:
+ logging.warning(
+ f"Function '{function_name}' not found in available functions"
+ )
+ continue
+
+ try:
+ function_args = json.loads(arguments)
+ except json.JSONDecodeError as e:
+ logging.error(f"Failed to parse streamed tool arguments: {e}")
+ continue
+
+ result = self._handle_tool_execution(
+ function_name=function_name,
+ function_args=function_args,
+ available_functions=available_functions,
+ from_task=from_task,
+ from_agent=from_agent,
+ )
+
+ if result is not None:
+ yield str(result)
+ return
+
+ # Apply stop words and emit completion event
+ full_response = self._apply_stop_words(full_response)
+
+ self._emit_call_completed_event(
+ response=full_response,
+ call_type=LLMCallType.LLM_CALL,
+ from_task=from_task,
+ from_agent=from_agent,
+ messages=completion_params["messages"],
+ )
+
+ except NotFoundError as e:
+ error_msg = f"Model {self.model} not found: {e}"
+ logging.error(error_msg)
+ self._emit_call_failed_event(
+ error=error_msg, from_task=from_task, from_agent=from_agent
+ )
+ raise ValueError(error_msg) from e
+ except APIConnectionError as e:
+ error_msg = f"Failed to connect to NVIDIA NIM API: {e}"
+ logging.error(error_msg)
+ self._emit_call_failed_event(
+ error=error_msg, from_task=from_task, from_agent=from_agent
+ )
+ raise ConnectionError(error_msg) from e
+ except Exception as e:
+ if is_context_length_exceeded(e):
+ logging.error(f"Context window exceeded: {e}")
+ raise LLMContextLengthExceededError(str(e)) from e
+
+ error_msg = f"NVIDIA NIM API call failed: {e!s}"
+ logging.error(error_msg)
+ self._emit_call_failed_event(
+ error=error_msg, from_task=from_task, from_agent=from_agent
+ )
+ raise
+
+ def supports_function_calling(self) -> bool:
+ """Check if the model supports function calling."""
+ return self.supports_tools
+
+ def supports_stop_words(self) -> bool:
+ """Check if the model supports stop words."""
+ return True # NVIDIA models support stop sequences
+
+ def get_context_window_size(self) -> int:
+ """Get the context window size for the model.
+
+ Tries to fetch max_model_len from NVIDIA API, falls back to pattern-based
+ defaults if API is unavailable or doesn't return the information.
+ """
+ from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO
+
+ # Try to get from API first
+ metadata = self._fetch_model_metadata()
+ if metadata and "max_model_len" in metadata:
+ max_len = metadata["max_model_len"]
+ logging.debug(
+ f"Using API-provided context window for {self.model}: {max_len}"
+ )
+ return int(max_len * CONTEXT_WINDOW_USAGE_RATIO)
+
+ # Fallback to pattern-based defaults
+ model_lower = self.model.lower()
+
+ # Modern models with large context windows (128K)
+ if any(
+ indicator in model_lower
+ for indicator in ["llama-3.1", "llama-3.2", "llama-3.3", "qwen3", "phi-3"]
+ ):
+ logging.debug(f"Using pattern-based context window (128K) for {self.model}")
+ return int(128000 * CONTEXT_WINDOW_USAGE_RATIO)
+
+ # Default for all NVIDIA models (32K - safe for most modern models)
+ logging.debug(f"Using default context window (32K) for {self.model}")
+ return int(32768 * CONTEXT_WINDOW_USAGE_RATIO)
+
+ def _extract_token_usage(
+ self, response: ChatCompletion | ChatCompletionChunk
+ ) -> dict[str, Any]:
+ """Extract token usage from NVIDIA NIM ChatCompletion or ChatCompletionChunk response."""
+ if hasattr(response, "usage") and response.usage:
+ usage = response.usage
+ return {
+ "prompt_tokens": getattr(usage, "prompt_tokens", 0),
+ "completion_tokens": getattr(usage, "completion_tokens", 0),
+ "total_tokens": getattr(usage, "total_tokens", 0),
+ }
+ return {"total_tokens": 0}
+
+ def _format_messages(self, messages: str | list[LLMMessage]) -> list[LLMMessage]:
+ """Format messages for NVIDIA NIM API."""
+ return super()._format_messages(messages)
diff --git a/lib/crewai/src/crewai/rag/embeddings/factory.py b/lib/crewai/src/crewai/rag/embeddings/factory.py
index eacf67b825..69a16b92da 100644
--- a/lib/crewai/src/crewai/rag/embeddings/factory.py
+++ b/lib/crewai/src/crewai/rag/embeddings/factory.py
@@ -67,6 +67,10 @@
)
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
+ from crewai.rag.embeddings.providers.nvidia.embedding_callable import (
+ NvidiaEmbeddingFunction,
+ )
+ from crewai.rag.embeddings.providers.nvidia.types import NvidiaProviderSpec
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec
from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec
@@ -96,6 +100,7 @@
"huggingface": "crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider",
"instructor": "crewai.rag.embeddings.providers.instructor.instructor_provider.InstructorProvider",
"jina": "crewai.rag.embeddings.providers.jina.jina_provider.JinaProvider",
+ "nvidia": "crewai.rag.embeddings.providers.nvidia.nvidia_provider.NvidiaProvider",
"ollama": "crewai.rag.embeddings.providers.ollama.ollama_provider.OllamaProvider",
"onnx": "crewai.rag.embeddings.providers.onnx.onnx_provider.ONNXProvider",
"openai": "crewai.rag.embeddings.providers.openai.openai_provider.OpenAIProvider",
@@ -192,6 +197,10 @@ def build_embedder_from_dict(
def build_embedder_from_dict(spec: JinaProviderSpec) -> JinaEmbeddingFunction: ...
+@overload
+def build_embedder_from_dict(spec: NvidiaProviderSpec) -> NvidiaEmbeddingFunction: ...
+
+
@overload
def build_embedder_from_dict(
spec: RoboflowProviderSpec,
@@ -321,6 +330,10 @@ def build_embedder(spec: InstructorProviderSpec) -> InstructorEmbeddingFunction:
def build_embedder(spec: JinaProviderSpec) -> JinaEmbeddingFunction: ...
+@overload
+def build_embedder(spec: NvidiaProviderSpec) -> NvidiaEmbeddingFunction: ...
+
+
@overload
def build_embedder(spec: RoboflowProviderSpec) -> RoboflowEmbeddingFunction: ...
diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/nvidia/__init__.py b/lib/crewai/src/crewai/rag/embeddings/providers/nvidia/__init__.py
new file mode 100644
index 0000000000..297564405c
--- /dev/null
+++ b/lib/crewai/src/crewai/rag/embeddings/providers/nvidia/__init__.py
@@ -0,0 +1,19 @@
+"""NVIDIA embeddings provider."""
+
+from crewai.rag.embeddings.providers.nvidia.embedding_callable import (
+ NvidiaEmbeddingFunction,
+)
+from crewai.rag.embeddings.providers.nvidia.nvidia_provider import NvidiaProvider
+from crewai.rag.embeddings.providers.nvidia.types import (
+ NvidiaEmbeddingModels,
+ NvidiaProviderConfig,
+ NvidiaProviderSpec,
+)
+
+__all__ = [
+ "NvidiaProvider",
+ "NvidiaEmbeddingFunction",
+ "NvidiaEmbeddingModels",
+ "NvidiaProviderConfig",
+ "NvidiaProviderSpec",
+]
diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/nvidia/embedding_callable.py b/lib/crewai/src/crewai/rag/embeddings/providers/nvidia/embedding_callable.py
new file mode 100644
index 0000000000..108b21dfe0
--- /dev/null
+++ b/lib/crewai/src/crewai/rag/embeddings/providers/nvidia/embedding_callable.py
@@ -0,0 +1,137 @@
+"""NVIDIA embedding callable implementation."""
+
+import os
+from typing import cast
+
+import httpx
+import numpy as np
+
+from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
+from crewai.rag.core.types import Documents, Embeddings
+
+
+class NvidiaEmbeddingFunction(EmbeddingFunction[Documents]):
+ """NVIDIA embedding function using the /v1/embeddings endpoint.
+
+ Supports NVIDIA's embedding models through the OpenAI-compatible API.
+ Default base URL: https://integrate.api.nvidia.com/v1
+ """
+
+ def __init__(
+ self,
+ api_key: str | None = None,
+ model_name: str = "nvidia/nv-embed-v1",
+ api_base: str = "https://integrate.api.nvidia.com/v1",
+ input_type: str = "query",
+ truncate: str = "NONE",
+ **kwargs: dict,
+ ) -> None:
+ """Initialize NVIDIA embedding function.
+
+ Args:
+ api_key: NVIDIA API key (or use NVIDIA_API_KEY/NVIDIA_NIM_API_KEY environment variable)
+ model_name: NVIDIA embedding model name (e.g., 'nvidia/nv-embed-v1')
+ api_base: Base URL for NVIDIA API
+ input_type: Type of input for asymmetric models ('query' or 'passage')
+ - 'query': For search queries or questions
+ - 'passage': For documents/passages to be searched
+ truncate: Truncation strategy ('NONE', 'START', 'END')
+ **kwargs: Additional parameters
+ """
+ # Validate and get API key from environment if not provided
+ if api_key is None:
+ api_key = os.getenv("NVIDIA_API_KEY") or os.getenv("NVIDIA_NIM_API_KEY")
+ if api_key is None:
+ raise ValueError(
+ "NVIDIA_API_KEY or NVIDIA_NIM_API_KEY is required for embeddings. "
+ "Get your API key from https://build.nvidia.com/"
+ )
+
+ self._api_key = api_key
+ self._model_name = model_name
+ self._api_base = api_base.rstrip("/")
+ self._input_type = input_type
+ self._truncate = truncate
+ self._session = httpx.Client()
+
+ # Models that require input_type parameter
+ self._requires_input_type = any(
+ keyword in model_name.lower()
+ for keyword in ["embedqa", "embedcode", "nemoretriever"]
+ )
+
+ @staticmethod
+ def name() -> str:
+ """Return the name of the embedding function for ChromaDB compatibility."""
+ return "nvidia"
+
+ def __call__(self, input: Documents) -> Embeddings:
+ """Generate embeddings for the given documents.
+
+ Args:
+ input: List of documents to embed
+
+ Returns:
+ List of embedding vectors as numpy arrays
+ """
+ # Build request payload
+ payload = {
+ "model": self._model_name,
+ "input": input,
+ }
+
+ # Add input_type and truncate for models that require them
+ if self._requires_input_type:
+ payload["input_type"] = self._input_type
+ payload["truncate"] = self._truncate
+
+ # NVIDIA embeddings API (OpenAI-compatible)
+ response = self._session.post(
+ f"{self._api_base}/embeddings",
+ headers={
+ "Authorization": f"Bearer {self._api_key}",
+ "Content-Type": "application/json",
+ },
+ json=payload,
+ timeout=60.0,
+ )
+
+ # Handle errors
+ if response.status_code != 200:
+ error_detail = ""
+ try:
+ error_data = response.json()
+ error_detail = error_data.get("detail", "") or error_data.get("error", {}).get("message", "")
+ except Exception:
+ error_detail = response.text[:500]
+
+ raise RuntimeError(
+ f"NVIDIA embeddings API returned status {response.status_code}: {error_detail}"
+ )
+
+ # Parse response
+ result = response.json()
+ embeddings_data = result.get("data", [])
+
+ if not embeddings_data:
+ raise ValueError(f"No embeddings returned from NVIDIA API for {len(input)} documents")
+
+ # Sort by index and extract embeddings
+ embeddings_data = sorted(embeddings_data, key=lambda x: x.get("index", 0))
+
+ # Extract embeddings with error handling for malformed responses
+ embeddings = []
+ for idx, item in enumerate(embeddings_data):
+ if "embedding" not in item:
+ raise ValueError(
+ f"NVIDIA API returned malformed response: item at index {idx} missing 'embedding' key. "
+ f"Available keys: {list(item.keys())}"
+ )
+ embeddings.append(np.array(item["embedding"], dtype=np.float32))
+
+ return cast(Embeddings, embeddings)
+
+ def __del__(self) -> None:
+ """Clean up HTTP session."""
+ if hasattr(self, "_session"):
+ self._session.close()
diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/nvidia/nvidia_provider.py b/lib/crewai/src/crewai/rag/embeddings/providers/nvidia/nvidia_provider.py
new file mode 100644
index 0000000000..68e7186b27
--- /dev/null
+++ b/lib/crewai/src/crewai/rag/embeddings/providers/nvidia/nvidia_provider.py
@@ -0,0 +1,93 @@
+"""NVIDIA embeddings provider."""
+
+from pydantic import AliasChoices, Field
+
+from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
+from crewai.rag.embeddings.providers.nvidia.embedding_callable import (
+ NvidiaEmbeddingFunction,
+)
+
+
+class NvidiaProvider(BaseEmbeddingsProvider[NvidiaEmbeddingFunction]):
+ """NVIDIA embeddings provider for RAG systems.
+
+ Provides access to NVIDIA's embedding models through the native API.
+ Supports all NVIDIA embedding models including:
+ - nvidia/nv-embed-v1 (4096 dimensions)
+ - nvidia/nv-embedqa-mistral-7b-v2
+ - nvidia/nv-embedcode-7b-v1
+ - nvidia/embed-qa-4
+ - nvidia/llama-3.2-nemoretriever-*
+ - And more...
+
+ Example:
+ ```python
+ from crewai.rag.embeddings.providers.nvidia import NvidiaProvider
+
+ embeddings = NvidiaProvider(
+ api_key="nvapi-...",
+ model_name="nvidia/nv-embed-v1"
+ )
+ ```
+ """
+
+ embedding_callable: type[NvidiaEmbeddingFunction] = Field(
+ default=NvidiaEmbeddingFunction,
+ description="NVIDIA embedding function class",
+ )
+
+ api_key: str | None = Field(
+ default=None,
+ description="NVIDIA API key",
+ validation_alias=AliasChoices(
+ "EMBEDDINGS_NVIDIA_API_KEY",
+ "NVIDIA_API_KEY",
+ ),
+ )
+
+ model_name: str = Field(
+ default="nvidia/nv-embed-v1",
+ description="NVIDIA embedding model name",
+ validation_alias=AliasChoices(
+ "EMBEDDINGS_NVIDIA_MODEL_NAME",
+ "NVIDIA_EMBEDDING_MODEL",
+ "model",
+ ),
+ )
+
+ api_base: str = Field(
+ default="https://integrate.api.nvidia.com/v1",
+ description="Base URL for NVIDIA API",
+ validation_alias=AliasChoices(
+ "EMBEDDINGS_NVIDIA_API_BASE",
+ "NVIDIA_API_BASE",
+ ),
+ )
+
+ input_type: str = Field(
+ default="query",
+ description="Input type for asymmetric models: 'query' for questions, 'passage' for documents",
+ validation_alias=AliasChoices(
+ "EMBEDDINGS_NVIDIA_INPUT_TYPE",
+ "NVIDIA_INPUT_TYPE",
+ ),
+ )
+
+ truncate: str = Field(
+ default="NONE",
+ description="Truncation strategy: 'NONE', 'START', or 'END'",
+ validation_alias=AliasChoices(
+ "EMBEDDINGS_NVIDIA_TRUNCATE",
+ "NVIDIA_TRUNCATE",
+ ),
+ )
+
+ def _create_embedding_function(self) -> NvidiaEmbeddingFunction:
+ """Create an NVIDIA embedding function instance from this provider's configuration.
+
+ Returns:
+ An initialized NvidiaEmbeddingFunction instance.
+ """
+ return self.embedding_callable(
+ **self.model_dump(exclude={"embedding_callable"})
+ )
diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/nvidia/types.py b/lib/crewai/src/crewai/rag/embeddings/providers/nvidia/types.py
new file mode 100644
index 0000000000..7ec7729e07
--- /dev/null
+++ b/lib/crewai/src/crewai/rag/embeddings/providers/nvidia/types.py
@@ -0,0 +1,34 @@
+"""Type definitions for NVIDIA embeddings provider."""
+
+from typing import Annotated, Literal
+
+from typing_extensions import Required, TypedDict
+
+# NVIDIA embedding models verified accessible via API testing
+# Last verified: 2026-01-06 (7 of 13 models in catalog are accessible)
+NvidiaEmbeddingModels = Literal[
+ "baai/bge-m3", # 1024 dimensions - General purpose embedding
+ "nvidia/llama-3.2-nemoretriever-300m-embed-v1", # 2048 dimensions - Compact retriever
+ "nvidia/llama-3.2-nemoretriever-300m-embed-v2", # 2048 dimensions - Compact retriever v2
+ "nvidia/llama-3.2-nv-embedqa-1b-v2", # 2048 dimensions - QA embedding
+ "nvidia/nv-embed-v1", # 4096 dimensions - NVIDIA's flagship (recommended)
+ "nvidia/nv-embedcode-7b-v1", # 4096 dimensions - Code embedding specialist
+ "nvidia/nv-embedqa-e5-v5", # 1024 dimensions - QA embedding based on E5
+]
+
+
+class NvidiaProviderConfig(TypedDict, total=False):
+ """Configuration for NVIDIA provider."""
+
+ api_key: str
+ model_name: str
+ api_base: str
+ input_type: str # 'query' or 'passage' for asymmetric models
+ truncate: str # 'NONE', 'START', or 'END'
+
+
+class NvidiaProviderSpec(TypedDict, total=False):
+ """NVIDIA provider specification."""
+
+ provider: Required[Literal["nvidia"]]
+ config: NvidiaProviderConfig
diff --git a/lib/crewai/src/crewai/rag/embeddings/types.py b/lib/crewai/src/crewai/rag/embeddings/types.py
index 794f4c6f9a..79fb934462 100644
--- a/lib/crewai/src/crewai/rag/embeddings/types.py
+++ b/lib/crewai/src/crewai/rag/embeddings/types.py
@@ -16,6 +16,7 @@
)
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
+from crewai.rag.embeddings.providers.nvidia.types import NvidiaProviderSpec
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec
from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec
@@ -38,6 +39,7 @@
| HuggingFaceProviderSpec
| InstructorProviderSpec
| JinaProviderSpec
+ | NvidiaProviderSpec
| OllamaProviderSpec
| ONNXProviderSpec
| OpenAIProviderSpec
@@ -60,6 +62,7 @@
"huggingface",
"instructor",
"jina",
+ "nvidia",
"ollama",
"onnx",
"openai",
diff --git a/lib/crewai/tests/llms/nvidia/__init__.py b/lib/crewai/tests/llms/nvidia/__init__.py
new file mode 100644
index 0000000000..c9073ec3b6
--- /dev/null
+++ b/lib/crewai/tests/llms/nvidia/__init__.py
@@ -0,0 +1 @@
+# NVIDIA LLM tests
\ No newline at end of file
diff --git a/lib/crewai/tests/llms/nvidia/test_nvidia.py b/lib/crewai/tests/llms/nvidia/test_nvidia.py
new file mode 100644
index 0000000000..d2456abbcc
--- /dev/null
+++ b/lib/crewai/tests/llms/nvidia/test_nvidia.py
@@ -0,0 +1,702 @@
+import os
+import sys
+import types
+from unittest.mock import patch, MagicMock
+import pytest
+
+from crewai.llm import LLM
+from crewai.crew import Crew
+from crewai.agent import Agent
+from crewai.task import Task
+
+
+@pytest.fixture(autouse=True)
+def mock_nvidia_api_key():
+ """Automatically mock NVIDIA_API_KEY for all tests in this module."""
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-key"}):
+ yield
+
+
+def test_nvidia_completion_is_used_when_nvidia_provider():
+ """
+ Test that NvidiaCompletion from completion.py is used when LLM uses provider 'nvidia'
+ """
+ llm = LLM(model="nvidia/llama-3.1-70b-instruct")
+
+ assert llm.__class__.__name__ == "NvidiaCompletion"
+ assert llm.provider == "nvidia"
+ assert llm.model == "llama-3.1-70b-instruct"
+
+
+def test_nvidia_completion_is_used_when_model_has_slash():
+ """
+ Test that NvidiaCompletion is used when model contains '/' and NVIDIA_API_KEY is set
+ """
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-key"}):
+ llm = LLM(model="meta/llama-3.1-70b-instruct")
+
+ from crewai.llms.providers.nvidia.completion import NvidiaCompletion
+ assert isinstance(llm, NvidiaCompletion)
+ assert llm.provider == "nvidia"
+ assert llm.model == "meta/llama-3.1-70b-instruct"
+
+
+def test_nvidia_falls_back_when_no_api_key():
+ """
+ Test that NVIDIA models fall back to LiteLLM when no NVIDIA_API_KEY is set
+ """
+ # Ensure no NVIDIA API key
+ with patch.dict(os.environ, {}, clear=True):
+ llm = LLM(model="meta/llama-3.1-70b-instruct")
+
+ # Should not be NvidiaCompletion
+ assert llm.__class__.__name__ != "NvidiaCompletion"
+
+
+def test_nvidia_tool_use_conversation_flow():
+ """
+ Test that the NVIDIA completion properly handles tool use conversation flow
+ """
+ from unittest.mock import Mock, patch
+ from crewai.llms.providers.nvidia.completion import NvidiaCompletion
+
+ # Create NvidiaCompletion instance
+ completion = NvidiaCompletion(model="meta/llama-3.1-70b-instruct")
+
+ # Mock tool function
+ def mock_weather_tool(location: str) -> str:
+ return f"The weather in {location} is sunny and 75°F"
+
+ available_functions = {"get_weather": mock_weather_tool}
+
+ # Mock the OpenAI client responses
+ with patch.object(completion.client.chat.completions, 'create') as mock_create:
+ # Mock function call in response
+ mock_function_call = Mock()
+ mock_function_call.name = "get_weather"
+ mock_function_call.arguments = '{"location": "San Francisco"}'
+
+ mock_tool_call = Mock()
+ mock_tool_call.id = "call_123"
+ mock_tool_call.function = mock_function_call
+
+ mock_choice = Mock()
+ mock_choice.message.tool_calls = [mock_tool_call]
+ mock_choice.message.content = None
+
+ mock_response = Mock()
+ mock_response.choices = [mock_choice]
+ mock_response.usage = Mock()
+ mock_response.usage.prompt_tokens = 100
+ mock_response.usage.completion_tokens = 50
+ mock_response.usage.total_tokens = 150
+
+ mock_create.return_value = mock_response
+
+ # Test the call
+ messages = [{"role": "user", "content": "What's the weather like in San Francisco?"}]
+ result = completion.call(
+ messages=messages,
+ available_functions=available_functions
+ )
+
+ # Verify the tool was executed and returned the result
+ assert result == "The weather in San Francisco is sunny and 75°F"
+
+ # Verify that the API was called
+ assert mock_create.called
+
+
+def test_nvidia_completion_module_is_imported():
+ """
+ Test that the completion module is properly imported when using NVIDIA provider
+ """
+ module_name = "crewai.llms.providers.nvidia.completion"
+
+ # Remove module from cache if it exists
+ if module_name in sys.modules:
+ del sys.modules[module_name]
+
+ # Create LLM instance - this should trigger the import
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-key"}):
+ LLM(model="nvidia/llama-3.1-70b-instruct")
+
+ # Verify the module was imported
+ assert module_name in sys.modules
+ completion_mod = sys.modules[module_name]
+ assert isinstance(completion_mod, types.ModuleType)
+
+ # Verify the class exists in the module
+ assert hasattr(completion_mod, 'NvidiaCompletion')
+
+
+def test_native_nvidia_raises_error_when_initialization_fails():
+ """
+ Test that LLM raises ImportError when native NVIDIA completion fails.
+
+ With the new behavior, when a native provider is in SUPPORTED_NATIVE_PROVIDERS
+ but fails to instantiate, we raise an ImportError instead of silently falling back.
+ This provides clearer error messages to users about missing dependencies.
+ """
+ # Mock the _get_native_provider to return a failing class
+ with patch('crewai.llm.LLM._get_native_provider') as mock_get_provider:
+
+ class FailingCompletion:
+ def __init__(self, *args, **kwargs):
+ raise Exception("Native NVIDIA SDK failed")
+
+ mock_get_provider.return_value = FailingCompletion
+
+ # This should raise ImportError with clear message
+ with pytest.raises(ImportError) as excinfo:
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-key"}):
+ LLM(model="nvidia/llama-3.1-70b-instruct")
+
+ # Verify the error message is helpful
+ assert "Error importing native provider" in str(excinfo.value)
+ assert "Native NVIDIA SDK failed" in str(excinfo.value)
+
+
+def test_nvidia_completion_initialization_parameters():
+ """
+ Test that NvidiaCompletion is initialized with correct parameters
+ """
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-key"}):
+ llm = LLM(
+ model="nvidia/llama-3.1-70b-instruct",
+ temperature=0.7,
+ max_tokens=2000,
+ top_p=0.9,
+ frequency_penalty=0.1,
+ api_key="test-key"
+ )
+
+ from crewai.llms.providers.nvidia.completion import NvidiaCompletion
+ assert isinstance(llm, NvidiaCompletion)
+ assert llm.model == "llama-3.1-70b-instruct"
+ assert llm.temperature == 0.7
+ assert llm.max_tokens == 2000
+ assert llm.top_p == 0.9
+ assert llm.frequency_penalty == 0.1
+
+
+def test_nvidia_specific_parameters():
+ """
+ Test NVIDIA-specific parameters like seed, stream, and response_format
+ """
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-key"}):
+ llm = LLM(
+ model="nvidia/llama-3.1-70b-instruct",
+ seed=42,
+ stream=True,
+ response_format={"type": "json_object"},
+ logprobs=True,
+ top_logprobs=5
+ )
+
+ from crewai.llms.providers.nvidia.completion import NvidiaCompletion
+ assert isinstance(llm, NvidiaCompletion)
+ assert llm.seed == 42
+ assert llm.stream == True
+ assert llm.response_format == {"type": "json_object"}
+ assert llm.logprobs == True
+ assert llm.top_logprobs == 5
+
+
+def test_nvidia_completion_call():
+ """
+ Test that NvidiaCompletion call method works
+ """
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-key"}):
+ llm = LLM(model="nvidia/llama-3.1-70b-instruct")
+
+ # Mock the call method on the instance
+ with patch.object(llm, 'call', return_value="Hello! I'm NVIDIA Llama, ready to help.") as mock_call:
+ result = llm.call("Hello, how are you?")
+
+ assert result == "Hello! I'm NVIDIA Llama, ready to help."
+ mock_call.assert_called_once_with("Hello, how are you?")
+
+
+def test_nvidia_completion_called_during_crew_execution():
+ """
+ Test that NvidiaCompletion.call is actually invoked when running a crew
+ """
+ # Create the LLM instance first
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-key"}):
+ nvidia_llm = LLM(model="nvidia/llama-3.1-70b-instruct")
+
+ # Mock the call method on the specific instance
+ with patch.object(nvidia_llm, 'call', return_value="Tokyo has 14 million people.") as mock_call:
+
+ # Create agent with explicit LLM configuration
+ agent = Agent(
+ role="Research Assistant",
+ goal="Find population info",
+ backstory="You research populations.",
+ llm=nvidia_llm,
+ )
+
+ task = Task(
+ description="Find Tokyo population",
+ expected_output="Population number",
+ agent=agent,
+ )
+
+ crew = Crew(agents=[agent], tasks=[task])
+ result = crew.kickoff()
+
+ # Verify mock was called
+ assert mock_call.called
+ assert "14 million" in str(result)
+
+
+def test_nvidia_completion_call_arguments():
+ """
+ Test that NvidiaCompletion.call is invoked with correct arguments
+ """
+ # Create LLM instance first
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-key"}):
+ nvidia_llm = LLM(model="nvidia/llama-3.1-70b-instruct")
+
+ # Mock the instance method
+ with patch.object(nvidia_llm, 'call') as mock_call:
+ mock_call.return_value = "Task completed successfully."
+
+ agent = Agent(
+ role="Test Agent",
+ goal="Complete a simple task",
+ backstory="You are a test agent.",
+ llm=nvidia_llm # Use same instance
+ )
+
+ task = Task(
+ description="Say hello world",
+ expected_output="Hello world",
+ agent=agent,
+ )
+
+ crew = Crew(agents=[agent], tasks=[task])
+ crew.kickoff()
+
+ # Verify call was made
+ assert mock_call.called
+
+ # Check the arguments passed to the call method
+ call_args = mock_call.call_args
+ assert call_args is not None
+
+ # The first argument should be the messages
+ messages = call_args[0][0] # First positional argument
+ assert isinstance(messages, (str, list))
+
+ # Verify that the task description appears in the messages
+ if isinstance(messages, str):
+ assert "hello world" in messages.lower()
+ elif isinstance(messages, list):
+ message_content = str(messages).lower()
+ assert "hello world" in message_content
+
+
+def test_multiple_nvidia_calls_in_crew():
+ """
+ Test that NvidiaCompletion.call is invoked multiple times for multiple tasks
+ """
+ # Create LLM instance first
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-key"}):
+ nvidia_llm = LLM(model="nvidia/llama-3.1-70b-instruct")
+
+ # Mock the instance method
+ with patch.object(nvidia_llm, 'call') as mock_call:
+ mock_call.return_value = "Task completed."
+
+ agent = Agent(
+ role="Multi-task Agent",
+ goal="Complete multiple tasks",
+ backstory="You can handle multiple tasks.",
+ llm=nvidia_llm # Use same instance
+ )
+
+ task1 = Task(
+ description="First task",
+ expected_output="First result",
+ agent=agent,
+ )
+
+ task2 = Task(
+ description="Second task",
+ expected_output="Second result",
+ agent=agent,
+ )
+
+ crew = Crew(
+ agents=[agent],
+ tasks=[task1, task2]
+ )
+ crew.kickoff()
+
+ # Verify multiple calls were made
+ assert mock_call.call_count >= 2 # At least one call per task
+
+ # Verify each call had proper arguments
+ for call in mock_call.call_args_list:
+ assert len(call[0]) > 0 # Has positional arguments
+ messages = call[0][0]
+ assert messages is not None
+
+
+def test_nvidia_completion_with_tools():
+ """
+ Test that NvidiaCompletion.call is invoked with tools when agent has tools
+ """
+ from crewai.tools import tool
+
+ @tool
+ def sample_tool(query: str) -> str:
+ """A sample tool for testing"""
+ return f"Tool result for: {query}"
+
+ # Create LLM instance first
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-key"}):
+ nvidia_llm = LLM(model="nvidia/llama-3.1-70b-instruct")
+
+ # Mock the instance method
+ with patch.object(nvidia_llm, 'call') as mock_call:
+ mock_call.return_value = "Task completed with tools."
+
+ agent = Agent(
+ role="Tool User",
+ goal="Use tools to complete tasks",
+ backstory="You can use tools.",
+ llm=nvidia_llm, # Use same instance
+ tools=[sample_tool]
+ )
+
+ task = Task(
+ description="Use the sample tool",
+ expected_output="Tool usage result",
+ agent=agent,
+ )
+
+ crew = Crew(agents=[agent], tasks=[task])
+ crew.kickoff()
+
+ assert mock_call.called
+
+ call_args = mock_call.call_args
+ call_kwargs = call_args[1] if len(call_args) > 1 else {}
+
+ if 'tools' in call_kwargs:
+ assert call_kwargs['tools'] is not None
+ assert len(call_kwargs['tools']) > 0
+
+
+def test_nvidia_raises_error_when_model_not_supported():
+ """Test that NvidiaCompletion raises ValueError when model not supported"""
+
+ # Mock the OpenAI client to raise an error
+ with patch('crewai.llms.providers.nvidia.completion.OpenAI') as mock_openai:
+ mock_client = MagicMock()
+ mock_openai.return_value = mock_client
+
+ mock_response = MagicMock()
+ mock_response.status_code = 404
+
+ from openai import NotFoundError
+ mock_client.chat.completions.create.side_effect = NotFoundError("Model not found", response=mock_response, body=None)
+
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-key"}):
+ llm = LLM(model="nvidia/model-doesnt-exist")
+
+ with pytest.raises(ValueError): # Should raise ValueError for unsupported model
+ llm.call("Hello")
+
+
+def test_nvidia_api_key_configuration():
+ """
+ Test that API key configuration works for both NVIDIA_API_KEY and NVIDIA_NIM_API_KEY
+ """
+ # Test with NVIDIA_API_KEY
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-nvidia-key"}):
+ llm = LLM(model="nvidia/llama-3.1-70b-instruct")
+
+ from crewai.llms.providers.nvidia.completion import NvidiaCompletion
+ assert isinstance(llm, NvidiaCompletion)
+ assert llm.api_key == "test-nvidia-key"
+
+ # Test with NVIDIA_NIM_API_KEY
+ with patch.dict(os.environ, {"NVIDIA_NIM_API_KEY": "test-nim-key"}, clear=True):
+ llm = LLM(model="nvidia/llama-3.1-70b-instruct")
+
+ assert isinstance(llm, NvidiaCompletion)
+ assert llm.api_key == "test-nim-key"
+
+
+def test_nvidia_model_capabilities():
+ """
+ Test that model capabilities are correctly identified
+ """
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-key"}):
+ # Test Llama 3.1 model
+ llm_llama = LLM(model="meta/llama-3.1-70b-instruct")
+ from crewai.llms.providers.nvidia.completion import NvidiaCompletion
+ assert isinstance(llm_llama, NvidiaCompletion)
+ assert llm_llama.supports_tools == True
+
+ # Test vision model
+ llm_vision = LLM(model="meta/llama-3.2-90b-vision-instruct")
+ assert isinstance(llm_vision, NvidiaCompletion)
+ assert llm_vision.is_vision_model == True
+
+
+def test_nvidia_generation_config():
+ """
+ Test that generation config is properly prepared
+ """
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-key"}):
+ llm = LLM(
+ model="nvidia/llama-3.1-70b-instruct",
+ temperature=0.7,
+ top_p=0.9,
+ frequency_penalty=0.1,
+ max_tokens=1000
+ )
+
+ from crewai.llms.providers.nvidia.completion import NvidiaCompletion
+ assert isinstance(llm, NvidiaCompletion)
+
+ # Test config preparation
+ params = llm._prepare_completion_params([])
+
+ # Verify config has the expected parameters
+ assert "temperature" in params
+ assert params["temperature"] == 0.7
+ assert "top_p" in params
+ assert params["top_p"] == 0.9
+ assert "frequency_penalty" in params
+ assert params["frequency_penalty"] == 0.1
+ assert "max_tokens" in params
+ assert params["max_tokens"] == 1000
+
+
+def test_nvidia_model_detection():
+ """
+ Test that various NVIDIA model formats are properly detected
+ """
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-key"}):
+ # Test NVIDIA model naming patterns that actually work with provider detection
+ nvidia_test_cases = [
+ "nvidia/llama-3.1-70b-instruct",
+ "meta/llama-3.1-70b-instruct",
+ "qwen/qwen3-next-80b-a3b-instruct",
+ "deepseek-ai/deepseek-r1",
+ "google/gemma-2-27b-it",
+ "mistralai/mistral-large-3-675b-instruct-2512"
+ ]
+
+ for model_name in nvidia_test_cases:
+ llm = LLM(model=model_name)
+ from crewai.llms.providers.nvidia.completion import NvidiaCompletion
+ assert isinstance(llm, NvidiaCompletion), f"Failed for model: {model_name}"
+
+
+def test_nvidia_supports_stop_words():
+ """
+ Test that NVIDIA models support stop sequences
+ """
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-key"}):
+ llm = LLM(model="nvidia/llama-3.1-70b-instruct")
+ assert llm.supports_stop_words() == True
+
+
+def test_nvidia_context_window_size():
+ """
+ Test that NVIDIA models return correct context window sizes
+ """
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-key"}):
+ # Test Llama 3.1 model
+ llm_llama = LLM(model="meta/llama-3.1-70b-instruct")
+ context_size_llama = llm_llama.get_context_window_size()
+ assert context_size_llama > 100000 # Should be substantial
+
+ # Test DeepSeek R1 model
+ llm_deepseek = LLM(model="deepseek-ai/deepseek-r1")
+ context_size_deepseek = llm_deepseek.get_context_window_size()
+ assert context_size_deepseek > 100000 # Should be large
+
+
+def test_nvidia_message_formatting():
+ """
+ Test that messages are properly formatted for NVIDIA API
+ """
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-key"}):
+ llm = LLM(model="nvidia/llama-3.1-70b-instruct")
+
+ # Test message formatting
+ test_messages = [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Hello"},
+ {"role": "assistant", "content": "Hi there!"},
+ {"role": "user", "content": "How are you?"}
+ ]
+
+ formatted_messages = llm._format_messages(test_messages)
+
+ # Should have all messages
+ assert len(formatted_messages) == 4
+
+ # Check roles are preserved
+ assert formatted_messages[0]["role"] == "system"
+ assert formatted_messages[1]["role"] == "user"
+ assert formatted_messages[2]["role"] == "assistant"
+ assert formatted_messages[3]["role"] == "user"
+
+
+def test_nvidia_streaming_parameter():
+ """
+ Test that streaming parameter is properly handled
+ """
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-key"}):
+ # Test non-streaming
+ llm_no_stream = LLM(model="nvidia/llama-3.1-70b-instruct", stream=False)
+ assert llm_no_stream.stream == False
+
+ # Test streaming
+ llm_stream = LLM(model="nvidia/llama-3.1-70b-instruct", stream=True)
+ assert llm_stream.stream == True
+
+
+def test_nvidia_tool_conversion():
+ """
+ Test that tools are properly converted to OpenAI format for NVIDIA
+ """
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-key"}):
+ llm = LLM(model="nvidia/llama-3.1-70b-instruct")
+
+ # Mock tool in CrewAI format
+ crewai_tools = [{
+ "type": "function",
+ "function": {
+ "name": "test_tool",
+ "description": "A test tool",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "query": {"type": "string", "description": "Search query"}
+ },
+ "required": ["query"]
+ }
+ }
+ }]
+
+ # Test tool conversion
+ nvidia_tools = llm._convert_tools_for_interference(crewai_tools)
+
+ assert len(nvidia_tools) == 1
+ # NVIDIA tools are in OpenAI format
+ assert nvidia_tools[0]["type"] == "function"
+ assert nvidia_tools[0]["function"]["name"] == "test_tool"
+ assert nvidia_tools[0]["function"]["description"] == "A test tool"
+
+
+def test_nvidia_environment_variable_api_key():
+ """
+ Test that NVIDIA API key is properly loaded from environment
+ """
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-nvidia-key"}):
+ llm = LLM(model="nvidia/llama-3.1-70b-instruct")
+
+ assert llm.client is not None
+ assert hasattr(llm.client, 'chat')
+ assert llm.api_key == "test-nvidia-key"
+
+
+def test_nvidia_token_usage_tracking():
+ """
+ Test that token usage is properly tracked for NVIDIA responses
+ """
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-key"}):
+ llm = LLM(model="nvidia/llama-3.1-70b-instruct")
+
+ # Mock the OpenAI response with usage information
+ with patch.object(llm.client.chat.completions, 'create') as mock_create:
+ mock_response = MagicMock()
+ mock_response.choices = [MagicMock()]
+ mock_response.choices[0].message.content = "test response"
+ mock_response.usage = MagicMock(
+ prompt_tokens=50,
+ completion_tokens=25,
+ total_tokens=75
+ )
+ mock_create.return_value = mock_response
+
+ result = llm.call("Hello")
+
+ # Verify the response
+ assert result == "test response"
+
+ # Verify token usage was extracted
+ usage = llm._extract_token_usage(mock_response)
+ assert usage["prompt_tokens"] == 50
+ assert usage["completion_tokens"] == 25
+ assert usage["total_tokens"] == 75
+
+
+def test_nvidia_reasoning_model_detection():
+ """Test that reasoning models like DeepSeek R1 are properly detected."""
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-key"}):
+ llm = LLM(model="deepseek-ai/deepseek-r1")
+
+ from crewai.llms.providers.nvidia.completion import NvidiaCompletion
+ assert isinstance(llm, NvidiaCompletion)
+
+ # Test that reasoning models get default max_tokens when not specified
+ config = llm._prepare_completion_params([])
+ assert "max_tokens" in config
+ assert config["max_tokens"] == 4096 # Default for reasoning models
+
+
+def test_nvidia_vision_model_detection():
+ """Test that vision models are properly detected."""
+ with patch.dict(os.environ, {"NVIDIA_API_KEY": "test-key"}):
+ vision_models = [
+ "meta/llama-3.2-90b-vision-instruct",
+ "meta/llama-3.2-11b-vision-instruct",
+ "microsoft/phi-3-vision-128k-instruct"
+ ]
+
+ for model in vision_models:
+ llm = LLM(model=model)
+ from crewai.llms.providers.nvidia.completion import NvidiaCompletion
+ assert isinstance(llm, NvidiaCompletion)
+ assert llm.is_vision_model == True
+
+
+@pytest.mark.vcr()
+@pytest.mark.skip(reason="VCR cannot replay SSE streaming responses")
+def test_nvidia_streaming_returns_usage_metrics():
+ """
+ Test that NVIDIA streaming calls return proper token usage metrics.
+ """
+ agent = Agent(
+ role="Research Assistant",
+ goal="Find information about the capital of Japan",
+ backstory="You are a helpful research assistant.",
+ llm=LLM(model="meta/llama-3.1-70b-instruct", stream=True),
+ verbose=True,
+ )
+
+ task = Task(
+ description="What is the capital of Japan?",
+ expected_output="The capital of Japan",
+ agent=agent,
+ )
+
+ crew = Crew(agents=[agent], tasks=[task])
+ result = crew.kickoff()
+
+ assert result.token_usage is not None
+ assert result.token_usage.total_tokens > 0
+ assert result.token_usage.prompt_tokens > 0
+ assert result.token_usage.completion_tokens > 0
+ assert result.token_usage.successful_requests >= 1
\ No newline at end of file