Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ These changes are available on the `master` branch, but have not yet been releas
([#2775](https://github.com/Pycord-Development/pycord/pull/2775))
- Added `discord.Interaction.created_at`.
([#2801](https://github.com/Pycord-Development/pycord/pull/2801))
- Added support for asynchronous functions in dynamic cooldowns.
([#2823](https://github.com/Pycord-Development/pycord/pull/2823))
- Added `User.nameplate` property.
([#2817](https://github.com/Pycord-Development/pycord/pull/2817))
- Added role gradients support with `Role.colours` and the `RoleColours` class.
Expand Down
16 changes: 7 additions & 9 deletions discord/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,10 +333,10 @@ def guild_only(self, value: bool) -> None:
InteractionContextType.private_channel,
}

def _prepare_cooldowns(self, ctx: ApplicationContext):
async def _prepare_cooldowns(self, ctx: ApplicationContext):
if self._buckets.valid:
current = datetime.datetime.now().timestamp()
bucket = self._buckets.get_bucket(ctx, current) # type: ignore # ctx instead of non-existent message
bucket = await self._buckets.get_bucket(ctx, current)

if bucket is not None:
retry_after = bucket.update_rate_limit(current)
Expand All @@ -356,18 +356,16 @@ async def prepare(self, ctx: ApplicationContext) -> None:
)

if self._max_concurrency is not None:
# For this application, context can be duck-typed as a Message
await self._max_concurrency.acquire(ctx) # type: ignore # ctx instead of non-existent message

await self._max_concurrency.acquire(ctx)
try:
self._prepare_cooldowns(ctx)
await self._prepare_cooldowns(ctx)
await self.call_before_hooks(ctx)
except:
if self._max_concurrency is not None:
await self._max_concurrency.release(ctx) # type: ignore # ctx instead of non-existent message
raise

def is_on_cooldown(self, ctx: ApplicationContext) -> bool:
async def is_on_cooldown(self, ctx: ApplicationContext) -> bool:
"""Checks whether the command is currently on cooldown.

.. note::
Expand All @@ -387,7 +385,7 @@ def is_on_cooldown(self, ctx: ApplicationContext) -> bool:
if not self._buckets.valid:
return False

bucket = self._buckets.get_bucket(ctx) # type: ignore
bucket = await self._buckets.get_bucket(ctx) # type: ignore
current = utcnow().timestamp()
return bucket.get_tokens(current) == 0

Expand All @@ -400,7 +398,7 @@ def reset_cooldown(self, ctx: ApplicationContext) -> None:
The invocation context to reset the cooldown under.
"""
if self._buckets.valid:
bucket = self._buckets.get_bucket(ctx) # type: ignore # ctx instead of non-existent message
bucket = await self._buckets.get_bucket(ctx)
bucket.reset()

def get_cooldown_retry_after(self, ctx: ApplicationContext) -> float:
Expand Down
6 changes: 5 additions & 1 deletion discord/ext/commands/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,11 @@ def me(self) -> Member | ClientUser:
message contexts, or when :meth:`Intents.guilds` is absent.
"""
# bot.user will never be None at this point.
return self.guild.me if self.guild is not None and self.guild.me is not None else self.bot.user # type: ignore
return (
self.guild.me
if self.guild is not None and self.guild.me is not None
else self.bot.user
) # type: ignore

@property
def voice_client(self) -> VoiceProtocol | None:
Expand Down
89 changes: 54 additions & 35 deletions discord/ext/commands/cooldowns.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,18 @@
import asyncio
import time
from collections import deque
from typing import TYPE_CHECKING, Any, Callable, Deque, TypeVar
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Deque, TypeVar

import discord.abc
from discord import utils
from discord.enums import Enum

from ...abc import PrivateChannel
from .errors import MaxConcurrencyReached

if TYPE_CHECKING:
from ...commands import ApplicationContext
from ...ext.commands import Context
from ...message import Message

__all__ = (
Expand All @@ -60,31 +63,35 @@ class BucketType(Enum):
category = 5
role = 6

def get_key(self, msg: Message) -> Any:
def get_key(self, ctx: Context | ApplicationContext | Message) -> Any:
if self is BucketType.user:
return msg.author.id
return ctx.author.id
elif self is BucketType.guild:
return (msg.guild or msg.author).id
return (ctx.guild or ctx.author).id
elif self is BucketType.channel:
return msg.channel.id
return ctx.channel.id
elif self is BucketType.member:
return (msg.guild and msg.guild.id), msg.author.id
return (ctx.guild and ctx.guild.id), ctx.author.id
elif self is BucketType.category:
return (
msg.channel.category.id
if isinstance(msg.channel, discord.abc.GuildChannel)
and msg.channel.category
else msg.channel.id
ctx.channel.category.id
if isinstance(ctx.channel, discord.abc.GuildChannel)
and ctx.channel.category
else ctx.channel.id
)
elif self is BucketType.role:
# we return the channel id of a private-channel as there are only roles in guilds
# and that yields the same result as for a guild with only the @everyone role
# NOTE: PrivateChannel doesn't actually have an id attribute, but we assume we are
# receiving a DMChannel or GroupChannel which inherit from PrivateChannel and do
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore
return (
ctx.channel
if isinstance(ctx.channel, PrivateChannel)
else ctx.author.top_role
).id # type: ignore

def __call__(self, msg: Message) -> Any:
return self.get_key(msg)
def __call__(self, ctx: Context | ApplicationContext | Message) -> Any:
return self.get_key(ctx)


class Cooldown:
Expand Down Expand Up @@ -208,14 +215,14 @@ class CooldownMapping:
def __init__(
self,
original: Cooldown | None,
type: Callable[[Message], Any],
type: Callable[[Context | ApplicationContext | Message], Any],
) -> None:
if not callable(type):
raise TypeError("Cooldown type must be a BucketType or callable")

self._cache: dict[Any, Cooldown] = {}
self._cooldown: Cooldown | None = original
self._type: Callable[[Message], Any] = type
self._type: Callable[[Context | ApplicationContext | Message], Any] = type

def copy(self) -> CooldownMapping:
ret = CooldownMapping(self._cooldown, self._type)
Expand All @@ -227,15 +234,15 @@ def valid(self) -> bool:
return self._cooldown is not None

@property
def type(self) -> Callable[[Message], Any]:
def type(self) -> Callable[[Context | ApplicationContext | Message], Any]:
return self._type

@classmethod
def from_cooldown(cls: type[C], rate, per, type) -> C:
return cls(Cooldown(rate, per), type)

def _bucket_key(self, msg: Message) -> Any:
return self._type(msg)
def _bucket_key(self, ctx: Context | ApplicationContext | Message) -> Any:
return self._type(ctx)

def _verify_cache_integrity(self, current: float | None = None) -> None:
# we want to delete all cache objects that haven't been used
Expand All @@ -246,37 +253,47 @@ def _verify_cache_integrity(self, current: float | None = None) -> None:
for k in dead_keys:
del self._cache[k]

def create_bucket(self, message: Message) -> Cooldown:
async def create_bucket(
self, ctx: Context | ApplicationContext | Message
) -> Cooldown:
return self._cooldown.copy() # type: ignore

def get_bucket(self, message: Message, current: float | None = None) -> Cooldown:
async def get_bucket(
self, ctx: Context | ApplicationContext | Message, current: float | None = None
) -> Cooldown:
if self._type is BucketType.default:
return self._cooldown # type: ignore

self._verify_cache_integrity(current)
key = self._bucket_key(message)
key = self._bucket_key(ctx)
if key not in self._cache:
bucket = self.create_bucket(message)
bucket = await self.create_bucket(ctx)
if bucket is not None:
self._cache[key] = bucket
else:
bucket = self._cache[key]

return bucket

def update_rate_limit(
self, message: Message, current: float | None = None
async def update_rate_limit(
self, ctx: Context | ApplicationContext | Message, current: float | None = None
) -> float | None:
bucket = self.get_bucket(message, current)
bucket = await self.get_bucket(ctx, current)
return bucket.update_rate_limit(current)


class DynamicCooldownMapping(CooldownMapping):
def __init__(
self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any]
self,
factory: Callable[
[Context | ApplicationContext | Message], Cooldown | Awaitable[Cooldown]
],
type: Callable[[Context | ApplicationContext | Message], Any],
) -> None:
super().__init__(None, type)
self._factory: Callable[[Message], Cooldown] = factory
self._factory: Callable[
[Context | ApplicationContext | Message], Cooldown | Awaitable[Cooldown]
] = factory

def copy(self) -> DynamicCooldownMapping:
ret = DynamicCooldownMapping(self._factory, self._type)
Expand All @@ -287,8 +304,10 @@ def copy(self) -> DynamicCooldownMapping:
def valid(self) -> bool:
return True

def create_bucket(self, message: Message) -> Cooldown:
return self._factory(message)
async def create_bucket(
self, ctx: Context | ApplicationContext | Message
) -> Cooldown:
return await utils.maybe_coroutine(self._factory, ctx)


class _Semaphore:
Expand Down Expand Up @@ -376,11 +395,11 @@ def __repr__(self) -> str:
f"<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>"
)

def get_key(self, message: Message) -> Any:
return self.per.get_key(message)
def get_key(self, ctx: Context | ApplicationContext | Message) -> Any:
return self.per.get_key(ctx)

async def acquire(self, message: Message) -> None:
key = self.get_key(message)
async def acquire(self, ctx: Context | ApplicationContext | Message) -> None:
key = self.get_key(ctx)

try:
sem = self._mapping[key]
Expand All @@ -391,10 +410,10 @@ async def acquire(self, message: Message) -> None:
if not acquired:
raise MaxConcurrencyReached(self.number, self.per)

async def release(self, message: Message) -> None:
async def release(self, ctx: Context | ApplicationContext | Message) -> None:
# Technically there's no reason for this function to be async
# But it might be more useful in the future
key = self.get_key(message)
key = self.get_key(ctx)

try:
sem = self._mapping[key]
Expand Down
Loading
Loading