Skip to content
5 changes: 4 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ dependencies {
// OpenAI Java SDK
implementation("com.openai:openai-java:4.16.1")

// JSON Schema Generator
implementation 'com.github.victools:jsonschema-generator:4.38.0'
implementation 'com.github.victools:jsonschema-module-jackson:4.38.0'

// 모니터링용
implementation 'org.springframework.boot:spring-boot-starter-actuator'
runtimeOnly 'io.micrometer:micrometer-registry-prometheus'
Expand Down Expand Up @@ -90,4 +94,3 @@ tasks.withType(JavaCompile).configureEach {
clean.doLast {
file(querydslDir).deleteDir()
}

Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@ public interface AiBatchService {
*/
String createBatch(Map<String, String> prompts);

/**
* 배치 작업을 요청한다.
* DTO 클래스를 지정하여 LLM 응답이 해당 형식을 준수하도록 보장한다.
*
* @param prompts 임의의 ID를 key로, prompt 내용을 value로 갖는 Map<br/>
* Batch API 결과는 순서가 보장되지 않으므로 구분을 위해 prompt마다 고유한 ID를 부여해야 한다.
* @param resultDtoClass 결과 DTO 클래스
* @return 배치 작업 ID
*/
String createBatch(Map<String, String> prompts, Class<?> resultDtoClass);

/**
* 배치 작업 완료 여부를 확인한다.
* 배치 작업에 실패한 경우 AiException 하위 예외를 던진다.
Expand All @@ -31,4 +42,13 @@ public interface AiBatchService {
*/
Map<String, String> getResults(String batchId);

/**
* 배치 작업의 응답을 가져와 지정한 DTO로 파싱한다.
*
* @param batchId 배치 작업 ID
* @param resultDtoClass 결과 DTO 클래스
* @return 요청에서 지정된 ID를 key로, 결과 DTO를 value로 갖는 Map
*/
<T> Map<String, T> getResults(String batchId, Class<T> resultDtoClass);

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.github.victools.jsonschema.generator.SchemaGenerator;
import com.nova.nova_server.domain.ai.exception.AiException;
import com.nova.nova_server.global.config.OpenAIConfig;
import com.openai.client.OpenAIClient;
Expand All @@ -27,6 +28,7 @@
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

import static com.openai.models.batches.Batch.Status.*;

Expand All @@ -38,6 +40,7 @@ public class OpenAiBatchService implements AiBatchService {
private final OpenAIClient client;
private final OpenAIConfig config;
private final ObjectMapper objectMapper;
private final SchemaGenerator schemaGenerator;

@Override
public String createBatch(Map<String, String> prompts) {
Expand All @@ -53,6 +56,22 @@ public String createBatch(Map<String, String> prompts) {
return requestBatch(inputFileId);
}

@Override
public String createBatch(Map<String, String> prompts, Class<?> resultDtoClass) {
validatePrompts(prompts);
validateResultDtoClass(resultDtoClass);

String batchInput = createBatchInput(
prompts,
config.getModel(),
config.getTemperature(),
resultDtoClass
);
String inputFileId = uploadBatchInput(batchInput);

return requestBatch(inputFileId);
}

@Override
public boolean isCompleted(String batchId) {
validateBatchId(batchId);
Expand All @@ -70,16 +89,19 @@ public Map<String, String> getResults(String batchId) {
}

String batchOutput = fetchBatchOutput(batch);
long total = batch.requestCounts()
.orElseThrow(() -> {
log.error("배치 요청 수를 확인할 수 없습니다. batchId={}", batchId);
return new AiException.InvalidBatchOutputException("배치 요청 수를 확인할 수 없습니다.");
})
.total();

return parseBatchOutput(batchOutput);
}

@Override
public <T> Map<String, T> getResults(String batchId, Class<T> resultDtoClass) {
validateResultDtoClass(resultDtoClass);

Map<String, String> rawResults = getResults(batchId);

return parseBatchResults(rawResults, resultDtoClass);
}

/**
* OpenAI 배치 작업에 필요한 jsonl 형식의 입력 문자열을 생성한다.
*
Expand All @@ -89,7 +111,21 @@ public Map<String, String> getResults(String batchId) {
* @return 배치 입력 문자열 (jsonl 형식)
*/
private String createBatchInput(Map<String, String> prompts, String model, double temperature) {
return createBatchInput(prompts, model, temperature, null);
}

/**
* OpenAI 배치 작업에 필요한 jsonl 형식의 입력 문자열을 생성한다.
*
* @param prompts prompt map
* @param model OpenAI LLM 모델 이름
* @param temperature temperature 값
* @param resultDtoClass 결과 DTO 클래스 (Structured Outputs 설정에 사용)
* @return 배치 입력 문자열 (jsonl 형식)
*/
private String createBatchInput(Map<String, String> prompts, String model, double temperature, Class<?> resultDtoClass) {
StringBuilder jsonlBuilder = new StringBuilder();
ObjectNode responseFormatNode = Optional.ofNullable(resultDtoClass).map(this::createResponseFormatNode).orElse(null);

for (String key : prompts.keySet()) {
ObjectNode requestNode = objectMapper.createObjectNode();
Expand All @@ -100,6 +136,9 @@ private String createBatchInput(Map<String, String> prompts, String model, doubl
ObjectNode bodyNode = objectMapper.createObjectNode();
bodyNode.put("model", model);
bodyNode.put("temperature", temperature);
if (responseFormatNode != null) {
bodyNode.set("response_format", responseFormatNode);
}

ObjectNode messageNode = objectMapper.createObjectNode();
messageNode.put("role", "user");
Expand All @@ -114,6 +153,25 @@ private String createBatchInput(Map<String, String> prompts, String model, doubl
return jsonlBuilder.toString();
}

/**
* Structured Outputs 설정을 위한 response_format 노드를 생성한다.
*
* @param resultDtoClass 결과 DTO 클래스
* @return response_format 노드
*/
private ObjectNode createResponseFormatNode(Class<?> resultDtoClass) {
ObjectNode responseFormatNode = objectMapper.createObjectNode();
responseFormatNode.put("type", "json_schema");

ObjectNode jsonSchemaNode = objectMapper.createObjectNode();
jsonSchemaNode.put("name", resultDtoClass.getSimpleName());
jsonSchemaNode.put("strict", true);
jsonSchemaNode.set("schema", schemaGenerator.generateSchema(resultDtoClass));

responseFormatNode.set("json_schema", jsonSchemaNode);
return responseFormatNode;
}

/**
* 배치 입력 파일을 업로드한다.
*
Expand Down Expand Up @@ -182,13 +240,12 @@ private boolean isCompleted(Batch batch) {
* @return 배치 작업 결과 문자열 (jsonl 형식)
*/
private String fetchBatchOutput(Batch batch) {
StringBuffer outputBuffer = new StringBuffer();

batch.outputFileId().ifPresent(fileId -> {
outputBuffer.append(fetchBatchOutputFile(fileId));
});

return outputBuffer.toString();
return batch.outputFileId()
.map(this::fetchBatchOutputFile)
.orElseThrow(() -> {
log.error("배치 결과 파일이 존재하지 않습니다. batchId={}", batch.id());
return new AiException.InvalidBatchOutputException("배치 결과 파일이 존재하지 않습니다.");
});
}

/**
Expand Down Expand Up @@ -235,13 +292,44 @@ private Map<String, String> parseBatchOutput(String batchOutput) {
return resultMap;
}

/**
* 배치 작업 결과 Map의 value를 지정한 DTO로 파싱한다.
*
* @param rawResults 배치 작업 결과 Map
* @return 요청에서 지정된 ID를 key로, 결과 DTO를 value로 갖는 Map
*/
private <T> Map<String, T> parseBatchResults(Map<String, String> rawResults, Class<T> resultDtoClass) {
Map<String, T> parsedResults = new HashMap<>();

for (String customId : rawResults.keySet()) {
String rawContent = rawResults.get(customId);
if (!StringUtils.hasText(rawContent)) {
continue;
}

try {
T parsed = objectMapper.readValue(rawContent, resultDtoClass);
parsedResults.put(customId, parsed);
} catch (Exception e) {
log.warn("배치 결과 DTO 파싱에 실패했습니다. customId={}, content={}", customId, rawContent);
throw new AiException.InvalidBatchOutputException("배치 결과 DTO 파싱에 실패했습니다.");
}
}
return parsedResults;
}

private void validatePrompts(Map<String, String> prompts) {
if (CollectionUtils.isEmpty(prompts))
throw new AiException.InvalidBatchInputException("배치 입력이 누락되었습니다.");
if (prompts.size() > config.getMaxRequestPerBatch())
throw new AiException.InvalidBatchInputException("배치 당 최대 요청수를 초과했습니다.");
}

private void validateResultDtoClass(Class<?> resultDtoClass) {
if (resultDtoClass == null)
throw new AiException.InvalidBatchInputException("결과 DTO 클래스가 누락되었습니다.");
}

private void validateBatchId(String batchId) {
if (!StringUtils.hasText(batchId))
throw new AiException.InvalidBatchIdException("배치 ID가 누락되었습니다.");
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import com.nova.nova_server.domain.ai.exception.AiException;
import com.nova.nova_server.domain.ai.service.AiBatchService;
import com.nova.nova_server.domain.batch.cardnews.converter.AiResponseConverter;
import com.nova.nova_server.domain.batch.common.entity.AiBatchEntity;
import com.nova.nova_server.domain.batch.common.entity.AiBatchState;
import com.nova.nova_server.domain.batch.common.entity.ArticleEntity;
Expand Down Expand Up @@ -46,8 +45,11 @@ public void processBatchResult(AiBatchEntity entity) {
}

private void onBatchSuccess(String batchId) {
Map<String, String> batchResult = aiBatchService.getResults(batchId);
Map<Long, LlmSummaryResult> summaryResult = AiResponseConverter.fromBatchResult(batchResult);
Map<Long, LlmSummaryResult> summaryResult = aiBatchService.getResults(batchId, LlmSummaryResult.class)
.entrySet().stream().collect(Collectors.toMap(
entry -> Long.parseLong(entry.getKey()),
Map.Entry::getValue
));
Map<Long, ArticleEntity> entities = articleEntityRepository.findAllByIdIn(summaryResult.keySet())
.stream()
.collect(Collectors.toMap(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.nova.nova_server.domain.batch.summary.dto;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;

import java.util.List;

Expand All @@ -9,7 +10,8 @@
*/
@JsonIgnoreProperties(ignoreUnknown = true)
public record LlmSummaryResult(
String summary,
List<String> evidence,
List<String> keywords) {
@JsonProperty(required = true) String summary,
@JsonProperty(required = true) List<String> evidence,
@JsonProperty(required = true) List<String> keywords
) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public void write(Chunk<? extends ArticleEntity> chunk) {
log.info("ArticleSummaryWriter: Processing chunk of {} items", chunk.size());

Map<String, String> prompts = PromptConverter.toPromptMap(chunk.getItems());
String batchId = aiBatchService.createBatch(prompts);
String batchId = aiBatchService.createBatch(prompts, LlmSummaryResult.class);
aiBatchRepository.save(AiBatchEntity.fromBatchId(batchId));
log.info("Batch submitted. BatchId: {}, Count: {}", batchId, chunk.size());

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package com.nova.nova_server.global.config;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.victools.jsonschema.generator.*;
import com.github.victools.jsonschema.module.jackson.JacksonModule;
import com.github.victools.jsonschema.module.jackson.JacksonOption;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
public class JsonSchemaConfig {

@Bean
public SchemaGenerator schemaGenerator(ObjectMapper objectMapper) {
SchemaGeneratorConfig config = new SchemaGeneratorConfigBuilder(
objectMapper,
SchemaVersion.DRAFT_2020_12,
OptionPreset.PLAIN_JSON
)
.with(Option.FORBIDDEN_ADDITIONAL_PROPERTIES_BY_DEFAULT)
.with(new JacksonModule(JacksonOption.RESPECT_JSONPROPERTY_REQUIRED))
.build();

return new SchemaGenerator(config);
}

}
4 changes: 4 additions & 0 deletions src/test/resources/application-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ oauth:
client-secret: test-secret
redirect-uri: http://localhost:8080/auth/github/callback

batch:
article-ingestion:
cron: "0 0 18 * * ?"

ai:
openai:
key: test-key
Expand Down