Skip to content

Commit ad893f3

Browse files
Apply PR feedback (3)
1 parent e5e469f commit ad893f3

File tree

5 files changed

+120
-35
lines changed

5 files changed

+120
-35
lines changed

dd-java-agent/agent-aiguard/build.gradle

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import com.github.jengelman.gradle.plugins.shadow.tasks.ShadowJar
2+
13
plugins {
24
id 'com.gradleup.shadow'
35
}
@@ -20,19 +22,15 @@ dependencies {
2022
implementation project(':communication')
2123

2224
testImplementation project(':utils:test-utils')
23-
testImplementation('org.skyscreamer:jsonassert:1.5.1')
25+
testImplementation('org.skyscreamer:jsonassert:1.5.3')
26+
testImplementation('com.fasterxml.jackson.core:jackson-databind:2.20.0')
2427
}
2528

26-
shadowJar {
29+
tasks.named("shadowJar", ShadowJar) {
2730
dependencies deps.excludeShared
2831
}
2932

30-
jar {
33+
tasks.named("jar", Jar) {
3134
archiveClassifier = 'unbundled'
3235
}
3336

34-
spotless {
35-
java {
36-
target 'src/**/*.java'
37-
}
38-
}

dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java

Lines changed: 91 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
package com.datadog.aiguard;
22

3+
import static java.util.Collections.singletonMap;
4+
5+
import com.squareup.moshi.JsonAdapter;
36
import com.squareup.moshi.JsonReader;
47
import com.squareup.moshi.JsonWriter;
58
import com.squareup.moshi.Moshi;
9+
import com.squareup.moshi.Types;
610
import datadog.communication.http.OkHttpUtils;
711
import datadog.trace.api.Config;
812
import datadog.trace.api.aiguard.AIGuard;
@@ -20,11 +24,13 @@
2024
import datadog.trace.bootstrap.instrumentation.api.AgentSpan;
2125
import datadog.trace.bootstrap.instrumentation.api.AgentTracer;
2226
import java.io.IOException;
27+
import java.lang.annotation.Annotation;
28+
import java.lang.reflect.Type;
2329
import java.util.Collection;
24-
import java.util.Collections;
2530
import java.util.HashMap;
2631
import java.util.List;
2732
import java.util.Map;
33+
import java.util.Set;
2834
import java.util.stream.Collectors;
2935
import javax.annotation.Nullable;
3036
import okhttp3.HttpUrl;
@@ -36,6 +42,12 @@
3642
import okhttp3.ResponseBody;
3743
import okio.BufferedSink;
3844

45+
/**
46+
* Concrete implementation of the SDK used to interact with the AIGuard REST API.
47+
*
48+
* <p>An instance of this class is initialized and configured automatically during agent startup
49+
* through {@link AIGuardSystem#start()}.
50+
*/
3951
public class AIGuardInternal implements Evaluator {
4052

4153
public static class BadConfigurationException extends RuntimeException {
@@ -87,7 +99,7 @@ static void uninstall() {
8799
this.url = url;
88100
this.headers = headers;
89101
this.client = client;
90-
this.moshi = new Moshi.Builder().build();
102+
this.moshi = new Moshi.Builder().add(new AIGuardFactory()).build();
91103
final Config config = Config.get();
92104
this.meta = mapOf("service", config.getServiceName(), "env", config.getEnv());
93105
}
@@ -126,21 +138,20 @@ private static String getToolName(final Message current, final List<Message> mes
126138
.map(ToolCall::getFunction)
127139
.map(Function::getName)
128140
.collect(Collectors.joining(","));
129-
} else {
130-
// assistant message with tool output (search the linked tool call in reverse order)
131-
final String id = current.getToolCallId();
132-
for (int i = messages.size() - 1; i >= 0; i--) {
133-
final Message message = messages.get(i);
134-
if (message.getToolCalls() != null) {
135-
for (final ToolCall toolCall : message.getToolCalls()) {
136-
if (toolCall.getId().equals(id)) {
137-
return toolCall.getFunction() == null ? null : toolCall.getFunction().getName();
138-
}
141+
}
142+
// assistant message with tool output (search the linked tool call in reverse order)
143+
final String id = current.getToolCallId();
144+
for (int i = messages.size() - 1; i >= 0; i--) {
145+
final Message message = messages.get(i);
146+
if (message.getToolCalls() != null) {
147+
for (final ToolCall toolCall : message.getToolCalls()) {
148+
if (toolCall.getId().equals(id)) {
149+
return toolCall.getFunction() == null ? null : toolCall.getFunction().getName();
139150
}
140151
}
141152
}
142-
return null;
143153
}
154+
return null;
144155
}
145156

146157
private boolean isBlockingEnabled(final Object isBlockingEnabled) {
@@ -155,18 +166,17 @@ public Evaluation evaluate(final List<Message> messages, final Options options)
155166
final AgentTracer.TracerAPI tracer = AgentTracer.get();
156167
final AgentSpan span = tracer.buildSpan(SPAN_NAME, SPAN_NAME).start();
157168
try (final AgentScope scope = tracer.activateSpan(span)) {
158-
final Message current = messages.get(messages.size() - 1);
159-
if (isToolCall(current)) {
169+
final Message last = messages.get(messages.size() - 1);
170+
if (isToolCall(last)) {
160171
span.setTag(TARGET_TAG, "tool");
161-
final String toolName = getToolName(current, messages);
172+
final String toolName = getToolName(last, messages);
162173
if (toolName != null) {
163174
span.setTag(TOOL_TAG, toolName);
164175
}
165176
} else {
166177
span.setTag(TARGET_TAG, "prompt");
167178
}
168-
final Map<String, Object> metaStruct =
169-
Collections.singletonMap(META_STRUCT_KEY, truncate(messages));
179+
final Map<String, Object> metaStruct = singletonMap(META_STRUCT_KEY, truncate(messages));
170180
span.setMetaStruct(META_STRUCT_TAG, metaStruct);
171181
final Request.Builder request =
172182
new Request.Builder()
@@ -243,6 +253,69 @@ public static void install(final Evaluator evaluator) {
243253
}
244254
}
245255

256+
static class AIGuardFactory implements JsonAdapter.Factory {
257+
258+
@Nullable
259+
@Override
260+
public JsonAdapter<?> create(
261+
final Type type, final Set<? extends Annotation> annotations, final Moshi moshi) {
262+
final Class<?> rawType = Types.getRawType(type);
263+
if (rawType != AIGuard.Message.class) {
264+
return null;
265+
}
266+
return new MessageAdapter(moshi.adapter(AIGuard.ToolCall.class));
267+
}
268+
}
269+
270+
static class MessageAdapter extends JsonAdapter<Message> {
271+
272+
private final JsonAdapter<AIGuard.ToolCall> toolCallAdapter;
273+
274+
MessageAdapter(final JsonAdapter<ToolCall> toolCallAdapter) {
275+
this.toolCallAdapter = toolCallAdapter;
276+
}
277+
278+
@Nullable
279+
@Override
280+
public Message fromJson(JsonReader reader) throws IOException {
281+
throw new UnsupportedOperationException("Serializing only adapter");
282+
}
283+
284+
@Override
285+
public void toJson(final JsonWriter writer, @Nullable final Message value) throws IOException {
286+
if (value == null) {
287+
writer.nullValue();
288+
return;
289+
}
290+
writer.beginObject();
291+
writeValue(writer, "role", value.getRole());
292+
writeValue(writer, "content", value.getContent());
293+
writeArray(writer, "tool_calls", value.getToolCalls());
294+
writeValue(writer, "tool_call_id", value.getToolCallId());
295+
writer.endObject();
296+
}
297+
298+
private void writeValue(final JsonWriter writer, final String name, final Object value)
299+
throws IOException {
300+
if (value != null) {
301+
writer.name(name);
302+
writer.jsonValue(value);
303+
}
304+
}
305+
306+
private void writeArray(final JsonWriter writer, final String name, final List<ToolCall> value)
307+
throws IOException {
308+
if (value != null) {
309+
writer.name(name);
310+
writer.beginArray();
311+
for (final ToolCall toolCall : value) {
312+
toolCallAdapter.toJson(writer, toolCall);
313+
}
314+
writer.endArray();
315+
}
316+
}
317+
}
318+
246319
static class MoshiJsonRequestBody extends RequestBody {
247320

248321
private static final MediaType JSON = MediaType.parse("application/json");

dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
package com.datadog.aiguard
22

3+
import com.fasterxml.jackson.annotation.JsonInclude
4+
import com.fasterxml.jackson.databind.ObjectMapper
5+
import com.fasterxml.jackson.databind.PropertyNamingStrategies
36
import com.squareup.moshi.Moshi
47
import datadog.trace.api.Config
58
import datadog.trace.api.aiguard.AIGuard
@@ -40,6 +43,13 @@ class AIGuardInternalTests extends DDSpecification {
4043
@Shared
4144
protected static final MOSHI = new Moshi.Builder().build()
4245

46+
@Shared
47+
protected static final MAPPER = new ObjectMapper()
48+
.setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE)
49+
.setDefaultPropertyInclusion(
50+
JsonInclude.Value.construct(JsonInclude.Include.NON_NULL, JsonInclude.Include.NON_NULL)
51+
)
52+
4353
@Shared
4454
protected static final TOOL_CALL = [
4555
AIGuard.Message.message('system', 'You are a beautiful AI assistant'),
@@ -423,11 +433,15 @@ class AIGuardInternalTests extends DDSpecification {
423433
}
424434
assert request.body().contentType().toString().contains('application/json')
425435
final receivedBody = readRequestBody(request.body())
426-
final expectedBody = MOSHI.adapter(Map).toJson([data: [attributes: [messages: messages, meta: [service: 'ai_guard_test', env: 'test']]]])
427-
JSONAssert.assertEquals(expectedBody, receivedBody, JSONCompareMode.LENIENT)
436+
final expectedBody = snakeCaseJson([data: [attributes: [messages: messages, meta: [service: 'ai_guard_test', env: 'test']]]])
437+
JSONAssert.assertEquals(expectedBody, receivedBody, JSONCompareMode.NON_EXTENSIBLE)
428438
return true
429439
}
430440

441+
private static String snakeCaseJson(final Object value) {
442+
MAPPER.writeValueAsString(value)
443+
}
444+
431445
private static String readRequestBody(final RequestBody body) {
432446
final output = new ByteArrayOutputStream()
433447
final buffer = Okio.buffer(Okio.sink(output))

dd-trace-api/build.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ val excludedClassesCoverage by extra(
2525
"datadog.trace.api.internal.TraceSegment",
2626
"datadog.trace.api.internal.TraceSegment.NoOp",
2727
"datadog.trace.api.aiguard.AIGuard",
28-
"datadog.trace.api.aiguard.AIGuard.AIGuardClientError",
2928
"datadog.trace.api.aiguard.AIGuard.AIGuardAbortError",
29+
"datadog.trace.api.aiguard.AIGuard.AIGuardClientError",
3030
"datadog.trace.api.aiguard.AIGuard.Options",
3131
"datadog.trace.api.civisibility.CIVisibility",
3232
"datadog.trace.api.civisibility.DDTestModule",

dd-trace-api/src/main/java/datadog/trace/api/aiguard/AIGuard.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ public static class Message {
191191

192192
private final String role;
193193
private final String content;
194-
private final List<ToolCall> tool_calls;
195-
private final String tool_call_id;
194+
private final List<ToolCall> toolCalls;
195+
private final String toolCallId;
196196

197197
/**
198198
* Creates a new message with the specified parameters.
@@ -211,8 +211,8 @@ public Message(
211211
final String toolCallId) {
212212
this.role = role;
213213
this.content = content;
214-
this.tool_calls = toolCalls;
215-
this.tool_call_id = toolCallId;
214+
this.toolCalls = toolCalls;
215+
this.toolCallId = toolCallId;
216216
}
217217

218218
/**
@@ -239,7 +239,7 @@ public String getContent() {
239239
* @return list of tool calls, or null if this message has no tool calls
240240
*/
241241
public List<ToolCall> getToolCalls() {
242-
return tool_calls;
242+
return toolCalls;
243243
}
244244

245245
/**
@@ -248,7 +248,7 @@ public List<ToolCall> getToolCalls() {
248248
* @return the tool call ID, or null if this is not a tool response message
249249
*/
250250
public String getToolCallId() {
251-
return tool_call_id;
251+
return toolCallId;
252252
}
253253

254254
/**

0 commit comments

Comments
 (0)