Skip to content

Commit

Permalink
feat(ai-model): remove dom info in assertion to make it reliable (#284)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: zhouxiao.shaw <zhouxiao.shaw@bytedance.com>
  • Loading branch information
yuyutaotao and zhoushaw authored Jan 16, 2025
1 parent 69ce6ec commit 4cad2e1
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 18 deletions.
1 change: 1 addition & 0 deletions packages/midscene/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"test:ai": "AITEST=true npm run test",
"computer": "TEST_COMPUTER=true npm run test:ai -- tests/ai/evaluate/computer.test.ts",
"evaluate": "npm run test:ai -- tests/ai/evaluate/inspect.test.ts",
"evaluate:assertion": "npm run test:ai -- tests/ai/evaluate/assertion.test.ts",
"prompt": "npm run test:ai -- tests/ai/parse-action.test.ts",
"evaluate:update": "UPDATE_AI_DATA=true npm run test:ai -- tests/ai/evaluate/inspect.test.ts",
"prepublishOnly": "npm run build"
Expand Down
10 changes: 4 additions & 6 deletions packages/midscene/src/ai-model/inspect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,10 @@ export async function AiAssert<
{
type: 'text',
text: `
pageDescription: \n
${description}
Here is the description of the assertion. Just go ahead:
=====================================
${assertion}
=====================================
Here is the description of the assertion. Just go ahead:
=====================================
${assertion}
=====================================
`,
},
],
Expand Down
7 changes: 2 additions & 5 deletions packages/midscene/src/ai-model/prompt/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ DATA_DEMAND start:
{dataKeys}
{dataQuery}
=====================================
DATA_DEMAND ends.
`,
Expand Down Expand Up @@ -117,14 +116,12 @@ export const extractDataSchema: ResponseFormatJSONSchema = {
export function systemPromptToAssert() {
return `
${characteristic}
${contextFormatIntro}
Based on the information you get, Return assertion judgment:
User will give an assertion, and some information about the page. Based on the information you get, tell whether the assertion is truthy.
Return in the following JSON format:
{
thought: string, // string, the thought of the assertion. Should in the same language as the assertion.
pass: true, // true or false, whether the assertion is passed
pass: true, // true or false, whether the assertion is truthy
}
`;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"testDataPath": "test-data/online_order",
"testCases": [
{
"prompt": "the 'select option' button is yellow",
"expected": true
},
{
"prompt": "there are three tabs in the page, named 'Menu', 'Reviews', 'Merchant'",
"expected": true
},
{
"prompt": "the 'select option' button is blue",
"expected": false
},
{
"prompt": "there are three tabs in the page, named 'Home', 'Order', 'Profile'",
"expected": false
},
{
"prompt": "there is a shopping bag icon on the top right of the page",
"expected": true
},
{
"prompt": "The shopping bag icon on the top left of the page",
"expected": false
},
{
"prompt": "There is a homepage icon on the top right of the page",
"expected": false
}
]
}
80 changes: 80 additions & 0 deletions packages/midscene/tests/ai/evaluate/assertion.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import { readFileSync } from 'node:fs';
import path from 'node:path';
import { describe } from 'node:test';
import { AiAssert } from '@/ai-model';
import { afterAll, expect, test } from 'vitest';
import {
type InspectAiTestCase,
getPageTestData,
repeatFile,
} from './test-suite/util';
import 'dotenv/config';

const repeatTime = 2;
const testSources = [
// 'todo',
'online_order',
// 'online_order_list',
// 'taobao',
// 'aweme_login',
// 'aweme_play',
];

describe('ai inspect element', () => {
const testResult: {
path: string;
result: {
score: number;
averageTime: string;
successCount: number;
failCount: number;
};
}[] = [];

afterAll(async () => {
console.table(
testResult.map((r) => {
return {
path: r.path,
...r.result,
};
}),
);
});
repeatFile(testSources, repeatTime, (source, repeatIndex) => {
const aiDataPath = path.join(__dirname, `ai-data/assertion/${source}.json`);
const aiData = JSON.parse(
readFileSync(aiDataPath, 'utf-8'),
) as InspectAiTestCase;

aiData.testCases.forEach((testCase, index) => {
const prompt = testCase.prompt;
test(
`${source}-${repeatIndex}: assertion-${prompt.slice(0, 30)}...`,
async () => {
const { context } = await getPageTestData(
path.join(__dirname, aiData.testDataPath),
);

const { prompt, expected } = testCase;
const result = await AiAssert({
assertion: prompt,
context,
});

expect(typeof result?.content?.pass).toBe('boolean');
if (result?.content?.pass !== expected) {
throw new Error(
`assertion failed: ${prompt} expected: ${expected}, actual: ${result?.content?.pass}, thought: ${result?.content?.thought}`,
);
}

console.log('assertion passed, thought:', result?.content?.thought);
},
{
timeout: 30 * 1000,
},
);
});
});
});
22 changes: 15 additions & 7 deletions packages/midscene/tests/ai/evaluate/test-suite/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import { base64Encoded, imageInfoOfBase64 } from '@/image';

type TestCase = {
prompt: string;
multi?: boolean;
response: Array<{ id: string; indexId: number }>;
expected?: boolean;
};

export type InspectAiTestCase = {
Expand All @@ -33,7 +33,6 @@ export interface AiElementsResponse {
}

export interface TextAiElementResponse extends AiElementsResponse {
multi: boolean;
response: Array<
| {
id: string;
Expand All @@ -58,7 +57,6 @@ export async function runTestCases(
context: any,
getAiResponse: (options: {
description: string;
multi: boolean;
}) => Promise<AiElementsResponse>,
) {
let aiResponse: Array<TextAiElementResponse> = [];
Expand All @@ -68,7 +66,6 @@ export async function runTestCases(
const startTime = Date.now();
const msg = await getAiResponse({
description: testCase.prompt,
multi: Boolean(testCase.multi),
});
const endTime = Date.now();
const spendTime = endTime - startTime;
Expand All @@ -77,7 +74,6 @@ export async function runTestCases(
...msg,
prompt: testCase.prompt,
response: msg.elements,
multi: Boolean(testCase.multi),
caseIndex,
spendTime,
elementsSnapshot: msg.elements.map((element) => {
Expand Down Expand Up @@ -135,6 +131,18 @@ export const repeat = (times: number, fn: (index: number) => void) => {
}
};

export const repeatFile = (
files: Array<string>,
times: number,
fn: (file: string, index: number) => void,
) => {
for (const file of files) {
repeat(times, (index) => {
fn(file, index);
});
}
};

function ensureDirectoryExistence(filePath: string) {
const dirname = path.dirname(filePath);
if (existsSync(dirname)) {
Expand Down Expand Up @@ -172,12 +180,12 @@ export async function getPageTestData(targetDir: string): Promise<{
}> {
// Note: this is the magic
const resizeOutputImgP = path.join(targetDir, 'output_without_text.png');
const originalInputputImgP = path.join(targetDir, 'input.png');
const originalInputImgP = path.join(targetDir, 'input.png');
const snapshotJsonPath = path.join(targetDir, 'element-snapshot.json');
const snapshotJson = readFileSync(snapshotJsonPath, { encoding: 'utf-8' });
const elementSnapshot = JSON.parse(snapshotJson);
const screenshotBase64 = base64Encoded(resizeOutputImgP);
const originalScreenshotBase64 = base64Encoded(originalInputputImgP);
const originalScreenshotBase64 = base64Encoded(originalInputImgP);
const size = await imageInfoOfBase64(screenshotBase64);
const baseContext = {
size,
Expand Down

0 comments on commit 4cad2e1

Please sign in to comment.