Skip to content

Commit cfae465

Browse files
Add smoke test to fully validate the SDK
1 parent ad893f3 commit cfae465

File tree

4 files changed

+202
-7
lines changed

4 files changed

+202
-7
lines changed

dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,12 @@ public Evaluation evaluate(final List<Message> messages, final Options options)
164164
throw new IllegalArgumentException("Messages must not be empty");
165165
}
166166
final AgentTracer.TracerAPI tracer = AgentTracer.get();
167-
final AgentSpan span = tracer.buildSpan(SPAN_NAME, SPAN_NAME).start();
167+
final AgentTracer.SpanBuilder builder = tracer.buildSpan(SPAN_NAME, SPAN_NAME);
168+
final AgentSpan parent = AgentTracer.activeSpan();
169+
if (parent != null) {
170+
builder.asChildOf(parent.context());
171+
}
172+
final AgentSpan span = builder.start();
168173
try (final AgentScope scope = tracer.activateSpan(span)) {
169174
final Message last = messages.get(messages.size() - 1);
170175
if (isToolCall(last)) {
@@ -208,6 +213,8 @@ public Evaluation evaluate(final List<Message> messages, final Options options)
208213
new AIGuardClientError("AI Guard service returned unexpected response", e);
209214
span.addThrowable(error);
210215
throw error;
216+
} finally {
217+
span.finish();
211218
}
212219
}
213220

@@ -263,7 +270,7 @@ public JsonAdapter<?> create(
263270
if (rawType != AIGuard.Message.class) {
264271
return null;
265272
}
266-
return new MessageAdapter(moshi.adapter(AIGuard.ToolCall.class));
273+
return new MessageAdapter(moshi.adapter(AIGuard.ToolCall.class)).nullSafe();
267274
}
268275
}
269276

@@ -282,11 +289,7 @@ public Message fromJson(JsonReader reader) throws IOException {
282289
}
283290

