Skip to content

Commit c780e7c

Browse files
committed
feat(langchain4j): add UsageMetadata extraction from TokenCountEstimator or TokenUsage
Signed-off-by: Rhuan Rocha <rhuan080@gmail.com> feat(langchain4j): fixing exception treatment Signed-off-by: Rhuan Rocha <rhuan080@gmail.com> feat(langchain4j): fixing exception treatment Signed-off-by: Rhuan Rocha <rhuan080@gmail.com> feat(langchain4j): refactoring constructor Signed-off-by: Rhuan Rocha <rhuan080@gmail.com>
1 parent ad901e2 commit c780e7c

File tree

7 files changed

+441
-11
lines changed

7 files changed

+441
-11
lines changed

contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java

Lines changed: 127 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import com.google.genai.types.FunctionDeclaration;
3030
import com.google.genai.types.FunctionResponse;
3131
import com.google.genai.types.GenerateContentConfig;
32+
import com.google.genai.types.GenerateContentResponseUsageMetadata;
3233
import com.google.genai.types.Part;
3334
import com.google.genai.types.Schema;
3435
import com.google.genai.types.ToolConfig;
@@ -51,6 +52,7 @@
5152
import dev.langchain4j.data.pdf.PdfFile;
5253
import dev.langchain4j.data.video.Video;
5354
import dev.langchain4j.exception.UnsupportedFeatureException;
55+
import dev.langchain4j.model.TokenCountEstimator;
5456
import dev.langchain4j.model.chat.ChatModel;
5557
import dev.langchain4j.model.chat.StreamingChatModel;
5658
import dev.langchain4j.model.chat.request.ChatRequest;
@@ -64,6 +66,7 @@
6466
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
6567
import dev.langchain4j.model.chat.response.ChatResponse;
6668
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
69+
import dev.langchain4j.model.output.TokenUsage;
6770
import io.reactivex.rxjava3.core.BackpressureStrategy;
6871
import io.reactivex.rxjava3.core.Flowable;
6972
import java.util.ArrayList;
@@ -83,24 +86,109 @@ public class LangChain4j extends BaseLlm {
8386
private final ChatModel chatModel;
8487
private final StreamingChatModel streamingChatModel;
8588
private final ObjectMapper objectMapper;
89+
private final TokenCountEstimator tokenCountEstimator;
90+
91+
public static Builder builder() {
92+
return new Builder();
93+
}
94+
95+
public static class Builder {
96+
private ChatModel chatModel;
97+
private StreamingChatModel streamingChatModel;
98+
private String modelName;
99+
private TokenCountEstimator tokenCountEstimator;
100+
101+
private Builder() {}
102+
103+
public Builder chatModel(ChatModel chatModel) {
104+
this.chatModel = chatModel;
105+
return this;
106+
}
107+
108+
public Builder streamingChatModel(StreamingChatModel streamingChatModel) {
109+
this.streamingChatModel = streamingChatModel;
110+
return this;
111+
}
112+
113+
public Builder modelName(String modelName) {
114+
this.modelName = modelName;
115+
return this;
116+
}
117+
118+
public Builder tokenCountEstimator(TokenCountEstimator tokenCountEstimator) {
119+
this.tokenCountEstimator = tokenCountEstimator;
120+
return this;
121+
}
122+
123+
public LangChain4j build() {
124+
if (chatModel == null && streamingChatModel == null) {
125+
throw new IllegalStateException(
126+
"At least one of chatModel or streamingChatModel must be provided");
127+
}
128+
129+
String effectiveModelName = modelName;
130+
if (effectiveModelName == null) {
131+
if (chatModel != null) {
132+
effectiveModelName = chatModel.defaultRequestParameters().modelName();
133+
} else {
134+
effectiveModelName = streamingChatModel.defaultRequestParameters().modelName();
135+
}
136+
}
137+
138+
if (effectiveModelName == null) {
139+
throw new IllegalStateException("Model name cannot be null");
140+
}
141+
142+
return new LangChain4j(
143+
chatModel, streamingChatModel, effectiveModelName, tokenCountEstimator);
144+
}
145+
}
146+
147+
private LangChain4j(
148+
ChatModel chatModel,
149+
StreamingChatModel streamingChatModel,
150+
String modelName,
151+
TokenCountEstimator tokenCountEstimator) {
152+
super(Objects.requireNonNull(modelName, "model name cannot be null"));
153+
this.chatModel = chatModel;
154+
this.streamingChatModel = streamingChatModel;
155+
this.objectMapper = new ObjectMapper();
156+
this.tokenCountEstimator = tokenCountEstimator;
157+
}
86158

87159
public LangChain4j(ChatModel chatModel) {
160+
this(chatModel, (TokenCountEstimator) null);
161+
}
162+
163+
public LangChain4j(ChatModel chatModel, TokenCountEstimator tokenCountEstimator) {
88164
super(
89165
Objects.requireNonNull(
90166
chatModel.defaultRequestParameters().modelName(), "chat model name cannot be null"));
91167
this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null");
92168
this.streamingChatModel = null;
93169
this.objectMapper = new ObjectMapper();
170+
this.tokenCountEstimator = tokenCountEstimator;
94171
}
95172

96173
public LangChain4j(ChatModel chatModel, String modelName) {
174+
this(chatModel, modelName, (TokenCountEstimator) null);
175+
}
176+
177+
public LangChain4j(
178+
ChatModel chatModel, String modelName, TokenCountEstimator tokenCountEstimator) {
97179
super(Objects.requireNonNull(modelName, "chat model name cannot be null"));
98180
this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null");
99181
this.streamingChatModel = null;
100182
this.objectMapper = new ObjectMapper();
183+
this.tokenCountEstimator = tokenCountEstimator;
101184
}
102185

103186
public LangChain4j(StreamingChatModel streamingChatModel) {
187+
this(streamingChatModel, (TokenCountEstimator) null);
188+
}
189+
190+
public LangChain4j(
191+
StreamingChatModel streamingChatModel, TokenCountEstimator tokenCountEstimator) {
104192
super(
105193
Objects.requireNonNull(
106194
streamingChatModel.defaultRequestParameters().modelName(),
@@ -109,22 +197,23 @@ public LangChain4j(StreamingChatModel streamingChatModel) {
109197
this.streamingChatModel =
110198
Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null");
111199
this.objectMapper = new ObjectMapper();
200+
this.tokenCountEstimator = tokenCountEstimator;
112201
}
113202

114203
public LangChain4j(StreamingChatModel streamingChatModel, String modelName) {
115-
super(Objects.requireNonNull(modelName, "streaming chat model name cannot be null"));
116-
this.chatModel = null;
117-
this.streamingChatModel =
118-
Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null");
119-
this.objectMapper = new ObjectMapper();
204+
this(streamingChatModel, modelName, (TokenCountEstimator) null);
120205
}
121206

122-
public LangChain4j(ChatModel chatModel, StreamingChatModel streamingChatModel, String modelName) {
123-
super(Objects.requireNonNull(modelName, "model name cannot be null"));
124-
this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null");
207+
public LangChain4j(
208+
StreamingChatModel streamingChatModel,
209+
String modelName,
210+
TokenCountEstimator tokenCountEstimator) {
211+
super(Objects.requireNonNull(modelName, "streaming chat model name cannot be null"));
212+
this.chatModel = null;
125213
this.streamingChatModel =
126214
Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null");
127215
this.objectMapper = new ObjectMapper();
216+
this.tokenCountEstimator = tokenCountEstimator;
128217
}
129218

130219
@Override
@@ -185,7 +274,7 @@ public void onError(Throwable throwable) {
185274

186275
ChatRequest chatRequest = toChatRequest(llmRequest);
187276
ChatResponse chatResponse = chatModel.chat(chatRequest);
188-
LlmResponse llmResponse = toLlmResponse(chatResponse);
277+
LlmResponse llmResponse = toLlmResponse(chatResponse, chatRequest);
189278

190279
return Flowable.just(llmResponse);
191280
}
@@ -496,11 +585,38 @@ private JsonSchemaElement toJsonSchemaElement(Schema schema) {
496585
}
497586
}
498587

499-
private LlmResponse toLlmResponse(ChatResponse chatResponse) {
588+
private LlmResponse toLlmResponse(ChatResponse chatResponse, ChatRequest chatRequest) {
500589
Content content =
501590
Content.builder().role("model").parts(toParts(chatResponse.aiMessage())).build();
502591

503-
return LlmResponse.builder().content(content).build();
592+
LlmResponse.Builder builder = LlmResponse.builder().content(content);
593+
TokenUsage tokenUsage = chatResponse.tokenUsage();
594+
if (tokenCountEstimator != null) {
595+
try {
596+
int estimatedInput =
597+
tokenCountEstimator.estimateTokenCountInMessages(chatRequest.messages());
598+
int estimatedOutput =
599+
tokenCountEstimator.estimateTokenCountInText(chatResponse.aiMessage().text());
600+
int estimatedTotal = estimatedInput + estimatedOutput;
601+
builder.usageMetadata(
602+
GenerateContentResponseUsageMetadata.builder()
603+
.promptTokenCount(estimatedInput)
604+
.candidatesTokenCount(estimatedOutput)
605+
.totalTokenCount(estimatedTotal)
606+
.build());
607+
} catch (Exception e) {
608+
e.printStackTrace();
609+
}
610+
} else if (tokenUsage != null) {
611+
builder.usageMetadata(
612+
GenerateContentResponseUsageMetadata.builder()
613+
.promptTokenCount(tokenUsage.inputTokenCount())
614+
.candidatesTokenCount(tokenUsage.outputTokenCount())
615+
.totalTokenCount(tokenUsage.totalTokenCount())
616+
.build());
617+
}
618+
619+
return builder.build();
504620
}
505621

506622
private List<Part> toParts(AiMessage aiMessage) {

contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,15 @@
2626
import dev.langchain4j.agent.tool.ToolExecutionRequest;
2727
import dev.langchain4j.data.message.AiMessage;
2828
import dev.langchain4j.data.message.UserMessage;
29+
import dev.langchain4j.model.TokenCountEstimator;
2930
import dev.langchain4j.model.chat.ChatModel;
3031
import dev.langchain4j.model.chat.StreamingChatModel;
3132
import dev.langchain4j.model.chat.request.ChatRequest;
3233
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
3334
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
3435
import dev.langchain4j.model.chat.response.ChatResponse;
3536
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
37+
import dev.langchain4j.model.output.TokenUsage;
3638
import io.reactivex.rxjava3.core.Flowable;
3739
import java.util.ArrayList;
3840
import java.util.List;
@@ -688,4 +690,140 @@ void testGenerateContentWithStructuredResponseJsonSchema() {
688690
final UserMessage userMessage = (UserMessage) capturedRequest.messages().get(0);
689691
assertThat(userMessage.singleText()).isEqualTo("Give me information about John Doe");
690692
}
693+
694+
@Test
695+
@DisplayName(
696+
"Should use TokenCountEstimator to estimate token usage when TokenUsage is not available")
697+
void testTokenCountEstimatorFallback() {
698+
// Given
699+
// Create a mock TokenCountEstimator
700+
final TokenCountEstimator tokenCountEstimator = mock(TokenCountEstimator.class);
701+
when(tokenCountEstimator.estimateTokenCountInMessages(any())).thenReturn(50); // Input tokens
702+
when(tokenCountEstimator.estimateTokenCountInText(any())).thenReturn(20); // Output tokens
703+
704+
// Create LangChain4j with the TokenCountEstimator using Builder
705+
final LangChain4j langChain4jWithEstimator =
706+
LangChain4j.builder()
707+
.chatModel(chatModel)
708+
.modelName(MODEL_NAME)
709+
.tokenCountEstimator(tokenCountEstimator)
710+
.build();
711+
712+
// Create a LlmRequest
713+
final LlmRequest llmRequest =
714+
LlmRequest.builder()
715+
.contents(List.of(Content.fromParts(Part.fromText("What is the weather today?"))))
716+
.build();
717+
718+
// Mock ChatResponse WITHOUT TokenUsage (simulating when LLM doesn't provide token counts)
719+
final ChatResponse chatResponse = mock(ChatResponse.class);
720+
final AiMessage aiMessage = AiMessage.from("The weather is sunny today.");
721+
when(chatResponse.aiMessage()).thenReturn(aiMessage);
722+
when(chatResponse.tokenUsage()).thenReturn(null); // No token usage from LLM
723+
when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse);
724+
725+
// When
726+
final LlmResponse response =
727+
langChain4jWithEstimator.generateContent(llmRequest, false).blockingFirst();
728+
729+
// Then
730+
// Verify the response has usage metadata estimated by TokenCountEstimator
731+
assertThat(response).isNotNull();
732+
assertThat(response.content()).isPresent();
733+
assertThat(response.content().get().text()).isEqualTo("The weather is sunny today.");
734+
735+
// IMPORTANT: Verify that token usage was estimated via the TokenCountEstimator
736+
assertThat(response.usageMetadata()).isPresent();
737+
final GenerateContentResponseUsageMetadata usageMetadata = response.usageMetadata().get();
738+
assertThat(usageMetadata.promptTokenCount()).isEqualTo(Optional.of(50)); // From estimator
739+
assertThat(usageMetadata.candidatesTokenCount()).isEqualTo(Optional.of(20)); // From estimator
740+
assertThat(usageMetadata.totalTokenCount()).isEqualTo(Optional.of(70)); // 50 + 20
741+
742+
// Verify the estimator was actually called
743+
verify(tokenCountEstimator).estimateTokenCountInMessages(any());
744+
verify(tokenCountEstimator).estimateTokenCountInText("The weather is sunny today.");
745+
}
746+
747+
@Test
748+
@DisplayName("Should prioritize TokenCountEstimator over TokenUsage when estimator is provided")
749+
void testTokenCountEstimatorPriority() {
750+
// Given
751+
// Create a mock TokenCountEstimator
752+
final TokenCountEstimator tokenCountEstimator = mock(TokenCountEstimator.class);
753+
when(tokenCountEstimator.estimateTokenCountInMessages(any())).thenReturn(100); // From estimator
754+
when(tokenCountEstimator.estimateTokenCountInText(any())).thenReturn(50); // From estimator
755+
756+
// Create LangChain4j with the TokenCountEstimator using Builder
757+
final LangChain4j langChain4jWithEstimator =
758+
LangChain4j.builder()
759+
.chatModel(chatModel)
760+
.modelName(MODEL_NAME)
761+
.tokenCountEstimator(tokenCountEstimator)
762+
.build();
763+
764+
// Create a LlmRequest
765+
final LlmRequest llmRequest =
766+
LlmRequest.builder()
767+
.contents(List.of(Content.fromParts(Part.fromText("What is the weather today?"))))
768+
.build();
769+
770+
// Mock ChatResponse WITH actual TokenUsage from the LLM
771+
final ChatResponse chatResponse = mock(ChatResponse.class);
772+
final AiMessage aiMessage = AiMessage.from("The weather is sunny today.");
773+
final TokenUsage actualTokenUsage = new TokenUsage(30, 15, 45); // Actual token counts from LLM
774+
when(chatResponse.aiMessage()).thenReturn(aiMessage);
775+
when(chatResponse.tokenUsage()).thenReturn(actualTokenUsage); // LLM provides token usage
776+
when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse);
777+
778+
// When
779+
final LlmResponse response =
780+
langChain4jWithEstimator.generateContent(llmRequest, false).blockingFirst();
781+
782+
// Then
783+
// IMPORTANT: When TokenCountEstimator is present, it takes priority over TokenUsage
784+
assertThat(response).isNotNull();
785+
assertThat(response.usageMetadata()).isPresent();
786+
final GenerateContentResponseUsageMetadata usageMetadata = response.usageMetadata().get();
787+
assertThat(usageMetadata.promptTokenCount()).isEqualTo(Optional.of(100)); // From estimator
788+
assertThat(usageMetadata.candidatesTokenCount()).isEqualTo(Optional.of(50)); // From estimator
789+
assertThat(usageMetadata.totalTokenCount()).isEqualTo(Optional.of(150)); // 100 + 50
790+
791+
// Verify the estimator was called (it takes priority)
792+
verify(tokenCountEstimator).estimateTokenCountInMessages(any());
793+
verify(tokenCountEstimator).estimateTokenCountInText("The weather is sunny today.");
794+
}
795+
796+
@Test
797+
@DisplayName("Should not include usageMetadata when TokenUsage is null and no estimator provided")
798+
void testNoUsageMetadataWithoutEstimator() {
799+
// Given
800+
// Create LangChain4j WITHOUT TokenCountEstimator (default behavior)
801+
final LangChain4j langChain4jNoEstimator = new LangChain4j(chatModel, MODEL_NAME);
802+
803+
// Create a LlmRequest
804+
final LlmRequest llmRequest =
805+
LlmRequest.builder()
806+
.contents(List.of(Content.fromParts(Part.fromText("Hello, world!"))))
807+
.build();
808+
809+
// Mock ChatResponse WITHOUT TokenUsage
810+
final ChatResponse chatResponse = mock(ChatResponse.class);
811+
final AiMessage aiMessage = AiMessage.from("Hello! How can I help you?");
812+
when(chatResponse.aiMessage()).thenReturn(aiMessage);
813+
when(chatResponse.tokenUsage()).thenReturn(null); // No token usage from LLM
814+
when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse);
815+
816+
// When
817+
final LlmResponse response =
818+
langChain4jNoEstimator.generateContent(llmRequest, false).blockingFirst();
819+
820+
// Then
821+
// Verify the response does NOT have usage metadata
822+
assertThat(response).isNotNull();
823+
assertThat(response.content()).isPresent();
824+
assertThat(response.content().get().text()).isEqualTo("Hello! How can I help you?");
825+
826+
// IMPORTANT: usageMetadata should be empty when no TokenUsage and no estimator
827+
assertThat(response.usageMetadata()).isEmpty();
828+
}
691829
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<projectDescription>
3+
<name>google-adk-sample-a2a-basic</name>
4+
<comment></comment>
5+
<projects>
6+
</projects>
7+
<buildSpec>
8+
<buildCommand>
9+
<name>org.eclipse.jdt.core.javabuilder</name>
10+
<arguments>
11+
</arguments>
12+
</buildCommand>
13+
<buildCommand>
14+
<name>org.eclipse.m2e.core.maven2Builder</name>
15+
<arguments>
16+
</arguments>
17+
</buildCommand>
18+
</buildSpec>
19+
<natures>
20+
<nature>org.eclipse.jdt.core.javanature</nature>
21+
<nature>org.eclipse.m2e.core.maven2Nature</nature>
22+
</natures>
23+
</projectDescription>
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
eclipse.preferences.version=1
2+
encoding/<project>=UTF-8

0 commit comments

Comments
 (0)