forked from defnull/multipart
-
Notifications
You must be signed in to change notification settings - Fork 0
/
multipart.py
534 lines (405 loc) · 16.9 KB
/
multipart.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
# -*- coding: utf-8 -*-
"""
Parser for multipart/form-data
==============================
This module provides a parser for the multipart/form-data format. It can read
from a file, a socket or a WSGI environment. The parser can be used to replace
cgi.FieldStorage to work around its limitations.
"""
__author__ = "Marcel Hellkamp"
__version__ = "0.2.4"
__license__ = "MIT"
__all__ = ["MultipartError", "MultipartParser", "MultipartPart", "parse_form_data"]
import re
import sys
from io import BytesIO
from tempfile import TemporaryFile
from urllib.parse import parse_qs
from wsgiref.headers import Headers
from collections.abc import MutableMapping as DictMixin
##############################################################################
################################ Helper & Misc ###############################
##############################################################################
# Some of these were copied from bottle: http://bottle.paws.de/
# ---------
# MultiDict
# ---------
class MultiDict(DictMixin):
""" A dict that remembers old values for each key.
HTTP headers may repeat with differing values,
such as Set-Cookie. We need to remember all
values.
"""
def __init__(self, *args, **kwargs):
self.dict = dict()
for k, v in dict(*args, **kwargs).items():
self[k] = v
def __len__(self):
return len(self.dict)
def __iter__(self):
return iter(self.dict)
def __contains__(self, key):
return key in self.dict
def __delitem__(self, key):
del self.dict[key]
def keys(self):
return self.dict.keys()
def __getitem__(self, key):
return self.get(key, KeyError, -1)
def __setitem__(self, key, value):
self.append(key, value)
def append(self, key, value):
self.dict.setdefault(key, []).append(value)
def replace(self, key, value):
self.dict[key] = [value]
def getall(self, key):
return self.dict.get(key) or []
def get(self, key, default=None, index=-1):
if key not in self.dict and default != KeyError:
return [default][index]
return self.dict[key][index]
def iterallitems(self):
for key, values in self.dict.items():
for value in values:
yield key, value
def to_bytes(data, enc="utf8"):
if isinstance(data, str):
data = data.encode(enc)
return data
def copy_file(stream, target, maxread=-1, buffer_size=2 * 16):
""" Read from :stream and write to :target until :maxread or EOF. """
size, read = 0, stream.read
while True:
to_read = buffer_size if maxread < 0 else min(buffer_size, maxread - size)
part = read(to_read)
if not part:
return size
target.write(part)
size += len(part)
# -------------
# Header Parser
# -------------
_special = re.escape('()<>@,;:"\\/[]?={} \t')
_re_special = re.compile(r'[%s]' % _special)
_quoted_string = r'"(?:\\.|[^"])*"' # Quoted string
_value = r'(?:[^%s]+|%s)' % (_special, _quoted_string) # Save or quoted string
_option = r'(?:;|^)\s*([^%s]+)\s*=\s*(%s)' % (_special, _value)
_re_option = re.compile(_option) # key=value part of an Content-Type like header
def header_quote(val):
if not _re_special.search(val):
return val
return '"' + val.replace("\\", "\\\\").replace('"', '\\"') + '"'
def header_unquote(val, filename=False):
if val[0] == val[-1] == '"':
val = val[1:-1]
if val[1:3] == ":\\" or val[:2] == "\\\\":
val = val.split("\\")[-1] # fix ie6 bug: full path --> filename
return val.replace("\\\\", "\\").replace('\\"', '"')
return val
def parse_options_header(header, options=None):
if ";" not in header:
return header.lower().strip(), {}
content_type, tail = header.split(";", 1)
options = options or {}
for match in _re_option.finditer(tail):
key = match.group(1).lower()
value = header_unquote(match.group(2), key == "filename")
options[key] = value
return content_type, options
##############################################################################
################################## Multipart #################################
##############################################################################
class MultipartError(ValueError):
pass
class MultipartParser(object):
def __init__(
self,
stream,
boundary,
content_length=-1,
disk_limit=2 ** 30,
mem_limit=2 ** 20,
memfile_limit=2 ** 18,
buffer_size=2 ** 16,
charset="latin1",
):
""" Parse a multipart/form-data byte stream. This object is an iterator
over the parts of the message.
:param stream: A file-like stream. Must implement ``.read(size)``.
:param boundary: The multipart boundary as a byte string.
:param content_length: The maximum number of bytes to read.
"""
self.stream = stream
self.boundary = boundary
self.content_length = content_length
self.disk_limit = disk_limit
self.memfile_limit = memfile_limit
self.mem_limit = min(mem_limit, self.disk_limit)
self.buffer_size = min(buffer_size, self.mem_limit)
self.charset = charset
if self.buffer_size - 6 < len(boundary): # "--boundary--\r\n"
raise MultipartError("Boundary does not fit into buffer_size.")
self._done = []
self._part_iter = None
def __iter__(self):
""" Iterate over the parts of the multipart message. """
if not self._part_iter:
self._part_iter = self._iterparse()
for part in self._done:
yield part
for part in self._part_iter:
self._done.append(part)
yield part
def parts(self):
""" Returns a list with all parts of the multipart message. """
return list(self)
def get(self, name, default=None):
""" Return the first part with that name or a default value (None). """
for part in self:
if name == part.name:
return part
return default
def get_all(self, name):
""" Return a list of parts with that name. """
return [p for p in self if p.name == name]
def _lineiter(self):
""" Iterate over a binary file-like object line by line. Each line is
returned as a (line, line_ending) tuple. If the line does not fit
into self.buffer_size, line_ending is empty and the rest of the line
is returned with the next iteration.
"""
read = self.stream.read
maxread, maxbuf = self.content_length, self.buffer_size
buffer = b"" # buffer for the last (partial) line
while True:
data = read(maxbuf if maxread < 0 else min(maxbuf, maxread))
maxread -= len(data)
lines = (buffer + data).splitlines(True)
len_first_line = len(lines[0])
# be sure that the first line does not become too big
if len_first_line > self.buffer_size:
# at the same time don't split a '\r\n' accidentally
if len_first_line == self.buffer_size + 1 and lines[0].endswith(b"\r\n"):
splitpos = self.buffer_size - 1
else:
splitpos = self.buffer_size
lines[:1] = [lines[0][:splitpos], lines[0][splitpos:]]
if data:
buffer = lines[-1]
lines = lines[:-1]
for line in lines:
if line.endswith(b"\r\n"):
yield line[:-2], b"\r\n"
elif line.endswith(b"\n"):
yield line[:-1], b"\n"
elif line.endswith(b"\r"):
yield line[:-1], b"\r"
else:
yield line, b""
if not data:
break
def _iterparse(self):
lines, line = self._lineiter(), ""
separator = b"--" + to_bytes(self.boundary)
terminator = b"--" + to_bytes(self.boundary) + b"--"
# Consume first boundary. Ignore any preamble, as required by RFC
# 2046, section 5.1.1.
for line, nl in lines:
if line in (separator, terminator):
break
else:
raise MultipartError("Stream does not contain boundary")
# Check for empty data
if line == terminator:
for _ in lines:
raise MultipartError("Data after end of stream")
return
# For each part in stream...
mem_used, disk_used = 0, 0 # Track used resources to prevent DoS
is_tail = False # True if the last line was incomplete (cutted)
opts = {
"buffer_size": self.buffer_size,
"memfile_limit": self.memfile_limit,
"charset": self.charset,
}
part = MultipartPart(**opts)
for line, nl in lines:
if line == terminator and not is_tail:
part.file.seek(0)
yield part
break
elif line == separator and not is_tail:
if part.is_buffered():
mem_used += part.size
else:
disk_used += part.size
part.file.seek(0)
yield part
part = MultipartPart(**opts)
else:
is_tail = not nl # The next line continues this one
try:
part.feed(line, nl)
if part.is_buffered():
if part.size + mem_used > self.mem_limit:
raise MultipartError("Memory limit reached.")
elif part.size + disk_used > self.disk_limit:
raise MultipartError("Disk limit reached.")
except MultipartError:
part.close()
raise
else:
# If we run off the end of the loop, the current MultipartPart
# will not have been yielded, so it's our responsibility to
# close it.
part.close()
if line != terminator:
raise MultipartError("Unexpected end of multipart stream.")
class MultipartPart(object):
def __init__(self, buffer_size=2 ** 16, memfile_limit=2 ** 18, charset="latin1"):
self.headerlist = []
self.headers = None
self.file = False
self.size = 0
self._buf = b""
self.disposition = None
self.name = None
self.filename = None
self.content_type = None
self.charset = charset
self.memfile_limit = memfile_limit
self.buffer_size = buffer_size
def feed(self, line, nl=""):
if self.file:
return self.write_body(line, nl)
return self.write_header(line, nl)
def write_header(self, line, nl):
line = line.decode(self.charset)
if not nl:
raise MultipartError("Unexpected end of line in header.")
if not line.strip(): # blank line -> end of header segment
self.finish_header()
elif line[0] in " \t" and self.headerlist:
name, value = self.headerlist.pop()
self.headerlist.append((name, value + line.strip()))
else:
if ":" not in line:
raise MultipartError("Syntax error in header: No colon.")
name, value = line.split(":", 1)
self.headerlist.append((name.strip(), value.strip()))
def write_body(self, line, nl):
if not line and not nl:
return # This does not even flush the buffer
self.size += len(line) + len(self._buf)
self.file.write(self._buf + line)
self._buf = nl
if self.content_length > 0 and self.size > self.content_length:
raise MultipartError("Size of body exceeds Content-Length header.")
if self.size > self.memfile_limit and isinstance(self.file, BytesIO):
# TODO: What about non-file uploads that exceed the memfile_limit?
self.file, old = TemporaryFile(mode="w+b"), self.file
old.seek(0)
copy_file(old, self.file, self.size, self.buffer_size)
def finish_header(self):
self.file = BytesIO()
self.headers = Headers(self.headerlist)
content_disposition = self.headers.get("Content-Disposition", "")
content_type = self.headers.get("Content-Type", "")
if not content_disposition:
raise MultipartError("Content-Disposition header is missing.")
self.disposition, self.options = parse_options_header(content_disposition)
self.name = self.options.get("name")
self.filename = self.options.get("filename")
self.content_type, options = parse_options_header(content_type)
self.charset = options.get("charset") or self.charset
self.content_length = int(self.headers.get("Content-Length") or "-1")
def is_buffered(self):
""" Return true if the data is fully buffered in memory."""
return isinstance(self.file, BytesIO)
@property
def value(self):
""" Data decoded with the specified charset """
return self.raw.decode(self.charset)
@property
def raw(self):
""" Data without decoding """
pos = self.file.tell()
self.file.seek(0)
try:
val = self.file.read()
except IOError:
raise
finally:
self.file.seek(pos)
return val
def save_as(self, path):
with open(path, "wb") as fp:
pos = self.file.tell()
try:
self.file.seek(0)
size = copy_file(self.file, fp)
finally:
self.file.seek(pos)
return size
def close(self):
if self.file:
self.file.close()
self.file = False
##############################################################################
#################################### WSGI ####################################
##############################################################################
def parse_form_data(environ, charset="utf8", strict=False, **kwargs):
""" Parse form data from an environ dict and return a (forms, files) tuple.
Both tuple values are dictionaries with the form-field name as a key
(unicode) and lists as values (multiple values per key are possible).
The forms-dictionary contains form-field values as unicode strings.
The files-dictionary contains :class:`MultipartPart` instances, either
because the form-field was a file-upload or the value is too big to fit
into memory limits.
:param environ: An WSGI environment dict.
:param charset: The charset to use if unsure. (default: utf8)
:param strict: If True, raise :exc:`MultipartError` on any parsing
errors. These are silently ignored by default.
"""
forms, files = MultiDict(), MultiDict()
try:
if environ.get("REQUEST_METHOD", "GET").upper() not in ("POST", "PUT"):
raise MultipartError("Request method other than POST or PUT.")
content_type = environ.get("CONTENT_TYPE", "")
if not content_type:
raise MultipartError("Missing Content-Type header.")
content_length = int(environ.get("CONTENT_LENGTH") or "-1")
content_type, options = parse_options_header(content_type)
stream = environ.get("wsgi.input") or BytesIO()
kwargs["charset"] = charset = options.get("charset", charset)
if content_type == "multipart/form-data":
boundary = options.get("boundary", "")
if not boundary:
raise MultipartError("No boundary for multipart/form-data.")
for part in MultipartParser(stream, boundary, content_length, **kwargs):
if part.filename or not part.is_buffered():
files[part.name] = part
else: # TODO: Big form-fields are in the files dict. really?
forms[part.name] = part.value
elif content_type in (
"application/x-www-form-urlencoded",
"application/x-url-encoded",
):
mem_limit = kwargs.get("mem_limit", 2 ** 20)
if content_length > mem_limit:
raise MultipartError("Request too big. Increase MAXMEM.")
data = stream.read(mem_limit).decode(charset)
if stream.read(1): # These is more that does not fit mem_limit
raise MultipartError("Request too big. Increase MAXMEM.")
data = parse_qs(data, keep_blank_values=True, encoding=charset)
for key, values in data.items():
for value in values:
forms[key] = value
else:
raise MultipartError("Unsupported content type.")
except MultipartError:
if strict:
for part in files.values():
part.close()
raise
return forms, files