Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions packages/extension/src/ChatViewProvider.js
Original file line number Diff line number Diff line change
Expand Up @@ -54,28 +54,32 @@ class ChatViewProvider {
})
break;
}
case VsCodeMessageTypes.getPrompts: {
this.getResponse(alitaService, 'getPrompts', {})
break;
}
case VsCodeMessageTypes.getPromptDetail: {
this.getResponse(alitaService, 'getPromptDetail', message.data)
break;
}
case VsCodeMessageTypes.getDatasources: {
this.getResponse(alitaService, 'getDatasources')
break;
}
case VsCodeMessageTypes.getDatasourceDetail: {
this.getResponse(alitaService, 'getDatasourceDetail', message.data)
break;
}
// case VsCodeMessageTypes.getPrompts: {
// this.getResponse(alitaService, 'getPrompts', {})
// break;
// }
// case VsCodeMessageTypes.getPromptDetail: {
// this.getResponse(alitaService, 'getPromptDetail', message.data)
// break;
// }
// case VsCodeMessageTypes.getDatasources: {
// this.getResponse(alitaService, 'getDatasources')
// break;
// }
// case VsCodeMessageTypes.getDatasourceDetail: {
// this.getResponse(alitaService, 'getDatasourceDetail', message.data)
// break;
// }
case VsCodeMessageTypes.getApplications: {
this.getResponse(alitaService, 'getApplications')
break;
}
case VsCodeMessageTypes.getDeployments: {
this.getResponse(alitaService, 'getDeployments')
case VsCodeMessageTypes.getEmbeddings: {
this.getResponse(alitaService, 'getEmbeddings')
break;
}
case VsCodeMessageTypes.createConversation: {
this.getResponse(alitaService, 'createConversation', message.data)
break;
}
case VsCodeMessageTypes.getApplicationDetail: {
Expand Down
15 changes: 4 additions & 11 deletions packages/shared/index.js
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
export const VsCodeMessageTypes = {
getSelectedText: 'extension.getSelectedText',
getChatResponse: 'extension.getChatResponse',
getPrompts: 'extension.getPrompts',
getDatasources: 'extension.getDatasources',
getPromptDetail: 'extension.getPromptDetail',
getDatasourceDetail: 'extension.getDatasourceDetail',
getSocketConfig: 'extension.getSocketConfig',
getModelSettings: 'extension.getModelSettings',
createConversation: 'extension.createConversation',
getApplicationDetail: 'extension.getApplicationDetail',
getApplications: 'extension.getApplications',
getDeployments: 'extension.getDeployments',
getEmbeddings: 'extension.getEmbeddings',
copyCodeToEditor: 'extension.copyCodeToEditor',
stopDatasourceTask: 'extension.stopDatasourceTask',
stopApplicationTask: 'extension.stopApplicationTask',
Expand All @@ -22,21 +19,17 @@ export const UiMessageTypes = {
error: 'ui.error',
startLoading: 'ui.startLoading',
stopLoading: 'ui.stopLoading',
getPrompts: 'ui.getPrompts',
getPromptDetail: 'ui.getPromptDetail',
getDatasourceDetail: 'ui.getDatasourceDetail',
getDatasources: 'ui.getDatasources',
getChatResponse: 'ui.getChatResponse',
getProviderConfig: 'ui.getProviderConfig',
getSocketConfig: 'ui.getSocketConfig',
getModelSettings: 'ui.getModelSettings',
getApplicationDetail: 'ui.getApplicationDetail',
getApplications: 'ui.getApplications',
getDeployments: 'ui.getDeployments',
createConversation: 'ui.createConversation',
getEmbeddings: 'ui.getEmbeddings',
settingsChanged: 'ui.settingsChanged',
stopDatasourceTask: 'ui.stopDatasourceTask',
stopApplicationTask: 'ui.stopApplicationTask',
getPromptVersionDetail: 'ui.getPromptVersionDetail',
getApplicationVersionDetail: 'ui.getApplicationVersionDetail',
copyMessageToClipboard: 'ui.copyMessageToClipboard',
getIdeSettings: 'ui.getIdeSettings',
Expand Down
4 changes: 2 additions & 2 deletions packages/ui/src/common/constants.js
Original file line number Diff line number Diff line change
Expand Up @@ -610,8 +610,8 @@ export const APIKeyTypes = {
export const sioEvents = {
application_predict: 'application_predict',
application_leave_rooms: 'application_leave_rooms',
promptlib_predict: 'promptlib_predict',
promptlib_leave_rooms: 'promptlib_leave_rooms',
chat_predict: 'chat_predict',
chat_leave_rooms: 'chat_leave_rooms',
datasource_predict: 'datasource_predict',
datasource_dataset_status: 'datasource_dataset_status',
datasource_leave_rooms: 'datasource_leave_rooms'
Expand Down
100 changes: 38 additions & 62 deletions packages/ui/src/components/ChatBox/ChatBox.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,15 @@ const getDefaultModel = (model = {}, modelsList) => {


const generatePayload = ({
projectId, prompt_id, type, name, variables, currentVersionId, model_settings
projectId, prompt_id, type, name, variables, currentVersionId, llm_settings
}) => ({
prompt_id,
projectId,

user_name: name,
project_id: projectId,
prompt_version_id: currentVersionId,
model_settings,
llm_settings,

type,
variables: variables ? variables.map(({ name: key, value }) => ({
Expand All @@ -142,10 +142,10 @@ const generatePayload = ({
})

const generateChatPayload = ({
projectId, prompt_id, question, messages, variables, chatHistory, name, currentVersionId, model_settings, interaction_uuid
projectId, prompt_id, question, messages, variables, chatHistory, name, currentVersionId, llm_settings, interaction_uuid, conversation_uuid
}) => {
const payload = generatePayload({
projectId, prompt_id, type: 'chat', variables, name, currentVersionId, model_settings
projectId, prompt_id, type: 'chat', variables, name, currentVersionId, llm_settings
})
payload.chat_history = chatHistory ? chatHistory.map((message) => {
const { role, content, name: userName } = message;
Expand All @@ -160,6 +160,7 @@ const generateChatPayload = ({
payload.messages = messages
}
payload.interaction_uuid = interaction_uuid;
payload.conversation_uuid = conversation_uuid;
return payload
}

Expand All @@ -180,8 +181,9 @@ const ChatBox = forwardRef(({
deployments,
providerConfig,
modelSettings,
conversationUuid,
sendMessage
} = useContext(DataContext);
} = useContext(DataContext);
const chatInput = useRef(null);
const listRefs = useRef([]);
const messagesEndRef = useRef();
Expand All @@ -201,7 +203,6 @@ const ChatBox = forwardRef(({
participantRef.current = participant
}, [participant])


useEffect(() => {
chatWithRef.current = chatWith
}, [chatWith])
Expand Down Expand Up @@ -269,8 +270,7 @@ const ChatBox = forwardRef(({
const handleSocketEvent = useCallback(async message => {
const { stream_id, message_id, type: socketMessageType, message_type, response_metadata } = message
const { task_id } = message.content instanceof Object ? message.content : {}
const [msgIndex, msg] = getMessage(stream_id || message_id, message_type)

const [msgIndex, msg] = getMessage(/*stream_id || */ message_id, message_type)
const scrollToMessageBottom = () => {
if (sessionStorage.getItem(AUTO_SCROLL_KEY) === 'true') {
const messageElement = listRefs.current[msgIndex]
Expand Down Expand Up @@ -384,18 +384,16 @@ const ChatBox = forwardRef(({
prevState[msgIndex] = msg
return [...prevState]
})
}, [getMessage, handleError, scrollToMessageListEnd, setChatHistory])
}, [getMessage, handleError, scrollToMessageListEnd, setChatHistory, chatHistory])

const dataContext = useContext(DataContext);
const { emit, error } = useSocket(
chatWith === ChatTypes.datasource ?
sioEvents.datasource_predict :
chatWith === ChatTypes.application ?
sioEvents.application_predict :
sioEvents.promptlib_predict,
chatWith === ChatTypes.application ?
sioEvents.application_predict :
sioEvents.chat_predict,
handleSocketEvent
)

)
useEffect(() => {
if (error) {
toastError(error);
Expand Down Expand Up @@ -426,7 +424,7 @@ const ChatBox = forwardRef(({
const onPredictStream = useCallback(async data => {
setTimeout(scrollToMessageListEnd, 0);

const { modelSettings, socketConfig, sendMessage } = dataContext
const { socketConfig } = dataContext

const selectedText = await sendMessage({
type: VsCodeMessageTypes.getSelectedText
Expand All @@ -451,7 +449,6 @@ const ChatBox = forwardRef(({
sentTo: participantRef.current ?? {}
}]
})

const projectId = socketConfig?.projectId
if (!projectId) {
toastError('Elitea Code extension Project ID setting is missing. Please set it to continue chat.');
Expand All @@ -464,10 +461,6 @@ const ChatBox = forwardRef(({
}

if (data.application_id) {
if (!data.llm_settings?.integration_uid) {
toastError('Application chat model is missing. Please select another one for chat.');
return
}
const payload = generateApplicationStreamingPayload({
...data,
...data.llm_settings,
Expand All @@ -477,26 +470,9 @@ const ChatBox = forwardRef(({
chatHistory,
interaction_uuid
})

emit(payload)
return
} else if (data.datasource_id) {
if (!data.chat_settings_ai?.integration_uid ||
!data.chat_settings_embedding?.integration_uid) {
toastError('Datasource chat model and/or embedding setting is missing. Please select another one for chat.');
return
}
emit({
project_id: projectId,
version_id: data.currentVersionId || data.datasource_id,
input: question,
chat_history: chatHistory.filter(i => i.role !== MESSAGE_REFERENCE_ROLE).concat(messages),
context: data.context,
chat_settings_ai: data.chat_settings_ai,
chat_settings_embedding: data.chat_settings_embedding,
interaction_uuid
})
return

} else {
const payloadData = {
projectId,
Expand All @@ -506,30 +482,32 @@ const ChatBox = forwardRef(({
messages,
interaction_uuid
}
if (data.prompt_id && data.currentVersionId) {
payloadData.prompt_id = data.prompt_id
payloadData.currentVersionId = data.currentVersionId
payloadData.variables = data.variables
} else {
if (modelSettings) {
payloadData.model_settings = modelSettings
payloadData.model_settings.model = getDefaultModel(modelSettings.model, deployments)
if (!payloadData.model_settings.model.model_name) {
toastError('Elitea Code extension model settings are missing.');
return
} else if (!payloadData.model_settings.model.integration_uid) {
toastError('Elitea Code extension integration Uid is missing.');
return
}
} else {

if (modelSettings) {

payloadData.llm_settings = modelSettings
payloadData.llm_settings = getDefaultModel(modelSettings.model, deployments)
payloadData.llm_settings.model_project_id = 1
payloadData.llm_settings.temperature = modelSettings.temperature
// payloadData.llm_settings.top_p = modelSettings.top_p
payloadData.llm_settings.max_tokens = modelSettings.max_tokens
if (!payloadData.llm_settings.model_name) {
toastError('Elitea Code extension model settings are missing.');
return
} else if (!payloadData.llm_settings.integration_uid) {
toastError('Elitea Code extension integration Uid is missing.');
return
}
} else {
toastError('Elitea Code extension model settings are missing.');
return
}
// }
const payload = generateChatPayload(payloadData)
payload.conversation_uuid = conversationUuid
emit(payload)
}
},
},
[scrollToMessageListEnd, dataContext, setChatHistory, emit, chatHistory, deployments, error, toastError, interaction_uuid])

const onSend = useCallback(
Expand All @@ -544,11 +522,9 @@ const ChatBox = forwardRef(({
);

const { emit: manualEmit } = useManualSocket(
chatWith === ChatTypes.datasource ?
sioEvents.datasource_leave_rooms :
chatWith === ChatTypes.application ?
sioEvents.application_leave_rooms :
sioEvents.promptlib_leave_rooms
chatWith === ChatTypes.application ?
sioEvents.application_leave_rooms :
sioEvents.chat_leave_rooms
);
const {
isStreaming,
Expand Down
Loading
Loading