forked from JJTech0130/pypush
-
Notifications
You must be signed in to change notification settings - Fork 0
/
apns.py
323 lines (244 loc) · 9.68 KB
/
apns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
from __future__ import annotations
import random
import socket
import threading
import time
from hashlib import sha1
from base64 import b64encode, b64decode
import logging
logger = logging.getLogger("apns")
import tlslite
if tlslite.__version__ != "0.8.0-alpha43":
logger.warning("tlslite-ng is not the correct version!")
logger.warning("Please install tlslite-ng==0.8.0a43 or you will experience issues!")
import albert
import bags
#COURIER_HOST = "windows.courier.push.apple.com" # TODO: Get this from config
# Pick a random courier server from 01 to APNSCourierHostcount
COURIER_HOST = f"{random.randint(1, bags.apns_init_bag()['APNSCourierHostcount'])}-{bags.apns_init_bag()['APNSCourierHostname']}"
COURIER_PORT = 5223
ALPN = [b"apns-security-v2"]
# Connect to the courier server
def _connect(private_key: str, cert: str) -> tlslite.TLSConnection:
# Connect to the courier server
sock = socket.create_connection((COURIER_HOST, COURIER_PORT))
# Wrap the socket in TLS
sock = tlslite.TLSConnection(sock)
# Parse the certificate and private key
cert = tlslite.X509CertChain([tlslite.X509().parse(cert)])
private_key = tlslite.parsePEMKey(private_key, private=True)
# Handshake with the server
sock.handshakeClientCert(cert, private_key, alpn=ALPN)
logger.info(f"Connected to APNs ({COURIER_HOST})")
return sock
class IncomingQueue:
def __init__(self):
self.queue = []
self.lock = threading.Lock()
def append(self, item):
with self.lock:
self.queue.append(item)
def pop(self, index = -1):
with self.lock:
return self.queue.pop(index)
def __getitem__(self, index):
with self.lock:
return self.queue[index]
def __len__(self):
with self.lock:
return len(self.queue)
def find(self, finder):
with self.lock:
return next((i for i in self.queue if finder(i)), None)
def pop_find(self, finder):
with self.lock:
found = next((i for i in self.queue if finder(i)), None)
if found is not None:
# We have the lock, so we can safely remove it
self.queue.remove(found)
return found
def remove_all(self, id):
with self.lock:
self.queue = [i for i in self.queue if i[0] != id]
def wait_pop_find(self, finder, delay=0.1):
found = None
while found is None:
found = self.pop_find(finder)
if found is None:
time.sleep(delay)
return found
class APNSConnection:
incoming_queue = IncomingQueue()
# Sink everything in the queue
def sink(self):
self.incoming_queue = IncomingQueue()
def _queue_filler(self):
while True and not self.sock.closed:
payload = _deserialize_payload(self.sock)
if payload is not None:
# Automatically ACK incoming notifications to prevent APNs from getting mad at us
if payload[0] == 0x0A:
logger.debug("Sending automatic ACK")
self._send_ack(_get_field(payload[1], 4))
logger.debug(f"Received payload: {payload}")
self.incoming_queue.append(payload)
logger.debug(f"Queue length: {len(self.incoming_queue)}")
def _keep_alive_loop(self):
while True and not self.sock.closed:
time.sleep(300)
self._keep_alive()
def __init__(self, private_key=None, cert=None):
# Generate the private key and certificate if they're not provided
if private_key is None or cert is None:
logger.debug("APNs needs a new push certificate")
self.private_key, self.cert = albert.generate_push_cert()
else:
self.private_key, self.cert = private_key, cert
self.sock = _connect(self.private_key, self.cert)
self.queue_filler_thread = threading.Thread(
target=self._queue_filler, daemon=True
)
self.queue_filler_thread.start()
self.keep_alive_thread = threading.Thread(
target=self._keep_alive_loop, daemon=True
)
self.keep_alive_thread.start()
def connect(self, root: bool = True, token: bytes = None):
if token is None:
logger.debug(f"Sending connect message without token (root={root})")
else:
logger.debug(f"Sending connect message with token {b64encode(token).decode()} (root={root})")
flags = 0b01000001
if root:
flags |= 0b0100
if token is None:
payload = _serialize_payload(
7, [(2, 0x01.to_bytes(1, "big")), (5, flags.to_bytes(4, "big"))]
)
else:
payload = _serialize_payload(
7,
[
(1, token),
(2, 0x01.to_bytes(1, "big")),
(5, flags.to_bytes(4, "big")),
],
)
self.sock.write(payload)
payload = self.incoming_queue.wait_pop_find(lambda i: i[0] == 8)
if (
payload == None
or payload[0] != 8
or _get_field(payload[1], 1) != 0x00.to_bytes(1, "big")
):
raise Exception("Failed to connect")
new_token = _get_field(payload[1], 3)
if new_token is not None:
self.token = new_token
elif token is not None:
self.token = token
else:
raise Exception("No token")
logger.debug(f"Recieved connect response with token {b64encode(self.token).decode()}")
return self.token
def filter(self, topics: list[str]):
logger.debug(f"Sending filter message with topics {topics}")
fields = [(1, self.token)]
for topic in topics:
fields.append((2, sha1(topic.encode()).digest()))
payload = _serialize_payload(9, fields)
self.sock.write(payload)
def send_message(self, topic: str, payload: str, id=None):
logger.debug(f"Sending message to topic {topic} with payload {payload}")
if id is None:
id = random.randbytes(4)
payload = _serialize_payload(
0x0A,
[
(4, id),
(1, sha1(topic.encode()).digest()),
(2, self.token),
(3, payload),
],
)
self.sock.write(payload)
# Wait for ACK
payload = self.incoming_queue.wait_pop_find(lambda i: i[0] == 0x0B)
if payload[1][0][1] != 0x00.to_bytes(1, "big"):
raise Exception("Failed to send message")
def set_state(self, state: int):
logger.debug(f"Sending state message with state {state}")
self.sock.write(
_serialize_payload(
0x14,
[(1, state.to_bytes(1, "big")), (2, 0x7FFFFFFF.to_bytes(4, "big"))],
)
)
def _keep_alive(self):
logger.debug("Sending keep alive message")
self.sock.write(_serialize_payload(0x0C, []))
# Remove any keep alive responses we have or missed
self.incoming_queue.remove_all(0x0D)
def _send_ack(self, id: bytes):
logger.debug(f"Sending ACK for message {id}")
payload = _serialize_payload(0x0B, [(1, self.token), (4, id), (8, b"\x00")])
self.sock.write(payload)
# #self.sock.write(_serialize_payload(0x0B, [(4, id)])
# #pass
# def recieve_message(self):
# payload = self.incoming_queue.wait_pop_find(lambda i: i[0] == 0x0A)
# # Send ACK
# self._send_ack(_get_field(payload[1], 4))
# return _get_field(payload[1], 3)
# TODO: Find a way to make this non-blocking
# def expect_message(self) -> tuple[int, list[tuple[int, bytes]]] | None:
# return _deserialize_payload(self.sock)
def _serialize_field(id: int, value: bytes) -> bytes:
return id.to_bytes(1, "big") + len(value).to_bytes(2, "big") + value
def _serialize_payload(id: int, fields: list[(int, bytes)]) -> bytes:
payload = b""
for fid, value in fields:
if fid is not None:
payload += _serialize_field(fid, value)
return id.to_bytes(1, "big") + len(payload).to_bytes(4, "big") + payload
def _deserialize_field(stream: bytes) -> tuple[int, bytes]:
id = int.from_bytes(stream[:1], "big")
length = int.from_bytes(stream[1:3], "big")
value = stream[3 : 3 + length]
return id, value
# Note: Takes a stream, not a buffer, as we do not know the length of the payload
# WILL BLOCK IF THE STREAM IS EMPTY
def _deserialize_payload(stream) -> tuple[int, list[tuple[int, bytes]]] | None:
id = int.from_bytes(stream.read(1), "big")
if id == 0x0:
return None
length = int.from_bytes(stream.read(4), "big")
buffer = stream.read(length)
fields = []
while len(buffer) > 0:
fid, value = _deserialize_field(buffer)
fields.append((fid, value))
buffer = buffer[3 + len(value) :]
return id, fields
def _deserialize_payload_from_buffer(
buffer: bytes,
) -> tuple[int, list[tuple[int, bytes]]] | None:
id = int.from_bytes(buffer[:1], "big")
if id == 0x0:
return None
length = int.from_bytes(buffer[1:5], "big")
buffer = buffer[5:]
if len(buffer) < length:
raise Exception("Buffer is too short")
fields = []
while len(buffer) > 0:
fid, value = _deserialize_field(buffer)
fields.append((fid, value))
buffer = buffer[3 + len(value) :]
return id, fields
# Returns the value of the first field with the given id
def _get_field(fields: list[tuple[int, bytes]], id: int) -> bytes:
for field_id, value in fields:
if field_id == id:
return value
return None