diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 4cbc9dc2..eeb2565e 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -18,8 +18,23 @@ jobs: steps: - uses: actions/checkout@v4 - - name: Some linters - run: echo "Some linters" + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install linters + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + - name: Run Ruff (format check) + run: ruff format --check + + - name: Run Ruff (lint + auto-fix) + run: ruff check + + # - name: Run MyPy (static type check) + # run: mypy test: uses: ./.github/workflows/tests.yml diff --git a/.github/workflows/other.yml b/.github/workflows/other.yml index f0fe8147..6557d839 100644 --- a/.github/workflows/other.yml +++ b/.github/workflows/other.yml @@ -13,8 +13,23 @@ jobs: steps: - uses: actions/checkout@v4 - - name: Some linters - run: echo "Some linters" + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install linters + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + - name: Run Ruff (format check) + run: ruff format --check + + - name: Run Ruff (lint + auto-fix) + run: ruff check + + # - name: Run MyPy (static type check) + # run: mypy test: uses: ./.github/workflows/tests.yml diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 69f92095..19ae0979 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -52,7 +52,7 @@ jobs: - name: Install dependencies run: | pip install --upgrade pip - pip install -r requirements-test.txt + pip install -r requirements.txt - name: Run tests run: | diff --git a/Makefile b/Makefile index 7dcd9565..5a6be830 100644 --- a/Makefile +++ b/Makefile @@ -12,8 +12,8 @@ help: @echo " make venv - Create virtual environment" @echo " make install - Install dependencies" @echo " make freeze - Freeze dependencies" - @echo " make test - Run tests" - @echo " make lint - Lint the code" + @echo " make tests - Run tests" + @echo " make linters - Run ruff formatter and linter" @echo " make up - Run project" @echo " make down - Stop project" @echo " make manage - Run manage.py command" @@ -42,11 +42,13 @@ install: venv freeze: venv $(PIP) freeze > requirements.txt -test: +tests: $(COMPOSE) exec -i $(SERVICE) python manage.py test --keepdb -lint: - $(COMPOSE) exec -i $(SERVICE) ruff format --exclude '**/migrations/*.py' +linters: + ruff format; \ + ruff check --fix; \ + # mypy migrate: ./migrate.sh diff --git a/docker-compose.yml b/docker-compose.yml index ab7b3a16..dd6ed547 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -7,8 +7,12 @@ services: volumes: - ./workflow:/var/app/workflow depends_on: - - db - - rabbitmq + migration: + condition: service_completed_successfully + db: + condition: service_started + rabbitmq: + condition: service_started ports: - "8001:8080" @@ -16,6 +20,7 @@ services: <<: *base environment: - DJANGO_SETTINGS_MODULE=workflow_app.settings # remove after change manage.py location + ports: [] command: > sh -c "celery --app workflow_app worker --concurrency=1 --loglevel=INFO -n $WORKFLOW_WORKER_NAME -Q $WORKFLOW_QUEUES --max-tasks-per-child=1" @@ -23,6 +28,7 @@ services: <<: *base environment: - DJANGO_SETTINGS_MODULE=workflow_app.settings # remove after change manage.py location + ports: [] command: > sh -c "celery --app workflow_app beat --loglevel=INFO --scheduler workflow.schedulers:DatabaseScheduler --pidfile=/tmp/celerybeat.pid" @@ -30,6 +36,12 @@ services: <<: *base environment: - DJANGO_SETTINGS_MODULE=workflow_app.settings # remove after change manage.py location + ports: [] + depends_on: + db: + condition: service_started + rabbitmq: + condition: service_started command: ["python", "manage.py", "migrate_all_schemes"] db: diff --git a/finmars_standardized_errors/formatter.py b/finmars_standardized_errors/formatter.py index 1d3a2578..e10dea53 100644 --- a/finmars_standardized_errors/formatter.py +++ b/finmars_standardized_errors/formatter.py @@ -1,24 +1,22 @@ +import datetime from dataclasses import asdict from http import HTTPStatus -from typing import List, Union -import datetime from django.utils.timezone import now - from rest_framework import exceptions from rest_framework.status import is_client_error from .models import ErrorRecord from .settings import package_settings -from .types import Error, ErrorResponse, ErrorType, ExceptionHandlerContext, ErrorResponseDetails +from .types import Error, ErrorResponse, ErrorResponseDetails, ErrorType, ExceptionHandlerContext class ExceptionFormatter: def __init__( - self, - exc: exceptions.APIException, - context: ExceptionHandlerContext, - original_exc: Exception, + self, + exc: exceptions.APIException, + context: ExceptionHandlerContext, + original_exc: Exception, ): self.exc = exc self.context = context @@ -43,16 +41,24 @@ def run(self): error_type = self.get_error_type() errors = self.get_errors() - url = str(self.context['request'].build_absolute_uri()) - username = str(self.context['request'].user.username) + url = str(self.context["request"].build_absolute_uri()) + username = str(self.context["request"].user.username) status_code = self.exc.status_code http_code_to_message = {v.value: v.description for v in HTTPStatus} message = http_code_to_message[status_code] error_datetime = str(datetime.datetime.strftime(now(), "%Y-%m-%d %H:%M:%S")) - ErrorRecord.objects.create(url=url, username=username, status_code=self.exc.status_code, message=message, details=asdict(ErrorResponseDetails(error_type, errors))) + ErrorRecord.objects.create( + url=url, + username=username, + status_code=self.exc.status_code, + message=message, + details=asdict(ErrorResponseDetails(error_type, errors)), + ) - error_response = self.get_error_response(url, username, status_code, message, error_datetime, error_type, errors) + error_response = self.get_error_response( + url, username, status_code, message, error_datetime, error_type, errors + ) return self.format_error_response(error_response) @@ -64,26 +70,32 @@ def get_error_type(self) -> ErrorType: else: return ErrorType.SERVER_ERROR - def get_errors(self) -> List[Error]: + def get_errors(self) -> list[Error]: """ Account for validation errors in nested serializers by returning a list of errors instead of a nested dict """ return flatten_errors(self.exc.detail) - def get_error_response(self, url: str, username: str, status_code: int, message, error_datetime, error_type: ErrorType, errors: List[Error]): - + def get_error_response( + self, + url: str, + username: str, + status_code: int, + message, + error_datetime, + error_type: ErrorType, + errors: list[Error], + ): error_response_details = ErrorResponseDetails(error_type, errors) return ErrorResponse(url, username, status_code, message, error_datetime, error_response_details) def format_error_response(self, error_response: ErrorResponse): - return {'error': asdict(error_response)} + return {"error": asdict(error_response)} -def flatten_errors( - detail: Union[list, dict, exceptions.ErrorDetail], attr=None, index=None -) -> List[Error]: +def flatten_errors(detail: list | dict | exceptions.ErrorDetail, attr=None, index=None) -> list[Error]: """ convert this: { @@ -129,13 +141,9 @@ def flatten_errors( new_attr = f"{attr}{package_settings.NESTED_FIELD_SEPARATOR}{index}" else: new_attr = str(index) - return flatten_errors(first_item, new_attr, index) + flatten_errors( - rest, attr, index - ) + return flatten_errors(first_item, new_attr, index) + flatten_errors(rest, attr, index) else: - return flatten_errors(first_item, attr, index) + flatten_errors( - rest, attr, index - ) + return flatten_errors(first_item, attr, index) + flatten_errors(rest, attr, index) elif isinstance(detail, dict): (key, value), *rest = list(detail.items()) if attr: diff --git a/finmars_standardized_errors/handler.py b/finmars_standardized_errors/handler.py index 0ae5f9b8..50dc9a80 100644 --- a/finmars_standardized_errors/handler.py +++ b/finmars_standardized_errors/handler.py @@ -1,5 +1,4 @@ import sys -from typing import Optional import django from django.conf import settings @@ -17,9 +16,7 @@ from .types import ExceptionHandlerContext -def exception_handler( - exc: Exception, context: ExceptionHandlerContext -) -> Optional[Response]: +def exception_handler(exc: Exception, context: ExceptionHandlerContext) -> Response | None: exception_handler_class = package_settings.EXCEPTION_HANDLER_CLASS msg = "`EXCEPTION_HANDLER_CLASS` should be a subclass of ExceptionHandler." assert issubclass(exception_handler_class, ExceptionHandler), msg @@ -31,7 +28,7 @@ def __init__(self, exc: Exception, context: ExceptionHandlerContext): self.exc = exc self.context = context - def run(self) -> Optional[Response]: + def run(self) -> Response | None: """entrypoint for handling an exception""" exc = self.convert_known_exceptions(self.exc) if self.should_not_handle(exc): @@ -96,7 +93,7 @@ def get_headers(self, exc: exceptions.APIException) -> dict: if getattr(exc, "auth_header", None): headers["WWW-Authenticate"] = exc.auth_header if getattr(exc, "wait", None): - headers["Retry-After"] = "%d" % exc.wait + headers["Retry-After"] = f"{exc.wait}" return headers def report_exception(self, exc: exceptions.APIException, response): diff --git a/finmars_standardized_errors/middleware.py b/finmars_standardized_errors/middleware.py index 990774a1..c862a8c9 100644 --- a/finmars_standardized_errors/middleware.py +++ b/finmars_standardized_errors/middleware.py @@ -16,11 +16,10 @@ import logging -_l = logging.getLogger('finmars') +_l = logging.getLogger("finmars") class ExceptionMiddleware(MiddlewareMixin): - def __init__(self, get_response): self.get_response = get_response @@ -31,7 +30,7 @@ def __call__(self, request): def process_exception(self, request, exception): # print('exception %s' % exception) - _l.error("ExceptionMiddleware process error %s" % request.build_absolute_uri()) + _l.error("ExceptionMiddleware process error %s", request.build_absolute_uri()) _l.error(traceback.format_exc()) lines = traceback.format_exc().splitlines()[-6:] @@ -47,23 +46,29 @@ def process_exception(self, request, exception): message = http_code_to_message[500] data = { - 'error': { - 'url': url, - 'username': username, - 'details': { - 'traceback': '\n'.join(traceback_lines), - 'error_message': repr(exception), + "error": { + "url": url, + "username": username, + "details": { + "traceback": "\n".join(traceback_lines), + "error_message": repr(exception), }, - 'message': message, - 'status_code': 500, - 'datetime': str(datetime.datetime.strftime(now(), '%Y-%m-%d %H:%M:%S')) + "message": message, + "status_code": 500, + "datetime": str(datetime.datetime.strftime(now(), "%Y-%m-%d %H:%M:%S")), } } - ErrorRecord.objects.create(url=url, username=username, status_code=500, message=message, details={ - 'traceback': '\n'.join(traceback_lines), - 'error_message': repr(exception), - }) + ErrorRecord.objects.create( + url=url, + username=username, + status_code=500, + message=message, + details={ + "traceback": "\n".join(traceback_lines), + "error_message": repr(exception), + }, + ) response_json = json.dumps(data, indent=2, sort_keys=True) diff --git a/finmars_standardized_errors/models.py b/finmars_standardized_errors/models.py index 43ec372c..6b7ca12e 100644 --- a/finmars_standardized_errors/models.py +++ b/finmars_standardized_errors/models.py @@ -3,27 +3,23 @@ from django.db import models from django.utils.translation import gettext_lazy -class ErrorRecord(models.Model): - url = models.CharField(max_length=255, null=True, blank=True, - verbose_name=gettext_lazy('url')) +class ErrorRecord(models.Model): + url = models.CharField(max_length=255, null=True, blank=True, verbose_name=gettext_lazy("url")) - username = models.CharField(max_length=255, null=True, blank=True, - verbose_name=gettext_lazy('username')) + username = models.CharField(max_length=255, null=True, blank=True, verbose_name=gettext_lazy("username")) - message = models.TextField(blank=True, default='', verbose_name=gettext_lazy('message')) - status_code = models.IntegerField(verbose_name=gettext_lazy('integer')) + message = models.TextField(blank=True, default="", verbose_name=gettext_lazy("message")) + status_code = models.IntegerField(verbose_name=gettext_lazy("integer")) - notes = models.TextField(blank=True, default='', verbose_name=gettext_lazy('notes')) + notes = models.TextField(blank=True, default="", verbose_name=gettext_lazy("notes")) - details_data = models.TextField(null=True, blank=True, verbose_name=gettext_lazy('details data')) + details_data = models.TextField(null=True, blank=True, verbose_name=gettext_lazy("details data")) created = models.DateTimeField(auto_now_add=True) class Meta: - - ordering = ['-created'] - + ordering = ["-created"] @property def details(self): @@ -40,4 +36,4 @@ def details(self, val): if val: self.details_data = json.dumps(val, default=str, sort_keys=True) else: - self.details_data = None \ No newline at end of file + self.details_data = None diff --git a/finmars_standardized_errors/openapi.py b/finmars_standardized_errors/openapi.py index 88ebac12..1fc0726b 100644 --- a/finmars_standardized_errors/openapi.py +++ b/finmars_standardized_errors/openapi.py @@ -1,5 +1,4 @@ import inspect -from typing import List, Optional, Type from drf_spectacular.drainage import warn from drf_spectacular.openapi import AutoSchema as BaseAutoSchema @@ -65,31 +64,23 @@ def _get_response_bodies(self, direction="response"): "the value is the serializer." ) continue - error_responses[status_code] = self._get_response_for_code( - serializer, status_code - ) + error_responses[status_code] = self._get_response_for_code(serializer, status_code) return {**error_responses, **responses} - def _get_allowed_error_status_codes(self) -> List[str]: + def _get_allowed_error_status_codes(self) -> list[str]: allowed_status_codes = package_settings.ALLOWED_ERROR_STATUS_CODES or [] return [str(status_code) for status_code in allowed_status_codes] - def _should_add_error_response(self, responses: dict, status_code: str) -> bool: - if ( - self.view.get_exception_handler() is not standardized_errors_handler - or status_code in responses - ): + def _should_add_error_response(self, responses: dict, status_code: str) -> bool: # noqa: PLR0911 + if self.view.get_exception_handler() is not standardized_errors_handler or status_code in responses: # this means that the exception handler has been overridden for this view # or the error response has already been added via extend_schema, so we # should not override that return False if status_code == "400": - return ( - self._should_add_parse_error_response() - or self._should_add_validation_error_response() - ) + return self._should_add_parse_error_response() or self._should_add_validation_error_response() elif status_code == "401": return self._should_add_http401_error_response() elif status_code == "403": @@ -118,9 +109,7 @@ def _should_add_parse_error_response(self) -> bool: MultiPartParser, FileUploadParser, ) - return any( - isinstance(parser, parsers_that_raise_parse_errors) for parser in parsers - ) + return any(isinstance(parser, parsers_that_raise_parse_errors) for parser in parsers) def _should_add_validation_error_response(self) -> bool: """ @@ -130,18 +119,12 @@ def _should_add_validation_error_response(self) -> bool: request_serializer = self.get_request_serializer() has_request_body = self.method in ("PUT", "PATCH", "POST") and ( isinstance(request_serializer, serializers.Field) - or ( - inspect.isclass(request_serializer) - and issubclass(request_serializer, serializers.Field) - ) + or (inspect.isclass(request_serializer) and issubclass(request_serializer, serializers.Field)) ) filter_backends = get_django_filter_backends(self.get_filter_backends()) has_filters = any( - [ - filter_backend.get_schema_operation_parameters(self.view) - for filter_backend in filter_backends - ] + [filter_backend.get_schema_operation_parameters(self.view) for filter_backend in filter_backends] ) return has_request_body or has_filters @@ -159,38 +142,19 @@ def _should_add_http403_error_response(self) -> bool: # in the view, then the error raised is a 401 not a 403 (check implementation # of rest_framework.views.APIView.permission_denied) is_authenticated = ( - len(permissions) == 1 - and isinstance(permissions[0], IsAuthenticated) - and self.view.get_authenticators() + len(permissions) == 1 and isinstance(permissions[0], IsAuthenticated) and self.view.get_authenticators() ) return bool(permissions) and not is_allow_any and not is_authenticated def _should_add_http404_error_response(self) -> bool: paginator = self._get_paginator() - paginator_can_raise_404 = isinstance( - paginator, (PageNumberPagination, CursorPagination) - ) + paginator_can_raise_404 = isinstance(paginator, PageNumberPagination | CursorPagination) versioning_scheme_can_raise_404 = self.view.versioning_class and issubclass( self.view.versioning_class, - ( - URLPathVersioning, - NamespaceVersioning, - HostNameVersioning, - QueryParameterVersioning, - ), - ) - has_path_parameters = bool( - [ - parameter - for parameter in self._get_parameters() - if parameter["in"] == "path" - ] - ) - return ( - paginator_can_raise_404 - or versioning_scheme_can_raise_404 - or has_path_parameters + URLPathVersioning | NamespaceVersioning | HostNameVersioning | QueryParameterVersioning, ) + has_path_parameters = bool([parameter for parameter in self._get_parameters() if parameter["in"] == "path"]) + return paginator_can_raise_404 or versioning_scheme_can_raise_404 or has_path_parameters def _should_add_http405_error_response(self) -> bool: # API consumers can at all ties use the wrong method against any endpoint @@ -199,8 +163,7 @@ def _should_add_http405_error_response(self) -> bool: def _should_add_http406_error_response(self) -> bool: content_negotiator = self.view.get_content_negotiator() return isinstance(content_negotiator, DefaultContentNegotiation) or ( - self.view.versioning_class - and issubclass(self.view.versioning_class, AcceptHeaderVersioning) + self.view.versioning_class and issubclass(self.view.versioning_class, AcceptHeaderVersioning) ) def _should_add_http415_error_response(self) -> bool: @@ -210,13 +173,8 @@ def _should_add_http415_error_response(self) -> bool: handles everything (media type "*/*"), then this error can be raised. """ content_negotiator = self.view.get_content_negotiator() - parsers_that_handle_everything = [ - parser for parser in self.view.get_parsers() if parser.media_type == "*/*" - ] - return ( - isinstance(content_negotiator, DefaultContentNegotiation) - and not parsers_that_handle_everything - ) + parsers_that_handle_everything = [parser for parser in self.view.get_parsers() if parser.media_type == "*/*"] + return isinstance(content_negotiator, DefaultContentNegotiation) and not parsers_that_handle_everything def _should_add_http429_error_response(self) -> bool: return bool(self.view.get_throttles()) @@ -227,9 +185,7 @@ def _should_add_http500_error_response(self) -> bool: def _get_error_response_serializer(self, status_code: str): error_schemas = package_settings.ERROR_SCHEMAS or {} - error_schemas = { - str(status_code): schema for status_code, schema in error_schemas.items() - } + error_schemas = {str(status_code): schema for status_code, schema in error_schemas.items()} if serializer := error_schemas.get(status_code): return serializer @@ -272,13 +228,13 @@ def _get_http400_serializer(self): def _get_serializer_for_validation_error_response( self, - ) -> Optional[Type[serializers.Serializer]]: + ) -> type[serializers.Serializer] | None: fields_with_error_codes = self._determine_fields_with_error_codes() operation_id = self.get_operation_id() return get_validation_error_serializer(operation_id, fields_with_error_codes) - def _determine_fields_with_error_codes(self) -> "List[InputDataField]": + def _determine_fields_with_error_codes(self) -> "list[InputDataField]": if self.method in ("PUT", "PATCH", "POST"): serializer = self.get_request_serializer() fields = get_flat_serializer_fields(serializer) diff --git a/finmars_standardized_errors/openapi_hooks.py b/finmars_standardized_errors/openapi_hooks.py index 4b58236a..a4e87fc9 100644 --- a/finmars_standardized_errors/openapi_hooks.py +++ b/finmars_standardized_errors/openapi_hooks.py @@ -14,7 +14,7 @@ from .settings import package_settings -def postprocess_schema_enums(result, generator, **kwargs): +def postprocess_schema_enums(result, generator, **kwargs): # noqa: PLR0912,PLR0915 """ This a copy of the postprocessing hook for enums provided by drf-spectacular with only one change in `iter_prop_containers`. The change allows excluding @@ -31,11 +31,11 @@ def postprocess_schema_enums(result, generator, **kwargs): def iter_prop_containers(schema, component_name=None): if not component_name: - for component_name, schema in schema.items(): + for component_name, schema in schema.items(): # noqa: B020,PLR1704 if spectacular_settings.COMPONENT_SPLIT_PATCH: - component_name = re.sub("^Patched(.+)", r"\1", component_name) + component_name = re.sub("^Patched(.+)", r"\1", component_name) # noqa: PLW2901 if spectacular_settings.COMPONENT_SPLIT_REQUEST: - component_name = re.sub("(.+)Request$", r"\1", component_name) + component_name = re.sub("(.+)Request$", r"\1", component_name) # noqa: PLW2901 yield from iter_prop_containers(schema, component_name) elif isinstance(schema, list): for item in schema: @@ -54,9 +54,7 @@ def iter_prop_containers(schema, component_name=None): yield from iter_prop_containers(schema.get("anyOf", []), component_name) def create_enum_component(name, schema): - component = ResolvedComponent( - name=name, type=ResolvedComponent.SCHEMA, schema=schema, object=name - ) + component = ResolvedComponent(name=name, type=ResolvedComponent.SCHEMA, schema=schema, object=name) generator.registry.register_on_missing(component) return component @@ -70,13 +68,11 @@ def create_enum_component(name, schema): for component_name, props in iter_prop_containers(schemas): for prop_name, prop_schema in props.items(): if prop_schema.get("type") == "array": - prop_schema = prop_schema.get("items", {}) + prop_schema = prop_schema.get("items", {}) # noqa: PLW2901 if "enum" not in prop_schema: continue # remove blank/null entry for hashing. will be reconstructed in the last step - prop_enum_cleaned_hash = list_hash( - [i for i in prop_schema["enum"] if i not in ["", None]] - ) + prop_enum_cleaned_hash = list_hash([i for i in prop_schema["enum"] if i not in ["", None]]) prop_hash_mapping[prop_name].add(prop_enum_cleaned_hash) hash_name_mapping[prop_enum_cleaned_hash].add((component_name, prop_name)) @@ -120,41 +116,28 @@ def create_enum_component(name, schema): for prop_name, prop_schema in props.items(): is_array = prop_schema.get("type") == "array" if is_array: - prop_schema = prop_schema.get("items", {}) + prop_schema = prop_schema.get("items", {}) # noqa: PLW2901 if "enum" not in prop_schema: continue prop_enum_original_list = prop_schema["enum"] - prop_schema["enum"] = [ - i for i in prop_schema["enum"] if i not in ["", None] - ] + prop_schema["enum"] = [i for i in prop_schema["enum"] if i not in ["", None]] prop_hash = list_hash(prop_schema["enum"]) # when choice sets are reused under multiple names, the generated name cannot be # resolved from the hash alone. fall back to prop_name and hash for resolution. - enum_name = ( - enum_name_mapping.get(prop_hash) - or enum_name_mapping[prop_hash, prop_name] - ) + enum_name = enum_name_mapping.get(prop_hash) or enum_name_mapping[prop_hash, prop_name] # split property into remaining property and enum component parts - enum_schema = { - k: v for k, v in prop_schema.items() if k in ["type", "enum"] - } - prop_schema = { - k: v for k, v in prop_schema.items() if k not in ["type", "enum"] - } + enum_schema = {k: v for k, v in prop_schema.items() if k in ["type", "enum"]} + prop_schema = {k: v for k, v in prop_schema.items() if k not in ["type", "enum"]} # noqa: PLW2901 components = [create_enum_component(enum_name, schema=enum_schema)] if spectacular_settings.ENUM_ADD_EXPLICIT_BLANK_NULL_CHOICE: if "" in prop_enum_original_list: - components.append( - create_enum_component("BlankEnum", schema={"enum": [""]}) - ) + components.append(create_enum_component("BlankEnum", schema={"enum": [""]})) if None in prop_enum_original_list: - components.append( - create_enum_component("NullEnum", schema={"enum": [None]}) - ) + components.append(create_enum_component("NullEnum", schema={"enum": [None]})) if len(components) == 1: prop_schema.update(components[0].ref) @@ -167,7 +150,5 @@ def create_enum_component(name, schema): props[prop_name] = safe_ref(prop_schema) # sort again with additional components - result["components"] = generator.registry.build( - spectacular_settings.APPEND_COMPONENTS - ) + result["components"] = generator.registry.build(spectacular_settings.APPEND_COMPONENTS) return result diff --git a/finmars_standardized_errors/openapi_utils.py b/finmars_standardized_errors/openapi_utils.py index 31999946..f69dff00 100644 --- a/finmars_standardized_errors/openapi_utils.py +++ b/finmars_standardized_errors/openapi_utils.py @@ -1,7 +1,6 @@ from collections import defaultdict from dataclasses import dataclass from dataclasses import field as dataclass_field -from typing import List, Optional, Set, Type, Union from django import forms from django.core.validators import ( @@ -35,9 +34,9 @@ from .settings import package_settings -def get_flat_serializer_fields( - field: Union[serializers.Field, List[serializers.Field]], prefix: str = None -) -> "List[InputDataField]": +def get_flat_serializer_fields( # noqa: PLR0911 + field: serializers.Field | list[serializers.Field], prefix: str = None +) -> "list[InputDataField]": """ return a flat list of serializer fields. The fields list will later be used to identify error codes that can be raised by each field. So, it contains @@ -65,11 +64,9 @@ def get_flat_serializer_fields( non_field_errors_name = get_prefix(prefix, drf_settings.NON_FIELD_ERRORS_KEY) f = InputDataField(non_field_errors_name, field) return [f] + get_flat_serializer_fields(list(field.fields.values()), prefix) - elif isinstance(field, (list, tuple)): + elif isinstance(field, list | tuple): first, *remaining = field - return get_flat_serializer_fields(first, prefix) + get_flat_serializer_fields( - remaining, prefix - ) + return get_flat_serializer_fields(first, prefix) + get_flat_serializer_fields(remaining, prefix) elif hasattr(field, "child"): # composite field (List or Dict fields) prefix = get_prefix(prefix, field.field_name) @@ -84,7 +81,7 @@ def get_flat_serializer_fields( return [InputDataField(name, field)] -def get_prefix(prefix: Optional[str], name: str) -> str: +def get_prefix(prefix: str | None, name: str) -> str: if prefix and name: return f"{prefix}{package_settings.NESTED_FIELD_SEPARATOR}{name}" elif prefix: @@ -94,8 +91,8 @@ def get_prefix(prefix: Optional[str], name: str) -> str: def get_serializer_fields_with_error_codes( - serializer_fields: "List[InputDataField]", -) -> "List[InputDataField]": + serializer_fields: "list[InputDataField]", +) -> "list[InputDataField]": fields_with_error_codes = [] for sfield in serializer_fields: if error_codes := get_serializer_field_error_codes(sfield.field, sfield.name): @@ -106,27 +103,21 @@ def get_serializer_fields_with_error_codes( sfields_with_unique_together_validators = [ sfield for sfield in fields_with_error_codes - if is_basic_serializer(sfield.field) - and has_validator(sfield.field, UniqueTogetherValidator) + if is_basic_serializer(sfield.field) and has_validator(sfield.field, UniqueTogetherValidator) ] - add_unique_together_error_codes( - sfields_with_unique_together_validators, fields_with_error_codes - ) + add_unique_together_error_codes(sfields_with_unique_together_validators, fields_with_error_codes) sfields_with_unique_for_validators = [ sfield for sfield in fields_with_error_codes - if is_basic_serializer(sfield.field) - and has_validator(sfield.field, BaseUniqueForValidator) + if is_basic_serializer(sfield.field) and has_validator(sfield.field, BaseUniqueForValidator) ] - add_unique_for_error_codes( - sfields_with_unique_for_validators, fields_with_error_codes - ) + add_unique_for_error_codes(sfields_with_unique_for_validators, fields_with_error_codes) return fields_with_error_codes -def get_serializer_field_error_codes(field: serializers.Field, attr: str) -> Set[str]: +def get_serializer_field_error_codes(field: serializers.Field, attr: str) -> set[str]: # noqa: PLR0912 if field.read_only or isinstance(field, serializers.HiddenField): return set() @@ -135,11 +126,7 @@ def get_serializer_field_error_codes(field: serializers.Field, attr: str) -> Set error_codes.add("required") if not field.allow_null: error_codes.add("null") - if ( - hasattr(field, "allow_blank") - and not field.allow_blank - and not isinstance(field, serializers.ChoiceField) - ): + if hasattr(field, "allow_blank") and not field.allow_blank and not isinstance(field, serializers.ChoiceField): error_codes.add("blank") if getattr(field, "max_digits", None) is not None: error_codes.add("max_digits") @@ -198,9 +185,7 @@ def get_serializer_field_error_codes(field: serializers.Field, attr: str) -> Set # to error messages, so it should not be added automatically to error codes error_codes_with_specific_conditions.append("invalid") - remaining_error_codes = set(field.error_messages).difference( - error_codes_with_specific_conditions - ) + remaining_error_codes = set(field.error_messages).difference(error_codes_with_specific_conditions) error_codes.update(remaining_error_codes) # for top-level (as opposed to nested) serializer non_field_errors, @@ -218,23 +203,17 @@ def get_serializer_field_error_codes(field: serializers.Field, attr: str) -> Set # {'zones': {0: [ErrorDetail(string='A valid integer is required.', code='invalid')]}} if isinstance(field, serializers.ManyRelatedField): # required and null are added depending on the ManyRelatedField definition - child_error_codes = set(field.child_relation.error_messages).difference( - ["required", "null"] - ) + child_error_codes = set(field.child_relation.error_messages).difference(["required", "null"]) error_codes.update(child_error_codes) return error_codes -def add_unique_together_error_codes( - sfields_with_unique_together_validators, sfields_with_error_codes -): +def add_unique_together_error_codes(sfields_with_unique_together_validators, sfields_with_error_codes): for sfield in sfields_with_unique_together_validators: sfield.error_codes.add("unique") unique_together_validators = [ - validator - for validator in sfield.field.validators - if isinstance(validator, UniqueTogetherValidator) + validator for validator in sfield.field.validators if isinstance(validator, UniqueTogetherValidator) ] # fields involved in a unique together constraint have an implied # "required" state, so we're adding the "required" error code to them @@ -245,19 +224,13 @@ def add_unique_together_error_codes( add_error_code(sfield.name, field, "required", sfields_with_error_codes) -def add_unique_for_error_codes( - sfields_with_unique_for_validators, sfields_with_error_codes -): +def add_unique_for_error_codes(sfields_with_unique_for_validators, sfields_with_error_codes): for sfield in sfields_with_unique_for_validators: unique_for_validators = [ - validator - for validator in sfield.field.validators - if isinstance(validator, BaseUniqueForValidator) + validator for validator in sfield.field.validators if isinstance(validator, BaseUniqueForValidator) ] for v in unique_for_validators: - add_error_code( - sfield.name, v.date_field, "required", sfields_with_error_codes - ) + add_error_code(sfield.name, v.date_field, "required", sfields_with_error_codes) add_error_code(sfield.name, v.field, "required", sfields_with_error_codes) add_error_code(sfield.name, v.field, "unique", sfields_with_error_codes) @@ -279,7 +252,7 @@ def add_error_code(attr, field_name, error_code, sfields): break -def get_filter_forms(view: APIView, filter_backends: list) -> List[forms.Form]: +def get_filter_forms(view: APIView, filter_backends: list) -> list[forms.Form]: filter_forms = [] for backend in filter_backends: model = get_view_model(view) @@ -291,7 +264,7 @@ def get_filter_forms(view: APIView, filter_backends: list) -> List[forms.Form]: return filter_forms -def get_form_fields_with_error_codes(form: forms.Form) -> "List[InputDataField]": +def get_form_fields_with_error_codes(form: forms.Form) -> "list[InputDataField]": data_fields = [] for field_name, field in form.fields.items(): error_codes = set() @@ -303,20 +276,20 @@ def get_form_fields_with_error_codes(form: forms.Form) -> "List[InputDataField]" return data_fields -def get_form_fields(field: Union[forms.Field, List[forms.Field]]) -> List[forms.Field]: +def get_form_fields(field: forms.Field | list[forms.Field]) -> list[forms.Field]: if not field: return [] - if isinstance(field, (list, tuple)): + if isinstance(field, list | tuple): first, *rest = field return get_form_fields(first) + get_form_fields(rest) - elif isinstance(field, (forms.ComboField, forms.MultiValueField)): + elif isinstance(field, forms.ComboField | forms.MultiValueField): return [field] + get_form_fields(field.fields) else: return [field] -def get_form_field_error_codes(field: forms.Field) -> Set[str]: +def get_form_field_error_codes(field: forms.Field) -> set[str]: if field.disabled: return set() @@ -338,9 +311,7 @@ def get_form_field_error_codes(field: forms.Field) -> Set[str]: # add the error codes defined in error_messages after excluding the ones # that are conditionally raised error_codes_with_specific_conditions = ["required", "max_length", "empty"] - remaining_error_codes = set(field.error_messages).difference( - error_codes_with_specific_conditions - ) + remaining_error_codes = set(field.error_messages).difference(error_codes_with_specific_conditions) error_codes.update(remaining_error_codes) # the "missing" error code is defined but never used by FileField @@ -349,13 +320,11 @@ def get_form_field_error_codes(field: forms.Field) -> Set[str]: return error_codes.difference(["missing", "incomplete"]) -def has_validator(field: Union[serializers.Field, forms.Field], validator): +def has_validator(field: serializers.Field | forms.Field, validator): return any(isinstance(v, validator) for v in field.validators) -def get_error_codes_from_validators( - field: Union[serializers.Field, forms.Field] -) -> Set[str]: +def get_error_codes_from_validators(field: serializers.Field | forms.Field) -> set[str]: error_codes = set() for validator in field.validators: @@ -385,9 +354,7 @@ def get_error_codes_from_validators( return error_codes -def get_validation_error_serializer( - operation_id: str, data_fields: "List[InputDataField]" -): +def get_validation_error_serializer(operation_id: str, data_fields: "list[InputDataField]"): validation_error_component_name = f"{camelize(operation_id)}ValidationError" errors_component_name = f"{camelize(operation_id)}Error" @@ -417,11 +384,9 @@ class Meta: return ValidationErrorSerializer -def get_error_serializer( - operation_id: str, attr: str, error_codes: Set[str] -) -> Type[serializers.Serializer]: +def get_error_serializer(operation_id: str, attr: str, error_codes: set[str]) -> type[serializers.Serializer]: attr_choices = [(attr, attr)] - error_code_choices = sorted(zip(error_codes, error_codes)) + error_code_choices = sorted(zip(error_codes, error_codes, strict=False)) camelcase_operation_id = camelize(operation_id) attr_with_underscores = attr.replace(package_settings.NESTED_FIELD_SEPARATOR, "_") @@ -443,8 +408,8 @@ class Meta: @dataclass class InputDataField: name: str - field: Union[serializers.Field, forms.Field] - error_codes: Set[str] = dataclass_field(default_factory=set) + field: serializers.Field | forms.Field + error_codes: set[str] = dataclass_field(default_factory=set) def get_django_filter_backends(backends): @@ -456,9 +421,7 @@ def get_django_filter_backends(backends): filter_backends = [filter_backend() for filter_backend in backends] return [ - backend - for backend in filter_backends - if isinstance(backend, DjangoFilterBackend) and backend.raise_exception + backend for backend in filter_backends if isinstance(backend, DjangoFilterBackend) and backend.raise_exception ] diff --git a/finmars_standardized_errors/serializers.py b/finmars_standardized_errors/serializers.py index aed3312d..a34732a3 100644 --- a/finmars_standardized_errors/serializers.py +++ b/finmars_standardized_errors/serializers.py @@ -1,8 +1,9 @@ -from finmars_standardized_errors.models import ErrorRecord from rest_framework import serializers +from finmars_standardized_errors.models import ErrorRecord + class ErrorRecordSerializer(serializers.ModelSerializer): class Meta: model = ErrorRecord - fields = ['id', 'url', 'username', 'message', 'status_code', 'notes', 'created', 'details'] \ No newline at end of file + fields = ["id", "url", "username", "message", "status_code", "notes", "created", "details"] diff --git a/finmars_standardized_errors/types.py b/finmars_standardized_errors/types.py index b25e2e65..69c33bbf 100644 --- a/finmars_standardized_errors/types.py +++ b/finmars_standardized_errors/types.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from enum import Enum -from typing import List, Optional, TypedDict +from typing import TypedDict from rest_framework.request import Request from rest_framework.views import APIView @@ -10,7 +10,7 @@ class ExceptionHandlerContext(TypedDict): view: APIView args: tuple kwargs: dict - request: Optional[Request] + request: Request | None class ErrorType(str, Enum): @@ -23,13 +23,14 @@ class ErrorType(str, Enum): class Error: code: str detail: str - attr: Optional[str] + attr: str | None @dataclass class ErrorResponseDetails: type: ErrorType - errors: List[Error] + errors: list[Error] + @dataclass class ErrorResponse: diff --git a/finmars_standardized_errors/views.py b/finmars_standardized_errors/views.py index b34a84f2..59275e7c 100644 --- a/finmars_standardized_errors/views.py +++ b/finmars_standardized_errors/views.py @@ -10,24 +10,24 @@ class ErrorRecordViewSet(ModelViewSet): queryset = ErrorRecord.objects.all() serializer_class = ErrorRecordSerializer - permission_classes = [ - IsAuthenticated - ] + permission_classes = [IsAuthenticated] filter_backends = [] - ordering_fields = ['created'] + ordering_fields = ["created"] def list(self, request, *args, **kwargs): - queryset = self.filter_queryset(self.get_queryset()) - query = request.GET.get('query', None) + query = request.GET.get("query", None) if query: queryset = queryset.filter( - Q(username__icontains=query) | Q(message__icontains=query) | Q(details_data__icontains=query) | Q( - url__icontains=query) | Q( - status_code__icontains=query) | Q( - created__icontains=query)) + Q(username__icontains=query) + | Q(message__icontains=query) + | Q(details_data__icontains=query) + | Q(url__icontains=query) + | Q(status_code__icontains=query) + | Q(created__icontains=query) + ) page = self.paginate_queryset(queryset) diff --git a/healthcheck/__init__.py b/healthcheck/__init__.py index 73153f6a..33ca2da5 100644 --- a/healthcheck/__init__.py +++ b/healthcheck/__init__.py @@ -1,3 +1 @@ -from __future__ import unicode_literals - -default_app_config = 'healthcheck.apps.HealthcheckConfig' +default_app_config = "healthcheck.apps.HealthcheckConfig" diff --git a/healthcheck/apps.py b/healthcheck/apps.py index 67e865f0..d5afb8d1 100644 --- a/healthcheck/apps.py +++ b/healthcheck/apps.py @@ -2,4 +2,4 @@ class HealthcheckConfig(AppConfig): - name = 'healthcheck' + name = "healthcheck" diff --git a/healthcheck/conf.py b/healthcheck/conf.py index cac65365..1777924b 100644 --- a/healthcheck/conf.py +++ b/healthcheck/conf.py @@ -1,6 +1,6 @@ from django.conf import settings -HEALTHCHECK = getattr(settings, 'HEALTHCHECK', {}) -HEALTHCHECK.setdefault('DISK_USAGE_MAX', 90) -HEALTHCHECK.setdefault('MEMORY_MIN', 100) -HEALTHCHECK.setdefault('WARNINGS_AS_ERRORS', True) +HEALTHCHECK = getattr(settings, "HEALTHCHECK", {}) +HEALTHCHECK.setdefault("DISK_USAGE_MAX", 90) +HEALTHCHECK.setdefault("MEMORY_MIN", 100) +HEALTHCHECK.setdefault("WARNINGS_AS_ERRORS", True) diff --git a/healthcheck/exceptions.py b/healthcheck/exceptions.py index 93a9bf04..bdab3fbe 100644 --- a/healthcheck/exceptions.py +++ b/healthcheck/exceptions.py @@ -8,7 +8,7 @@ def __init__(self, message): self.message = message def __str__(self): - return "%s: %s" % (self.message_type, self.message) + return f"{self.message_type}: {self.message}" class ServiceWarning(HealthCheckException): diff --git a/healthcheck/handlers.py b/healthcheck/handlers.py index 0857d1d3..c3707511 100644 --- a/healthcheck/handlers.py +++ b/healthcheck/handlers.py @@ -1,29 +1,31 @@ +import datetime +import locale import logging +import socket import time from timeit import default_timer as timer -from healthcheck.conf import HEALTHCHECK -from healthcheck.exceptions import HealthCheckException, ServiceReturnedUnexpectedResult, ServiceWarning, \ - ServiceUnavailable -from django.db import DatabaseError, IntegrityError - -import locale -import socket - -import datetime import psutil +from django.db import DatabaseError, IntegrityError +from healthcheck.conf import HEALTHCHECK +from healthcheck.exceptions import ( + HealthCheckException, + ServiceReturnedUnexpectedResult, + ServiceUnavailable, + ServiceWarning, +) from healthcheck.models import HealthcheckTestModel -_l = logging.getLogger('healthcheck') +_l = logging.getLogger("healthcheck") host = socket.gethostname() -DISK_USAGE_MAX = HEALTHCHECK['DISK_USAGE_MAX'] -MEMORY_MIN = HEALTHCHECK['MEMORY_MIN'] +DISK_USAGE_MAX = HEALTHCHECK["DISK_USAGE_MAX"] +MEMORY_MIN = HEALTHCHECK["MEMORY_MIN"] -class BaseHealthCheck: +class BaseHealthCheck: def __init__(self): self.errors = [] @@ -60,15 +62,12 @@ def add_error(self, error, cause=None): def pretty_status(self): if self.errors: - - return { - 'errors': [str(e) for e in self.errors] - } + return {"errors": [str(e) for e in self.errors]} return self.get_info() def get_info(self): - return 'working' + return "working" @property def status(self): @@ -81,36 +80,32 @@ def identifier(self): class DiskUsagePlugin(BaseHealthCheck): def check_status(self): try: - du = psutil.disk_usage('/') + du = psutil.disk_usage("/") if DISK_USAGE_MAX and du.percent >= DISK_USAGE_MAX: - raise ServiceWarning( - "{host} {percent}% disk usage exceeds {disk_usage}%".format( - host=host, percent=du.percent, disk_usage=DISK_USAGE_MAX) - ) + raise ServiceWarning(f"{host} {du.percent}% disk usage exceeds {DISK_USAGE_MAX}%") except ValueError as e: self.add_error(ServiceReturnedUnexpectedResult("ValueError"), e) def get_info(self): - data = [] item = {} - du = psutil.disk_usage('/') + du = psutil.disk_usage("/") - item['componentType'] = 'system' - item['observedValue'] = du.percent - item['observedUnit'] = 'percent' - item['time'] = datetime.datetime.now().isoformat() - item['status'] = 'pass' - item['output'] = '' + item["componentType"] = "system" + item["observedValue"] = du.percent + item["observedUnit"] = "percent" + item["time"] = datetime.datetime.now().isoformat() + item["status"] = "pass" + item["output"] = "" data.append(item) return data def identifier(self): - return 'disk:utilization' + return "disk:utilization" class MemoryUsagePlugin(BaseHealthCheck): @@ -118,18 +113,14 @@ def check_status(self): try: memory = psutil.virtual_memory() if MEMORY_MIN and memory.available < (MEMORY_MIN * 1024 * 1024): - locale.setlocale(locale.LC_ALL, '') - avail = '{:n}'.format(int(memory.available / 1024 / 1024)) - threshold = '{:n}'.format(MEMORY_MIN) - raise ServiceWarning( - "{host} {avail} MB available RAM below {threshold} MB".format( - host=host, avail=avail, threshold=threshold) - ) + locale.setlocale(locale.LC_ALL, "") + avail = f"{int(memory.available / 1024 / 1024):n}" + threshold = f"{MEMORY_MIN:n}" + raise ServiceWarning(f"{host} {avail} MB available RAM below {threshold} MB") except ValueError as e: self.add_error(ServiceReturnedUnexpectedResult("ValueError"), e) def get_info(self): - data = [] item = {} @@ -137,23 +128,22 @@ def get_info(self): memory = psutil.virtual_memory() available_memory_mb = int(memory.used / 1024 / 1024) - item['componentType'] = 'system' - item['observedValue'] = available_memory_mb - item['observedUnit'] = 'MiB' - item['time'] = datetime.datetime.now().isoformat() - item['status'] = 'pass' - item['output'] = '' + item["componentType"] = "system" + item["observedValue"] = available_memory_mb + item["observedUnit"] = "MiB" + item["time"] = datetime.datetime.now().isoformat() + item["status"] = "pass" + item["output"] = "" data.append(item) return data def identifier(self): - return 'memory:utilization' + return "memory:utilization" class DatabasePlugin(BaseHealthCheck): - response_time = None def check_status(self): @@ -166,57 +156,54 @@ def check_status(self): obj.name = "Second" obj.save() obj.delete() - except IntegrityError: - raise ServiceReturnedUnexpectedResult("Integrity Error") - except DatabaseError: - raise ServiceUnavailable("Database error") + except IntegrityError as e: + raise ServiceReturnedUnexpectedResult("Integrity Error") from e + except DatabaseError as e: + raise ServiceUnavailable("Database error") from e def get_info(self): - data = [] item = {} response_time_ms = int(round(self.response_time * 1000)) - item['componentType'] = 'datastore' - item['observedValue'] = response_time_ms - item['observedUnit'] = 'ms' - item['time'] = datetime.datetime.now().isoformat() - item['status'] = 'pass' - item['output'] = '' + item["componentType"] = "datastore" + item["observedValue"] = response_time_ms + item["observedUnit"] = "ms" + item["time"] = datetime.datetime.now().isoformat() + item["status"] = "pass" + item["output"] = "" data.append(item) return data def identifier(self): - return 'database:responseTime' + return "database:responseTime" class UptimePlugin(BaseHealthCheck): - def check_status(self): pass def get_info(self): - data = [] item = {} uptime = datetime.datetime.now() - datetime.datetime.fromtimestamp(psutil.boot_time()) - item['componentType'] = 'system' - item['observedValue'] = uptime.total_seconds() - item['observedUnit'] = 's' - item['time'] = datetime.datetime.now().isoformat() - item['status'] = 'pass' - item['output'] = '' + item["componentType"] = "system" + item["observedValue"] = uptime.total_seconds() + item["observedUnit"] = "s" + item["time"] = datetime.datetime.now().isoformat() + item["status"] = "pass" + item["output"] = "" data.append(item) return data def identifier(self): - return 'uptime' + return "uptime" diff --git a/healthcheck/views.py b/healthcheck/views.py index 7553b710..26e546ef 100644 --- a/healthcheck/views.py +++ b/healthcheck/views.py @@ -1,15 +1,14 @@ +from concurrent.futures import ThreadPoolExecutor + from django.http import JsonResponse +from django.utils.decorators import method_decorator from django.views.decorators.cache import never_cache from rest_framework.views import APIView -from concurrent.futures import ThreadPoolExecutor -from django.utils.decorators import method_decorator - -from healthcheck.handlers import DatabasePlugin, MemoryUsagePlugin, DiskUsagePlugin, UptimePlugin +from healthcheck.handlers import DatabasePlugin, DiskUsagePlugin, MemoryUsagePlugin, UptimePlugin class HealthcheckView(APIView): - _errors = None _plugins = None @@ -22,12 +21,7 @@ def errors(self): @property def plugins(self): if not self._plugins: - self._plugins = [ - DiskUsagePlugin(), - MemoryUsagePlugin(), - DatabasePlugin(), - UptimePlugin() - ] + self._plugins = [DiskUsagePlugin(), MemoryUsagePlugin(), DatabasePlugin(), UptimePlugin()] return self._plugins def run_check(self): @@ -39,29 +33,29 @@ def _run(plugin): return plugin finally: from django.db import connections + connections.close_all() with ThreadPoolExecutor(max_workers=len(self.plugins) or 1) as executor: for plugin in executor.map(_run, self.plugins): - errors.extend(plugin.errors) + errors.extend(plugin.errors) return errors - @method_decorator(never_cache, name='dispatch') + @method_decorator(never_cache, name="dispatch") def get(self, request, *args, **kwargs): - data = {} - data['version'] = 1 - data['checks'] = {} - data['status'] = 'pass' - data['notes'] = '' - data['description'] = '' - data['output'] = '' + data["version"] = 1 + data["checks"] = {} + data["status"] = "pass" + data["notes"] = "" + data["description"] = "" + data["output"] = "" status_code = 200 if self.errors: status_code = 500 - data['status'] = 'fail' + data["status"] = "fail" # for item in self.plugins: # @@ -69,7 +63,4 @@ def get(self, request, *args, **kwargs): # # data['checks'][key] = item.pretty_status() - return JsonResponse( - data, - status=status_code - ) + return JsonResponse(data, status=status_code) diff --git a/logstash/__init__.py b/logstash/__init__.py index 69146553..d8cdce7d 100644 --- a/logstash/__init__.py +++ b/logstash/__init__.py @@ -1,4 +1,4 @@ from logstash.formatter import LogstashFormatterVersion - from logstash.handler_tcp import TCPLogstashHandler +__all__ = ["LogstashFormatterVersion", "TCPLogstashHandler"] diff --git a/logstash/formatter.py b/logstash/formatter.py index 2ee94b57..94215c96 100644 --- a/logstash/formatter.py +++ b/logstash/formatter.py @@ -13,8 +13,7 @@ class LogstashFormatterBase(logging.Formatter): - - def __init__(self, message_type='Logstash', tags=None, fqdn=False): + def __init__(self, message_type="Logstash", tags=None, fqdn=False): self.message_type = message_type self.tags = tags if tags is not None else [] @@ -27,10 +26,31 @@ def get_extra_fields(self, record): # The list contains all the attributes listed in # http://docs.python.org/library/logging.html#logrecord-attributes skip_list = ( - 'args', 'asctime', 'created', 'exc_info', 'exc_text', 'filename', - 'funcName', 'id', 'levelname', 'levelno', 'lineno', 'module', - 'msecs', 'msecs', 'message', 'msg', 'name', 'pathname', 'process', - 'processName', 'relativeCreated', 'thread', 'threadName', 'extra') + "args", + "asctime", + "created", + "exc_info", + "exc_text", + "filename", + "funcName", + "id", + "levelname", + "levelno", + "lineno", + "module", + "msecs", + "msecs", + "message", + "msg", + "name", + "pathname", + "process", + "processName", + "relativeCreated", + "thread", + "threadName", + "extra", + ) easy_types = (str, bool, dict, float, int, list, type(None)) @@ -47,67 +67,63 @@ def get_extra_fields(self, record): def get_debug_fields(self, record): fields = { - 'stack_trace': self.format_exception(record.exc_info), - 'lineno': record.lineno, - 'process': record.process, - 'thread_name': record.threadName, + "stack_trace": self.format_exception(record.exc_info), + "lineno": record.lineno, + "process": record.process, + "thread_name": record.threadName, } # funcName was added in 2.5 - if not getattr(record, 'funcName', None): - fields['funcName'] = record.funcName + if not getattr(record, "funcName", None): + fields["funcName"] = record.funcName # processName was added in 2.6 - if not getattr(record, 'processName', None): - fields['processName'] = record.processName + if not getattr(record, "processName", None): + fields["processName"] = record.processName return fields @classmethod def format_source(cls, message_type, host, path): - return "%s://%s/%s" % (message_type, host, path) + return f"{message_type}://{host}/{path}" @classmethod def format_timestamp(cls, time): tstamp = datetime.utcfromtimestamp(time) - return tstamp.strftime("%Y-%m-%dT%H:%M:%S") + ".%03d" % (tstamp.microsecond / 1000) + "Z" + return tstamp.strftime("%Y-%m-%dT%H:%M:%S") + f".{tstamp.microsecond / 1000}Z" @classmethod def format_exception(cls, exc_info): - return ''.join(traceback.format_exception(*exc_info)) if exc_info else '' + return "".join(traceback.format_exception(*exc_info)) if exc_info else "" @classmethod def serialize(cls, message): - if sys.version_info < (3, 0): + if sys.version_info < (3, 0): # noqa: UP036 return json.dumps(message) else: - return bytes(json.dumps(message), 'utf-8') + return bytes(json.dumps(message), "utf-8") class LogstashFormatterVersion(LogstashFormatterBase): - def mask_secret_data(self, message): - - return message def format(self, record): # Create message dict message = { - '@timestamp': self.format_timestamp(record.created), - '@version': '1', - 'message': self.mask_secret_data(record.getMessage()), - 'host': self.host, - 'path': record.pathname, - 'tags': self.tags, - 'type': self.message_type, - 'host_location': settings.HOST_LOCATION, - + "@timestamp": self.format_timestamp(record.created), + "@version": "1", + "message": self.mask_secret_data(record.getMessage()), + "host": self.host, + "path": record.pathname, + "tags": self.tags, + "type": self.message_type, + "host_location": settings.HOST_LOCATION, # Extra Fields - 'level': record.levelname, - 'logger_name': record.name, - 'module': record.module, - 'lineno': record.lineno, + "level": record.levelname, + "logger_name": record.name, + "module": record.module, + "lineno": record.lineno, } # Add extra fields diff --git a/logstash/handler_tcp.py b/logstash/handler_tcp.py index 15166655..abddb94b 100644 --- a/logstash/handler_tcp.py +++ b/logstash/handler_tcp.py @@ -1,12 +1,12 @@ import ssl -import warnings from logging.handlers import SocketHandler + from logstash import formatter # Derive from object to force a new-style class and thus allow super() to work # on Python 2.6 -class TCPLogstashHandler(SocketHandler, object): +class TCPLogstashHandler(SocketHandler): """Python logging handler for Logstash. Sends events over TCP. :param host: The host of the logstash server. :param port: The port of the logstash server (default 5959). @@ -18,11 +18,25 @@ class TCPLogstashHandler(SocketHandler, object): :param ssl_verify: Should the server's SSL certificate be verified? :param keyfile: The path to client side SSL key file (default is None). :param certfile: The path to client side SSL certificate file (default is None). - :param ca_certs: The path to the file containing recognised CA certificates. System wide CA certs are used if omitted. + :param ca_certs: The path to the file containing recognised CA certificates. System wide CA certs + are used if omitted. """ - def __init__(self, host, port=5959, message_type='logstash', tags=None, fqdn=False, version=0, ssl=True, ssl_verify=True, keyfile=None, certfile=None, ca_certs=None): - super(TCPLogstashHandler, self).__init__(host, port) + def __init__( + self, + host, + port=5959, + message_type="logstash", + tags=None, + fqdn=False, + version=0, + ssl=True, + ssl_verify=True, + keyfile=None, + certfile=None, + ca_certs=None, + ): + super().__init__(host, port) self.ssl = ssl self.ssl_verify = ssl_verify @@ -33,11 +47,10 @@ def __init__(self, host, port=5959, message_type='logstash', tags=None, fqdn=Fal self.formatter = formatter.LogstashFormatterVersion(message_type, tags, fqdn) def makePickle(self, record): - return self.formatter.format(record) + b'\n' - + return self.formatter.format(record) + b"\n" def makeSocket(self, timeout=1): - s = super(TCPLogstashHandler, self).makeSocket(timeout) + s = super().makeSocket(timeout) if not self.ssl: return s @@ -58,4 +71,4 @@ def makeSocket(self, timeout=1): # if self.certfile and self.keyfile: # context.load_cert_chain(self.certfile, keyfile=self.keyfile) - return context.wrap_socket(s, server_hostname=self.host) \ No newline at end of file + return context.wrap_socket(s, server_hostname=self.host) diff --git a/manage.py b/manage.py index 72601803..2783f9c4 100755 --- a/manage.py +++ b/manage.py @@ -1,11 +1,12 @@ #!/usr/bin/env python """Django's command-line utility for administrative tasks.""" + import os import sys def main(): - os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'workflow_app.settings') + os.environ.setdefault("DJANGO_SETTINGS_MODULE", "workflow_app.settings") try: from django.core.management import execute_from_command_line except ImportError as exc: @@ -17,5 +18,5 @@ def main(): execute_from_command_line(sys.argv) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/pyproject.toml b/pyproject.toml index 20eb9950..eec49e03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "finmars-workflow" -version = "1.17.0" +version = "1.22.0" [tool.ruff] line-length = 119 @@ -39,7 +39,7 @@ version = "1.17.0" [tool.coverage.run] branch = true concurrency = ["multiprocessing"] -source = ["poms"] +source = ["workflow"] [tool.coverage.report] ignore_errors = true diff --git a/requirements-test.txt b/requirements-test.txt deleted file mode 100644 index a591374e..00000000 --- a/requirements-test.txt +++ /dev/null @@ -1,4 +0,0 @@ --r requirements.txt -coverage==7.7.0 -factory-boy==3.3.3 -tblib==3.0.0 diff --git a/requirements.txt b/requirements.txt index 0c54fcdb..65116fb3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -145,3 +145,6 @@ watchdog==6.0.0 wcwidth==0.2.5 whitenoise==6.3.0 zipp==3.19.1 +coverage==7.7.0 +factory-boy==3.3.3 +tblib==3.0.0 diff --git a/workflow/admin.py b/workflow/admin.py index ed28e048..7c360e66 100644 --- a/workflow/admin.py +++ b/workflow/admin.py @@ -1,8 +1,7 @@ from django.contrib import admin from django.contrib.auth.admin import UserAdmin -from workflow.models import User, Task, Workflow, Space, Schedule -from workflow_app import settings +from workflow.models import Schedule, Space, Task, User, Workflow admin.site.site_header = "Workflow Admin" admin.site.site_title = "Workflow Admin" diff --git a/workflow/api.py b/workflow/api.py index 4954b71d..7bdcb6d4 100644 --- a/workflow/api.py +++ b/workflow/api.py @@ -1,5 +1,6 @@ +from functools import partial, wraps from threading import local -from functools import wraps, partial + from workflow.tasks.base import BaseTask _registered_task = local() @@ -28,9 +29,7 @@ def decorator(func): _registered_task.func.__name__ = func.__name__ task_kwargs["name"] = prefixed_name - task_kwargs["base"] = ( - BaseTask # Extremely important, never forget to replace BaseTask - ) + task_kwargs["base"] = BaseTask # Extremely important, never forget to replace BaseTask # Register the function as a Celery task with the updated name # task = celery_app.task(*task_args, **task_kwargs)(execute_workflow_step) diff --git a/workflow/apps.py b/workflow/apps.py index 633c7db4..b27e9ef7 100644 --- a/workflow/apps.py +++ b/workflow/apps.py @@ -1,8 +1,7 @@ import logging from django.apps import AppConfig -from django.db import DEFAULT_DB_ALIAS -from django.db import connection +from django.db import DEFAULT_DB_ALIAS, connection from django.db.models.signals import post_migrate from workflow_app import settings @@ -23,15 +22,9 @@ class WorkflowConfig(AppConfig): def ready(self): import sys - sys.stdout.close = lambda: (_ for _ in ()).throw( - Exception("stdout close attempt detected") - ) + sys.stdout.close = lambda: (_ for _ in ()).throw(Exception("stdout close attempt detected")) - if not ( - "makemigrations" in sys.argv - or "migrate" in sys.argv - or "migrate_all_schemes" in sys.argv - ): + if not ("makemigrations" in sys.argv or "migrate" in sys.argv or "migrate_all_schemes" in sys.argv): from workflow.system import get_system_workflow_manager system_workflow_manager = get_system_workflow_manager() @@ -47,13 +40,13 @@ def bootstrap(self, app_config, verbosity=2, using=DEFAULT_DB_ALIAS, **kwargs): current_space_code = get_current_search_path() - _l.info("bootstrap: Current search path: %s" % current_space_code) + _l.info("bootstrap: Current search path: %s", current_space_code) self.create_space_if_not_exist() self.create_finmars_bot() except Exception as e: - _l.info("bootstrap: failed: %e" % e) + _l.info("bootstrap: failed: %e", e) def create_space_if_not_exist(self): from workflow.models import Space @@ -71,12 +64,10 @@ def create_space_if_not_exist(self): space.name = space_code space.realm_code = settings.REALM_CODE space.save() - _l.info("bootstrap.space_exists: %s " % space_code) + _l.info("bootstrap.space_exists: %s ", space_code) except Space.DoesNotExist: - space = Space.objects.create( - space_code=space_code, name=space_code, realm_code=settings.REALM_CODE - ) - _l.info("bootstrap.creating_new_space: %s " % space_code) + space = Space.objects.create(space_code=space_code, name=space_code, realm_code=settings.REALM_CODE) + _l.info("bootstrap.creating_new_space: %s ", space_code) def create_finmars_bot(self): from workflow.models import User @@ -85,6 +76,6 @@ def create_finmars_bot(self): user = User.objects.get(username="finmars_bot") except Exception as e: - user = User.objects.create(username="finmars_bot", is_bot=True) + user = User.objects.create(username="finmars_bot", is_bot=True) # noqa: F841 - _l.info("Finmars bot created %s" % e) + _l.info("Finmars bot created %s", e) diff --git a/workflow/authentication.py b/workflow/authentication.py index de06bc28..d11bfb76 100644 --- a/workflow/authentication.py +++ b/workflow/authentication.py @@ -1,6 +1,6 @@ import logging -import jwt +import jwt from django.conf import settings from django.contrib.auth import get_user_model from django.utils.translation import gettext_lazy as _ @@ -19,11 +19,9 @@ def get_access_token(request): try: token = auth[1].decode() - except UnicodeError: - msg = _( - "Invalid token header. Token string should not contain invalid characters." - ) - raise exceptions.AuthenticationFailed(msg) + except UnicodeError as e: + msg = _("Invalid token header. Token string should not contain invalid characters.") + raise exceptions.AuthenticationFailed(msg) from e return token @@ -44,8 +42,8 @@ def get_auth_token_from_request(self, request): if not auth: for key, value in request.COOKIES.items(): - if "access_token" == key: - auth = ["Token".encode(), value.encode()] + if key == "access_token": + auth = [b"Token", value.encode()] if not auth or auth[0].lower() != self.keyword.lower().encode(): return None @@ -59,11 +57,9 @@ def get_auth_token_from_request(self, request): try: token = auth[1].decode() - except UnicodeError: - msg = _( - "Invalid token header. Token string should not contain invalid characters." - ) - raise exceptions.AuthenticationFailed(msg) + except UnicodeError as e: + msg = _("Invalid token header. Token string should not contain invalid characters.") + raise exceptions.AuthenticationFailed(msg) from e return token @@ -71,7 +67,7 @@ def authenticate(self, request): # print('KeycloakAuthentication.authenticate') # print('KeycloakAuthentication.request.method %s' % request.method) - user_model = get_user_model() + user_model = get_user_model() # noqa: F841 if request.method == "OPTIONS": finmars_bot = User.objects.get(username="finmars_bot") @@ -80,9 +76,7 @@ def authenticate(self, request): token = self.get_auth_token_from_request(request) if token is None: - return ( - None # No token or not a Bearer token, continue to next authentication - ) + return None # No token or not a Bearer token, continue to next authentication return self.authenticate_credentials(token, request) @@ -109,15 +103,15 @@ def authenticate_credentials(self, key, request=None): userinfo = self.keycloak.userinfo(key) except Exception as e: msg = _("Invalid or expired token.") - raise exceptions.AuthenticationFailed(msg) + raise exceptions.AuthenticationFailed(msg) from e - user_model = get_user_model() + user_model = get_user_model() # noqa: F841 # user = user_model.objects.get(username=userinfo['preferred_username']) try: user = User.objects.get(username=userinfo["preferred_username"]) - except Exception as e: + except Exception: # _l.error("User not found %s" % e) # raise exceptions.AuthenticationFailed(e) @@ -131,7 +125,7 @@ def authenticate_credentials(self, key, request=None): password=generate_random_string(12), ) - except Exception as e: + except Exception: try: # TODO # Do not remove this thing @@ -142,7 +136,7 @@ def authenticate_credentials(self, key, request=None): except Exception as e: # _l.error("Error create new user %s" % e) - raise exceptions.AuthenticationFailed(e) + raise exceptions.AuthenticationFailed(e) from e return user, key @@ -174,11 +168,9 @@ def get_auth_token_from_request(self, request): try: token = auth[1].decode() - except UnicodeError: - msg = _( - "Invalid token header. Token string should not contain invalid characters." - ) - raise exceptions.AuthenticationFailed(msg) + except UnicodeError as e: + msg = _("Invalid token header. Token string should not contain invalid characters.") + raise exceptions.AuthenticationFailed(msg) from e return token @@ -190,14 +182,12 @@ def authenticate(self, request): token = self.get_auth_token_from_request(request) if token is None: - return ( - None # No token or not a Bearer token, continue to next authentication - ) + return None # No token or not a Bearer token, continue to next authentication return self.authenticate_credentials(token, request) def authenticate_credentials(self, key, request=None): - user_model = get_user_model() + user_model = get_user_model() # noqa: F841 # user = user_model.objects.get(username=userinfo['preferred_username']) @@ -205,19 +195,19 @@ def authenticate_credentials(self, key, request=None): # Decode the JWT token payload = jwt.decode(key, settings.SECRET_KEY, algorithms=["HS256"]) - except jwt.ExpiredSignatureError: - raise exceptions.AuthenticationFailed("Token has expired") - except jwt.InvalidTokenError: - raise exceptions.AuthenticationFailed("Invalid token") + except jwt.ExpiredSignatureError as e: + raise exceptions.AuthenticationFailed("Token has expired") from e + except jwt.InvalidTokenError as e: + raise exceptions.AuthenticationFailed("Invalid token") from e except Exception as e: - raise exceptions.AuthenticationFailed(str(e)) + raise exceptions.AuthenticationFailed(str(e)) from e try: user = User.objects.get(username=payload["username"]) except Exception as e: # _l.error("User not found %s" % e) - raise exceptions.AuthenticationFailed(e) + raise exceptions.AuthenticationFailed(e) from e return user, key diff --git a/workflow/builder.py b/workflow/builder.py index 84294535..1237cc71 100644 --- a/workflow/builder.py +++ b/workflow/builder.py @@ -4,29 +4,26 @@ from celery.utils import uuid from workflow.exceptions import WorkflowSyntaxError -from workflow.models import Workflow, Task +from workflow.models import Task, Workflow +from workflow.system import get_system_workflow_manager from workflow.tasks.workflows import ( - start, end, - failure_hooks_launcher, execute_workflow_step, + failure_hooks_launcher, + start, ) from workflow_app import celery_app _l = logging.getLogger("workflow") -from workflow.system import get_system_workflow_manager - system_workflow_manager = get_system_workflow_manager() -class WorkflowBuilder(object): +class WorkflowBuilder: def __init__(self, workflow_id, workflow_data): self.workflow_id = workflow_id self._workflow = None - self.workflow_data = ( - workflow_data # Pass in the workflow data (which might have a version) - ) + self.workflow_data = workflow_data # Pass in the workflow data (which might have a version) @property def workflow(self): @@ -93,9 +90,7 @@ def parse_flat_tasks(self, tasks, is_hook=False): if "type" not in task[name] and task[name]["type"] != "group": raise WorkflowSyntaxError() - sub_canvas_tasks = [ - self.new_task(t, is_hook, single=False) for t in task[name]["tasks"] - ] + sub_canvas_tasks = [self.new_task(t, is_hook, single=False) for t in task[name]["tasks"]] sub_canvas = group(*sub_canvas_tasks, task_id=uuid()) canvas.append(sub_canvas) @@ -116,19 +111,13 @@ def build(self): self.tasks = system_workflow_manager.get_tasks(str(self.workflow)) # Initialize hooks - self.failure_hook = system_workflow_manager.get_failure_hook_task( - str(self.workflow) - ) + self.failure_hook = system_workflow_manager.get_failure_hook_task(str(self.workflow)) self.failure_hook_canvas = [] - self.success_hook = system_workflow_manager.get_success_hook_task( - str(self.workflow) - ) + self.success_hook = system_workflow_manager.get_success_hook_task(str(self.workflow)) self.success_hook_canvas = [] - self.before_start_hook = system_workflow_manager.get_before_start_hook_task( - str(self.workflow) - ) + self.before_start_hook = system_workflow_manager.get_before_start_hook_task(str(self.workflow)) self.parse_queues() @@ -142,9 +131,7 @@ def build(self): if self.before_start_hook: initial_previous = self.previous self.previous = None - self.before_start_hook_canvas = self.parse_flat_tasks( - [self.before_start_hook], True - )[0] + self.before_start_hook_canvas = self.parse_flat_tasks([self.before_start_hook], True)[0] _l.info(f"Before start hook canvas: {self.before_start_hook_canvas}") @@ -195,9 +182,7 @@ def build_hooks(self): # Success Hook if self.success_hook and not self.success_hook_canvas: - self.success_hook_canvas = [ - self.parse_flat_tasks([self.success_hook], True)[0] - ] + self.success_hook_canvas = [self.parse_flat_tasks([self.success_hook], True)[0]] self.previous = initial_previous @@ -226,7 +211,7 @@ def cancel(self): status_to_cancel = [Task.STATUS_PROGRESS, Task.STATUS_INIT, Task.STATUS_NESTED_PROGRESS] for task in self.workflow.tasks: if task.status in status_to_cancel: - celery_app.control.revoke(task.celery_task_id, terminate=True, signal='SIGKILL') + celery_app.control.revoke(task.celery_task_id, terminate=True, signal="SIGKILL") task.mark_task_as_finished() task.status = Task.STATUS_CANCELED task.save() diff --git a/workflow/fields.py b/workflow/fields.py index af356d0a..1cf1a1b5 100644 --- a/workflow/fields.py +++ b/workflow/fields.py @@ -1,6 +1,6 @@ from rest_framework import serializers -from workflow.models import Space, User +from workflow.models import Space class CurrentSpaceDefault: diff --git a/workflow/filters.py b/workflow/filters.py index c26cdb65..c49918ff 100644 --- a/workflow/filters.py +++ b/workflow/filters.py @@ -1,8 +1,8 @@ import logging from datetime import datetime, timedelta -from django.db.models import Q import django_filters +from django.db.models import Q from rest_framework.filters import BaseFilterBackend, SearchFilter _l = logging.getLogger("workflow") @@ -74,9 +74,7 @@ def filter_queryset(self, request, queryset, view): queryset = queryset.filter(created_at__gte=date) if date_to: - date = datetime.strptime(date_to, "%Y-%m-%d") + timedelta( - days=1, microseconds=-1 - ) + date = datetime.strptime(date_to, "%Y-%m-%d") + timedelta(days=1, microseconds=-1) queryset = queryset.filter(created_at__lte=date) if status: diff --git a/workflow/finmars.py b/workflow/finmars.py index f763f962..81cf901a 100644 --- a/workflow/finmars.py +++ b/workflow/finmars.py @@ -14,7 +14,7 @@ from flatten_json import flatten from workflow.authentication import FinmarsRefreshToken -from workflow.models import User, Space +from workflow.models import Space, User from workflow_app import settings _l = logging.getLogger("workflow") @@ -41,9 +41,7 @@ def get_access_token(ttl_minutes=60 * 8, *args, **kwargs): bot = User.objects.get(username="finmars_bot") # Define the expiration time +1 hour from now - expiration_time = datetime.datetime.utcnow() + datetime.timedelta( - minutes=ttl_minutes - ) + expiration_time = datetime.datetime.utcnow() + datetime.timedelta(minutes=ttl_minutes) space = Space.objects.all().first() @@ -69,9 +67,7 @@ def get_refresh_token(ttl_minutes=60 * 8, *args, **kwargs): bot = User.objects.get(username="finmars_bot") # Define the expiration time +1 hour from now - expiration_time = datetime.datetime.utcnow() + datetime.timedelta( - minutes=ttl_minutes - ) + expiration_time = datetime.datetime.utcnow() + datetime.timedelta(minutes=ttl_minutes) space = Space.objects.all().first() @@ -164,17 +160,9 @@ def execute_expression(expression): + "/api/v1/utils/expression/" ) else: - url = ( - "https://" - + settings.DOMAIN_NAME - + "/" - + space.space_code - + "/api/v1/utils/expression/" - ) + url = "https://" + settings.DOMAIN_NAME + "/" + space.space_code + "/api/v1/utils/expression/" - response = requests.post( - url=url, data=json.dumps(data), headers=headers, verify=settings.VERIFY_SSL - ) + response = requests.post(url=url, data=json.dumps(data), headers=headers, verify=settings.VERIFY_SSL) if response.status_code != 200: raise Exception(response.text) @@ -216,9 +204,7 @@ def execute_expression_procedure(payload): + "/api/v1/procedures/expression-procedure/execute/" ) - response = requests.post( - url=url, data=json.dumps(data), headers=headers, verify=settings.VERIFY_SSL - ) + response = requests.post(url=url, data=json.dumps(data), headers=headers, verify=settings.VERIFY_SSL) if response.status_code != 200: raise Exception(response.text) @@ -252,17 +238,9 @@ def execute_data_procedure(payload): + "/api/v1/procedures/data-procedure/execute/" ) else: - url = ( - "https://" - + settings.DOMAIN_NAME - + "/" - + space.space_code - + "/api/v1/procedures/data-procedure/execute/" - ) + url = "https://" + settings.DOMAIN_NAME + "/" + space.space_code + "/api/v1/procedures/data-procedure/execute/" - response = requests.post( - url=url, data=json.dumps(data), headers=headers, verify=settings.VERIFY_SSL - ) + response = requests.post(url=url, data=json.dumps(data), headers=headers, verify=settings.VERIFY_SSL) if response.status_code != 200: raise Exception(response.text) @@ -284,23 +262,9 @@ def get_data_procedure_instance(id): space = get_space() if space.realm_code: - url = ( - "https://" - + settings.DOMAIN_NAME - + "/" - + space.realm_code - + "/" - + space.space_code - + "/api/v1/procedures/data-procedure-instance/%s/" % id - ) + url = f"https://{settings.DOMAIN_NAME}/{space.realm_code}/{space.space_code}/api/v1/procedures/data-procedure-instance/{id}/" else: - url = ( - "https://" - + settings.DOMAIN_NAME - + "/" - + space.space_code - + "/api/v1/procedures/data-procedure-instance/%s/" % id - ) + url = f"https://{settings.DOMAIN_NAME}/{space.space_code}/api/v1/procedures/data-procedure-instance/{id}/" response = requests.get(url=url, headers=headers, verify=settings.VERIFY_SSL) @@ -344,9 +308,7 @@ def execute_pricing_procedure(payload): + "/api/v1/procedures/pricing-procedure/execute/" ) - response = requests.post( - url=url, data=json.dumps(data), headers=headers, verify=settings.VERIFY_SSL - ) + response = requests.post(url=url, data=json.dumps(data), headers=headers, verify=settings.VERIFY_SSL) if response.status_code != 200: raise Exception(response.text) @@ -354,7 +316,10 @@ def execute_pricing_procedure(payload): return response.json() -def execute_task(task_name, payload={}): +def execute_task(task_name, payload=None): + if payload is None: + payload = {} + refresh = get_refresh_token() # _l.info('refresh %s' % refresh.access_token) @@ -380,17 +345,9 @@ def execute_task(task_name, payload={}): + "/api/v1/tasks/task/execute/" ) else: - url = ( - "https://" - + settings.DOMAIN_NAME - + "/" - + space.space_code - + "/api/v1/tasks/task/execute/" - ) + url = "https://" + settings.DOMAIN_NAME + "/" + space.space_code + "/api/v1/tasks/task/execute/" - response = requests.post( - url=url, data=json.dumps(data), headers=headers, verify=settings.VERIFY_SSL - ) + response = requests.post(url=url, data=json.dumps(data), headers=headers, verify=settings.VERIFY_SSL) if response.status_code != 200: raise Exception(response.text) @@ -413,14 +370,12 @@ def update_task_status(platform_task_id, status, result=None, error=None): } url = f"{get_base_path()}/api/v1/tasks/task/{platform_task_id}/update-status/" - response = requests.post( - url=url, data=json.dumps(data), headers=headers, verify=settings.VERIFY_SSL - ) + response = requests.post(url=url, data=json.dumps(data), headers=headers, verify=settings.VERIFY_SSL) try: response.raise_for_status() return response.json() except Exception as e: - _l.error("update_task_status error: %s" % e) + _l.error("update_task_status error: %s", e) def get_task(id): @@ -437,23 +392,9 @@ def get_task(id): space = get_space() if space.realm_code: - url = ( - "https://" - + settings.DOMAIN_NAME - + "/" - + space.realm_code - + "/" - + space.space_code - + "/api/v1/tasks/task/%s/" % id - ) + url = f"https://{settings.DOMAIN_NAME}/{space.realm_code}/{space.space_code}/api/v1/tasks/task/{id}/" else: - url = ( - "https://" - + settings.DOMAIN_NAME - + "/" - + space.space_code - + "/api/v1/tasks/task/%s/" % id - ) + url = f"https://{settings.DOMAIN_NAME}/{space.space_code}/api/v1/tasks/task/{id}/" response = requests.get(url=url, headers=headers, verify=settings.VERIFY_SSL) @@ -463,11 +404,9 @@ def get_task(id): return response.json() -def _wait_task_to_complete_recursive( - task_id=None, retries=5, retry_interval=60, counter=None -): +def _wait_task_to_complete_recursive(task_id=None, retries=5, retry_interval=60, counter=None): if counter == retries: - raise Exception("Task exceeded retries %s count" % retries) + raise Exception("Task exceeded retries %s count", retries) try: result = get_task(task_id) @@ -475,7 +414,7 @@ def _wait_task_to_complete_recursive( if result["status"] not in ["progress", "P", "I"]: return result except Exception as e: - _l.error("_wait_task_to_complete_recursive %s" % e) + _l.error("_wait_task_to_complete_recursive %s", e) counter = counter + 1 @@ -511,7 +450,7 @@ def poll_workflow_status(workflow_id, max_retries=100, wait_time=5): return status # Return the status when it's success or error else: - _l.error(f"Error fetching status") + _l.error("Error fetching status") time.sleep(wait_time) # Wait before the next attempt @@ -519,11 +458,9 @@ def poll_workflow_status(workflow_id, max_retries=100, wait_time=5): return None # Indicate that the status was not found -def _wait_procedure_to_complete_recursive( - procedure_instance_id=None, retries=5, retry_interval=60, counter=None -): +def _wait_procedure_to_complete_recursive(procedure_instance_id=None, retries=5, retry_interval=60, counter=None): if counter == retries: - raise Exception("Task exceeded retries %s count" % retries) + raise Exception("Task exceeded retries %s count", retries) result = get_data_procedure_instance(procedure_instance_id) @@ -542,9 +479,7 @@ def _wait_procedure_to_complete_recursive( ) -def wait_procedure_to_complete( - procedure_instance_id=None, retries=5, retry_interval=60 -): +def wait_procedure_to_complete(procedure_instance_id=None, retries=5, retry_interval=60): counter = 0 result = None @@ -583,17 +518,9 @@ def execute_transaction_import(payload): + "/api/v1/import/transaction-import/execute/" ) else: - url = ( - "https://" - + settings.DOMAIN_NAME - + "/" - + space.space_code - + "/api/v1/import/transaction-import/execute/" - ) + url = "https://" + settings.DOMAIN_NAME + "/" + space.space_code + "/api/v1/import/transaction-import/execute/" - response = requests.post( - url=url, data=json.dumps(data), headers=headers, verify=settings.VERIFY_SSL - ) + response = requests.post(url=url, data=json.dumps(data), headers=headers, verify=settings.VERIFY_SSL) if response.status_code != 200: raise Exception(response.text) @@ -627,17 +554,9 @@ def execute_simple_import(payload): + "/api/v1/import/simple-import/execute/" ) else: - url = ( - "https://" - + settings.DOMAIN_NAME - + "/" - + space.space_code - + "/api/v1/import/simple-import/execute/" - ) + url = "https://" + settings.DOMAIN_NAME + "/" + space.space_code + "/api/v1/import/simple-import/execute/" - response = requests.post( - url=url, data=json.dumps(data), headers=headers, verify=settings.VERIFY_SSL - ) + response = requests.post(url=url, data=json.dumps(data), headers=headers, verify=settings.VERIFY_SSL) if response.status_code != 200: raise Exception(response.text) @@ -657,15 +576,7 @@ def request_api(path, method="get", data=None): space = get_space() if space.realm_code: - url = ( - "https://" - + settings.DOMAIN_NAME - + "/" - + space.realm_code - + "/" - + space.space_code - + path - ) + url = "https://" + settings.DOMAIN_NAME + "/" + space.realm_code + "/" + space.space_code + path else: url = "https://" + settings.DOMAIN_NAME + "/" + space.space_code + path @@ -675,28 +586,18 @@ def request_api(path, method="get", data=None): response = requests.get(url=url, headers=headers, verify=settings.VERIFY_SSL) elif method.lower() == "post": - response = requests.post( - url=url, data=json.dumps(data), headers=headers, verify=settings.VERIFY_SSL - ) + response = requests.post(url=url, data=json.dumps(data), headers=headers, verify=settings.VERIFY_SSL) elif method.lower() == "put": - response = requests.put( - url=url, data=json.dumps(data), headers=headers, verify=settings.VERIFY_SSL - ) + response = requests.put(url=url, data=json.dumps(data), headers=headers, verify=settings.VERIFY_SSL) elif method.lower() == "patch": - response = requests.patch( - url=url, data=json.dumps(data), headers=headers, verify=settings.VERIFY_SSL - ) + response = requests.patch(url=url, data=json.dumps(data), headers=headers, verify=settings.VERIFY_SSL) elif method.lower() == "delete": response = requests.delete(url=url, headers=headers, verify=settings.VERIFY_SSL) - if ( - response.status_code != 200 - and response.status_code != 201 - and response.status_code != 204 - ): + if response.status_code not in [200, 201, 204]: raise Exception(response.text) if response.status_code != 204: @@ -833,9 +734,7 @@ def get_yesterday( yesterday = today - timedelta(days=1) return yesterday - def get_list_of_business_days_between_two_dates( - self, date_from, date_to, to_string=False - ): + def get_list_of_business_days_between_two_dates(self, date_from, date_to, to_string=False): result = [] format = "%Y-%m-%d" @@ -864,28 +763,17 @@ def import_from_storage(self, file_path): space = get_space() if file_path[0] == "/": - file_path = os.path.join( - settings.WORKFLOW_STORAGE_ROOT - + "/tasks/" - + space.space_code - + file_path - ) + file_path = os.path.join(settings.WORKFLOW_STORAGE_ROOT + "/tasks/" + space.space_code + file_path) else: - file_path = os.path.join( - settings.WORKFLOW_STORAGE_ROOT - + "/tasks/" - + space.space_code - + "/" - + file_path - ) + file_path = os.path.join(settings.WORKFLOW_STORAGE_ROOT + "/tasks/" + space.space_code + "/" + file_path) - _l.info("import_from_storage.file_path %s" % file_path) + _l.info("import_from_storage.file_path %s", file_path) directory, filename = os.path.split(file_path) module_name, _ = os.path.splitext(filename) - _l.info("import_from_storage.module_name %s" % module_name) - _l.info("import_from_storage.file_path %s" % file_path) + _l.info("import_from_storage.module_name %s", module_name) + _l.info("import_from_storage.file_path %s", file_path) loader = importlib.machinery.SourceFileLoader(module_name, file_path) module = loader.load_module() @@ -976,9 +864,7 @@ def to_ascii_or_unicode(char): try: # Try to encode the character in ASCII ascii_char = char.encode("ascii") - return ( - ascii_char.decode() - ) # Return as string if it's a valid ASCII character + return ascii_char.decode() # Return as string if it's a valid ASCII character except UnicodeEncodeError: # If it's not an ASCII character, return its Unicode code point return f"U{ord(char)}" @@ -1017,13 +903,11 @@ def get_secret(self, path, provider="finmars"): + space.realm_code + "/" + space.space_code - + f"/api/v1/vault/vault-record/?user_code=" + + "/api/v1/vault/vault-record/?user_code=" + path ) - response = requests.get( - url=url, headers=headers, verify=settings.VERIFY_SSL - ) + response = requests.get(url=url, headers=headers, verify=settings.VERIFY_SSL) if response.status_code != 200: raise Exception(response.text) @@ -1073,9 +957,7 @@ def get_secret(self, path, provider="finmars"): + f"/api/v1/vault/vault-secret/get/?engine_name={engine_name}&path={secret_path}" ) - response = requests.get( - url=url, headers=headers, verify=settings.VERIFY_SSL - ) + response = requests.get(url=url, headers=headers, verify=settings.VERIFY_SSL) if response.status_code != 200: raise Exception(response.text) @@ -1083,7 +965,7 @@ def get_secret(self, path, provider="finmars"): return response.json()["data"]["data"] else: - raise Exception("Unknown provider %s" % provider) + raise Exception("Unknown provider %s", provider) storage = Storage() diff --git a/workflow/keycloak.py b/workflow/keycloak.py index 453a1b2b..968130b4 100644 --- a/workflow/keycloak.py +++ b/workflow/keycloak.py @@ -54,24 +54,11 @@ def __init__(self, server_url, realm_name, client_id, client_secret_key=None): self.client_secret_key = client_secret_key # Keycloak useful Urls - self.well_known_endpoint = ( - self.server_url - + "/realms/" - + self.realm_name - + "/.well-known/openid-configuration" - ) + self.well_known_endpoint = self.server_url + "/realms/" + self.realm_name + "/.well-known/openid-configuration" self.token_introspection_endpoint = ( - self.server_url - + "/realms/" - + self.realm_name - + "/protocol/openid-connect/token/introspect" - ) - self.userinfo_endpoint = ( - self.server_url - + "/realms/" - + self.realm_name - + "/protocol/openid-connect/userinfo" + self.server_url + "/realms/" + self.realm_name + "/protocol/openid-connect/token/introspect" ) + self.userinfo_endpoint = self.server_url + "/realms/" + self.realm_name + "/protocol/openid-connect/userinfo" def well_known(self): """Lists endpoints and other configuration options @@ -80,14 +67,11 @@ def well_known(self): Returns: [type]: [list of keycloak endpoints] """ - response = requests.request( - "GET", self.well_known_endpoint, verify=settings.VERIFY_SSL - ) + response = requests.request("GET", self.well_known_endpoint, verify=settings.VERIFY_SSL) error = response.raise_for_status() if error: LOGGER.error( - "Error obtaining list of endpoints from endpoint: " - f"{self.well_known_endpoint}, response error {error}" + f"Error obtaining list of endpoints from endpoint: {self.well_known_endpoint}, response error {error}" ) return {} return response.json() @@ -156,7 +140,7 @@ def is_token_active(self, token): """ introspect_token = self.introspect(token) is_active = introspect_token.get("active", None) - return True if is_active else False + return bool(is_active) def roles_from_token(self, token): """ @@ -172,18 +156,10 @@ def roles_from_token(self, token): realm_access = token_decoded.get("realm_access", None) resource_access = token_decoded.get("resource_access", None) - client_access = ( - resource_access.get(self.client_id, None) - if resource_access is not None - else None - ) + client_access = resource_access.get(self.client_id, None) if resource_access is not None else None - client_roles = ( - client_access.get("roles", None) if client_access is not None else None - ) - realm_roles = ( - realm_access.get("roles", None) if realm_access is not None else None - ) + client_roles = client_access.get("roles", None) if client_access is not None else None + realm_roles = realm_access.get("roles", None) if realm_access is not None else None if client_roles is None: return realm_roles @@ -201,9 +177,7 @@ def userinfo(self, token): json: user info data """ headers = {"authorization": "Bearer " + token} - response = requests.request( - "GET", self.userinfo_endpoint, headers=headers, verify=settings.VERIFY_SSL - ) + response = requests.request("GET", self.userinfo_endpoint, headers=headers, verify=settings.VERIFY_SSL) error = response.raise_for_status() if error: LOGGER.error( diff --git a/workflow/management/commands/cancel_existing_tasks.py b/workflow/management/commands/cancel_existing_tasks.py index 1a985e3d..caae4adb 100644 --- a/workflow/management/commands/cancel_existing_tasks.py +++ b/workflow/management/commands/cancel_existing_tasks.py @@ -6,10 +6,8 @@ __author__ = "szhitenev" -from workflow.utils import get_all_tenant_schemas - - from workflow.system import get_system_workflow_manager +from workflow.utils import get_all_tenant_schemas system_workflow_manager = get_system_workflow_manager() @@ -34,4 +32,4 @@ def handle(self, *args, **options): cursor.execute("SET search_path TO public;") except Exception as e: - print("cancel_existing_tasks error e %s " % e) + print(f"cancel_existing_tasks error e {e}") diff --git a/workflow/management/commands/copy_css_libs.py b/workflow/management/commands/copy_css_libs.py index 4ef81b3e..859f5458 100644 --- a/workflow/management/commands/copy_css_libs.py +++ b/workflow/management/commands/copy_css_libs.py @@ -1,5 +1,6 @@ -import shutil import os +import shutil + from django.core.management.base import BaseCommand @@ -20,6 +21,4 @@ def handle(self, *args, **options): dst = f"workflow/static/css/{lib.split('/')[-1]}" shutil.copy(src, dst) - self.stdout.write( - self.style.SUCCESS("Successfully copied CSS libraries to static directory") - ) + self.stdout.write(self.style.SUCCESS("Successfully copied CSS libraries to static directory")) diff --git a/workflow/management/commands/copy_js_libs.py b/workflow/management/commands/copy_js_libs.py index 4d90821a..4c38c5c2 100644 --- a/workflow/management/commands/copy_js_libs.py +++ b/workflow/management/commands/copy_js_libs.py @@ -1,5 +1,6 @@ -import shutil import os +import shutil + from django.core.management.base import BaseCommand @@ -28,8 +29,4 @@ def handle(self, *args, **options): dst = f"workflow/static/scripts/{lib.split('/')[-1]}" shutil.copy(src, dst) - self.stdout.write( - self.style.SUCCESS( - "Successfully copied JavaScript libraries to static directory" - ) - ) + self.stdout.write(self.style.SUCCESS("Successfully copied JavaScript libraries to static directory")) diff --git a/workflow/management/commands/generate_super_user.py b/workflow/management/commands/generate_super_user.py index a976a32b..0b93e854 100644 --- a/workflow/management/commands/generate_super_user.py +++ b/workflow/management/commands/generate_super_user.py @@ -19,19 +19,13 @@ def handle(self, *args, **options): try: superuser = User.objects.get(username=username) - self.stdout.write( - "Skip. Super user '%s' already exists." % superuser.username - ) + self.stdout.write(f"Skip. Super user '{superuser.username}' already exists.") except User.DoesNotExist: - superuser = User.objects.create_superuser( - username=username, email=email, password=password - ) + superuser = User.objects.create_superuser(username=username, email=email, password=password) superuser.save() - self.stdout.write("Super user '%s' created." % superuser.username) + self.stdout.write(f"Super user '{superuser.username}' created.") else: - self.stdout.write( - "Skip. Super user username and password are not provided." - ) + self.stdout.write("Skip. Super user username and password are not provided.") diff --git a/workflow/management/commands/sync_remote_storage_to_local_storage_all_spaces.py b/workflow/management/commands/sync_remote_storage_to_local_storage_all_spaces.py index a99fe020..b9ee088c 100644 --- a/workflow/management/commands/sync_remote_storage_to_local_storage_all_spaces.py +++ b/workflow/management/commands/sync_remote_storage_to_local_storage_all_spaces.py @@ -1,6 +1,4 @@ -from django.core.management import call_command from django.core.management.base import BaseCommand -from django.db import connection from workflow.system import get_system_workflow_manager diff --git a/workflow/management/commands/update_sequences.py b/workflow/management/commands/update_sequences.py index d283c373..d1a9604b 100644 --- a/workflow/management/commands/update_sequences.py +++ b/workflow/management/commands/update_sequences.py @@ -1,5 +1,5 @@ -from django.core.management.base import BaseCommand from django.apps import apps +from django.core.management.base import BaseCommand from django.db import connection from django.db.models import AutoField @@ -27,9 +27,7 @@ def handle(self, *args, **options): continue primary_key_column = primary_key_field.column - cursor.execute( - f"SELECT MAX({primary_key_column}) FROM {table_name}" - ) + cursor.execute(f"SELECT MAX({primary_key_column}) FROM {table_name}") max_id = cursor.fetchone()[0] or 0 cursor.execute(f"""SELECT distinct seq.relname AS sequence_name @@ -42,16 +40,14 @@ def handle(self, *args, **options): AND tab.relname = '{table_name}' and ns.nspname = '{schema}'""") sequence_name = cursor.fetchone()[0] cursor.execute( - f"""SELECT COUNT(*) FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace + """SELECT COUNT(*) FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE c.relkind = 'S' AND c.relname = %s AND n.nspname = %s""", [sequence_name, schema], ) sequence_exists = cursor.fetchone()[0] if not sequence_exists: self.stdout.write( - self.style.WARNING( - f"Sequence {sequence_name} does not exist for {schema}.{table_name}" - ) + self.style.WARNING(f"Sequence {sequence_name} does not exist for {schema}.{table_name}") ) continue @@ -59,13 +55,7 @@ def handle(self, *args, **options): last_id = cursor.fetchone()[0] if max_id >= last_id: - cursor.execute( - f"SELECT setval('{sequence_name}', {max_id + 1}, false)" - ) - self.stdout.write( - f"Updated sequence {schema}.{sequence_name} to {max_id + 1}" - ) + cursor.execute(f"SELECT setval('{sequence_name}', {max_id + 1}, false)") + self.stdout.write(f"Updated sequence {schema}.{sequence_name} to {max_id + 1}") - self.stdout.write( - self.style.SUCCESS("Successfully updated sequences where necessary") - ) + self.stdout.write(self.style.SUCCESS("Successfully updated sequences where necessary")) diff --git a/workflow/middleware.py b/workflow/middleware.py index 8e507d65..7408b7fc 100644 --- a/workflow/middleware.py +++ b/workflow/middleware.py @@ -30,7 +30,7 @@ def __call__(self, request): # return HttpResponseBadRequest("Invalid space code.") with connection.cursor() as cursor: - cursor.execute(f"SET search_path TO public;") + cursor.execute("SET search_path TO public;") else: # Setting the PostgreSQL search path to the tenant's schema @@ -48,13 +48,9 @@ def __call__(self, request): response = self.get_response(request) if not response.streaming and "/admin/" in request.path_info: - response.content = response.content.replace( - b"spacexxxxx", request.space_code.encode() - ) + response.content = response.content.replace(b"spacexxxxx", request.space_code.encode()) if "location" in response: - response["location"] = response["location"].replace( - "spacexxxxx", request.space_code - ) + response["location"] = response["location"].replace("spacexxxxx", request.space_code) # Optionally, reset the search path to default after the request is processed # This can be important in preventing "leakage" of the schema setting across requests diff --git a/workflow/models.py b/workflow/models.py index b0d37cee..357b7167 100644 --- a/workflow/models.py +++ b/workflow/models.py @@ -1,31 +1,28 @@ -from __future__ import unicode_literals -from celery import schedules import json +import logging +from datetime import datetime import pytz +from celery import schedules +from croniter import croniter from django.conf import settings from django.contrib.auth.models import AbstractUser from django.core.serializers.json import DjangoJSONEncoder -from django.db import models, connection +from django.db import connection, models +from django.utils.timezone import now from django.utils.translation import gettext_lazy -from django_celery_beat.models import PeriodicTask, CrontabSchedule +from django.utils.translation import gettext_lazy as _ +from django_celery_beat.models import CrontabSchedule, PeriodicTask from workflow.storage import get_storage from workflow.utils import get_all_tenant_schemas, get_next_node_by_condition +from workflow_app import celery_app LANGUAGE_MAX_LENGTH = 5 TIMEZONE_MAX_LENGTH = 20 TIMEZONE_CHOICES = sorted(list((k, k) for k in pytz.all_timezones)) TIMEZONE_COMMON_CHOICES = sorted(list((k, k) for k in pytz.common_timezones)) -from django.utils.translation import gettext_lazy as _ -from workflow_app import celery_app -from django.utils.timezone import now -from croniter import croniter -from datetime import datetime -import pytz - -import logging _l = logging.getLogger("workflow") @@ -44,17 +41,11 @@ class User(AbstractUser): verbose_name=gettext_lazy("timezone"), ) - two_factor_verification = models.BooleanField( - default=False, verbose_name=gettext_lazy("two factor verification") - ) + two_factor_verification = models.BooleanField(default=False, verbose_name=gettext_lazy("two factor verification")) - json_data = models.TextField( - null=True, blank=True, verbose_name=gettext_lazy("json data") - ) + json_data = models.TextField(null=True, blank=True, verbose_name=gettext_lazy("json data")) - is_verified = models.BooleanField( - default=False, verbose_name=gettext_lazy("is verified") - ) + is_verified = models.BooleanField(default=False, verbose_name=gettext_lazy("is verified")) password = models.CharField(_("password"), max_length=256) @@ -101,27 +92,17 @@ class Meta: class Space(TimeStampedModel): - name = models.CharField( - max_length=255, null=True, blank=True, verbose_name=gettext_lazy("name") - ) + name = models.CharField(max_length=255, null=True, blank=True, verbose_name=gettext_lazy("name")) - realm_code = models.CharField( - max_length=255, null=True, blank=True, verbose_name=gettext_lazy("realm_code") - ) + realm_code = models.CharField(max_length=255, null=True, blank=True, verbose_name=gettext_lazy("realm_code")) - space_code = models.CharField( - max_length=255, verbose_name=gettext_lazy("space_code") - ) + space_code = models.CharField(max_length=255, verbose_name=gettext_lazy("space_code")) class WorkflowTemplate(TimeStampedModel): - name = models.CharField( - max_length=255, null=True, blank=True, verbose_name=gettext_lazy("name") - ) + name = models.CharField(max_length=255, null=True, blank=True, verbose_name=gettext_lazy("name")) - user_code = models.CharField( - max_length=1024, null=True, blank=True, verbose_name=gettext_lazy("user_code") - ) + user_code = models.CharField(max_length=1024, null=True, blank=True, verbose_name=gettext_lazy("user_code")) notes = models.TextField(null=True, blank=True, verbose_name=gettext_lazy("notes")) @@ -164,9 +145,7 @@ class Workflow(TimeStampedModel): (STATUS_CANCELED, "canceled"), ) - name = models.CharField( - max_length=255, null=True, blank=True, verbose_name=gettext_lazy("name") - ) + name = models.CharField(max_length=255, null=True, blank=True, verbose_name=gettext_lazy("name")) workflow_template = models.ForeignKey( WorkflowTemplate, @@ -176,12 +155,8 @@ class Workflow(TimeStampedModel): related_name="workflows", ) - current_node_id = models.CharField( - max_length=255, null=True, blank=True - ) # Store the current node ID - last_task_output = models.JSONField( - null=True, blank=True - ) # New field for storing last output + current_node_id = models.CharField(max_length=255, null=True, blank=True) # Store the current node ID + last_task_output = models.JSONField(null=True, blank=True) # New field for storing last output node_id = models.CharField( max_length=255, blank=True, @@ -189,9 +164,7 @@ class Workflow(TimeStampedModel): help_text="Node ID from the workflow JSON structure", ) - user_code = models.CharField( - max_length=1024, null=True, blank=True, verbose_name=gettext_lazy("user_code") - ) + user_code = models.CharField(max_length=1024, null=True, blank=True, verbose_name=gettext_lazy("user_code")) status = models.CharField( null=True, @@ -200,14 +173,10 @@ class Workflow(TimeStampedModel): choices=STATUS_CHOICES, verbose_name="status", ) - payload_data = models.TextField( - null=True, blank=True, verbose_name=gettext_lazy("payload data") - ) + payload_data = models.TextField(null=True, blank=True, verbose_name=gettext_lazy("payload data")) periodic = models.BooleanField(default=False, verbose_name=gettext_lazy("periodic")) - is_manager = models.BooleanField( - default=False, verbose_name=gettext_lazy("is manager") - ) + is_manager = models.BooleanField(default=False, verbose_name=gettext_lazy("is manager")) platform_task_id = models.IntegerField( null=True, help_text="Platform Task ID in case if Platform initiated some pipeline", @@ -235,9 +204,7 @@ class Workflow(TimeStampedModel): related_name="workflows", ) - finished_at = models.DateTimeField( - null=True, db_index=True, verbose_name=gettext_lazy("finished at") - ) + finished_at = models.DateTimeField(null=True, db_index=True, verbose_name=gettext_lazy("finished at")) parent = models.ForeignKey( "self", @@ -313,24 +280,25 @@ def save(self, *args, **kwargs): last_task = self.tasks.last() if last_task: - update_task_status( - self.platform_task_id, self.status, result=last_task.result - ) + update_task_status(self.platform_task_id, self.status, result=last_task.result) except Exception as ex: - _l.warning("update_task_status %s" % ex) + _l.warning("update_task_status %s", ex) def cancel(self): status_to_cancel = [Task.STATUS_PROGRESS, Task.STATUS_INIT, Task.STATUS_NESTED_PROGRESS] for task in self.tasks.all(): if task.status in status_to_cancel: - celery_app.control.revoke(task.celery_task_id, terminate=True, signal='SIGKILL') + celery_app.control.revoke(task.celery_task_id, terminate=True, signal="SIGKILL") task.status = Task.STATUS_CANCELED task.mark_task_as_finished() task.save() self.status = Workflow.STATUS_CANCELED self.save() - def run_new_workflow(self, user_code, payload={}): + def run_new_workflow(self, user_code, payload=None): + if payload is None: + payload = {} + if not user_code: raise Exception("User code is required.") @@ -344,16 +312,12 @@ def run_new_workflow(self, user_code, payload={}): user_code = f"{self.space.space_code}.{user_code}" - new_workflow = system_workflow_manager.get_by_user_code( - user_code, sync_remote=True - ) + new_workflow = system_workflow_manager.get_by_user_code(user_code, sync_remote=True) is_manager = new_workflow["workflow"].get("is_manager", False) if is_manager: - raise Exception( - "New Workflow is manager. Manager can't execute another manager" - ) + raise Exception("New Workflow is manager. Manager can't execute another manager") _l.info("run_new_workflow. Going to execute: %s", user_code) @@ -403,9 +367,7 @@ class Task(TimeStampedModel): celery_task_id = models.CharField(null=True, max_length=255) name = models.CharField(null=True, max_length=255) - source_code = models.TextField( - null=True, blank=True, verbose_name=gettext_lazy("source code") - ) + source_code = models.TextField(null=True, blank=True, verbose_name=gettext_lazy("source code")) status = models.CharField( null=True, max_length=255, @@ -413,43 +375,27 @@ class Task(TimeStampedModel): choices=STATUS_CHOICES, verbose_name="status", ) - worker_name = models.CharField( - null=True, max_length=255, verbose_name="worker name" - ) + worker_name = models.CharField(null=True, max_length=255, verbose_name="worker name") type = models.CharField(max_length=50, blank=True, null=True) - payload_data = models.TextField( - null=True, blank=True, verbose_name=gettext_lazy("payload data") - ) - result_data = models.TextField( - null=True, blank=True, verbose_name=gettext_lazy("result data") - ) + payload_data = models.TextField(null=True, blank=True, verbose_name=gettext_lazy("payload data")) + result_data = models.TextField(null=True, blank=True, verbose_name=gettext_lazy("result data")) - progress_data = models.TextField( - null=True, blank=True, verbose_name=gettext_lazy("progress data") - ) + progress_data = models.TextField(null=True, blank=True, verbose_name=gettext_lazy("progress data")) log = models.TextField(null=True, blank=True, verbose_name=gettext_lazy("log")) notes = models.TextField(null=True, blank=True, verbose_name=gettext_lazy("notes")) - error_message = models.TextField( - null=True, blank=True, verbose_name=gettext_lazy("error message") - ) + error_message = models.TextField(null=True, blank=True, verbose_name=gettext_lazy("error message")) verbose_name = models.CharField(null=True, max_length=255) - verbose_result = models.TextField( - null=True, blank=True, verbose_name=gettext_lazy("verbose result") - ) + verbose_result = models.TextField(null=True, blank=True, verbose_name=gettext_lazy("verbose result")) - previous_data = models.TextField( - null=True, blank=True, verbose_name=gettext_lazy("previous data") - ) + previous_data = models.TextField(null=True, blank=True, verbose_name=gettext_lazy("previous data")) is_hook = models.BooleanField(default=False, verbose_name=gettext_lazy("is hook")) - finished_at = models.DateTimeField( - null=True, db_index=True, verbose_name=gettext_lazy("finished at") - ) + finished_at = models.DateTimeField(null=True, db_index=True, verbose_name=gettext_lazy("finished at")) space = models.ForeignKey( Space, @@ -462,7 +408,7 @@ class Meta: ordering = ["-created_at"] def __str__(self): - return "".format(self) + return f"" @property def payload(self): @@ -532,7 +478,7 @@ def update_progress(self, progress): # description # } - _l.info("update_progress %s" % progress) + _l.info("update_progress %s", progress) self.progress = progress @@ -540,9 +486,7 @@ def update_progress(self, progress): def handle_task_success(self, retval): if self.status == Task.STATUS_NESTED_PROGRESS: - _l.info( - f"Task {self.id} is in STATUS_NESTED_PROGRESS status; waiting for nested workflow to complete." - ) + _l.info(f"Task {self.id} is in STATUS_NESTED_PROGRESS status; waiting for nested workflow to complete.") # If the task status is STATUS_NESTED_PROGRESS, we should exit without marking it complete # It will be resumed by the nested workflow completion logic return @@ -578,12 +522,8 @@ def handle_task_success(self, retval): next_node_ids = [] if current_node["data"]["node"]["type"] == "condition": # Use the condition result to determine the next path - _l.info( - f"BaseTask.on_success.Processing conditional node {current_node_id}, result: {retval}" - ) - next_node_id = get_next_node_by_condition( - current_node_id, retval, connections - ) + _l.info(f"BaseTask.on_success.Processing conditional node {current_node_id}, result: {retval}") + next_node_id = get_next_node_by_condition(current_node_id, retval, connections) if next_node_id: next_node_ids.append(next_node_id) else: @@ -592,7 +532,8 @@ def handle_task_success(self, retval): if not next_node_ids: _l.info( - f"BaseTask.on_success.No next nodes found for current node ID: {current_node_id}. Marking workflow as complete." + f"BaseTask.on_success.No next nodes found for current node ID: {current_node_id}. " + "Marking workflow as complete." ) _l.info(f"BaseTask.on_success.workflow owner {self.workflow.owner}") @@ -600,13 +541,12 @@ def handle_task_success(self, retval): self.workflow.status = Workflow.STATUS_SUCCESS self.workflow.finished_at = now() self.workflow.save() - _l.info( - f"BaseTask.on_success.Workflow ID {self.workflow.id} status updated to SUCCESS." - ) + _l.info(f"BaseTask.on_success.Workflow ID {self.workflow.id} status updated to SUCCESS.") if self.workflow.parent: _l.info( - f"BaseTask.on_success.Workflow has a parent with ID {self.workflow.parent.id}. Triggering next task." + f"BaseTask.on_success.Workflow has a parent with ID {self.workflow.parent.id}. " + "Triggering next task." ) parent_workflow = self.workflow.parent @@ -631,9 +571,7 @@ def handle_task_success(self, retval): ) continue - _l.info( - f"BaseTask.on_success.Processing next node: {next_node_id}, Name: {next_node['name']}" - ) + _l.info(f"BaseTask.on_success.Processing next node: {next_node_id}, Name: {next_node['name']}") # Check if the workflow is in WAIT state @@ -663,9 +601,7 @@ def enabled(self): with connection.cursor() as cursor: cursor.execute(f"SET search_path TO {schema};") - schema_schedules = list( - self.filter(enabled=True).prefetch_related("crontab") - ) + schema_schedules = list(self.filter(enabled=True).prefetch_related("crontab")) result.extend(schema_schedules) return result @@ -677,13 +613,9 @@ class Schedule(PeriodicTask, TimeStampedModel): notes = models.TextField(null=True, blank=True, verbose_name=gettext_lazy("notes")) - payload = models.JSONField( - null=True, blank=True, verbose_name=gettext_lazy("payload") - ) + payload = models.JSONField(null=True, blank=True, verbose_name=gettext_lazy("payload")) - is_manager = models.BooleanField( - default=False, verbose_name=gettext_lazy("is manager") - ) + is_manager = models.BooleanField(default=False, verbose_name=gettext_lazy("is manager")) space = models.ForeignKey( Space, @@ -699,19 +631,14 @@ class Schedule(PeriodicTask, TimeStampedModel): related_name="schedules", ) - workflow_user_code = models.TextField( - verbose_name=gettext_lazy("workflow_user_code") - ) + workflow_user_code = models.TextField(verbose_name=gettext_lazy("workflow_user_code")) @property def crontab_line(self) -> str | None: if self.crontab: - return "{} {} {} {} {}".format( - self.crontab.minute, - self.crontab.hour, - self.crontab.day_of_month, - self.crontab.month_of_year, - self.crontab.day_of_week, + return ( + f"{self.crontab.minute} {self.crontab.hour} {self.crontab.day_of_month} " + f"{self.crontab.month_of_year} {self.crontab.day_of_week}" ) @crontab_line.setter @@ -750,7 +677,7 @@ def save(self, *args, **kwargs): "schedule_id": self.id, } ) - _l.info("Schedule save: %s" % self.kwargs) + _l.info("Schedule save: %s", self.kwargs) return super().save(*args, **kwargs) def __str__(self): diff --git a/workflow/pagination.py b/workflow/pagination.py index 0e7ca410..19eef0a9 100644 --- a/workflow/pagination.py +++ b/workflow/pagination.py @@ -1,9 +1,6 @@ import logging -import sys import time -from django.core.paginator import InvalidPage -from rest_framework.exceptions import NotFound from rest_framework.pagination import PageNumberPagination from rest_framework.settings import api_settings @@ -15,13 +12,13 @@ class PageNumberPaginationExt(PageNumberPagination): max_page_size = api_settings.PAGE_SIZE * 10 def post_paginate_queryset(self, queryset, request, view=None): - start_time = time.time() + start_time = time.time() # noqa: F841 qs = super().paginate_queryset(queryset, request, view) # _l.debug('post_paginate_queryset before list page') - list_page_st = time.perf_counter() + list_page_st = time.perf_counter() # noqa: F841 # _l.debug('res %s' % len(qs)) diff --git a/workflow/schedulers.py b/workflow/schedulers.py index 4bf2199a..ed1d8d94 100644 --- a/workflow/schedulers.py +++ b/workflow/schedulers.py @@ -1,6 +1,7 @@ from celery.utils.log import get_logger -from django_celery_beat.schedulers import ModelEntry, DatabaseScheduler as DCBScheduler from django.db.utils import DatabaseError, InterfaceError +from django_celery_beat.schedulers import DatabaseScheduler as DCBScheduler + from workflow.models import Schedule from workflow.utils import get_all_tenant_schemas, set_schema_from_context @@ -18,7 +19,7 @@ def all_as_schedule(self): for schema in schemas: set_schema_from_context({"space_code": schema}) for model in self.Model.objects.enabled(): - try: + try: # noqa: SIM105 s[model.name] = self.Entry(model, app=self.app) except ValueError: pass @@ -41,10 +42,7 @@ def schedule_changed(self): logger.exception("Database gave error: %r", exc) return False except InterfaceError: - warning( - "DatabaseScheduler: InterfaceError in schedule_changed(), " - "waiting to retry in next call..." - ) + warning("DatabaseScheduler: InterfaceError in schedule_changed(), waiting to retry in next call...") return False try: if ts and ts > (last if last else ts): diff --git a/workflow/serializers.py b/workflow/serializers.py index 82a0939a..f5dc41fa 100644 --- a/workflow/serializers.py +++ b/workflow/serializers.py @@ -2,9 +2,9 @@ from rest_framework import serializers -from workflow.fields import SpaceField, OwnerField +from workflow.fields import OwnerField, SpaceField from workflow.finmars import Storage -from workflow.models import Workflow, Task, Schedule, WorkflowTemplate +from workflow.models import Schedule, Task, Workflow, WorkflowTemplate class TaskSerializer(serializers.ModelSerializer): @@ -18,9 +18,7 @@ class TaskSerializer(serializers.ModelSerializer): def to_representation(self, instance): representation = super().to_representation(instance) if representation["worker_name"]: - representation["worker_name"] = representation["worker_name"].replace( - "celery@", "" - ) + representation["worker_name"] = representation["worker_name"].replace("celery@", "") return representation class Meta: @@ -96,9 +94,7 @@ def update(self, instance, validated_data): class SimpleWorkflowSerializer(serializers.ModelSerializer): space = SpaceField() - workflow_template_object = WorkflowTemplateSerializer( - read_only=True, source="workflow_template" - ) + workflow_template_object = WorkflowTemplateSerializer(read_only=True, source="workflow_template") payload = serializers.JSONField(allow_null=True, required=False) tasks = TaskSerializer(many=True, read_only=True) @@ -129,9 +125,7 @@ class Meta: class WorkflowSerializer(serializers.ModelSerializer): space = SpaceField() - workflow_template_object = WorkflowTemplateSerializer( - read_only=True, source="workflow_template" - ) + workflow_template_object = WorkflowTemplateSerializer(read_only=True, source="workflow_template") payload = serializers.JSONField(allow_null=True, required=False) tasks = TaskSerializer(many=True, read_only=True) parent = SimpleWorkflowSerializer(read_only=True) @@ -163,18 +157,12 @@ class Meta: def get_workflow_version(self, obj) -> int: if obj.workflow_template: - if workflow_template_data := WorkflowTemplateSerializer( - obj.workflow_template - ).data: + if workflow_template_data := WorkflowTemplateSerializer(obj.workflow_template).data: if data := workflow_template_data.get("data"): if isinstance(data, str): data = json.loads(data) - if ( - (version := data.get("version")) - and isinstance(version, str) - and version.isdigit() - ): + if (version := data.get("version")) and isinstance(version, str) and version.isdigit(): return int(version) return 1 @@ -268,10 +256,10 @@ def __init__(self, *args, **kwargs): def validate_crontab_line(self, value): try: minute, hour, day, month, weekday = value.split(" ") - except ValueError: + except ValueError as e: raise serializers.ValidationError( "Wrong crontab format. Make sure there are 5 space-separated values" - ) + ) from e return value def get_owner_username(self, obj): diff --git a/workflow/storage.py b/workflow/storage.py index ab360480..b0176144 100644 --- a/workflow/storage.py +++ b/workflow/storage.py @@ -4,7 +4,7 @@ import shutil import tempfile from io import BytesIO -from zipfile import ZipFile, ZIP_DEFLATED +from zipfile import ZIP_DEFLATED, ZipFile from cryptography.hazmat.primitives.ciphers.aead import AESGCM from django.core.files.base import ContentFile, File @@ -35,7 +35,7 @@ def __init__(self, *args, name=None, **kwargs): self.name = name -class EncryptedStorage(object): +class EncryptedStorage: def get_symmetric_key(self): if settings.ENCRYPTION_KEY: self.symmetric_key = bytes.fromhex(settings.ENCRYPTION_KEY) @@ -49,9 +49,7 @@ def get_symmetric_key(self): self.symmetric_key = self._get_symmetric_key_from_vault() except Exception as e: - raise Exception( - "Could not connect to Vault symmetric_key is not set. Error %s" % e - ) + raise Exception("Could not connect to Vault symmetric_key is not set. Error %s", e) from e def _get_symmetric_key_from_vault(self): # Retrieve the symmetric key from Vault @@ -154,13 +152,13 @@ def folder_exists_and_has_files(self, folder_path): try: # TODO maybe wrong implementation if not self.listdir: - raise NotImplemented("Listdir method not implemented") + raise NotImplementedError("Listdir method not implemented") # Check if the folder exists by listing its contents files, folders = self.listdir(folder_path) # Return True if there are any files in the folder return bool(files) - except Exception as e: + except Exception: return False def download_file_and_save_locally(self, storage_file_path, local_file_path): @@ -201,19 +199,12 @@ def download_paths_as_zip(self, paths): if path[0] == "/": self.download_directory(space.space_code + path, local_filename) else: - self.download_directory( - space.space_code + "/" + path, local_filename - ) + self.download_directory(space.space_code + "/" + path, local_filename) + elif path[0] == "/": + self.download_file_and_save_locally(space.space_code + path, local_filename) else: - if path[0] == "/": - self.download_file_and_save_locally( - space.space_code + path, local_filename - ) - else: - self.download_file_and_save_locally( - space.space_code + "/" + path, local_filename - ) + self.download_file_and_save_locally(space.space_code + "/" + path, local_filename) self.zip_directory(temp_dir_path, zip_filename) @@ -236,9 +227,7 @@ def download_directory(self, directory_path, local_destination_path): for root, _, files in self.sftp_client.walk(directory_path): for file in files: remote_path = os.path.join(root, file) - local_path = os.path.join( - local_destination_path, os.path.relpath(remote_path, directory_path) - ) + local_path = os.path.join(local_destination_path, os.path.relpath(remote_path, directory_path)) os.makedirs(os.path.dirname(local_path), exist_ok=True) self.sftp_client.get(remote_path, local_path) @@ -267,9 +256,7 @@ def download_directory(self, directory_path, local_destination_path): for blob in blob_list: # Check if the blob is inside the folder if blob.name.startswith(directory_path): - local_path = os.path.join( - local_destination_path, os.path.relpath(blob.name, directory_path) - ) + local_path = os.path.join(local_destination_path, os.path.relpath(blob.name, directory_path)) # Create the local directory structure os.makedirs(os.path.dirname(local_path), exist_ok=True) @@ -290,9 +277,7 @@ def download_directory_as_zip(self, directory_path): for blob in blob_list: # Check if the blob is inside the folder if blob.name.startswith(directory_path): - local_path = os.path.join( - temp_dir, os.path.relpath(blob.name, directory_path) - ) + local_path = os.path.join(temp_dir, os.path.relpath(blob.name, directory_path)) # Create the local directory structure os.makedirs(os.path.dirname(local_path), exist_ok=True) @@ -328,16 +313,14 @@ def delete_directory(self, directory_path): self.bucket.delete_objects(Delete={"Objects": objects_to_delete}) def download_directory(self, directory_path, local_destination_path): - _l.info("directory_path %s" % directory_path) + _l.info("directory_path %s", directory_path) folder = os.path.dirname(local_destination_path) if folder: os.makedirs(folder, exist_ok=True) for obj in self.bucket.objects.filter(Prefix=directory_path): - local_path = os.path.join( - local_destination_path, os.path.relpath(obj.key, directory_path) - ) + local_path = os.path.join(local_destination_path, os.path.relpath(obj.key, directory_path)) os.makedirs(os.path.dirname(local_path), exist_ok=True) self.bucket.download_file(obj.key, local_path) @@ -348,9 +331,7 @@ def download_directory_as_zip(self, directory_path): # Download all files from the remote folder to the temporary local directory for obj in self.bucket.objects.filter(Prefix=directory_path): - local_path = os.path.join( - temp_dir, os.path.relpath(obj.key, directory_path) - ) + local_path = os.path.join(temp_dir, os.path.relpath(obj.key, directory_path)) os.makedirs(os.path.dirname(local_path), exist_ok=True) self.bucket.download_file(obj.key, local_path) diff --git a/workflow/system.py b/workflow/system.py index 47953ee1..0abaf92b 100644 --- a/workflow/system.py +++ b/workflow/system.py @@ -1,28 +1,24 @@ +import fnmatch import importlib import json +import logging import os import shutil import sys from pathlib import Path -import fnmatch import yaml - +from django.db import connection from pluginbase import PluginBase from workflow.exceptions import WorkflowNotFound from workflow.models import Space from workflow.storage import get_storage from workflow.utils import build_celery_schedule, construct_path, get_all_tenant_schemas -from workflow_app import celery_app -from workflow_app import settings -from django.db import connection - +from workflow_app import celery_app, settings storage = get_storage() -import logging - _l = logging.getLogger("workflow") @@ -39,8 +35,7 @@ def register_workflows(self, space_code=None): schemas = [space_code] for schema in schemas: - - if schema != 'public': + if schema != "public": with connection.cursor() as cursor: cursor.execute(f"SET search_path TO {schema};") @@ -86,15 +81,15 @@ def load_workflows_for_schema(self, schema): # Use Pathlib to simplify path manipulations root_path = Path(local_workflows_folder_path) - _l.info("local_workflows_folder_path %s" % local_workflows_folder_path) + _l.info("local_workflows_folder_path %s", local_workflows_folder_path) # Iterate through all files in the /workflows directory and subdirectories for workflow_file in root_path.glob("**/workflow.*"): - _l.debug("workflow_file %s" % workflow_file) + _l.debug("workflow_file %s", workflow_file) if workflow_file.suffix in [".yaml", ".yml", ".json"]: try: - with open(str(workflow_file), "r") as f: + with open(str(workflow_file)) as f: if workflow_file.suffix in [".yaml", ".yml"]: config = yaml.load(f, Loader=yaml.SafeLoader) elif workflow_file.suffix == ".json": @@ -109,14 +104,10 @@ def load_workflows_for_schema(self, schema): config["workflow"]["space_code"] = space.space_code self.workflows[space.space_code + "." + user_code] = config - _l.debug( - f"Loaded workflow for user code: {space.space_code}.{user_code}" - ) + _l.debug(f"Loaded workflow for user code: {space.space_code}.{user_code}") except Exception as e: - _l.warning( - f"Could not load workflow config file: {workflow_file} - {e}" - ) + _l.warning(f"Could not load workflow config file: {workflow_file} - {e}") else: _l.debug(f"Skipped unsupported file format: {workflow_file}") @@ -124,7 +115,7 @@ def load_workflows_for_schema(self, schema): _l.error(f"Error loading workflows for schema {schema}: {e}") def get_by_user_code(self, user_code, sync_remote=False): - _l.info("get_by_user_code %s" % user_code) + _l.info("get_by_user_code %s", user_code) workflow = self.workflows.get(user_code) @@ -186,25 +177,16 @@ def sync_remote_storage_to_local_storage_for_schema( # Check if the local workflows directory exists before attempting to remove it if os.path.exists(local_workflows_folder_path): - _l.info( - f"Removing local workflows directory: {local_workflows_folder_path}" - ) + _l.info(f"Removing local workflows directory: {local_workflows_folder_path}") try: shutil.rmtree(local_workflows_folder_path) - _l.info( - "====[CLEAR DIRECTORY]==== Successfully removed local workflows directory." - ) + _l.info("====[CLEAR DIRECTORY]==== Successfully removed local workflows directory.") except Exception as e: _l.error(f"Failed to remove local workflows directory: {e}") else: - _l.info( - f"Local workflows directory does not exist, no need to remove: {local_workflows_folder_path}" - ) + _l.info(f"Local workflows directory does not exist, no need to remove: {local_workflows_folder_path}") - _l.info( - "remote_workflows_folder_path %s" - % construct_path(remote_workflows_folder_path, module_path) - ) + _l.info("remote_workflows_folder_path %s", construct_path(remote_workflows_folder_path, module_path)) module_path_components = [] if module_path: @@ -214,47 +196,27 @@ def sync_remote_storage_to_local_storage_for_schema( count = 0 for configuration_directory in configuration_directories: - if ( - len(module_path_components) > 0 - and configuration_directory != module_path_components[0] - ): + if len(module_path_components) > 0 and configuration_directory != module_path_components[0]: continue - organization_folder_path = construct_path( - remote_workflows_folder_path, configuration_directory - ) + organization_folder_path = construct_path(remote_workflows_folder_path, configuration_directory) organization_directories, _ = storage.listdir(organization_folder_path) for organization_directory in organization_directories: - if ( - len(module_path_components) > 1 - and organization_directory != module_path_components[1] - ): + if len(module_path_components) > 1 and organization_directory != module_path_components[1]: continue - module_folder_path = construct_path( - organization_folder_path, organization_directory - ) + module_folder_path = construct_path(organization_folder_path, organization_directory) modules_directories, _ = storage.listdir(module_folder_path) for module_directory in modules_directories: - if ( - len(module_path_components) > 2 - and module_directory != module_path_components[2] - ): + if len(module_path_components) > 2 and module_directory != module_path_components[2]: continue - workflow_folder_path = construct_path( - module_folder_path, module_directory - ) + workflow_folder_path = construct_path(module_folder_path, module_directory) workflow_directories, _ = storage.listdir(workflow_folder_path) for workflow_directory in workflow_directories: - if ( - len(module_path_components) > 3 - and workflow_directory != module_path_components[3] - ): + if len(module_path_components) > 3 and workflow_directory != module_path_components[3]: continue - file_folder_path = construct_path( - workflow_folder_path, workflow_directory - ) + file_folder_path = construct_path(workflow_folder_path, workflow_directory) _, files = storage.listdir(file_folder_path) # _l.info("sync_remote_storage_to_local_storage_for_schema.files %s" % files) @@ -278,10 +240,7 @@ def sync_remote_storage_to_local_storage_for_schema( # Log the file syncing # _l.info(f"Syncing file: {filepath}") - if any( - fnmatch.fnmatch(filename, pattern) - for pattern in patterns - ): + if any(fnmatch.fnmatch(filename, pattern) for pattern in patterns): with storage.open(filepath) as f: f_content = f.read() @@ -292,9 +251,7 @@ def sync_remote_storage_to_local_storage_for_schema( "local", filepath.lstrip("/"), ) - os.makedirs( - os.path.dirname(local_path), exist_ok=True - ) + os.makedirs(os.path.dirname(local_path), exist_ok=True) with open(local_path, "wb") as new_file: new_file.write(f_content) @@ -303,12 +260,9 @@ def sync_remote_storage_to_local_storage_for_schema( except Exception as e: _l.error(f"Could not sync file: {filename} - {e}") - # _l.info("load_user_tasks_from_storage_to_local_filesystem.Going to sync file %s DONE " % filepath) + # _l.info("load_user_tasks_from_storage_to_local_filesystem.Going to sync file %s DONE " % filepath) # noqa: E501 - _l.info( - "sync_remote_storage_to_local_storage_for_schema.Done syncing %s files" - % count - ) + _l.info("sync_remote_storage_to_local_storage_for_schema.Done syncing %s files", count) def import_user_tasks(self, workflow_path="**", raise_exception=False): self.plugin_base = PluginBase(package="workflow.foobar") @@ -321,11 +275,9 @@ def import_user_tasks(self, workflow_path="**", raise_exception=False): folder = Path(local_workflows_folder_path).resolve() - _l.info("import_user_tasks %s" % local_workflows_folder_path) + _l.info("import_user_tasks %s", local_workflows_folder_path) - self.plugin_source = self.plugin_base.make_plugin_source( - searchpath=[str(folder)] - ) + self.plugin_source = self.plugin_base.make_plugin_source(searchpath=[str(folder)]) tasks = Path(folder).glob(f"{workflow_path}/*.py") @@ -334,17 +286,13 @@ def import_user_tasks(self, workflow_path="**", raise_exception=False): for task in tasks: if task.stem == "__init__": continue - module_name = ( - str(task.relative_to(folder)).replace("/", ".").rsplit(".", 1)[0] - ) + module_name = str(task.relative_to(folder)).replace("/", ".").rsplit(".", 1)[0] try: # Load the module with a specific and isolated namespace spec = importlib.util.spec_from_file_location(module_name, task) module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = ( - module # Optional: register to sys.modules if needed globally - ) + sys.modules[module_name] = module # Optional: register to sys.modules if needed globally spec.loader.exec_module(module) except Exception as e: _l.info(f"Could not load user script {task}. Error {e}") @@ -375,13 +323,10 @@ def import_user_tasks(self, workflow_path="**", raise_exception=False): _l.info("Tasks are loaded") def cancel_all_existing_tasks(self, worker_name): - from workflow.models import Task - from workflow.models import Workflow + from workflow.models import Task, Workflow # find workflows through tasks - tasks = Task.objects.filter( - status__in=[Task.STATUS_PROGRESS, Task.STATUS_INIT], worker_name=worker_name - ) + tasks = Task.objects.filter(status__in=[Task.STATUS_PROGRESS, Task.STATUS_INIT], worker_name=worker_name) workflow_ids = [task.workflow_id for task in tasks] workflows = Workflow.objects.filter( status__in=[Workflow.STATUS_PROGRESS, Workflow.STATUS_INIT], @@ -399,16 +344,16 @@ def cancel_all_existing_tasks(self, worker_name): try: # just in case if rabbitmq still holds a task if task.celery_task_id: - celery_app.control.revoke(task.celery_task_id, terminate=True, signal='SIGKILL') + celery_app.control.revoke(task.celery_task_id, terminate=True, signal="SIGKILL") except Exception as e: - _l.error("Something went wrong %s" % e) + _l.error("Something went wrong %s", e) task.mark_task_as_finished() task.save() - _l.info("Canceled %s tasks " % len(tasks)) + _l.info("Canceled %s tasks ", len(tasks)) def init_periodic_tasks(self): for user_code, config in self.workflows.items(): @@ -421,9 +366,7 @@ def init_periodic_tasks(self): if "periodic" in workflow: periodic_conf = workflow.get("periodic") periodic_payload = periodic_conf.get("payload", "{}") - schedule_str, schedule_value = build_celery_schedule( - user_code, periodic_conf - ) + schedule_str, schedule_value = build_celery_schedule(user_code, periodic_conf) celery_app.conf.beat_schedule.update( { @@ -446,14 +389,14 @@ def init_periodic_tasks(self): } ) - _l.info("Schedule %s" % celery_app.conf.beat_schedule) + _l.info("Schedule %s", celery_app.conf.beat_schedule) system_workflow_manager = None def get_system_workflow_manager(): - global system_workflow_manager + global system_workflow_manager # noqa: PLW0603 if "makemigrations" in sys.argv or "migrate" in sys.argv: _l.info("system_workflow_manager ignored - TEST MODE") diff --git a/workflow/tasks/base.py b/workflow/tasks/base.py index b12a0145..e3006fb6 100644 --- a/workflow/tasks/base.py +++ b/workflow/tasks/base.py @@ -1,16 +1,14 @@ from celery import Task as _Task -from celery.signals import task_prerun, task_postrun, task_failure, task_internal_error -from celery.exceptions import TimeLimitExceeded, SoftTimeLimitExceeded +from celery.exceptions import SoftTimeLimitExceeded, TimeLimitExceeded +from celery.signals import task_failure, task_internal_error, task_postrun, task_prerun from celery.utils.log import get_task_logger from django.db import connection -from django.utils.timezone import now from workflow.models import Task, Workflow from workflow.utils import ( - send_alert, schema_exists, + send_alert, set_schema_from_context, - get_next_node_by_condition, ) from workflow_app import celery_app @@ -41,14 +39,14 @@ def workflow_prerun(task_id, task, sender, *args, **kwargs): logger.info(f"task_prerun.context {space_code}") else: # REMOVE IN 1.9.0, PROBABLY SECURITY ISSUE with connection.cursor() as cursor: - cursor.execute(f"SET search_path TO public;") + cursor.execute("SET search_path TO public;") else: raise Exception("No space_code in context") else: raise Exception("No context in kwargs") with celery_app.app.app_context(): - print("task_id %s" % task_id) + print("task_id %s", task_id) task = Task.objects.get(celery_task_id=task_id) task.status = Task.STATUS_PROGRESS @@ -66,9 +64,9 @@ def cleanup(task_id, **kwargs): @task_failure.connect @task_internal_error.connect def on_failure(task_id, exception, args, einfo, **kwargs): - logger.info("task_failure.task_id: %s" % task_id) - logger.info("task_failure.kwargs: %s" % kwargs["kwargs"]) - logger.info("task_failure.exception: %s" % exception) + logger.info("task_failure.task_id: %s", task_id) + logger.info("task_failure.kwargs: %s", kwargs["kwargs"]) + logger.info("task_failure.exception: %s", exception) context = kwargs["kwargs"].get("context") set_schema_from_context(context) @@ -76,7 +74,7 @@ def on_failure(task_id, exception, args, einfo, **kwargs): task = Task.objects.get(celery_task_id=task_id) workflow = Workflow.objects.get(id=task.workflow_id) - if isinstance(exception, (TimeLimitExceeded, SoftTimeLimitExceeded)): + if isinstance(exception, TimeLimitExceeded | SoftTimeLimitExceeded): workflow.status = Workflow.STATUS_TIMEOUT task.status = Task.STATUS_TIMEOUT else: @@ -91,9 +89,7 @@ def on_failure(task_id, exception, args, einfo, **kwargs): workflow.save() if task.workflow.parent: - logger.info( - f"task_failureWorkflow has a parent with ID {task.workflow.parent.id}. Triggering next task." - ) + logger.info(f"task_failureWorkflow has a parent with ID {task.workflow.parent.id}. Triggering next task.") parent_workflow = task.workflow.parent parent_task = Task.objects.get( @@ -172,8 +168,8 @@ def is_workflow_already_running(self, workflow_user_code): return is_running def before_start(self, task_id, args, kwargs): - logger.info("BaseTask.before_start.task_id %s" % task_id) - logger.info("BaseTask.before_start.kwargs: %s" % kwargs) + logger.info("BaseTask.before_start.task_id %s", task_id) + logger.info("BaseTask.before_start.kwargs: %s", kwargs) context = kwargs.get("context") set_schema_from_context(context) @@ -189,13 +185,13 @@ def before_start(self, task_id, args, kwargs): self.workflow = workflow logger.info(f"Task {task_id} is now in progress") - super(BaseTask, self).before_start(task_id, args, kwargs) + super().before_start(task_id, args, kwargs) def on_success(self, retval, task_id, args, kwargs): - super(BaseTask, self).on_success(retval, task_id, args, kwargs) + super().on_success(retval, task_id, args, kwargs) - logger.info("BaseTask.on_success.task_id %s" % task_id) - logger.info("BaseTask.on_success.kwargs: %s" % kwargs) + logger.info("BaseTask.on_success.task_id %s", task_id) + logger.info("BaseTask.on_success.kwargs: %s", kwargs) context = kwargs.get("context") set_schema_from_context(context) diff --git a/workflow/tasks/export_backend_historical_records.py b/workflow/tasks/export_backend_historical_records.py index 21943962..f73a9abb 100644 --- a/workflow/tasks/export_backend_historical_records.py +++ b/workflow/tasks/export_backend_historical_records.py @@ -1,7 +1,7 @@ -from datetime import datetime, timedelta import json -import requests +from datetime import datetime, timedelta +import requests from celery.utils.log import get_task_logger from django.conf import settings @@ -39,17 +39,9 @@ def call_export_backend_historical_records(self, *args, **kwargs): + "/api/v1/history/historical-record/export/" ) else: - url = ( - "https://" - + settings.DOMAIN_NAME - + "/" - + space.space_code - + "/api/v1/history/historical-record/export/" - ) + url = "https://" + settings.DOMAIN_NAME + "/" + space.space_code + "/api/v1/history/historical-record/export/" - response = requests.post( - url=url, data=json.dumps(data), headers=headers, verify=settings.VERIFY_SSL - ) + response = requests.post(url=url, data=json.dumps(data), headers=headers, verify=settings.VERIFY_SSL) if response.status_code != 200: logger.error(response.text) diff --git a/workflow/tasks/workflows.py b/workflow/tasks/workflows.py index b95fba53..2a2a22c2 100644 --- a/workflow/tasks/workflows.py +++ b/workflow/tasks/workflows.py @@ -1,3 +1,4 @@ +import logging import os.path import time import traceback @@ -7,14 +8,12 @@ from celery.utils.log import get_task_logger from django.utils.timezone import now -from workflow.models import Task, Workflow, User, Space, Schedule, WorkflowTemplate +from workflow.models import Schedule, Space, Task, User, Workflow, WorkflowTemplate from workflow.tasks.base import BaseTask -from workflow.utils import set_schema_from_context, are_inputs_ready +from workflow.utils import are_inputs_ready, set_schema_from_context from workflow_app import celery_app logger = get_task_logger(__name__) -import logging - _l = logging.getLogger("workflow") @@ -68,10 +67,8 @@ def mark_as_canceled_init_tasks(self, workflow_id, *args, **kwargs): @celery_app.task(bind=True) -def failure_hooks_launcher( - self, workflow_id, queue, tasks_names, payload, *args, **kwargs -): - logger.info("failure_hooks_launcher %s" % workflow_id) +def failure_hooks_launcher(self, workflow_id, queue, tasks_names, payload, *args, **kwargs): + logger.info("failure_hooks_launcher %s", workflow_id) context = kwargs.get("context") set_schema_from_context(context) @@ -98,7 +95,7 @@ def failure_hooks_launcher( ) task.save() - logger.info("failure_hooks_launcher.task %s" % task) + logger.info("failure_hooks_launcher.task %s", task) canvas.append(signature) @@ -110,21 +107,19 @@ def failure_hooks_launcher( try: result.get() except Exception as e: - logger.error("failure_hooks_launcher.result.get e %s" % e) + logger.error("failure_hooks_launcher.result.get e %s", e) pass logger.info("Going to cancel init tasks") task_id = uuid() - signature_mark_as_canceled = celery_app.tasks.get( - "workflow.tasks.workflows.mark_as_canceled_init_tasks" - ).subtask( + signature_mark_as_canceled = celery_app.tasks.get("workflow.tasks.workflows.mark_as_canceled_init_tasks").subtask( kwargs={"workflow_id": workflow_id}, queue="workflow", task_id=task_id, ) - logger.info("signature_mark_as_canceled %s" % signature_mark_as_canceled) + logger.info("signature_mark_as_canceled %s", signature_mark_as_canceled) signature_mark_as_canceled.apply_async() @@ -140,10 +135,10 @@ def failure_hooks_launcher( def execute(self, user_code, payload, is_manager, *args, **kwargs): from workflow.system import get_system_workflow_manager - manager = get_system_workflow_manager() + manager = get_system_workflow_manager() # noqa: F841 try: - logger.info("periodic.execute %s" % user_code) + logger.info("periodic.execute %s", user_code) context = kwargs.get("context") @@ -154,7 +149,7 @@ def execute(self, user_code, payload, is_manager, *args, **kwargs): schedule_id = kwargs.get("schedule_id") - logger.info("periodic.schedule_id %s" % schedule_id) + logger.info("periodic.schedule_id %s", schedule_id) schedule = Schedule.objects.get(id=schedule_id) @@ -179,12 +174,12 @@ def execute(self, user_code, payload, is_manager, *args, **kwargs): return data except Exception as e: - logger.error("periodic task error: %s" % e, exc_info=True) + logger.error("periodic task error: %s", e, exc_info=True) @celery_app.task(bind=True, base=BaseTask) -def execute_workflow_step(self, *args, **kwargs): - from workflow.api import get_registered_task, clear_registered_task +def execute_workflow_step(self, *args, **kwargs): # noqa: PLR0912,PLR0915 + from workflow.api import clear_registered_task, get_registered_task from workflow.system import get_system_workflow_manager clear_registered_task() @@ -214,18 +209,14 @@ def execute_workflow_step(self, *args, **kwargs): # If the code has defined a `main()` function, call it if "main" in exec_scope: - # logger.info(f"Executing main() function in user-provided source code for node {self.task.source_code}") + # logger.info(f"Executing main() function in user-provided source code for node {self.task.source_code}") # noqa: E501 result = exec_scope["main"](self, *args, **kwargs) return result else: - logger.warning( - f"No main() function found in source code for node. Skipping execution." - ) + logger.warning("No main() function found in source code for node. Skipping execution.") except Exception as e: - logger.error( - f"Error executing custom source code for node {self.task.source_code}: {e}" - ) + logger.error(f"Error executing custom source code for node {self.task.source_code}: {e}") raise e else: @@ -234,19 +225,13 @@ def execute_workflow_step(self, *args, **kwargs): if target_workflow_user_code.endswith(".task"): target_workflow_user_code = target_workflow_user_code[:-5] - _l.info("target_workflow_user_code %s" % target_workflow_user_code) + _l.info("target_workflow_user_code %s", target_workflow_user_code) - target_space_workflow_user_code = ( - f"{context.get('space_code')}.{target_workflow_user_code}" - ) + target_space_workflow_user_code = f"{context.get('space_code')}.{target_workflow_user_code}" - target_wf = manager.get_by_user_code( - target_space_workflow_user_code, sync_remote=True - ) + target_wf = manager.get_by_user_code(target_space_workflow_user_code, sync_remote=True) - _l.info( - f"execute_workflow_step: Target Workflow Version {target_wf.get('version')}" - ) + _l.info(f"execute_workflow_step: Target Workflow Version {target_wf.get('version')}") if int(target_wf.get("version", 1)) == 2: target_workflow_template = None @@ -259,7 +244,7 @@ def execute_workflow_step(self, *args, **kwargs): ) except Exception as e: _l.error("No target_workflow_template exist, abort") - raise Exception(e) + raise Exception(e) from e child_workflow = Workflow.objects.create( owner=parent_workflow.owner, @@ -308,16 +293,14 @@ def execute_workflow_step(self, *args, **kwargs): if isinstance(imports, list): for extra_path in imports: logger.info(f"importing {extra_path}") - extra_path = os.path.normpath(os.path.join(module_path, extra_path)) + extra_path = os.path.normpath(os.path.join(module_path, extra_path)) # noqa: PLW2901 last_segment = extra_path.split("/")[-1] if "*" in last_segment or "?" in last_segment: # given a wildcard for file name - extra_path, pattern = extra_path.rsplit("/", maxsplit=1) + extra_path, pattern = extra_path.rsplit("/", maxsplit=1) # noqa: PLW2901 else: pattern = "*.*" - manager.sync_remote_storage_to_local_storage_for_schema( - extra_path, [pattern] - ) + manager.sync_remote_storage_to_local_storage_for_schema(extra_path, [pattern]) manager.import_user_tasks(module_path, raise_exception=True) @@ -332,7 +315,7 @@ def execute_workflow_step(self, *args, **kwargs): @celery_app.task(bind=True) def execute_workflow_v2(self, *args, **kwargs): - logger.info(f"Opening the workflow with ID: {kwargs.get('workflow_id', None)}") + logger.info(f"Opening the workflow with ID: {kwargs.get('workflow_id')}") # Log the context passed to the workflow context = kwargs.get("context") @@ -369,11 +352,7 @@ def execute_workflow_v2(self, *args, **kwargs): logger.info(f"Adjacency list created: {adjacency_list}") # Start from nodes without incoming edges (root nodes) - start_nodes = [ - node_id - for node_id in nodes - if not any(node_id == conn["target"] for conn in connections) - ] + start_nodes = [node_id for node_id in nodes if not any(node_id == conn["target"] for conn in connections)] logger.info(f"Start nodes determined: {start_nodes}") # Execute tasks from start nodes @@ -396,9 +375,7 @@ def execute_workflow_v2(self, *args, **kwargs): @celery_app.task(bind=True) -def process_next_node( - self, current_node_id, workflow_id, nodes, adjacency_list, **kwargs -): +def process_next_node(self, current_node_id, workflow_id, nodes, adjacency_list, **kwargs): # noqa: PLR0912,PLR0915 context = kwargs.get("context") logger.info(f"process_next_node context received: {context}") set_schema_from_context(context) @@ -411,18 +388,14 @@ def process_next_node( current_node = nodes[current_node_id] if workflow.status == Workflow.STATUS_WAIT: - logger.info( - f"Workflow {workflow_id} is currently waiting. Stopping execution until resumed." - ) + logger.info(f"Workflow {workflow_id} is currently waiting. Stopping execution until resumed.") # Save the current_node_id for resuming workflow.current_node_id = current_node_id workflow.save() return # Exit the task without further execution if not are_inputs_ready(workflow, current_node_id, kwargs.get("connections")): - logger.info( - f"Task for Node ID: {current_node_id}, inputs are not ready, wait" - ) + logger.info(f"Task for Node ID: {current_node_id}, inputs are not ready, wait") return if current_node["data"]["node"]["type"] == "source_code": @@ -432,9 +405,7 @@ def process_next_node( else: workflow_user_code = current_node["data"]["workflow"]["user_code"] - logger.info( - f"Executing task for Node ID: {current_node_id}, Task Name: {workflow_user_code}" - ) + logger.info(f"Executing task for Node ID: {current_node_id}, Task Name: {workflow_user_code}") payload = workflow.payload # Default to the workflow payload previous_output = None @@ -443,50 +414,36 @@ def process_next_node( # Identify the previous node connected to "in" previous_node_id = None for connection in kwargs.get("connections"): - if ( - connection["target"] == current_node_id - and connection["targetInput"] == "in" - ): + if connection["target"] == current_node_id and connection["targetInput"] == "in": previous_node_id = connection["source"] break if previous_node_id: previous_task = ( - Task.objects.filter(workflow=workflow, node_id=previous_node_id) - .order_by("-created_at") - .first() + Task.objects.filter(workflow=workflow, node_id=previous_node_id).order_by("-created_at").first() ) if previous_task: previous_output = previous_task.result - logger.info( - f"Using previous output from node ID {previous_node_id}: {previous_output}" - ) + logger.info(f"Using previous output from node ID {previous_node_id}: {previous_output}") # Find the previous task that provided the payload if "payload_input" in current_node["inputs"]: # Identify the payload generator node connected to "payload_input" payload_generator_node_id = None for connection in kwargs.get("connections"): - if ( - connection["target"] == current_node_id - and connection["targetInput"] == "payload_input" - ): + if connection["target"] == current_node_id and connection["targetInput"] == "payload_input": payload_generator_node_id = connection["source"] break if payload_generator_node_id: payload_task = ( - Task.objects.filter( - workflow=workflow, node_id=payload_generator_node_id - ) + Task.objects.filter(workflow=workflow, node_id=payload_generator_node_id) .order_by("-created_at") .first() ) if payload_task: payload = payload_task.result - logger.info( - f"Using payload from node ID {payload_generator_node_id}: {payload}" - ) + logger.info(f"Using payload from node ID {payload_generator_node_id}: {payload}") # Create Celery signature for the current task task_id = uuid() diff --git a/workflow/tests/base.py b/workflow/tests/base.py index 8d1dd2a4..0d5d5013 100644 --- a/workflow/tests/base.py +++ b/workflow/tests/base.py @@ -1,11 +1,10 @@ import random import string + from django.test import TestCase class BaseTestCase(TestCase): @classmethod def random_string(cls, length: int = 10) -> str: - return "".join( - random.SystemRandom().choice(string.ascii_uppercase) for _ in range(length) - ) + return "".join(random.SystemRandom().choice(string.ascii_uppercase) for _ in range(length)) diff --git a/workflow/tests/factories.py b/workflow/tests/factories.py index 1b506d91..f7e7056a 100644 --- a/workflow/tests/factories.py +++ b/workflow/tests/factories.py @@ -1,17 +1,17 @@ import json import random import string + import factory from faker import Faker -from workflow.models import Space, TimeStampedModel, User, WorkflowTemplate + +from workflow.models import Space, User, WorkflowTemplate fake = Faker() def random_code(length): - return "".join( - random.choice(string.ascii_lowercase + string.digits) for _ in range(length) - ) + return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(length)) class UserFactory(factory.django.DjangoModelFactory): @@ -38,9 +38,7 @@ class Meta: model = WorkflowTemplate name = factory.Faker("word") - user_code = factory.LazyAttribute( - lambda _: f"{fake.word()}.{fake.word()}.{fake.word()}:{fake.word()}" - ) + user_code = factory.LazyAttribute(lambda _: f"{fake.word()}.{fake.word()}.{fake.word()}:{fake.word()}") notes = factory.Faker("text") data = factory.LazyAttribute(lambda _: json.dumps({"version": "2", "workflow": {}})) space = factory.SubFactory(SpaceFactory) diff --git a/workflow/tests/test_schedule_view_set.py b/workflow/tests/test_schedule_view_set.py index 43324381..c053465b 100644 --- a/workflow/tests/test_schedule_view_set.py +++ b/workflow/tests/test_schedule_view_set.py @@ -1,6 +1,7 @@ from rest_framework import status from rest_framework.test import APIClient -from workflow.models import Schedule, User, Space + +from workflow.models import Schedule, Space, User from workflow.system import get_system_workflow_manager from .base import BaseTestCase @@ -12,9 +13,7 @@ def setUp(self): self.realm_code = f"realm{self.random_string(5)}" self.space_code = f"space{self.random_string(5)}" self.url_prefix = f"/{self.realm_code}/{self.space_code}/workflow/api/schedule/" - self.space = Space.objects.create( - realm_code=self.realm_code, space_code=self.space_code - ) + self.space = Space.objects.create(realm_code=self.realm_code, space_code=self.space_code) self.user = User.objects.create( username=self.random_string(5), is_staff=True, @@ -65,9 +64,7 @@ def test_update_schedule(self): "crontab_line": "30 * * * *", "payload": {"updated": "data"}, } - response = self.client.patch( - self.url_prefix + f"{self.schedule.pk}/", data, format="json" - ) + response = self.client.patch(self.url_prefix + f"{self.schedule.pk}/", data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.schedule.refresh_from_db() self.assertEqual(self.schedule.user_code, data["user_code"]) diff --git a/workflow/tests/test_signals.py b/workflow/tests/test_signals.py index 92bb3626..4e7cce82 100644 --- a/workflow/tests/test_signals.py +++ b/workflow/tests/test_signals.py @@ -1,23 +1,21 @@ import time -from django.test import TransactionTestCase, override_settings + from celery.contrib.testing.worker import start_worker from celery.exceptions import SoftTimeLimitExceeded from celery.utils import uuid -from workflow.models import Space, Workflow, Task, User -from workflow_app import celery_app +from django.test import TransactionTestCase, override_settings + +from workflow.models import Space, Task, User, Workflow from workflow.tasks.base import BaseTask +from workflow_app import celery_app -@override_settings( - CELERY_BROKER_URL="memory://", CELERY_RESULT_BACKEND="cache+memory://" -) +@override_settings(CELERY_BROKER_URL="memory://", CELERY_RESULT_BACKEND="cache+memory://") class SignalsTest(TransactionTestCase): def setUp(self): space = Space.objects.first() task_id = uuid() - self.workflow = Workflow.objects.create( - space=space, owner=User.objects.first(), status=Workflow.STATUS_INIT - ) + self.workflow = Workflow.objects.create(space=space, owner=User.objects.first(), status=Workflow.STATUS_INIT) self.task = Task.objects.create( workflow=self.workflow, space=space, diff --git a/workflow/tests/test_workflow_template_view_set.py b/workflow/tests/test_workflow_template_view_set.py index 13e5ee20..b1971e44 100644 --- a/workflow/tests/test_workflow_template_view_set.py +++ b/workflow/tests/test_workflow_template_view_set.py @@ -1,5 +1,7 @@ from datetime import timedelta + from rest_framework.test import APIClient + from workflow.tests.factories import SpaceFactory, UserFactory, WorkflowTemplateFactory from .base import BaseTestCase @@ -19,12 +21,8 @@ def setUp(self): self.client = APIClient() self.client.force_authenticate(self.user) - self.workflow_template1 = WorkflowTemplateFactory( - space=self.space, owner=self.user - ) - self.workflow_template2 = WorkflowTemplateFactory( - space=self.space, owner=self.user - ) + self.workflow_template1 = WorkflowTemplateFactory(space=self.space, owner=self.user) + self.workflow_template2 = WorkflowTemplateFactory(space=self.space, owner=self.user) def get_ids(self, response): return [w["id"] for w in response.data["results"]] @@ -36,12 +34,8 @@ def test_get_list(self): self.assertEqual(response.data["count"], 2) def test_filter_queryset_date_range(self): - created_at_after = ( - self.workflow_template1.created_at - timedelta(days=1) - ).strftime("%Y-%m-%d") - crated_at_before = ( - self.workflow_template1.created_at + timedelta(days=1) - ).strftime("%Y-%m-%d") + created_at_after = (self.workflow_template1.created_at - timedelta(days=1)).strftime("%Y-%m-%d") + crated_at_before = (self.workflow_template1.created_at + timedelta(days=1)).strftime("%Y-%m-%d") self.workflow_template2.created_at -= timedelta(days=3) self.workflow_template2.save() @@ -60,9 +54,7 @@ def test_filter_queryset_date_range(self): self.assertNotIn(self.workflow_template2.id, ids) def test_filter_query_user_code(self): - response = self.client.get( - self.url, {"user_code": self.workflow_template1.user_code} - ) + response = self.client.get(self.url, {"user_code": self.workflow_template1.user_code}) self.assertEqual(response.status_code, 200) diff --git a/workflow/tests/test_workflow_view_set.py b/workflow/tests/test_workflow_view_set.py index 4e0e62d7..08660fb2 100644 --- a/workflow/tests/test_workflow_view_set.py +++ b/workflow/tests/test_workflow_view_set.py @@ -1,6 +1,8 @@ from datetime import date + from rest_framework.test import APIClient -from workflow.models import Workflow, User, Space + +from workflow.models import Space, User, Workflow from workflow.tests.factories import WorkflowTemplateFactory from .base import BaseTestCase @@ -12,9 +14,7 @@ def setUp(self): self.realm_code = f"realm{self.random_string(5)}" self.space_code = f"space{self.random_string(5)}" self.url_prefix = f"/{self.realm_code}/{self.space_code}/workflow/api/workflow/" - self.space = Space.objects.create( - realm_code=self.realm_code, space_code=self.space_code - ) + self.space = Space.objects.create(realm_code=self.realm_code, space_code=self.space_code) self.user = User.objects.create( username=self.random_string(5), is_staff=True, @@ -62,9 +62,7 @@ def test_filter_queryset_payload(self): ids = [w["id"] for w in response.data["results"]] self.assertIn(self.workflow1.id, ids) self.assertNotIn(self.workflow2.id, ids) - self.assertTrue( - all([w for w in response.data["results"] if "another" in w["payload"]]) - ) + self.assertTrue(all([w for w in response.data["results"] if "another" in w["payload"]])) def test_filter_queryset_payload_partial_not_found(self): response = self.client.get(self.url_prefix, {"payload": "test"}) diff --git a/workflow/user_sessions.py b/workflow/user_sessions.py index 69cc351e..a0f4d142 100644 --- a/workflow/user_sessions.py +++ b/workflow/user_sessions.py @@ -1,3 +1,6 @@ +import base64 +import contextlib +import io import json import logging import sys @@ -5,14 +8,9 @@ import traceback import uuid -_l = logging.getLogger("workflow") -import contextlib -import io - -import sys import matplotlib.pyplot as plt -import base64 -import json + +_l = logging.getLogger("workflow") class UserSession: @@ -36,10 +34,10 @@ def execute_code(user_id, file_path, code): session = sessions[user_id] context = session.get_file_context(file_path) - _l.info("execute_code.context %s" % context) + _l.info("execute_code.context %s", context) # Create a StringIO object to capture the standard output - stdout = io.StringIO() + stdout = io.StringIO() # noqa: F841 # Add print() to last line if it's not an assignment # code_lines = code.split('\n') @@ -54,7 +52,7 @@ def execute_code(user_id, file_path, code): # Execute the code try: exec(code, context) - except Exception as e: + except Exception: # Print the traceback of the error traceback.print_exc(file=redirected_output) @@ -125,9 +123,7 @@ def _execute_code(code, context): except Exception as e: # Handle any errors that occur during execution - traceback_str = "".join( - traceback.format_exception(None, e, e.__traceback__) - ) + traceback_str = "".join(traceback.format_exception(None, e, e.__traceback__)) return {"type": "error", "data": traceback_str} finally: @@ -139,10 +135,10 @@ def _execute_code(code, context): def execute_file(user_id, file_path, data): - session = sessions[user_id] + session = sessions[user_id] # noqa: F841 context = {} - _l.info("execute_file.context %s" % context) + _l.info("execute_file.context %s", context) # Create a StringIO object to capture the standard output diff --git a/workflow/utils.py b/workflow/utils.py index c90fcb0d..8fb9ed81 100644 --- a/workflow/utils.py +++ b/workflow/utils.py @@ -1,17 +1,15 @@ import logging import os -import sys import random import string +import sys from celery.schedules import crontab -from jsonschema.validators import validator_for from django.db import connection -from django.db import transaction +from jsonschema.validators import validator_for from workflow.exceptions import WorkflowSyntaxError - _l = logging.getLogger("workflow") @@ -35,7 +33,7 @@ def format_schema_errors(e): def build_celery_schedule(workflow_name, data): """A celery schedule can accept seconds or crontab""" - _l.info("build_celery_schedule %s" % workflow_name) + _l.info("build_celery_schedule %s", workflow_name) def _handle_schedule(schedule): try: @@ -62,7 +60,7 @@ def _handle_crontab(ct): ) excluded_keys = ["payload"] - keys = [k for k in data.keys() if k not in excluded_keys] + keys = [k for k in data if k not in excluded_keys] schedule_functions = { # Legacy syntax for backward compatibility @@ -72,7 +70,7 @@ def _handle_crontab(ct): "interval": float, } - if len(keys) != 1 or keys[0] not in schedule_functions.keys(): + if len(keys) != 1 or keys[0] not in schedule_functions: # When there is no key (schedule, interval, crontab) in the periodic configuration raise WorkflowSyntaxError(workflow_name) @@ -82,20 +80,20 @@ def _handle_crontab(ct): # Apply the function mapped to the schedule type return str(schedule_input), schedule_functions[schedule_key](schedule_input) except Exception as e: - _l.error("build_celery_schedule.e %s" % e) + _l.error("build_celery_schedule.e %s", e) - raise WorkflowSyntaxError(workflow_name) + raise WorkflowSyntaxError(workflow_name) from e def send_alert(workflow): - from workflow.models import Workflow - from workflow.models import User - from workflow.models import Task - from workflow_app import settings - from rest_framework_simplejwt.tokens import RefreshToken - import requests import json + import requests + from rest_framework_simplejwt.tokens import RefreshToken + + from workflow.models import Task, User, Workflow + from workflow_app import settings + if workflow.status == Workflow.STATUS_ERROR: # _l.info("Going to report Error to Finmars") @@ -109,15 +107,15 @@ def send_alert(workflow): headers = { "Content-type": "application/json", "Accept": "application/json", - "Authorization": "Bearer %s" % refresh.access_token, + "Authorization": f"Bearer {refresh.access_token}", } error_task = workflow.tasks.filter(status=Task.STATUS_ERROR).first() - error_description = "Unknown" + error_description = "Unknown" # noqa: F841 if error_task: - error_description = str(error_task.error_message) + error_description = str(error_task.error_message) # noqa: F841 _l.info("Going to report Error to Finmars") @@ -141,15 +139,9 @@ def send_alert(workflow): + "/api/v1/utils/expression/" ) else: - url = ( - "https://" - + settings.DOMAIN_NAME - + "/" - + workflow.space.space_code - + "/api/v1/utils/expression/" - ) + url = "https://" + settings.DOMAIN_NAME + "/" + workflow.space.space_code + "/api/v1/utils/expression/" - response = requests.post( + response = requests.post( # noqa: F841 url=url, data=json.dumps(data), headers=headers, @@ -159,7 +151,7 @@ def send_alert(workflow): # _l.info('response %s' % response.text) except Exception as e: - _l.error("Could not send system message to finmars. Error %s" % e) + _l.error("Could not send system message to finmars. Error %s", e) def construct_path(*args): @@ -231,9 +223,7 @@ def set_schema_from_context(context): def generate_random_string(N): - return "".join( - random.choice(string.ascii_lowercase + string.digits) for _ in range(N) - ) + return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(N)) def are_inputs_ready(workflow, node_id, connections): @@ -273,9 +263,7 @@ def get_next_node_by_condition(current_node_id, condition_result, connections): Returns: - The ID of the next node to execute. """ - _l.info( - f"Evaluating condition for node {current_node_id}, result: {condition_result}" - ) + _l.info(f"Evaluating condition for node {current_node_id}, result: {condition_result}") # Define which output to follow based on condition result @@ -288,19 +276,12 @@ def get_next_node_by_condition(current_node_id, condition_result, connections): raise Exception("Wrong condition_result") # Iterate through connections to find the target node - for connection in connections: - if ( - connection["source"] == current_node_id - and connection["sourceOutput"] == output_to_follow - ): + for connection in connections: # noqa: F402 + if connection["source"] == current_node_id and connection["sourceOutput"] == output_to_follow: next_node_id = connection["target"] - _l.info( - f"Following output '{output_to_follow}' to next node {next_node_id}" - ) + _l.info(f"Following output '{output_to_follow}' to next node {next_node_id}") return next_node_id # If no matching connection is found, return None and log a warning - _l.warning( - f"No matching connection found for node {current_node_id} with output '{output_to_follow}'" - ) + _l.warning(f"No matching connection found for node {current_node_id} with output '{output_to_follow}'") return None diff --git a/workflow/views.py b/workflow/views.py index 648415ae..046f0825 100644 --- a/workflow/views.py +++ b/workflow/views.py @@ -1,10 +1,8 @@ -import json import logging import os import traceback import django_filters -import pexpect from django.core.management import call_command from django.http import HttpResponse from django.utils import timezone @@ -20,29 +18,28 @@ from rest_framework.viewsets import ModelViewSet, ViewSet from workflow.filters import ( - WorkflowQueryFilter, WholeWordsSearchFilter, + WorkflowQueryFilter, WorkflowSearchParamFilter, ) -from workflow.models import Workflow, Task, Schedule, WorkflowTemplate +from workflow.models import Schedule, Task, Workflow, WorkflowTemplate from workflow.serializers import ( - WorkflowSerializer, - TaskSerializer, - PingSerializer, - WorkflowLightSerializer, BulkSerializer, + PingSerializer, + ResumeWorkflowSerializer, RunWorkflowSerializer, ScheduleSerializer, + TaskSerializer, + WorkflowLightSerializer, + WorkflowSerializer, WorkflowTemplateSerializer, - ResumeWorkflowSerializer, ) -from workflow.user_sessions import create_session, execute_code, sessions, execute_file +from workflow.system import get_system_workflow_manager +from workflow.user_sessions import create_session, execute_code, execute_file, sessions from workflow.workflows import execute_workflow - -_l = logging.getLogger("workflow") from workflow_app import celery_app -from workflow.system import get_system_workflow_manager +_l = logging.getLogger("workflow") system_workflow_manager = get_system_workflow_manager() @@ -106,7 +103,7 @@ def run_workflow(self, request, pk=None, *args, **kwargs): platform_task_id, ) - _l.info("data %s" % data) + _l.info("data %s", data) return Response(data) @@ -114,9 +111,7 @@ def run_workflow(self, request, pk=None, *args, **kwargs): class WorkflowFilterSet(FilterSet): name = django_filters.CharFilter() user_code = django_filters.CharFilter() - status = django_filters.MultipleChoiceFilter( - field_name="status", choices=Workflow.STATUS_CHOICES - ) + status = django_filters.MultipleChoiceFilter(field_name="status", choices=Workflow.STATUS_CHOICES) created_at = django_filters.DateFromToRangeFilter() class Meta: @@ -175,7 +170,7 @@ def run_workflow(self, request, pk=None, *args, **kwargs): if request.space_code not in user_code: user_code = f"{request.space_code}.{user_code}" - _l.info("user_code %s" % user_code) + _l.info("user_code %s", user_code) system_workflow_manager.get_by_user_code(user_code, sync_remote=True) @@ -188,7 +183,7 @@ def run_workflow(self, request, pk=None, *args, **kwargs): platform_task_id, ) - _l.info("data %s" % data) + _l.info("data %s", data) return Response(data) @@ -207,13 +202,9 @@ def relaunch(self, request, pk=None, *args, **kwargs): @action(detail=True, methods=("POST",), url_path="cancel") def cancel(self, request, pk=None, *args, **kwargs): - - - workflow = Workflow.objects.get(id=pk) - if workflow.status == Workflow.STATUS_INIT or workflow.status == Workflow.STATUS_PROGRESS or workflow.status == Workflow.STATUS_WAIT: - + if workflow.status in [Workflow.STATUS_INIT, Workflow.STATUS_PROGRESS, Workflow.STATUS_WAIT]: workflow.cancel() return Response(workflow.to_dict()) @@ -229,12 +220,10 @@ def cancel(self, request, pk=None, *args, **kwargs): ) def bulk_cancel(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) - valid = serializer.is_valid(raise_exception=False) + valid = serializer.is_valid(raise_exception=False) # noqa: F841 data = serializer.validated_data - workflows = Workflow.objects.filter( - id__in=data["ids"], status=Workflow.STATUS_PROGRESS - ) + workflows = Workflow.objects.filter(id__in=data["ids"], status=Workflow.STATUS_PROGRESS) for workflow in workflows: workflow.cancel() @@ -278,9 +267,7 @@ def pause_workflow(self, request, pk=None, *args, **kwargs): ) except Workflow.DoesNotExist: - return Response( - {"message": "Workflow not found."}, status=status.HTTP_404_NOT_FOUND - ) + return Response({"message": "Workflow not found."}, status=status.HTTP_404_NOT_FOUND) # Resume Workflow Action @action( @@ -310,9 +297,7 @@ def resume_workflow(self, request, pk=None, *args, **kwargs): if active_tasks.exists(): return Response( - { - "message": "Cannot resume workflow while there are active tasks running." - }, + {"message": "Cannot resume workflow while there are active tasks running."}, status=status.HTTP_400_BAD_REQUEST, ) @@ -329,10 +314,7 @@ def resume_workflow(self, request, pk=None, *args, **kwargs): # Trigger the next task from the stored `current_node_id` if workflow.current_node_id: - nodes = { - node["id"]: node - for node in workflow.workflow_template.data["workflow"]["nodes"] - } + nodes = {node["id"]: node for node in workflow.workflow_template.data["workflow"]["nodes"]} connections = workflow.workflow_template.data["workflow"]["connections"] adjacency_list = {node_id: [] for node_id in nodes} for connection in connections: @@ -365,9 +347,7 @@ def resume_workflow(self, request, pk=None, *args, **kwargs): ) except Workflow.DoesNotExist: - return Response( - {"message": "Workflow not found."}, status=status.HTTP_404_NOT_FOUND - ) + return Response({"message": "Workflow not found."}, status=status.HTTP_404_NOT_FOUND) class TaskViewSet(ModelViewSet): @@ -422,9 +402,7 @@ def list(self, request, *args, **kwargs): # _l.info("RefreshStorageViewSet.stop flower result %s" % result) # c = pexpect.spawn("python /var/app/manage.py sync_remote_storage_to_local_storage", timeout=240) - system_workflow_manager.sync_remote_storage_to_local_storage( - request.space_code - ) + system_workflow_manager.sync_remote_storage_to_local_storage(request.space_code) # c = pexpect.spawn("supervisorctl start celery", timeout=240) # result = c.read() @@ -440,8 +418,8 @@ def list(self, request, *args, **kwargs): system_workflow_manager.register_workflows(request.space_code) except Exception as e: - _l.info("Could not restart celery.exception %s" % e) - _l.info("Could not restart celery.traceback %s" % traceback.format_exc()) + _l.info("Could not restart celery.exception %s", e) + _l.info("Could not restart celery.traceback %s", traceback.format_exc()) return Response({"status": "ok"}) @@ -454,9 +432,7 @@ def list(self, request, *args, **kwargs): # _l.info('DefinitionViewSet.definition %s' % definition) if definition["workflow"]["space_code"] == request.space_code: - workflow_definitions.append( - {"user_code": user_code, **definition["workflow"]} - ) + workflow_definitions.append({"user_code": user_code, **definition["workflow"]}) return Response(workflow_definitions) @@ -471,7 +447,7 @@ def list(self, request, *args, **kwargs): # Read the last 2MB of your log file bytes_to_read = 2 * 1024 * 1024 # 2MB in bytes - with open(log_file_path, "r") as log_file: + with open(log_file_path) as log_file: log_file.seek(max(0, log_file.tell() - bytes_to_read), 0) log_content = log_file.read() @@ -590,19 +566,16 @@ def run_manual(self, request, *args, **kwargs): ) except Schedule.DoesNotExist: - return Response( - {"error": "Schedule not found."}, status=status.HTTP_404_NOT_FOUND - ) + return Response({"error": "Schedule not found."}, status=status.HTTP_404_NOT_FOUND) except Exception as e: - return Response( - {"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR - ) + return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) class CeleryStatusViewSet(ViewSet): """ A simple ViewSet that returns Celery queue and worker status. """ + def list(self, request, *args, **kwargs): insp = celery_app.control.inspect() data = { @@ -612,4 +585,4 @@ def list(self, request, *args, **kwargs): "scheduled": insp.scheduled() or {}, } - return Response(data) \ No newline at end of file + return Response(data) diff --git a/workflow/workflows.py b/workflow/workflows.py index 0e1628ac..d28fa380 100644 --- a/workflow/workflows.py +++ b/workflow/workflows.py @@ -1,8 +1,7 @@ import logging from workflow.builder import WorkflowBuilder - -from workflow.models import Workflow, User, Space, WorkflowTemplate +from workflow.models import Space, User, Workflow, WorkflowTemplate from workflow.tasks.workflows import execute_workflow_v2 _l = logging.getLogger("workflow") @@ -11,12 +10,15 @@ def execute_workflow( username, user_code, - payload={}, + payload=None, realm_code=None, space_code=None, platform_task_id=None, crontab_id=None, ): + if payload is None: + payload = {} + user = User.objects.get(username=username) from workflow.system import get_system_workflow_manager @@ -28,14 +30,12 @@ def execute_workflow( workflow_template = None - _l.info("Looking for worklow template %s" % user_code) + _l.info("Looking for worklow template %s", user_code) space_less_user_code = ".".join(user_code.split(".")[1:]) - try: - workflow_template = WorkflowTemplate.objects.get( - user_code=space_less_user_code, space=space - ) + try: # noqa: SIM105 + workflow_template = WorkflowTemplate.objects.get(user_code=space_less_user_code, space=space) except WorkflowTemplate.DoesNotExist: pass @@ -67,7 +67,7 @@ def execute_workflow( else: _l.info("Execute old version") - data = obj.to_dict() + data = obj.to_dict() # noqa: F841 workflow = WorkflowBuilder(obj.id, wf) workflow.build() # Build the workflow execution plan workflow.run() # Run the workflow diff --git a/workflow_app/celery.py b/workflow_app/celery.py index a0a2a194..6f1bce33 100644 --- a/workflow_app/celery.py +++ b/workflow_app/celery.py @@ -1,22 +1,22 @@ # from __future__ import absolute_import, unicode_literals import logging + from celery import Celery from celery.signals import setup_logging # noqa from django.conf import settings - -_l = logging.getLogger('workflow') +_l = logging.getLogger("workflow") print("Creating Celery app Instance...") -app = Celery('workflow') +app = Celery("workflow") # Using a string here means the worker doesn't have to serialize # the configuration object to child processes. # - namespace='CELERY' means all celery-related configuration keys # should have a `CELERY_` prefix. -app.config_from_object('django.conf:settings', namespace='CELERY') +app.config_from_object("django.conf:settings", namespace="CELERY") @setup_logging.connect @@ -26,4 +26,5 @@ def config_loggers(*args, **kwargs): dictConfig(settings.LOGGING) -app.autodiscover_tasks(lambda: settings.INSTALLED_APPS) \ No newline at end of file + +app.autodiscover_tasks(lambda: settings.INSTALLED_APPS) diff --git a/workflow_app/gunicorn.py b/workflow_app/gunicorn.py index d42ce904..b4325eb9 100644 --- a/workflow_app/gunicorn.py +++ b/workflow_app/gunicorn.py @@ -1,6 +1,5 @@ import os - chdir = "/var/app/" project_name = os.getenv("PROJECT_NAME", "workflow_app") @@ -19,6 +18,7 @@ INSTANCE_TYPE = os.getenv("INSTANCE_TYPE", "web") + def on_starting(server): if INSTANCE_TYPE == "web": print("I'm web_instance") @@ -26,4 +26,4 @@ def on_starting(server): else: print("Gunicorn should not start for INSTANCE_TYPE:", INSTANCE_TYPE) server.log.info("Exiting because this pod is not a web instance") - exit(0) + exit(0) # noqa: PLR1722 diff --git a/workflow_app/openapi.py b/workflow_app/openapi.py index 27d0f13d..b58e213b 100644 --- a/workflow_app/openapi.py +++ b/workflow_app/openapi.py @@ -2,18 +2,17 @@ from django.shortcuts import render from django.urls import include, path from drf_yasg import openapi +from drf_yasg.generators import OpenAPISchemaGenerator from drf_yasg.views import get_schema_view -from drf_yasg.generators import OpenAPISchemaGenerator class TenantSchemaGenerator(OpenAPISchemaGenerator): - def get_schema(self, request=None, public=False): swagger = super().get_schema(request, public) # Iterate over paths and replace placeholder parameters with default values - for path in list(swagger.paths.keys()): - new_path = path.replace('{realm_code}', request.realm_code).replace('{space_code}', request.space_code) + for path in list(swagger.paths.keys()): # noqa: F402 + new_path = path.replace("{realm_code}", request.realm_code).replace("{space_code}", request.space_code) swagger.paths[new_path] = swagger.paths[path] del swagger.paths[path] @@ -22,20 +21,23 @@ def get_schema(self, request=None, public=False): def get_tags(self, operation_keys=None): tags = super().get_tags(operation_keys) - print('tags %s' % tags) + print(f"tags {tags}") # Custom logic to modify tags if necessary return tags + def scheme_get_method_decorator(func): - def wrapper(self, request, version='', format=None, *args, **kwargs): - return func(self, request, version='', format=None) + def wrapper(self, request, version="", format=None, *args, **kwargs): + return func(self, request, version="", format=None) + return wrapper + def generate_schema(local_urlpatterns): schema_view = get_schema_view( openapi.Info( title="Finmars Workflow API", - default_version='v1', + default_version="v1", description="Finmars Documentation", terms_of_service="https://www.finmars.com/policies/terms/", contact=openapi.Contact(email="admin@finmars.com"), @@ -43,8 +45,8 @@ def generate_schema(local_urlpatterns): x_logo={ "url": "https://finmars.com/wp-content/uploads/2023/04/logo.png", "backgroundColor": "#000", - "href": '/' + settings.REALM_CODE + '/docs/api/v1/' - } + "href": "/" + settings.REALM_CODE + "/docs/api/v1/", + }, ), patterns=local_urlpatterns, public=True, @@ -61,7 +63,7 @@ def get_api_documentation(*args, **kwargs): from .urls import router local_urlpatterns = [ - path('//workflow/api/', include(router.urls)), + path("//workflow/api/", include(router.urls)), ] schema_view = generate_schema(local_urlpatterns) @@ -69,25 +71,22 @@ def get_api_documentation(*args, **kwargs): return schema_view - def render_main_page(request, *args, **kwargs): - context = { - 'realm_code': request.realm_code, - 'space_code': request.space_code - } + context = {"realm_code": request.realm_code, "space_code": request.space_code} - return render(request, 'finmars_redoc.html', context) + return render(request, "finmars_redoc.html", context) def get_redoc_urlpatterns(): api_schema_view = get_api_documentation() urlpatterns = [ - - path('//workflow/docs/', render_main_page, name='main'), - path('//workflow/docs/api/', - api_schema_view.with_ui('redoc', cache_timeout=0), name='api'), - + path("//workflow/docs/", render_main_page, name="main"), + path( + "//workflow/docs/api/", + api_schema_view.with_ui("redoc", cache_timeout=0), + name="api", + ), ] return urlpatterns diff --git a/workflow_app/urls.py b/workflow_app/urls.py index 9743a2ce..b9476674 100644 --- a/workflow_app/urls.py +++ b/workflow_app/urls.py @@ -16,49 +16,55 @@ from django.conf import settings from django.contrib import admin -from django.urls import re_path, include, path +from django.urls import include, path, re_path from django.views.generic import TemplateView from rest_framework import routers -from workflow.views import WorkflowViewSet, TaskViewSet, PingViewSet, DefinitionViewSet, RefreshStorageViewSet, \ - LogFileViewSet, CodeExecutionViewSet, RealmMigrateSchemeView, FileExecutionViewSet, ScheduleViewSet, \ - WorkflowTemplateViewSet, CeleryStatusViewSet +from workflow.views import ( + CeleryStatusViewSet, + CodeExecutionViewSet, + DefinitionViewSet, + FileExecutionViewSet, + LogFileViewSet, + RealmMigrateSchemeView, + RefreshStorageViewSet, + ScheduleViewSet, + TaskViewSet, + WorkflowTemplateViewSet, + WorkflowViewSet, +) from workflow_app.openapi import get_redoc_urlpatterns router = routers.DefaultRouter() -router.register(r'workflow', WorkflowViewSet, 'workflow') -router.register(r'workflow-template', WorkflowTemplateViewSet, 'workflow-template') -router.register(r'task', TaskViewSet, "task") +router.register(r"workflow", WorkflowViewSet, "workflow") +router.register(r"workflow-template", WorkflowTemplateViewSet, "workflow-template") +router.register(r"task", TaskViewSet, "task") # router.register(r'ping', PingViewSet, "ping") -router.register(r'refresh-storage', RefreshStorageViewSet, "refresh-storage") -router.register(r'definition', DefinitionViewSet, "ping") -router.register(r'schedule', ScheduleViewSet, "schedule") -router.register(r'log', LogFileViewSet, "log") -router.register(r'execute-code', CodeExecutionViewSet, basename='execute-code') -router.register(r'execute-file', FileExecutionViewSet, basename='execute-file') +router.register(r"refresh-storage", RefreshStorageViewSet, "refresh-storage") +router.register(r"definition", DefinitionViewSet, "ping") +router.register(r"schedule", ScheduleViewSet, "schedule") +router.register(r"log", LogFileViewSet, "log") +router.register(r"execute-code", CodeExecutionViewSet, basename="execute-code") +router.register(r"execute-file", FileExecutionViewSet, basename="execute-file") router.register(r"authorizer/migrate", RealmMigrateSchemeView, "migrate") -router.register(r'celery-status', CeleryStatusViewSet, basename='celery-status') +router.register(r"celery-status", CeleryStatusViewSet, basename="celery-status") urlpatterns = [ - # Old Approach (delete in 1.9.0) # re_path(r'^(?P[^/]+)/workflow/api/', include(router.urls)), # re_path(r'^(?P[^/]+)/workflow/admin/docs/', include('django.contrib.admindocs.urls')), # re_path(r'^(?P[^/]+)/workflow/admin/', admin.site.urls), - - re_path(r'^(?P[^/]+)/workflow/$', TemplateView.as_view(template_name='index.html')), - + re_path(r"^(?P[^/]+)/workflow/$", TemplateView.as_view(template_name="index.html")), # New Approach - re_path(r'^(?P[^/]+)/(?P[^/]+)/workflow/api/', include(router.urls)), - re_path(r'^(?P[^/]+)/(?P[^/]+)/workflow/api/v1/', include(router.urls)), + re_path(r"^(?P[^/]+)/(?P[^/]+)/workflow/api/", include(router.urls)), + re_path(r"^(?P[^/]+)/(?P[^/]+)/workflow/api/v1/", include(router.urls)), # re_path(r'^(?P[^/]+)/(?P[^/]+)/workflow/admin/docs/', # include('django.contrib.admindocs.urls')), re_path(rf"^{settings.REALM_CODE}/(?:space\w{{5}})/workflow/admin/", admin.site.urls), - - re_path(r'^(?P[^/]+)/(?P[^/]+)/workflow/$', - TemplateView.as_view(template_name='index.html')) - + re_path( + r"^(?P[^/]+)/(?P[^/]+)/workflow/$", TemplateView.as_view(template_name="index.html") + ), ] if "drf_yasg" in settings.INSTALLED_APPS: @@ -68,5 +74,5 @@ import debug_toolbar urlpatterns = [ - path('__debug__/', include(debug_toolbar.urls)), - ] + urlpatterns + path("__debug__/", include(debug_toolbar.urls)), + ] + urlpatterns diff --git a/workflow_app/utils.py b/workflow_app/utils.py index 6be0afbf..c6c7534a 100644 --- a/workflow_app/utils.py +++ b/workflow_app/utils.py @@ -1,24 +1,24 @@ -import os import contextlib +import os import warnings -def ENV_BOOL(env_name, default): +def ENV_BOOL(env_name, default): val = os.environ.get(env_name, default) if not val: return default - if val == 'True' or val == True: + if val == "True" or val is True: return True - if val == 'False' or val == False: + if val == "False" or val is False: return False - warnings.warn('Variable %s is not boolean. It is %s' % (env_name, val)) + warnings.warn("Variable %s is not boolean. It is %s", env_name, val) -def ENV_STR(env_name, default): +def ENV_STR(env_name, default): val = os.environ.get(env_name, default) if not val: @@ -26,8 +26,8 @@ def ENV_STR(env_name, default): return val -def ENV_INT(env_name, default): +def ENV_INT(env_name, default): val = os.environ.get(env_name, default) if not val: @@ -37,14 +37,13 @@ def ENV_INT(env_name, default): def print_finmars(): - text = """ ███████╗ ██╗ ███╗ ██╗ ███╗ ███╗ █████╗ ██████╗ ███████╗ ██╔════╝ ██║ ████╗ ██║ ████╗ ████║ ██╔══██╗ ██╔══██╗ ██╔════╝ █████╗ ██║ ██╔██╗ ██║ ██╔████╔██║ ███████║ ██████╔╝ ███████╗ ██╔══╝ ██║ ██║╚██╗██║ ██║╚██╔╝██║ ██╔══██║ ██╔══██╗ ╚════██║ ██║ ██║ ██║ ╚████║ ██║ ╚═╝ ██║ ██║ ██║ ██║ ██║ ███████║ -╚═╝ ╚═╝ ╚═╝ ╚═══╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝ ╚══════╝ +╚═╝ ╚═╝ ╚═╝ ╚═══╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝ ╚══════╝ """ print(text) @@ -52,9 +51,9 @@ def print_finmars(): def filter_sentry_events(event, hint): with contextlib.suppress(Exception): - frames = event['exception']['values'][0]['stacktrace']['frames'] + frames = event["exception"]["values"][0]["stacktrace"]["frames"] for i, frame in enumerate(frames): - if frame['function'] == 'execute_workflow_step' and len(frames) > i+1: + if frame["function"] == "execute_workflow_step" and len(frames) > i + 1: # do not report exceptions raised in custom modules return None return event diff --git a/workflow_app/wsgi.py b/workflow_app/wsgi.py index 92298d31..a4101026 100644 --- a/workflow_app/wsgi.py +++ b/workflow_app/wsgi.py @@ -11,6 +11,6 @@ from django.core.wsgi import get_wsgi_application -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'workflow_app.settings') +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "workflow_app.settings") application = get_wsgi_application()