diff --git a/data/users/morgan-platinum.txt b/data/users/morgan-platinum.txt new file mode 100644 index 0000000..c5a83db --- /dev/null +++ b/data/users/morgan-platinum.txt @@ -0,0 +1,14 @@ + +Notes about Morgan Platinum: + +- He lives in San Francisco +- He is a software engineer at a tech company +- He doesn't rise early +- He prefers redeyes for flights to the East Coast +- He dislikes transiting in Chicago O'Hare +- He's happy to drive from a more distant airport at his destination +- He's a frequent traveler +- He prefers to fly business class, unless the cost is prohibitive +- He prefers aisle seats +- He is allergic to peanuts +- He is celiac and cannot eat gluten \ No newline at end of file diff --git a/pom.xml b/pom.xml index 2f3e0d5..2c27e2f 100644 --- a/pom.xml +++ b/pom.xml @@ -20,6 +20,7 @@ 21 0.3.5-SNAPSHOT 0.1.0-SNAPSHOT + 0.1.0-SNAPSHOT 24.6.4 @@ -56,6 +57,19 @@ ${embabel-rag-pgvector.version} + + com.embabel.dice + dice + ${dice.version} + + + + + com.embabel + vaadin-components + 0.1.0-SNAPSHOT + + org.springframework.boot diff --git a/src/main/bundles/README.md b/src/main/bundles/README.md new file mode 100644 index 0000000..c9737d1 --- /dev/null +++ b/src/main/bundles/README.md @@ -0,0 +1,32 @@ +This directory is automatically generated by Vaadin and contains the pre-compiled +frontend files/resources for your project (frontend development bundle). + +It should be added to Version Control System and committed, so that other developers +do not have to compile it again. + +Frontend development bundle is automatically updated when needed: +- an npm/pnpm package is added with @NpmPackage or directly into package.json +- CSS, JavaScript or TypeScript files are added with @CssImport, @JsModule or @JavaScript +- Vaadin add-on with front-end customizations is added +- Custom theme imports/assets added into 'theme.json' file +- Exported web component is added. + +If your project development needs a hot deployment of the frontend changes, +you can switch Flow to use Vite development server (default in Vaadin 23.3 and earlier versions): +- set `vaadin.frontend.hotdeploy=true` in `application.properties` +- configure `vaadin-maven-plugin`: +``` + + true + +``` +- configure `jetty-maven-plugin`: +``` + + + true + + +``` + +Read more [about Vaadin development mode](https://vaadin.com/docs/next/flow/configuration/development-mode#precompiled-bundle). \ No newline at end of file diff --git a/src/main/bundles/dev.bundle b/src/main/bundles/dev.bundle new file mode 100644 index 0000000..0583abc Binary files /dev/null and b/src/main/bundles/dev.bundle differ diff --git a/src/main/frontend/themes/embabel-air/styles.css b/src/main/frontend/themes/embabel-air/styles.css index 9a86033..755acae 100644 --- a/src/main/frontend/themes/embabel-air/styles.css +++ b/src/main/frontend/themes/embabel-air/styles.css @@ -34,488 +34,7 @@ --lumo-font-family: system-ui, -apple-system, sans-serif; } -/* Body */ -html { - background: var(--sb-bg-dark); - min-height: 100vh; -} - -body { - background: transparent; -} - -/* Main layout */ -vaadin-vertical-layout { - background: transparent; -} - -/* Header */ -.chat-title { - color: var(--sb-text-primary); - font-weight: 600; - margin: 0; -} - -.chat-subtitle { - color: var(--sb-text-secondary); - font-size: var(--lumo-font-size-s); -} - -/* Chat scroller */ -.chat-scroller { - background: var(--sb-bg-medium); - border: 1px solid var(--sb-border); - border-radius: var(--lumo-border-radius-l); -} - -/* Chat bubbles */ -.chat-bubble-container { - display: flex; - width: 100%; - margin-bottom: var(--lumo-space-s); -} - -.chat-bubble-container.user { - justify-content: flex-end; -} - -.chat-bubble-container.assistant { - justify-content: flex-start; -} - -.chat-bubble { - max-width: 80%; - padding: var(--lumo-space-m); - border-radius: var(--lumo-border-radius-l); -} - -.chat-bubble.user { - background: var(--sb-accent); - color: white; -} - -.chat-bubble.assistant { - background: var(--sb-bg-light); - border: 1px solid var(--sb-border); - color: var(--sb-text-primary); -} - -.chat-bubble-sender { - display: block; - font-size: var(--lumo-font-size-xs); - font-weight: 600; - text-transform: uppercase; - letter-spacing: 0.05em; - margin-bottom: var(--lumo-space-xs); - opacity: 0.8; -} - -.chat-bubble-text { - line-height: 1.5; -} - -/* Markdown rendering in assistant messages */ -.chat-bubble.assistant .chat-bubble-text p { - margin: 0 0 var(--lumo-space-s) 0; -} - -.chat-bubble.assistant .chat-bubble-text p:last-child { - margin-bottom: 0; -} - -.chat-bubble.assistant .chat-bubble-text code { - background: rgba(0, 0, 0, 0.3); - padding: 2px 6px; - border-radius: 4px; - font-family: monospace; - font-size: 0.9em; -} - -.chat-bubble.assistant .chat-bubble-text pre { - background: rgba(0, 0, 0, 0.3); - padding: var(--lumo-space-s); - border-radius: var(--lumo-border-radius-m); - overflow-x: auto; - margin: var(--lumo-space-s) 0; -} - -.chat-bubble.assistant .chat-bubble-text pre code { - background: transparent; - padding: 0; -} - -.chat-bubble.assistant .chat-bubble-text ul, -.chat-bubble.assistant .chat-bubble-text ol { - margin: var(--lumo-space-s) 0; - padding-left: var(--lumo-space-l); -} - -.chat-bubble.assistant .chat-bubble-text blockquote { - border-left: 3px solid var(--sb-accent); - margin: var(--lumo-space-s) 0; - padding-left: var(--lumo-space-m); - color: var(--sb-text-secondary); -} - -/* Error message */ -.chat-bubble-error { - background: rgba(239, 68, 68, 0.2); - border: 1px solid var(--sb-error); - color: var(--sb-error); - padding: var(--lumo-space-m); - border-radius: var(--lumo-border-radius-m); - text-align: center; -} - -/* Text input */ -vaadin-text-field::part(input-field) { - background: var(--sb-bg-light); - border: 1px solid var(--sb-border); - border-radius: var(--lumo-border-radius-m); -} - -vaadin-text-field:focus-within::part(input-field) { - border-color: var(--sb-accent); -} - -/* Send button */ -vaadin-button[theme~="primary"] { - background: var(--sb-accent); - color: white; - border-radius: var(--lumo-border-radius-m); -} - -vaadin-button[theme~="primary"]:hover { - background: var(--sb-accent-light); -} - -/* Scrollbar styling */ -::-webkit-scrollbar { - width: 8px; - height: 8px; -} - -::-webkit-scrollbar-track { - background: transparent; -} - -::-webkit-scrollbar-thumb { - background: var(--sb-border); - border-radius: 4px; -} - -::-webkit-scrollbar-thumb:hover { - background: var(--sb-text-muted); -} - -/* Footer */ -.app-footer { - padding-top: var(--lumo-space-s); - border-top: 1px solid var(--sb-border); - margin-top: var(--lumo-space-s); -} - -.footer-copyright, -.footer-stats { - color: var(--sb-text-muted); - font-size: var(--lumo-font-size-xs); -} - -.footer-separator { - color: var(--sb-text-muted); - opacity: 0.5; -} - -/* Side Panel / Drawer */ -.side-panel { - position: fixed; - top: 0; - right: 0; - bottom: 0; - width: 400px; - max-width: 90vw; - background: linear-gradient(180deg, var(--sb-bg-medium) 0%, var(--sb-bg-dark) 100%); - border-left: 1px solid var(--sb-border); - box-shadow: -8px 0 32px rgba(0, 0, 0, 0.5); - transform: translateX(100%); - transition: transform 0.3s ease; - z-index: 200; - display: flex; - flex-direction: column; -} - -.side-panel.open { - transform: translateX(0); -} - -.side-panel-header { - padding: var(--lumo-space-m); - border-bottom: 1px solid var(--sb-border); - align-items: center; -} - -.side-panel-title { - font-size: var(--lumo-font-size-l); - font-weight: 600; - color: var(--sb-text-primary); -} - -.side-panel-close { - background: transparent; - border: none; - color: var(--sb-text-secondary); - cursor: pointer; - padding: var(--lumo-space-xs); -} - -.side-panel-close:hover { - color: var(--sb-text-primary); -} - -.side-panel-content { - flex: 1; - overflow-y: auto; - padding: var(--lumo-space-m); -} - -/* Side Panel Toggle Button */ -.side-panel-toggle { - position: fixed; - right: var(--lumo-space-m); - top: 50%; - transform: translateY(-50%); - z-index: 150; - background: linear-gradient(135deg, var(--sb-bg-light) 0%, var(--sb-bg-medium) 100%); - border: 1px solid var(--sb-border); - border-radius: var(--lumo-border-radius-m) 0 0 var(--lumo-border-radius-m); - color: var(--sb-accent); - padding: var(--lumo-space-s) var(--lumo-space-xs); - cursor: pointer; - transition: all 0.2s ease; -} - -.side-panel-toggle:hover { - background: var(--sb-bg-light); - border-color: var(--sb-accent); -} - -.side-panel-toggle.hidden { - opacity: 0; - pointer-events: none; -} - -/* Side Panel Backdrop */ -.side-panel-backdrop { - position: fixed; - top: 0; - left: 0; - right: 0; - bottom: 0; - background: rgba(0, 0, 0, 0.3); - z-index: 190; - opacity: 0; - pointer-events: none; - transition: opacity 0.3s ease; -} - -.side-panel-backdrop.visible { - opacity: 1; - pointer-events: auto; -} - -/* Section Styles */ -.section-instructions { - color: var(--sb-text-secondary); - font-size: var(--lumo-font-size-s); - margin-bottom: var(--lumo-space-m); - display: block; -} - -.section-title { - color: var(--sb-text-primary); - margin: 0 0 var(--lumo-space-m) 0; - font-size: var(--lumo-font-size-l); -} - -/* Stats Container */ -.stats-container { - background: var(--sb-bg-light); - border: 1px solid var(--sb-border); - border-radius: var(--lumo-border-radius-m); - padding: var(--lumo-space-m); -} - -.stat-row { - display: flex; - justify-content: space-between; - align-items: center; - gap: var(--lumo-space-l); - padding: var(--lumo-space-s) 0; - border-bottom: 1px solid var(--sb-border); -} - -.stat-row .stat-label { - flex-shrink: 0; -} - -.stat-row .stat-value { - text-align: right; -} - -.stat-row:last-child { - border-bottom: none; -} - -.stat-label { - color: var(--sb-text-secondary); -} - -.stat-value { - color: var(--sb-text-primary); - font-weight: 600; -} - -/* Tool Call Indicator */ -.tool-call-indicator { - color: var(--sb-text-muted); - font-size: var(--lumo-font-size-s); - padding: var(--lumo-space-xs) var(--lumo-space-m); - background: var(--sb-bg-light); - border-left: 3px solid var(--sb-accent); - border-radius: var(--lumo-border-radius-s); - margin: var(--lumo-space-xs) 0; - font-family: monospace; - animation: pulse 1.5s ease-in-out infinite; -} - -@keyframes pulse { - 0%, 100% { opacity: 0.7; } - 50% { opacity: 1; } -} - -/* Login View */ -.login-view { - background: var(--sb-bg-dark); -} - -.login-title { - color: var(--sb-text-primary); - font-weight: 600; - margin-bottom: 0; -} - -.login-subtitle { - color: var(--sb-text-secondary); - font-size: var(--lumo-font-size-s); - margin-bottom: var(--lumo-space-l); -} - -/* User Section */ -.profile-chip { - background: var(--sb-bg-light); - border: 1px solid var(--sb-border); - border-radius: 20px; - padding: var(--lumo-space-xs) var(--lumo-space-m) var(--lumo-space-xs) var(--lumo-space-xs); - gap: var(--lumo-space-s); -} - -.user-avatar { - width: 28px; - height: 28px; - border-radius: 50%; - background: var(--sb-accent); - color: white; - display: flex; - align-items: center; - justify-content: center; - font-size: var(--lumo-font-size-xs); - font-weight: 600; -} - -.user-name { - color: var(--sb-text-primary); - font-size: var(--lumo-font-size-s); -} - -.logout-button { - color: var(--sb-text-secondary); - font-size: var(--lumo-font-size-s); -} - -.logout-button:hover { - color: var(--sb-text-primary); -} - -/* Context Selector */ -.context-select { - --vaadin-select-text-field-width: auto; - min-width: 100px; -} - -.context-select::part(input-field) { - background: var(--sb-bg-light); - border: 1px solid var(--sb-border); - border-radius: var(--lumo-border-radius-m); - padding: 0 var(--lumo-space-s); - min-height: 32px; -} - -.context-select::part(toggle-button) { - color: var(--sb-text-secondary); -} - -.context-select:hover::part(input-field) { - border-color: var(--sb-accent); -} - -/* Documents List */ -.documents-list { - max-height: 400px; - overflow-y: auto; -} - -.document-row { - padding: var(--lumo-space-s) var(--lumo-space-m); - border-bottom: 1px solid var(--sb-border); -} - -.document-row:last-child { - border-bottom: none; -} - -.document-row:hover { - background: var(--sb-bg-light); -} - -.document-title { - color: var(--sb-text-primary); - font-size: var(--lumo-font-size-s); - white-space: nowrap; - overflow: hidden; - text-overflow: ellipsis; - max-width: 280px; - display: block; -} - -.context-badge { - display: inline-block; - background: var(--sb-accent); - color: white; - font-size: var(--lumo-font-size-xs); - padding: 2px 8px; - border-radius: 10px; - margin-top: var(--lumo-space-xs); -} - -.empty-list-label { - color: var(--sb-text-muted); - font-size: var(--lumo-font-size-s); - padding: var(--lumo-space-m); - text-align: center; - display: block; -} - -/* Session Panel Container */ +/* Session Panel — embabel-air-specific */ .session-panel-container { position: fixed; top: 0; @@ -525,48 +44,6 @@ vaadin-button[theme~="primary"]:hover { z-index: 185; } -/* Session Panel specific styles */ .session-panel { background: linear-gradient(180deg, var(--sb-bg-medium) 0%, var(--sb-bg-dark) 100%); } - -/* Vaadin Tabs in side panel */ -.side-panel vaadin-tabs { - background: var(--sb-bg-light); - border-bottom: 1px solid var(--sb-border); -} - -.side-panel vaadin-tab { - color: var(--sb-text-secondary); - padding: var(--lumo-space-s) var(--lumo-space-m); -} - -.side-panel vaadin-tab[selected] { - color: var(--sb-accent); -} - -.side-panel vaadin-tab:hover { - color: var(--sb-text-primary); -} - -/* Asset card styles */ -.asset-card { - background: var(--sb-bg-light); - border: 1px solid var(--sb-border); - border-left: 3px solid var(--sb-accent); - border-radius: var(--lumo-border-radius-m); - padding: var(--lumo-space-s); - margin-bottom: var(--lumo-space-s); -} - -/* Mobile Responsive */ -@media (max-width: 768px) { - .side-panel { - width: 100%; - max-width: 100%; - } - - .side-panel-toggle { - right: var(--lumo-space-s); - } -} diff --git a/src/main/frontend/themes/embabel-air/theme.json b/src/main/frontend/themes/embabel-air/theme.json new file mode 100644 index 0000000..2d57719 --- /dev/null +++ b/src/main/frontend/themes/embabel-air/theme.json @@ -0,0 +1,4 @@ +{ + "parent": "embabel-base", + "lumoImports": ["typography", "color", "spacing", "badge", "utility"] +} diff --git a/src/main/java/com/embabel/air/ai/AirProperties.java b/src/main/java/com/embabel/air/ai/AirProperties.java index 563b6d6..0de2d62 100644 --- a/src/main/java/com/embabel/air/ai/AirProperties.java +++ b/src/main/java/com/embabel/air/ai/AirProperties.java @@ -2,6 +2,7 @@ import com.embabel.agent.rag.ingestion.ContentChunker; import com.embabel.common.ai.model.LlmOptions; +import com.embabel.dice.proposition.extraction.PropositionExtractionProperties; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; @@ -10,13 +11,15 @@ * * @param chatLlm LLM model and hyperparameters to use * @param chunkerConfig configuration for ingestion + * @param memory memory extraction properties */ @ConfigurationProperties(prefix = "embabel-air") public record AirProperties( boolean showChatPrompts, @NestedConfigurationProperty LlmOptions chatLlm, @NestedConfigurationProperty LlmOptions triageLlm, - @NestedConfigurationProperty ContentChunker.Config chunkerConfig + @NestedConfigurationProperty ContentChunker.Config chunkerConfig, + @NestedConfigurationProperty PropositionExtractionProperties memory ) { } diff --git a/src/main/java/com/embabel/air/ai/agent/ChatActions.java b/src/main/java/com/embabel/air/ai/agent/ChatActions.java index 11292cc..1bb5a04 100644 --- a/src/main/java/com/embabel/air/ai/agent/ChatActions.java +++ b/src/main/java/com/embabel/air/ai/agent/ChatActions.java @@ -5,17 +5,29 @@ import com.embabel.agent.api.tool.Tool; import com.embabel.air.ai.AirProperties; import com.embabel.air.ai.rag.RagConfiguration.AirlinePolicies; +import com.embabel.air.backend.BookingService; import com.embabel.air.backend.Customer; import com.embabel.air.backend.Reservation; import com.embabel.air.backend.ReservationRepository; import com.embabel.chat.AssistantMessage; import com.embabel.chat.Conversation; +import com.embabel.chat.Message; +import com.embabel.dice.agent.Memory; +import com.embabel.dice.common.ConversationAnalysisRequestEvent; +import com.embabel.dice.projection.memory.MemoryProjector; +import com.embabel.dice.proposition.PropositionRepository; import com.embabel.springdata.EntityViewService; +import com.fasterxml.jackson.databind.ObjectMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.context.ApplicationEventPublisher; +import java.time.LocalDate; +import java.util.Arrays; +import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; /** * The platform can use any action to respond to user messages. @@ -23,8 +35,19 @@ @EmbabelComponent public class ChatActions { - private final static Logger logger = LoggerFactory.getLogger(ChatActions.class); + private static final Logger logger = LoggerFactory.getLogger(ChatActions.class); + private static final ObjectMapper objectMapper = new ObjectMapper(); + private final ApplicationEventPublisher eventPublisher; + private final AirProperties airProperties; + private final BookingService bookingService; + + public ChatActions(ApplicationEventPublisher eventPublisher, AirProperties airProperties, + BookingService bookingService) { + this.eventPublisher = eventPublisher; + this.airProperties = airProperties; + this.bookingService = bookingService; + } /** * Condition: true if the last message in the conversation is from the user, @@ -60,7 +83,7 @@ ChitchatState greetCustomer( } else { logger.warn("greetCustomer: forUser is not a Customer: {}", forUser); } - return new ChitchatState(); + return new ChitchatState(eventPublisher, airProperties, bookingService); } /** @@ -69,6 +92,17 @@ ChitchatState greetCustomer( @State static class ChitchatState implements AirState { + private final ApplicationEventPublisher eventPublisher; + private final AirProperties airProperties; + private final BookingService bookingService; + + ChitchatState(ApplicationEventPublisher eventPublisher, AirProperties airProperties, + BookingService bookingService) { + this.eventPublisher = eventPublisher; + this.airProperties = airProperties; + this.bookingService = bookingService; + } + @Action( pre = "shouldRespond", canRerun = true @@ -79,32 +113,41 @@ AirState respond( ActionContext context, @Provided AirProperties properties, @Provided AirlinePolicies airlinePolicies, - @Provided EntityViewService entityViewService) { + @Provided EntityViewService entityViewService, + @Provided PropositionRepository propositionRepository, + @Provided MemoryProjector memoryProjector) { - var assets = conversation.getAssetTracker().mostRecentlyAdded(1).references(); + var references = new LinkedList<>( + conversation.getAssetTracker().mostRecentlyAdded(1).references()); + references.add(airlinePolicies.reference()); + references.add(entityViewService.entityReferenceFor(customer)); + if (properties.memory().getEnabled()) { + String recentContext = conversation.getMessages().stream() + // TODO fix hardcoding + .skip(Math.max(0, conversation.getMessages().size() - 15)) + .map(Message::getContent) + .collect(Collectors.joining("\n")); + references.add( + Memory.forContext(customer.getId()) + .withRepository(propositionRepository) + .withProjector(memoryProjector) + .withEagerSearchAbout(recentContext, properties.memory().getExistingPropositionsToShow()) + ); + } + + var tools = new LinkedList<>(commonTools()); + tools.addAll(bookingTools(bookingService, customer)); + tools.addAll(conversation.getAssetTracker().addAnyReturnedAssets( + entityViewService.repositoryToolsFor(ReservationRepository.class))); + tools.addAll(conversation.getAssetTracker().addAnyReturnedAssets( + List.of(entityViewService.finderFor(Reservation.class)))); var assistantMessage = context. ai() .withLlm(properties.chatLlm()) .withId("chitchat.respond") - .withReferences( - airlinePolicies.reference(), - entityViewService.entityReferenceFor(customer) - ) - .withReferences(assets) - .withTools(commonTools()) - .withTools( - conversation.getAssetTracker().addAnyReturnedAssets( - entityViewService.repositoryToolsFor(ReservationRepository.class) - ) - ) - .withTools( - conversation.getAssetTracker().addAnyReturnedAssets( - List.of( - entityViewService.finderFor(Reservation.class) - ) - ) - ) + .withReferences(references) + .withTools(tools) .rendering("air") .respondWithSystemPrompt( conversation, @@ -112,6 +155,12 @@ AirState respond( "properties", properties )); context.sendAndSave(assistantMessage); + + if (airProperties.memory() != null && airProperties.memory().getEnabled()) { + eventPublisher.publishEvent( + new ConversationAnalysisRequestEvent(this, customer, conversation)); + } + return this; } } @@ -185,4 +234,135 @@ private static List commonTools() { completionTool() ); } + + private static Tool searchFlightsTool(BookingService bookingService) { + var description = """ + Search for available flights between two airports on a given date. + Returns direct and one-stop connecting itineraries with prices. + Input JSON: {"fromAirport": "JFK", "toAirport": "LAX", "date": "2026-03-15"} + - fromAirport: 3-letter IATA airport code for departure + - toAirport: 3-letter IATA airport code for arrival + - date: departure date in YYYY-MM-DD format + """; + return Tool.create("search_flights", description, input -> { + try { + var node = objectMapper.readTree(input); + var from = node.get("fromAirport").asText(); + var to = node.get("toAirport").asText(); + var date = LocalDate.parse(node.get("date").asText()); + + var itineraries = bookingService.searchRoutes(from, to, date, 1); + if (itineraries.isEmpty()) { + return Tool.Result.text("No flights found from %s to %s on %s.".formatted(from, to, date)); + } + + var sb = new StringBuilder(); + sb.append("Found %d itinerary(ies) from %s to %s on %s:\n\n".formatted( + itineraries.size(), from, to, date)); + for (int i = 0; i < itineraries.size(); i++) { + sb.append("Option %d: %s\n\n".formatted(i + 1, itineraries.get(i).summary())); + } + return Tool.Result.text(sb.toString()); + } catch (Exception e) { + logger.error("search_flights error", e); + return Tool.Result.text("Error searching flights: " + e.getMessage()); + } + }); + } + + private static Tool bookFlightTool(BookingService bookingService, Customer customer) { + var description = """ + Book flights for the current customer using flight segment IDs from search results. + Input JSON: {"flightSegmentIds": "id1,id2"} + - flightSegmentIds: comma-separated flight segment IDs to book + """; + return Tool.create("book_flight", description, input -> { + try { + var node = objectMapper.readTree(input); + var idsRaw = node.get("flightSegmentIds").asText(); + var ids = Arrays.stream(idsRaw.split(",")) + .map(String::trim) + .filter(s -> !s.isEmpty()) + .toList(); + + var reservation = bookingService.book(customer, ids); + var segments = reservation.getFlightSegments(); + var sb = new StringBuilder(); + sb.append("Booking confirmed! Reference: %s\n".formatted(reservation.getBookingReference())); + sb.append("Total: $%s\n".formatted(reservation.getPaidAmount())); + for (var seg : segments) { + sb.append(" %s → %s departing %s arriving %s\n".formatted( + seg.getDepartureAirportCode(), seg.getArrivalAirportCode(), + seg.getDepartureDateTime(), seg.getArrivalDateTime())); + } + return Tool.Result.text(sb.toString()); + } catch (Exception e) { + logger.error("book_flight error", e); + return Tool.Result.text("Error booking flight: " + e.getMessage()); + } + }); + } + + private static Tool cancelBookingTool(BookingService bookingService) { + var description = """ + Cancel an existing reservation by its booking reference. + Input JSON: {"bookingReference": "ABC123"} + - bookingReference: the 6-character booking reference code + """; + return Tool.create("cancel_booking", description, input -> { + try { + var node = objectMapper.readTree(input); + var ref = node.get("bookingReference").asText(); + bookingService.cancel(ref); + return Tool.Result.text("Reservation %s has been cancelled successfully.".formatted(ref)); + } catch (Exception e) { + logger.error("cancel_booking error", e); + return Tool.Result.text("Error cancelling booking: " + e.getMessage()); + } + }); + } + + private static Tool rebookFlightTool(BookingService bookingService) { + var description = """ + Rebook an existing reservation onto different flights. + Cancels the old flight segments and books new ones, preserving the booking reference. + Input JSON: {"bookingReference": "ABC123", "newFlightSegmentIds": "id1,id2"} + - bookingReference: the 6-character booking reference to rebook + - newFlightSegmentIds: comma-separated new flight segment IDs + """; + return Tool.create("rebook_flight", description, input -> { + try { + var node = objectMapper.readTree(input); + var ref = node.get("bookingReference").asText(); + var idsRaw = node.get("newFlightSegmentIds").asText(); + var ids = Arrays.stream(idsRaw.split(",")) + .map(String::trim) + .filter(s -> !s.isEmpty()) + .toList(); + + var reservation = bookingService.rebook(ref, ids); + var sb = new StringBuilder(); + sb.append("Rebooked! Reference: %s\n".formatted(reservation.getBookingReference())); + sb.append("New total: $%s\n".formatted(reservation.getPaidAmount())); + for (var seg : reservation.getFlightSegments()) { + sb.append(" %s → %s departing %s arriving %s\n".formatted( + seg.getDepartureAirportCode(), seg.getArrivalAirportCode(), + seg.getDepartureDateTime(), seg.getArrivalDateTime())); + } + return Tool.Result.text(sb.toString()); + } catch (Exception e) { + logger.error("rebook_flight error", e); + return Tool.Result.text("Error rebooking: " + e.getMessage()); + } + }); + } + + private static List bookingTools(BookingService bookingService, Customer customer) { + return List.of( + searchFlightsTool(bookingService), + bookFlightTool(bookingService, customer), + cancelBookingTool(bookingService), + rebookFlightTool(bookingService) + ); + } } \ No newline at end of file diff --git a/src/main/java/com/embabel/air/ai/memory/MemoryConfiguration.java b/src/main/java/com/embabel/air/ai/memory/MemoryConfiguration.java new file mode 100644 index 0000000..3ec4b57 --- /dev/null +++ b/src/main/java/com/embabel/air/ai/memory/MemoryConfiguration.java @@ -0,0 +1,195 @@ +package com.embabel.air.ai.memory; + +import com.embabel.agent.api.common.Ai; +import com.embabel.agent.core.DataDictionary; +import com.embabel.agent.rag.model.NamedEntity; +import com.embabel.agent.rag.pgvector.JdbcNamedEntityDataRepository; +import com.embabel.agent.rag.pgvector.NativeEntityLookup; +import com.embabel.agent.rag.service.NamedEntityDataRepository; +import com.embabel.air.ai.AirProperties; +import com.embabel.air.backend.City; +import com.embabel.air.backend.Country; +import com.embabel.air.backend.Customer; +import com.embabel.air.backend.CustomerRepository; +import com.embabel.dice.common.EntityResolver; +import com.embabel.dice.common.KnowledgeType; +import com.embabel.dice.common.Relations; +import com.embabel.dice.common.resolver.BakeoffPromptStrategies; +import com.embabel.dice.common.resolver.EscalatingEntityResolver; +import com.embabel.dice.common.resolver.LlmCandidateBakeoff; +import com.embabel.dice.incremental.ChunkHistoryStore; +import com.embabel.dice.incremental.InMemoryChunkHistoryStore; +import com.embabel.dice.pipeline.PropositionPipeline; +import com.embabel.dice.projection.graph.*; +import com.embabel.dice.projection.memory.MemoryProjector; +import com.embabel.dice.projection.memory.support.DefaultMemoryProjector; +import com.embabel.dice.projection.memory.support.RelationBasedKnowledgeTypeClassifier; +import com.embabel.dice.proposition.PropositionRepository; +import com.embabel.dice.proposition.extraction.IncrementalPropositionExtraction; +import com.embabel.dice.proposition.extraction.LlmPropositionExtractor; +import com.embabel.dice.proposition.jdbc.JdbcPropositionRepository; +import com.embabel.dice.proposition.revision.LlmPropositionReviser; +import com.embabel.dice.proposition.revision.PropositionReviser; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Primary; +import org.springframework.jdbc.core.simple.JdbcClient; +import org.springframework.scheduling.annotation.EnableAsync; + +/** + * Assembles the full DICE memory pipeline for Embabel Air: + * proposition extraction, entity resolution, graph projection, and memory recall. + */ +@Configuration +@EnableAsync +public class MemoryConfiguration { + + @Bean + @Primary + DataDictionary airSchema() { + return DataDictionary.fromClasses("embabel-air", Customer.class) + .plus(NamedEntity.dataDictionaryFromPackages("com.embabel.air.backend")); + } + + @Bean + Relations airRelations() { + return Relations.empty() + .withPredicatesForSubject(Customer.class, KnowledgeType.SEMANTIC, + "prefers", "likes", "dislikes", "lives_in", "travels_to", "is_member_of") + .withPredicatesForSubject(City.class, KnowledgeType.SEMANTIC, + "is_in", "has_airport") + .withPredicatesForSubject(Country.class, KnowledgeType.SEMANTIC, + "contains"); + } + + @Bean + JdbcPropositionRepository jpaPropositionRepository(JdbcClient jdbcClient, Ai ai) { + return new JdbcPropositionRepository(jdbcClient, ai.withDefaultEmbeddingService()); + } + + @Bean + NamedEntityDataRepository namedEntityDataRepository( + JdbcClient jdbcClient, Ai ai, DataDictionary dataDictionary, + CustomerRepository customerRepository) { + return JdbcNamedEntityDataRepository.builder() + .withJdbcClient(jdbcClient) + .withDataDictionary(dataDictionary) + .withEmbeddingService(ai.withDefaultEmbeddingService()) + .withNativeLookup(Customer.class, new NativeEntityLookup<>() { + @Override + public Customer findById(String id) { + return customerRepository.findById(id).orElse(null); + } + + @Override + public java.util.List findAll() { + return customerRepository.findAll(); + } + }) + .build(); + } + + @Bean + EntityResolver entityResolver(NamedEntityDataRepository repository, Ai ai, AirProperties properties) { + var bakeoff = LlmCandidateBakeoff + .withLlm(properties.memory() != null && properties.memory().getExtractionLlm() != null + ? properties.memory().getExtractionLlm() + : properties.chatLlm()) + .withAi(ai) + .withPromptStrategy(BakeoffPromptStrategies.FULL); + return EscalatingEntityResolver.create(repository, bakeoff); + } + + @Bean + LlmPropositionExtractor llmPropositionExtractor(Ai ai, AirProperties properties) { + var llm = properties.memory() != null && properties.memory().getExtractionLlm() != null + ? properties.memory().getExtractionLlm() + : properties.chatLlm(); + return LlmPropositionExtractor + .withLlm(llm) + .withAi(ai) + .withTemplate("dice/extract_air_propositions"); + } + + @Bean + PropositionReviser propositionReviser(Ai ai, AirProperties properties) { + var llm = properties.memory() != null && properties.memory().getExtractionLlm() != null + ? properties.memory().getExtractionLlm() + : properties.chatLlm(); + return LlmPropositionReviser + .withLlm(llm) + .withAi(ai); + } + + @Bean + PropositionPipeline propositionPipeline( + LlmPropositionExtractor extractor, + PropositionReviser reviser, + JdbcPropositionRepository repository) { + return PropositionPipeline + .withExtractor(extractor) + .withRevision(reviser, repository); + } + + @Bean + GraphProjector graphProjector(Relations relations, Ai ai, AirProperties properties) { + var llm = properties.memory() != null && properties.memory().getExtractionLlm() != null + ? properties.memory().getExtractionLlm() + : properties.chatLlm(); + return LlmGraphProjector + .withLlm(llm) + .withAi(ai) + .withRelations(relations) + .withLenientPolicy(); + } + + @Bean + GraphRelationshipPersister graphRelationshipPersister(NamedEntityDataRepository repository) { + return new NamedEntityDataRepositoryGraphRelationshipPersister(repository); + } + + @Bean + GraphProjectionService graphProjectionService( + GraphProjector graphProjector, + GraphRelationshipPersister persister, + DataDictionary dataDictionary) { + return GraphProjectionService.create(graphProjector, persister, dataDictionary); + } + + @Bean + MemoryProjector memoryProjector(Relations relations) { + return DefaultMemoryProjector + .withKnowledgeTypeClassifier(new RelationBasedKnowledgeTypeClassifier(relations)); + } + + @Bean + ChunkHistoryStore chunkHistoryStore() { + return new InMemoryChunkHistoryStore(); + } + + @Bean + IncrementalPropositionExtraction incrementalPropositionExtraction( + PropositionPipeline pipeline, + ChunkHistoryStore chunkHistoryStore, + DataDictionary dataDictionary, + Relations relations, + PropositionRepository propositionRepository, + NamedEntityDataRepository entityRepository, + EntityResolver entityResolver, + GraphProjectionService graphProjectionService, + AirProperties properties) { + return new IncrementalPropositionExtraction( + pipeline, + chunkHistoryStore, + dataDictionary, + relations, + propositionRepository, + entityRepository, + entityResolver, + graphProjectionService, + properties.memory(), + user -> user.getId(), + user -> java.util.Map.of("customer", user) + ); + } +} diff --git a/src/main/java/com/embabel/air/ai/rag/DocumentService.java b/src/main/java/com/embabel/air/ai/rag/DocumentService.java index 78cecaa..e6b9aba 100644 --- a/src/main/java/com/embabel/air/ai/rag/DocumentService.java +++ b/src/main/java/com/embabel/air/ai/rag/DocumentService.java @@ -5,31 +5,26 @@ import com.embabel.agent.rag.model.ContentRoot; import com.embabel.agent.rag.model.NavigableDocument; import com.embabel.agent.rag.store.ChunkingContentElementRepository; -import io.vavr.collection.List; +import com.embabel.vaadin.document.DocumentInfoProvider; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.stereotype.Service; -import java.time.Instant; +import java.util.List; import java.util.Map; +import java.util.stream.StreamSupport; /** * Service for managing document retrieval. */ @Service -public class DocumentService { +public class DocumentService implements DocumentInfoProvider { private static final Logger logger = LoggerFactory.getLogger(DocumentService.class); private final HierarchicalContentReader contentReader = new TikaHierarchicalContentReader(); private final ChunkingContentElementRepository contentElementRepository; - /** - * Summary info about an ingested document. - */ - public record DocumentInfo(String uri, String title, String context, Instant ingestedAt) { - } - public DocumentService( ChunkingContentElementRepository chunkingContentElementRepository) { this.contentElementRepository = chunkingContentElementRepository; @@ -43,17 +38,17 @@ public NavigableDocument ingestUrl(String url) { return document; } - /** - * Get list of all ingested documents. - */ - public List getDocuments() { - return List.ofAll(contentElementRepository.findAll(ContentRoot.class)) - .map(doc -> new DocumentInfo( + @Override + public List getDocuments() { + return StreamSupport.stream(contentElementRepository.findAll(ContentRoot.class).spliterator(), false) + .map(doc -> new DocumentInfoProvider.DocumentInfo( doc.getUri(), doc.getTitle(), extractContext(doc.getMetadata()), + 0, doc.getIngestionTimestamp() - )); + )) + .toList(); } private String extractContext(Map metadata) { @@ -64,25 +59,19 @@ private String extractContext(Map metadata) { return context != null ? context.toString() : ""; } - /** - * Delete a document by its URI. - */ + @Override public boolean deleteDocument(String uri) { logger.info("Deleting document: {}", uri); var result = contentElementRepository.deleteRootAndDescendants(uri); return result != null; } - /** - * Get total document count. - */ + @Override public int getDocumentCount() { return contentElementRepository.info().getDocumentCount(); } - /** - * Get total chunk count. - */ + @Override public int getChunkCount() { return contentElementRepository.info().getChunkCount(); } diff --git a/src/main/java/com/embabel/air/ai/rag/RagConfiguration.java b/src/main/java/com/embabel/air/ai/rag/RagConfiguration.java index 2539158..f10ff74 100644 --- a/src/main/java/com/embabel/air/ai/rag/RagConfiguration.java +++ b/src/main/java/com/embabel/air/ai/rag/RagConfiguration.java @@ -5,7 +5,6 @@ import com.embabel.agent.rag.ingestion.transform.AddTitlesChunkTransformer; import com.embabel.agent.rag.pgvector.PgVectorStore; import com.embabel.agent.rag.pgvector.PgVectorStoreBuilder; -import com.embabel.agent.rag.service.SearchOperations; import com.embabel.agent.rag.tools.ToolishRag; import com.embabel.air.ai.AirProperties; import org.slf4j.Logger; @@ -39,7 +38,7 @@ PgVectorStore pgVectorStore( } @Bean - AirlinePolicies airlinePolicies(SearchOperations searchOperations) { + AirlinePolicies airlinePolicies(PgVectorStore searchOperations) { return new AirlinePolicies( new ToolishRag("policies", "Embabel Air policies", searchOperations) .asMatryoshka() diff --git a/src/main/java/com/embabel/air/ai/vaadin/ChatMessageBubble.java b/src/main/java/com/embabel/air/ai/vaadin/ChatMessageBubble.java deleted file mode 100644 index b26bb66..0000000 --- a/src/main/java/com/embabel/air/ai/vaadin/ChatMessageBubble.java +++ /dev/null @@ -1,66 +0,0 @@ -package com.embabel.air.ai.vaadin; - -import com.vaadin.flow.component.Html; -import com.vaadin.flow.component.html.Div; -import com.vaadin.flow.component.html.Span; -import org.commonmark.parser.Parser; -import org.commonmark.renderer.html.HtmlRenderer; - -/** - * Chat message bubble component with sender name and text content. - * Styled differently for user vs assistant messages. - * Assistant messages render markdown as HTML. - */ -public class ChatMessageBubble extends Div { - - private static final Parser MARKDOWN_PARSER = Parser.builder().build(); - private static final HtmlRenderer HTML_RENDERER = HtmlRenderer.builder().build(); - - public ChatMessageBubble(String sender, String text, boolean isUser) { - addClassName("chat-bubble-container"); - addClassName(isUser ? "user" : "assistant"); - - var messageDiv = new Div(); - messageDiv.addClassName("chat-bubble"); - messageDiv.addClassName(isUser ? "user" : "assistant"); - - var senderSpan = new Span(sender); - senderSpan.addClassName("chat-bubble-sender"); - - if (isUser) { - var textSpan = new Span(text); - textSpan.addClassName("chat-bubble-text"); - messageDiv.add(senderSpan, textSpan); - } else { - var contentDiv = new Div(); - contentDiv.addClassName("chat-bubble-text"); - contentDiv.add(new Html("
" + renderMarkdown(text) + "
")); - messageDiv.add(senderSpan, contentDiv); - } - - add(messageDiv); - } - - private static String renderMarkdown(String markdown) { - if (markdown == null || markdown.isBlank()) { - return ""; - } - var document = MARKDOWN_PARSER.parse(markdown.strip()); - return HTML_RENDERER.render(document).strip(); - } - - public static ChatMessageBubble user(String text) { - return new ChatMessageBubble("You", text, true); - } - - public static ChatMessageBubble assistant(String persona, String text) { - return new ChatMessageBubble(persona, text, false); - } - - public static Div error(String text) { - var messageDiv = new Div(); - messageDiv.addClassName("chat-bubble-error"); - messageDiv.setText(text); - return messageDiv; - } -} diff --git a/src/main/java/com/embabel/air/ai/vaadin/ChatView.java b/src/main/java/com/embabel/air/ai/vaadin/ChatView.java index a1af276..1f855a0 100644 --- a/src/main/java/com/embabel/air/ai/vaadin/ChatView.java +++ b/src/main/java/com/embabel/air/ai/vaadin/ChatView.java @@ -13,6 +13,11 @@ import com.embabel.chat.ChatSession; import com.embabel.chat.Chatbot; import com.embabel.chat.UserMessage; +import com.embabel.dice.proposition.PropositionRepository; +import com.embabel.dice.proposition.extraction.IncrementalPropositionExtraction; +import com.embabel.vaadin.component.ChatMessageBubble; +import com.embabel.vaadin.component.Footer; +import com.embabel.vaadin.component.UserSection; import com.vaadin.flow.component.Key; import com.vaadin.flow.component.UI; import com.vaadin.flow.component.button.Button; @@ -45,12 +50,15 @@ public class ChatView extends VerticalLayout { private static final Logger logger = LoggerFactory.getLogger(ChatView.class); + private static final String SESSION_DATA_KEY = "sessionData"; + private final Chatbot chatbot; private final String persona; private final AirProperties properties; private final DocumentService documentService; private final Customer currentUser; private final AgentPlatform agentPlatform; + private final PropositionRepository propositionRepository; private VerticalLayout messagesLayout; private Scroller messagesScroller; @@ -59,13 +67,19 @@ public class ChatView extends VerticalLayout { private Footer footer; private SessionPanel sessionPanel; + private final IncrementalPropositionExtraction propositionExtraction; + public ChatView(Chatbot chatbot, AirProperties properties, DocumentService documentService, - CustomerService userService, AgentPlatform agentPlatform) { + CustomerService userService, AgentPlatform agentPlatform, + PropositionRepository propositionRepository, + IncrementalPropositionExtraction propositionExtraction) { this.chatbot = chatbot; this.properties = properties; this.documentService = documentService; this.currentUser = userService.getAuthenticatedUser(); this.agentPlatform = agentPlatform; + this.propositionRepository = propositionRepository; + this.propositionExtraction = propositionExtraction; this.persona = "Emmie"; setSizeFull(); @@ -86,11 +100,16 @@ public ChatView(Chatbot chatbot, AirProperties properties, DocumentService docum // User section (right) - clicking opens session panel var userSection = new UserSection(currentUser, this::toggleSessionPanel); - headerRow.add(headerImage, userSection); + var logoutButton = new Button("Logout", e -> getUI().ifPresent(ui -> ui.getPage().setLocation("/logout"))); + logoutButton.addThemeVariants(ButtonVariant.LUMO_TERTIARY, ButtonVariant.LUMO_SMALL); + logoutButton.addClassName("logout-button"); + var userArea = new HorizontalLayout(userSection, logoutButton); + userArea.setAlignItems(Alignment.CENTER); + headerRow.add(headerImage, userArea); add(headerRow); // Session panel (drawer from right) - sessionPanel = new SessionPanel(currentUser, this::getCurrentSession, agentPlatform); + sessionPanel = new SessionPanel(currentUser, this::getCurrentSession, agentPlatform, propositionRepository, propositionExtraction); getElement().appendChild(sessionPanel.getElement()); // Messages container @@ -106,11 +125,14 @@ public ChatView(Chatbot chatbot, AirProperties properties, DocumentService docum add(messagesScroller); setFlexGrow(1, messagesScroller); + // Restore previous messages if session exists + restorePreviousMessages(); + // Input section add(createInputSection()); // Footer - footer = new Footer(documentService.getDocumentCount(), documentService.getChunkCount()); + footer = new Footer(documentService.getDocumentCount() + " documents \u00b7 " + documentService.getChunkCount() + " chunks"); add(footer); // Documents drawer @@ -120,33 +142,33 @@ public ChatView(Chatbot chatbot, AirProperties properties, DocumentService docum // Initialize session on attach (kicks off agent process and greeting) addAttachListener(event -> { var ui = event.getUI(); - restorePreviousMessages(ui); - initializeSession(); + initializeSession(ui); }); } - private void initializeSession() { - var ui = getUI().orElse(null); - if (ui == null) return; - + private void initializeSession(UI ui) { var vaadinSession = VaadinSession.getCurrent(); - var sessionKey = getSessionKey(ui); - if (vaadinSession.getAttribute(sessionKey) != null) { - return; // Session already exists for this UI + var sessionData = (SessionData) vaadinSession.getAttribute(SESSION_DATA_KEY); + + if (sessionData != null) { + // Session already exists — update the output channel's UI reference + // in case the UI was recreated (e.g., page refresh, reconnect) + sessionData.outputChannel().updateUI(ui); + return; } // Create session with output channel that directly updates UI var outputChannel = new VaadinOutputChannel(ui); var chatSession = chatbot.createSession(currentUser, outputChannel, null, UUID.randomUUID().toString()); - var sessionData = new SessionData(chatSession, outputChannel); - vaadinSession.setAttribute(sessionKey, sessionData); - logger.info("Created new chat session for UI {}", ui.getUIId()); + sessionData = new SessionData(chatSession, outputChannel); + vaadinSession.setAttribute(SESSION_DATA_KEY, sessionData); + logger.info("Created new chat session"); // Greeting will be displayed automatically when it arrives via the output channel } private void refreshFooter() { remove(footer); - footer = new Footer(documentService.getDocumentCount(), documentService.getChunkCount()); + footer = new Footer(documentService.getDocumentCount() + " documents \u00b7 " + documentService.getChunkCount() + " chunks"); add(footer); } @@ -159,36 +181,24 @@ private void toggleSessionPanel() { } private ChatSession getCurrentSession() { - var ui = getUI().orElse(null); - if (ui == null) return null; var vaadinSession = VaadinSession.getCurrent(); - var sessionKey = getSessionKey(ui); - var sessionData = (SessionData) vaadinSession.getAttribute(sessionKey); + var sessionData = (SessionData) vaadinSession.getAttribute(SESSION_DATA_KEY); return sessionData != null ? sessionData.chatSession() : null; } private record SessionData(ChatSession chatSession, VaadinOutputChannel outputChannel) { } - /** - * Get the session attribute key for this UI instance. - * Each browser tab (UI) gets its own chat session to prevent cross-talk. - */ - private String getSessionKey(UI ui) { - return "sessionData-" + ui.getUIId(); - } - private SessionData getOrCreateSession(UI ui) { var vaadinSession = VaadinSession.getCurrent(); - var sessionKey = getSessionKey(ui); - var sessionData = (SessionData) vaadinSession.getAttribute(sessionKey); + var sessionData = (SessionData) vaadinSession.getAttribute(SESSION_DATA_KEY); if (sessionData == null) { var outputChannel = new VaadinOutputChannel(ui); var chatSession = chatbot.createSession(currentUser, outputChannel, null, UUID.randomUUID().toString()); sessionData = new SessionData(chatSession, outputChannel); - vaadinSession.setAttribute(sessionKey, sessionData); - logger.info("Created new chat session for UI {}", ui.getUIId()); + vaadinSession.setAttribute(SESSION_DATA_KEY, sessionData); + logger.info("Created new chat session"); } return sessionData; @@ -285,10 +295,9 @@ private void scrollToBottom() { messagesScroller.getElement().executeJs("this.scrollTop = this.scrollHeight"); } - private void restorePreviousMessages(UI ui) { + private void restorePreviousMessages() { var vaadinSession = VaadinSession.getCurrent(); - var sessionKey = getSessionKey(ui); - var sessionData = (SessionData) vaadinSession.getAttribute(sessionKey); + var sessionData = (SessionData) vaadinSession.getAttribute(SESSION_DATA_KEY); if (sessionData == null) { return; } @@ -312,7 +321,7 @@ private void restorePreviousMessages(UI ui) { * Uses CompletableFuture to signal when a response to a user message has been received. */ private class VaadinOutputChannel implements OutputChannel { - private final UI ui; + private volatile UI ui; private final AtomicReference> pendingResponse = new AtomicReference<>(); volatile Div currentToolCallIndicator; // package-private for access from sendMessage @@ -320,6 +329,13 @@ private class VaadinOutputChannel implements OutputChannel { this.ui = ui; } + /** + * Update the UI reference when the UI is recreated (e.g., page refresh, reconnect). + */ + void updateUI(UI ui) { + this.ui = ui; + } + /** * Set a future that will be completed when the next assistant message arrives. */ diff --git a/src/main/java/com/embabel/air/ai/vaadin/DocumentListSection.java b/src/main/java/com/embabel/air/ai/vaadin/DocumentListSection.java deleted file mode 100644 index 6cbc12e..0000000 --- a/src/main/java/com/embabel/air/ai/vaadin/DocumentListSection.java +++ /dev/null @@ -1,129 +0,0 @@ -package com.embabel.air.ai.vaadin; - -import com.embabel.air.ai.rag.DocumentService; -import com.vaadin.flow.component.button.Button; -import com.vaadin.flow.component.button.ButtonVariant; -import com.vaadin.flow.component.html.Div; -import com.vaadin.flow.component.html.H4; -import com.vaadin.flow.component.html.Span; -import com.vaadin.flow.component.icon.VaadinIcon; -import com.vaadin.flow.component.notification.Notification; -import com.vaadin.flow.component.notification.NotificationVariant; -import com.vaadin.flow.component.orderedlayout.HorizontalLayout; -import com.vaadin.flow.component.orderedlayout.VerticalLayout; - -/** - * Document list section for the documents drawer. - * Shows list of indexed documents with their context. - */ -public class DocumentListSection extends VerticalLayout { - - private final DocumentService documentService; - private final Runnable onDocumentsChanged; - private final VerticalLayout documentsList; - private final Span documentCountSpan; - private final Span chunkCountSpan; - - public DocumentListSection(DocumentService documentService, Runnable onDocumentsChanged) { - this.documentService = documentService; - this.onDocumentsChanged = onDocumentsChanged; - - setPadding(true); - setSpacing(true); - - // Stats section - var statsTitle = new H4("Statistics"); - statsTitle.addClassName("section-title"); - - var statsContainer = new Div(); - statsContainer.addClassName("stats-container"); - - documentCountSpan = new Span(); - documentCountSpan.addClassName("stat-value"); - - chunkCountSpan = new Span(); - chunkCountSpan.addClassName("stat-value"); - - statsContainer.add(createStatRow("Documents", documentCountSpan), createStatRow("Chunks", chunkCountSpan)); - - // Documents list section - var docsTitle = new H4("Documents"); - docsTitle.addClassName("section-title"); - docsTitle.getStyle().set("margin-top", "var(--lumo-space-m)"); - - documentsList = new VerticalLayout(); - documentsList.setPadding(false); - documentsList.setSpacing(false); - documentsList.addClassName("documents-list"); - - add(statsTitle, statsContainer, docsTitle, documentsList); - - refresh(); - } - - private Div createStatRow(String label, Span valueSpan) { - var row = new Div(); - row.addClassName("stat-row"); - - var labelSpan = new Span(label); - labelSpan.addClassName("stat-label"); - - row.add(labelSpan, valueSpan); - return row; - } - - public void refresh() { - documentCountSpan.setText(String.valueOf(documentService.getDocumentCount())); - chunkCountSpan.setText(String.valueOf(documentService.getChunkCount())); - - documentsList.removeAll(); - - var documents = documentService.getDocuments(); - if (documents.isEmpty()) { - var emptyLabel = new Span("No documents indexed yet"); - emptyLabel.addClassName("empty-list-label"); - documentsList.add(emptyLabel); - } else { - for (var doc : documents) { - documentsList.add(createDocumentRow(doc)); - } - } - } - - private HorizontalLayout createDocumentRow(DocumentService.DocumentInfo doc) { - var row = new HorizontalLayout(); - row.setWidthFull(); - row.setAlignItems(Alignment.CENTER); - row.addClassName("document-row"); - - var infoSection = new VerticalLayout(); - infoSection.setPadding(false); - infoSection.setSpacing(false); - - var title = new Span(doc.title() != null ? doc.title() : doc.uri()); - title.addClassName("document-title"); - - var contextBadge = new Span(doc.context()); - contextBadge.addClassName("context-badge"); - - infoSection.add(title, contextBadge); - - var deleteButton = new Button(VaadinIcon.TRASH.create()); - deleteButton.addThemeVariants(ButtonVariant.LUMO_TERTIARY, ButtonVariant.LUMO_ERROR, ButtonVariant.LUMO_SMALL); - deleteButton.addClickListener(e -> { - if (documentService.deleteDocument(doc.uri())) { - Notification.show("Deleted: " + doc.title(), 3000, Notification.Position.BOTTOM_CENTER); - refresh(); - onDocumentsChanged.run(); - } else { - Notification.show("Failed to delete", 3000, Notification.Position.BOTTOM_CENTER) - .addThemeVariants(NotificationVariant.LUMO_ERROR); - } - }); - - row.add(infoSection, deleteButton); - row.setFlexGrow(1, infoSection); - - return row; - } -} diff --git a/src/main/java/com/embabel/air/ai/vaadin/DocumentsDrawer.java b/src/main/java/com/embabel/air/ai/vaadin/DocumentsDrawer.java index 2faf0a3..341ef9d 100644 --- a/src/main/java/com/embabel/air/ai/vaadin/DocumentsDrawer.java +++ b/src/main/java/com/embabel/air/ai/vaadin/DocumentsDrawer.java @@ -2,6 +2,7 @@ import com.embabel.air.ai.rag.DocumentService; import com.embabel.air.backend.Customer; +import com.embabel.vaadin.document.DocumentListSection; import com.vaadin.flow.component.Key; import com.vaadin.flow.component.ShortcutRegistration; import com.vaadin.flow.component.button.Button; diff --git a/src/main/java/com/embabel/air/ai/vaadin/Footer.java b/src/main/java/com/embabel/air/ai/vaadin/Footer.java deleted file mode 100644 index 56404f8..0000000 --- a/src/main/java/com/embabel/air/ai/vaadin/Footer.java +++ /dev/null @@ -1,29 +0,0 @@ -package com.embabel.air.ai.vaadin; - -import com.vaadin.flow.component.html.Span; -import com.vaadin.flow.component.orderedlayout.HorizontalLayout; - -/** - * Footer component showing copyright and document statistics. - */ -public class Footer extends HorizontalLayout { - - public Footer(int documentCount, int chunkCount) { - setWidthFull(); - setPadding(false); - setSpacing(true); - setJustifyContentMode(JustifyContentMode.CENTER); - addClassName("app-footer"); - - var copyright = new Span("© Embabel 2025"); - copyright.addClassName("footer-copyright"); - - var separator = new Span("·"); - separator.addClassName("footer-separator"); - - var stats = new Span(documentCount + " documents · " + chunkCount + " chunks"); - stats.addClassName("footer-stats"); - - add(copyright, separator, stats); - } -} diff --git a/src/main/java/com/embabel/air/ai/vaadin/SessionPanel.java b/src/main/java/com/embabel/air/ai/vaadin/SessionPanel.java index 97b0ed2..2e7636a 100644 --- a/src/main/java/com/embabel/air/ai/vaadin/SessionPanel.java +++ b/src/main/java/com/embabel/air/ai/vaadin/SessionPanel.java @@ -7,6 +7,9 @@ import com.embabel.chat.Asset; import com.embabel.chat.AssetView; import com.embabel.chat.ChatSession; +import com.embabel.dice.proposition.Proposition; +import com.embabel.dice.proposition.PropositionRepository; +import com.embabel.dice.proposition.extraction.IncrementalPropositionExtraction; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.SerializationFeature; @@ -16,14 +19,22 @@ import com.vaadin.flow.component.html.Span; import com.vaadin.flow.component.icon.Icon; import com.vaadin.flow.component.icon.VaadinIcon; +import com.vaadin.flow.component.notification.Notification; +import com.vaadin.flow.component.notification.NotificationVariant; import com.vaadin.flow.component.orderedlayout.FlexComponent; import com.vaadin.flow.component.orderedlayout.HorizontalLayout; import com.vaadin.flow.component.orderedlayout.VerticalLayout; +import com.vaadin.flow.component.progressbar.ProgressBar; import com.vaadin.flow.component.tabs.Tab; import com.vaadin.flow.component.tabs.Tabs; +import com.vaadin.flow.component.upload.Upload; +import com.vaadin.flow.component.upload.receivers.MemoryBuffer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.time.ZoneId; import java.time.format.DateTimeFormatter; +import java.util.Comparator; import java.util.function.Supplier; /** @@ -32,6 +43,7 @@ */ class SessionPanel extends Div { + private static final Logger logger = LoggerFactory.getLogger(SessionPanel.class); private static final DateTimeFormatter TIME_FORMAT = DateTimeFormatter.ofPattern("HH:mm:ss"); private static final ObjectMapper objectMapper = new ObjectMapper() .enable(SerializationFeature.INDENT_OUTPUT); @@ -42,14 +54,23 @@ class SessionPanel extends Div { private final Supplier sessionSupplier; private final AgentPlatform agentPlatform; + private final PropositionRepository propositionRepository; + private final IncrementalPropositionExtraction propositionExtraction; + private final Customer customer; // Content panels private final VerticalLayout assetsContent; private final VerticalLayout stateContent; + private final VerticalLayout memoryContent; - SessionPanel(Customer user, Supplier sessionSupplier, AgentPlatform agentPlatform) { + SessionPanel(Customer user, Supplier sessionSupplier, AgentPlatform agentPlatform, + PropositionRepository propositionRepository, + IncrementalPropositionExtraction propositionExtraction) { this.sessionSupplier = sessionSupplier; this.agentPlatform = agentPlatform; + this.propositionRepository = propositionRepository; + this.propositionExtraction = propositionExtraction; + this.customer = user; addClassName("session-panel-container"); @@ -106,8 +127,9 @@ class SessionPanel extends Div { // Tabs var assetsTab = new Tab(VaadinIcon.CUBE.create(), new Span("Assets")); var stateTab = new Tab(VaadinIcon.COG.create(), new Span("State")); + var memoryTab = new Tab(VaadinIcon.LIGHTBULB.create(), new Span("Memory")); - var tabs = new Tabs(assetsTab, stateTab); + var tabs = new Tabs(assetsTab, stateTab, memoryTab); tabs.setWidthFull(); sidePanel.add(tabs); @@ -128,7 +150,13 @@ class SessionPanel extends Div { stateContent.setSpacing(true); stateContent.setVisible(false); - contentArea.add(assetsContent, stateContent); + // Memory content + memoryContent = new VerticalLayout(); + memoryContent.setPadding(true); + memoryContent.setSpacing(true); + memoryContent.setVisible(false); + + contentArea.add(assetsContent, stateContent, memoryContent); sidePanel.add(contentArea); sidePanel.setFlexGrow(1, contentArea); @@ -137,11 +165,14 @@ class SessionPanel extends Div { var selected = event.getSelectedTab(); assetsContent.setVisible(selected == assetsTab); stateContent.setVisible(selected == stateTab); + memoryContent.setVisible(selected == memoryTab); if (selected == assetsTab) { refreshAssets(); } else if (selected == stateTab) { refreshState(); + } else if (selected == memoryTab) { + refreshMemory(); } }); @@ -351,6 +382,197 @@ private void refreshState() { stateContent.add(objectsList); } + private void refreshMemory() { + memoryContent.removeAll(); + + // Learn button row + memoryContent.add(createLearnUpload()); + + var memoryTitle = new H4("Remembered Facts"); + memoryTitle.addClassName("section-title"); + memoryContent.add(memoryTitle); + + try { + var propositions = propositionRepository.findByContextIdValue(customer.getId()); + if (propositions.isEmpty()) { + var emptyMessage = new Span("No memories yet. Facts are extracted from conversations automatically, or upload a document with Learn."); + emptyMessage.addClassName("empty-list-label"); + memoryContent.add(emptyMessage); + return; + } + + // Sort by created descending (most recent first) + var sorted = propositions.stream() + .sorted(Comparator.comparing(Proposition::getCreated).reversed()) + .toList(); + + for (var proposition : sorted) { + memoryContent.add(createPropositionCard(proposition)); + } + } catch (Exception e) { + memoryContent.add(new Span("Error loading memories: " + e.getMessage())); + } + } + + private VerticalLayout createLearnUpload() { + var wrapper = new VerticalLayout(); + wrapper.setPadding(false); + wrapper.setSpacing(true); + wrapper.setWidthFull(); + + var statusRow = new HorizontalLayout(); + statusRow.setWidthFull(); + statusRow.setAlignItems(FlexComponent.Alignment.CENTER); + statusRow.setSpacing(true); + statusRow.setPadding(false); + statusRow.setVisible(false); + + var statusLabel = new Span(); + var statusBar = new ProgressBar(); + statusRow.add(statusLabel, statusBar); + statusRow.setFlexGrow(1, statusBar); + + var buffer = new MemoryBuffer(); + var upload = new Upload(buffer); + upload.setDropAllowed(false); + var learnButton = new Button("Learn", VaadinIcon.BOOK.create()); + upload.setUploadButton(learnButton); + upload.setAcceptedFileTypes( + ".pdf", ".txt", ".md", ".html", ".htm", + ".doc", ".docx", ".odt", ".rtf", + "application/pdf", "text/plain", "text/markdown", "text/html", + "application/msword", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + ); + upload.setMaxFileSize(10 * 1024 * 1024); + upload.setMaxFiles(1); + upload.addClassName("learn-upload"); + + // Clear file list so it never shows inline + upload.getElement().addEventListener("upload-start", e -> + upload.getElement().executeJs("this.files = []")); + upload.getElement().addEventListener("upload-success", e -> + upload.getElement().executeJs("this.files = []")); + + upload.addStartedListener(event -> { + statusLabel.setText("Uploading: " + event.getFileName()); + statusBar.setIndeterminate(false); + statusBar.setValue(0); + statusRow.setVisible(true); + }); + + upload.addProgressListener(event -> { + if (event.getContentLength() > 0) { + statusBar.setIndeterminate(false); + statusBar.setValue((double) event.getReadBytes() / event.getContentLength()); + } else { + statusBar.setIndeterminate(true); + } + }); + + upload.addSucceededListener(event -> { + var filename = event.getFileName(); + statusLabel.setText("Extracting memories from: " + filename); + statusBar.setIndeterminate(true); + try { + propositionExtraction.rememberFile(buffer.getInputStream(), filename, customer); + getUI().ifPresent(ui -> new Thread(() -> { + try { + Thread.sleep(5000); + ui.access(() -> { + statusRow.setVisible(false); + refreshMemory(); + }); + } catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + } + }).start()); + } catch (Exception e) { + logger.error("Failed to learn file: {}", filename, e); + statusLabel.setText("Error: " + e.getMessage()); + statusBar.setVisible(false); + } + }); + + upload.addFailedListener(event -> { + logger.error("Upload failed: {}", event.getReason().getMessage()); + statusRow.setVisible(false); + Notification.show("Upload failed: " + event.getReason().getMessage(), + 5000, Notification.Position.BOTTOM_CENTER) + .addThemeVariants(NotificationVariant.LUMO_ERROR); + }); + + wrapper.add(upload, statusRow); + return wrapper; + } + + private Div createPropositionCard(Proposition proposition) { + var card = new Div(); + card.getStyle() + .set("background", "var(--sb-bg-light)") + .set("border", "1px solid var(--sb-border)") + .set("border-left", "3px solid var(--sb-accent)") + .set("border-radius", "var(--lumo-border-radius-m)") + .set("padding", "var(--lumo-space-s)") + .set("margin-bottom", "var(--lumo-space-s)"); + + // Proposition text + var text = new Span(proposition.getText()); + text.getStyle() + .set("display", "block") + .set("color", "var(--sb-text-primary)") + .set("font-size", "var(--lumo-font-size-s)"); + card.add(text); + + // Metadata row: confidence + timestamp + var metaRow = new HorizontalLayout(); + metaRow.setWidthFull(); + metaRow.setJustifyContentMode(FlexComponent.JustifyContentMode.BETWEEN); + metaRow.setAlignItems(FlexComponent.Alignment.CENTER); + metaRow.setPadding(false); + metaRow.setSpacing(false); + metaRow.getStyle().set("margin-top", "var(--lumo-space-xs)"); + + var confidence = new Span("%.0f%% confidence".formatted(proposition.getConfidence() * 100)); + confidence.getStyle() + .set("color", "var(--sb-text-muted)") + .set("font-size", "var(--lumo-font-size-xs)"); + + var timestamp = proposition.getCreated() + .atZone(ZoneId.systemDefault()) + .format(TIME_FORMAT); + var time = new Span(timestamp); + time.getStyle() + .set("color", "var(--sb-text-muted)") + .set("font-size", "var(--lumo-font-size-xs)"); + + metaRow.add(confidence, time); + card.add(metaRow); + + // Entity mentions + var mentions = proposition.getMentions(); + if (mentions != null && !mentions.isEmpty()) { + var mentionsRow = new HorizontalLayout(); + mentionsRow.setPadding(false); + mentionsRow.setSpacing(true); + mentionsRow.getStyle().set("margin-top", "var(--lumo-space-xs)"); + + for (var mention : mentions) { + var badge = new Span(mention.getSpan() + " (" + mention.getType() + ")"); + badge.getStyle() + .set("background", "var(--sb-bg-medium)") + .set("color", "var(--sb-text-secondary)") + .set("padding", "2px 6px") + .set("border-radius", "4px") + .set("font-size", "var(--lumo-font-size-xs)"); + mentionsRow.add(badge); + } + card.add(mentionsRow); + } + + return card; + } + private String getInitials(String name) { if (name == null || name.isBlank()) return "?"; var parts = name.trim().split("\\s+"); diff --git a/src/main/java/com/embabel/air/ai/vaadin/UserSection.java b/src/main/java/com/embabel/air/ai/vaadin/UserSection.java deleted file mode 100644 index 3efa808..0000000 --- a/src/main/java/com/embabel/air/ai/vaadin/UserSection.java +++ /dev/null @@ -1,59 +0,0 @@ -package com.embabel.air.ai.vaadin; - -import com.embabel.air.backend.Customer; -import com.vaadin.flow.component.button.Button; -import com.vaadin.flow.component.button.ButtonVariant; -import com.vaadin.flow.component.html.Div; -import com.vaadin.flow.component.html.Span; -import com.vaadin.flow.component.orderedlayout.FlexComponent; -import com.vaadin.flow.component.orderedlayout.HorizontalLayout; - -/** - * User section component showing avatar, name, and logout button. - */ -class UserSection extends HorizontalLayout { - - UserSection(Customer user, Runnable onProfileClick) { - setAlignItems(FlexComponent.Alignment.CENTER); - setSpacing(true); - - // Profile chip with avatar and name - var profileChip = new HorizontalLayout(); - profileChip.addClassName("profile-chip"); - profileChip.setAlignItems(FlexComponent.Alignment.CENTER); - profileChip.setSpacing(false); - - // Avatar with initials - var initials = getInitials(user.getDisplayName()); - var avatar = new Div(); - avatar.setText(initials); - avatar.addClassName("user-avatar"); - - var userName = new Span(user.getDisplayName()); - userName.addClassName("user-name"); - - profileChip.add(avatar, userName); - - // Make profile chip clickable - profileChip.getStyle().set("cursor", "pointer"); - profileChip.addClickListener(e -> onProfileClick.run()); - - // Logout button - var logoutButton = new Button("Logout", e -> { - getUI().ifPresent(ui -> ui.getPage().setLocation("/logout")); - }); - logoutButton.addThemeVariants(ButtonVariant.LUMO_TERTIARY, ButtonVariant.LUMO_SMALL); - logoutButton.addClassName("logout-button"); - - add(profileChip, logoutButton); - } - - private String getInitials(String name) { - if (name == null || name.isBlank()) return "?"; - var parts = name.trim().split("\\s+"); - if (parts.length >= 2) { - return (parts[0].substring(0, 1) + parts[parts.length - 1].substring(0, 1)).toUpperCase(); - } - return name.substring(0, Math.min(2, name.length())).toUpperCase(); - } -} diff --git a/src/main/java/com/embabel/air/backend/Airport.java b/src/main/java/com/embabel/air/backend/Airport.java new file mode 100644 index 0000000..e03fa7d --- /dev/null +++ b/src/main/java/com/embabel/air/backend/Airport.java @@ -0,0 +1,9 @@ +package com.embabel.air.backend; + +import com.embabel.agent.rag.model.NamedEntity; +import org.jspecify.annotations.NonNull; + +public interface Airport extends NamedEntity { + + @NonNull String getCode(); +} diff --git a/src/main/java/com/embabel/air/backend/BookingService.java b/src/main/java/com/embabel/air/backend/BookingService.java new file mode 100644 index 0000000..623f48e --- /dev/null +++ b/src/main/java/com/embabel/air/backend/BookingService.java @@ -0,0 +1,35 @@ +package com.embabel.air.backend; + +import java.time.LocalDate; +import java.util.List; + +public interface BookingService { + + /** + * Search for available direct flights between two airports on a given date. + */ + List searchDirectFlights(String fromAirport, String toAirport, LocalDate date); + + /** + * Search for available routes (direct or connecting) between two airports. + * Returns itineraries, each containing one or more flight segments. + */ + List searchRoutes(String fromAirport, String toAirport, LocalDate date, int maxConnections); + + /** + * Book a trip for a customer given selected flight segment IDs. + * Creates a Reservation, assigns segments, decrements seats. + */ + Reservation book(Customer customer, List flightSegmentIds); + + /** + * Cancel a reservation. Releases seats back to available inventory. + */ + void cancel(String bookingReference); + + /** + * Rebook an existing reservation onto different flights. + * Cancels old segments, books new ones. + */ + Reservation rebook(String bookingReference, List newFlightSegmentIds); +} diff --git a/src/main/java/com/embabel/air/backend/City.java b/src/main/java/com/embabel/air/backend/City.java new file mode 100644 index 0000000..aaa089f --- /dev/null +++ b/src/main/java/com/embabel/air/backend/City.java @@ -0,0 +1,6 @@ +package com.embabel.air.backend; + +import com.embabel.agent.rag.model.NamedEntity; + +public interface City extends NamedEntity { +} diff --git a/src/main/java/com/embabel/air/backend/Country.java b/src/main/java/com/embabel/air/backend/Country.java new file mode 100644 index 0000000..e9f7541 --- /dev/null +++ b/src/main/java/com/embabel/air/backend/Country.java @@ -0,0 +1,6 @@ +package com.embabel.air.backend; + +import com.embabel.agent.rag.model.NamedEntity; + +public interface Country extends NamedEntity { +} diff --git a/src/main/java/com/embabel/air/backend/Customer.java b/src/main/java/com/embabel/air/backend/Customer.java index 83d6f25..54335a5 100644 --- a/src/main/java/com/embabel/air/backend/Customer.java +++ b/src/main/java/com/embabel/air/backend/Customer.java @@ -2,6 +2,7 @@ import com.embabel.agent.api.annotation.LlmTool; import com.embabel.agent.api.identity.User; +import com.embabel.agent.rag.model.NamedEntity; import jakarta.persistence.*; import org.jetbrains.annotations.ApiStatus; import org.jetbrains.annotations.Nullable; @@ -9,6 +10,8 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; +import java.util.Set; /** * User model for Embabel Air. @@ -18,7 +21,7 @@ @Index(name = "idx_customer_username", columnList = "username", unique = true), @Index(name = "idx_customer_email", columnList = "email") }) -public class Customer implements User { +public class Customer implements User, NamedEntity { @Id @GeneratedValue(strategy = GenerationType.UUID) @@ -135,4 +138,21 @@ public void addReservation(Reservation reservation) { public void setStatus(SkyPointsStatus skyPointsStatus) { this.skyPointsStatus = skyPointsStatus; } + + // NamedEntity implementation + + @Override + public @NonNull String getName() { + return displayName; + } + + @Override + public @NonNull String getDescription() { + return "Customer: " + displayName; + } + + @Override + public @NonNull String embeddableValue() { + return displayName; + } } diff --git a/src/main/java/com/embabel/air/backend/DefaultBookingService.java b/src/main/java/com/embabel/air/backend/DefaultBookingService.java new file mode 100644 index 0000000..4c590b5 --- /dev/null +++ b/src/main/java/com/embabel/air/backend/DefaultBookingService.java @@ -0,0 +1,244 @@ +package com.embabel.air.backend; + +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.time.DayOfWeek; +import java.time.Duration; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Random; + +@Service +@Transactional +public class DefaultBookingService implements BookingService { + + private final FlightSegmentRepository flightSegmentRepository; + private final ReservationRepository reservationRepository; + private final Random random = new Random(42); + + public DefaultBookingService(FlightSegmentRepository flightSegmentRepository, + ReservationRepository reservationRepository) { + this.flightSegmentRepository = flightSegmentRepository; + this.reservationRepository = reservationRepository; + } + + @Override + @Transactional(readOnly = true) + public List searchDirectFlights(String fromAirport, String toAirport, LocalDate date) { + var start = date.atStartOfDay(); + var end = date.plusDays(1).atStartOfDay(); + return flightSegmentRepository + .findByDepartureAirportCodeAndArrivalAirportCodeAndDepartureDateTimeBetween( + fromAirport.toUpperCase(), toAirport.toUpperCase(), start, end) + .stream() + .filter(s -> s.getReservation() == null && s.getSeatsLeft() > 0) + .toList(); + } + + @Override + @Transactional(readOnly = true) + public List searchRoutes(String fromAirport, String toAirport, LocalDate date, int maxConnections) { + var from = fromAirport.toUpperCase(); + var to = toAirport.toUpperCase(); + var results = new ArrayList(); + + // Direct flights + var directFlights = searchDirectFlights(from, to, date); + for (var flight : directFlights) { + results.add(buildItinerary(List.of(flight))); + } + + // Connecting flights (1 stop) + if (maxConnections >= 1) { + var dayStart = date.atStartOfDay(); + var dayEnd = date.plusDays(1).atStartOfDay(); + + var firstLegs = flightSegmentRepository + .findByDepartureAirportCodeAndDepartureDateTimeBetween(from, dayStart, dayEnd) + .stream() + .filter(s -> s.getReservation() == null && s.getSeatsLeft() > 0) + .filter(s -> !s.getArrivalAirportCode().equals(from)) + .toList(); + + for (var first : firstLegs) { + var connectionAirport = first.getArrivalAirportCode(); + if (connectionAirport.equals(to)) { + continue; // already covered by direct + } + + // Look for connecting flights 1-6 hours after arrival + var minConnect = first.getArrivalDateTime().plusHours(1); + var maxConnect = first.getArrivalDateTime().plusHours(6); + + var secondLegs = flightSegmentRepository + .findByDepartureAirportCodeAndArrivalAirportCodeAndDepartureDateTimeBetween( + connectionAirport, to, minConnect, maxConnect) + .stream() + .filter(s -> s.getReservation() == null && s.getSeatsLeft() > 0) + .toList(); + + for (var second : secondLegs) { + results.add(buildItinerary(List.of(first, second))); + } + } + } + + results.sort(Comparator.comparing(Itinerary::totalTravelTime)); + return results; + } + + @Override + public Reservation book(Customer customer, List flightSegmentIds) { + var segments = flightSegmentIds.stream() + .map(id -> flightSegmentRepository.findById(id) + .orElseThrow(() -> new IllegalArgumentException("Flight segment not found: " + id))) + .toList(); + + for (var seg : segments) { + if (seg.getReservation() != null) { + throw new IllegalStateException( + "Flight segment %s is already booked".formatted(seg.getId())); + } + if (seg.getSeatsLeft() <= 0) { + throw new IllegalStateException( + "No seats left on flight %s → %s departing %s".formatted( + seg.getDepartureAirportCode(), seg.getArrivalAirportCode(), + seg.getDepartureDateTime())); + } + } + + var reservation = new Reservation(generateBookingReference()); + var totalPrice = BigDecimal.ZERO; + + for (var seg : segments) { + reservation.addFlightSegment(seg); + seg.setSeatsLeft(seg.getSeatsLeft() - 1); + totalPrice = totalPrice.add(estimateSegmentPrice(seg)); + } + + reservation.setPaidAmount(totalPrice); + customer.addReservation(reservation); + reservationRepository.save(reservation); + return reservation; + } + + @Override + public void cancel(String bookingReference) { + var reservation = reservationRepository.findByBookingReference(bookingReference); + if (reservation == null) { + throw new IllegalArgumentException("Reservation not found: " + bookingReference); + } + + for (var seg : new ArrayList<>(reservation.getFlightSegments())) { + seg.setSeatsLeft(seg.getSeatsLeft() + 1); + reservation.removeFlightSegment(seg); + } + + var customer = reservation.getCustomer(); + if (customer != null) { + customer.getReservations().remove(reservation); + } + reservationRepository.delete(reservation); + } + + @Override + public Reservation rebook(String bookingReference, List newFlightSegmentIds) { + var reservation = reservationRepository.findByBookingReference(bookingReference); + if (reservation == null) { + throw new IllegalArgumentException("Reservation not found: " + bookingReference); + } + + var customer = reservation.getCustomer(); + + // Release old segments + for (var seg : new ArrayList<>(reservation.getFlightSegments())) { + seg.setSeatsLeft(seg.getSeatsLeft() + 1); + reservation.removeFlightSegment(seg); + } + + // Load and validate new segments + var newSegments = newFlightSegmentIds.stream() + .map(id -> flightSegmentRepository.findById(id) + .orElseThrow(() -> new IllegalArgumentException("Flight segment not found: " + id))) + .toList(); + + var totalPrice = BigDecimal.ZERO; + for (var seg : newSegments) { + if (seg.getReservation() != null) { + throw new IllegalStateException( + "Flight segment %s is already booked".formatted(seg.getId())); + } + if (seg.getSeatsLeft() <= 0) { + throw new IllegalStateException( + "No seats left on flight %s → %s departing %s".formatted( + seg.getDepartureAirportCode(), seg.getArrivalAirportCode(), + seg.getDepartureDateTime())); + } + reservation.addFlightSegment(seg); + seg.setSeatsLeft(seg.getSeatsLeft() - 1); + totalPrice = totalPrice.add(estimateSegmentPrice(seg)); + } + + reservation.setPaidAmount(totalPrice); + return reservation; + } + + private Itinerary buildItinerary(List segments) { + var first = segments.getFirst(); + var last = segments.getLast(); + var totalTravel = Duration.between(first.getDepartureDateTime(), last.getArrivalDateTime()); + + var totalLayover = Duration.ZERO; + for (int i = 0; i < segments.size() - 1; i++) { + var arrivalTime = segments.get(i).getArrivalDateTime(); + var nextDeparture = segments.get(i + 1).getDepartureDateTime(); + totalLayover = totalLayover.plus(Duration.between(arrivalTime, nextDeparture)); + } + + var price = BigDecimal.ZERO; + for (var seg : segments) { + price = price.add(estimateSegmentPrice(seg)); + } + + return new Itinerary(segments, totalTravel, totalLayover, price); + } + + /** + * Simple pricing heuristic: + * Base: $50 + $0.50 per minute of flight time. + * Weekend surcharge: +30% for Saturday/Sunday departures. + * Lucky-3 discount: -20% if departure minute is divisible by 3. + */ + BigDecimal estimateSegmentPrice(FlightSegment segment) { + long flightMinutes = segment.getDuration().toMinutes(); + var base = BigDecimal.valueOf(50).add( + BigDecimal.valueOf(flightMinutes).multiply(BigDecimal.valueOf(0.50))); + + DayOfWeek day = segment.getDepartureDateTime().getDayOfWeek(); + if (day == DayOfWeek.SATURDAY || day == DayOfWeek.SUNDAY) { + base = base.multiply(BigDecimal.valueOf(1.30)); + } + + int minute = segment.getDepartureDateTime().getMinute(); + if (minute % 3 == 0) { + base = base.multiply(BigDecimal.valueOf(0.80)); + } + + return base.setScale(2, RoundingMode.HALF_UP); + } + + private String generateBookingReference() { + var chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + var sb = new StringBuilder(6); + for (int i = 0; i < 6; i++) { + sb.append(chars.charAt(random.nextInt(chars.length()))); + } + return sb.toString(); + } +} diff --git a/src/main/java/com/embabel/air/backend/DevDataLoader.java b/src/main/java/com/embabel/air/backend/DevDataLoader.java index aa9fe96..6129c3a 100644 --- a/src/main/java/com/embabel/air/backend/DevDataLoader.java +++ b/src/main/java/com/embabel/air/backend/DevDataLoader.java @@ -183,11 +183,34 @@ private String generateBookingReference() { return sb.toString(); } + private static final List POPULAR_ROUTES = List.of( + new String[]{"JFK", "LAX"}, new String[]{"LAX", "JFK"}, + new String[]{"JFK", "SFO"}, new String[]{"SFO", "JFK"}, + new String[]{"JFK", "MIA"}, new String[]{"MIA", "JFK"}, + new String[]{"JFK", "ORD"}, new String[]{"ORD", "JFK"}, + new String[]{"LAX", "ORD"}, new String[]{"ORD", "LAX"}, + new String[]{"LAX", "SEA"}, new String[]{"SEA", "LAX"}, + new String[]{"ORD", "MIA"}, new String[]{"MIA", "ORD"}, + new String[]{"DFW", "ATL"}, new String[]{"ATL", "DFW"}, + new String[]{"SFO", "SEA"}, new String[]{"SEA", "SFO"}, + new String[]{"BOS", "DCA"}, new String[]{"DCA", "BOS"} + ); + private void seedAvailableFlights() { var now = LocalDateTime.now(); - // Create available flights for the next 30 days - for (int i = 0; i < 100; i++) { + // Seed popular routes: 2 flights per day for 30 days on each route + for (var route : POPULAR_ROUTES) { + for (int day = 0; day < 30; day++) { + // Morning flight + seedFlight(route[0], route[1], now.plusDays(day), 7 + random.nextInt(4)); + // Afternoon/evening flight + seedFlight(route[0], route[1], now.plusDays(day), 14 + random.nextInt(6)); + } + } + + // Additional random flights for variety and connecting options + for (int i = 0; i < 300; i++) { var departure = AIRPORTS.get(random.nextInt(AIRPORTS.size())); var arrival = AIRPORTS.get(random.nextInt(AIRPORTS.size())); while (arrival.equals(departure)) { @@ -196,21 +219,24 @@ private void seedAvailableFlights() { var daysOffset = random.nextInt(30); var hour = 6 + random.nextInt(16); - var departureTime = now.plusDays(daysOffset).withHour(hour).withMinute(random.nextInt(4) * 15); + seedFlight(departure, arrival, now.plusDays(daysOffset), hour); + } + log.info("Created {} available flight segments", flightSegmentRepository.count()); + } - var flightDurationMinutes = 60 + random.nextInt(300); - var arrivalTime = departureTime.plusMinutes(flightDurationMinutes); + private void seedFlight(String departure, String arrival, LocalDateTime baseDate, int hour) { + var departureTime = baseDate.withHour(hour).withMinute(random.nextInt(4) * 15); + var flightDurationMinutes = 60 + random.nextInt(300); + var arrivalTime = departureTime.plusMinutes(flightDurationMinutes); - var airline = AIRLINES.get(random.nextInt(AIRLINES.size())); - var equipment = EQUIPMENT.get(random.nextInt(EQUIPMENT.size())); - var seatsLeft = 10 + random.nextInt(150); + var airline = AIRLINES.get(random.nextInt(AIRLINES.size())); + var equipment = EQUIPMENT.get(random.nextInt(EQUIPMENT.size())); + var seatsLeft = 10 + random.nextInt(150); - var segment = new FlightSegment(departure, departureTime, arrival, arrivalTime, airline); - segment.setEquipment(equipment); - segment.setSeatsLeft(seatsLeft); + var segment = new FlightSegment(departure, departureTime, arrival, arrivalTime, airline); + segment.setEquipment(equipment); + segment.setSeatsLeft(seatsLeft); - flightSegmentRepository.save(segment); - } - log.info("Created {} available flight segments", flightSegmentRepository.count()); + flightSegmentRepository.save(segment); } } diff --git a/src/main/java/com/embabel/air/backend/FlightSegmentRepository.java b/src/main/java/com/embabel/air/backend/FlightSegmentRepository.java index b8f8b03..810dc12 100644 --- a/src/main/java/com/embabel/air/backend/FlightSegmentRepository.java +++ b/src/main/java/com/embabel/air/backend/FlightSegmentRepository.java @@ -13,4 +13,10 @@ List findByDepartureAirportCodeAndArrivalAirportCodeAndDepartureD LocalDateTime startDateTime, LocalDateTime endDateTime ); + + List findByDepartureAirportCodeAndDepartureDateTimeBetween( + String departureAirportCode, + LocalDateTime startDateTime, + LocalDateTime endDateTime + ); } diff --git a/src/main/java/com/embabel/air/backend/Itinerary.java b/src/main/java/com/embabel/air/backend/Itinerary.java new file mode 100644 index 0000000..a43fdf7 --- /dev/null +++ b/src/main/java/com/embabel/air/backend/Itinerary.java @@ -0,0 +1,55 @@ +package com.embabel.air.backend; + +import java.math.BigDecimal; +import java.time.Duration; +import java.util.List; +import java.util.stream.Collectors; + +public record Itinerary( + List segments, + Duration totalTravelTime, + Duration totalLayoverTime, + BigDecimal estimatedPrice +) { + + public int connections() { + return segments.size() - 1; + } + + public String summary() { + var route = segments.stream() + .map(FlightSegment::getDepartureAirportCode) + .collect(Collectors.joining(" → ")); + route += " → " + segments.getLast().getArrivalAirportCode(); + + var sb = new StringBuilder(); + sb.append(route); + sb.append(" | Travel time: ").append(formatDuration(totalTravelTime)); + if (connections() > 0) { + sb.append(" | Layover: ").append(formatDuration(totalLayoverTime)); + } + sb.append(" | Est. price: $").append(estimatedPrice); + sb.append("\n"); + + for (var seg : segments) { + sb.append(" ").append(seg.getDepartureAirportCode()) + .append(" → ").append(seg.getArrivalAirportCode()) + .append(" | ").append(seg.getAirline()) + .append(" | ").append(seg.getEquipment()) + .append(" | Departs ").append(seg.getDepartureDateTime()) + .append(" Arrives ").append(seg.getArrivalDateTime()) + .append(" (").append(formatDuration(seg.getDuration())).append(")") + .append(" | ").append(seg.getSeatsLeft()).append(" seats left") + .append(" | ID: ").append(seg.getId()) + .append("\n"); + } + + return sb.toString().trim(); + } + + private static String formatDuration(Duration d) { + long hours = d.toHours(); + long minutes = d.toMinutesPart(); + return hours + "h " + minutes + "m"; + } +} diff --git a/src/main/java/com/embabel/air/security/LoginView.java b/src/main/java/com/embabel/air/security/LoginView.java index 04b8045..bf0743f 100644 --- a/src/main/java/com/embabel/air/security/LoginView.java +++ b/src/main/java/com/embabel/air/security/LoginView.java @@ -6,6 +6,7 @@ import com.vaadin.flow.component.html.H1; import com.vaadin.flow.component.html.Span; import com.vaadin.flow.component.login.LoginForm; +import com.vaadin.flow.component.login.LoginI18n; import com.vaadin.flow.component.orderedlayout.VerticalLayout; import com.vaadin.flow.router.BeforeEnterEvent; import com.vaadin.flow.router.BeforeEnterObserver; @@ -30,6 +31,9 @@ public LoginView() { setJustifyContentMode(JustifyContentMode.CENTER); loginForm.setAction("login"); + var i18n = LoginI18n.createDefault(); + i18n.getForm().setTitle("Embabel Air"); + loginForm.setI18n(i18n); var title = new H1("Embabel Air"); title.addClassName("login-title"); @@ -39,7 +43,7 @@ public LoginView() { var demoSection = createDemoSection(); - add(title, subtitle, loginForm, demoSection); + add(/*title, subtitle,*/ loginForm, demoSection); var topUser = DevDataLoader.DEMO_USERS.stream().filter(u -> u.level() == SkyPointsStatus.Level.PLATINUM) .findFirst().orElseThrow(); diff --git a/src/main/java/com/embabel/dice/proposition/jdbc/JdbcPropositionRepository.java b/src/main/java/com/embabel/dice/proposition/jdbc/JdbcPropositionRepository.java new file mode 100644 index 0000000..ac63d08 --- /dev/null +++ b/src/main/java/com/embabel/dice/proposition/jdbc/JdbcPropositionRepository.java @@ -0,0 +1,315 @@ +package com.embabel.dice.proposition.jdbc; + +import com.embabel.agent.core.ContextId; +import com.embabel.agent.rag.service.RetrievableIdentifier; +import com.embabel.common.ai.model.EmbeddingService; +import com.embabel.common.core.types.SimilarityResult; +import com.embabel.common.core.types.TextSimilaritySearchRequest; +import com.embabel.dice.proposition.EntityMention; +import com.embabel.dice.proposition.MentionRole; +import com.embabel.dice.proposition.Proposition; +import com.embabel.dice.proposition.PropositionRepository; +import com.embabel.dice.proposition.PropositionStatus; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; +import org.jetbrains.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.jdbc.core.RowMapper; +import org.springframework.jdbc.core.simple.JdbcClient; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.time.Instant; +import java.util.*; +import java.util.stream.Collectors; + +/** + * JDBC/pgvector-backed implementation of {@link PropositionRepository}. + * Uses {@link JdbcClient} for direct SQL operations against the propositions table. + */ +public class JdbcPropositionRepository implements PropositionRepository { + + private static final Logger logger = LoggerFactory.getLogger(JdbcPropositionRepository.class); + + private final JdbcClient jdbcClient; + private final EmbeddingService embeddingService; + private final ObjectMapper objectMapper; + + public JdbcPropositionRepository(JdbcClient jdbcClient, @Nullable EmbeddingService embeddingService) { + this.jdbcClient = jdbcClient; + this.embeddingService = embeddingService; + this.objectMapper = new ObjectMapper(); + this.objectMapper.registerModule(new JavaTimeModule()); + } + + @Override + public Proposition save(Proposition proposition) { + String embedding = null; + if (embeddingService != null) { + float[] vec = embeddingService.embed(proposition.getText()); + embedding = floatArrayToString(vec); + } + + String mentionsJson; + String metadataJson; + try { + mentionsJson = objectMapper.writeValueAsString(proposition.getMentions()); + metadataJson = objectMapper.writeValueAsString(proposition.getMetadata()); + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to serialize proposition data", e); + } + + jdbcClient.sql(""" + INSERT INTO propositions (id, context_id, text, confidence, decay, importance, reasoning, + status, level, reinforce_count, created, revised, last_accessed, + mentions, source_ids, grounding, metadata, uri, embedding) + VALUES (:id, :contextId, :text, :confidence, :decay, :importance, :reasoning, + :status, :level, :reinforceCount, :created, :revised, :lastAccessed, + :mentions::jsonb, :sourceIds, :grounding, :metadata::jsonb, :uri, CAST(:embedding AS vector)) + ON CONFLICT (id) DO UPDATE SET + context_id = EXCLUDED.context_id, + text = EXCLUDED.text, + confidence = EXCLUDED.confidence, + decay = EXCLUDED.decay, + importance = EXCLUDED.importance, + reasoning = EXCLUDED.reasoning, + status = EXCLUDED.status, + level = EXCLUDED.level, + reinforce_count = EXCLUDED.reinforce_count, + revised = EXCLUDED.revised, + last_accessed = EXCLUDED.last_accessed, + mentions = EXCLUDED.mentions, + source_ids = EXCLUDED.source_ids, + grounding = EXCLUDED.grounding, + metadata = EXCLUDED.metadata, + uri = EXCLUDED.uri, + embedding = EXCLUDED.embedding + """) + .param("id", proposition.getId()) + .param("contextId", proposition.getContextIdValue()) + .param("text", proposition.getText()) + .param("confidence", proposition.getConfidence()) + .param("decay", proposition.getDecay()) + .param("importance", proposition.getImportance()) + .param("reasoning", proposition.getReasoning()) + .param("status", proposition.getStatus().name()) + .param("level", proposition.getLevel()) + .param("reinforceCount", proposition.getReinforceCount()) + .param("created", Timestamp.from(proposition.getCreated())) + .param("revised", Timestamp.from(proposition.getRevised())) + .param("lastAccessed", Timestamp.from(proposition.getLastAccessed())) + .param("mentions", mentionsJson) + .param("sourceIds", proposition.getSourceIds().toArray(new String[0])) + .param("grounding", proposition.getGrounding().toArray(new String[0])) + .param("metadata", metadataJson) + .param("uri", proposition.getUri()) + .param("embedding", embedding) + .update(); + + return proposition; + } + + @Override + public Proposition findById(String id) { + return jdbcClient.sql("SELECT * FROM propositions WHERE id = :id") + .param("id", id) + .query(new PropositionRowMapper()) + .optional() + .orElse(null); + } + + @Override + public List findAll() { + return jdbcClient.sql("SELECT * FROM propositions ORDER BY created DESC") + .query(new PropositionRowMapper()) + .list(); + } + + @Override + public boolean delete(String id) { + int count = jdbcClient.sql("DELETE FROM propositions WHERE id = :id") + .param("id", id) + .update(); + return count > 0; + } + + @Override + public int count() { + return jdbcClient.sql("SELECT COUNT(*) FROM propositions") + .query(Integer.class) + .single(); + } + + @Override + public List findByEntity(RetrievableIdentifier entityIdentifier) { + return jdbcClient.sql(""" + SELECT * FROM propositions + WHERE mentions @> :mentionFilter::jsonb + ORDER BY created DESC + """) + .param("mentionFilter", "[{\"resolvedId\":\"" + entityIdentifier.getId() + "\"}]") + .query(new PropositionRowMapper()) + .list(); + } + + @Override + public List> findSimilarWithScores(TextSimilaritySearchRequest request) { + if (embeddingService == null) { + logger.warn("Vector search requested but no embedding service configured"); + return Collections.emptyList(); + } + + float[] queryVec = embeddingService.embed(request.getQuery()); + String embedding = floatArrayToString(queryVec); + + return jdbcClient.sql(""" + SELECT *, (1 - (embedding <=> CAST(:embedding AS vector))) AS score + FROM propositions + WHERE embedding IS NOT NULL + ORDER BY embedding <=> CAST(:embedding AS vector) + LIMIT :topK + """) + .param("embedding", embedding) + .param("topK", request.getTopK()) + .query((rs, rowNum) -> { + Proposition prop = new PropositionRowMapper().mapRow(rs, rowNum); + double score = rs.getDouble("score"); + return SimilarityResult.create(prop, score); + }) + .list() + .stream() + .filter(r -> r.getScore() >= request.getSimilarityThreshold()) + .collect(Collectors.toList()); + } + + @Override + public List findByStatus(PropositionStatus status) { + return jdbcClient.sql("SELECT * FROM propositions WHERE status = :status ORDER BY created DESC") + .param("status", status.name()) + .query(new PropositionRowMapper()) + .list(); + } + + @Override + public List findByGrounding(String chunkId) { + return jdbcClient.sql("SELECT * FROM propositions WHERE :chunkId = ANY(grounding) ORDER BY created DESC") + .param("chunkId", chunkId) + .query(new PropositionRowMapper()) + .list(); + } + + @Override + public List findByMinLevel(int minLevel) { + return jdbcClient.sql("SELECT * FROM propositions WHERE level >= :minLevel ORDER BY created DESC") + .param("minLevel", minLevel) + .query(new PropositionRowMapper()) + .list(); + } + + @Override + public List findByContextIdValue(String contextIdValue) { + return jdbcClient.sql("SELECT * FROM propositions WHERE context_id = :contextId ORDER BY created DESC") + .param("contextId", contextIdValue) + .query(new PropositionRowMapper()) + .list(); + } + + @Override + public boolean supportsType(String type) { + return "Proposition".equals(type); + } + + @Override + public String getLuceneSyntaxNotes() { + return "PostgreSQL pgvector cosine similarity search on proposition text embeddings"; + } + + // === Helpers === + + private String floatArrayToString(float[] arr) { + StringBuilder sb = new StringBuilder("["); + for (int i = 0; i < arr.length; i++) { + if (i > 0) sb.append(","); + sb.append(arr[i]); + } + sb.append("]"); + return sb.toString(); + } + + private class PropositionRowMapper implements RowMapper { + + @Override + public Proposition mapRow(ResultSet rs, int rowNum) throws SQLException { + String mentionsJson = rs.getString("mentions"); + List mentions; + try { + if (mentionsJson == null || mentionsJson.isBlank()) { + mentions = Collections.emptyList(); + } else { + List> mentionMaps = objectMapper.readValue( + mentionsJson, new TypeReference<>() {}); + mentions = mentionMaps.stream().map(m -> new EntityMention( + (String) m.getOrDefault("span", ""), + (String) m.getOrDefault("type", "Entity"), + (String) m.get("resolvedId"), + MentionRole.valueOf( + (String) m.getOrDefault("role", "OTHER")), + Collections.emptyMap() + )).collect(Collectors.toList()); + } + } catch (Exception e) { + logger.warn("Failed to parse mentions JSON: {}", e.getMessage()); + mentions = Collections.emptyList(); + } + + java.sql.Array sourceIdsArray = rs.getArray("source_ids"); + List sourceIds = sourceIdsArray != null + ? Arrays.asList((String[]) sourceIdsArray.getArray()) + : Collections.emptyList(); + + java.sql.Array groundingArray = rs.getArray("grounding"); + List grounding = groundingArray != null + ? Arrays.asList((String[]) groundingArray.getArray()) + : Collections.emptyList(); + + String metadataJson = rs.getString("metadata"); + Map metadata; + try { + metadata = (metadataJson == null || metadataJson.isBlank()) + ? Collections.emptyMap() + : objectMapper.readValue(metadataJson, new TypeReference<>() {}); + } catch (Exception e) { + metadata = Collections.emptyMap(); + } + + Timestamp created = rs.getTimestamp("created"); + Timestamp revised = rs.getTimestamp("revised"); + Timestamp lastAccessed = rs.getTimestamp("last_accessed"); + + return Proposition.create( + rs.getString("id"), + rs.getString("context_id"), + rs.getString("text"), + mentions, + rs.getDouble("confidence"), + rs.getDouble("decay"), + rs.getDouble("importance"), + rs.getString("reasoning"), + grounding, + created != null ? created.toInstant() : Instant.now(), + revised != null ? revised.toInstant() : Instant.now(), + lastAccessed != null ? lastAccessed.toInstant() : Instant.now(), + PropositionStatus.valueOf(rs.getString("status")), + rs.getInt("level"), + sourceIds, + rs.getInt("reinforce_count"), + metadata, + rs.getString("uri") + ); + } + } +} diff --git a/src/main/java/com/embabel/springdata/EntityViewService.java b/src/main/java/com/embabel/springdata/EntityViewService.java index 34a1c5c..d599f6d 100644 --- a/src/main/java/com/embabel/springdata/EntityViewService.java +++ b/src/main/java/com/embabel/springdata/EntityViewService.java @@ -870,7 +870,7 @@ private Object convertToType(Object value, Class targetType) { } private Result convertResult(Object result) { - if (result == null) return Result.text(""); + if (result == null) return Result.text("No result"); if (result instanceof String s) return Result.text(s); if (result instanceof Result r) return r; @@ -1237,7 +1237,7 @@ private Object convertToType(Object value, Class targetType) { } private Result convertResult(Object result) { - if (result == null) return Result.text(""); + if (result == null) return Result.text("No result"); if (result instanceof String s) return Result.text(s); if (result instanceof Result r) return r; diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index 4186788..0105516 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -4,6 +4,9 @@ server: vaadin: heartbeat-interval: 60 close-idle-sessions: false + devmode: + copilot: + enabled: false embabel-air: chunker-config: @@ -20,6 +23,14 @@ embabel-air: show-chat-prompts: true + memory: + enabled: true + extraction-llm: + model: gpt-4.1-mini + temperature: 0.0 + window-size: 10 + trigger-interval: 6 + embabel: models: diff --git a/src/main/resources/db/migration/V6__add_propositions.sql b/src/main/resources/db/migration/V6__add_propositions.sql new file mode 100644 index 0000000..fb66d10 --- /dev/null +++ b/src/main/resources/db/migration/V6__add_propositions.sql @@ -0,0 +1,26 @@ +CREATE TABLE IF NOT EXISTS propositions ( + id VARCHAR(255) PRIMARY KEY, + context_id VARCHAR(255) NOT NULL, + text TEXT NOT NULL, + confidence DOUBLE PRECISION NOT NULL DEFAULT 0.5, + decay DOUBLE PRECISION NOT NULL DEFAULT 0.0, + importance DOUBLE PRECISION NOT NULL DEFAULT 0.5, + reasoning TEXT, + status VARCHAR(50) NOT NULL DEFAULT 'ACTIVE', + level INTEGER NOT NULL DEFAULT 0, + reinforce_count INTEGER NOT NULL DEFAULT 0, + created TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + revised TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + last_accessed TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + mentions JSONB NOT NULL DEFAULT '[]', + source_ids TEXT[] NOT NULL DEFAULT ARRAY[]::TEXT[], + grounding TEXT[] NOT NULL DEFAULT ARRAY[]::TEXT[], + metadata JSONB NOT NULL DEFAULT '{}', + uri TEXT, + embedding vector(1536) +); + +CREATE INDEX IF NOT EXISTS idx_propositions_context ON propositions(context_id); +CREATE INDEX IF NOT EXISTS idx_propositions_status ON propositions(status); +CREATE INDEX IF NOT EXISTS idx_propositions_embedding ON propositions USING hnsw (embedding vector_cosine_ops); +CREATE INDEX IF NOT EXISTS idx_propositions_mentions ON propositions USING gin (mentions); diff --git a/src/main/resources/db/migration/V7__add_named_entities.sql b/src/main/resources/db/migration/V7__add_named_entities.sql new file mode 100644 index 0000000..99c8614 --- /dev/null +++ b/src/main/resources/db/migration/V7__add_named_entities.sql @@ -0,0 +1,27 @@ +CREATE TABLE IF NOT EXISTS named_entities ( + id VARCHAR(255) PRIMARY KEY, + name TEXT NOT NULL, + description TEXT NOT NULL DEFAULT '', + uri TEXT, + labels TEXT[] NOT NULL DEFAULT ARRAY[]::TEXT[], + properties JSONB NOT NULL DEFAULT '{}', + metadata JSONB NOT NULL DEFAULT '{}', + context_id VARCHAR(255), + embedding vector(1536) +); + +CREATE TABLE IF NOT EXISTS entity_relationships ( + id SERIAL PRIMARY KEY, + source_id VARCHAR(255) NOT NULL REFERENCES named_entities(id) ON DELETE CASCADE, + target_id VARCHAR(255) NOT NULL REFERENCES named_entities(id) ON DELETE CASCADE, + relationship_name VARCHAR(255) NOT NULL, + properties JSONB NOT NULL DEFAULT '{}', + UNIQUE(source_id, target_id, relationship_name) +); + +CREATE INDEX IF NOT EXISTS idx_entities_labels ON named_entities USING gin (labels); +CREATE INDEX IF NOT EXISTS idx_entities_embedding ON named_entities USING hnsw (embedding vector_cosine_ops); +CREATE INDEX IF NOT EXISTS idx_entities_properties ON named_entities USING gin (properties); +CREATE INDEX IF NOT EXISTS idx_entities_context ON named_entities(context_id); +CREATE INDEX IF NOT EXISTS idx_relationships_source ON entity_relationships(source_id); +CREATE INDEX IF NOT EXISTS idx_relationships_target ON entity_relationships(target_id); diff --git a/src/main/resources/prompts/dice/extract_air_propositions.jinja b/src/main/resources/prompts/dice/extract_air_propositions.jinja new file mode 100644 index 0000000..da09652 --- /dev/null +++ b/src/main/resources/prompts/dice/extract_air_propositions.jinja @@ -0,0 +1,44 @@ +You are an expert at extracting structured facts about airline customers. + +Extract propositions (factual statements) from the following text. +Focus on facts about or relevant to the customer. + +## Entity Types in Schema +{% for type in model.context.schema.domainTypes %} +- {{ type.ownLabel }}: {{ type.description | default(type.ownLabel) }} +{% endfor %} + +## What to Extract +- Travel preferences (seat preferences, meal preferences, airline preferences) +- Loyalty/membership information +- Destinations they mention traveling to or from +- Companions (family, business colleagues) +- Dietary needs or restrictions +- Frequent routes or travel patterns +- Home city or country +- Any other factual information about the customer + +## What NOT to Extract +- Greetings or pleasantries +- Questions the user asks (unless they reveal preferences) +- Assistant responses or opinions +- Transient booking details already in the system + +## Output Format +For each proposition, provide: +- text: A clear, factual statement starting with "{{ customer.name }}" as the subject (e.g., "{{ customer.name }} prefers window seats") +- mentions: Entity references with type and role +- confidence: How certain you are (0.0-1.0) +- decay: How quickly this fact becomes stale (0.0 for permanent facts, higher for temporal ones) +- importance: How useful this fact is for future interactions (0.0-1.0) +- reasoning: Brief explanation of why you extracted this + +## Text +{{ model.chunk.text }} + +{% if model.existingPropositions is defined and model.existingPropositions %} +## Already Known Facts (avoid duplicates) +{% for prop in model.existingPropositions %} +- {{ prop.text }} +{% endfor %} +{% endif %}