-
Notifications
You must be signed in to change notification settings - Fork 143
/
FastTelethon.py
287 lines (258 loc) · 8.53 KB
/
FastTelethon.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
# copied from https://github.com/tulir/mautrix-telegram/blob/master/mautrix_telegram/util/parallel_file_transfer.py
# Copyright (C) 2021 Tulir Asokan
import asyncio
import hashlib
import inspect
import logging
import math
import os
from collections import defaultdict
from fileinput import filename
from typing import (
AsyncGenerator,
Awaitable,
BinaryIO,
DefaultDict,
List,
Optional,
Tuple,
Union,
)
from telethon import TelegramClient, helpers, utils
from telethon.crypto import AuthKey
from telethon.network import MTProtoSender
from telethon.tl.alltlobjects import LAYER
from telethon.tl.functions import InvokeWithLayerRequest
from telethon.tl.functions.auth import (
ExportAuthorizationRequest,
ImportAuthorizationRequest,
)
from telethon.tl.functions.upload import (
GetFileRequest,
SaveBigFilePartRequest,
SaveFilePartRequest,
)
from telethon.tl.types import (
Document,
InputDocumentFileLocation,
InputFile,
InputFileBig,
InputFileLocation,
InputPeerPhotoFileLocation,
InputPhotoFileLocation,
TypeInputFile,
)
log: logging.Logger = logging.getLogger("telethon")
TypeLocation = Union[
Document,
InputDocumentFileLocation,
InputPeerPhotoFileLocation,
InputFileLocation,
InputPhotoFileLocation,
]
class UploadSender:
client: TelegramClient
sender: MTProtoSender
request: Union[SaveFilePartRequest, SaveBigFilePartRequest]
part_count: int
stride: int
previous: Optional[asyncio.Task]
loop: asyncio.AbstractEventLoop
def __init__(
self,
client: TelegramClient,
sender: MTProtoSender,
file_id: int,
part_count: int,
big: bool,
index: int,
stride: int,
loop: asyncio.AbstractEventLoop,
) -> None:
self.client = client
self.sender = sender
self.part_count = part_count
if big:
self.request = SaveBigFilePartRequest(file_id, index, part_count, b"")
else:
self.request = SaveFilePartRequest(file_id, index, b"")
self.stride = stride
self.previous = None
self.loop = loop
async def next(self, data: bytes) -> None:
if self.previous:
await self.previous
self.previous = self.loop.create_task(self._next(data))
async def _next(self, data: bytes) -> None:
self.request.bytes = data
log.debug(
f"Sending file part {self.request.file_part}/{self.part_count}"
f" with {len(data)} bytes"
)
await self.client._call(self.sender, self.request)
self.request.file_part += self.stride
async def disconnect(self) -> None:
if self.previous:
await self.previous
return await self.sender.disconnect()
class ParallelTransferrer:
client: TelegramClient
loop: asyncio.AbstractEventLoop
dc_id: int
senders: Optional[List[Union[UploadSender]]]
auth_key: AuthKey
upload_ticker: int
def __init__(self, client: TelegramClient, dc_id: Optional[int] = None) -> None:
self.client = client
self.loop = self.client.loop
self.dc_id = dc_id or self.client.session.dc_id
self.auth_key = (
None
if dc_id and self.client.session.dc_id != dc_id
else self.client.session.auth_key
)
self.senders = None
self.upload_ticker = 0
async def _cleanup(self) -> None:
await asyncio.gather(*[sender.disconnect() for sender in self.senders])
self.senders = None
@staticmethod
def _get_connection_count(
file_size: int, max_count: int = 20, full_size: int = 100 * 1024 * 1024
) -> int:
if file_size > full_size:
return max_count
return math.ceil((file_size / full_size) * max_count)
async def _init_upload(
self, connections: int, file_id: int, part_count: int, big: bool
) -> None:
self.senders = [
await self._create_upload_sender(file_id, part_count, big, 0, connections),
*await asyncio.gather(
*[
self._create_upload_sender(file_id, part_count, big, i, connections)
for i in range(1, connections)
]
),
]
async def _create_upload_sender(
self, file_id: int, part_count: int, big: bool, index: int, stride: int
) -> UploadSender:
return UploadSender(
self.client,
await self._create_sender(),
file_id,
part_count,
big,
index,
stride,
loop=self.loop,
)
async def _create_sender(self) -> MTProtoSender:
dc = await self.client._get_dc(self.dc_id)
sender = MTProtoSender(self.auth_key, loggers=self.client._log)
await sender.connect(
self.client._connection(
dc.ip_address,
dc.port,
dc.id,
loggers=self.client._log,
proxy=self.client._proxy,
)
)
if not self.auth_key:
log.debug(f"Exporting auth to DC {self.dc_id}")
auth = await self.client(ExportAuthorizationRequest(self.dc_id))
self.client._init_request.query = ImportAuthorizationRequest(
id=auth.id, bytes=auth.bytes
)
req = InvokeWithLayerRequest(LAYER, self.client._init_request)
await sender.send(req)
self.auth_key = sender.auth_key
return sender
async def init_upload(
self,
file_id: int,
file_size: int,
part_size_kb: Optional[float] = None,
connection_count: Optional[int] = None,
) -> Tuple[int, int, bool]:
connection_count = connection_count or self._get_connection_count(file_size)
part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
part_count = (file_size + part_size - 1) // part_size
is_large = file_size > 10 * 1024 * 1024
await self._init_upload(connection_count, file_id, part_count, is_large)
return part_size, part_count, is_large
async def upload(self, part: bytes) -> None:
await self.senders[self.upload_ticker].next(part)
self.upload_ticker = (self.upload_ticker + 1) % len(self.senders)
async def finish_upload(self) -> None:
await self._cleanup()
parallel_transfer_locks: DefaultDict[int, asyncio.Lock] = defaultdict(
lambda: asyncio.Lock()
)
def stream_file(file_to_stream: BinaryIO, chunk_size=1024):
while True:
data_read = file_to_stream.read(chunk_size)
if not data_read:
break
yield data_read
async def _internal_transfer_to_telegram(
client: TelegramClient,
response: BinaryIO,
progress_callback: callable,
file_name: str = None,
) -> Tuple[TypeInputFile, int]:
file_id = helpers.generate_random_long()
file_size = os.path.getsize(response.name)
hash_md5 = hashlib.md5()
uploader = ParallelTransferrer(client)
part_size, part_count, is_large = await uploader.init_upload(file_id, file_size)
buffer = bytearray()
for data in stream_file(response):
if progress_callback:
r = progress_callback(response.tell(), file_size)
if inspect.isawaitable(r):
await r
if not is_large:
hash_md5.update(data)
if len(buffer) == 0 and len(data) == part_size:
await uploader.upload(data)
continue
new_len = len(buffer) + len(data)
if new_len >= part_size:
cutoff = part_size - len(buffer)
buffer.extend(data[:cutoff])
await uploader.upload(bytes(buffer))
buffer.clear()
buffer.extend(data[cutoff:])
else:
buffer.extend(data)
if len(buffer) > 0:
await uploader.upload(bytes(buffer))
await uploader.finish_upload()
if is_large:
return (
InputFileBig(file_id, part_count, file_name if file_name else "upload"),
file_size,
)
else:
return (
InputFile(
file_id,
part_count,
file_name if file_name else "upload",
hash_md5.hexdigest(),
),
file_size,
)
async def upload_file(
client: TelegramClient,
file: BinaryIO,
progress_callback: callable = None,
file_name: str = None,
) -> TypeInputFile:
res = (
await _internal_transfer_to_telegram(client, file, progress_callback, file_name)
)[0]
return res