1
1
from __future__ import annotations
2
2
3
+ from asyncio import TimeoutError as AsyncTimeoutError
4
+ from asyncio import sleep , wait_for
5
+ from contextlib import suppress
3
6
from importlib import import_module
4
- from logging import CRITICAL , INFO , Formatter , getLogger
7
+ from logging import CRITICAL , INFO , Formatter , getLogger , StreamHandler
5
8
from logging .handlers import RotatingFileHandler
6
9
from pathlib import Path
7
10
from random import choice
8
11
from sys import modules
9
12
from textwrap import dedent
13
+ from types import ModuleType
10
14
from typing import TYPE_CHECKING
11
15
12
16
from aiohttp import ClientSession
13
17
from asyncpg import create_pool
14
18
from nextcord import Embed , Interaction , Member , Thread , User , abc
15
- from nextcord .ext .commands import AutoShardedBot , when_mentioned , when_mentioned_or
19
+ from nextcord .ext .commands import (
20
+ AutoShardedBot ,
21
+ ExtensionNotFound ,
22
+ when_mentioned ,
23
+ when_mentioned_or ,
24
+ )
16
25
17
26
from .blacklist import Blacklist
18
27
from .emojis import Emojis
26
35
)
27
36
28
37
if TYPE_CHECKING :
29
- from typing import Any , Callable , Optional , Union
38
+ from typing import Any , Callable , Mapping , Optional , Union
30
39
31
40
from asyncpg import Pool
32
41
from nextcord import Guild , Message , PartialMessageable
72
81
"""
73
82
74
83
75
- def get_handler ():
84
+ def get_handlers ():
85
+ formatter = Formatter (
86
+ "%(levelname)-7s %(asctime)s %(filename)12s:%(funcName)-28s: %(message)s" ,
87
+ datefmt = "%H:%M:%S %d/%m/%Y" ,
88
+ )
76
89
h = RotatingFileHandler (
77
90
"./logs/bot/io.log" ,
78
91
maxBytes = 1000000 ,
79
92
backupCount = 5 ,
80
93
encoding = "utf-8" ,
81
94
)
82
- h .setFormatter (
83
- Formatter (
84
- "%(levelname)-7s %(asctime)s %(filename)12s:%(funcName)-28s: %(message)s" ,
85
- datefmt = "%H:%M:%S %d/%m/%Y" ,
86
- )
87
- )
95
+ i = StreamHandler ()
96
+
97
+ i .setFormatter (formatter )
98
+ h .setFormatter (formatter )
88
99
h .namer = lambda name : name .replace (".log" , "" ) + ".log"
89
- return h
100
+ return h , i
90
101
91
102
92
103
class BotBase (AutoShardedBot ):
@@ -141,9 +152,10 @@ def __init__(self, *args, config_module: str = "config", **kwargs) -> None:
141
152
log = getLogger ()
142
153
log .handlers = []
143
154
log .setLevel (INFO )
144
- h = get_handler ()
155
+ h , i = get_handlers ()
145
156
146
157
log .addHandler (h )
158
+ log .addHandler (i )
147
159
getLogger ("asyncio" ).setLevel (CRITICAL )
148
160
149
161
self .loop .set_exception_handler (self .asyncio_handler )
@@ -158,9 +170,8 @@ def __init__(self, *args, config_module: str = "config", **kwargs) -> None:
158
170
self .db_enabled = True
159
171
self .db_args = (db_url ,)
160
172
self .db_kwargs = {}
161
- elif (
162
- (db_name := getattr (config , "db_name" , None ))
163
- and (db_user := getattr (config , "db_user" , "ooliver" ))
173
+ elif (db_name := getattr (config , "db_name" , None )) and (
174
+ db_user := getattr (config , "db_user" , "ooliver" )
164
175
):
165
176
self .db_enabled = True
166
177
self .db_args = ()
@@ -169,6 +180,8 @@ def __init__(self, *args, config_module: str = "config", **kwargs) -> None:
169
180
"user" : db_user ,
170
181
"host" : getattr (config , "db_host" , None ),
171
182
}
183
+ if port := getattr (config , "db_port" , None ):
184
+ self .db_kwargs ["port" ] = port
172
185
else :
173
186
self .db_enabled = False
174
187
self .db_args = ()
@@ -191,6 +204,7 @@ def __init__(self, *args, config_module: str = "config", **kwargs) -> None:
191
204
self .logchannel : int | None = getattr (config , "logchannel" , None )
192
205
self .guild_ids : list [int ] | None = getattr (config , "guild_ids" , None )
193
206
self .database_init : str = initialise + getattr (config , "database_init" , "" )
207
+ self .name : Optional [str ] = getattr (config , "name" , None )
194
208
195
209
self ._single_events : dict [str , Callable ] = {
196
210
"on_message" : self .get_wrapped_message ,
@@ -231,11 +245,18 @@ def asyncio_handler(self, _, context: dict) -> None:
231
245
)
232
246
)
233
247
234
- async def startup (self ) -> None :
248
+ async def start (self , * args , ** kwargs ) -> None :
235
249
if self .db_enabled :
236
- db = await create_pool (* self .db_args , ** self .db_kwargs )
237
- assert db is not None
238
- self .db = db
250
+ for tries in range (5 ):
251
+ try :
252
+ db = await create_pool (* self .db_args , ** self .db_kwargs )
253
+ assert db is not None
254
+ self .db = db
255
+ except AssertionError :
256
+ await sleep (2.5 * tries + 1 )
257
+ else :
258
+ break
259
+
239
260
await self .db .execute (self .database_init )
240
261
241
262
if self .aiohttp_enabled :
@@ -244,9 +265,9 @@ async def startup(self) -> None:
244
265
if self .blacklist_enabled and self .db_enabled :
245
266
self .blacklist = Blacklist (self .db )
246
267
247
- def run (self , * args , ** kwargs ) -> None :
248
- self .loop .create_task (self .startup ())
268
+ await super ().start (* args , ** kwargs )
249
269
270
+ def run (self , * args , ** kwargs ) -> None :
250
271
cog_dir = f"{ self .mod } /cogs" if self .mod else "./cogs"
251
272
cogs = Path (cog_dir )
252
273
@@ -255,17 +276,21 @@ def run(self, *args, **kwargs) -> None:
255
276
if "extras" in ext .parts or any (part .startswith ("_" ) for part in ext .parts ):
256
277
continue
257
278
if ext .suffix == ".py" :
258
- a = str (ext ). replace ( "/" , "." )[: - 3 ]
279
+ a = "." . join (ext . parts ). removesuffix ( ".py" )
259
280
log .info ("Loading ext %s" , a )
260
281
self .load_extension (a )
261
282
log .info ("Loaded ext %s" , a )
262
283
263
284
super ().run (* args , ** kwargs )
264
285
265
286
async def close (self , * args , ** kwargs ) -> None :
266
- if self .aiohttp_enabled :
287
+ if self .aiohttp_enabled and hasattr ( self , "session" ) :
267
288
await self .session .close ()
268
289
290
+ if self .db_enabled and hasattr (self , "db" ):
291
+ with suppress (AsyncTimeoutError ):
292
+ await wait_for (self .db .close (), timeout = 5 )
293
+
269
294
await super ().close (* args , ** kwargs )
270
295
271
296
@staticmethod
@@ -531,3 +556,47 @@ async def on_guild_remove(self, guild: Guild):
531
556
await self .get_channel (self .logchannel ).send (embed = embed ) # type: ignore
532
557
except AttributeError :
533
558
pass
559
+
560
+ def load_extension (
561
+ self , name : str , * , extras : Optional [dict [str , Any ]] = None
562
+ ) -> None :
563
+ ext = f"{ self .name } .cogs.{ name } " if self .name else name
564
+
565
+ try :
566
+ super ().load_extension (ext , extras = extras )
567
+ except (ExtensionNotFound , ModuleNotFoundError ):
568
+ super ().load_extension (name , extras = extras )
569
+
570
+ if self .is_ready ():
571
+ self .loop .create_task (self .sync_all_application_commands ())
572
+
573
+ def reload_extension (self , name : str ) -> None :
574
+ ext = f"{ self .name } .cogs.{ name } " if self .name else name
575
+
576
+ try :
577
+ super ().reload_extension (ext )
578
+ except (ExtensionNotFound , ModuleNotFoundError ):
579
+ super ().reload_extension (name )
580
+
581
+ if self .is_ready ():
582
+ self .loop .create_task (self .sync_all_application_commands ())
583
+
584
+ def unload_extension (self , name : str ) -> None :
585
+ ext = f"{ self .name } .cogs.{ name } " if self .name else name
586
+
587
+ try :
588
+ super ().unload_extension (ext )
589
+ except (ExtensionNotFound , ModuleNotFoundError ):
590
+ super ().unload_extension (name )
591
+
592
+ self .loop .create_task (self .sync_all_application_commands ())
593
+
594
+ @property
595
+ def extensions (self ) -> Mapping [str , ModuleType ]:
596
+ if not self .name :
597
+ return super ().extensions
598
+
599
+ return {
600
+ k .removeprefix (f"{ self .name } .cogs." ): v
601
+ for k , v in super ().extensions .items ()
602
+ } | super ().extensions
0 commit comments