Skip to content

Commit 8f95da4

Browse files
Erick Friisccurme
Erick Friis
andauthored
multiple: structured output tracing standard metadata (#29421)
Co-authored-by: Chester Curme <chester.curme@gmail.com>
1 parent 284c935 commit 8f95da4

File tree

9 files changed

+288
-28
lines changed

9 files changed

+288
-28
lines changed

libs/core/langchain_core/language_models/chat_models.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,11 +365,28 @@ def stream(
365365
else:
366366
config = ensure_config(config)
367367
messages = self._convert_input(input).to_messages()
368+
structured_output_format = kwargs.pop("structured_output_format", None)
369+
if structured_output_format:
370+
try:
371+
structured_output_format_dict = {
372+
"structured_output_format": {
373+
"kwargs": structured_output_format.get("kwargs", {}),
374+
"schema": convert_to_openai_tool(
375+
structured_output_format["schema"]
376+
),
377+
}
378+
}
379+
except ValueError:
380+
structured_output_format_dict = {}
381+
else:
382+
structured_output_format_dict = {}
383+
368384
params = self._get_invocation_params(stop=stop, **kwargs)
369385
options = {"stop": stop, **kwargs}
370386
inheritable_metadata = {
371387
**(config.get("metadata") or {}),
372388
**self._get_ls_params(stop=stop, **kwargs),
389+
**structured_output_format_dict,
373390
}
374391
callback_manager = CallbackManager.configure(
375392
config.get("callbacks"),
@@ -441,11 +458,29 @@ async def astream(
441458

442459
config = ensure_config(config)
443460
messages = self._convert_input(input).to_messages()
461+
462+
structured_output_format = kwargs.pop("structured_output_format", None)
463+
if structured_output_format:
464+
try:
465+
structured_output_format_dict = {
466+
"structured_output_format": {
467+
"kwargs": structured_output_format.get("kwargs", {}),
468+
"schema": convert_to_openai_tool(
469+
structured_output_format["schema"]
470+
),
471+
}
472+
}
473+
except ValueError:
474+
structured_output_format_dict = {}
475+
else:
476+
structured_output_format_dict = {}
477+
444478
params = self._get_invocation_params(stop=stop, **kwargs)
445479
options = {"stop": stop, **kwargs}
446480
inheritable_metadata = {
447481
**(config.get("metadata") or {}),
448482
**self._get_ls_params(stop=stop, **kwargs),
483+
**structured_output_format_dict,
449484
}
450485
callback_manager = AsyncCallbackManager.configure(
451486
config.get("callbacks"),
@@ -606,11 +641,28 @@ def generate(
606641
An LLMResult, which contains a list of candidate Generations for each input
607642
prompt and additional model provider-specific output.
608643
"""
644+
structured_output_format = kwargs.pop("structured_output_format", None)
645+
if structured_output_format:
646+
try:
647+
structured_output_format_dict = {
648+
"structured_output_format": {
649+
"kwargs": structured_output_format.get("kwargs", {}),
650+
"schema": convert_to_openai_tool(
651+
structured_output_format["schema"]
652+
),
653+
}
654+
}
655+
except ValueError:
656+
structured_output_format_dict = {}
657+
else:
658+
structured_output_format_dict = {}
659+
609660
params = self._get_invocation_params(stop=stop, **kwargs)
610661
options = {"stop": stop}
611662
inheritable_metadata = {
612663
**(metadata or {}),
613664
**self._get_ls_params(stop=stop, **kwargs),
665+
**structured_output_format_dict,
614666
}
615667

616668
callback_manager = CallbackManager.configure(
@@ -697,11 +749,28 @@ async def agenerate(
697749
An LLMResult, which contains a list of candidate Generations for each input
698750
prompt and additional model provider-specific output.
699751
"""
752+
structured_output_format = kwargs.pop("structured_output_format", None)
753+
if structured_output_format:
754+
try:
755+
structured_output_format_dict = {
756+
"structured_output_format": {
757+
"kwargs": structured_output_format.get("kwargs", {}),
758+
"schema": convert_to_openai_tool(
759+
structured_output_format["schema"]
760+
),
761+
}
762+
}
763+
except ValueError:
764+
structured_output_format_dict = {}
765+
else:
766+
structured_output_format_dict = {}
767+
700768
params = self._get_invocation_params(stop=stop, **kwargs)
701769
options = {"stop": stop}
702770
inheritable_metadata = {
703771
**(metadata or {}),
704772
**self._get_ls_params(stop=stop, **kwargs),
773+
**structured_output_format_dict,
705774
}
706775

707776
callback_manager = AsyncCallbackManager.configure(
@@ -1240,7 +1309,12 @@ class AnswerWithJustification(BaseModel):
12401309
if self.bind_tools is BaseChatModel.bind_tools:
12411310
msg = "with_structured_output is not implemented for this model."
12421311
raise NotImplementedError(msg)
1243-
llm = self.bind_tools([schema], tool_choice="any")
1312+
1313+
llm = self.bind_tools(
1314+
[schema],
1315+
tool_choice="any",
1316+
structured_output_format={"kwargs": {}, "schema": schema},
1317+
)
12441318
if isinstance(schema, type) and is_basemodel_subclass(schema):
12451319
output_parser: OutputParserLike = PydanticToolsParser(
12461320
tools=[cast(TypeBaseModel, schema)], first_tool_only=True

libs/partners/anthropic/langchain_anthropic/chat_models.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,9 +1111,13 @@ class AnswerWithJustification(BaseModel):
11111111
Added support for TypedDict class as `schema`.
11121112
11131113
""" # noqa: E501
1114-
1115-
tool_name = convert_to_anthropic_tool(schema)["name"]
1116-
llm = self.bind_tools([schema], tool_choice=tool_name)
1114+
formatted_tool = convert_to_anthropic_tool(schema)
1115+
tool_name = formatted_tool["name"]
1116+
llm = self.bind_tools(
1117+
[schema],
1118+
tool_choice=tool_name,
1119+
structured_output_format={"kwargs": {}, "schema": formatted_tool},
1120+
)
11171121
if isinstance(schema, type) and is_basemodel_subclass(schema):
11181122
output_parser: OutputParserLike = PydanticToolsParser(
11191123
tools=[schema], first_tool_only=True

libs/partners/fireworks/langchain_fireworks/chat_models.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -965,8 +965,16 @@ class AnswerWithJustification(BaseModel):
965965
"schema must be specified when method is 'function_calling'. "
966966
"Received None."
967967
)
968-
tool_name = convert_to_openai_tool(schema)["function"]["name"]
969-
llm = self.bind_tools([schema], tool_choice=tool_name)
968+
formatted_tool = convert_to_openai_tool(schema)
969+
tool_name = formatted_tool["function"]["name"]
970+
llm = self.bind_tools(
971+
[schema],
972+
tool_choice=tool_name,
973+
structured_output_format={
974+
"kwargs": {"method": "function_calling"},
975+
"schema": formatted_tool,
976+
},
977+
)
970978
if is_pydantic_schema:
971979
output_parser: OutputParserLike = PydanticToolsParser(
972980
tools=[schema], # type: ignore[list-item]
@@ -977,7 +985,13 @@ class AnswerWithJustification(BaseModel):
977985
key_name=tool_name, first_tool_only=True
978986
)
979987
elif method == "json_mode":
980-
llm = self.bind(response_format={"type": "json_object"})
988+
llm = self.bind(
989+
response_format={"type": "json_object"},
990+
structured_output_format={
991+
"kwargs": {"method": "json_mode"},
992+
"schema": schema,
993+
},
994+
)
981995
output_parser = (
982996
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
983997
if is_pydantic_schema

libs/partners/groq/langchain_groq/chat_models.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -996,8 +996,16 @@ class AnswerWithJustification(BaseModel):
996996
"schema must be specified when method is 'function_calling'. "
997997
"Received None."
998998
)
999-
tool_name = convert_to_openai_tool(schema)["function"]["name"]
1000-
llm = self.bind_tools([schema], tool_choice=tool_name)
999+
formatted_tool = convert_to_openai_tool(schema)
1000+
tool_name = formatted_tool["function"]["name"]
1001+
llm = self.bind_tools(
1002+
[schema],
1003+
tool_choice=tool_name,
1004+
structured_output_format={
1005+
"kwargs": {"method": "function_calling"},
1006+
"schema": formatted_tool,
1007+
},
1008+
)
10011009
if is_pydantic_schema:
10021010
output_parser: OutputParserLike = PydanticToolsParser(
10031011
tools=[schema], # type: ignore[list-item]
@@ -1008,7 +1016,13 @@ class AnswerWithJustification(BaseModel):
10081016
key_name=tool_name, first_tool_only=True
10091017
)
10101018
elif method == "json_mode":
1011-
llm = self.bind(response_format={"type": "json_object"})
1019+
llm = self.bind(
1020+
response_format={"type": "json_object"},
1021+
structured_output_format={
1022+
"kwargs": {"method": "json_mode"},
1023+
"schema": schema,
1024+
},
1025+
)
10121026
output_parser = (
10131027
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
10141028
if is_pydantic_schema

libs/partners/mistralai/langchain_mistralai/chat_models.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,14 @@ class AnswerWithJustification(BaseModel):
931931
)
932932
# TODO: Update to pass in tool name as tool_choice if/when Mistral supports
933933
# specifying a tool.
934-
llm = self.bind_tools([schema], tool_choice="any")
934+
llm = self.bind_tools(
935+
[schema],
936+
tool_choice="any",
937+
structured_output_format={
938+
"kwargs": {"method": "function_calling"},
939+
"schema": schema,
940+
},
941+
)
935942
if is_pydantic_schema:
936943
output_parser: OutputParserLike = PydanticToolsParser(
937944
tools=[schema], # type: ignore[list-item]
@@ -943,7 +950,16 @@ class AnswerWithJustification(BaseModel):
943950
key_name=key_name, first_tool_only=True
944951
)
945952
elif method == "json_mode":
946-
llm = self.bind(response_format={"type": "json_object"})
953+
llm = self.bind(
954+
response_format={"type": "json_object"},
955+
structured_output_format={
956+
"kwargs": {
957+
# this is correct - name difference with mistral api
958+
"method": "json_mode"
959+
},
960+
"schema": schema,
961+
},
962+
)
947963
output_parser = (
948964
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
949965
if is_pydantic_schema
@@ -956,7 +972,13 @@ class AnswerWithJustification(BaseModel):
956972
"Received None."
957973
)
958974
response_format = _convert_to_openai_response_format(schema, strict=True)
959-
llm = self.bind(response_format=response_format)
975+
llm = self.bind(
976+
response_format=response_format,
977+
structured_output_format={
978+
"kwargs": {"method": "json_schema"},
979+
"schema": schema,
980+
},
981+
)
960982

961983
output_parser = (
962984
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]

libs/partners/ollama/langchain_ollama/chat_models.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,8 +1085,16 @@ class AnswerWithJustification(BaseModel):
10851085
"schema must be specified when method is not 'json_mode'. "
10861086
"Received None."
10871087
)
1088-
tool_name = convert_to_openai_tool(schema)["function"]["name"]
1089-
llm = self.bind_tools([schema], tool_choice=tool_name)
1088+
formatted_tool = convert_to_openai_tool(schema)
1089+
tool_name = formatted_tool["function"]["name"]
1090+
llm = self.bind_tools(
1091+
[schema],
1092+
tool_choice=tool_name,
1093+
structured_output_format={
1094+
"kwargs": {"method": method},
1095+
"schema": formatted_tool,
1096+
},
1097+
)
10901098
if is_pydantic_schema:
10911099
output_parser: Runnable = PydanticToolsParser(
10921100
tools=[schema], # type: ignore[list-item]
@@ -1097,7 +1105,13 @@ class AnswerWithJustification(BaseModel):
10971105
key_name=tool_name, first_tool_only=True
10981106
)
10991107
elif method == "json_mode":
1100-
llm = self.bind(format="json")
1108+
llm = self.bind(
1109+
format="json",
1110+
structured_output_format={
1111+
"kwargs": {"method": method},
1112+
"schema": schema,
1113+
},
1114+
)
11011115
output_parser = (
11021116
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
11031117
if is_pydantic_schema
@@ -1111,7 +1125,13 @@ class AnswerWithJustification(BaseModel):
11111125
)
11121126
if is_pydantic_schema:
11131127
schema = cast(TypeBaseModel, schema)
1114-
llm = self.bind(format=schema.model_json_schema())
1128+
llm = self.bind(
1129+
format=schema.model_json_schema(),
1130+
structured_output_format={
1131+
"kwargs": {"method": method},
1132+
"schema": schema,
1133+
},
1134+
)
11151135
output_parser = PydanticOutputParser(pydantic_object=schema)
11161136
else:
11171137
if is_typeddict(schema):
@@ -1126,7 +1146,13 @@ class AnswerWithJustification(BaseModel):
11261146
else:
11271147
# is JSON schema
11281148
response_format = schema
1129-
llm = self.bind(format=response_format)
1149+
llm = self.bind(
1150+
format=response_format,
1151+
structured_output_format={
1152+
"kwargs": {"method": method},
1153+
"schema": response_format,
1154+
},
1155+
)
11301156
output_parser = JsonOutputParser()
11311157
else:
11321158
raise ValueError(

libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_standard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def supports_json_mode(self) -> bool:
3131
"Fails with 'AssertionError'. Ollama does not support 'tool_choice' yet."
3232
)
3333
)
34-
def test_structured_output(self, model: BaseChatModel) -> None:
35-
super().test_structured_output(model)
34+
def test_structured_output(self, model: BaseChatModel, schema_type: str) -> None:
35+
super().test_structured_output(model, schema_type)
3636

3737
@pytest.mark.xfail(
3838
reason=(

libs/partners/openai/langchain_openai/chat_models/base.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,7 +1390,13 @@ def with_structured_output(
13901390
)
13911391
tool_name = convert_to_openai_tool(schema)["function"]["name"]
13921392
bind_kwargs = self._filter_disabled_params(
1393-
tool_choice=tool_name, parallel_tool_calls=False, strict=strict
1393+
tool_choice=tool_name,
1394+
parallel_tool_calls=False,
1395+
strict=strict,
1396+
structured_output_format={
1397+
"kwargs": {"method": method},
1398+
"schema": schema,
1399+
},
13941400
)
13951401

13961402
llm = self.bind_tools([schema], **bind_kwargs)
@@ -1404,7 +1410,13 @@ def with_structured_output(
14041410
key_name=tool_name, first_tool_only=True
14051411
)
14061412
elif method == "json_mode":
1407-
llm = self.bind(response_format={"type": "json_object"})
1413+
llm = self.bind(
1414+
response_format={"type": "json_object"},
1415+
structured_output_format={
1416+
"kwargs": {"method": method},
1417+
"schema": schema,
1418+
},
1419+
)
14081420
output_parser = (
14091421
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
14101422
if is_pydantic_schema
@@ -1417,7 +1429,13 @@ def with_structured_output(
14171429
"Received None."
14181430
)
14191431
response_format = _convert_to_openai_response_format(schema, strict=strict)
1420-
llm = self.bind(response_format=response_format)
1432+
llm = self.bind(
1433+
response_format=response_format,
1434+
structured_output_format={
1435+
"kwargs": {"method": method},
1436+
"schema": convert_to_openai_tool(schema),
1437+
},
1438+
)
14211439
if is_pydantic_schema:
14221440
output_parser = _oai_structured_outputs_parser.with_types(
14231441
output_type=cast(type, schema)

0 commit comments

Comments
 (0)