diff --git a/backend/app/connectors/velociraptor/routes/flows.py b/backend/app/connectors/velociraptor/routes/flows.py index c824f60a..65de0f8a 100644 --- a/backend/app/connectors/velociraptor/routes/flows.py +++ b/backend/app/connectors/velociraptor/routes/flows.py @@ -56,6 +56,74 @@ async def get_velociraptor_id(session: AsyncSession, hostname: str) -> str: return agent.velociraptor_id +async def get_velociraptor_org(session: AsyncSession, hostname: str) -> str: + """ + Retrieves the velociraptor_org associated with the given hostname. + + Args: + session (AsyncSession): The database session. + hostname (str): The hostname of the agent. + + Returns: + str: The velociraptor_org associated with the hostname. + + Raises: + HTTPException: If the agent with the given hostname is not found or if the velociraptor_org is not available. + """ + logger.info(f"Getting velociraptor_org from hostname {hostname}") + result = await session.execute(select(Agents).filter(Agents.hostname == hostname)) + agent = result.scalars().first() + + if not agent: + raise HTTPException( + status_code=404, + detail=f"Agent with hostname {hostname} not found", + ) + + if agent.velociraptor_org is None: + raise HTTPException( + status_code=404, + detail=f"Velociraptor ORG for hostname {hostname} is not available", + ) + + logger.info(f"velociraptor_org for hostname {hostname} is {agent.velociraptor_org}") + return agent.velociraptor_org + + +async def get_velociraptor_org_via_client_id(session: AsyncSession, client_id: str) -> str: + """ + Retrieves the velociraptor_org associated with the given hostname. + + Args: + session (AsyncSession): The database session. + hostname (str): The hostname of the agent. + + Returns: + str: The velociraptor_org associated with the hostname. + + Raises: + HTTPException: If the agent with the given hostname is not found or if the velociraptor_org is not available. + """ + logger.info(f"Getting velociraptor_org from client id {client_id}") + result = await session.execute(select(Agents).filter(Agents.velociraptor_id == client_id)) + agent = result.scalars().first() + + if not agent: + raise HTTPException( + status_code=404, + detail=f"Agent with client id {client_id} not found", + ) + + if agent.velociraptor_org is None: + raise HTTPException( + status_code=404, + detail=f"Velociraptor ORG for hostname {client_id} is not available", + ) + + logger.info(f"velociraptor_org for hostname {client_id} is {agent.velociraptor_org}") + return agent.velociraptor_org + + @velociraptor_flows_router.get( "/{hostname}", response_model=FlowResponse, @@ -79,8 +147,12 @@ async def get_all_flows_for_hostname( logger.info(f"Fetching all flows for hostname {hostname}") velociraptor_id = await get_velociraptor_id(session, hostname) + velociraptor_org = await get_velociraptor_org( + session, + hostname, + ) logger.info(f"velociraptor_id for hostname {hostname} is {velociraptor_id}") - return await get_flows(velociraptor_id) + return await get_flows(velociraptor_id, velociraptor_org) @velociraptor_flows_router.post( @@ -91,6 +163,7 @@ async def get_all_flows_for_hostname( ) async def retrieve_flow( retrieve_flow_request: RetrieveFlowRequest, + session: AsyncSession = Depends(get_db), ) -> CollectArtifactResponse: """ Retrieve ran flows for a specific host. @@ -103,4 +176,4 @@ async def retrieve_flow( CollectArtifactResponse: The response containing the retrieved flows. """ logger.info(f"Fetching flow for flow_id {retrieve_flow_request.session_id}") - return await get_flow(retrieve_flow_request) + return await get_flow(retrieve_flow_request, await get_velociraptor_org_via_client_id(session, retrieve_flow_request.client_id)) diff --git a/backend/app/connectors/velociraptor/services/artifacts.py b/backend/app/connectors/velociraptor/services/artifacts.py index 842b2221..843dadaf 100644 --- a/backend/app/connectors/velociraptor/services/artifacts.py +++ b/backend/app/connectors/velociraptor/services/artifacts.py @@ -125,7 +125,7 @@ async def run_artifact_collection( f"FROM scope()" ), ) - flow = velociraptor_service.execute_query(query) + flow = velociraptor_service.execute_query(query, org_id=collect_artifact_body.velociraptor_org) logger.info(f"Successfully ran artifact collection on {flow}") artifact_key = get_artifact_key(analyzer_body=collect_artifact_body) @@ -133,12 +133,13 @@ async def run_artifact_collection( flow_id = flow["results"][0][artifact_key]["flow_id"] logger.info(f"Extracted flow_id: {flow_id}") - completed = velociraptor_service.watch_flow_completion(flow_id) + completed = velociraptor_service.watch_flow_completion(flow_id, org_id=collect_artifact_body.velociraptor_org) logger.info(f"Successfully watched flow completion on {completed}") results = velociraptor_service.read_collection_results( client_id=collect_artifact_body.velociraptor_id, flow_id=flow_id, + org_id=collect_artifact_body.velociraptor_org, artifact=collect_artifact_body.artifact_name, ) @@ -186,7 +187,7 @@ async def run_remote_command(run_command_body: RunCommandBody) -> RunCommandResp "FROM scope()" ), ) - flow = velociraptor_service.execute_query(query) + flow = velociraptor_service.execute_query(query, org_id=run_command_body.velociraptor_org) logger.info(f"Successfully ran artifact collection on {flow}") artifact_key = get_artifact_key(analyzer_body=run_command_body) @@ -194,12 +195,13 @@ async def run_remote_command(run_command_body: RunCommandBody) -> RunCommandResp flow_id = flow["results"][0][artifact_key]["flow_id"] logger.info(f"Extracted flow_id: {flow_id}") - completed = velociraptor_service.watch_flow_completion(flow_id) + completed = velociraptor_service.watch_flow_completion(flow_id, org_id=run_command_body.velociraptor_org) logger.info(f"Successfully watched flow completion on {completed}") results = velociraptor_service.read_collection_results( client_id=run_command_body.velociraptor_id, flow_id=flow_id, + org_id=run_command_body.velociraptor_org, artifact=run_command_body.artifact_name, ) @@ -250,7 +252,7 @@ async def quarantine_host(quarantine_body: QuarantineBody) -> QuarantineResponse "FROM scope()" ), ) - flow = velociraptor_service.execute_query(query) + flow = velociraptor_service.execute_query(query, org_id=quarantine_body.velociraptor_org) logger.info(f"Successfully ran artifact collection on {flow}") artifact_key = get_artifact_key(analyzer_body=quarantine_body) @@ -258,12 +260,13 @@ async def quarantine_host(quarantine_body: QuarantineBody) -> QuarantineResponse flow_id = flow["results"][0][artifact_key]["flow_id"] logger.info(f"Extracted flow_id: {flow_id}") - completed = velociraptor_service.watch_flow_completion(flow_id) + completed = velociraptor_service.watch_flow_completion(flow_id, org_id=quarantine_body.velociraptor_org) logger.info(f"Successfully watched flow completion on {completed}") results = velociraptor_service.read_collection_results( client_id=quarantine_body.velociraptor_id, flow_id=flow_id, + org_id=quarantine_body.velociraptor_org, artifact=quarantine_body.artifact_name, ) diff --git a/backend/app/connectors/velociraptor/services/flows.py b/backend/app/connectors/velociraptor/services/flows.py index 53d423be..627895cb 100644 --- a/backend/app/connectors/velociraptor/services/flows.py +++ b/backend/app/connectors/velociraptor/services/flows.py @@ -21,7 +21,7 @@ def create_query(query: str) -> str: return query -async def get_flows(velociraptor_id: str) -> FlowResponse: +async def get_flows(velociraptor_id: str, velociraptor_org: str = "root") -> FlowResponse: """ Get all artifacts from Velociraptor. @@ -33,7 +33,7 @@ async def get_flows(velociraptor_id: str) -> FlowResponse: query = create_query( f"SELECT * FROM flows(client_id='{velociraptor_id}')", ) - all_flows = velociraptor_service.execute_query(query) + all_flows = velociraptor_service.execute_query(query, org_id=velociraptor_org) logger.info(f"all_flows: {all_flows}") flows = [FlowClientSession(**flow) for flow in all_flows["results"]] logger.info(f"flows: {flows}") @@ -59,7 +59,7 @@ async def get_flows(velociraptor_id: str) -> FlowResponse: ) -async def get_flow(retrieve_flow_request: RetrieveFlowRequest): +async def get_flow(retrieve_flow_request: RetrieveFlowRequest, velociraptor_org: str = "root"): """ Get all artifacts from Velociraptor. @@ -71,7 +71,7 @@ async def get_flow(retrieve_flow_request: RetrieveFlowRequest): query = create_query( f"SELECT * FROM flow_results(client_id='{retrieve_flow_request.client_id}', flow_id='{retrieve_flow_request.session_id}')", ) - flow_results = velociraptor_service.execute_query(query) + flow_results = velociraptor_service.execute_query(query, org_id=velociraptor_org) logger.info(f"flow_results: {flow_results}") try: if flow_results["success"]: diff --git a/backend/app/connectors/velociraptor/utils/universal.py b/backend/app/connectors/velociraptor/utils/universal.py index 95c032f4..ab092d2d 100644 --- a/backend/app/connectors/velociraptor/utils/universal.py +++ b/backend/app/connectors/velociraptor/utils/universal.py @@ -141,7 +141,7 @@ async def create(cls, connector_name: str): # ! Modify this to use AsyncSessionLocal End - def create_vql_request(self, vql: str): + def create_vql_request(self, vql: str, org_id: str = "root"): """ Creates a VQLCollectorArgs object with given VQL query. @@ -153,6 +153,7 @@ def create_vql_request(self, vql: str): """ return api_pb2.VQLCollectorArgs( max_wait=1, + org_id=org_id, Query=[ api_pb2.VQLRequest( Name="VQLRequest", @@ -161,7 +162,7 @@ def create_vql_request(self, vql: str): ], ) - def execute_query(self, vql: str): + def execute_query(self, vql: str, org_id: str = "root"): """ Executes a VQL query and returns the results. @@ -173,7 +174,7 @@ def execute_query(self, vql: str): """ logger.info(f"Executing query: {vql}") - client_request = self.create_vql_request(vql) + client_request = self.create_vql_request(vql, org_id) try: results = [] @@ -202,7 +203,7 @@ def execute_query(self, vql: str): logger.error(f"Failed to execute query: {e}") raise HTTPException(status_code=500, detail=f"Failed to execute query: {e}") - def watch_flow_completion(self, flow_id: str): + def watch_flow_completion(self, flow_id: str, org_id: str = "root"): """ Watch for the completion of a flow. @@ -213,14 +214,14 @@ def watch_flow_completion(self, flow_id: str): dict: A dictionary with the success status and a message. """ vql = f"SELECT * FROM watch_monitoring(artifact='System.Flow.Completion') WHERE FlowId='{flow_id}' LIMIT 1" - # vql = f"SELECT * FROM query(org_id='OL680', query='SELECT * FROM watch_monitoring(artifact='System.Flow.Completion') WHERE FlowId='{flow_id}' LIMIT 1')" logger.info(f"Watching flow {flow_id} for completion") - return self.execute_query(vql) + return self.execute_query(vql, org_id) def read_collection_results( self, client_id: str, flow_id: str, + org_id: str = "root", artifact: str = "Generic.Client.Info/BasicInformation", ): """ @@ -235,7 +236,7 @@ def read_collection_results( dict: A dictionary with the success status, a message, and potentially the results. """ vql = f"SELECT * FROM source(client_id='{client_id}', flow_id='{flow_id}', artifact='{artifact}')" - return self.execute_query(vql) + return self.execute_query(vql, org_id) async def get_client_id(self, client_name: str): """