From c821a3349418a955c0c960e079cf63d71159ae4c Mon Sep 17 00:00:00 2001 From: ayushaga14 Date: Wed, 8 Oct 2025 14:06:25 +0530 Subject: [PATCH] use malicious event dao in main threat queries --- .../java/com/akto/threat/backend/Main.java | 7 +- .../threat/backend/dao/MaliciousEventDao.java | 15 ++++ .../service/MaliciousEventService.java | 79 ++++++------------- .../backend/service/ThreatActorService.java | 73 +++++++---------- .../backend/service/ThreatApiService.java | 32 ++------ .../threat/backend/utils/ThreatUtils.java | 31 ++------ 6 files changed, 89 insertions(+), 148 deletions(-) diff --git a/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/Main.java b/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/Main.java index 9b0cc3a5ca..52d627a24d 100644 --- a/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/Main.java +++ b/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/Main.java @@ -9,6 +9,7 @@ import com.akto.kafka.KafkaProducerConfig; import com.akto.kafka.Serializer; import com.akto.log.LoggerMaker; +import com.akto.threat.backend.dao.MaliciousEventDao; import com.akto.threat.backend.dao.ThreatDetectionDaoInit; import com.akto.threat.backend.service.ApiDistributionDataService; import com.akto.threat.backend.service.MaliciousEventService; @@ -72,10 +73,10 @@ public static void main(String[] args) throws Exception { new FlushMessagesToDB(internalKafkaConfig, threatProtectionMongo).run(); MaliciousEventService maliciousEventService = - new MaliciousEventService(internalKafkaConfig, threatProtectionMongo); + new MaliciousEventService(internalKafkaConfig, MaliciousEventDao.instance); - ThreatActorService threatActorService = new ThreatActorService(threatProtectionMongo); - ThreatApiService threatApiService = new ThreatApiService(threatProtectionMongo); + ThreatActorService threatActorService = new ThreatActorService(threatProtectionMongo, MaliciousEventDao.instance); + ThreatApiService threatApiService = new ThreatApiService(MaliciousEventDao.instance); ApiDistributionDataService apiDistributionDataService = new ApiDistributionDataService(threatProtectionMongo); new BackendVerticle(maliciousEventService, threatActorService, threatApiService, apiDistributionDataService).start(); diff --git a/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/dao/MaliciousEventDao.java b/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/dao/MaliciousEventDao.java index 4c5499ebdf..7f61206581 100644 --- a/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/dao/MaliciousEventDao.java +++ b/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/dao/MaliciousEventDao.java @@ -2,7 +2,12 @@ import com.akto.dto.threat_detection_backend.MaliciousEventDto; import com.akto.threat.backend.constants.MongoDBCollection; +import com.mongodb.client.AggregateIterable; import com.mongodb.client.MongoCollection; +import org.bson.Document; +import org.bson.conversions.Bson; + +import java.util.List; public class MaliciousEventDao extends AccountBasedDao { @@ -27,4 +32,14 @@ public void insertOne(String accountId, MaliciousEventDto event) { public MongoCollection getCollection(String accountId) { return super.getCollection(accountId); } + + public AggregateIterable aggregateRaw(String accountId, List pipeline) { + return getDatabase(accountId) + .getCollection(getCollectionName(), Document.class) + .aggregate(pipeline); + } + + public long countDocuments(String accountId, Bson filter) { + return getCollection(accountId).countDocuments(filter); + } } diff --git a/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/service/MaliciousEventService.java b/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/service/MaliciousEventService.java index 403e142775..f2160fb09f 100644 --- a/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/service/MaliciousEventService.java +++ b/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/service/MaliciousEventService.java @@ -17,14 +17,16 @@ import com.akto.proto.generated.threat_detection.service.malicious_alert_service.v1.RecordMaliciousEventRequest; import com.akto.threat.backend.constants.KafkaTopic; import com.akto.threat.backend.constants.MongoDBCollection; +import com.akto.threat.backend.dao.MaliciousEventDao; import com.akto.threat.backend.utils.KafkaUtils; import com.akto.threat.backend.utils.ThreatUtils; import com.akto.threat.backend.constants.StatusConstants; -import com.mongodb.client.*; -import com.mongodb.client.model.Filters; +import com.mongodb.client.DistinctIterable; +import com.mongodb.client.MongoCursor; import java.util.*; +import com.mongodb.client.model.Filters; import com.mongodb.client.model.Updates; import org.bson.Document; import org.bson.conversions.Bson; @@ -32,15 +34,15 @@ public class MaliciousEventService { private final Kafka kafka; - private final MongoClient mongoClient; + private final MaliciousEventDao maliciousEventDao; private static final LoggerMaker logger = new LoggerMaker(MaliciousEventService.class); private static final HashMap shouldNotCreateIndexes = new HashMap<>(); public MaliciousEventService( - KafkaConfig kafkaConfig, MongoClient mongoClient) { + KafkaConfig kafkaConfig, MaliciousEventDao maliciousEventDao) { this.kafka = new Kafka(kafkaConfig); - this.mongoClient = mongoClient; + this.maliciousEventDao = maliciousEventDao; } // Convert string label to model Label enum @@ -147,9 +149,9 @@ public void recordMaliciousEvent(String accountId, RecordMaliciousEventRequest r KafkaTopic.ThreatDetection.INTERNAL_DB_MESSAGES); } - private static Set findDistinctFields( - MongoCollection coll, String fieldName, Class tClass, Bson filters) { - DistinctIterable r = coll.distinct(fieldName, filters, tClass); + private Set findDistinctFields( + String accountId, String fieldName, Class tClass, Bson filters) { + DistinctIterable r = maliciousEventDao.getCollection(accountId).distinct(fieldName, filters, tClass); Set result = new HashSet<>(); MongoCursor cursor = r.cursor(); while (cursor.hasNext()) { @@ -160,42 +162,28 @@ private static Set findDistinctFields( public ThreatActorFilterResponse fetchThreatActorFilters( String accountId, ThreatActorFilterRequest request) { - MongoCollection coll = - this.mongoClient - .getDatabase(accountId) - .getCollection("malicious_events", MaliciousEventDto.class); Set latestAttack = - MaliciousEventService.findDistinctFields( - coll, "filterId", String.class, Filters.empty()); + this.findDistinctFields(accountId, "filterId", String.class, Filters.empty()); Set countries = - MaliciousEventService.findDistinctFields( - coll, "country", String.class, Filters.empty()); + this.findDistinctFields(accountId, "country", String.class, Filters.empty()); Set actorIds = - MaliciousEventService.findDistinctFields( - coll, "actor", String.class, Filters.empty()); + this.findDistinctFields(accountId, "actor", String.class, Filters.empty()); return ThreatActorFilterResponse.newBuilder().addAllSubCategories(latestAttack).addAllCountries(countries).addAllActorId(actorIds).build(); } public FetchAlertFiltersResponse fetchAlertFilters( String accountId, FetchAlertFiltersRequest request) { - MongoCollection coll = - this.mongoClient - .getDatabase(accountId) - .getCollection("malicious_events", MaliciousEventDto.class); Set actors = - MaliciousEventService.findDistinctFields( - coll, "actor", String.class, Filters.empty()); + this.findDistinctFields(accountId, "actor", String.class, Filters.empty()); Set urls = - MaliciousEventService.findDistinctFields( - coll, "latestApiEndpoint", String.class, Filters.empty()); + this.findDistinctFields(accountId, "latestApiEndpoint", String.class, Filters.empty()); Set subCategories = - MaliciousEventService.findDistinctFields( - coll, "filterId", String.class, Filters.empty()); + this.findDistinctFields(accountId, "filterId", String.class, Filters.empty()); return FetchAlertFiltersResponse.newBuilder().addAllActors(actors).addAllUrls(urls).addAllSubCategory(subCategories).build(); } @@ -216,12 +204,6 @@ public ListMaliciousRequestsResponse listMaliciousRequests( return ListMaliciousRequestsResponse.newBuilder().build(); } - MongoCollection coll = - this.mongoClient - .getDatabase(accountId) - .getCollection( - MongoDBCollection.ThreatDetection.MALICIOUS_EVENTS, MaliciousEventDto.class); - Document query = new Document(); if (!filter.getActorsList().isEmpty()) { query.append("actor", new Document("$in", filter.getActorsList())); @@ -283,9 +265,10 @@ public ListMaliciousRequestsResponse listMaliciousRequests( applyLabelFilter(query, labelEnum); } - long total = coll.countDocuments(query); + long total = maliciousEventDao.countDocuments(accountId, query); try (MongoCursor cursor = - coll.find(query) + maliciousEventDao.getCollection(accountId) + .find(query) .sort(new Document("detectedAt", sort.getOrDefault("detectedAt", -1))) .skip(skip) .limit(limit) @@ -327,13 +310,12 @@ public ListMaliciousRequestsResponse listMaliciousRequests( } public void createIndexIfAbsent(String accountId) { - ThreatUtils.createIndexIfAbsent(accountId, mongoClient); + ThreatUtils.createIndexIfAbsent(accountId, maliciousEventDao); shouldNotCreateIndexes.put(accountId, true); } public int updateMaliciousEventStatus(String accountId, List eventIds, Map filterMap, String status) { try { - MongoCollection coll = getMaliciousEventCollection(accountId); MaliciousEventDto.Status eventStatus = MaliciousEventDto.Status.valueOf(status.toUpperCase()); Bson update = Updates.set("status", eventStatus.toString()); @@ -342,11 +324,11 @@ public int updateMaliciousEventStatus(String accountId, List eventIds, M return 0; } - String logMessage = String.format("Updating events %s to status: %s", + String logMessage = String.format("Updating events %s to status: %s", getQueryDescription(eventIds, filterMap), status); logger.info(logMessage); - - long modifiedCount = coll.updateMany(query, update).getModifiedCount(); + + long modifiedCount = maliciousEventDao.getCollection(accountId).updateMany(query, update).getModifiedCount(); return (int) modifiedCount; } catch (Exception e) { logger.error("Error updating malicious event status", e); @@ -356,8 +338,6 @@ public int updateMaliciousEventStatus(String accountId, List eventIds, M public int deleteMaliciousEvents(String accountId, List eventIds, Map filterMap) { try { - MongoCollection coll = getMaliciousEventCollection(accountId); - Document query = buildQuery(eventIds, filterMap, "delete"); if (query == null) { return 0; @@ -365,10 +345,10 @@ public int deleteMaliciousEvents(String accountId, List eventIds, Map eventIds, Map getMaliciousEventCollection(String accountId) { - return this.mongoClient - .getDatabase(accountId) - .getCollection( - MongoDBCollection.ThreatDetection.MALICIOUS_EVENTS, MaliciousEventDto.class); - } - private Document buildQuery(List eventIds, Map filterMap, String operation) { if (eventIds != null && !eventIds.isEmpty()) { // Query by event IDs diff --git a/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/service/ThreatActorService.java b/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/service/ThreatActorService.java index 8bf4e251e3..786924beb7 100644 --- a/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/service/ThreatActorService.java +++ b/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/service/ThreatActorService.java @@ -25,6 +25,7 @@ import com.akto.proto.generated.threat_detection.service.dashboard_service.v1.ListThreatActorResponse.ActivityData; import com.akto.ProtoMessageUtils; import com.akto.threat.backend.constants.MongoDBCollection; +import com.akto.threat.backend.dao.MaliciousEventDao; import com.akto.threat.backend.db.ActorInfoModel; import com.akto.threat.backend.dto.RateLimitConfigDTO; import com.akto.threat.backend.db.SplunkIntegrationModel; @@ -48,10 +49,12 @@ public class ThreatActorService { private final MongoClient mongoClient; + private final MaliciousEventDao maliciousEventDao; private static final LoggerMaker loggerMaker = new LoggerMaker(ThreatActorService.class, LoggerMaker.LogDb.THREAT_DETECTION); - public ThreatActorService(MongoClient mongoClient) { + public ThreatActorService(MongoClient mongoClient, MaliciousEventDao maliciousEventDao) { this.mongoClient = mongoClient; + this.maliciousEventDao = maliciousEventDao; } public ThreatConfiguration fetchThreatConfiguration(String accountId) { @@ -142,12 +145,8 @@ public ThreatConfiguration modifyThreatConfiguration(String accountId, ThreatCon public void deleteAllMaliciousEvents(String accountId) { loggerMaker.infoAndAddToDb("Deleting all malicious events for accountId: " + accountId); - MongoCollection coll = this.mongoClient - .getDatabase(accountId) - .getCollection(MongoDBCollection.ThreatDetection.MALICIOUS_EVENTS, Document.class); - - coll.drop(); - ThreatUtils.createIndexIfAbsent(accountId, mongoClient); + maliciousEventDao.getCollection(accountId).drop(); + ThreatUtils.createIndexIfAbsent(accountId, maliciousEventDao); loggerMaker.infoAndAddToDb("Deleted all malicious events for accountId: " + accountId); } @@ -157,10 +156,6 @@ public ListThreatActorResponse listThreatActors(String accountId, ListThreatActo int limit = request.getLimit(); Map sort = request.getSortMap(); - MongoCollection coll = this.mongoClient - .getDatabase(accountId) - .getCollection(MongoDBCollection.ThreatDetection.MALICIOUS_EVENTS, Document.class); - ListThreatActorsRequest.Filter filter = request.getFilter(); Document match = new Document(); @@ -207,7 +202,7 @@ public ListThreatActorResponse listThreatActors(String accountId, ListThreatActo .append("count", Arrays.asList(new Document("$count", "total"))) )); - Document result = coll.aggregate(pipeline).first(); + Document result = maliciousEventDao.aggregateRaw(accountId, pipeline).first(); List paginated = result.getList("paginated", Document.class, Collections.emptyList()); List countList = result.getList("count", Document.class, Collections.emptyList()); long total = countList.isEmpty() ? 0 : countList.get(0).getInteger("total"); @@ -218,18 +213,19 @@ public ListThreatActorResponse listThreatActors(String accountId, ListThreatActo String actorId = doc.getString("_id"); List activityDataList = new ArrayList<>(); - try (MongoCursor cursor2 = coll.find(Filters.eq("actor", actorId)) + try (MongoCursor cursor2 = maliciousEventDao.getCollection(accountId) + .find(Filters.eq("actor", actorId)) .sort(Sorts.descending("detectedAt")) .limit(40) .cursor()) { while (cursor2.hasNext()) { - Document doc2 = cursor2.next(); + MaliciousEventDto event = cursor2.next(); activityDataList.add(ActivityData.newBuilder() - .setUrl(doc2.getString("latestApiEndpoint")) - .setDetectedAt(doc2.getLong("detectedAt")) - .setSubCategory(doc2.getString("filterId")) - .setSeverity(doc2.getString("severity")) - .setMethod(doc2.getString("latestApiMethod")) + .setUrl(event.getLatestApiEndpoint()) + .setDetectedAt(event.getDetectedAt()) + .setSubCategory(event.getFilterId()) + .setSeverity(event.getSeverity()) + .setMethod(event.getLatestApiMethod().name()) .build()); } } @@ -256,10 +252,6 @@ public DailyActorsCountResponse getDailyActorCounts(String accountId, long start } List actors = new ArrayList<>(); - MongoCollection coll = this.mongoClient - .getDatabase(accountId) - .getCollection(MongoDBCollection.ThreatDetection.MALICIOUS_EVENTS, Document.class); - List pipeline = new ArrayList<>(); @@ -318,8 +310,8 @@ public DailyActorsCountResponse getDailyActorCounts(String accountId, long start new Document("$eq", Arrays.asList("$severity", "HIGH")), 1, 0)))))); - - try (MongoCursor cursor = coll.aggregate(pipeline).cursor()) { + + try (MongoCursor cursor = maliciousEventDao.aggregateRaw(accountId, pipeline).cursor()) { while (cursor.hasNext()) { Document doc = cursor.next(); // Convert dayStart from Date (ms) back to seconds @@ -351,9 +343,6 @@ public ThreatActivityTimelineResponse getThreatActivityTimeline(String accountId // if (startTs < endTs - sevenDaysInSeconds) { // startTs = endTs - sevenDaysInSeconds; // } - MongoCollection coll = this.mongoClient - .getDatabase(accountId) - .getCollection(MongoDBCollection.ThreatDetection.MALICIOUS_EVENTS, Document.class); Document match = new Document(); @@ -385,7 +374,7 @@ public ThreatActivityTimelineResponse getThreatActivityTimeline(String accountId new Document("$push", new Document("subCategory", "$_id.subCategory").append("count", "$count")))) ); - try (MongoCursor cursor = coll.aggregate(pipeline).cursor()) { + try (MongoCursor cursor = maliciousEventDao.aggregateRaw(accountId, pipeline).cursor()) { while (cursor.hasNext()) { Document doc = cursor.next(); System.out.print(doc); @@ -437,17 +426,17 @@ private String fetchMetadataString(Document doc){ return metadataStr; } - private List fetchMaliciousPayloadsResponse(FindIterable respList){ + private List fetchMaliciousPayloadsResponse(FindIterable respList){ if (respList == null) { return Collections.emptyList(); } List maliciousPayloadsResponse = new ArrayList<>(); - for (Document doc: respList) { + for (MaliciousEventDto event: respList) { maliciousPayloadsResponse.add( FetchMaliciousEventsResponse.MaliciousPayloadsResponse.newBuilder(). - setOrig(HttpResponseParams.getSampleStringFromProtoString(doc.getString("latestApiOrig"))). - setMetadata(fetchMetadataString(doc)). - setTs(doc.getLong("detectedAt")).build()); + setOrig(HttpResponseParams.getSampleStringFromProtoString(event.getLatestApiOrig())). + setMetadata(event.getMetadata() != null ? event.getMetadata() : ""). + setTs(event.getDetectedAt()).build()); } return maliciousPayloadsResponse; } @@ -457,9 +446,8 @@ public FetchMaliciousEventsResponse fetchAggregateMaliciousRequests( List maliciousPayloadsResponse = new ArrayList<>(); String refId = request.getRefId(); - MongoCollection coll = this.mongoClient.getDatabase(accountId).getCollection(MongoDBCollection.ThreatDetection.MALICIOUS_EVENTS, Document.class); Bson filters = Filters.eq("refId", refId); - FindIterable respList; + FindIterable respList; if (request.getEventType().equalsIgnoreCase(MaliciousEventDto.EventType.AGGREGATED.name())) { Bson matchConditions = Filters.and( @@ -470,12 +458,12 @@ public FetchMaliciousEventsResponse fetchAggregateMaliciousRequests( matchConditions, filters ); - respList = (FindIterable) coll.find(matchConditions).sort(Sorts.descending("detectedAt")).limit(10); + respList = maliciousEventDao.getCollection(accountId).find(matchConditions).sort(Sorts.descending("detectedAt")).limit(10); maliciousPayloadsResponse.addAll(this.fetchMaliciousPayloadsResponse(respList)); // TODO: Handle case where aggregate was satisfied only once. } else { - respList = (FindIterable) coll.find(filters); - maliciousPayloadsResponse = this.fetchMaliciousPayloadsResponse(respList); + respList = maliciousEventDao.getCollection(accountId).find(filters); + maliciousPayloadsResponse = this.fetchMaliciousPayloadsResponse(respList); } return FetchMaliciousEventsResponse.newBuilder().addAllMaliciousPayloadsResponse(maliciousPayloadsResponse).build(); @@ -488,11 +476,6 @@ public ThreatActorByCountryResponse getThreatActorByCountry( return ThreatActorByCountryResponse.newBuilder().build(); } - MongoCollection coll = - this.mongoClient - .getDatabase(accountId) - .getCollection(MongoDBCollection.ThreatDetection.MALICIOUS_EVENTS, Document.class); - List pipeline = new ArrayList<>(); Document match = new Document(); @@ -525,7 +508,7 @@ public ThreatActorByCountryResponse getThreatActorByCountry( List actorsByCountryCount = new ArrayList<>(); - try (MongoCursor cursor = coll.aggregate(pipeline).batchSize(1000).cursor()) { + try (MongoCursor cursor = maliciousEventDao.aggregateRaw(accountId, pipeline).batchSize(1000).cursor()) { while (cursor.hasNext()) { Document doc = cursor.next(); actorsByCountryCount.add( diff --git a/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/service/ThreatApiService.java b/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/service/ThreatApiService.java index b732f0bf0e..37db2fd675 100644 --- a/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/service/ThreatApiService.java +++ b/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/service/ThreatApiService.java @@ -8,9 +8,7 @@ import com.akto.proto.generated.threat_detection.service.dashboard_service.v1.ThreatCategoryWiseCountResponse; import com.akto.proto.generated.threat_detection.service.dashboard_service.v1.ThreatSeverityWiseCountRequest; import com.akto.proto.generated.threat_detection.service.dashboard_service.v1.ThreatSeverityWiseCountResponse; -import com.akto.threat.backend.constants.MongoDBCollection; -import com.mongodb.client.MongoClient; -import com.mongodb.client.MongoCollection; +import com.akto.threat.backend.dao.MaliciousEventDao; import com.mongodb.client.MongoCursor; import com.mongodb.client.model.Filters; @@ -23,11 +21,11 @@ public class ThreatApiService { - private final MongoClient mongoClient; + private final MaliciousEventDao maliciousEventDao; private static final LoggerMaker loggerMaker = new LoggerMaker(ThreatApiService.class); - public ThreatApiService(MongoClient mongoClient) { - this.mongoClient = mongoClient; + public ThreatApiService(MaliciousEventDao maliciousEventDao) { + this.maliciousEventDao = maliciousEventDao; } public ListThreatApiResponse listThreatApis(String accountId, ListThreatApiRequest request) { @@ -37,10 +35,6 @@ public ListThreatApiResponse listThreatApis(String accountId, ListThreatApiReque int skip = request.hasSkip() ? request.getSkip() : 0; int limit = request.getLimit(); Map sort = request.getSortMap(); - MongoCollection coll = - this.mongoClient - .getDatabase(accountId) - .getCollection(MongoDBCollection.ThreatDetection.MALICIOUS_EVENTS, Document.class); List base = new ArrayList<>(); ListThreatApiRequest.Filter filter = request.getFilter(); @@ -90,7 +84,7 @@ public ListThreatApiResponse listThreatApis(String accountId, ListThreatApiReque List countPipeline = new ArrayList<>(base); countPipeline.add(new Document("$count", "total")); - Document result = coll.aggregate(countPipeline).first(); + Document result = maliciousEventDao.aggregateRaw(accountId, countPipeline).first(); long total = result != null ? result.getInteger("total", 0) : 0; List pipeline = new ArrayList<>(base); @@ -107,7 +101,7 @@ public ListThreatApiResponse listThreatApis(String accountId, ListThreatApiReque .append("actorsCount", sort.getOrDefault("actorsCount", -1)))); List apis = new ArrayList<>(); - try (MongoCursor cursor = coll.aggregate(pipeline).cursor()) { + try (MongoCursor cursor = maliciousEventDao.aggregateRaw(accountId, pipeline).cursor()) { while (cursor.hasNext()) { Document doc = cursor.next(); Document agg = (Document) doc.get("_id"); @@ -137,11 +131,6 @@ public ThreatCategoryWiseCountResponse getSubCategoryWiseCount( loggerMaker.info("getSubCategoryWiseCount start ts " + Context.now()); - MongoCollection coll = - this.mongoClient - .getDatabase(accountId) - .getCollection(MongoDBCollection.ThreatDetection.MALICIOUS_EVENTS, Document.class); - List pipeline = new ArrayList<>(); Document match = new Document(); @@ -169,7 +158,7 @@ public ThreatCategoryWiseCountResponse getSubCategoryWiseCount( List categoryWiseCounts = new ArrayList<>(); // 5. Execute aggregation with controlled batch size - try (MongoCursor cursor = coll.aggregate(pipeline).batchSize(1000).cursor()) { + try (MongoCursor cursor = maliciousEventDao.aggregateRaw(accountId, pipeline).batchSize(1000).cursor()) { while (cursor.hasNext()) { Document doc = cursor.next(); Document agg = (Document) doc.get("_id"); @@ -199,11 +188,6 @@ public ThreatSeverityWiseCountResponse getSeverityWiseCount( loggerMaker.info("getSeverityWiseCount start ts " + Context.now()); - MongoCollection coll = - this.mongoClient - .getDatabase(accountId) - .getCollection(MongoDBCollection.ThreatDetection.MALICIOUS_EVENTS, Document.class); - List categoryWiseCounts = new ArrayList<>(); String[] severities = { "CRITICAL", "HIGH", "MEDIUM", "LOW" }; @@ -216,7 +200,7 @@ public ThreatSeverityWiseCountResponse getSeverityWiseCount( Filters.in("filterId", req.getLatestAttackList()) ); - long count = coll.countDocuments(filter); + long count = maliciousEventDao.countDocuments(accountId, filter); if (count > 0) { categoryWiseCounts.add( diff --git a/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/utils/ThreatUtils.java b/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/utils/ThreatUtils.java index 16dfd040a8..979356e9ef 100644 --- a/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/utils/ThreatUtils.java +++ b/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/utils/ThreatUtils.java @@ -1,7 +1,8 @@ package com.akto.threat.backend.utils; -import com.akto.threat.backend.constants.MongoDBCollection; -import com.mongodb.client.*; +import com.akto.threat.backend.dao.MaliciousEventDao; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.MongoCursor; import com.mongodb.client.model.IndexOptions; import com.mongodb.client.model.Indexes; import org.bson.Document; @@ -11,28 +12,12 @@ public class ThreatUtils { - public static void createIndexIfAbsent(String accountId, MongoClient mongoClient) { - MongoDatabase database = mongoClient.getDatabase(accountId); - - MongoCursor stringMongoCursor = database.listCollectionNames().cursor(); - boolean maliciousEventCollectionExists = false; - - while (stringMongoCursor.hasNext()) { - String collectionName = stringMongoCursor.next(); - if(collectionName.equalsIgnoreCase(MongoDBCollection.ThreatDetection.MALICIOUS_EVENTS)) { - maliciousEventCollectionExists = true; - break; - } - } - - if (!maliciousEventCollectionExists) { - database.createCollection(MongoDBCollection.ThreatDetection.MALICIOUS_EVENTS); - } - - MongoCollection coll = database.getCollection(MongoDBCollection.ThreatDetection.MALICIOUS_EVENTS, Document.class); + public static void createIndexIfAbsent(String accountId, MaliciousEventDao maliciousEventDao) { + // Get the collection from DAO - this will create the collection if it doesn't exist + MongoCollection collection = maliciousEventDao.getCollection(accountId); Set existingIndexes = new HashSet<>(); - try (MongoCursor cursor = coll.listIndexes().iterator()) { + try (MongoCursor cursor = collection.listIndexes().iterator()) { while (cursor.hasNext()) { Document index = cursor.next(); existingIndexes.add(index.get("name", "")); @@ -51,7 +36,7 @@ public static void createIndexIfAbsent(String accountId, MongoClient mongoClient for (Map.Entry entry : requiredIndexes.entrySet()) { if (!existingIndexes.contains(entry.getKey())) { - coll.createIndex(entry.getValue(), new IndexOptions().name(entry.getKey())); + collection.createIndex(entry.getValue(), new IndexOptions().name(entry.getKey())); } } }