diff --git a/src/main/java/us/ajg0702/bots/ajsupport/SupportBot.java b/src/main/java/us/ajg0702/bots/ajsupport/SupportBot.java index e93086c..4c97e21 100644 --- a/src/main/java/us/ajg0702/bots/ajsupport/SupportBot.java +++ b/src/main/java/us/ajg0702/bots/ajsupport/SupportBot.java @@ -78,7 +78,7 @@ private SupportBot(JDA jda) throws InterruptedException { return; } json = new Gson().fromJson(jsonRaw.toString(), JsonObject.class); - jda.addEventListener(new AutoRespondManager(json)); + jda.addEventListener(new AutoRespondManager(json, this)); new Thread(() -> { diff --git a/src/main/java/us/ajg0702/bots/ajsupport/autorespond/AutoRespondManager.java b/src/main/java/us/ajg0702/bots/ajsupport/autorespond/AutoRespondManager.java index c8027bd..34568f6 100644 --- a/src/main/java/us/ajg0702/bots/ajsupport/autorespond/AutoRespondManager.java +++ b/src/main/java/us/ajg0702/bots/ajsupport/autorespond/AutoRespondManager.java @@ -2,9 +2,14 @@ import com.google.gson.JsonObject; import net.dv8tion.jda.api.EmbedBuilder; +import net.dv8tion.jda.api.JDA; import net.dv8tion.jda.api.entities.Member; +import net.dv8tion.jda.api.entities.Message; +import net.dv8tion.jda.api.entities.channel.concrete.TextChannel; import net.dv8tion.jda.api.events.message.MessageReceivedEvent; import net.dv8tion.jda.api.hooks.ListenerAdapter; +import us.ajg0702.bots.ajsupport.EchoException; +import us.ajg0702.bots.ajsupport.SupportBot; import us.ajg0702.bots.ajsupport.autorespond.responders.*; import us.ajg0702.bots.ajsupport.autorespond.responders.ajlb.BDNEResponder; import us.ajg0702.bots.ajsupport.autorespond.responders.ajlb.DontUpdatePermResponse; @@ -12,13 +17,19 @@ import us.ajg0702.bots.ajsupport.autorespond.responders.ajlb.OnlineOnlyResponse; import us.ajg0702.bots.ajsupport.autorespond.responders.ajq.SpigotForwardingResponse; +import java.io.IOException; +import java.math.BigDecimal; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; import java.util.List; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; public class AutoRespondManager extends ListenerAdapter { + private final SupportBot bot; + List responders = Arrays.asList( // ajlb new BDNEResponder(), @@ -31,7 +42,8 @@ public class AutoRespondManager extends ListenerAdapter { new SpigotForwardingResponse() ); - public AutoRespondManager(JsonObject responses) { + public AutoRespondManager(JsonObject responses, SupportBot bot) { + this.bot = bot; Responder.RESPONSES = responses; } @Override @@ -46,21 +58,65 @@ public void onMessageReceived(MessageReceivedEvent event) { if(r != null) responses.add(r); } - if(responses.isEmpty()) return; + if(responses.isEmpty()) { + autoRespondEmbeddings(event.getMessage()); + return; + } responses.sort(Comparator.comparingInt(Response::getConfidence).reversed()); Response bestResponse = responses.get(0); - if(bestResponse == null) return; // I don't think this could happen, but just to be safe + if(bestResponse == null) { // I don't think this could happen, but just to be safe + autoRespondEmbeddings(event.getMessage()); + return; + } event.getMessage() .reply(bestResponse.getMessage()) .addEmbeds( new EmbedBuilder() .setDescription("The message above is an automated response. If it is not helpful, please state that it was not helpful so that a human can help you when they are available. Otherwise they may assume this message solved your issue") - .setFooter("ajSupport • Response confidence: " + bestResponse.getConfidence() + "%") + .setFooter("ajSupport • Selection: static • Response confidence: " + bestResponse.getConfidence() + "%") .build() ) .queue(); } + + public void autoRespondEmbeddings(Message message) { + String categoryID = message.getChannel().asTextChannel().getParentCategoryId(); + if(categoryID == null || (!categoryID.equals("804502763547000893") && !message.getChannelId().equals("700885801352822825"))) return; + + try { + BigDecimal[] vec = EmbeddingUtils.embed(message.getContentStripped().replaceAll("\\n", " ")); + EmbeddingUtils.VectorizeResponse topResult = EmbeddingUtils.queryVectorize(vec); + if(topResult == null || topResult.getScore() < 0.8) return; + + String responseKey = topResult.getMetadata().get("response").getAsString(); + String originalChannelId = topResult.getMetadata().get("channelId").getAsString(); + String originalMessageLink = "https://discord.com/channels/615715762912362565/" + originalChannelId + "/" + topResult.getId(); + + TextChannel channel = bot.getJDA().getTextChannelById(698756204801032202L); + if(channel == null) { + bot.getLogger().error("Cannot find logger-log channel for aj's plugins!"); + throw new IOException("Cannot find log channel!"); + } + channel.sendMessageEmbeds( + new EmbedBuilder() + .setDescription("vectorize is replying with " + responseKey + " due to result from " + originalMessageLink + "\n" + + "Reply to "+message.getAuthor().getName()+": " + + SupportBot.cutString( + message.getContentStripped().replaceAll("\n", " "), + 100 + ) + ) + .build() + ).queue(); + + message.reply(bot.getJson().get(responseKey).getAsString()).queue(); + + } catch (IOException e) { + bot.getLogger().warn("Error while embedding for auto-response:", e); + } + + } } diff --git a/src/main/java/us/ajg0702/bots/ajsupport/autorespond/EmbeddingUtils.java b/src/main/java/us/ajg0702/bots/ajsupport/autorespond/EmbeddingUtils.java new file mode 100644 index 0000000..298ac74 --- /dev/null +++ b/src/main/java/us/ajg0702/bots/ajsupport/autorespond/EmbeddingUtils.java @@ -0,0 +1,188 @@ +package us.ajg0702.bots.ajsupport.autorespond; + +import java.io.IOException; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.net.HttpURLConnection; +import java.net.URL; +import java.io.InputStreamReader; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +import com.google.gson.Gson; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import org.jetbrains.annotations.Nullable; + +public class EmbeddingUtils { + + public static BigDecimal[] embed(String string) throws IOException { + final String token = System.getenv("CF_TOKEN"); + if(token == null || token.isEmpty()) throw new IOException("Missing CF Token!"); + + URL url = new URL("https://api.cloudflare.com/client/v4/accounts/f55b85c8a963663b11036975203c63c0/ai/run/@cf/baai/bge-base-en-v1.5"); + HttpURLConnection con = (HttpURLConnection) url.openConnection(); + con.setRequestMethod("POST"); + con.setRequestProperty("Content-Type", "application/json"); + con.setRequestProperty("Authorization", "Bearer " + token); // Replace YOUR_BEARER_TOKEN with your actual token + con.setDoOutput(true); + + + String payload = String.format("{\"text\": \"%s\"}", string); + try (var outputStream = con.getOutputStream()) { + outputStream.write(payload.getBytes()); + outputStream.flush(); + } + + var inputStream = con.getInputStream(); + var reader = new InputStreamReader(inputStream); + + Gson gson = new Gson(); + JsonObject json = gson.fromJson(reader, JsonObject.class).getAsJsonObject("result"); + + BigDecimal[] vec = new BigDecimal[768]; + + int i = 0; + for (JsonElement vectorElement : json.getAsJsonArray("data").get(0).getAsJsonArray()) { + vec[i++] = vectorElement.getAsBigDecimal(); + } + + return vec; + } + + public static @Nullable VectorizeResponse queryVectorize(BigDecimal[] query) throws IOException { + List results = queryVectorize(query, 1); + if(results.size() == 0) return null; + return results.get(0); + } + + public static List queryVectorize(BigDecimal[] query, int topK) throws IOException { + final String token = System.getenv("CF_TOKEN"); + if(token == null || token.isEmpty()) throw new IOException("Missing CF Token!"); + + URL url = new URL("https://api.cloudflare.com/client/v4/accounts/f55b85c8a963663b11036975203c63c0/vectorize/v2/indexes/support-autoresponse/query"); + HttpURLConnection con = (HttpURLConnection) url.openConnection(); + con.setRequestMethod("POST"); + con.setRequestProperty("Content-Type", "application/json"); + con.setRequestProperty("Authorization", "Bearer " + token); // Replace YOUR_BEARER_TOKEN with your actual token + con.setDoOutput(true); + + String payload = String.format( + "{\"vector\": %s, \"returnMetadata\": \"all\", \"returnValues\": false, \"topK\": " + topK + "}", + new Gson().toJson(query) + ); + try (var outputStream = con.getOutputStream()) { + outputStream.write(payload.getBytes()); + outputStream.flush(); + } + + var inputStream = con.getInputStream(); + var reader = new InputStreamReader(inputStream); + + Gson gson = new Gson(); + JsonObject json = gson.fromJson(reader, JsonObject.class); + var responses = new ArrayList(); + System.out.println(json.toString()); + + for (JsonElement element : json.getAsJsonObject("result").getAsJsonArray("matches")) { + responses.add(new VectorizeResponse( + element.getAsJsonObject().get("id").getAsString(), + element.getAsJsonObject().get("metadata").getAsJsonObject(), + element.getAsJsonObject().get("score").getAsDouble() + )); + } + + return responses; + } + + public static void insertIntoVectorize(String id, BigDecimal[] vector, String channelId, String response) throws IOException { + final String token = System.getenv("CF_TOKEN"); + if(token == null || token.isEmpty()) throw new IOException("Missing CF Token!"); + + final String boundary = new BigInteger(128, new Random()).toString(); + + URL url = new URL("https://api.cloudflare.com/client/v4/accounts/f55b85c8a963663b11036975203c63c0/vectorize/v2/indexes/support-autoresponse/upsert"); + HttpURLConnection con = (HttpURLConnection) url.openConnection(); + con.setRequestMethod("POST"); + con.setRequestProperty("Content-Type", "multipart/form-data; boundary=" + boundary); + con.setRequestProperty("Authorization", "Bearer " + token); // Replace YOUR_BEARER_TOKEN with your actual token + con.setDoOutput(true); + + final String end = boundary + "\r\n"; + + String payload = "--" + end + "Content-Disposition: form-data; name=\"vectors\"; filename=\"upsert.ndjson\"\r\nContent-Type: application/x-ndjson\r\n\r\n" + + String.format( + "{\"id\": \"%s\", \"values\": %s, \"metadata\": { \"channelId\":\"%s\", \"response\": \"%s\" }}", + id, + new Gson().toJson(vector), + channelId, + response + ) + "\r\n--" + boundary + "--"; + + + try (var outputStream = con.getOutputStream()) { + outputStream.write(payload.getBytes()); + outputStream.flush(); + } + + System.out.println("(Vectorize) Upsert request returned status " + con.getResponseCode()); + try (var inputStream = con.getResponseCode() < HttpURLConnection.HTTP_BAD_REQUEST ? con.getInputStream() : con.getErrorStream(); + var reader = new InputStreamReader(inputStream)) { + StringBuilder responseBuilder = new StringBuilder(); + int c; + while ((c = reader.read()) != -1) { + responseBuilder.append((char) c); + } + System.out.println("(Vectorize) Response body: " + responseBuilder); + } + + + + + } + + public static class VectorizeResponse { + private final String id; + private final JsonObject metadata; + private final double score; + + public VectorizeResponse(String id, JsonObject metadata, double score) { + this.id = id; + this.metadata = metadata; + this.score = score; + } + + public String getId() { + return id; + } + + public JsonObject getMetadata() { + return metadata; + } + + public double getScore() { + return score; + } + + @Override + public String toString() { + return score + " " + id + ": " + metadata.toString(); + } + } + + public static void main(String[] args) { + try { +// insertIntoVectorize("1326592604510617620", embed("can anyone tell me how to reset AjLeaderboards"), "810277316298408007", "reset"); +// insertIntoVectorize("1326229772476485633", embed("how do i reset people on leaderboard even tho they arent online?"), "810277316298408007", "reset"); + BigDecimal[] vec = embed("how do I reset my kills leaderboard"); +// System.out.println(vec); + + List responses = queryVectorize(vec); + System.out.println(responses); + } catch (IOException e) { + e.printStackTrace(); + } + } + +}