1
- use std:: collections:: HashSet ;
1
+ use std:: { collections:: HashSet , hash :: RandomState } ;
2
2
3
3
use indexmap:: { IndexMap , IndexSet } ;
4
4
use iso8601_timestamp:: Timestamp ;
@@ -305,6 +305,18 @@ impl Message {
305
305
..Default :: default ( )
306
306
} ;
307
307
308
+ // Parse mentions in message.
309
+ let mut mentions = HashSet :: new ( ) ;
310
+ if allow_mentions {
311
+ if let Some ( content) = & data. content {
312
+ for capture in RE_MENTION . captures_iter ( content) {
313
+ if let Some ( mention) = capture. get ( 1 ) {
314
+ mentions. insert ( mention. as_str ( ) . to_string ( ) ) ;
315
+ }
316
+ }
317
+ }
318
+ }
319
+
308
320
// Verify replies are valid.
309
321
let mut replies = HashSet :: new ( ) ;
310
322
if let Some ( entries) = data. replies {
@@ -325,29 +337,27 @@ impl Message {
325
337
}
326
338
}
327
339
328
- // Parse mentions in message.
329
- let mut mentions = HashSet :: new ( ) ;
330
- if allow_mentions {
331
- if let Some ( content) = & data. content {
332
- for capture in RE_MENTION . captures_iter ( content) {
333
- if let Some ( mention) = capture. get ( 1 ) {
334
- mentions. insert ( mention. as_str ( ) . to_string ( ) ) ;
335
- }
336
- }
337
- }
338
- }
339
-
340
340
if !mentions. is_empty ( ) {
341
341
// FIXME: temp fix to stop spam attacks
342
342
match channel {
343
- Channel :: DirectMessage { recipients, .. } | Channel :: Group { recipients, .. } => {
344
- mentions = mentions. intersection ( recipients) ;
343
+ Channel :: DirectMessage { ref recipients, .. }
344
+ | Channel :: Group { ref recipients, .. } => {
345
+ let recipients_hash: HashSet < & String , RandomState > =
346
+ HashSet :: from_iter ( recipients. iter ( ) ) ;
347
+
348
+ mentions. retain ( |m| recipients_hash. contains ( m) ) ;
345
349
}
346
- Channel :: TextChannel { server, .. } | Channel :: VoiceChannel { server, .. } => {
347
- let valid_members = db. fetch_members ( server. into ( ) , mentions) . await ;
350
+ Channel :: TextChannel { ref server, .. }
351
+ | Channel :: VoiceChannel { ref server, .. } => {
352
+ let mentions_vec = Vec :: from_iter ( mentions. iter ( ) . cloned ( ) ) ;
353
+ let valid_members = db. fetch_members ( server. as_str ( ) , & mentions_vec[ ..] ) . await ;
348
354
if let Ok ( valid_members) = valid_members {
349
- let valid_ids = valid_members. iter ( ) . map ( |member| member. id . user ) ;
350
- mentions = mentions. intersection ( valid_ids) ;
355
+ let valid_ids: HashSet < String , RandomState > = HashSet :: from_iter (
356
+ valid_members. iter ( ) . map ( |member| member. id . user . clone ( ) ) ,
357
+ ) ;
358
+ mentions. retain ( |m| valid_ids. contains ( m) ) ;
359
+ } else {
360
+ revolt_config:: capture_error ( & valid_members. unwrap_err ( ) ) ;
351
361
}
352
362
}
353
363
Channel :: SavedMessages { .. } => mentions. clear ( ) ,
0 commit comments