Last active 1724168445

rootspring's Avatar rootspring revised this gist 1724168445. Go to revision

1 file changed, 85 insertions

cache.rs(file created)

@@ -0,0 +1,85 @@
1 + use super::core::{Limit, LimitTypes};
2 + use governor::{DefaultKeyedRateLimiter, Quota};
3 + use moka::future::Cache;
4 + use serenity::all::{GuildId, UserId};
5 + use std::collections::HashMap;
6 + use std::num::NonZeroU32;
7 + use std::sync::{Arc, LazyLock};
8 +
9 + // Hashmap of limit types to a hashmap of limit ids to its ratelimiter
10 + pub type RatelimiterMap<RlKey> =
11 + HashMap<LimitTypes, HashMap<String, DefaultKeyedRateLimiter<RlKey>>>;
12 +
13 + #[derive(Debug)]
14 + pub struct GuildLimitsCache {
15 + pub global: RatelimiterMap<()>,
16 + pub per_user: RatelimiterMap<UserId>,
17 + }
18 +
19 + impl GuildLimitsCache {
20 + /// Attempts to limit a user, returning a tuple of whether the user is allowed to continue, the time at which the bucket will be replenished, limit id that was hit
21 + pub async fn limit(
22 + &self,
23 + user_id: UserId,
24 + limit_type: LimitTypes,
25 + ) -> (bool, Option<governor::clock::QuantaInstant>, Option<String>) {
26 + if let Some(limits) = self.per_user.get(&limit_type) {
27 + for (limit_id, lim) in limits.iter() {
28 + match lim.check_key(&user_id) {
29 + Ok(()) => continue, // TODO: Return the time at which the bucket will be replenished
30 + Err(wait) => {
31 + return (
32 + false,
33 + Some(wait.earliest_possible()),
34 + Some(limit_id.clone()),
35 + )
36 + }
37 + }
38 + }
39 + }
40 +
41 + (true, None, None)
42 + }
43 + }
44 +
45 + static GUILD_LIMITS: LazyLock<Cache<GuildId, Arc<GuildLimitsCache>>> =
46 + LazyLock::new(|| Cache::builder().support_invalidation_closures().build());
47 +
48 + pub async fn get_limits(
49 + data: &silverpelt::data::Data,
50 + guild_id: GuildId,
51 + ) -> Result<Arc<GuildLimitsCache>, silverpelt::Error> {
52 + if let Some(limits) = GUILD_LIMITS.get(&guild_id).await {
53 + Ok(limits.clone())
54 + } else {
55 + let mut limits = GuildLimitsCache {
56 + global: HashMap::new(),
57 + per_user: HashMap::new(),
58 + };
59 +
60 + // Init limits map here
61 + let limits_db = Limit::guild(&data.pool, guild_id).await?;
62 +
63 + for limit in limits_db {
64 + let limit_per = NonZeroU32::new(limit.limit_per as u32).ok_or("Invalid limit_per")?;
65 + let quota = Quota::with_period(std::time::Duration::from_secs(limit.limit_time as u64))
66 + .ok_or("Failed to create quota")?
67 + .allow_burst(limit_per);
68 +
69 + let lim = DefaultKeyedRateLimiter::keyed(quota);
70 +
71 + // TODO: Support global limits
72 + limits
73 + .per_user
74 + .entry(limit.limit_type)
75 + .or_default()
76 + .insert(limit.limit_id.clone(), lim);
77 + }
78 +
79 + let limits = Arc::new(limits);
80 +
81 + GUILD_LIMITS.insert(guild_id, limits.clone()).await;
82 +
83 + Ok(limits)
84 + }
85 + }
Newer Older