Skip to content

Commit

Permalink
refactor: MessageBag not depend on \ArrayObject anymore (#174)
Browse files Browse the repository at this point in the history
Follows
* https://github.com/php-llm/llm-chain/pull/166/files#r1898145794

BREAKING CHANGE: MessageBag is not an instance of ArrayObject anymore and therefore array access and iteration not possible anymore. Use `getMessages()` instead.
  • Loading branch information
OskarStark authored Dec 28, 2024
1 parent f7d8aa0 commit 3c48f5a
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 27 deletions.
2 changes: 1 addition & 1 deletion src/Bridge/Meta/LlamaPromptConverter.php
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public function convertToPrompt(MessageBag $messageBag): string
$messages = [];

/** @var UserMessage|SystemMessage|AssistantMessage $message */
foreach ($messageBag->getIterator() as $message) {
foreach ($messageBag->getMessages() as $message) {
$messages[] = self::convertMessage($message);
}

Expand Down
4 changes: 2 additions & 2 deletions src/Chain/ToolBox/ChainProcessor.php
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ public function processOutput(Output $output): void

while ($output->response instanceof ToolCallResponse) {
$toolCalls = $output->response->getContent();
$messages[] = Message::ofAssistant(toolCalls: $toolCalls);
$messages->add(Message::ofAssistant(toolCalls: $toolCalls));

foreach ($toolCalls as $toolCall) {
$result = $this->toolBox->execute($toolCall);
$messages[] = Message::ofToolCall($toolCall, $result);
$messages->add(Message::ofToolCall($toolCall, $result));
}

$output->response = $this->chain->call($messages, $output->options);
Expand Down
52 changes: 35 additions & 17 deletions src/Model/Message/MessageBag.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,34 @@

namespace PhpLlm\LlmChain\Model\Message;

/**
* @template-extends \ArrayObject<int, MessageInterface>
*/
final class MessageBag extends \ArrayObject implements \JsonSerializable
final class MessageBag implements \Countable, \JsonSerializable
{
/**
* @var MessageInterface[]
*/
private array $messages;

public function __construct(MessageInterface ...$messages)
{
parent::__construct(array_values($messages));
$this->messages = array_values($messages);
}

public function add(MessageInterface $message): void
{
$this->messages[] = $message;
}

/**
* @return MessageInterface[]
*/
public function getMessages(): array
{
return $this->messages;
}

public function getSystemMessage(): ?SystemMessage
{
foreach ($this as $message) {
foreach ($this->messages as $message) {
if ($message instanceof SystemMessage) {
return $message;
}
Expand All @@ -28,43 +43,41 @@ public function getSystemMessage(): ?SystemMessage
public function with(MessageInterface $message): self
{
$messages = clone $this;
$messages->append($message);
$messages->add($message);

return $messages;
}

public function merge(MessageBag $messageBag): self
{
$messages = clone $this;
$messages->exchangeArray(array_merge($messages->getArrayCopy(), $messageBag->getArrayCopy()));
$messages->messages = array_merge($messages->messages, $messageBag->messages);

return $messages;
}

public function withoutSystemMessage(): self
{
$messages = clone $this;
$messages->exchangeArray(
array_values(array_filter(
$messages->getArrayCopy(),
static fn (MessageInterface $message) => !$message instanceof SystemMessage,
))
);
$messages->messages = array_values(array_filter(
$messages->messages,
static fn (MessageInterface $message) => !$message instanceof SystemMessage,
));

return $messages;
}

public function prepend(MessageInterface $message): self
{
$messages = clone $this;
$messages->exchangeArray(array_merge([$message], $messages->getArrayCopy()));
$messages->messages = array_merge([$message], $messages->messages);

return $messages;
}

public function containsImage(): bool
{
foreach ($this as $message) {
foreach ($this->messages as $message) {
if ($message instanceof UserMessage && $message->hasImageContent()) {
return true;
}
Expand All @@ -73,11 +86,16 @@ public function containsImage(): bool
return false;
}

public function count(): int
{
return count($this->messages);
}

/**
* @return MessageInterface[]
*/
public function jsonSerialize(): array
{
return $this->getArrayCopy();
return $this->messages;
}
}
2 changes: 1 addition & 1 deletion tests/Bridge/Meta/LlamaPromptConverterTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public function convertMessages(): void
{
$messageBag = new MessageBag();
foreach (self::provideMessages() as $message) {
$messageBag->append($message[1]);
$messageBag->add($message[1]);
}

self::assertSame(<<<EXPECTED
Expand Down
12 changes: 6 additions & 6 deletions tests/Model/Message/MessageBagTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public function with(): void
self::assertCount(3, $messageBag);
self::assertCount(4, $newMessageBag);

$newMessageFromBag = $newMessageBag[3];
$newMessageFromBag = $newMessageBag->getMessages()[3];

self::assertInstanceOf(AssistantMessage::class, $newMessageFromBag);
self::assertSame('It is time to wake up.', $newMessageFromBag->content);
Expand All @@ -91,7 +91,7 @@ public function merge(): void

self::assertCount(4, $messageBag);

$messageFromBag = $messageBag[3];
$messageFromBag = $messageBag->getMessages()[3];

self::assertInstanceOf(AssistantMessage::class, $messageFromBag);
self::assertSame('It is time to wake up.', $messageFromBag->content);
Expand All @@ -111,7 +111,7 @@ public function withoutSystemMessage(): void
self::assertCount(3, $messageBag);
self::assertCount(2, $newMessageBag);

$messageFromNewBag = $newMessageBag[0];
$messageFromNewBag = $newMessageBag->getMessages()[0];

self::assertInstanceOf(AssistantMessage::class, $messageFromNewBag);
self::assertSame('It is time to sleep.', $messageFromNewBag->content);
Expand All @@ -131,14 +131,14 @@ public function prepend(): void
self::assertCount(2, $messageBag);
self::assertCount(3, $newMessageBag);

$newMessageBagMessage = $newMessageBag[0];
$newMessageBagMessage = $newMessageBag->getMessages()[0];

self::assertInstanceOf(SystemMessage::class, $newMessageBagMessage);
self::assertSame('My amazing system prompt.', $newMessageBagMessage->content);
}

#[Test]
public function containsImageWithoutImage(): void
public function containsImageReturnsFalseWithoutImage(): void
{
$messageBag = new MessageBag(
Message::ofAssistant('It is time to sleep.'),
Expand All @@ -149,7 +149,7 @@ public function containsImageWithoutImage(): void
}

#[Test]
public function containsImageWithImage(): void
public function containsImageReturnsTrueWithImage(): void
{
$messageBag = new MessageBag(
Message::ofAssistant('It is time to sleep.'),
Expand Down

0 comments on commit 3c48f5a

Please sign in to comment.