diff --git a/src/Chain.php b/src/Chain.php index 3fc12a5..8cc8500 100644 --- a/src/Chain.php +++ b/src/Chain.php @@ -47,25 +47,18 @@ public function __construct( */ public function call(MessageBagInterface $messages, array $options = []): ResponseInterface { - $llm = $this->llm; - - if (array_key_exists('llm', $options)) { - if (!$options['llm'] instanceof LanguageModel) { - throw new InvalidArgumentException(sprintf('Option "llm" must be an instance of %s.', LanguageModel::class)); - } - - $llm = $options['llm']; - unset($options['llm']); - } - - $input = new Input($llm, $messages, $options); + $input = new Input($this->llm, $messages, $options); array_map(fn (InputProcessor $processor) => $processor->processInput($input), $this->inputProcessor); + $llm = $input->llm; + $messages = $input->messages; + $options = $input->getOptions(); + if ($messages->containsImage() && !$llm->supportsImageInput()) { throw MissingModelSupport::forImageInput($llm::class); } - $response = $this->platform->request($llm, $messages, $options = $input->getOptions()); + $response = $this->platform->request($llm, $messages, $options); if ($response instanceof AsyncResponse) { $response = $response->unwrap(); diff --git a/src/Chain/Input.php b/src/Chain/Input.php index a99b507..9aa4d8d 100644 --- a/src/Chain/Input.php +++ b/src/Chain/Input.php @@ -13,8 +13,8 @@ final class Input * @param array $options */ public function __construct( - public readonly LanguageModel $llm, - public readonly MessageBagInterface $messages, + public LanguageModel $llm, + public MessageBagInterface $messages, private array $options, ) { } diff --git a/src/Chain/LlmOverrideInputProcessor.php b/src/Chain/LlmOverrideInputProcessor.php new file mode 100644 index 0000000..83241e3 --- /dev/null +++ b/src/Chain/LlmOverrideInputProcessor.php @@ -0,0 +1,26 @@ +getOptions(); + + if (!array_key_exists('llm', $options)) { + return; + } + + if (!$options['llm'] instanceof LanguageModel) { + throw new InvalidArgumentException(sprintf('Option "llm" must be an instance of %s.', LanguageModel::class)); + } + + $input->llm = $options['llm']; + } +} diff --git a/tests/Chain/LlmOverrideInputProcessorTest.php b/tests/Chain/LlmOverrideInputProcessorTest.php new file mode 100644 index 0000000..84476cb --- /dev/null +++ b/tests/Chain/LlmOverrideInputProcessorTest.php @@ -0,0 +1,66 @@ + $claude]); + + $processor = new LlmOverrideInputProcessor(); + $processor->processInput($input); + + self::assertSame($claude, $input->llm); + } + + #[Test] + public function processInputWithoutLlmOption(): void + { + $gpt = new GPT(); + $input = new Input($gpt, new MessageBag(), []); + + $processor = new LlmOverrideInputProcessor(); + $processor->processInput($input); + + self::assertSame($gpt, $input->llm); + } + + #[Test] + public function processInputWithInvalidLlmOption(): void + { + self::expectException(InvalidArgumentException::class); + self::expectExceptionMessage('Option "llm" must be an instance of PhpLlm\LlmChain\Model\LanguageModel.'); + + $gpt = new GPT(); + $model = new Embeddings(); + $input = new Input($gpt, new MessageBag(), ['llm' => $model]); + + $processor = new LlmOverrideInputProcessor(); + $processor->processInput($input); + } +}