Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (C) 2026 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

import { CSSProperties } from 'react';
import { CSSProperties, Fragment } from 'react';

import { Divider, Flex, Text } from '@geti/ui';
import { clsx } from 'clsx';
Expand Down Expand Up @@ -167,16 +167,15 @@ export const Labels = ({ isClassification = false, isMultiLabel = false, isReadO
aria-disabled={isReadOnly}
>
{labels.map((label) => (
<>
<Fragment key={label.id}>
{label.id === EMPTY_LABEL_ID && <Divider size={'S'} orientation={'vertical'} />}
<LabelBadge
key={label.id}
label={label}
isSelected={isLabelSelected(label)}
isDisabled={isReadOnly}
onClick={() => handleLabelClick(label)}
/>
</>
</Fragment>
))}
</div>
</Flex>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

import { dimensionValue, Flex, Grid, Heading, Text } from '@geti/ui';
import { Image, Tag } from '@geti/ui/icons';
import { useNumberFormatter } from 'react-aria';

import { TrainModel } from '../../../train-model/train-model.component';
import type { DatasetGroup } from '../../types';
import { DatasetActions } from '../dataset-actions/dataset-actions.component';
import { ModelBadge } from '../model-row/model-badge.component';
import { ThreeSectionRange } from '../three-section-range/three-section-range.component';

import classes from './group-headers.module.scss';

type DatasetGroupHeaderProps = {
dataset: DatasetGroup;
};
Expand All @@ -20,6 +20,7 @@ export const DatasetGroupHeader = ({ dataset }: DatasetGroupHeaderProps) => {
const gridColumns = hasDatasetRevisionData
? ['auto', '1fr', 'auto', '1fr', 'auto']
: ['auto', '1fr', 'auto', 'auto'];
const formatter = useNumberFormatter();

return (
<Grid columns={gridColumns} alignItems={'center'} marginBottom={'size-225'} gap={'size-200'}>
Expand All @@ -39,12 +40,12 @@ export const DatasetGroupHeader = ({ dataset }: DatasetGroupHeaderProps) => {
</Text>

<Flex gap={'size-50'} justifyContent={'center'}>
<Flex UNSAFE_className={classes.tag}>
<ModelBadge>
<Tag /> {dataset.labelCount}
</Flex>
<Flex UNSAFE_className={classes.tag}>
<Image /> {dataset.imageCount.toLocaleString()}
</Flex>
</ModelBadge>
<ModelBadge>
<Image /> {formatter.format(dataset.imageCount)}
</ModelBadge>
</Flex>

{hasDatasetRevisionData && (
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import type { Model } from '../../../../../constants/shared-types';
import { ModelDetailsTabs } from '../../model-details/model-details-tabs.component';
import { useModelListing } from '../../provider/model-listing-provider';
import { ArchitectureGroup, DatasetGroup } from '../../types';
import { isFailedModel } from '../../utils/utils';
import { GroupHeader } from '../group-headers/group-header.component';
import { ModelRowContainer } from '../model-row/model-row.container';
import { ModelsTableHeader } from '../models-table-header.component';
Expand Down Expand Up @@ -35,6 +36,7 @@ export const GroupModelsContainer = ({ group, models }: GroupModelsContainerProp
isQuiet
UNSAFE_className={classes.disclosure}
isExpanded={expandedModelIds.has(modelId)}
isDisabled={isFailedModel(model)}
onExpandedChange={() => onExpandModel(modelId)}
data-testid={`model-disclosure-${modelId}`}
>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ import type { Model } from '../../../../../constants/shared-types';
import { usePatchPipeline } from '../../../../../hooks/api/pipeline.hook';
import { useDeleteModel } from '../../../hooks/api/use-delete-model.hook';
import { useRenameModel } from '../../../hooks/api/use-rename-model.hook';
import { isFailedModel } from '../../utils/utils';
import { RenameModelDialog } from '../model-row/rename-model-dialog.component';

const MODEL_ACTIONS = {
ACTIVE: 'active',
RENAME: 'rename',
DELETE: 'delete',
EXPORT: 'export',
};

type ModelActionsProps = {
Expand All @@ -33,6 +33,8 @@ export const ModelActions = ({ model }: ModelActionsProps) => {
const [isRenameDialogOpen, setIsRenameDialogOpen] = useState(false);
const [isDeleteDialogOpen, setIsDeleteDialogOpen] = useState(false);

const disabledKeys = isFailedModel(model) ? [MODEL_ACTIONS.ACTIVE, MODEL_ACTIONS.RENAME] : [];

const handleAction = (key: Key) => {
if (key === MODEL_ACTIONS.ACTIVE) {
patchPipelineMutation.mutate({
Expand Down Expand Up @@ -70,7 +72,7 @@ export const ModelActions = ({ model }: ModelActionsProps) => {
<ActionButton isQuiet aria-label={'Model actions'}>
<MoreMenu />
</ActionButton>
<Menu onAction={handleAction} aria-label={'Model actions menu'}>
<Menu onAction={handleAction} aria-label={'Model actions menu'} disabledKeys={disabledKeys}>
<Item key={MODEL_ACTIONS.ACTIVE}>Set as active</Item>
<Item key={MODEL_ACTIONS.RENAME}>Rename</Item>
<Item key={MODEL_ACTIONS.DELETE}>Delete</Item>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (C) 2025-2026 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

import { Flex, Text } from '@geti/ui';

import { ReactComponent as ThumbsUp } from '../../../../../assets/icons/thumbs-up.svg';
import { ModelBadge } from './model-badge.component';

import classes from './model-row.module.scss';

type ArchitectureColumnProps = {
architecture: string;
};

export const ArchitectureColumn = ({ architecture }: ArchitectureColumnProps) => {
return (
<Flex direction={'column'} gap={'size-100'}>
<Text UNSAFE_className={classes.smallText}>{architecture} (Apache 2.0)</Text>
{/* TODO: Speed is hardcoded for now, once the backend is update we need to update this */}
<ModelBadge id={'architecture-name'}>
<ThumbsUp />
<Text>Speed</Text>
</ModelBadge>
</Flex>
);
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (C) 2025-2026 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

import { Flex, Text } from '@geti/ui';
import { Image, Tag } from '@geti/ui/icons';
import { useNumberFormatter } from 'react-aria';

import type { DatasetRevision } from '../../../../../constants/shared-types';
import { formatDatasetRevisionDate } from '../../utils/date-formatting';
import { ModelBadge } from './model-badge.component';

import styles from './model-row.module.scss';

type DatasetColumnProps = {
datasetRevision: DatasetRevision | undefined;
labelsCount: number | undefined;
};

export const DatasetColumn = ({ datasetRevision, labelsCount }: DatasetColumnProps) => {
const totalCount = datasetRevision?.item_counts?.total;
const formatter = useNumberFormatter();

// Should never happen, but just in case
if (datasetRevision === undefined) {
return (
<Flex alignItems={'center'} justifyContent={'center'}>
Unknown
</Flex>
);
}

return (
<Flex direction={'column'} gap={'size-50'}>
<Text UNSAFE_className={styles.datasetRevisionName}>{datasetRevision.name}</Text>
<Text UNSAFE_className={styles.datasetRevisionDate}>
{formatDatasetRevisionDate(datasetRevision.created_at)}
</Text>
<Flex gap={'size-100'}>
{labelsCount !== undefined && (
<ModelBadge id={'labels-count'}>
<Tag />
<Text>{labelsCount}</Text>
</ModelBadge>
)}
{totalCount !== undefined && (
<ModelBadge id={'dataset-count'}>
<Image />
<Text>{formatter.format(totalCount)}</Text>
</ModelBadge>
)}
</Flex>
</Flex>
);
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (C) 2025-2026 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

import { ReactNode } from 'react';

import { Badge, Flex, Text } from '@geti/ui';

import styles from './model-row.module.scss';

type ModelBadgeProps = {
children: ReactNode;
id?: string;
};

export const ModelBadge = ({ children, id }: ModelBadgeProps) => {
return (
<Badge variant={'neutral'} UNSAFE_className={styles.modelBadge} data-testid={id}>
<Text>
<Flex alignItems={'center'} gap={'size-50'}>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need these inside the badge? I thought it should support an icon and text out of the box.
https://react-spectrum.adobe.com/v3/Badge.html#content

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't use spectrum icons and this causes icon scaling issues (icon is too small)

{children}
</Flex>
</Text>
</Badge>
);
};
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

import { screen } from '@testing-library/react';
import { screen, within } from '@testing-library/react';
import userEvent from '@testing-library/user-event';
import { getMockedDatasetRevision } from 'mocks/mock-dataset-revision';
import { render } from 'test-utils/render';

import { getMockedModel } from '../../../../../../mocks/mock-model';
Expand All @@ -25,41 +26,102 @@ describe('ModelRow', () => {
},
});

const datasetRevision = getMockedDatasetRevision({
id: 'dataset-123',
name: 'Dataset 1',
item_counts: {
total: 10,
testing: 4,
training: 4,
validation: 2,
},
});

describe('basic rendering', () => {
it('should render all model information correctly', () => {
render(<ModelRow model={defaultModel} />);
it('should render all model information correctly when grouped by architecture', () => {
render(<ModelRow model={defaultModel} datasetRevision={datasetRevision} groupBy={'architecture'} />);

expect(screen.getByTestId('model-name')).toHaveTextContent('Test Model');

const datasetBadge = screen.getByTestId('dataset-count');
const labelsBadge = screen.getByTestId('labels-count');
const labelSchemaRevision = defaultModel.training_info.label_schema_revision ?? {};
const labelsCount =
'labels' in labelSchemaRevision && Array.isArray(labelSchemaRevision.labels)
? labelSchemaRevision.labels.length
: '';

expect(screen.getByText(datasetRevision.name)).toBeInTheDocument();
expect(within(datasetBadge).getByText(datasetRevision.item_counts?.total?.toString() ?? ''));
expect(within(labelsBadge).getByText(labelsCount));
expect(screen.queryByText(/YOLOX/)).not.toBeInTheDocument();
expect(screen.queryByText('Speed')).not.toBeInTheDocument();
});

it('should render all model information correctly when grouped by dataset', () => {
render(<ModelRow model={defaultModel} datasetRevision={datasetRevision} groupBy={'dataset'} />);

expect(screen.getByTestId('model-name')).toHaveTextContent('Test Model');
expect(screen.getByText(/YOLOX \(Apache 2\.0\)/)).toBeInTheDocument();

expect(screen.getByText(/YOLOX/)).toBeInTheDocument();
expect(screen.getByText('Speed')).toBeInTheDocument();
expect(screen.queryByText(datasetRevision.name)).not.toBeInTheDocument();
expect(screen.queryByTestId('dataset-count')).not.toBeInTheDocument();
expect(screen.queryByTestId('labels-count')).not.toBeInTheDocument();
});

it('should render "Unnamed Model" when model name is null or undefined', () => {
const modelWithoutName = getMockedModel({ name: undefined });

render(<ModelRow model={modelWithoutName} />);
render(<ModelRow model={modelWithoutName} datasetRevision={datasetRevision} groupBy={'dataset'} />);

expect(screen.getByTestId('model-name')).toHaveTextContent('Unnamed Model');
});

it('should render "-" when model size is 0 or negative', () => {
const modelWithZeroSize = getMockedModel({ size: 0 });

render(<ModelRow model={modelWithZeroSize} />);
render(<ModelRow model={modelWithZeroSize} datasetRevision={datasetRevision} groupBy={'dataset'} />);

expect(screen.getByText('-')).toBeInTheDocument();
});

it('renders "Failed" badge when training status is failed', () => {
const failedModel = getMockedModel({
training_info: {
status: 'failed',
},
});

render(<ModelRow model={failedModel} datasetRevision={datasetRevision} groupBy={'dataset'} />);

expect(screen.getByText('Failed')).toBeInTheDocument();
});
});

describe('active model tag', () => {
it('should show active tag only when model id matches activeModelArchitectureId', () => {
const { rerender } = render(<ModelRow model={defaultModel} activeModelArchitectureId='model-123' />);
const { rerender } = render(
<ModelRow
model={defaultModel}
activeModelArchitectureId='model-123'
datasetRevision={datasetRevision}
groupBy={'dataset'}
/>
);
expect(screen.getByText('Active')).toBeInTheDocument();

rerender(<ModelRow model={defaultModel} activeModelArchitectureId={'different-id'} />);
rerender(
<ModelRow
model={defaultModel}
activeModelArchitectureId={'different-id'}
datasetRevision={datasetRevision}
groupBy={'dataset'}
/>
);
expect(screen.queryByText('Active')).not.toBeInTheDocument();

rerender(<ModelRow model={defaultModel} />);
rerender(<ModelRow model={defaultModel} datasetRevision={datasetRevision} groupBy={'dataset'} />);
expect(screen.queryByText('Active')).not.toBeInTheDocument();
});
});
Expand All @@ -72,7 +134,15 @@ describe('ModelRow', () => {
name: 'Parent Model',
});

render(<ModelRow model={defaultModel} parentRevisionModel={parentModel} onExpandModel={onExpandModel} />);
render(
<ModelRow
model={defaultModel}
parentRevisionModel={parentModel}
onExpandModel={onExpandModel}
datasetRevision={datasetRevision}
groupBy={'dataset'}
/>
);

expect(screen.getByText('Fine-tuned from')).toBeInTheDocument();
const parentLink = screen.getByRole('link', { name: 'Parent Model' });
Expand All @@ -83,7 +153,7 @@ describe('ModelRow', () => {
});

it('should not render parent revision model when not provided', () => {
render(<ModelRow model={defaultModel} />);
render(<ModelRow model={defaultModel} datasetRevision={datasetRevision} groupBy={'dataset'} />);

expect(screen.queryByText('Fine-tuned from')).not.toBeInTheDocument();
});
Expand Down
Loading
Loading