diff --git a/.changeset/heavy-dots-applaud.md b/.changeset/heavy-dots-applaud.md new file mode 100644 index 00000000000..26b303ed917 --- /dev/null +++ b/.changeset/heavy-dots-applaud.md @@ -0,0 +1,16 @@ +--- +"@aws-amplify/ui-react-ai": minor +--- + +feat(ai): add message renderer + +```tsx + {text}, + }} +/> +``` diff --git a/examples/next/pages/ui/components/ai/ai-conversation-renderer/amplify_outputs.js b/examples/next/pages/ui/components/ai/ai-conversation-renderer/amplify_outputs.js new file mode 100644 index 00000000000..2f1016412fd --- /dev/null +++ b/examples/next/pages/ui/components/ai/ai-conversation-renderer/amplify_outputs.js @@ -0,0 +1,2 @@ +import amplifyOutputs from '@environments/ai/gen2/amplify_outputs'; +export default amplifyOutputs; diff --git a/examples/next/pages/ui/components/ai/ai-conversation-renderer/index.page.tsx b/examples/next/pages/ui/components/ai/ai-conversation-renderer/index.page.tsx new file mode 100644 index 00000000000..42409609fdd --- /dev/null +++ b/examples/next/pages/ui/components/ai/ai-conversation-renderer/index.page.tsx @@ -0,0 +1,89 @@ +import { Amplify } from 'aws-amplify'; +import { createAIHooks, AIConversation } from '@aws-amplify/ui-react-ai'; +import { generateClient } from 'aws-amplify/api'; +import '@aws-amplify/ui-react/styles.css'; +import '@aws-amplify/ui-react-ai/ai-conversation-styles.css'; + +import outputs from './amplify_outputs'; +import type { Schema } from '@environments/ai/gen2/amplify/data/resource'; +import { Authenticator, Card, Text } from '@aws-amplify/ui-react'; +import Image from 'next/image'; + +const client = generateClient({ authMode: 'userPool' }); +const { useAIConversation } = createAIHooks(client); + +Amplify.configure(outputs); + +function arrayBufferToBase64(buffer: ArrayBuffer) { + let binary = ''; + const bytes = new Uint8Array(buffer); + const len = bytes.byteLength; + for (let i = 0; i < len; i++) { + binary += String.fromCharCode(bytes[i]); + } + return window.btoa(binary); +} + +function convertBufferToBase64(buffer: ArrayBuffer, format: string): string { + let base64string = ''; + // Use node-based buffer if available + // fall back on browser if not + if (typeof Buffer !== 'undefined') { + base64string = Buffer.from(new Uint8Array(buffer)).toString('base64'); + } else { + base64string = arrayBufferToBase64(buffer); + } + return `data:image/${format};base64,${base64string}`; +} + +function Chat() { + const [ + { + data: { messages }, + isLoading, + }, + sendMessage, + ] = useAIConversation('pirateChat'); + + return ( + + {text}, + image: ({ image }) => ( + + ), + }} + suggestedPrompts={[ + { + inputText: 'hello', + header: 'hello', + }, + { + inputText: 'how are you?', + header: 'how are you?', + }, + ]} + variant="bubble" + /> + + ); +} + +export default function Example() { + return ( + + + + ); +} diff --git a/packages/react-ai/src/components/AIConversation/AIConversation.tsx b/packages/react-ai/src/components/AIConversation/AIConversation.tsx index 382d1b380b3..a17bc811c86 100644 --- a/packages/react-ai/src/components/AIConversation/AIConversation.tsx +++ b/packages/react-ai/src/components/AIConversation/AIConversation.tsx @@ -33,6 +33,7 @@ function AIConversationBase({ isLoading, displayText, allowAttachments, + messageRenderer, }: AIConversationBaseProps): JSX.Element { useSetUserAgent({ componentName: 'AIConversation', @@ -78,6 +79,7 @@ function AIConversationBase({ }, displayText, allowAttachments, + messageRenderer, }; return ( diff --git a/packages/react-ai/src/components/AIConversation/context/MessageRenderContext.tsx b/packages/react-ai/src/components/AIConversation/context/MessageRenderContext.tsx new file mode 100644 index 00000000000..7336e74f95c --- /dev/null +++ b/packages/react-ai/src/components/AIConversation/context/MessageRenderContext.tsx @@ -0,0 +1,13 @@ +import { createContextUtilities } from '@aws-amplify/ui-react-core'; +import { MessageRenderer } from '../types'; + +export const { + MessageRendererContext, + MessageRendererProvider, + useMessageRenderer, +} = createContextUtilities({ + contextName: 'MessageRenderer', + defaultValue: undefined, + errorMessage: + '`useMessageRenderer` must be used with an AIConversation component', +}); diff --git a/packages/react-ai/src/components/AIConversation/context/index.ts b/packages/react-ai/src/components/AIConversation/context/index.ts index 5c56c750018..ca32e10d824 100644 --- a/packages/react-ai/src/components/AIConversation/context/index.ts +++ b/packages/react-ai/src/components/AIConversation/context/index.ts @@ -34,5 +34,11 @@ export { RESPONSE_COMPONENT_PREFIX, } from './ResponseComponentsContext'; export { SendMessageContextProvider } from './SendMessageContext'; +export { + MessageRendererProvider, + MessageRendererContext, + useMessageRenderer, +} from './MessageRenderContext'; +export { AttachmentProvider, AttachmentContext } from './AttachmentContext'; export * from './elements'; diff --git a/packages/react-ai/src/components/AIConversation/createAIConversation.tsx b/packages/react-ai/src/components/AIConversation/createAIConversation.tsx index 83bad22362f..e672db0ef50 100644 --- a/packages/react-ai/src/components/AIConversation/createAIConversation.tsx +++ b/packages/react-ai/src/components/AIConversation/createAIConversation.tsx @@ -31,6 +31,7 @@ export function createAIConversation(input: AIConversationInput = {}): { controls, displayText, allowAttachments, + messageRenderer, } = input; function AIConversation(props: AIConversationProps): JSX.Element { @@ -48,6 +49,7 @@ export function createAIConversation(input: AIConversationInput = {}): { avatars, handleSendMessage, isLoading, + messageRenderer, }; return ( diff --git a/packages/react-ai/src/components/AIConversation/types.ts b/packages/react-ai/src/components/AIConversation/types.ts index ba987b8dce1..1a6ccac8487 100644 --- a/packages/react-ai/src/components/AIConversation/types.ts +++ b/packages/react-ai/src/components/AIConversation/types.ts @@ -11,7 +11,12 @@ import { } from './views'; import { DisplayTextTemplate } from '@aws-amplify/ui'; import { AIConversationDisplayText } from './displayText'; -import { ConversationMessage, SendMessage } from '../../types'; +import { + ConversationMessage, + ImageContentBlock, + SendMessage, + TextContentBlock, +} from '../../types'; import { ControlsContextProps } from './context/ControlsContext'; export interface Controls { @@ -32,6 +37,7 @@ export interface AIConversationInput { variant?: MessageVariant; controls?: ControlsContextProps; allowAttachments?: boolean; + messageRenderer?: MessageRenderer; } export interface AIConversationProps { @@ -54,6 +60,11 @@ export interface AIConversation { export type MessageVariant = 'bubble' | 'default'; +export interface MessageRenderer { + text?: (input: { text: TextContentBlock }) => React.JSX.Element; + image?: (input: { image: ImageContentBlock }) => React.JSX.Element; +} + export interface Avatar { username?: string; avatar?: React.ReactNode; diff --git a/packages/react-ai/src/components/AIConversation/utils.ts b/packages/react-ai/src/components/AIConversation/utils.ts index d93401baa66..023cec0be22 100644 --- a/packages/react-ai/src/components/AIConversation/utils.ts +++ b/packages/react-ai/src/components/AIConversation/utils.ts @@ -1,4 +1,4 @@ -import { ImageContent } from '../../types'; +import { ImageContentBlock } from '../../types'; export function formatDate(date: Date): string { const dateString = date.toLocaleDateString('en-US', { @@ -27,7 +27,7 @@ function arrayBufferToBase64(buffer: ArrayBuffer) { export function convertBufferToBase64( buffer: ArrayBuffer, - format: ImageContent['format'] + format: ImageContentBlock['format'] ): string { let base64string = ''; // Use node-based buffer if available diff --git a/packages/react-ai/src/components/AIConversation/views/Controls/MessagesControl.tsx b/packages/react-ai/src/components/AIConversation/views/Controls/MessagesControl.tsx index deface4b773..5de1fd3f1dc 100644 --- a/packages/react-ai/src/components/AIConversation/views/Controls/MessagesControl.tsx +++ b/packages/react-ai/src/components/AIConversation/views/Controls/MessagesControl.tsx @@ -2,6 +2,7 @@ import React from 'react'; import { withBaseElementProps } from '@aws-amplify/ui-react-core/elements'; import { + MessageRendererContext, MessagesContext, MessageVariantContext, RoleContext, @@ -63,17 +64,22 @@ const ContentContainer: typeof View = React.forwardRef( export const MessageControl: MessageControl = ({ message }) => { const responseComponents = React.useContext(ResponseComponentsContext); + const messageRenderer = React.useContext(MessageRendererContext); return ( {message.content.map((content, index) => { if (content.text) { - return ( + return messageRenderer?.text ? ( + messageRenderer.text({ text: content.text }) + ) : ( {content.text} ); } else if (content.image) { - return ( + return messageRenderer?.image ? ( + messageRenderer?.image({ image: content.image }) + ) : ( { content.image?.source.bytes, content.image?.format )} - > + /> ); } else if (content.toolUse) { // For now tool use is limited to custom response components @@ -164,7 +170,7 @@ const Layout: typeof View = React.forwardRef(function Layout(props, ref) { ); }); -export const MessagesControl: MessagesControl = ({ renderMessage }) => { +export const MessagesControl: MessagesControl = () => { const messages = React.useContext(MessagesContext); const controls = React.useContext(ControlsContext); const { getMessageTimestampText } = useConversationDisplayText(); @@ -226,9 +232,7 @@ export const MessagesControl: MessagesControl = ({ renderMessage }) => { return ( {messagesWithRenderableContent?.map((message, index) => { - return renderMessage ? ( - renderMessage(message) - ) : ( + return ( React.ReactNode; - }): JSX.Element; + (): JSX.Element; ActionsBar: ActionsBarControl; Avatar: AvatarControl; Container: AIConversationElements['View']; diff --git a/packages/react-ai/src/components/AIConversation/views/Controls/__tests__/MessagesControl.spec.tsx b/packages/react-ai/src/components/AIConversation/views/Controls/__tests__/MessagesControl.spec.tsx index be5ed0206fd..a054835a569 100644 --- a/packages/react-ai/src/components/AIConversation/views/Controls/__tests__/MessagesControl.spec.tsx +++ b/packages/react-ai/src/components/AIConversation/views/Controls/__tests__/MessagesControl.spec.tsx @@ -13,6 +13,7 @@ import { MessagesControl, MessageControl } from '../MessagesControl'; import { convertBufferToBase64 } from '../../../utils'; import { ConversationMessage } from '../../../../../types'; import { ResponseComponentsProvider } from '../../../context/ResponseComponentsContext'; +import { MessageRendererProvider } from '../../../context'; const AITextMessage: ConversationMessage = { conversationId: 'foobar', @@ -212,44 +213,6 @@ describe('MessagesControl', () => { expect(actionElements).toHaveLength(2); }); - it('renders a MessagesControl element with a custom renderMessage function', () => { - const customMessage = jest.fn((message: ConversationMessage) => ( -
- {message.content.map((content, index) => { - if (content.text) { - return

{content.text}

; - } else if (content.image) { - return ( - - ); - } - })} -
- )); - - render( - - - - ); - - expect(customMessage).toHaveBeenCalledTimes(3); - - const defaultMessageElements = screen.queryAllByTestId('message'); - expect(defaultMessageElements).toHaveLength(0); - - const customMessageElements = screen.queryAllByTestId('custom-message'); - expect(customMessageElements).toHaveLength(3); - }); - it('renders avatars and actions appropriately if the same user sends multiple messages', () => { const { rerender } = render( @@ -387,4 +350,33 @@ describe('MessageControl', () => { const { container } = render(); expect(container.firstChild).toBeEmptyDOMElement(); }); + + it('uses text message renderer if passed', () => { + render( +
{text}
} + > + +
+ ); + const message = screen.getByTestId('custom-message'); + expect(message).toBeInTheDocument(); + }); + + it('uses image message renderer if passed', () => { + render( + ( + + )} + > + + + ); + const message = screen.getByTestId('custom-message'); + expect(message).toBeInTheDocument(); + }); }); diff --git a/packages/react-ai/src/types.ts b/packages/react-ai/src/types.ts index ec7c108ece4..8a887c5020e 100644 --- a/packages/react-ai/src/types.ts +++ b/packages/react-ai/src/types.ts @@ -11,9 +11,11 @@ export type ConversationMessage = NonNullable< export type ConversationMessageContent = ConversationMessage['content'][number]; -export type TextContent = NonNullable; +export type TextContentBlock = NonNullable; -export type ImageContent = NonNullable; +export type ImageContentBlock = NonNullable< + ConversationMessageContent['image'] +>; // Note: the conversation sendMessage function is an overload // that accepts a string OR an object