diff --git a/application/ui/src/features/annotator/tools/tools.component.tsx b/application/ui/src/features/annotator/tools/tools.component.tsx index 781873fe53..63acc246d8 100644 --- a/application/ui/src/features/annotator/tools/tools.component.tsx +++ b/application/ui/src/features/annotator/tools/tools.component.tsx @@ -1,7 +1,7 @@ // Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 -import { ActionButton } from '@geti/ui'; +import { ActionButton, Tooltip, TooltipTrigger } from '@geti/ui'; import { useHotkeys } from 'react-hotkeys-hook'; import { Fragment } from 'react/jsx-runtime'; @@ -18,16 +18,21 @@ const Tool = ({ tool, activeTool, setActiveTool }: ToolProps) => { useHotkeys(tool.hotkey, () => setActiveTool(tool.type), [setActiveTool]); return ( - setActiveTool(tool.type)} - aria-label={`${tool.type} tool`} - > - - - - + + setActiveTool(tool.type)} + aria-label={`${tool.type} tool`} + > + + + + + + {tool.label} ({tool.hotkey}) + + ); }; diff --git a/application/ui/src/features/dataset/media-preview/primary-toolbar/toggle-annotations-visibility.component.tsx b/application/ui/src/features/dataset/media-preview/primary-toolbar/toggle-annotations-visibility.component.tsx index 9e6b50140c..f9865b637d 100644 --- a/application/ui/src/features/dataset/media-preview/primary-toolbar/toggle-annotations-visibility.component.tsx +++ b/application/ui/src/features/dataset/media-preview/primary-toolbar/toggle-annotations-visibility.component.tsx @@ -14,7 +14,7 @@ export const ToggleAnnotationsVisibility = () => { useHotkeys(HOTKEYS.toggleAnnotationsVisibility, toggleVisibility, [toggleVisibility]); return ( - + {isVisible ? : } diff --git a/application/ui/src/features/models/train-model/select-dataset-revision.component.tsx b/application/ui/src/features/models/train-model/select-dataset-revision.component.tsx index 888925c872..111e6d3700 100644 --- a/application/ui/src/features/models/train-model/select-dataset-revision.component.tsx +++ b/application/ui/src/features/models/train-model/select-dataset-revision.component.tsx @@ -1,7 +1,7 @@ // Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 -import { Item, Picker } from '@geti/ui'; +import { Content, ContextualHelp, Heading, Item, Picker } from '@geti/ui'; import { useTrainModel } from './train-model-provider.component'; @@ -9,14 +9,28 @@ export const SelectDatasetRevision = () => { const { datasetRevisions, selectedDatasetRevisionId, onSelectDatasetRevisionId } = useTrainModel(); return ( - onSelectDatasetRevisionId(String(key))} - > - {(item) => {item.name}} - + <> + onSelectDatasetRevisionId(String(key))} + contextualHelp={ + + Selecting a dataset + + {`Choose the version of the dataset to use for training. If you want to train the new model + on the exact same data (media and annotations) as another model, please select the + corresponding dataset revision. Conversely, if you want to train on the most recent version + of the data (what you see in the "Dataset" page), please select "Use + current dataset".`} + + + } + > + {(item) => {item.name}} + + ); }; diff --git a/application/ui/src/features/models/train-model/train-model-dialog.component.tsx b/application/ui/src/features/models/train-model/train-model-dialog.component.tsx index 3bde19fb7a..65fea78021 100644 --- a/application/ui/src/features/models/train-model/train-model-dialog.component.tsx +++ b/application/ui/src/features/models/train-model/train-model-dialog.component.tsx @@ -1,12 +1,26 @@ // Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 -import { Button, ButtonGroup, Content, Dialog, Divider, Flex, Heading, Link, Text, toast } from '@geti/ui'; +import { + Button, + ButtonGroup, + Content, + Dialog, + Divider, + Flex, + Footer, + Heading, + InlineAlert, + Link, + Text, + toast, +} from '@geti/ui'; import { useProjectIdentifier } from 'hooks/use-project-identifier.hook'; import { useMatch } from 'react-router'; import { paths } from '../../../constants/paths'; import { useTrainModelMutation } from '../hooks/api/use-train-model-mutation'; +import { useIsTrainingButtonDisabled } from '../hooks/use-is-training-button-disabled'; import { TrainModelDialogContent } from './train-model-dialog-content'; import { useTrainModel } from './train-model-provider.component'; @@ -20,8 +34,10 @@ export const TrainModelDialog = ({ onClose }: TrainModelDialogProps) => { const trainModelMutation = useTrainModelMutation(); const projectId = useProjectIdentifier(); const isModelsPage = useMatch(paths.project.models.pattern); + const isTrainingDisabled = useIsTrainingButtonDisabled(); - const isStartButtonDisabled = selectedModelArchitectureId === null || selectedTrainingDevice === null; + const isStartButtonDisabled = + isTrainingDisabled || selectedModelArchitectureId === null || selectedTrainingDevice === null; const trainModel = () => { if (isStartButtonDisabled) return; @@ -57,21 +73,39 @@ export const TrainModelDialog = ({ onClose }: TrainModelDialogProps) => { }; return ( - + Select a model to train + + + - - - - + +
+ + {isTrainingDisabled ? ( + + Why can I not start training? + + In order to train a model, you need to annotate at least 3 items in your dataset, + although we recommend annotating at least 10 for better results. + + + ) : null} + + + + + + +
); }; diff --git a/application/ui/src/features/models/train-model/train-model-provider.component.tsx b/application/ui/src/features/models/train-model/train-model-provider.component.tsx index bf204c05fe..34fb07c064 100644 --- a/application/ui/src/features/models/train-model/train-model-provider.component.tsx +++ b/application/ui/src/features/models/train-model/train-model-provider.component.tsx @@ -44,7 +44,7 @@ const useDatasetRevisions = () => { const { data: datasetRevisions } = useGetDatasetRevisions(); return { datasetRevisions: [ - { id: 'use-current-dataset-revision', name: 'Use current revision', value: null }, + { id: 'use-current-dataset-revision', name: 'Use current dataset', value: null }, ...(datasetRevisions?.map(({ id, name }) => ({ id, name, value: String(id) })) ?? []), ], }; diff --git a/application/ui/src/features/models/train-model/train-model.test.tsx b/application/ui/src/features/models/train-model/train-model.component.test.tsx similarity index 54% rename from application/ui/src/features/models/train-model/train-model.test.tsx rename to application/ui/src/features/models/train-model/train-model.component.test.tsx index 101933325b..3ee581b8ea 100644 --- a/application/ui/src/features/models/train-model/train-model.test.tsx +++ b/application/ui/src/features/models/train-model/train-model.component.test.tsx @@ -1,7 +1,9 @@ // Copyright (C) 2025-2026 Intel Corporation // SPDX-License-Identifier: Apache-2.0 -import { screen, waitFor } from '@testing-library/react'; +import { fireEvent, screen, waitFor } from '@testing-library/react'; +import { getMockedPipeline } from 'mocks/mock-pipeline'; +import { getMockedProject } from 'mocks/mock-project'; import { HttpResponse } from 'msw'; import { render } from 'test-utils/render'; @@ -10,7 +12,30 @@ import { server } from '../../../msw-node-setup'; import { TrainModel } from './train-model.component'; describe('TrainModel', () => { - it('disables train model button when there are no enough annotated media items', async () => { + beforeEach(() => { + server.use( + http.get('/api/projects/{project_id}', () => { + return HttpResponse.json(getMockedProject({ id: '123' })); + }), + http.get('/api/projects/{project_id}/pipeline', () => { + return HttpResponse.json(getMockedPipeline({})); + }), + http.get('/api/projects/{project_id}/dataset_revisions', () => { + return HttpResponse.json([]); + }), + http.get('/api/model_architectures', () => { + return HttpResponse.json({ + model_architectures: [], + top_picks: null, + }); + }), + http.get('/api/system/devices/training', () => { + return HttpResponse.json([{ type: 'cpu', name: 'CPU' }]); + }) + ); + }); + + it('shows warning message when there are not enough annotated media items', async () => { server.use( http.get('/api/projects/{project_id}/dataset/items', () => { return HttpResponse.json({ @@ -23,14 +48,10 @@ describe('TrainModel', () => { id: '2', subset: 'unassigned', }, - { - id: '3', - subset: 'unassigned', - }, ], pagination: { - total: 3, - count: 3, + total: 2, + count: 2, limit: 10, offset: 0, }, @@ -40,12 +61,14 @@ describe('TrainModel', () => { render(); - await waitFor(() => { - expect(screen.getByRole('button', { name: 'Train model' })).toBeDisabled(); - }); + fireEvent.click(screen.getByRole('button', { name: 'Train model' })); + + expect( + await screen.findByText(/In order to train a model, you need to annotate at least 3 items/) + ).toBeVisible(); }); - it('enables train model button when there are enough annotated media items', async () => { + it('does not show warning message when there are enough annotated media items', async () => { server.use( http.get('/api/projects/{project_id}/dataset/items', () => { return HttpResponse.json({ @@ -79,8 +102,12 @@ describe('TrainModel', () => { render(); + fireEvent.click(screen.getByRole('button', { name: 'Train model' })); + await waitFor(() => { - expect(screen.getByRole('button', { name: 'Train model' })).toBeEnabled(); + expect( + screen.queryByText(/In order to train a model, you need to annotate at least 3 items/) + ).not.toBeInTheDocument(); }); }); }); diff --git a/application/ui/src/features/models/train-model/train-model.component.tsx b/application/ui/src/features/models/train-model/train-model.component.tsx index 543ad52e34..37e88ca22d 100644 --- a/application/ui/src/features/models/train-model/train-model.component.tsx +++ b/application/ui/src/features/models/train-model/train-model.component.tsx @@ -5,7 +5,6 @@ import { Suspense } from 'react'; import { Button, DialogTrigger, Loading, View } from '@geti/ui'; -import { useIsTrainingButtonDisabled } from '../hooks/use-is-training-button-disabled'; import { TrainModelDialog } from './train-model-dialog.component'; import { TrainModelProvider } from './train-model-provider.component'; @@ -14,11 +13,9 @@ type TrainModelProps = { }; export const TrainModel = ({ preSelectedDatasetRevisionId }: TrainModelProps) => { - const isTrainingDisabled = useIsTrainingButtonDisabled(); - return ( - + {(close) => (