From 418a0e3f26dfb979ac7667add145647e0021891c Mon Sep 17 00:00:00 2001 From: Oskar Stark Date: Sat, 28 Dec 2024 22:50:24 +0100 Subject: [PATCH] `MessageBag` not depend on `\ArrayObject` anymore Follows * https://github.com/php-llm/llm-chain/pull/166/files#r1898145794 --- src/Bridge/Meta/LlamaPromptConverter.php | 2 +- src/Chain/ToolBox/ChainProcessor.php | 4 +- src/Model/Message/MessageBag.php | 52 +++++++++++++------ .../Bridge/Meta/LlamaPromptConverterTest.php | 2 +- tests/Model/Message/MessageBagTest.php | 12 ++--- 5 files changed, 45 insertions(+), 27 deletions(-) diff --git a/src/Bridge/Meta/LlamaPromptConverter.php b/src/Bridge/Meta/LlamaPromptConverter.php index 195ad04a..389c9d30 100644 --- a/src/Bridge/Meta/LlamaPromptConverter.php +++ b/src/Bridge/Meta/LlamaPromptConverter.php @@ -19,7 +19,7 @@ public function convertToPrompt(MessageBag $messageBag): string $messages = []; /** @var UserMessage|SystemMessage|AssistantMessage $message */ - foreach ($messageBag->getIterator() as $message) { + foreach ($messageBag->getMessages() as $message) { $messages[] = self::convertMessage($message); } diff --git a/src/Chain/ToolBox/ChainProcessor.php b/src/Chain/ToolBox/ChainProcessor.php index 15293d9c..4c25bf29 100644 --- a/src/Chain/ToolBox/ChainProcessor.php +++ b/src/Chain/ToolBox/ChainProcessor.php @@ -45,11 +45,11 @@ public function processOutput(Output $output): void while ($output->response instanceof ToolCallResponse) { $toolCalls = $output->response->getContent(); - $messages[] = Message::ofAssistant(toolCalls: $toolCalls); + $messages->add(Message::ofAssistant(toolCalls: $toolCalls)); foreach ($toolCalls as $toolCall) { $result = $this->toolBox->execute($toolCall); - $messages[] = Message::ofToolCall($toolCall, $result); + $messages->add(Message::ofToolCall($toolCall, $result)); } $output->response = $this->chain->call($messages, $output->options); diff --git a/src/Model/Message/MessageBag.php b/src/Model/Message/MessageBag.php index cd7bdcf5..63c33900 100644 --- a/src/Model/Message/MessageBag.php +++ b/src/Model/Message/MessageBag.php @@ -4,19 +4,34 @@ namespace PhpLlm\LlmChain\Model\Message; -/** - * @template-extends \ArrayObject - */ -final class MessageBag extends \ArrayObject implements \JsonSerializable +final class MessageBag implements \Countable, \JsonSerializable { + /** + * @var MessageInterface[] + */ + private array $messages; + public function __construct(MessageInterface ...$messages) { - parent::__construct(array_values($messages)); + $this->messages = array_values($messages); + } + + public function add(MessageInterface $message): void + { + $this->messages[] = $message; + } + + /** + * @return MessageInterface[] + */ + public function getMessages(): array + { + return $this->messages; } public function getSystemMessage(): ?SystemMessage { - foreach ($this as $message) { + foreach ($this->messages as $message) { if ($message instanceof SystemMessage) { return $message; } @@ -28,7 +43,7 @@ public function getSystemMessage(): ?SystemMessage public function with(MessageInterface $message): self { $messages = clone $this; - $messages->append($message); + $messages->add($message); return $messages; } @@ -36,7 +51,7 @@ public function with(MessageInterface $message): self public function merge(MessageBag $messageBag): self { $messages = clone $this; - $messages->exchangeArray(array_merge($messages->getArrayCopy(), $messageBag->getArrayCopy())); + $messages->messages = array_merge($messages->messages, $messageBag->messages); return $messages; } @@ -44,12 +59,10 @@ public function merge(MessageBag $messageBag): self public function withoutSystemMessage(): self { $messages = clone $this; - $messages->exchangeArray( - array_values(array_filter( - $messages->getArrayCopy(), - static fn (MessageInterface $message) => !$message instanceof SystemMessage, - )) - ); + $messages->messages = array_values(array_filter( + $messages->messages, + static fn (MessageInterface $message) => !$message instanceof SystemMessage, + )); return $messages; } @@ -57,14 +70,14 @@ public function withoutSystemMessage(): self public function prepend(MessageInterface $message): self { $messages = clone $this; - $messages->exchangeArray(array_merge([$message], $messages->getArrayCopy())); + $messages->messages = array_merge([$message], $messages->messages); return $messages; } public function containsImage(): bool { - foreach ($this as $message) { + foreach ($this->messages as $message) { if ($message instanceof UserMessage && $message->hasImageContent()) { return true; } @@ -73,11 +86,16 @@ public function containsImage(): bool return false; } + public function count(): int + { + return count($this->messages); + } + /** * @return MessageInterface[] */ public function jsonSerialize(): array { - return $this->getArrayCopy(); + return $this->messages; } } diff --git a/tests/Bridge/Meta/LlamaPromptConverterTest.php b/tests/Bridge/Meta/LlamaPromptConverterTest.php index 8db2d986..7d9c3149 100644 --- a/tests/Bridge/Meta/LlamaPromptConverterTest.php +++ b/tests/Bridge/Meta/LlamaPromptConverterTest.php @@ -26,7 +26,7 @@ public function convertMessages(): void { $messageBag = new MessageBag(); foreach (self::provideMessages() as $message) { - $messageBag->append($message[1]); + $messageBag->add($message[1]); } self::assertSame(<<getMessages()[3]; self::assertInstanceOf(AssistantMessage::class, $newMessageFromBag); self::assertSame('It is time to wake up.', $newMessageFromBag->content); @@ -91,7 +91,7 @@ public function merge(): void self::assertCount(4, $messageBag); - $messageFromBag = $messageBag[3]; + $messageFromBag = $messageBag->getMessages()[3]; self::assertInstanceOf(AssistantMessage::class, $messageFromBag); self::assertSame('It is time to wake up.', $messageFromBag->content); @@ -111,7 +111,7 @@ public function withoutSystemMessage(): void self::assertCount(3, $messageBag); self::assertCount(2, $newMessageBag); - $messageFromNewBag = $newMessageBag[0]; + $messageFromNewBag = $newMessageBag->getMessages()[0]; self::assertInstanceOf(AssistantMessage::class, $messageFromNewBag); self::assertSame('It is time to sleep.', $messageFromNewBag->content); @@ -131,14 +131,14 @@ public function prepend(): void self::assertCount(2, $messageBag); self::assertCount(3, $newMessageBag); - $newMessageBagMessage = $newMessageBag[0]; + $newMessageBagMessage = $newMessageBag->getMessages()[0]; self::assertInstanceOf(SystemMessage::class, $newMessageBagMessage); self::assertSame('My amazing system prompt.', $newMessageBagMessage->content); } #[Test] - public function containsImageWithoutImage(): void + public function containsImageReturnsFalseWithoutImage(): void { $messageBag = new MessageBag( Message::ofAssistant('It is time to sleep.'), @@ -149,7 +149,7 @@ public function containsImageWithoutImage(): void } #[Test] - public function containsImageWithImage(): void + public function containsImageReturnsTrueWithImage(): void { $messageBag = new MessageBag( Message::ofAssistant('It is time to sleep.'),