Last active 1724168445

cache.rs Raw
1use super::core::{Limit, LimitTypes};
2use governor::{DefaultKeyedRateLimiter, Quota};
3use moka::future::Cache;
4use serenity::all::{GuildId, UserId};
5use std::collections::HashMap;
6use std::num::NonZeroU32;
7use std::sync::{Arc, LazyLock};
8
9// Hashmap of limit types to a hashmap of limit ids to its ratelimiter
10pub type RatelimiterMap<RlKey> =
11 HashMap<LimitTypes, HashMap<String, DefaultKeyedRateLimiter<RlKey>>>;
12
13#[derive(Debug)]
14pub struct GuildLimitsCache {
15 pub global: RatelimiterMap<()>,
16 pub per_user: RatelimiterMap<UserId>,
17}
18
19impl GuildLimitsCache {
20 /// Attempts to limit a user, returning a tuple of whether the user is allowed to continue, the time at which the bucket will be replenished, limit id that was hit
21 pub async fn limit(
22 &self,
23 user_id: UserId,
24 limit_type: LimitTypes,
25 ) -> (bool, Option<governor::clock::QuantaInstant>, Option<String>) {
26 if let Some(limits) = self.per_user.get(&limit_type) {
27 for (limit_id, lim) in limits.iter() {
28 match lim.check_key(&user_id) {
29 Ok(()) => continue, // TODO: Return the time at which the bucket will be replenished
30 Err(wait) => {
31 return (
32 false,
33 Some(wait.earliest_possible()),
34 Some(limit_id.clone()),
35 )
36 }
37 }
38 }
39 }
40
41 (true, None, None)
42 }
43}
44
45static GUILD_LIMITS: LazyLock<Cache<GuildId, Arc<GuildLimitsCache>>> =
46 LazyLock::new(|| Cache::builder().support_invalidation_closures().build());
47
48pub async fn get_limits(
49 data: &silverpelt::data::Data,
50 guild_id: GuildId,
51) -> Result<Arc<GuildLimitsCache>, silverpelt::Error> {
52 if let Some(limits) = GUILD_LIMITS.get(&guild_id).await {
53 Ok(limits.clone())
54 } else {
55 let mut limits = GuildLimitsCache {
56 global: HashMap::new(),
57 per_user: HashMap::new(),
58 };
59
60 // Init limits map here
61 let limits_db = Limit::guild(&data.pool, guild_id).await?;
62
63 for limit in limits_db {
64 let limit_per = NonZeroU32::new(limit.limit_per as u32).ok_or("Invalid limit_per")?;
65 let quota = Quota::with_period(std::time::Duration::from_secs(limit.limit_time as u64))
66 .ok_or("Failed to create quota")?
67 .allow_burst(limit_per);
68
69 let lim = DefaultKeyedRateLimiter::keyed(quota);
70
71 // TODO: Support global limits
72 limits
73 .per_user
74 .entry(limit.limit_type)
75 .or_default()
76 .insert(limit.limit_id.clone(), lim);
77 }
78
79 let limits = Arc::new(limits);
80
81 GUILD_LIMITS.insert(guild_id, limits.clone()).await;
82
83 Ok(limits)
84 }
85}
86