Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/Gateway/Prism/PrismGateway.php
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,13 @@ public function generateEmbeddings(

/**
* Map the given Laravel AI provider to a Prism provider.
*
* Built-in drivers are mapped to their corresponding Prism provider enum.
* Unknown drivers fall through as strings, allowing external Prism providers
* registered via PrismManager::extend() (e.g. prism-php/bedrock) to be
* resolved by Prism's own provider resolution.
*/
protected static function toPrismProvider(Provider $provider): PrismProvider
protected static function toPrismProvider(Provider $provider): PrismProvider|string
{
return match ($provider->driver()) {
'anthropic' => PrismProvider::Anthropic,
Expand All @@ -393,7 +398,7 @@ protected static function toPrismProvider(Provider $provider): PrismProvider
'openrouter' => PrismProvider::OpenRouter,
'voyageai' => PrismProvider::VoyageAI,
'xai' => PrismProvider::XAI,
default => throw new InvalidArgumentException('Gateway does not support provider ['.$provider.'].'),
default => $provider->driver(),
};
}

Expand Down
81 changes: 81 additions & 0 deletions tests/Unit/Gateway/Prism/ToPrismProviderTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
<?php

namespace Tests\Unit\Gateway\Prism;

use InvalidArgumentException;
use Laravel\Ai\Gateway\Prism\PrismGateway;
use Laravel\Ai\Providers\Provider;
use Prism\Prism\Enums\Provider as PrismProvider;
use ReflectionMethod;
use Tests\TestCase;

class ToPrismProviderTest extends TestCase
{
/**
* Call the protected static toPrismProvider method via reflection.
*/
protected function callToPrismProvider(Provider $provider): PrismProvider|string
{
$method = new ReflectionMethod(PrismGateway::class, 'toPrismProvider');

return $method->invoke(null, $provider);
}

/**
* Create a minimal Provider stub with the given driver name.
*/
protected function makeProvider(string $driver): Provider
{
return new class($driver) extends Provider
{
public function __construct(string $driver)
{
parent::__construct(
new \Laravel\Ai\Gateway\Prism\PrismGateway(app('events')),
['driver' => $driver, 'key' => 'test-key', 'name' => $driver],
app('events'),
);
}
};
}

public function test_built_in_drivers_return_prism_provider_enum(): void
{
$mappings = [
'anthropic' => PrismProvider::Anthropic,
'azure' => PrismProvider::OpenAI,
'deepseek' => PrismProvider::DeepSeek,
'gemini' => PrismProvider::Gemini,
'groq' => PrismProvider::Groq,
'mistral' => PrismProvider::Mistral,
'ollama' => PrismProvider::Ollama,
'openai' => PrismProvider::OpenAI,
'openrouter' => PrismProvider::OpenRouter,
'voyageai' => PrismProvider::VoyageAI,
'xai' => PrismProvider::XAI,
];

foreach ($mappings as $driver => $expectedEnum) {
$result = $this->callToPrismProvider($this->makeProvider($driver));

$this->assertInstanceOf(PrismProvider::class, $result, "Driver '{$driver}' should return a PrismProvider enum");
$this->assertSame($expectedEnum, $result, "Driver '{$driver}' mapped to wrong enum");
}
}

public function test_unknown_driver_returns_string_for_custom_prism_providers(): void
{
$result = $this->callToPrismProvider($this->makeProvider('bedrock'));

$this->assertIsString($result);
$this->assertSame('bedrock', $result);
}

public function test_custom_driver_name_is_returned_as_is(): void
{
$result = $this->callToPrismProvider($this->makeProvider('workers-ai'));

$this->assertIsString($result);
$this->assertSame('workers-ai', $result);
}
}
Loading