1
1
package com .datadog .aiguard ;
2
2
3
+ import static java .util .Collections .singletonMap ;
4
+
5
+ import com .squareup .moshi .JsonAdapter ;
3
6
import com .squareup .moshi .JsonReader ;
4
7
import com .squareup .moshi .JsonWriter ;
5
8
import com .squareup .moshi .Moshi ;
9
+ import com .squareup .moshi .Types ;
6
10
import datadog .communication .http .OkHttpUtils ;
7
11
import datadog .trace .api .Config ;
8
12
import datadog .trace .api .aiguard .AIGuard ;
20
24
import datadog .trace .bootstrap .instrumentation .api .AgentSpan ;
21
25
import datadog .trace .bootstrap .instrumentation .api .AgentTracer ;
22
26
import java .io .IOException ;
27
+ import java .lang .annotation .Annotation ;
28
+ import java .lang .reflect .Type ;
23
29
import java .util .Collection ;
24
- import java .util .Collections ;
25
30
import java .util .HashMap ;
26
31
import java .util .List ;
27
32
import java .util .Map ;
33
+ import java .util .Set ;
28
34
import java .util .stream .Collectors ;
29
35
import javax .annotation .Nullable ;
30
36
import okhttp3 .HttpUrl ;
36
42
import okhttp3 .ResponseBody ;
37
43
import okio .BufferedSink ;
38
44
45
+ /**
46
+ * Concrete implementation of the SDK used to interact with the AIGuard REST API.
47
+ *
48
+ * <p>An instance of this class is initialized and configured automatically during agent startup
49
+ * through {@link AIGuardSystem#start()}.
50
+ */
39
51
public class AIGuardInternal implements Evaluator {
40
52
41
53
public static class BadConfigurationException extends RuntimeException {
@@ -87,7 +99,7 @@ static void uninstall() {
87
99
this .url = url ;
88
100
this .headers = headers ;
89
101
this .client = client ;
90
- this .moshi = new Moshi .Builder ().build ();
102
+ this .moshi = new Moshi .Builder ().add ( new AIGuardFactory ()). build ();
91
103
final Config config = Config .get ();
92
104
this .meta = mapOf ("service" , config .getServiceName (), "env" , config .getEnv ());
93
105
}
@@ -126,21 +138,20 @@ private static String getToolName(final Message current, final List<Message> mes
126
138
.map (ToolCall ::getFunction )
127
139
.map (Function ::getName )
128
140
.collect (Collectors .joining ("," ));
129
- } else {
130
- // assistant message with tool output (search the linked tool call in reverse order)
131
- final String id = current .getToolCallId ();
132
- for (int i = messages .size () - 1 ; i >= 0 ; i --) {
133
- final Message message = messages .get (i );
134
- if (message .getToolCalls () != null ) {
135
- for (final ToolCall toolCall : message .getToolCalls ()) {
136
- if (toolCall .getId ().equals (id )) {
137
- return toolCall .getFunction () == null ? null : toolCall .getFunction ().getName ();
138
- }
141
+ }
142
+ // assistant message with tool output (search the linked tool call in reverse order)
143
+ final String id = current .getToolCallId ();
144
+ for (int i = messages .size () - 1 ; i >= 0 ; i --) {
145
+ final Message message = messages .get (i );
146
+ if (message .getToolCalls () != null ) {
147
+ for (final ToolCall toolCall : message .getToolCalls ()) {
148
+ if (toolCall .getId ().equals (id )) {
149
+ return toolCall .getFunction () == null ? null : toolCall .getFunction ().getName ();
139
150
}
140
151
}
141
152
}
142
- return null ;
143
153
}
154
+ return null ;
144
155
}
145
156
146
157
private boolean isBlockingEnabled (final Object isBlockingEnabled ) {
@@ -155,18 +166,17 @@ public Evaluation evaluate(final List<Message> messages, final Options options)
155
166
final AgentTracer .TracerAPI tracer = AgentTracer .get ();
156
167
final AgentSpan span = tracer .buildSpan (SPAN_NAME , SPAN_NAME ).start ();
157
168
try (final AgentScope scope = tracer .activateSpan (span )) {
158
- final Message current = messages .get (messages .size () - 1 );
159
- if (isToolCall (current )) {
169
+ final Message last = messages .get (messages .size () - 1 );
170
+ if (isToolCall (last )) {
160
171
span .setTag (TARGET_TAG , "tool" );
161
- final String toolName = getToolName (current , messages );
172
+ final String toolName = getToolName (last , messages );
162
173
if (toolName != null ) {
163
174
span .setTag (TOOL_TAG , toolName );
164
175
}
165
176
} else {
166
177
span .setTag (TARGET_TAG , "prompt" );
167
178
}
168
- final Map <String , Object > metaStruct =
169
- Collections .singletonMap (META_STRUCT_KEY , truncate (messages ));
179
+ final Map <String , Object > metaStruct = singletonMap (META_STRUCT_KEY , truncate (messages ));
170
180
span .setMetaStruct (META_STRUCT_TAG , metaStruct );
171
181
final Request .Builder request =
172
182
new Request .Builder ()
@@ -243,6 +253,69 @@ public static void install(final Evaluator evaluator) {
243
253
}
244
254
}
245
255
256
+ static class AIGuardFactory implements JsonAdapter .Factory {
257
+
258
+ @ Nullable
259
+ @ Override
260
+ public JsonAdapter <?> create (
261
+ final Type type , final Set <? extends Annotation > annotations , final Moshi moshi ) {
262
+ final Class <?> rawType = Types .getRawType (type );
263
+ if (rawType != AIGuard .Message .class ) {
264
+ return null ;
265
+ }
266
+ return new MessageAdapter (moshi .adapter (AIGuard .ToolCall .class ));
267
+ }
268
+ }
269
+
270
+ static class MessageAdapter extends JsonAdapter <Message > {
271
+
272
+ private final JsonAdapter <AIGuard .ToolCall > toolCallAdapter ;
273
+
274
+ MessageAdapter (final JsonAdapter <ToolCall > toolCallAdapter ) {
275
+ this .toolCallAdapter = toolCallAdapter ;
276
+ }
277
+
278
+ @ Nullable
279
+ @ Override
280
+ public Message fromJson (JsonReader reader ) throws IOException {
281
+ throw new UnsupportedOperationException ("Serializing only adapter" );
282
+ }
283
+
284
+ @ Override
285
+ public void toJson (final JsonWriter writer , @ Nullable final Message value ) throws IOException {
286
+ if (value == null ) {
287
+ writer .nullValue ();
288
+ return ;
289
+ }
290
+ writer .beginObject ();
291
+ writeValue (writer , "role" , value .getRole ());
292
+ writeValue (writer , "content" , value .getContent ());
293
+ writeArray (writer , "tool_calls" , value .getToolCalls ());
294
+ writeValue (writer , "tool_call_id" , value .getToolCallId ());
295
+ writer .endObject ();
296
+ }
297
+
298
+ private void writeValue (final JsonWriter writer , final String name , final Object value )
299
+ throws IOException {
300
+ if (value != null ) {
301
+ writer .name (name );
302
+ writer .jsonValue (value );
303
+ }
304
+ }
305
+
306
+ private void writeArray (final JsonWriter writer , final String name , final List <ToolCall > value )
307
+ throws IOException {
308
+ if (value != null ) {
309
+ writer .name (name );
310
+ writer .beginArray ();
311
+ for (final ToolCall toolCall : value ) {
312
+ toolCallAdapter .toJson (writer , toolCall );
313
+ }
314
+ writer .endArray ();
315
+ }
316
+ }
317
+ }
318
+
246
319
static class MoshiJsonRequestBody extends RequestBody {
247
320
248
321
private static final MediaType JSON = MediaType .parse ("application/json" );
0 commit comments