Skip to content

Commit 320c1e2

Browse files
committed
improve bot library
1 parent 5521491 commit 320c1e2

File tree

1 file changed

+23
-6
lines changed

1 file changed

+23
-6
lines changed

nightwatch/bot/client.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313

1414
# Exceptions
1515
class AuthorizationFailed(Exception):
16-
pass
16+
def __init__(self, message: str, json: dict | None = None) -> None:
17+
super().__init__(message)
18+
self.json = json
1719

1820
# Handle state
1921
class ClientState:
@@ -64,6 +66,12 @@ def __init__(self) -> None:
6466
self.__state = ClientState()
6567
self.__session = requests.Session()
6668

69+
# Public attributes (provided just for the hell of it)
70+
self.user: User | None = None
71+
"""The current user this client is connected as."""
72+
self.address: str | None = None
73+
"""The address this client is connected to."""
74+
6775
# Events (for overwriting)
6876
async def on_connect(self, ctx: Context) -> None:
6977
"""Listen to the :connect: event."""
@@ -105,12 +113,12 @@ async def __authorize(self, username: str, hex: str, address: str) -> tuple[str,
105113
# Handle payload
106114
payload = response.json()
107115
if payload["code"] != 200:
108-
raise AuthorizationFailed(response)
116+
raise AuthorizationFailed("Connection failed!", payload)
109117

110118
return host, int(port), f"ws{protocol}://", payload["authorization"]
111119

112-
except requests.RequestException:
113-
raise AuthorizationFailed("Connection failed!")
120+
except requests.RequestException as e:
121+
raise AuthorizationFailed("Connection failed!", e.response.json() if e.response is not None else None)
114122

115123
async def __match_event(self, event: dict[str, typing.Any]) -> None:
116124
match event:
@@ -129,6 +137,9 @@ async def __match_event(self, event: dict[str, typing.Any]) -> None:
129137

130138
case {"type": "join", "data": payload}:
131139
user = from_dict(User, payload["user"])
140+
if user == self.user:
141+
return
142+
132143
self.__state.user_list.append(user)
133144
await self.on_join(Context(self.__state, user = user))
134145

@@ -137,16 +148,22 @@ async def __match_event(self, event: dict[str, typing.Any]) -> None:
137148
self.__state.user_list.remove(user)
138149
await self.on_leave(Context(self.__state, user = user))
139150

140-
async def __event_loop(self, username: str, hex: str, address: str) -> None:
151+
async def event_loop(self, username: str, hex: str, address: str) -> None:
141152
"""Establish a connection and listen to websocket messages.
142153
This method shouldn't be called directly, use :Client.run: instead."""
143154

144155
host, port, protocol, auth = await self.__authorize(username, hex, address)
156+
self.user, self.address = User(username, hex, False, True), address
157+
145158
async with connect(f"{protocol}{host}:{port}/api/ws?authorization={auth}") as socket:
146159
self.__state.socket = socket
147160
while socket.state == 1:
148161
await self.__match_event(orjson.loads(await socket.recv()))
149162

163+
async def close(self) -> None:
164+
"""Closes the websocket connection."""
165+
await self.__state.socket.close()
166+
150167
def run(
151168
self,
152169
username: str,
@@ -160,4 +177,4 @@ def run(
160177
:hex: (str) -- the hex color code to connect with
161178
:address: (str) -- the FQDN to connect to
162179
"""
163-
asyncio.run(self.__event_loop(username, hex, address))
180+
asyncio.run(self.event_loop(username, hex, address))

0 commit comments

Comments
 (0)