Skip to content

Commit

Permalink
feat(ai): add message renderer (#5873)
Browse files Browse the repository at this point in the history
  • Loading branch information
dbanksdesign authored Oct 18, 2024
1 parent 7f42be2 commit 3a697ea
Show file tree
Hide file tree
Showing 12 changed files with 190 additions and 53 deletions.
16 changes: 16 additions & 0 deletions .changeset/heavy-dots-applaud.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
---
"@aws-amplify/ui-react-ai": minor
---

feat(ai): add message renderer

```tsx
<AIConversation
messages={messages}
handleSendMessage={sendMessage}
isLoading={isLoading}
messageRenderer={{
text: ({text}) => <ReactMarkdown>{text}</ReactMarkdown>,
}}
/>
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
import amplifyOutputs from '@environments/ai/gen2/amplify_outputs';
export default amplifyOutputs;
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import { Amplify } from 'aws-amplify';
import { createAIHooks, AIConversation } from '@aws-amplify/ui-react-ai';
import { generateClient } from 'aws-amplify/api';
import '@aws-amplify/ui-react/styles.css';
import '@aws-amplify/ui-react-ai/ai-conversation-styles.css';

import outputs from './amplify_outputs';
import type { Schema } from '@environments/ai/gen2/amplify/data/resource';
import { Authenticator, Card, Text } from '@aws-amplify/ui-react';
import Image from 'next/image';

const client = generateClient<Schema>({ authMode: 'userPool' });
const { useAIConversation } = createAIHooks(client);

Amplify.configure(outputs);

function arrayBufferToBase64(buffer: ArrayBuffer) {
let binary = '';
const bytes = new Uint8Array(buffer);
const len = bytes.byteLength;
for (let i = 0; i < len; i++) {
binary += String.fromCharCode(bytes[i]);
}
return window.btoa(binary);
}

function convertBufferToBase64(buffer: ArrayBuffer, format: string): string {
let base64string = '';
// Use node-based buffer if available
// fall back on browser if not
if (typeof Buffer !== 'undefined') {
base64string = Buffer.from(new Uint8Array(buffer)).toString('base64');
} else {
base64string = arrayBufferToBase64(buffer);
}
return `data:image/${format};base64,${base64string}`;
}

function Chat() {
const [
{
data: { messages },
isLoading,
},
sendMessage,
] = useAIConversation('pirateChat');

return (
<Card variation="outlined" width="50%" height="300px" margin="0 auto">
<AIConversation
messages={messages}
handleSendMessage={sendMessage}
isLoading={isLoading}
allowAttachments
messageRenderer={{
text: ({ text }) => <Text className="testing">{text}</Text>,
image: ({ image }) => (
<Image
className="testing"
width={200}
height={200}
src={convertBufferToBase64(image.source.bytes, image.format)}
alt=""
/>
),
}}
suggestedPrompts={[
{
inputText: 'hello',
header: 'hello',
},
{
inputText: 'how are you?',
header: 'how are you?',
},
]}
variant="bubble"
/>
</Card>
);
}

export default function Example() {
return (
<Authenticator>
<Chat />
</Authenticator>
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ function AIConversationBase({
isLoading,
displayText,
allowAttachments,
messageRenderer,
}: AIConversationBaseProps): JSX.Element {
useSetUserAgent({
componentName: 'AIConversation',
Expand Down Expand Up @@ -78,6 +79,7 @@ function AIConversationBase({
},
displayText,
allowAttachments,
messageRenderer,
};

return (
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import { createContextUtilities } from '@aws-amplify/ui-react-core';
import { MessageRenderer } from '../types';

export const {
MessageRendererContext,
MessageRendererProvider,
useMessageRenderer,
} = createContextUtilities<MessageRenderer>({
contextName: 'MessageRenderer',
defaultValue: undefined,
errorMessage:
'`useMessageRenderer` must be used with an AIConversation component',
});
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,11 @@ export {
RESPONSE_COMPONENT_PREFIX,
} from './ResponseComponentsContext';
export { SendMessageContextProvider } from './SendMessageContext';
export {
MessageRendererProvider,
MessageRendererContext,
useMessageRenderer,
} from './MessageRenderContext';
export { AttachmentProvider, AttachmentContext } from './AttachmentContext';

export * from './elements';
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export function createAIConversation(input: AIConversationInput = {}): {
controls,
displayText,
allowAttachments,
messageRenderer,
} = input;

function AIConversation(props: AIConversationProps): JSX.Element {
Expand All @@ -48,6 +49,7 @@ export function createAIConversation(input: AIConversationInput = {}): {
avatars,
handleSendMessage,
isLoading,
messageRenderer,
};
return (
<AIConversationProvider {...providerProps}>
Expand Down
13 changes: 12 additions & 1 deletion packages/react-ai/src/components/AIConversation/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@ import {
} from './views';
import { DisplayTextTemplate } from '@aws-amplify/ui';
import { AIConversationDisplayText } from './displayText';
import { ConversationMessage, SendMessage } from '../../types';
import {
ConversationMessage,
ImageContentBlock,
SendMessage,
TextContentBlock,
} from '../../types';
import { ControlsContextProps } from './context/ControlsContext';

export interface Controls {
Expand All @@ -32,6 +37,7 @@ export interface AIConversationInput {
variant?: MessageVariant;
controls?: ControlsContextProps;
allowAttachments?: boolean;
messageRenderer?: MessageRenderer;
}

export interface AIConversationProps {
Expand All @@ -54,6 +60,11 @@ export interface AIConversation {

export type MessageVariant = 'bubble' | 'default';

export interface MessageRenderer {
text?: (input: { text: TextContentBlock }) => React.JSX.Element;
image?: (input: { image: ImageContentBlock }) => React.JSX.Element;
}

export interface Avatar {
username?: string;
avatar?: React.ReactNode;
Expand Down
4 changes: 2 additions & 2 deletions packages/react-ai/src/components/AIConversation/utils.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { ImageContent } from '../../types';
import { ImageContentBlock } from '../../types';

export function formatDate(date: Date): string {
const dateString = date.toLocaleDateString('en-US', {
Expand Down Expand Up @@ -27,7 +27,7 @@ function arrayBufferToBase64(buffer: ArrayBuffer) {

export function convertBufferToBase64(
buffer: ArrayBuffer,
format: ImageContent['format']
format: ImageContentBlock['format']
): string {
let base64string = '';
// Use node-based buffer if available
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import React from 'react';
import { withBaseElementProps } from '@aws-amplify/ui-react-core/elements';

import {
MessageRendererContext,
MessagesContext,
MessageVariantContext,
RoleContext,
Expand Down Expand Up @@ -63,25 +64,30 @@ const ContentContainer: typeof View = React.forwardRef(

export const MessageControl: MessageControl = ({ message }) => {
const responseComponents = React.useContext(ResponseComponentsContext);
const messageRenderer = React.useContext(MessageRendererContext);
return (
<ContentContainer>
{message.content.map((content, index) => {
if (content.text) {
return (
return messageRenderer?.text ? (
messageRenderer.text({ text: content.text })
) : (
<TextContent data-testid={'text-content'} key={index}>
{content.text}
</TextContent>
);
} else if (content.image) {
return (
return messageRenderer?.image ? (
messageRenderer?.image({ image: content.image })
) : (
<MediaContent
data-testid={'image-content'}
key={index}
src={convertBufferToBase64(
content.image?.source.bytes,
content.image?.format
)}
></MediaContent>
/>
);
} else if (content.toolUse) {
// For now tool use is limited to custom response components
Expand Down Expand Up @@ -164,7 +170,7 @@ const Layout: typeof View = React.forwardRef(function Layout(props, ref) {
);
});

export const MessagesControl: MessagesControl = ({ renderMessage }) => {
export const MessagesControl: MessagesControl = () => {
const messages = React.useContext(MessagesContext);
const controls = React.useContext(ControlsContext);
const { getMessageTimestampText } = useConversationDisplayText();
Expand Down Expand Up @@ -226,9 +232,7 @@ export const MessagesControl: MessagesControl = ({ renderMessage }) => {
return (
<Layout>
{messagesWithRenderableContent?.map((message, index) => {
return renderMessage ? (
renderMessage(message)
) : (
return (
<RoleContext.Provider value={message.role} key={`message-${index}`}>
<MessageContainer
data-testid={`message`}
Expand Down Expand Up @@ -269,9 +273,7 @@ MessagesControl.Message = MessageControl;
MessagesControl.Separator = Separator;

export interface MessagesControl {
(props: {
renderMessage?: (message: ConversationMessage) => React.ReactNode;
}): JSX.Element;
(): JSX.Element;
ActionsBar: ActionsBarControl;
Avatar: AvatarControl;
Container: AIConversationElements['View'];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import { MessagesControl, MessageControl } from '../MessagesControl';
import { convertBufferToBase64 } from '../../../utils';
import { ConversationMessage } from '../../../../../types';
import { ResponseComponentsProvider } from '../../../context/ResponseComponentsContext';
import { MessageRendererProvider } from '../../../context';

const AITextMessage: ConversationMessage = {
conversationId: 'foobar',
Expand Down Expand Up @@ -212,44 +213,6 @@ describe('MessagesControl', () => {
expect(actionElements).toHaveLength(2);
});

it('renders a MessagesControl element with a custom renderMessage function', () => {
const customMessage = jest.fn((message: ConversationMessage) => (
<div key={message.id} data-testid="custom-message">
{message.content.map((content, index) => {
if (content.text) {
return <p key={index}>{content.text}</p>;
} else if (content.image) {
return (
<img
key={index}
src={convertBufferToBase64(
content.image?.source.bytes,
content.image?.format
)}
></img>
);
}
})}
</div>
));

render(
<MessagesProvider
messages={[AITextMessage, userTextMessage, AIImageMessage]}
>
<MessagesControl renderMessage={customMessage} />
</MessagesProvider>
);

expect(customMessage).toHaveBeenCalledTimes(3);

const defaultMessageElements = screen.queryAllByTestId('message');
expect(defaultMessageElements).toHaveLength(0);

const customMessageElements = screen.queryAllByTestId('custom-message');
expect(customMessageElements).toHaveLength(3);
});

it('renders avatars and actions appropriately if the same user sends multiple messages', () => {
const { rerender } = render(
<AvatarsProvider avatars={avatars}>
Expand Down Expand Up @@ -387,4 +350,33 @@ describe('MessageControl', () => {
const { container } = render(<MessageControl message={ToolUseMessage} />);
expect(container.firstChild).toBeEmptyDOMElement();
});

it('uses text message renderer if passed', () => {
render(
<MessageRendererProvider
text={({ text }) => <div data-testid="custom-message">{text}</div>}
>
<MessageControl message={AITextMessage} />
</MessageRendererProvider>
);
const message = screen.getByTestId('custom-message');
expect(message).toBeInTheDocument();
});

it('uses image message renderer if passed', () => {
render(
<MessageRendererProvider
image={({ image }) => (
<img
data-testid="custom-message"
src={convertBufferToBase64(image.source.bytes, image.format)}
/>
)}
>
<MessageControl message={AIImageMessage} />
</MessageRendererProvider>
);
const message = screen.getByTestId('custom-message');
expect(message).toBeInTheDocument();
});
});
6 changes: 4 additions & 2 deletions packages/react-ai/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ export type ConversationMessage = NonNullable<

export type ConversationMessageContent = ConversationMessage['content'][number];

export type TextContent = NonNullable<ConversationMessageContent['text']>;
export type TextContentBlock = NonNullable<ConversationMessageContent['text']>;

export type ImageContent = NonNullable<ConversationMessageContent['image']>;
export type ImageContentBlock = NonNullable<
ConversationMessageContent['image']
>;

// Note: the conversation sendMessage function is an overload
// that accepts a string OR an object
Expand Down

0 comments on commit 3a697ea

Please sign in to comment.