diff --git a/packages/extension/src/ChatViewProvider.js b/packages/extension/src/ChatViewProvider.js index 8730c3d..cb24f7f 100644 --- a/packages/extension/src/ChatViewProvider.js +++ b/packages/extension/src/ChatViewProvider.js @@ -54,28 +54,32 @@ class ChatViewProvider { }) break; } - case VsCodeMessageTypes.getPrompts: { - this.getResponse(alitaService, 'getPrompts', {}) - break; - } - case VsCodeMessageTypes.getPromptDetail: { - this.getResponse(alitaService, 'getPromptDetail', message.data) - break; - } - case VsCodeMessageTypes.getDatasources: { - this.getResponse(alitaService, 'getDatasources') - break; - } - case VsCodeMessageTypes.getDatasourceDetail: { - this.getResponse(alitaService, 'getDatasourceDetail', message.data) - break; - } + // case VsCodeMessageTypes.getPrompts: { + // this.getResponse(alitaService, 'getPrompts', {}) + // break; + // } + // case VsCodeMessageTypes.getPromptDetail: { + // this.getResponse(alitaService, 'getPromptDetail', message.data) + // break; + // } + // case VsCodeMessageTypes.getDatasources: { + // this.getResponse(alitaService, 'getDatasources') + // break; + // } + // case VsCodeMessageTypes.getDatasourceDetail: { + // this.getResponse(alitaService, 'getDatasourceDetail', message.data) + // break; + // } case VsCodeMessageTypes.getApplications: { this.getResponse(alitaService, 'getApplications') break; } - case VsCodeMessageTypes.getDeployments: { - this.getResponse(alitaService, 'getDeployments') + case VsCodeMessageTypes.getEmbeddings: { + this.getResponse(alitaService, 'getEmbeddings') + break; + } + case VsCodeMessageTypes.createConversation: { + this.getResponse(alitaService, 'createConversation', message.data) break; } case VsCodeMessageTypes.getApplicationDetail: { diff --git a/packages/shared/index.js b/packages/shared/index.js index 15753cc..539671f 100644 --- a/packages/shared/index.js +++ b/packages/shared/index.js @@ -1,15 +1,12 @@ export const VsCodeMessageTypes = { getSelectedText: 'extension.getSelectedText', getChatResponse: 'extension.getChatResponse', - getPrompts: 'extension.getPrompts', - getDatasources: 'extension.getDatasources', - getPromptDetail: 'extension.getPromptDetail', - getDatasourceDetail: 'extension.getDatasourceDetail', getSocketConfig: 'extension.getSocketConfig', getModelSettings: 'extension.getModelSettings', + createConversation: 'extension.createConversation', getApplicationDetail: 'extension.getApplicationDetail', getApplications: 'extension.getApplications', - getDeployments: 'extension.getDeployments', + getEmbeddings: 'extension.getEmbeddings', copyCodeToEditor: 'extension.copyCodeToEditor', stopDatasourceTask: 'extension.stopDatasourceTask', stopApplicationTask: 'extension.stopApplicationTask', @@ -22,21 +19,17 @@ export const UiMessageTypes = { error: 'ui.error', startLoading: 'ui.startLoading', stopLoading: 'ui.stopLoading', - getPrompts: 'ui.getPrompts', - getPromptDetail: 'ui.getPromptDetail', - getDatasourceDetail: 'ui.getDatasourceDetail', - getDatasources: 'ui.getDatasources', getChatResponse: 'ui.getChatResponse', getProviderConfig: 'ui.getProviderConfig', getSocketConfig: 'ui.getSocketConfig', getModelSettings: 'ui.getModelSettings', getApplicationDetail: 'ui.getApplicationDetail', getApplications: 'ui.getApplications', - getDeployments: 'ui.getDeployments', + createConversation: 'ui.createConversation', + getEmbeddings: 'ui.getEmbeddings', settingsChanged: 'ui.settingsChanged', stopDatasourceTask: 'ui.stopDatasourceTask', stopApplicationTask: 'ui.stopApplicationTask', - getPromptVersionDetail: 'ui.getPromptVersionDetail', getApplicationVersionDetail: 'ui.getApplicationVersionDetail', copyMessageToClipboard: 'ui.copyMessageToClipboard', getIdeSettings: 'ui.getIdeSettings', diff --git a/packages/ui/src/common/constants.js b/packages/ui/src/common/constants.js index e2c32ca..8fb3190 100644 --- a/packages/ui/src/common/constants.js +++ b/packages/ui/src/common/constants.js @@ -610,8 +610,8 @@ export const APIKeyTypes = { export const sioEvents = { application_predict: 'application_predict', application_leave_rooms: 'application_leave_rooms', - promptlib_predict: 'promptlib_predict', - promptlib_leave_rooms: 'promptlib_leave_rooms', + chat_predict: 'chat_predict', + chat_leave_rooms: 'chat_leave_rooms', datasource_predict: 'datasource_predict', datasource_dataset_status: 'datasource_dataset_status', datasource_leave_rooms: 'datasource_leave_rooms' diff --git a/packages/ui/src/components/ChatBox/ChatBox.jsx b/packages/ui/src/components/ChatBox/ChatBox.jsx index 8de1869..1757911 100644 --- a/packages/ui/src/components/ChatBox/ChatBox.jsx +++ b/packages/ui/src/components/ChatBox/ChatBox.jsx @@ -123,7 +123,7 @@ const getDefaultModel = (model = {}, modelsList) => { const generatePayload = ({ - projectId, prompt_id, type, name, variables, currentVersionId, model_settings + projectId, prompt_id, type, name, variables, currentVersionId, llm_settings }) => ({ prompt_id, projectId, @@ -131,7 +131,7 @@ const generatePayload = ({ user_name: name, project_id: projectId, prompt_version_id: currentVersionId, - model_settings, + llm_settings, type, variables: variables ? variables.map(({ name: key, value }) => ({ @@ -142,10 +142,10 @@ const generatePayload = ({ }) const generateChatPayload = ({ - projectId, prompt_id, question, messages, variables, chatHistory, name, currentVersionId, model_settings, interaction_uuid + projectId, prompt_id, question, messages, variables, chatHistory, name, currentVersionId, llm_settings, interaction_uuid, conversation_uuid }) => { const payload = generatePayload({ - projectId, prompt_id, type: 'chat', variables, name, currentVersionId, model_settings + projectId, prompt_id, type: 'chat', variables, name, currentVersionId, llm_settings }) payload.chat_history = chatHistory ? chatHistory.map((message) => { const { role, content, name: userName } = message; @@ -160,6 +160,7 @@ const generateChatPayload = ({ payload.messages = messages } payload.interaction_uuid = interaction_uuid; + payload.conversation_uuid = conversation_uuid; return payload } @@ -180,8 +181,9 @@ const ChatBox = forwardRef(({ deployments, providerConfig, modelSettings, + conversationUuid, sendMessage - } = useContext(DataContext); + } = useContext(DataContext); const chatInput = useRef(null); const listRefs = useRef([]); const messagesEndRef = useRef(); @@ -201,7 +203,6 @@ const ChatBox = forwardRef(({ participantRef.current = participant }, [participant]) - useEffect(() => { chatWithRef.current = chatWith }, [chatWith]) @@ -269,8 +270,7 @@ const ChatBox = forwardRef(({ const handleSocketEvent = useCallback(async message => { const { stream_id, message_id, type: socketMessageType, message_type, response_metadata } = message const { task_id } = message.content instanceof Object ? message.content : {} - const [msgIndex, msg] = getMessage(stream_id || message_id, message_type) - + const [msgIndex, msg] = getMessage(/*stream_id || */ message_id, message_type) const scrollToMessageBottom = () => { if (sessionStorage.getItem(AUTO_SCROLL_KEY) === 'true') { const messageElement = listRefs.current[msgIndex] @@ -384,18 +384,16 @@ const ChatBox = forwardRef(({ prevState[msgIndex] = msg return [...prevState] }) - }, [getMessage, handleError, scrollToMessageListEnd, setChatHistory]) + }, [getMessage, handleError, scrollToMessageListEnd, setChatHistory, chatHistory]) const dataContext = useContext(DataContext); const { emit, error } = useSocket( - chatWith === ChatTypes.datasource ? - sioEvents.datasource_predict : - chatWith === ChatTypes.application ? - sioEvents.application_predict : - sioEvents.promptlib_predict, + chatWith === ChatTypes.application ? + sioEvents.application_predict : + sioEvents.chat_predict, handleSocketEvent - ) - + ) + useEffect(() => { if (error) { toastError(error); @@ -426,7 +424,7 @@ const ChatBox = forwardRef(({ const onPredictStream = useCallback(async data => { setTimeout(scrollToMessageListEnd, 0); - const { modelSettings, socketConfig, sendMessage } = dataContext + const { socketConfig } = dataContext const selectedText = await sendMessage({ type: VsCodeMessageTypes.getSelectedText @@ -451,7 +449,6 @@ const ChatBox = forwardRef(({ sentTo: participantRef.current ?? {} }] }) - const projectId = socketConfig?.projectId if (!projectId) { toastError('Elitea Code extension Project ID setting is missing. Please set it to continue chat.'); @@ -464,10 +461,6 @@ const ChatBox = forwardRef(({ } if (data.application_id) { - if (!data.llm_settings?.integration_uid) { - toastError('Application chat model is missing. Please select another one for chat.'); - return - } const payload = generateApplicationStreamingPayload({ ...data, ...data.llm_settings, @@ -477,26 +470,9 @@ const ChatBox = forwardRef(({ chatHistory, interaction_uuid }) - emit(payload) return - } else if (data.datasource_id) { - if (!data.chat_settings_ai?.integration_uid || - !data.chat_settings_embedding?.integration_uid) { - toastError('Datasource chat model and/or embedding setting is missing. Please select another one for chat.'); - return - } - emit({ - project_id: projectId, - version_id: data.currentVersionId || data.datasource_id, - input: question, - chat_history: chatHistory.filter(i => i.role !== MESSAGE_REFERENCE_ROLE).concat(messages), - context: data.context, - chat_settings_ai: data.chat_settings_ai, - chat_settings_embedding: data.chat_settings_embedding, - interaction_uuid - }) - return + } else { const payloadData = { projectId, @@ -506,30 +482,32 @@ const ChatBox = forwardRef(({ messages, interaction_uuid } - if (data.prompt_id && data.currentVersionId) { - payloadData.prompt_id = data.prompt_id - payloadData.currentVersionId = data.currentVersionId - payloadData.variables = data.variables - } else { - if (modelSettings) { - payloadData.model_settings = modelSettings - payloadData.model_settings.model = getDefaultModel(modelSettings.model, deployments) - if (!payloadData.model_settings.model.model_name) { - toastError('Elitea Code extension model settings are missing.'); - return - } else if (!payloadData.model_settings.model.integration_uid) { - toastError('Elitea Code extension integration Uid is missing.'); - return - } - } else { + + if (modelSettings) { + + payloadData.llm_settings = modelSettings + payloadData.llm_settings = getDefaultModel(modelSettings.model, deployments) + payloadData.llm_settings.model_project_id = 1 + payloadData.llm_settings.temperature = modelSettings.temperature + // payloadData.llm_settings.top_p = modelSettings.top_p + payloadData.llm_settings.max_tokens = modelSettings.max_tokens + if (!payloadData.llm_settings.model_name) { toastError('Elitea Code extension model settings are missing.'); return + } else if (!payloadData.llm_settings.integration_uid) { + toastError('Elitea Code extension integration Uid is missing.'); + return } + } else { + toastError('Elitea Code extension model settings are missing.'); + return } + // } const payload = generateChatPayload(payloadData) + payload.conversation_uuid = conversationUuid emit(payload) } - }, + }, [scrollToMessageListEnd, dataContext, setChatHistory, emit, chatHistory, deployments, error, toastError, interaction_uuid]) const onSend = useCallback( @@ -544,11 +522,9 @@ const ChatBox = forwardRef(({ ); const { emit: manualEmit } = useManualSocket( - chatWith === ChatTypes.datasource ? - sioEvents.datasource_leave_rooms : - chatWith === ChatTypes.application ? - sioEvents.application_leave_rooms : - sioEvents.promptlib_leave_rooms + chatWith === ChatTypes.application ? + sioEvents.application_leave_rooms : + sioEvents.chat_leave_rooms ); const { isStreaming, diff --git a/packages/ui/src/context/DataContext.jsx b/packages/ui/src/context/DataContext.jsx index fc073be..873c1fa 100644 --- a/packages/ui/src/context/DataContext.jsx +++ b/packages/ui/src/context/DataContext.jsx @@ -18,34 +18,24 @@ const filterByCodeTag = (list) => { )); } -const getFilteredModels = (integration, capabilities) => { - return (integration?.settings?.models || []) - .filter(modelItem => { - return capabilities.some(capability => modelItem.capabilities[capability]); - }) - .map(({ name }) => ({ - model_name: name, - integration_uid: integration.uid, - })); +const getIntegrationOptions = (integrations) => { + const result = integrations.shared.items.map((integration) => { + return { + model_name: integration?.data?.name, + integration_uid: integration?.uuid + }; + }); + return result; } -const getIntegrationOptions = (integrations) => integrations.reduce((accumulator, integration) => { - const filteredModels = getFilteredModels(integration, ['chat_completion', 'completion']); - if (filteredModels.length > 0) { - accumulator = [...accumulator, ...filteredModels]; - } - return accumulator; -}, []); - export const DataProvider = ({ children }) => { const [providerConfig, setProviderConfig] = useState(null); const [socketConfig, setSocketConfig] = useState(null); const [chatHistory, setChatHistory] = useState([]); const [isLoading, setIsLoading] = useState(false); - const [prompts, setPrompts] = useState([]); - const [datasources, setDatasources] = useState([]); const [applications, setApplications] = useState([]); - const [deployments, setDeployments] = useState([]) + const [deployments, setDeployments] = useState([]); + const [conversationUuid, setConversationUuid] = useState(null); const vscodeRef = useRef(null); const [modelSettings, setModelSettings] = useState(null); @@ -69,22 +59,20 @@ export const DataProvider = ({ children }) => { if (!vscodeRef.current) return vscodeRef.current?.postMessage({ - type: VsCodeMessageTypes.getSocketConfig, - }); - vscodeRef.current?.postMessage({ - type: VsCodeMessageTypes.getModelSettings, + type: VsCodeMessageTypes.createConversation, + data: 'IDE Chat' }); vscodeRef.current?.postMessage({ - type: VsCodeMessageTypes.getPrompts, + type: VsCodeMessageTypes.getSocketConfig, }); vscodeRef.current?.postMessage({ - type: VsCodeMessageTypes.getDatasources, + type: VsCodeMessageTypes.getModelSettings, }); vscodeRef.current?.postMessage({ type: VsCodeMessageTypes.getApplications, }); vscodeRef.current?.postMessage({ - type: VsCodeMessageTypes.getDeployments, + type: VsCodeMessageTypes.getEmbeddings, }); vscodeRef.current?.postMessage({ type: VsCodeMessageTypes.getIdeSettings, @@ -137,16 +125,13 @@ export const DataProvider = ({ children }) => { case UiMessageTypes.stopLoading: setIsLoading(false); break; - case UiMessageTypes.getPrompts: - setPrompts(filterByCodeTag(message.data)); - break; - case UiMessageTypes.getDatasources: - setDatasources(filterByCodeTag(message.data)); - break; case UiMessageTypes.getApplications: setApplications(filterByCodeTag(message.data)); break; - case UiMessageTypes.getDeployments: + case UiMessageTypes.createConversation: + setConversationUuid(message.data.uuid); + break + case UiMessageTypes.getEmbeddings: setDeployments(getIntegrationOptions(message.data || [])); break; case UiMessageTypes.settingsChanged: @@ -174,13 +159,13 @@ export const DataProvider = ({ children }) => { const urlSrcObj = new URL(socketHostSrc.concat(socketPathSrc)); const socketHost = urlSrcObj.origin; const socketPath = urlSrcObj.pathname; + urlSrcObj.protocol = urlSrcObj.protocol.replace("ws", "http"); urlSrcObj.pathname = urlSrcObj.pathname.replace("/socket.io", ""); const url = removeTrailingSlashes(urlSrcObj.toString()); const apiUrl = url.concat("/api/v1"); const projectId = message.data.projectId; const token = message.data.token; - setProviderConfig({url, apiUrl, socketHost, socketPath, projectId, token}); } break; @@ -258,10 +243,11 @@ export const DataProvider = ({ children }) => { chatHistory, setChatHistory, isLoading, - prompts, - datasources, + // prompts, + // datasources, applications, deployments, + conversationUuid, sendMessage, loadCoreData, callProvider, diff --git a/packages/ui/src/context/SocketProvider.jsx b/packages/ui/src/context/SocketProvider.jsx index 6dfa19e..5e0f147 100644 --- a/packages/ui/src/context/SocketProvider.jsx +++ b/packages/ui/src/context/SocketProvider.jsx @@ -13,7 +13,6 @@ export function SocketProvider({ children }) { const createSocket = useCallback(() => { if (!providerConfig || !providerConfig.socketHost) return; const { socketHost, socketPath, token } = providerConfig - const socketIo = io(socketHost, { path: socketPath, reconnectionDelayMax: 2000, diff --git a/packages/ui/test/plugin/server.js b/packages/ui/test/plugin/server.js index d45c8f9..66d30d8 100644 --- a/packages/ui/test/plugin/server.js +++ b/packages/ui/test/plugin/server.js @@ -70,26 +70,14 @@ app.get('/', async (req, res) => { ); data.model.integration_uid = program.opts().modelUid; break; - case 'extension.getPrompts': - data = (await getData(`prompt_lib/prompts/prompt_lib/${projectId}`)).rows; - break; - case 'extension.getDatasources': - data = (await getData(`datasources/datasources/prompt_lib/${projectId}`)).rows; - break; case 'extension.getApplications': data = (await getData(`applications/applications/prompt_lib/${projectId}?offset=0&limit=1000`)).rows; break; - case 'extension.getPromptDetail': - data = (await getData(`prompt_lib/prompt/prompt_lib/${projectId}/${chatData}`)); - break; - case 'extension.getDatasourceDetail': - data = (await getData(`datasources/datasource/prompt_lib/${projectId}/${chatData}`)); - break; case 'extension.getApplicationDetail': data = (await getData(`applications/application/prompt_lib/${projectId}/${chatData}`)); break; - case 'extension.getDeployments': - data = (await getData(`integrations/integrations/default/${projectId}`)); + case 'extension.getEmbeddings': + data = (await getData(`configurations/configurations/${projectId}?include_shared=true§ion=llm`)).shared.items; break; case 'extension.copyCodeToEditor': data = ''; diff --git a/packages/ui/test/ui_test/tests/selectedText.spec.js b/packages/ui/test/ui_test/tests/selectedText.spec.js index 020d7c1..cb5c613 100644 --- a/packages/ui/test/ui_test/tests/selectedText.spec.js +++ b/packages/ui/test/ui_test/tests/selectedText.spec.js @@ -49,7 +49,7 @@ test('Verify selected text is sent in system message', async ({ page }) => { }); const promptlibMessage = socketMessages.find(msg => - msg.startsWith('42[') && msg.includes('promptlib_predict') + msg.startsWith('42[') && msg.includes('chat_predict') ); expect(promptlibMessage).toBeDefined();