diff --git a/application/ui/src/components/stream/stream.tsx b/application/ui/src/components/stream/stream.tsx deleted file mode 100644 index 78c240f4dc..0000000000 --- a/application/ui/src/components/stream/stream.tsx +++ /dev/null @@ -1,120 +0,0 @@ -import { Dispatch, RefObject, SetStateAction, useCallback, useEffect, useRef } from 'react'; - -import { useWebRTCConnection } from '../../components/stream/web-rtc-connection-provider'; -import { ZoomTransform } from '../zoom/zoom-transform'; - -const useSetTargetSizeBasedOnVideo = ( - setSize: Dispatch>, - videoRef: RefObject -) => { - useEffect(() => { - const video = videoRef.current; - - const onsize = video?.addEventListener('loadedmetadata', (event) => { - const target = event.currentTarget as HTMLVideoElement; - - if (target.videoWidth && target.videoHeight) { - setSize({ width: target.videoWidth, height: target.videoHeight }); - } - }); - - const onresize = video?.addEventListener('resize', (event) => { - const target = event.currentTarget as HTMLVideoElement; - - if (target.videoWidth && target.videoHeight) { - setSize({ width: target.videoWidth, height: target.videoHeight }); - } - }); - - return () => { - if (onsize) { - video?.removeEventListener('loadedmetadata', onsize); - } - - if (onresize) { - video?.removeEventListener('resize', onresize); - } - }; - }, [setSize, videoRef]); -}; - -const useStreamToVideo = () => { - const videoRef = useRef(null); - - const { status, webRTCConnectionRef } = useWebRTCConnection(); - - const connect = useCallback(async () => { - const videoOutput = videoRef.current; - const webrtcConnection = webRTCConnectionRef.current; - const peerConnection = webrtcConnection?.getPeerConnection(); - - if (!peerConnection) { - return; - } - - const receivers = peerConnection.getReceivers() ?? []; - const stream = new MediaStream(receivers.map((receiver) => receiver.track)); - - if (videoOutput && videoOutput.srcObject !== stream) { - videoOutput.srcObject = stream; - } - }, [videoRef, webRTCConnectionRef]); - - useEffect(() => { - if (status === 'connected') { - connect(); - } - }, [status, connect]); - - useEffect(() => { - const webrtcConnection = webRTCConnectionRef.current; - const peerConnection = webrtcConnection?.getPeerConnection(); - - if (!peerConnection) { - return; - } - - peerConnection.addEventListener('track', connect); - - return () => { - peerConnection.removeEventListener('track', connect); - }; - }, [webRTCConnectionRef, connect]); - - return videoRef; -}; - -export const Stream = ({ - size, - setSize, -}: { - size: { width: number; height: number }; - setSize: Dispatch>; -}) => { - const videoRef = useStreamToVideo(); - - useSetTargetSizeBasedOnVideo(setSize, videoRef); - - const { status } = useWebRTCConnection(); - - return ( - -
- {status === 'connected' && ( - // eslint-disable-next-line jsx-a11y/media-has-caption -
-
- ); -}; diff --git a/application/ui/src/components/stream/web-rtc-connection.ts b/application/ui/src/components/stream/web-rtc-connection.ts index 1022cc5433..30ae4b1a60 100644 --- a/application/ui/src/components/stream/web-rtc-connection.ts +++ b/application/ui/src/components/stream/web-rtc-connection.ts @@ -1,3 +1,5 @@ +import { fetchClient } from '../../api/client'; + export type WebRTCConnectionStatus = 'idle' | 'connecting' | 'connected' | 'disconnected' | 'failed'; type WebRTCConnectionEvent = @@ -32,7 +34,6 @@ export class WebRTCConnection { private timeoutId?: ReturnType; constructor() { - // TODO: replace with uuid this.webrtcId = Math.random().toString(36).substring(7); } @@ -123,7 +124,15 @@ export class WebRTCConnection { private async sendOffer(): Promise { if (!this.peerConnection) return; - throw new Error('Work in progress: not implemented'); + const { data } = await fetchClient.POST('/api/webrtc/offer', { + body: { + sdp: this.peerConnection.localDescription?.sdp ?? '', + type: this.peerConnection.localDescription?.type ?? '', + webrtc_id: this.webrtcId, + }, + }); + + return data as SessionData; } private async handleOfferResponse(data: SessionData | undefined): Promise { diff --git a/application/ui/src/features/inspect/inference-provider.component.tsx b/application/ui/src/features/inspect/inference-provider.component.tsx index 40776690f3..184fff65ac 100644 --- a/application/ui/src/features/inspect/inference-provider.component.tsx +++ b/application/ui/src/features/inspect/inference-provider.component.tsx @@ -2,6 +2,7 @@ import { createContext, ReactNode, use, useState } from 'react'; import { $api } from '@geti-inspect/api'; import { components } from '@geti-inspect/api/spec'; +import { useProjectIdentifier } from '@geti-inspect/hooks'; import { MediaItem } from './dataset/types'; import { useSelectedMediaItem } from './selected-media-item-provider.component'; @@ -61,6 +62,9 @@ interface InferenceProviderProps { } export const InferenceProvider = ({ children }: InferenceProviderProps) => { + const { projectId } = useProjectIdentifier(); + const updatePipeline = $api.useMutation('patch', '/api/projects/{project_id}/pipeline'); + const { inferenceResult, onInference, isPending } = useInferenceMutation(); const [selectedModelId, setSelectedModelId] = useState(undefined); const [inferenceOpacity, setInferenceOpacity] = useState(0.75); @@ -69,6 +73,10 @@ export const InferenceProvider = ({ children }: InferenceProviderProps) => { const onSetSelectedModelId = (modelId: string | undefined) => { setSelectedModelId(modelId); + updatePipeline.mutate({ + params: { path: { project_id: projectId } }, + body: { model_id: modelId }, + }); if (modelId && selectedMediaItem) { onInference(selectedMediaItem, modelId); diff --git a/application/ui/src/features/inspect/inference-result.component.tsx b/application/ui/src/features/inspect/inference-result.component.tsx index f51a59ee40..12ef18386a 100644 --- a/application/ui/src/features/inspect/inference-result.component.tsx +++ b/application/ui/src/features/inspect/inference-result.component.tsx @@ -8,6 +8,7 @@ import { useSpinDelay } from 'spin-delay'; import { useInference } from './inference-provider.component'; import { useSelectedMediaItem } from './selected-media-item-provider.component'; +import { StreamContainer } from './stream/stream-container'; import styles from './inference.module.scss'; @@ -56,6 +57,10 @@ export const InferenceResult = () => { const isInferenceAvailable = useIsInferenceAvailable(); const isLoadingInference = useSpinDelay(isPending, { delay: 300 }); + if (selectedMediaItem === undefined) { + return ; + } + if (!isInferenceAvailable && selectedMediaItem === undefined) { return ( { const { status, webRTCConnectionRef } = useWebRTCConnection(); const connect = useCallback(async () => { + console.log('feature connect'); const videoOutput = videoRef.current; const webrtcConnection = webRTCConnectionRef.current; const peerConnection = webrtcConnection?.getPeerConnection(); diff --git a/application/ui/src/features/inspect/toolbar/pipeline-switch/pipeline-switch.component.tsx b/application/ui/src/features/inspect/toolbar/pipeline-switch/pipeline-switch.component.tsx new file mode 100644 index 0000000000..42489179a4 --- /dev/null +++ b/application/ui/src/features/inspect/toolbar/pipeline-switch/pipeline-switch.component.tsx @@ -0,0 +1,41 @@ +import { $api } from '@geti-inspect/api'; +import { useProjectIdentifier } from '@geti-inspect/hooks'; +import { Switch, toast } from '@geti/ui'; + +export const PipelineSwitch = () => { + const { projectId } = useProjectIdentifier(); + const { data: pipeline } = $api.useSuspenseQuery('get', '/api/projects/{project_id}/pipeline', { + params: { path: { project_id: projectId } }, + }); + + const enablePipeline = $api.useMutation('post', '/api/projects/{project_id}/pipeline:enable', { + onError: (error) => { + if (error) { + toast({ type: 'error', message: String(error.detail) }); + } + }, + meta: { + invalidates: [ + ['get', '/api/projects/{project_id}/pipeline', { params: { path: { project_id: projectId } } }], + ], + }, + }); + const disablePipeline = $api.useMutation('post', '/api/projects/{project_id}/pipeline:disable', { + meta: { + invalidates: [ + ['get', '/api/projects/{project_id}/pipeline', { params: { path: { project_id: projectId } } }], + ], + }, + }); + + const handleChange = (isSelected: boolean) => { + const handler = isSelected ? enablePipeline.mutate : disablePipeline.mutate; + handler({ params: { path: { project_id: projectId } } }); + }; + + return ( + + Enabled + + ); +}; diff --git a/application/ui/src/features/inspect/toolbar/sinks/hooks/use-sink-mutation.hook.tsx b/application/ui/src/features/inspect/toolbar/sinks/hooks/use-sink-mutation.hook.tsx index bf371c92a8..3e0d65729f 100644 --- a/application/ui/src/features/inspect/toolbar/sinks/hooks/use-sink-mutation.hook.tsx +++ b/application/ui/src/features/inspect/toolbar/sinks/hooks/use-sink-mutation.hook.tsx @@ -1,7 +1,6 @@ import { $api } from '@geti-inspect/api'; import { useProjectIdentifier } from '@geti-inspect/hooks'; import { omit } from 'lodash-es'; -import { v4 as uuid } from 'uuid'; import { SinkConfig } from '../utils'; @@ -21,15 +20,12 @@ export const useSinkMutation = (isNewSink: boolean) => { return async (body: SinkConfig) => { if (isNewSink) { - const id = uuid(); - const sinkPayload = { ...body, id }; - await addSink.mutateAsync({ - body: sinkPayload, + body, params: { path: { project_id: projectId } }, }); - return id; + return body.id; } const response = await updateSink.mutateAsync({ diff --git a/application/ui/src/features/inspect/toolbar/sinks/local-folder-fields/utils.ts b/application/ui/src/features/inspect/toolbar/sinks/local-folder-fields/utils.ts index dee4a17be8..0d9b4b88d5 100644 --- a/application/ui/src/features/inspect/toolbar/sinks/local-folder-fields/utils.ts +++ b/application/ui/src/features/inspect/toolbar/sinks/local-folder-fields/utils.ts @@ -1,7 +1,9 @@ +import { v4 as uuid } from 'uuid'; + import { LocalFolderSinkConfig, SinkOutputFormats } from '../utils'; export const getLocalFolderInitialConfig = (project_id: string): LocalFolderSinkConfig => ({ - id: '', + id: uuid(), name: '', project_id, sink_type: 'folder', diff --git a/application/ui/src/features/inspect/toolbar/sinks/mqtt-fields copy/mqtt-fields.component.tsx b/application/ui/src/features/inspect/toolbar/sinks/mqtt-fields copy/mqtt-fields.component.tsx deleted file mode 100644 index bc3d2daad3..0000000000 --- a/application/ui/src/features/inspect/toolbar/sinks/mqtt-fields copy/mqtt-fields.component.tsx +++ /dev/null @@ -1,51 +0,0 @@ -import { Flex, NumberField, Switch, TextField } from '@geti/ui'; - -import { OutputFormats } from '../output-formats/output-formats.component'; -import { MqttSinkConfig } from '../utils'; - -interface MqttFieldsProps { - defaultState: MqttSinkConfig; -} - -export const MqttFields = ({ defaultState }: MqttFieldsProps) => { - return ( - - - - - - - - - - - - - - - Auth Required - - - - - - ); -}; diff --git a/application/ui/src/features/inspect/toolbar/sinks/mqtt-fields/utils.ts b/application/ui/src/features/inspect/toolbar/sinks/mqtt-fields/utils.ts index 29d4c30755..6629a1e560 100644 --- a/application/ui/src/features/inspect/toolbar/sinks/mqtt-fields/utils.ts +++ b/application/ui/src/features/inspect/toolbar/sinks/mqtt-fields/utils.ts @@ -1,7 +1,9 @@ +import { v4 as uuid } from 'uuid'; + import { MqttSinkConfig, SinkOutputFormats } from '../utils'; export const getMqttInitialConfig = (project_id: string): MqttSinkConfig => ({ - id: '', + id: uuid(), name: '', project_id, topic: '', diff --git a/application/ui/src/features/inspect/toolbar/sinks/ros-fields/utils.ts b/application/ui/src/features/inspect/toolbar/sinks/ros-fields/utils.ts index 3f6a5c8834..cde89d3c4e 100644 --- a/application/ui/src/features/inspect/toolbar/sinks/ros-fields/utils.ts +++ b/application/ui/src/features/inspect/toolbar/sinks/ros-fields/utils.ts @@ -1,7 +1,9 @@ +import { v4 as uuid } from 'uuid'; + import { RosSinkConfig, SinkOutputFormats } from '../utils'; export const getRosInitialConfig = (project_id: string): RosSinkConfig => ({ - id: '', + id: uuid(), name: '', project_id, topic: '', diff --git a/application/ui/src/features/inspect/toolbar/sinks/sink-list/sink-menu/sink-menu.component.tsx b/application/ui/src/features/inspect/toolbar/sinks/sink-list/sink-menu/sink-menu.component.tsx index ca787287e4..312b3a5baf 100644 --- a/application/ui/src/features/inspect/toolbar/sinks/sink-list/sink-menu/sink-menu.component.tsx +++ b/application/ui/src/features/inspect/toolbar/sinks/sink-list/sink-menu/sink-menu.component.tsx @@ -21,7 +21,6 @@ export const SinkMenu = ({ id, name, isConnected, onEdit }: SinkMenuProps) => { const updatePipeline = $api.useMutation('patch', '/api/projects/{project_id}/pipeline', { meta: { invalidates: [ - ['get', '/api/projects/{project_id}/sinks', { params: { path: { project_id: projectId } } }], ['get', '/api/projects/{project_id}/pipeline', { params: { path: { project_id: projectId } } }], ], }, diff --git a/application/ui/src/features/inspect/toolbar/sinks/webhook-fields/key-value-builder.component.tsx b/application/ui/src/features/inspect/toolbar/sinks/webhook-fields/header-key-value-builder.component.tsx similarity index 96% rename from application/ui/src/features/inspect/toolbar/sinks/webhook-fields/key-value-builder.component.tsx rename to application/ui/src/features/inspect/toolbar/sinks/webhook-fields/header-key-value-builder.component.tsx index 3ced7be887..0c69a7b920 100644 --- a/application/ui/src/features/inspect/toolbar/sinks/webhook-fields/key-value-builder.component.tsx +++ b/application/ui/src/features/inspect/toolbar/sinks/webhook-fields/header-key-value-builder.component.tsx @@ -7,7 +7,7 @@ import { isEmpty } from 'lodash-es'; import { RequiredTextField } from '../../../../../components/required-text-field/required-text-field.component'; import { Fields, getPairsFromObject, Pair } from './utils'; -type KeyValueBuilderProps = { +type HeaderKeyValueBuilderProps = { title: string; keysName: string; valuesName: string; @@ -17,7 +17,7 @@ type KeyValueBuilderProps = { const updatePairAtIndex = (indexToUpdate: number, field: Fields, value: string) => (pair: Pair, index: number) => index === indexToUpdate ? { ...pair, [field]: value } : pair; -export const KeyValueBuilder = ({ title, keysName, valuesName, config = {} }: KeyValueBuilderProps) => { +export const HeaderKeyValueBuilder = ({ title, keysName, valuesName, config = {} }: HeaderKeyValueBuilderProps) => { const [pairs, setPairs] = useState(getPairsFromObject(config)); const addPair = () => { diff --git a/application/ui/src/features/inspect/toolbar/sinks/webhook-fields/utils.ts b/application/ui/src/features/inspect/toolbar/sinks/webhook-fields/utils.ts index 3b22ae60ad..bea07cd1ca 100644 --- a/application/ui/src/features/inspect/toolbar/sinks/webhook-fields/utils.ts +++ b/application/ui/src/features/inspect/toolbar/sinks/webhook-fields/utils.ts @@ -1,3 +1,5 @@ +import { v4 as uuid } from 'uuid'; + import { getObjectFromFormData, SinkOutputFormats, WebhookHttpMethod, WebhookSinkConfig } from '../utils'; export type Pair = Record; @@ -12,7 +14,7 @@ export const getPairsFromObject = (obj: Record): Pair[] => { }; export const getWebhookInitialConfig = (project_id: string): WebhookSinkConfig => ({ - id: '', + id: uuid(), name: '', timeout: 0, project_id, diff --git a/application/ui/src/features/inspect/toolbar/sinks/webhook-fields/webhook-fields.component.tsx b/application/ui/src/features/inspect/toolbar/sinks/webhook-fields/webhook-fields.component.tsx index d8c2b7403e..8f7323db7d 100644 --- a/application/ui/src/features/inspect/toolbar/sinks/webhook-fields/webhook-fields.component.tsx +++ b/application/ui/src/features/inspect/toolbar/sinks/webhook-fields/webhook-fields.component.tsx @@ -2,7 +2,7 @@ import { Flex, Item, NumberField, Picker, TextField } from '@geti/ui'; import { OutputFormats } from '../output-formats/output-formats.component'; import { WebhookHttpMethod, WebhookSinkConfig } from '../utils'; -import { KeyValueBuilder } from './key-value-builder.component'; +import { HeaderKeyValueBuilder } from './header-key-value-builder.component'; interface WebhookFieldsProps { defaultState: WebhookSinkConfig; @@ -37,7 +37,7 @@ export const WebhookFields = ({ defaultState }: WebhookFieldsProps) => { - { const { status, stop } = useWebRTCConnection(); @@ -82,9 +83,8 @@ const useTrainedModels = () => { }; const ModelsPicker = () => { - const { selectedModelId, onSetSelectedModelId } = useInference(); - const models = useTrainedModels(); + const { selectedModelId, onSetSelectedModelId } = useInference(); useEffect(() => { if (selectedModelId !== undefined || models.length === 0) { @@ -131,7 +131,9 @@ export const Toolbar = () => { + + diff --git a/application/ui/src/features/inspect/train-model/train-model-dialog.component.tsx b/application/ui/src/features/inspect/train-model/train-model-dialog.component.tsx index 67adbf578e..1208c5db53 100644 --- a/application/ui/src/features/inspect/train-model/train-model-dialog.component.tsx +++ b/application/ui/src/features/inspect/train-model/train-model-dialog.component.tsx @@ -14,10 +14,9 @@ export const TrainModelDialog = ({ close }: { close: () => void }) => { const [searchParams, setSearchParams] = useSearchParams(); const { projectId } = useProjectIdentifier(); const startTrainingMutation = $api.useMutation('post', '/api/jobs:train', { - meta: { - invalidates: [['get', '/api/jobs']], - }, + meta: { invalidates: [['get', '/api/jobs']] }, }); + const startTraining = async () => { if (selectedModel === null) { return;