Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"@google/genai": "1.32.0",
"@modelcontextprotocol/sdk": "^1.24.0",
"google-auth-library": "^10.3.0",
"zod": "3.25.76"
"zod": "^4.2.1"
},
"devDependencies": {
"openapi-types": "^12.1.3"
Expand Down
38 changes: 18 additions & 20 deletions core/src/tools/function_tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,33 @@
* SPDX-License-Identifier: Apache-2.0
*/

import {FunctionDeclaration, Schema, Type} from '@google/genai';
import {
type infer as zInfer,
ZodObject,
type ZodRawShape,
} from 'zod';
import { FunctionDeclaration, Schema, Type } from '@google/genai';
import { z as z3 } from 'zod/v3';
import { z as z4 } from 'zod/v4';

import {isZodObject, zodObjectToSchema} from '../utils/simple_zod_to_json.js';
import { isZodObject, zodObjectToSchema } from '../utils/simple_zod_to_json.js';

import {BaseTool, RunAsyncToolRequest} from './base_tool.js';
import {ToolContext} from './tool_context.js';
import { BaseTool, RunAsyncToolRequest } from './base_tool.js';
import { ToolContext } from './tool_context.js';

/**
* Input parameters of the function tool.
*/
export type ToolInputParameters =
| undefined
| ZodObject<ZodRawShape>
| Schema;
export type ToolInputParameters = | undefined | z3.ZodObject<z3.ZodRawShape> | z4.ZodObject | Schema;

type ZodObject<T extends Record<string, any>> = z3.ZodObject<z3.ZodRawShape> | z4.ZodObject<T>;

/*
* The arguments of the function tool.
*/
export type ToolExecuteArgument<TParameters extends ToolInputParameters> =
TParameters extends ZodObject<infer T, infer U, infer V>
? zInfer<ZodObject<T, U, V>>
: TParameters extends Schema
? unknown
: string;
TParameters extends z3.ZodObject<infer T, infer U, infer V>
? z3.infer<z3.ZodObject<T, U, V>>
: TParameters extends z4.ZodObject<infer T>
? z4.infer<z4.ZodObject<T>>
: TParameters extends Schema
? unknown
: string;

/*
* The function to execute by the tool.
Expand Down Expand Up @@ -65,7 +63,7 @@ export type ToolOptions<
function toSchema<TParameters extends ToolInputParameters>(
parameters: TParameters): Schema {
if (parameters === undefined) {
return {type: Type.OBJECT, properties: {}};
return { type: Type.OBJECT, properties: {} };
}

if (isZodObject(parameters)) {
Expand Down Expand Up @@ -120,7 +118,7 @@ export class FunctionTool<
override async runAsync(req: RunAsyncToolRequest): Promise<unknown> {
try {
let validatedArgs: unknown = req.args;
if (this.parameters instanceof ZodObject) {
if (isZodObject(this.parameters)) {
validatedArgs = this.parameters.parse(req.args);
}
return await this.execute(
Expand Down
2 changes: 1 addition & 1 deletion core/src/utils/gemini_schema_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {z} from 'zod';

const MCPToolSchema = z.object({
type: z.literal('object'),
properties: z.record(z.unknown()).optional(),
properties: z.record(z.string(), z.unknown()).optional(),
required: z.string().array().optional(),
});
type MCPToolSchema = z.infer<typeof MCPToolSchema>;
Expand Down
160 changes: 107 additions & 53 deletions core/src/utils/simple_zod_to_json.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,57 @@
* SPDX-License-Identifier: Apache-2.0
*/

import {Schema, Type} from '@google/genai';
import {z, ZodObject, ZodTypeAny} from 'zod';
import { Schema, Type } from '@google/genai';
import { z as z3 } from 'zod/v3';
import { z as z4, toJSONSchema as toJSONSchemaV4 } from 'zod/v4';

type ZodSchema<T = any> = z3.ZodType<T> | z4.ZodType<T>;

type SchemaLike = ZodSchema | Schema;

function isZodSchema(obj: unknown): obj is ZodSchema {
return (
obj !== null &&
typeof obj === "object" &&
"parse" in obj &&
typeof (obj as { parse: unknown }).parse === "function" &&
"safeParse" in obj &&
typeof (obj as { safeParse: unknown }).safeParse === "function"
);
}

function isZodV3Schema(obj: unknown): obj is z3.ZodTypeAny {
return isZodSchema(obj) && !("_zod" in obj);
}

function isZodV4Schema(obj: unknown): obj is z4.ZodType {
return isZodSchema(obj) && "_zod" in obj;
}

function getZodTypeName(schema: z3.ZodTypeAny | z4.ZodType): string | undefined {
const schemaAny = schema as any;

if (schemaAny._def?.typeName) {
return schemaAny._def.typeName;
}

const zod4Type = schemaAny._def?.type;
if (typeof zod4Type === 'string' && zod4Type) {
return 'Zod' + zod4Type.charAt(0).toUpperCase() + zod4Type.slice(1);
}

return undefined;
}

/**
* Returns true if the given object is a V3 ZodObject.
* Returns true if the given object is a ZodObject (supports both Zod v3 and v4).
*/
export function isZodObject(obj: unknown): obj is ZodObject<any> {
return (
obj !== null && typeof obj === 'object' &&
(obj as any)._def?.typeName === 'ZodObject');
export function isZodObject(obj: unknown): obj is z3.ZodObject<any> | z4.ZodObject<any> {
return isZodSchema(obj) && getZodTypeName(obj) === 'ZodObject';
}

// TODO(b/425992518): consider conversion to FunctionDeclaration directly.

function parseZodType(zodType: ZodTypeAny): Schema|undefined {
function parseZodV3Type(zodType: z3.ZodTypeAny): Schema | undefined {
const def = zodType._def;
if (!def) {
return {};
Expand All @@ -35,7 +71,7 @@ function parseZodType(zodType: ZodTypeAny): Schema|undefined {
};

switch (def.typeName) {
case z.ZodFirstPartyTypeKind.ZodString:
case z3.ZodFirstPartyTypeKind.ZodString:
result.type = Type.STRING;
for (const check of def.checks || []) {
if (check.kind === 'min')
Expand All @@ -53,7 +89,7 @@ function parseZodType(zodType: ZodTypeAny): Schema|undefined {
}
return returnResult(result);

case z.ZodFirstPartyTypeKind.ZodNumber:
case z3.ZodFirstPartyTypeKind.ZodNumber:
result.type = Type.NUMBER;
for (const check of def.checks || []) {
if (check.kind === 'min')
Expand All @@ -65,23 +101,23 @@ function parseZodType(zodType: ZodTypeAny): Schema|undefined {
}
return returnResult(result);

case z.ZodFirstPartyTypeKind.ZodBoolean:
case z3.ZodFirstPartyTypeKind.ZodBoolean:
result.type = Type.BOOLEAN;
return returnResult(result);

case z.ZodFirstPartyTypeKind.ZodArray:
case z3.ZodFirstPartyTypeKind.ZodArray:
result.type = Type.ARRAY;
result.items = parseZodType(def.type);
result.items = parseZodV3Type(def.type);
if (def.minLength) result.minItems = def.minLength.value.toString();
if (def.maxLength) result.maxItems = def.maxLength.value.toString();
return returnResult(result);

case z.ZodFirstPartyTypeKind.ZodObject: {
const nestedSchema = zodObjectToSchema(zodType as ZodObject<any>);
case z3.ZodFirstPartyTypeKind.ZodObject: {
const nestedSchema = zodObjectToSchema(zodType as z3.ZodObject<any>);
return nestedSchema as Schema;
}

case z.ZodFirstPartyTypeKind.ZodLiteral:
case z3.ZodFirstPartyTypeKind.ZodLiteral:
const literalType = typeof def.value;
result.enum = [def.value.toString()];

Expand All @@ -99,71 +135,67 @@ function parseZodType(zodType: ZodTypeAny): Schema|undefined {

return returnResult(result);

case z.ZodFirstPartyTypeKind.ZodEnum:
case z3.ZodFirstPartyTypeKind.ZodEnum:
result.type = Type.STRING;
result.enum = def.values;
return returnResult(result);

case z.ZodFirstPartyTypeKind.ZodNativeEnum:
case z3.ZodFirstPartyTypeKind.ZodNativeEnum:
result.type = Type.STRING;
result.enum = Object.values(def.values);
return returnResult(result);

case z.ZodFirstPartyTypeKind.ZodUnion:
result.anyOf = def.options.map(parseZodType);
case z3.ZodFirstPartyTypeKind.ZodUnion:
result.anyOf = def.options.map(parseZodV3Type);
return returnResult(result);

case z.ZodFirstPartyTypeKind.ZodOptional:
return parseZodType(def.innerType);
case z.ZodFirstPartyTypeKind.ZodNullable:
const nullableInner = parseZodType(def.innerType);
case z3.ZodFirstPartyTypeKind.ZodOptional:
return parseZodV3Type(def.innerType);
case z3.ZodFirstPartyTypeKind.ZodNullable:
const nullableInner = parseZodV3Type(def.innerType);
return nullableInner ?
returnResult({
anyOf: [nullableInner, {type: Type.NULL}],
...(description && {description})
}) :
returnResult({type: Type.NULL, ...(description && {description})});
case z.ZodFirstPartyTypeKind.ZodDefault:
const defaultInner = parseZodType(def.innerType);
returnResult({
anyOf: [nullableInner, { type: Type.NULL }],
...(description && { description })
}) :
returnResult({ type: Type.NULL, ...(description && { description }) });
case z3.ZodFirstPartyTypeKind.ZodDefault:
const defaultInner = parseZodV3Type(def.innerType);
if (defaultInner) defaultInner.default = def.defaultValue();
return defaultInner;
case z.ZodFirstPartyTypeKind.ZodBranded:
return parseZodType(def.type);
case z.ZodFirstPartyTypeKind.ZodReadonly:
return parseZodType(def.innerType);
case z.ZodFirstPartyTypeKind.ZodNull:
case z3.ZodFirstPartyTypeKind.ZodBranded:
return parseZodV3Type(def.type);
case z3.ZodFirstPartyTypeKind.ZodReadonly:
return parseZodV3Type(def.innerType);
case z3.ZodFirstPartyTypeKind.ZodNull:
result.type = Type.NULL;
return returnResult(result);
case z.ZodFirstPartyTypeKind.ZodAny:
case z.ZodFirstPartyTypeKind.ZodUnknown:
return returnResult({...(description && {description})});
case z3.ZodFirstPartyTypeKind.ZodAny:
case z3.ZodFirstPartyTypeKind.ZodUnknown:
return returnResult({ ...(description && { description }) });
default:
throw new Error(`Unsupported Zod type: ${def.typeName}`);
}
}

export function zodObjectToSchema(schema: ZodObject<any>): Schema {
if (schema._def.typeName !== z.ZodFirstPartyTypeKind.ZodObject) {
throw new Error('Expected a ZodObject');
}

function toJsonSchemaZ3(schema: z3.ZodObject<z3.ZodRawShape>): Schema {
const shape = schema.shape;
const properties: Record<string, Schema> = {};
const required: string[] = [];

for (const key in shape) {
const fieldSchema = shape[key];
const parsedField = parseZodType(fieldSchema);
const parsedField = parseZodV3Type(fieldSchema);
if (parsedField) {
properties[key] = parsedField;
}

let currentSchema = fieldSchema;
let isOptional = false;
while (currentSchema._def.typeName ===
z.ZodFirstPartyTypeKind.ZodOptional ||
currentSchema._def.typeName === z.ZodFirstPartyTypeKind.ZodDefault) {
isOptional = true;
z3.ZodFirstPartyTypeKind.ZodOptional ||
currentSchema._def.typeName === z3.ZodFirstPartyTypeKind.ZodDefault) {
isOptional = true;
currentSchema = currentSchema._def.innerType;
}
if (!isOptional) {
Expand All @@ -172,16 +204,38 @@ export function zodObjectToSchema(schema: ZodObject<any>): Schema {
}

const catchall = schema._def.catchall;
let additionalProperties: boolean|Schema = false;
if (catchall && catchall._def.typeName !== z.ZodFirstPartyTypeKind.ZodNever) {
additionalProperties = parseZodType(catchall) || true;
let additionalProperties: boolean | Schema = false;
if (catchall && catchall._def.typeName !== z3.ZodFirstPartyTypeKind.ZodNever) {
additionalProperties = parseZodV3Type(catchall) || true;
} else {
additionalProperties = schema._def.unknownKeys === 'passthrough';
}
return {
type: Type.OBJECT,
properties,
required: required.length > 0 ? required : [],
...(schema._def.description ? {description: schema._def.description} : {}),
...(schema._def.description ? { description: schema._def.description } : {}),
};
}

export function zodObjectToSchema(schema: z3.ZodObject<z3.ZodRawShape> | z4.ZodObject<z4.ZodRawShape>): Schema {
if (!isZodObject(schema)) {
throw new Error('Expected a Zod Object');
}

if (isZodV4Schema(schema)) {
return toJSONSchemaV4(schema, {
target: 'openapi-3.0', override(ctx) {
if (ctx.jsonSchema.additionalProperties !== undefined) {
delete ctx.jsonSchema.additionalProperties;
}
},
}) as Schema;
}

if (isZodV3Schema(schema)) {
return toJsonSchemaZ3(schema);
}

throw new Error('Unsupported Zod schema version.');
}
Loading
Loading