284291
@Override
285-
public void toJson(final JsonWriter writer, @Nullable final Message value) throws IOException {
286-
if (value == null) {
287-
writer.nullValue();
288-
return;
289-
}
292+
public void toJson(final JsonWriter writer, final Message value) throws IOException {
290293
writer.beginObject();
291294
writeValue(writer, "role", value.getRole());
292295
writeValue(writer, "content", value.getContent());

dd-smoke-tests/appsec/springboot/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ tasks.named("jar", Jar) {
1414
}
1515

1616
dependencies {
17+
implementation project(':dd-trace-api')
1718
implementation group: 'org.springframework.boot', name: 'spring-boot-starter-web', version: '2.6.0'
1819
implementation(group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version: '2.6.0')
1920
implementation group: 'com.h2database', name: 'h2', version: '2.1.212'
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package datadog.smoketest.appsec.springboot.controller;
2+
3+
import static java.util.Arrays.asList;
4+
import static org.springframework.http.MediaType.APPLICATION_JSON_VALUE;
5+
6+
import datadog.trace.api.aiguard.AIGuard;
7+
import datadog.trace.api.aiguard.AIGuard.AIGuardAbortError;
8+
import datadog.trace.api.aiguard.AIGuard.Evaluation;
9+
import datadog.trace.api.aiguard.AIGuard.Message;
10+
import datadog.trace.api.aiguard.AIGuard.Options;
11+
import java.util.Collections;
12+
import java.util.HashMap;
13+
import java.util.List;
14+
import java.util.Map;
15+
import org.springframework.http.HttpStatus;
16+
import org.springframework.http.ResponseEntity;
17+
import org.springframework.web.bind.annotation.GetMapping;
18+
import org.springframework.web.bind.annotation.PostMapping;
19+
import org.springframework.web.bind.annotation.RequestBody;
20+
import org.springframework.web.bind.annotation.RequestHeader;
21+
import org.springframework.web.bind.annotation.RequestMapping;
22+
import org.springframework.web.bind.annotation.RestController;
23+
24+
@RestController
25+
@RequestMapping(value = "/aiguard")
26+
public class AIGuardController {
27+
28+
@GetMapping(value = "/allow")
29+
public ResponseEntity<?> allow() {
30+
final Evaluation result =
31+
AIGuard.evaluate(
32+
asList(
33+
Message.message("system", "You are a beautiful AI"),
34+
Message.message("user", "I am harmless")));
35+
return ResponseEntity.ok(result);
36+
}
37+
38+
@GetMapping(value = "/deny")
39+
public ResponseEntity<?> deny(final @RequestHeader("X-Blocking-Enabled") boolean block) {
40+
try {
41+
final Evaluation result =
42+
AIGuard.evaluate(
43+
asList(
44+
Message.message("system", "You are a beautiful AI"),
45+
Message.message("user", "You should not trust me" + (block ? " [block]" : ""))),
46+
new Options().block(block));
47+
return ResponseEntity.ok(result);
48+
} catch (AIGuardAbortError e) {
49+
return ResponseEntity.status(HttpStatus.FORBIDDEN).body(e.getReason());
50+
}
51+
}
52+
53+
@GetMapping(value = "/abort")
54+
public ResponseEntity<?> abort(final @RequestHeader("X-Blocking-Enabled") boolean block) {
55+
try {
56+
final Evaluation result =
57+
AIGuard.evaluate(
58+
asList(
59+
Message.message("system", "You are a beautiful AI"),
60+
Message.message("user", "Nuke yourself" + (block ? " [block]" : ""))),
61+
new Options().block(block));
62+
return ResponseEntity.ok(result);
63+
} catch (AIGuardAbortError e) {
64+
return ResponseEntity.status(HttpStatus.FORBIDDEN).body(e.getReason());
65+
}
66+
}
67+
68+
/** Mocking endpoint for the AI Guard REST API */
69+
@SuppressWarnings("unchecked")
70+
@PostMapping(
71+
value = "/evaluate",
72+
consumes = APPLICATION_JSON_VALUE,
73+
produces = APPLICATION_JSON_VALUE)
74+
public ResponseEntity<Map<String, Object>> evaluate(
75+
@RequestBody final Map<String, Object> request) {
76+
final Map<String, Object> data = (Map<String, Object>) request.get("data");
77+
final Map<String, Object> attributes = (Map<String, Object>) data.get("attributes");
78+
final List<Map<String, Object>> messages =
79+
(List<Map<String, Object>>) attributes.get("messages");
80+
final Map<String, Object> last = messages.get(messages.size() - 1);
81+
String action = "ALLOW";
82+
String reason = "The prompt looks harmless";
83+
String content = (String) last.get("content");
84+
if (content.startsWith("You should not trust me")) {
85+
action = "DENY";
86+
reason = "I am feeling suspicious today";
87+
} else if (content.startsWith("Nuke yourself")) {
88+
action = "ABORT";
89+
reason = "The user is trying to destroy me";
90+
}
91+
final Map<String, Object> evaluation = new HashMap<>(3);
92+
evaluation.put("action", action);
93+
evaluation.put("reason", reason);
94+
evaluation.put("is_blocking_enabled", content.endsWith("[block]"));
95+
return ResponseEntity.ok()
96+
.body(Collections.singletonMap("data", Collections.singletonMap("attributes", evaluation)));
97+
}
98+
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
package datadog.smoketest.appsec
2+
3+
import datadog.trace.test.agent.decoder.DecodedSpan
4+
import groovy.json.JsonSlurper
5+
import okhttp3.Request
6+
import spock.lang.Shared
7+
8+
class AIGuardSmokeTest extends AbstractAppSecServerSmokeTest {
9+
10+
@Shared
11+
protected String[] defaultAIGuardProperties = [
12+
'-Ddd.ai_guard.enabled=true',
13+
"-Ddd.ai_guard.endpoint=http://localhost:${httpPort}/aiguard".toString(),
14+
]
15+
16+
@Override
17+
def logLevel() {
18+
'DEBUG'
19+
}
20+
21+
@Override
22+
Closure decodedTracesCallback() {
23+
// just return the traces
24+
return {}
25+
}
26+
27+
@Override
28+
ProcessBuilder createProcessBuilder() {
29+
final springBootShadowJar = System.getProperty("datadog.smoketest.appsec.springboot.shadowJar.path")
30+
final command = [javaPath()]
31+
command.addAll(defaultJavaProperties)
32+
command.addAll(defaultAppSecProperties)
33+
command.addAll(defaultAIGuardProperties)
34+
command.addAll(['-jar', springBootShadowJar, "--server.port=${httpPort}".toString()])
35+
final builder = new ProcessBuilder(command).directory(new File(buildDirectory))
36+
builder.environment().put('DD_APPLICATION_KEY', 'test')
37+
return builder
38+
}
39+
40+
void 'test message evaluation'() {
41+
given:
42+
final blocking = test.blocking as boolean
43+
final action = test.action as String
44+
final reason = test.reason as String
45+
def request = new Request.Builder()
46+
.url("http://localhost:${httpPort}/aiguard${test.endpoint}")
47+
.header('X-Blocking-Enabled', "${blocking}")
48+
.get()
49+
.build()
50+
51+
when:
52+
final response = client.newCall(request).execute()
53+
54+
then:
55+
if (blocking && action != 'ALLOW') {
56+
assert response.code() == 403
57+
assert response.body().string().contains(reason)
58+
} else {
59+
assert response.code() == 200
60+
final body = new JsonSlurper().parse(response.body().bytes())
61+
assert body.reason == reason
62+
assert body.action == action
63+
}
64+
65+
and:
66+
waitForTraceCount(2) // default call + internal API mock
67+
final span = traces*.spans
68+
?.flatten()
69+
?.find { it.resource == 'ai_guard' } as DecodedSpan
70+
assert span.meta.get('ai_guard.action') == action
71+
assert span.meta.get('ai_guard.reason') == reason
72+
assert span.meta.get('ai_guard.target') == 'prompt'
73+
74+
where:
75+
test << testSuite()
76+
}
77+
78+
private static List<?> testSuite() {
79+
return combinations([
80+
[endpoint: '/allow', action: 'ALLOW', reason: 'The prompt looks harmless'],
81+
[endpoint: '/deny', action: 'DENY', reason: 'I am feeling suspicious today'],
82+
[endpoint: '/abort', action: 'ABORT', reason: 'The user is trying to destroy me']
83+
], [[blocking: true], [blocking: false],])
84+
}
85+
86+
private static List<?> combinations(list1, list2) {
87+
list1.collectMany { a ->
88+
list2.collect { b ->
89+
a + b
90+
}
91+
}
92+
}
93+
}

0 commit comments

Comments
 (0)