Skip to content

Commit

Permalink
simpilfy shutdown, reduce lines (#192)
Browse files Browse the repository at this point in the history
* simpilfy shutdown, reduce lines

* clean up reload module, reduce lines

* setdefault rate in sendfile, clean up

* formatting

* simpilfy logic
  • Loading branch information
nggit authored Dec 31, 2024
1 parent deb25e2 commit d497c13
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 106 deletions.
9 changes: 5 additions & 4 deletions tremolo/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,11 @@ def bind(value='', **context):

def version(**context):
print(
'tremolo %s (%s %d.%d.%d, %s)' % (__version__,
sys.implementation.name,
*sys.version_info[:3],
sys.platform)
'tremolo %s (%s %d.%d.%d, %s)' %
(__version__,
sys.implementation.name,
*sys.version_info[:3],
sys.platform)
)
return 0

Expand Down
5 changes: 2 additions & 3 deletions tremolo/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,8 @@ def _handle_websocket(self):
self._server['websocket'] = WebSocket(self.request, self.response)

async def _handle_response(self, func, options):
options['rate'] = options.get('rate', self.options['download_rate'])
options['buffer_size'] = options.get('buffer_size',
self.options['buffer_size'])
options.setdefault('rate', self.options['download_rate'])
options.setdefault('buffer_size', self.options['buffer_size'])

if not self.request.has_body:
if 'websocket' in options:
Expand Down
35 changes: 18 additions & 17 deletions tremolo/lib/http_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def set_cookie(self, name, value='', expires=0, path='/', domain=None,
path = quote(path).encode('latin-1')

cookie = bytearray(
b'%s=%s; expires=%s; max-age=%d; path=%s' % (
name, value, date_expired, expires, path)
b'%s=%s; expires=%s; max-age=%d; path=%s' %
(name, value, date_expired, expires, path)
)

for k, v in ((b'domain', domain), (b'samesite', samesite)):
Expand Down Expand Up @@ -186,17 +186,15 @@ async def end(self, data=b'', keepalive=True, **kwargs):

await self.send(
b'HTTP/%s %d %s\r\nContent-Type: %s\r\nContent-Length: %d\r\n'
b'Connection: %s\r\n%s\r\n\r\n%s' % (
self.request.version,
*status,
self.get_content_type(),
content_length,
KEEPALIVE_OR_CLOSE[
keepalive and self.request.http_keepalive],
b'\r\n'.join(
b'\r\n'.join(v) for k, v in self.headers.items() if
k not in excludes),
data), **kwargs
b'Connection: %s\r\n%s\r\n\r\n%s' %
(self.request.version,
*status,
self.get_content_type(),
content_length,
KEEPALIVE_OR_CLOSE[keepalive and self.request.http_keepalive],
b'\r\n'.join(b'\r\n'.join(v) for k, v in self.headers.items()
if k not in excludes),
data), **kwargs
)
self.headers_sent(True)

Expand Down Expand Up @@ -276,6 +274,9 @@ async def sendfile(self, path, file_size=None, buffer_size=16384,
if isinstance(content_type, str):
content_type = content_type.encode('latin-1')

kwargs.setdefault(
'rate', self.request.protocol.options['download_rate']
)
kwargs['buffer_size'] = buffer_size
loop = self.request.protocol.loop

Expand Down Expand Up @@ -383,8 +384,8 @@ def run_sync(func, *args):
self.set_content_type(content_type)
self.set_header(b'Content-Length', b'%d' % size)
self.set_header(
b'Content-Range', b'bytes %d-%d/%d' % (
start, end, file_size)
b'Content-Range', b'bytes %d-%d/%d' %
(start, end, file_size)
)
await run_sync(handle.seek, start)

Expand All @@ -404,8 +405,8 @@ def run_sync(func, *args):
for start, end, size in ranges:
await self.write(
b'--%s\r\nContent-Type: %s\r\n'
b'Content-Range: bytes %d-%d/%d\r\n\r\n' % (
boundary, content_type, start, end, file_size),
b'Content-Range: bytes %d-%d/%d\r\n\r\n' %
(boundary, content_type, start, end, file_size),
**kwargs
)
await run_sync(handle.seek, start)
Expand Down
10 changes: 5 additions & 5 deletions tremolo/lib/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ async def recv(self):

if payload_length > self.protocol.options['ws_max_payload_size']:
raise WebSocketServerClosed(
'%d exceeds maximum payload size (%d)' % (
payload_length,
self.protocol.options['ws_max_payload_size']),
'%d exceeds maximum payload size (%d)' %
(payload_length,
self.protocol.options['ws_max_payload_size']),
code=1009
)

Expand Down Expand Up @@ -105,8 +105,8 @@ async def recv(self):
)

raise WebSocketServerClosed(
'unsupported opcode %x with payload length %d' % (
opcode, payload_length),
'unsupported opcode %x with payload length %d' %
(opcode, payload_length),
code=1008
)

Expand Down
130 changes: 55 additions & 75 deletions tremolo/tremolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,7 @@ async def _serve(self, host, port, **options):

sock.listen(backlog)

if ('ssl' in options and options['ssl'] and
isinstance(options['ssl'], dict)):
if 'ssl' in options and isinstance(options['ssl'] or None, dict):
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(
certfile=options['ssl'].get('cert', ''),
Expand Down Expand Up @@ -202,17 +201,13 @@ async def _serve(self, host, port, **options):

options['_routes'].compile()

try:
for _, func in self.hooks['worker_start']:
if (await func(globals=context,
context=context,
app=self,
loop=self.loop,
logger=self.logger)):
break
except Exception as exc:
self.loop.stop()
raise exc
for _, func in self.hooks['worker_start']:
if (await func(globals=context,
context=context,
app=self,
loop=self.loop,
logger=self.logger)):
break
else:
from .asgi_lifespan import ASGILifespan
from .asgi_server import ASGIServer as Server
Expand All @@ -221,13 +216,13 @@ async def _serve(self, host, port, **options):
# '/path/to/module.py' -> 'module:app' (dir: '/path/to')
# '/path/to/module.py:myapp' -> 'module:myapp' (dir: '/path/to')

if (':\\' in options['app'] and options['app'].count(':') < 2 or
':' not in options['app']):
if options['app'].find(':', options['app'].find(':\\') + 1) == -1:
options['app'] += ':app'

path, attr_name = options['app'].rsplit(':', 1)
context.options['app_dir'], base_name = os.path.split(
os.path.abspath(path))
os.path.abspath(path)
)
module_name = os.path.splitext(base_name)[0]

if context.options['app_dir'] == '':
Expand Down Expand Up @@ -258,7 +253,6 @@ async def _serve(self, host, port, **options):
)

if exc:
self.loop.stop()
raise exc

sockname = sock.getsockname()
Expand Down Expand Up @@ -338,14 +332,11 @@ async def _serve(self, host, port, **options):
finally:
server.close()

try:
while context.tasks:
await context.tasks.pop()
while context.tasks:
await context.tasks.pop()

await server.wait_closed()
await self._worker_stop(context)
finally:
self.loop.stop()
await server.wait_closed()
await self._worker_stop(context)

async def _serve_forever(self, context):
limit_memory = context.options.get('limit_memory', 0)
Expand All @@ -362,21 +353,22 @@ async def _serve_forever(self, context):
# detect code changes
if 'reload' in context.options and context.options['reload']:
for module in (dict(modules) or sys.modules.values()):
if not hasattr(module, '__file__'):
module_file = getattr(module, '__file__', None)

if module_file is None:
continue

for path in paths:
if (module.__file__ is None or
module.__file__.startswith(path)):
if module_file.startswith(path):
break
else:
if not os.path.exists(module.__file__):
if not os.path.exists(module_file):
if module in modules:
del modules[module]

continue

sign = file_signature(module.__file__)
sign = file_signature(module_file)

if module in modules:
if modules[module] == sign:
Expand All @@ -388,7 +380,7 @@ async def _serve_forever(self, context):
modules[module] = sign
continue

self.logger.info('reload: %s', module.__file__)
self.logger.info('reload: %s', module_file)
sys.exit(3)

if limit_memory > 0 and memory_usage() > limit_memory:
Expand Down Expand Up @@ -448,21 +440,20 @@ def _worker(self, host, port, **kwargs):
asyncio.set_event_loop(self.loop)
task = self.loop.create_task(self._serve(host, port, **kwargs))

task.add_done_callback(lambda fut: self.loop.stop())
signal.signal(signal.SIGINT, lambda signum, frame: task.cancel())
signal.signal(signal.SIGTERM, lambda signum, frame: task.cancel())

try:
self.loop.run_forever() # until loop.stop() is called
finally:
try:
if not task.cancelled():
exc = task.exception()
self.loop.close()

# to avoid None, SystemExit, etc. for being printed
if isinstance(exc, Exception):
self.logger.error(exc)
finally:
self.loop.close()
if not task.cancelled():
exc = task.exception()

if exc:
raise exc

def create_sock(self, host, port, reuse_port=True):
try:
Expand Down Expand Up @@ -533,29 +524,22 @@ def _handle_reload(self, **info):

if kwargs['app'] is None:
for module in list(sys.modules.values()):
if (hasattr(module, '__file__') and
module_file = getattr(module, '__file__', None)

if (module_file and
module.__name__ not in ('__main__',
'__mp_main__',
'tremolo') and
not module.__name__.startswith('tremolo.') and
module.__file__ is not None and
module.__file__.startswith(kwargs['app_dir']) and
os.path.exists(module.__file__)):
module_file.startswith(kwargs['app_dir']) and
os.path.exists(module_file)):
reload_module(module)

if kwargs['module_name'] in sys.modules:
_module = sys.modules[kwargs['module_name']]
else:
_module = import_module(kwargs['module_name'])
module = import_module(kwargs['module_name'])

# we need to update/rebind objects like
# routes, middleware, etc.
for attr_name in dir(_module):
if attr_name.startswith('__'):
continue

attr = getattr(_module, attr_name)

for attr in module.__dict__.values():
if isinstance(attr, self.__class__):
self.__dict__.update(attr.__dict__)

Expand All @@ -578,17 +562,17 @@ def _handle_reload(self, **info):
def run(self, host=None, port=0, reuse_port=True, worker_num=1, **kwargs):
kwargs['reuse_port'] = reuse_port
kwargs['log_level'] = kwargs.get('log_level', 'DEBUG').upper()
kwargs['shutdown_timeout'] = kwargs.get('shutdown_timeout', 30)
kwargs.setdefault('shutdown_timeout', 30)
server_name = kwargs.get('server_name', 'Tremolo')
terminal_width = min(get_terminal_size()[0], 72)

print(
'Starting %s (tremolo %s, %s %d.%d.%d, %s)' % (
server_name,
__version__,
sys.implementation.name,
*sys.version_info[:3],
sys.platform)
'Starting %s (tremolo %s, %s %d.%d.%d, %s)' %
(server_name,
__version__,
sys.implementation.name,
*sys.version_info[:3],
sys.platform)
)
print('-' * terminal_width)

Expand All @@ -599,11 +583,8 @@ def run(self, host=None, port=0, reuse_port=True, worker_num=1, **kwargs):
if not hasattr(__main__, '__file__'):
raise RuntimeError('could not find ASGI app')

for attr_name in dir(__main__):
if attr_name.startswith('__'):
continue

if getattr(__main__, attr_name) == kwargs['app']:
for attr_name, attr in __main__.__dict__.items():
if attr == kwargs['app']:
break
else:
attr_name = 'app'
Expand All @@ -629,14 +610,13 @@ def run(self, host=None, port=0, reuse_port=True, worker_num=1, **kwargs):

for routes in self.routes.values():
for route in routes:
pattern, func, kwds = route
pattern, func, kw = route

print(
' %s -> %s(%s)' % (
pattern,
func.__name__,
', '.join(
'%s=%s' % item for item in kwds.items()))
' %s -> %s(%s)' %
(pattern,
func.__name__,
', '.join('%s=%s' % item for item in kw.items()))
)

print()
Expand Down Expand Up @@ -671,11 +651,11 @@ def run(self, host=None, port=0, reuse_port=True, worker_num=1, **kwargs):

options = {**kwargs, **options}
print(
' run(host=%s, port=%d, worker_num=%d, %s)' % (
_host,
_port,
worker_num,
', '.join('%s=%s' % item for item in options.items()))
' run(host=%s, port=%d, worker_num=%d, %s)' %
(_host,
_port,
worker_num,
', '.join('%s=%s' % item for item in options.items()))
)

args = (_host, _port)
Expand Down
4 changes: 2 additions & 2 deletions tremolo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def parse_args(**callbacks):
options[name] = int(sys.argv[i])
except ValueError:
print(
'Invalid %s value "%s". It must be a number' % (
sys.argv[i - 1], sys.argv[i])
'Invalid %s value "%s". It must be a number' %
(sys.argv[i - 1], sys.argv[i])
)
sys.exit(1)
elif sys.argv[i - 1] == '--ssl-cert':
Expand Down

0 comments on commit d497c13

Please sign in to comment.