2929import com .google .genai .types .FunctionDeclaration ;
3030import com .google .genai .types .FunctionResponse ;
3131import com .google .genai .types .GenerateContentConfig ;
32+ import com .google .genai .types .GenerateContentResponseUsageMetadata ;
3233import com .google .genai .types .Part ;
3334import com .google .genai .types .Schema ;
3435import com .google .genai .types .ToolConfig ;
5152import dev .langchain4j .data .pdf .PdfFile ;
5253import dev .langchain4j .data .video .Video ;
5354import dev .langchain4j .exception .UnsupportedFeatureException ;
55+ import dev .langchain4j .model .TokenCountEstimator ;
5456import dev .langchain4j .model .chat .ChatModel ;
5557import dev .langchain4j .model .chat .StreamingChatModel ;
5658import dev .langchain4j .model .chat .request .ChatRequest ;
6466import dev .langchain4j .model .chat .request .json .JsonStringSchema ;
6567import dev .langchain4j .model .chat .response .ChatResponse ;
6668import dev .langchain4j .model .chat .response .StreamingChatResponseHandler ;
69+ import dev .langchain4j .model .output .TokenUsage ;
6770import io .reactivex .rxjava3 .core .BackpressureStrategy ;
6871import io .reactivex .rxjava3 .core .Flowable ;
6972import java .util .ArrayList ;
@@ -83,24 +86,109 @@ public class LangChain4j extends BaseLlm {
8386 private final ChatModel chatModel ;
8487 private final StreamingChatModel streamingChatModel ;
8588 private final ObjectMapper objectMapper ;
89+ private final TokenCountEstimator tokenCountEstimator ;
90+
91+ public static Builder builder () {
92+ return new Builder ();
93+ }
94+
95+ public static class Builder {
96+ private ChatModel chatModel ;
97+ private StreamingChatModel streamingChatModel ;
98+ private String modelName ;
99+ private TokenCountEstimator tokenCountEstimator ;
100+
101+ private Builder () {}
102+
103+ public Builder chatModel (ChatModel chatModel ) {
104+ this .chatModel = chatModel ;
105+ return this ;
106+ }
107+
108+ public Builder streamingChatModel (StreamingChatModel streamingChatModel ) {
109+ this .streamingChatModel = streamingChatModel ;
110+ return this ;
111+ }
112+
113+ public Builder modelName (String modelName ) {
114+ this .modelName = modelName ;
115+ return this ;
116+ }
117+
118+ public Builder tokenCountEstimator (TokenCountEstimator tokenCountEstimator ) {
119+ this .tokenCountEstimator = tokenCountEstimator ;
120+ return this ;
121+ }
122+
123+ public LangChain4j build () {
124+ if (chatModel == null && streamingChatModel == null ) {
125+ throw new IllegalStateException (
126+ "At least one of chatModel or streamingChatModel must be provided" );
127+ }
128+
129+ String effectiveModelName = modelName ;
130+ if (effectiveModelName == null ) {
131+ if (chatModel != null ) {
132+ effectiveModelName = chatModel .defaultRequestParameters ().modelName ();
133+ } else {
134+ effectiveModelName = streamingChatModel .defaultRequestParameters ().modelName ();
135+ }
136+ }
137+
138+ if (effectiveModelName == null ) {
139+ throw new IllegalStateException ("Model name cannot be null" );
140+ }
141+
142+ return new LangChain4j (
143+ chatModel , streamingChatModel , effectiveModelName , tokenCountEstimator );
144+ }
145+ }
146+
147+ private LangChain4j (
148+ ChatModel chatModel ,
149+ StreamingChatModel streamingChatModel ,
150+ String modelName ,
151+ TokenCountEstimator tokenCountEstimator ) {
152+ super (Objects .requireNonNull (modelName , "model name cannot be null" ));
153+ this .chatModel = chatModel ;
154+ this .streamingChatModel = streamingChatModel ;
155+ this .objectMapper = new ObjectMapper ();
156+ this .tokenCountEstimator = tokenCountEstimator ;
157+ }
86158
87159 public LangChain4j (ChatModel chatModel ) {
160+ this (chatModel , (TokenCountEstimator ) null );
161+ }
162+
163+ public LangChain4j (ChatModel chatModel , TokenCountEstimator tokenCountEstimator ) {
88164 super (
89165 Objects .requireNonNull (
90166 chatModel .defaultRequestParameters ().modelName (), "chat model name cannot be null" ));
91167 this .chatModel = Objects .requireNonNull (chatModel , "chatModel cannot be null" );
92168 this .streamingChatModel = null ;
93169 this .objectMapper = new ObjectMapper ();
170+ this .tokenCountEstimator = tokenCountEstimator ;
94171 }
95172
96173 public LangChain4j (ChatModel chatModel , String modelName ) {
174+ this (chatModel , modelName , (TokenCountEstimator ) null );
175+ }
176+
177+ public LangChain4j (
178+ ChatModel chatModel , String modelName , TokenCountEstimator tokenCountEstimator ) {
97179 super (Objects .requireNonNull (modelName , "chat model name cannot be null" ));
98180 this .chatModel = Objects .requireNonNull (chatModel , "chatModel cannot be null" );
99181 this .streamingChatModel = null ;
100182 this .objectMapper = new ObjectMapper ();
183+ this .tokenCountEstimator = tokenCountEstimator ;
101184 }
102185
103186 public LangChain4j (StreamingChatModel streamingChatModel ) {
187+ this (streamingChatModel , (TokenCountEstimator ) null );
188+ }
189+
190+ public LangChain4j (
191+ StreamingChatModel streamingChatModel , TokenCountEstimator tokenCountEstimator ) {
104192 super (
105193 Objects .requireNonNull (
106194 streamingChatModel .defaultRequestParameters ().modelName (),
@@ -109,22 +197,23 @@ public LangChain4j(StreamingChatModel streamingChatModel) {
109197 this .streamingChatModel =
110198 Objects .requireNonNull (streamingChatModel , "streamingChatModel cannot be null" );
111199 this .objectMapper = new ObjectMapper ();
200+ this .tokenCountEstimator = tokenCountEstimator ;
112201 }
113202
114203 public LangChain4j (StreamingChatModel streamingChatModel , String modelName ) {
115- super (Objects .requireNonNull (modelName , "streaming chat model name cannot be null" ));
116- this .chatModel = null ;
117- this .streamingChatModel =
118- Objects .requireNonNull (streamingChatModel , "streamingChatModel cannot be null" );
119- this .objectMapper = new ObjectMapper ();
204+ this (streamingChatModel , modelName , (TokenCountEstimator ) null );
120205 }
121206
122- public LangChain4j (ChatModel chatModel , StreamingChatModel streamingChatModel , String modelName ) {
123- super (Objects .requireNonNull (modelName , "model name cannot be null" ));
124- this .chatModel = Objects .requireNonNull (chatModel , "chatModel cannot be null" );
207+ public LangChain4j (
208+ StreamingChatModel streamingChatModel ,
209+ String modelName ,
210+ TokenCountEstimator tokenCountEstimator ) {
211+ super (Objects .requireNonNull (modelName , "streaming chat model name cannot be null" ));
212+ this .chatModel = null ;
125213 this .streamingChatModel =
126214 Objects .requireNonNull (streamingChatModel , "streamingChatModel cannot be null" );
127215 this .objectMapper = new ObjectMapper ();
216+ this .tokenCountEstimator = tokenCountEstimator ;
128217 }
129218
130219 @ Override
@@ -185,7 +274,7 @@ public void onError(Throwable throwable) {
185274
186275 ChatRequest chatRequest = toChatRequest (llmRequest );
187276 ChatResponse chatResponse = chatModel .chat (chatRequest );
188- LlmResponse llmResponse = toLlmResponse (chatResponse );
277+ LlmResponse llmResponse = toLlmResponse (chatResponse , chatRequest );
189278
190279 return Flowable .just (llmResponse );
191280 }
@@ -496,11 +585,38 @@ private JsonSchemaElement toJsonSchemaElement(Schema schema) {
496585 }
497586 }
498587
499- private LlmResponse toLlmResponse (ChatResponse chatResponse ) {
588+ private LlmResponse toLlmResponse (ChatResponse chatResponse , ChatRequest chatRequest ) {
500589 Content content =
501590 Content .builder ().role ("model" ).parts (toParts (chatResponse .aiMessage ())).build ();
502591
503- return LlmResponse .builder ().content (content ).build ();
592+ LlmResponse .Builder builder = LlmResponse .builder ().content (content );
593+ TokenUsage tokenUsage = chatResponse .tokenUsage ();
594+ if (tokenCountEstimator != null ) {
595+ try {
596+ int estimatedInput =
597+ tokenCountEstimator .estimateTokenCountInMessages (chatRequest .messages ());
598+ int estimatedOutput =
599+ tokenCountEstimator .estimateTokenCountInText (chatResponse .aiMessage ().text ());
600+ int estimatedTotal = estimatedInput + estimatedOutput ;
601+ builder .usageMetadata (
602+ GenerateContentResponseUsageMetadata .builder ()
603+ .promptTokenCount (estimatedInput )
604+ .candidatesTokenCount (estimatedOutput )
605+ .totalTokenCount (estimatedTotal )
606+ .build ());
607+ } catch (Exception e ) {
608+ e .printStackTrace ();
609+ }
610+ } else if (tokenUsage != null ) {
611+ builder .usageMetadata (
612+ GenerateContentResponseUsageMetadata .builder ()
613+ .promptTokenCount (tokenUsage .inputTokenCount ())
614+ .candidatesTokenCount (tokenUsage .outputTokenCount ())
615+ .totalTokenCount (tokenUsage .totalTokenCount ())
616+ .build ());
617+ }
618+
619+ return builder .build ();
504620 }
505621
506622 private List <Part > toParts (AiMessage aiMessage ) {
0 commit comments