Skip to content

Commit

Permalink
Merge pull request #80 from VisActor/fix/badcases
Browse files Browse the repository at this point in the history
fix: chart generation badcases
  • Loading branch information
da730 authored Apr 19, 2024
2 parents c53712d + 1f9b37e commit 00b7c84
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 43 deletions.
12 changes: 8 additions & 4 deletions packages/vmind/__tests__/performance/performanceTest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import {

const TEST_GPT = false;
const TEST_SKYLARK = true;
const ShowThoughts = false;
const EnableDataQuery = true;

const demoDataList: { [key: string]: any } = {
pie: mockUserInput2,
Expand Down Expand Up @@ -61,7 +63,7 @@ const modelResultMap = {
[Model.SKYLARK2]: { totalCount: 0, successCount: 0, totalTime: 0 }
};

const testPerformance = (model: Model, vmind: any) => {
const testPerformance = (model: Model, vmind: VMind) => {
dataList.some((dataName, index) => {
if (index >= START_INDEX) {
it(dataName, async done => {
Expand All @@ -70,9 +72,11 @@ const testPerformance = (model: Model, vmind: any) => {
const { fieldInfo, dataset } = vmind.parseCSVData(csv);
//const { fieldInfo, dataset } = await vmind.parseCSVDataWithLLM(csv, describe);
const startTime = new Date().getTime();
const { spec, time, chartSource } = await vmind.generateChart(input, fieldInfo, dataset);
const { spec, time, chartSource, chartType } = await vmind.generateChart(input, fieldInfo, dataset, {
enableDataQuery: EnableDataQuery
});
const endTime = new Date().getTime();
log('generated chart type: ' + spec.type);
log('generated chart type: ' + chartType);
if (chartSource !== 'chartAdvisor') {
const costTime = endTime - startTime;
log('time cost: ' + costTime / 1000 + 's');
Expand All @@ -97,7 +101,7 @@ if (gptKey && gptURL && TEST_GPT) {
url: gptURL,
model: Model.GPT3_5,
cache: true,
showThoughts: false,
showThoughts: ShowThoughts,
headers: {
'api-key': gptKey
}
Expand Down
3 changes: 2 additions & 1 deletion packages/vmind/jest.performance.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ module.exports = {
testTimeout: 60000,
moduleNameMapper: {
axios: 'axios/dist/node/axios.cjs',
'd3-hierarchy': 'd3-hierarchy/dist/d3-hierarchy.min.js'
'd3-hierarchy': 'd3-hierarchy/dist/d3-hierarchy.min.js',
'^src/(.*)$': '<rootDir>/src/$1',
},
verbose: true,
// 在测试之前设置环境变量
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import type { Transformer } from 'src/base/tools/transformer';
import type { ChartAdvisorContext, ChartAdvisorOutput } from './types';
import type { Cell } from '../../types';
import { isValidDataset } from 'src/common/dataProcess';

import { ChartType as VMindChartType } from 'src/common/typings';
/**
* call @visactor/chart-advisor to get the list of advised charts
* sorted by scores of each chart type
Expand Down Expand Up @@ -91,7 +91,7 @@ const getTop1AdvisedChart: Transformer<getAdvisedListOutput, ChartAdvisorOutput>
// call rule-based method to get recommended chart type and fieldMap(cell)
if (advisedList.length === 0) {
return {
chartType: 'BAR CHART',
chartType: VMindChartType.BarChart.toUpperCase() as VMindChartType,
cell: {},
dataset: undefined,
chartSource,
Expand All @@ -100,7 +100,7 @@ const getTop1AdvisedChart: Transformer<getAdvisedListOutput, ChartAdvisorOutput>
}
const result = advisedList[0];
return {
chartType: result.chartType,
chartType: result.chartType as VMindChartType,
cell: getCell(result.cell),
dataset: result.dataset,
chartSource,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import type { Transformer } from 'src/base/tools/transformer';
import type { GenerateFieldMapContext, GenerateFieldMapOutput } from '../../types';
import { isArray, isString } from 'lodash';
import { matchFieldWithoutPunctuation } from './utils';
import { ChartType } from 'src/common/typings';
import { DataType, ROLE } from 'src/common/typings';
import { calculateTokenUsage, foldDatasetByYField } from 'src/common/utils/utils';
import { FOLD_NAME, FOLD_VALUE } from '@visactor/chart-advisor';
Expand Down Expand Up @@ -62,8 +63,8 @@ const patchColorField: Transformer<PatchContext, Partial<GenerateFieldMapOutput>
cellNew.color = undefined;
if (['BAR CHART', 'LINE CHART', 'DUAL AXIS CHART'].includes(chartTypeNew)) {
cellNew.y = [cellNew.y, color].flat();
if (chartTypeNew === 'DUAL AXIS CHART' && cellNew.y.length > 2) {
chartTypeNew = 'BAR CHART';
if (chartTypeNew === ChartType.DualAxisChart.toUpperCase() && cellNew.y.length > 2) {
chartTypeNew = ChartType.BarChart.toUpperCase() as ChartType;
}
}
}
Expand All @@ -77,7 +78,7 @@ const patchColorField: Transformer<PatchContext, Partial<GenerateFieldMapOutput>
const patchRadarChart: Transformer<PatchContext, Partial<GenerateFieldMapOutput>> = (context: PatchContext) => {
const { chartType, cell } = context;

if (chartType === 'RADAR CHART') {
if (chartType === ChartType.RadarChart.toUpperCase()) {
const cellNew = {
x: cell.angle,
y: cell.value,
Expand All @@ -94,7 +95,7 @@ const patchRadarChart: Transformer<PatchContext, Partial<GenerateFieldMapOutput>
const patchBoxPlot: Transformer<PatchContext, Partial<GenerateFieldMapOutput>> = (context: PatchContext) => {
const { chartType, cell } = context;

if (chartType === 'BOX PLOT') {
if (chartType === ChartType.BoxPlot.toUpperCase()) {
const { x, min, q1, median, q3, max } = cell as any;
const cellNew = {
x,
Expand All @@ -107,12 +108,16 @@ const patchBoxPlot: Transformer<PatchContext, Partial<GenerateFieldMapOutput>> =
return context;
};

const patchBarChart: Transformer<PatchContext, Partial<GenerateFieldMapOutput>> = (context: PatchContext) => {
const patchFoldField: Transformer<PatchContext, Partial<GenerateFieldMapOutput>> = (context: PatchContext) => {
const { chartType, cell, fieldInfo, dataset } = context;
const chartTypeNew = chartType;
const cellNew = { ...cell };
let datasetNew = dataset;
if (chartTypeNew === 'BAR CHART' || chartTypeNew === 'LINE CHART') {
if (
chartTypeNew === ChartType.BarChart.toUpperCase() ||
chartTypeNew === ChartType.LineChart.toUpperCase() ||
chartTypeNew === ChartType.RadarChart.toUpperCase()
) {
if (isValidDataset(datasetNew) && isArray(cellNew.y) && cellNew.y.length > 1) {
datasetNew = foldDatasetByYField(datasetNew, cellNew.y, fieldInfo);
cellNew.y = FOLD_VALUE.toString();
Expand All @@ -131,7 +136,7 @@ const patchDualAxisChart: Transformer<PatchContext, Partial<GenerateFieldMapOutp
const cellNew: any = { ...cell };
//Dual-axis drawing yLeft and yRight

if (chartType === 'DUAL AXIS CHART') {
if (chartType === ChartType.DualAxisChart.toUpperCase()) {
cellNew.y = [
...(isArray(cellNew.y) ? cellNew.y : []),
cellNew.leftAxis,
Expand All @@ -151,7 +156,7 @@ const patchDynamicBarChart: Transformer<PatchContext, Partial<GenerateFieldMapOu
};
let chartTypeNew = chartType;

if (chartType === 'DYNAMIC BAR CHART') {
if (chartType === ChartType.DynamicBarChart.toUpperCase()) {
if (!cellNew.time || cellNew.time === '' || cellNew.time.length === 0) {
const flattenedXField = Array.isArray(cellNew.x) ? cellNew.x : [cellNew.x];
const usedFields = Object.values(cellNew).filter(f => !Array.isArray(f));
Expand All @@ -172,7 +177,7 @@ const patchDynamicBarChart: Transformer<PatchContext, Partial<GenerateFieldMapOu
cellNew.time = stringField.fieldName;
} else {
//no available field, set chart type to bar chart
chartTypeNew = 'BAR CHART';
chartTypeNew = ChartType.BarChart.toUpperCase() as ChartType;
}
}
}
Expand Down Expand Up @@ -217,7 +222,7 @@ export const patchPipelines: Transformer<PatchContext, Partial<GenerateFieldMapO
patchColorField,
patchRadarChart,
patchBoxPlot,
patchBarChart,
patchFoldField,
patchDualAxisChart,
patchDynamicBarChart,
patchArrayField
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { isArray, isNil } from 'lodash';

import type { Transformer } from 'src/base/tools/transformer';
import { foldDatasetByYField, getFieldByDataType, getFieldByRole, getRemainedFields } from 'src/common/utils/utils';
import type { ChartType } from 'src/common/typings';
import { ChartType } from 'src/common/typings';
import { DataType, ROLE } from 'src/common/typings';
import type { GenerateChartAndFieldMapContext, GenerateChartAndFieldMapOutput } from '../../types';
import { isValidDataset } from 'src/common/dataProcess';
Expand Down Expand Up @@ -99,9 +99,10 @@ export const patchYField: Transformer<
}

if (
chartTypeNew === ('BAR CHART' as ChartType) ||
chartTypeNew === ('LINE CHART' as ChartType) ||
chartTypeNew === ('DUAL AXIS CHART' as ChartType)
chartTypeNew === ChartType.BarChart.toUpperCase() ||
chartTypeNew === ChartType.LineChart.toUpperCase() ||
chartTypeNew === ChartType.DualAxisChart.toUpperCase() ||
chartTypeNew === ChartType.RadarChart.toUpperCase()
) {
//use fold to visualize more than 2 y fields
if (isValidDataset(datasetNew)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import { generateRandomString } from './utils';

export const alasqlKeywordList = [
'ABSOLUTE',
'ACTION',
Expand Down Expand Up @@ -181,17 +179,3 @@ export const alasqlKeywordList = [
'WITH',
'WORK'
];

export const operatorList = [
['+', `_${generateRandomString(3)}_PLUS_${generateRandomString(3)}_`],
['-', `_${generateRandomString(3)}_DASH_${generateRandomString(3)}_`],
['*', `_${generateRandomString(3)}_ASTERISK_${generateRandomString(3)}_`],
['/', `_${generateRandomString(3)}_SLASH_${generateRandomString(3)}_`]
];

export const operators = operatorList.map(op => op[0]);

export const RESERVE_REPLACE_MAP = new Map<string, string>([
...operatorList,
...(alasqlKeywordList.map(keyword => [keyword, generateRandomString(10)]) as any)
]);
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ import { DataType, ROLE } from '../../../../common/typings';
import dayjs from 'dayjs';
import { uniqArray } from '@visactor/vutils';
import alasql from 'alasql';
import { RESERVE_REPLACE_MAP, operators } from './constants';

import { replaceAll } from 'src/common/utils/utils';
import { alasqlKeywordList } from './constants';

export const readTopNLine = (csvFile: string, n: number) => {
// get top n lines of a csv file
Expand Down Expand Up @@ -134,6 +135,20 @@ export function generateRandomString(len: number) {
return result;
}

const operatorList = [
['+', `_${generateRandomString(3)}_PLUS_${generateRandomString(3)}_`],
['-', `_${generateRandomString(3)}_DASH_${generateRandomString(3)}_`],
['*', `_${generateRandomString(3)}_ASTERISK_${generateRandomString(3)}_`],
['/', `_${generateRandomString(3)}_SLASH_${generateRandomString(3)}_`]
];

const operators = operatorList.map(op => op[0]);

const RESERVE_REPLACE_MAP = new Map<string, string>([
...operatorList,
...(alasqlKeywordList.map(keyword => [keyword, generateRandomString(10)]) as any)
]);

export const swapMap = (map: Map<string, string>) => {
//swap the map
const swappedMap = new Map();
Expand All @@ -150,7 +165,7 @@ export const swapMap = (map: Map<string, string>) => {
* @param str
* @returns
*/
export const replaceNonASCIICharacters = (str: string) => {
const replaceNonASCIICharacters = (str: string) => {
const nonAsciiCharMap = new Map();

const newStr = str.replace(/([^\x00-\x7F]+)/g, m => {
Expand Down
6 changes: 3 additions & 3 deletions packages/vmind/src/core/VMind.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class VMind {
private _applicationMap: VMindApplicationMap;

constructor(options?: ILLMOptions) {
this._options = { ...(options ?? {}), showThoughts: options.showThoughts ?? true }; //apply default settings
this._model = options.model ?? Model.GPT3_5;
this._options = { ...(options ?? {}), showThoughts: options?.showThoughts ?? true }; //apply default settings
this._model = options?.model ?? Model.GPT3_5;
this.registerApplications();
}

Expand Down Expand Up @@ -133,7 +133,7 @@ class VMind {
let finalFieldInfo = fieldInfo;

let queryDatasetUsage;
const { enableDataQuery, colorPalette, animationDuration, chartTypeList } = options;
const { enableDataQuery, colorPalette, animationDuration, chartTypeList } = options ?? {};
try {
if (!isNil(dataset) && (isNil(enableDataQuery) || enableDataQuery) && modelType !== ModelType.CHART_ADVISOR) {
//run data aggregation first
Expand Down

0 comments on commit 00b7c84

Please sign in to comment.