rootspring revised this gist . Go to revision
1 file changed, 85 insertions
cache.rs(file created)
@@ -0,0 +1,85 @@ | |||
1 | + | use super::core::{Limit, LimitTypes}; | |
2 | + | use governor::{DefaultKeyedRateLimiter, Quota}; | |
3 | + | use moka::future::Cache; | |
4 | + | use serenity::all::{GuildId, UserId}; | |
5 | + | use std::collections::HashMap; | |
6 | + | use std::num::NonZeroU32; | |
7 | + | use std::sync::{Arc, LazyLock}; | |
8 | + | ||
9 | + | // Hashmap of limit types to a hashmap of limit ids to its ratelimiter | |
10 | + | pub type RatelimiterMap<RlKey> = | |
11 | + | HashMap<LimitTypes, HashMap<String, DefaultKeyedRateLimiter<RlKey>>>; | |
12 | + | ||
13 | + | #[derive(Debug)] | |
14 | + | pub struct GuildLimitsCache { | |
15 | + | pub global: RatelimiterMap<()>, | |
16 | + | pub per_user: RatelimiterMap<UserId>, | |
17 | + | } | |
18 | + | ||
19 | + | impl GuildLimitsCache { | |
20 | + | /// Attempts to limit a user, returning a tuple of whether the user is allowed to continue, the time at which the bucket will be replenished, limit id that was hit | |
21 | + | pub async fn limit( | |
22 | + | &self, | |
23 | + | user_id: UserId, | |
24 | + | limit_type: LimitTypes, | |
25 | + | ) -> (bool, Option<governor::clock::QuantaInstant>, Option<String>) { | |
26 | + | if let Some(limits) = self.per_user.get(&limit_type) { | |
27 | + | for (limit_id, lim) in limits.iter() { | |
28 | + | match lim.check_key(&user_id) { | |
29 | + | Ok(()) => continue, // TODO: Return the time at which the bucket will be replenished | |
30 | + | Err(wait) => { | |
31 | + | return ( | |
32 | + | false, | |
33 | + | Some(wait.earliest_possible()), | |
34 | + | Some(limit_id.clone()), | |
35 | + | ) | |
36 | + | } | |
37 | + | } | |
38 | + | } | |
39 | + | } | |
40 | + | ||
41 | + | (true, None, None) | |
42 | + | } | |
43 | + | } | |
44 | + | ||
45 | + | static GUILD_LIMITS: LazyLock<Cache<GuildId, Arc<GuildLimitsCache>>> = | |
46 | + | LazyLock::new(|| Cache::builder().support_invalidation_closures().build()); | |
47 | + | ||
48 | + | pub async fn get_limits( | |
49 | + | data: &silverpelt::data::Data, | |
50 | + | guild_id: GuildId, | |
51 | + | ) -> Result<Arc<GuildLimitsCache>, silverpelt::Error> { | |
52 | + | if let Some(limits) = GUILD_LIMITS.get(&guild_id).await { | |
53 | + | Ok(limits.clone()) | |
54 | + | } else { | |
55 | + | let mut limits = GuildLimitsCache { | |
56 | + | global: HashMap::new(), | |
57 | + | per_user: HashMap::new(), | |
58 | + | }; | |
59 | + | ||
60 | + | // Init limits map here | |
61 | + | let limits_db = Limit::guild(&data.pool, guild_id).await?; | |
62 | + | ||
63 | + | for limit in limits_db { | |
64 | + | let limit_per = NonZeroU32::new(limit.limit_per as u32).ok_or("Invalid limit_per")?; | |
65 | + | let quota = Quota::with_period(std::time::Duration::from_secs(limit.limit_time as u64)) | |
66 | + | .ok_or("Failed to create quota")? | |
67 | + | .allow_burst(limit_per); | |
68 | + | ||
69 | + | let lim = DefaultKeyedRateLimiter::keyed(quota); | |
70 | + | ||
71 | + | // TODO: Support global limits | |
72 | + | limits | |
73 | + | .per_user | |
74 | + | .entry(limit.limit_type) | |
75 | + | .or_default() | |
76 | + | .insert(limit.limit_id.clone(), lim); | |
77 | + | } | |
78 | + | ||
79 | + | let limits = Arc::new(limits); | |
80 | + | ||
81 | + | GUILD_LIMITS.insert(guild_id, limits.clone()).await; | |
82 | + | ||
83 | + | Ok(limits) | |
84 | + | } | |
85 | + | } |
Newer
Older