Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions lib/crewai/src/crewai/llms/providers/bedrock/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,11 @@ def __init__(
guardrail_config: Guardrail configuration for content filtering
additional_model_request_fields: Model-specific request parameters
additional_model_response_field_paths: Custom response field paths
**kwargs: Additional parameters
**kwargs: Additional parameters (including model_id for cross-region inference)
"""
# Extract model_id from kwargs if provided (for cross-region inference profiles)
custom_model_id = kwargs.pop("model_id", None)

# Extract provider from kwargs to avoid duplicate argument
kwargs.pop("provider", None)

Expand Down Expand Up @@ -230,7 +233,7 @@ def __init__(
self.supports_streaming = True

# Handle inference profiles for newer models
self.model_id = model
self.model_id = custom_model_id if custom_model_id else model

def call(
self,
Expand Down
74 changes: 74 additions & 0 deletions lib/crewai/tests/llms/bedrock/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def mock_aws_credentials():
"AWS_SECRET_ACCESS_KEY": "test-secret-key",
"AWS_DEFAULT_REGION": "us-east-1"
}):
import crewai.llms.providers.bedrock.completion
# Mock boto3 Session to prevent actual AWS connections
with patch('crewai.llms.providers.bedrock.completion.Session') as mock_session_class:
# Create mock session instance
Expand Down Expand Up @@ -736,3 +737,76 @@ def test_bedrock_client_error_handling():
with pytest.raises(RuntimeError) as exc_info:
llm.call("Hello")
assert "throttled" in str(exc_info.value).lower()


def test_bedrock_cross_region_inference_profile():
"""
Test that Bedrock supports cross-region inference profiles with model_id parameter.

This tests the fix for issue #3791 where cross-region inference profiles
(which require using ARN as model_id) were not working in version 1.20.0.

When using cross-region inference profiles, users need to:
1. Set model to the base model name (e.g., "bedrock/anthropic.claude-sonnet-4-20250514-v1:0")
2. Set model_id to the inference profile ARN

The BedrockCompletion should use the model_id parameter when provided,
not the model parameter, for the actual API call.
"""
# Test with cross-region inference profile ARN
inference_profile_arn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/us.anthropic.claude-sonnet-4-20250514-v1:0"

llm = LLM(
model="bedrock/anthropic.claude-sonnet-4-20250514-v1:0",
model_id=inference_profile_arn,
temperature=0.3,
max_tokens=4000,
)

from crewai.llms.providers.bedrock.completion import BedrockCompletion
assert isinstance(llm, BedrockCompletion)

assert llm.model_id == inference_profile_arn
assert llm.model == "anthropic.claude-sonnet-4-20250514-v1:0"

# Verify that the client.converse call would use the correct model_id
with patch.object(llm.client, 'converse') as mock_converse:
mock_converse.return_value = {
'output': {
'message': {
'role': 'assistant',
'content': [{'text': 'Test response'}]
}
},
'usage': {
'inputTokens': 10,
'outputTokens': 5,
'totalTokens': 15
}
}

llm.call("Test message")

# Verify the converse call was made with the inference profile ARN
mock_converse.assert_called_once()
call_kwargs = mock_converse.call_args[1]
assert call_kwargs['modelId'] == inference_profile_arn


def test_bedrock_model_id_parameter_takes_precedence():
"""
Test that when both model and model_id are provided, model_id takes precedence
for the actual API call, while model is used for internal identification.
"""
custom_model_id = "custom-model-identifier"

llm = LLM(
model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
model_id=custom_model_id,
)

from crewai.llms.providers.bedrock.completion import BedrockCompletion
assert isinstance(llm, BedrockCompletion)

assert llm.model_id == custom_model_id
assert llm.model == "anthropic.claude-3-5-sonnet-20241022-v2:0"
Loading
Loading