Skip to content

Commit 5c7430d

Browse files
authored
fix(ai): Fix generateContentStream returning wrong inferenceSource. (#9381)
1 parent f5fc6bf commit 5c7430d

File tree

3 files changed

+39
-3
lines changed

3 files changed

+39
-3
lines changed

.changeset/gorgeous-rice-carry.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'@firebase/ai': patch
3+
---
4+
5+
Fix `generateContentStream` returning wrong `inferenceSource`.

packages/ai/src/methods/generate-content.test.ts

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import {
2626
import * as request from '../requests/request';
2727
import {
2828
generateContent,
29+
generateContentStream,
2930
templateGenerateContent,
3031
templateGenerateContentStream
3132
} from './generate-content';
@@ -35,6 +36,7 @@ import {
3536
HarmBlockMethod,
3637
HarmBlockThreshold,
3738
HarmCategory,
39+
InferenceSource,
3840
Language,
3941
Outcome
4042
} from '../types';
@@ -548,8 +550,7 @@ describe('generateContent()', () => {
548550
);
549551
});
550552
});
551-
// TODO: define a similar test for generateContentStream
552-
it('on-device', async () => {
553+
it('generateContent on-device', async () => {
553554
const chromeAdapter = fakeChromeAdapter;
554555
const isAvailableStub = stub(chromeAdapter, 'isAvailable').resolves(true);
555556
const mockResponse = getMockResponse(
@@ -566,9 +567,35 @@ describe('generateContent()', () => {
566567
chromeAdapter
567568
);
568569
expect(result.response.text()).to.include('Mountain View, California');
570+
expect(result.response.inferenceSource).to.equal(InferenceSource.ON_DEVICE);
569571
expect(isAvailableStub).to.be.called;
570572
expect(generateContentStub).to.be.calledWith(fakeRequestParams);
571573
});
574+
it('generateContentStream on-device', async () => {
575+
const chromeAdapter = fakeChromeAdapter;
576+
const isAvailableStub = stub(chromeAdapter, 'isAvailable').resolves(true);
577+
const mockResponse = getMockResponseStreaming(
578+
'vertexAI',
579+
'streaming-success-basic-reply-short.txt'
580+
);
581+
const generateContentStreamStub = stub(
582+
chromeAdapter,
583+
'generateContentStream'
584+
).resolves(mockResponse as Response);
585+
const result = await generateContentStream(
586+
fakeApiSettings,
587+
'model',
588+
fakeRequestParams,
589+
chromeAdapter
590+
);
591+
const aggregatedResponse = await result.response;
592+
expect(aggregatedResponse.text()).to.include('Cheyenne');
593+
expect(aggregatedResponse.inferenceSource).to.equal(
594+
InferenceSource.ON_DEVICE
595+
);
596+
expect(isAvailableStub).to.be.called;
597+
expect(generateContentStreamStub).to.be.calledWith(fakeRequestParams);
598+
});
572599
});
573600

574601
describe('templateGenerateContent', () => {

packages/ai/src/methods/generate-content.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,11 @@ export async function generateContentStream(
7070
() =>
7171
generateContentStreamOnCloud(apiSettings, model, params, requestOptions)
7272
);
73-
return processStream(callResult.response, apiSettings); // TODO: Map streaming responses
73+
return processStream(
74+
callResult.response,
75+
apiSettings,
76+
callResult.inferenceSource
77+
);
7478
}
7579

7680
async function generateContentOnCloud(

0 commit comments

Comments
 (0)