From e024e6fac01788cf5e8bda0d0c24a5bb67e9a4d3 Mon Sep 17 00:00:00 2001 From: Alex Osborne Date: Fri, 19 Jul 2024 20:44:15 +0900 Subject: [PATCH] Subject suggestions: add query param for experiments --- ui/src/pandas/render/PageInfoController.java | 29 ++++++++++++++++---- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/ui/src/pandas/render/PageInfoController.java b/ui/src/pandas/render/PageInfoController.java index 6b7cc07..22d30a0 100644 --- a/ui/src/pandas/render/PageInfoController.java +++ b/ui/src/pandas/render/PageInfoController.java @@ -46,9 +46,11 @@ public class PageInfoController { private final LoadingCache pageInfoCache; private final LoadingCache> subjectsCache; private final LLMClient llm; + private final SubjectRepository subjectRepository; public PageInfoController(@Autowired(required = false) LLMClient llm, SubjectRepository subjectRepository) { this.llm = llm; + this.subjectRepository = subjectRepository; this.httpClient = HttpClient.newBuilder() .followRedirects(HttpClient.Redirect.ALWAYS) .build(); @@ -70,7 +72,7 @@ public PageInfo load(String url) throws IOException, InterruptedException { public List load(String url) throws Exception { List subjects = subjectRepository.findAllSubjectNames(); //subjects = subjectRepository.findAllSubjectNamesNested(); - return suggestSubjects(subjects, subjectRepository.findAllSubjectNamesNested(), url, pageInfoCache.get(url)); + return suggestSubjects(subjects, subjects, url, pageInfoCache.get(url), null); } }); } @@ -86,10 +88,24 @@ public PageInfo load(@RequestParam(name = "url") String url) throws Exception { @GetMapping(value = "/subjects/suggest", produces = "application/json") @ResponseBody - public List getSubjectSuggestions(@RequestParam(name = "url") String url) throws Exception { + public List getSubjectSuggestions(@RequestParam(name = "url") String url, + @RequestParam(name = "nested", defaultValue = "false") boolean nested, + @RequestParam(name = "cache", defaultValue = "true") boolean cache, + @RequestParam(name = "prompt", required = false) String prompt) throws Exception { if (!url.startsWith("http://") && !url.startsWith("https://")) { throw new IllegalArgumentException("bad url"); } + + if (nested) { + List subjects = subjectRepository.findAllSubjectNames(); + List nestedSubjects = subjectRepository.findAllSubjectNamesNested(); + return suggestSubjects(subjects, nestedSubjects, url, pageInfoCache.get(url), prompt); + } + + if (!cache) { + List subjects = subjectRepository.findAllSubjectNames(); + return suggestSubjects(subjects, subjects, url, pageInfoCache.get(url), prompt); + } return subjectsCache.get(url); } @@ -150,11 +166,14 @@ private PageInfo fetchPageInfo(String url) throws IOException, InterruptedExcept } } - List suggestSubjects(List subjects, List subjectsIndented, String url, PageInfo pageInfo) { + List suggestSubjects(List subjects, List subjectsIndented, String url, PageInfo pageInfo, String instructions) { if (llm == null) return List.of(); StringBuilder prompt = new StringBuilder(); // prompt.append("Please categorise the website below into three of the following categories. Output only the categories as a JSON array of strings and no other text.\n\n\n"); - prompt.append("Please classify the website below into at most three of the following subject categories, making sure to select the most specific subcategory where applicable. Output the classification as a JSON array.\n\n\n\n"); + if (instructions == null) { + instructions = "Please classify the website below into at most three of the following subject categories, making sure to select the most specific subcategory where applicable. Output the classification as a JSON array."; + } + prompt.append(instructions).append("\n\n\n\n"); for (var subject: subjectsIndented) { prompt.append(subject).append('\n'); } @@ -337,6 +356,6 @@ Festivals & Events (Cultural) Water Women Youth""".split("\n")); - System.out.println(controller.suggestSubjects(subjects, subjects, args[0], pageInfo)); + System.out.println(controller.suggestSubjects(subjects, subjects, args[0], pageInfo, null)); } }