Skip to content

Commit

Permalink
Merge pull request #25 from MayNiklas/privacy-mode
Browse files Browse the repository at this point in the history
Privacy update
  • Loading branch information
MayNiklas authored Nov 5, 2024
2 parents 341d4e6 + ffd3476 commit 880bc0c
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 44 deletions.
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,22 @@ It simplifies offloading the heavy lifting of using Whisper to a central GPU ser
## Features

* Transcribes audio files to text using OpenAI Whisper

### Simple frontend
* Includes a simple static frontend to transcribe audio files (`/`)
* Includes a interactive API documentation using the Swagger UI (`/docs`)
* Uses GPU acceleration if available
* Implements a task queue to handle multiple requests (first in, first out)
* Stateless: to prioritize data privacy, the API only stores data in RAM. Audio files are stored using tempfile and are deleted after processing

### Resource friendly GPU support
* Uses GPU acceleration if available
* Supports loading the model into VRAM on startup OR on first request
* Supports unloading the model after a certain time of inactivity

### Privacy focussed
* Stateless: to prioritize data privacy, the API only stores data in RAM. Audio files are stored using tempfile and are deleted after processing.
* Logs don't contain any transcribed text and transcription ids are obfuscated
* Results are deleted from RAM after a given time

## Setup recommendations

This service performs the best, when it is run on a server with a GPU. For using the high-quality models, I recommend using a GPU with at least 12GB of VRAM. The RTX 3060 12GB is most likely the cheapest option for this task.
Expand Down Expand Up @@ -206,7 +214,8 @@ nix develop .#withCUDA
| `CPU_FALLBACK_MODEL` | The fallback when `MAX_MODEL` is not set and CPU mode is needed | name of official model | medium |
| `LOG_DIR` | The directory to store log-file(s) in "" means 'this directory', dir is created if needed | wanted directory name or empty str | "data/" |
| `LOG_FILE` | The name of the log file | arbitrary filename | whisper_api.log |
| `LOG_LEVEL_CONSOLE` | Level of logging for the console | "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL") | "INFO" |
| `LOG_LEVEL_CONSOLE` | The name of the log file | arbitrary filename | whisper_api.log |
| `LOG_PRIVACY_MODE` | Don't display full task uuids and other sensitive data in the logs | `1` (yes) or `0` (no) | 1 |
| `LOG_LEVEL_FILE` | Level of logging for the file | "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL") | "INFO" |
| `LOG_FORMAT` | Format of the log messages | any valid log message format | \*see below\* |
| `LOG_DATE_FORMAT` | Format of the date in log messages | any valid date format | "%d.%m. %H:%M:%S" |
Expand Down
7 changes: 4 additions & 3 deletions src/whisper_api/api_endpoints/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from whisper_api.environment import AUTHORIZED_MAILS
from whisper_api.environment import LOG_DIR
from whisper_api.log_setup import logger
from whisper_api.log_setup import uuid_log_format

V1_PREFIX = "/api/v1"

Expand Down Expand Up @@ -88,7 +89,7 @@ async def status(self, task_id: uuid_hex_t) -> TaskResponse:
"""
task = self.tasks.get(task_id, None)
if task is None:
logger.info(f"task_id '{task_id}' not found")
logger.info(f"task_id '{uuid_log_format(task_id)}' not found")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="task_id not valid",
Expand Down Expand Up @@ -138,15 +139,15 @@ async def srt(self, task_id: uuid_hex_t):
task = self.tasks.get(task_id, None)
# TODO maybe hold a set of tasks that were present but aren't any more for better message?
if task is None:
logger.info(f"task_id '{task_id}' not found")
logger.info(f"task_id '{uuid_log_format(task_id)}' not found")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="task_id not found",
)

# TODO better way for central declaration of those states
if task.status in ["pending", "processing", "failed"]:
logger.info(f"task_id '{task_id}' not ready or failed, status: '{task.status}'")
logger.info(f"task_id '{uuid_log_format(task_id)}' not ready or failed, status: '{task.status}'")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=task.to_transmit_full,
Expand Down
1 change: 1 addition & 0 deletions src/whisper_api/data_models/data_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Literal

uuid_hex_t = str
private_uuid_hex_t = str
task_type_str_t = Literal["transcribe", "translate"]
status_str_t = Literal["pending", "processing", "finished", "failed"]
model_sizes_str_t = Literal["base", "small", "medium", "turbo", "large"]
Expand Down
36 changes: 34 additions & 2 deletions src/whisper_api/data_models/task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime as dt
import io
import re
from tempfile import NamedTemporaryFile
from typing import Any
from uuid import uuid4
Expand All @@ -12,9 +13,40 @@
from whisper_api.data_models.data_types import status_str_t
from whisper_api.data_models.data_types import task_type_str_t
from whisper_api.data_models.data_types import uuid_hex_t
from whisper_api.log_setup import uuid_log_format


class TaskResponse(BaseModel):
class PrivacyAwareTaskBaseModel(BaseModel):
"""BaseModel that doesn't print full uuids when in privacy mode"""

def __str__(self):
"""gets the __str__ of the BaseModel in injects uuid obfuscation if needed"""
original_repr = super().__str__()

# regex pattern to extract the UUID
pattern = r"uuid='([a-f0-9\-]+)'"

# search for the pattern in the repr string
match = re.search(pattern, original_repr)

# check if a match was found
if not match:
return original_repr

old_uuid = match.group(1)
# convert uuid
new_uuid = uuid_log_format(old_uuid)

# replace old uuid with new uuid in the string
new_repr = re.sub(pattern, f"uuid='{new_uuid}'", original_repr)
return new_repr

def __repr__(self):
"""rebuild the BaseModel repr but with potential uuid obfuscation"""
return f"{self.__class__.__name__}({self.__str__()})"


class TaskResponse(PrivacyAwareTaskBaseModel):
"""The class that is returned via the API"""

task_id: str
Expand Down Expand Up @@ -64,7 +96,7 @@ def get_srt_buffer(self) -> io.StringIO:
return buffer


class Task(BaseModel):
class Task(PrivacyAwareTaskBaseModel):
audiofile_name: named_temp_file_name_t
task_type: task_type_str_t

Expand Down
14 changes: 8 additions & 6 deletions src/whisper_api/decoding/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from whisper_api.environment import DEVELOP_MODE
from whisper_api.environment import LOAD_MODEL_ON_STARTUP
from whisper_api.environment import MAX_TASK_QUEUE_SIZE
from whisper_api.log_setup import uuid_log_format

gigabyte_factor = int(1e9)
vram_model_map: dict[model_sizes_str_t, int] = {
Expand Down Expand Up @@ -233,7 +234,7 @@ def handle_task(self, task: Task) -> Task:

self.send_task_update(task)
self.logger.info(
f"Sent update for task {task.uuid}, "
f"Sent update for task {uuid_log_format(task.uuid)}, "
f"status={task.status}, position_in_queue={task.position_in_queue}, "
f"whisper result: 'is {'not' if task.whisper_result else ''} None'"
)
Expand Down Expand Up @@ -266,12 +267,12 @@ def get_unload_time():
try:
with self.task_queue_lock:
task: Task = next(self.task_queue)
logging.debug(f"Extracted new task from queue: '{task.uuid}' ")
self.logger.debug(f"Extracted new task from queue: '{uuid_log_format(task.uuid)}' ")

self.__busy = True
sent_empty_queue_info = False

self.logger.info(f"Now processing task '{task.uuid}'")
self.logger.info(f"Now processing task '{uuid_log_format(task.uuid)}'")
self.logger.info(f"Sending status update to parent")
self.send_status_update() # queue changed in size - status update

Expand Down Expand Up @@ -330,7 +331,7 @@ def run(self):
data = msg.get("data", None)

if task_type is None:
self.logger.debug(f"Decoder received '{task_type=}', weird... continuing - data: {msg=}")
self.logger.debug(f"Decoder received '{task_type=}', weird... ignoring message - data: {msg=}")
continue

elif task_type == "exit": # data is arbitrary since it will not be considered
Expand Down Expand Up @@ -364,12 +365,13 @@ def run(self):
# so just hold it for the whole time and nothing can go wrong :)
with self.task_queue_lock:
try:
self.logger.debug(f"Adding task '{task.uuid}' to queue")
self.logger.debug(f"Adding task '{uuid_log_format(task.uuid)}' to queue")
self.task_queue.put(task)
except OverflowError:
# TODO: maybe add new status "rejected" and a reason to it?
self.logger.warning(
f"Task '{task.uuid}' failed because queue of size {self.task_queue.max_size} is full"
f"Task '{uuid_log_format(task.uuid)}' failed "
f"because queue of size {self.task_queue.max_size} is full"
)
task.status = "failed"
self.send_task_update(task)
Expand Down
1 change: 1 addition & 0 deletions src/whisper_api/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
LOG_ROTATION_WHEN = os.getenv("LOG_ROTATION", "H")
LOG_ROTATION_INTERVAL = int(os.getenv("LOG_ROTATION_INTERVAL", 2))
LOG_ROTATION_BACKUP_COUNT = int(os.getenv("LOG_ROTATION_BACKUP_COUNT", 48))
LOG_PRIVACY_MODE = int(os.getenv("LOG_PRIVACY_MODE", 1))

AUTHORIZED_MAILS = set(os.getenv("LOG_AUTHORIZED_MAILS", "").split(" "))
AUTHORIZED_MAILS = AUTHORIZED_MAILS - {""}
55 changes: 36 additions & 19 deletions src/whisper_api/log_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
from multiprocessing.connection import Connection
from typing import Literal

from whisper_api.data_models.data_types import private_uuid_hex_t
from whisper_api.data_models.data_types import uuid_hex_t
from whisper_api.environment import LOG_DATE_FORMAT
from whisper_api.environment import LOG_FORMAT
from whisper_api.environment import LOG_LEVEL_CONSOLE
from whisper_api.environment import LOG_LEVEL_FILE
from whisper_api.environment import LOG_PRIVACY_MODE
from whisper_api.environment import LOG_ROTATION_BACKUP_COUNT
from whisper_api.environment import LOG_ROTATION_INTERVAL
from whisper_api.environment import LOG_ROTATION_WHEN
Expand All @@ -36,6 +39,17 @@
logger.setLevel(logging.DEBUG)


def uuid_log_format(uid: uuid_hex_t) -> uuid_hex_t | private_uuid_hex_t:
"""
returns the all task uuids shall be logged as.
reason: the uuids shall not be visible in a privacy focussed production deployment
the print of an uuid might allow the host to access the data we try to hide
"""
if LOG_PRIVACY_MODE:
return f"<task_uuid: {uid[:4]}...{uid[-4:]}>"
return uid


# TODO: rotating filehandler?
class PipedFileHandler(TimedRotatingFileHandler):
"""A logger that can be used in two processes, but only writes from MainProcess"""
Expand All @@ -48,7 +62,6 @@ def __init__(self, log_pipe: Connection, log_dir: str, log_file: str, **rotating

super().__init__(self.log_path, **rotating_file_handler_kwargs)
self.log_pipe = log_pipe
self.am_I_main = multiprocessing.current_process().name == "MainProcess"

if multiprocessing.current_process().name == "MainProcess":
# start listening for logs from children
Expand All @@ -58,6 +71,10 @@ def __init__(self, log_pipe: Connection, log_dir: str, log_file: str, **rotating
atexit.register(self.wait_for_listener)
self.is_end = False

@property
def am_I_main(self):
return multiprocessing.current_process().name == "MainProcess"

def wait_for_listener(self):
"""Ensure that we wait for thread when we shall exit"""
# well, we won't try to use the logger when waiting for logging to be finished :)
Expand All @@ -70,27 +87,27 @@ def wait_for_listener(self):

def emit(self, record: logging.LogRecord):
"""Emit the message or send it to the main"""
# only write from main process
if self.am_I_main:

# we need to replace the process name manually, otherwise processName is overwritten with 'MainProcess'
if record.processName != "MainProcess":
_formatter = logging.Formatter(
formatter_string.replace("{processName}", record.processName),
style=formatter_style,
datefmt=formatter_date_fmt,
)

# just use the normal formatter for main messages
else:
_formatter = logging.Formatter(formatter_string, style=formatter_style, datefmt=formatter_date_fmt)

self.setFormatter(_formatter)
super().emit(record)
# if we're in a child process, send the record to the pipe to main process
if not self.am_I_main:
self.log_pipe.send(record)
return

# if we're in a child process, send the record to the pipe
# only write from main process
# we need to replace the process name manually, otherwise processName is overwritten with 'MainProcess'
if record.processName != "MainProcess":
_formatter = logging.Formatter(
formatter_string.replace("{processName}", record.processName),
style=formatter_style,
datefmt=formatter_date_fmt,
)

# just use the normal formatter for main messages
else:
self.log_pipe.send(record)
_formatter = logging.Formatter(formatter_string, style=formatter_style, datefmt=formatter_date_fmt)

self.setFormatter(_formatter)
super().emit(record)

def listen_for_logs_from_children(self, pipe_to_listen_to: Connection, wait_before_exit_s: float = 1.0):
"""Tread listening for logs from children and sending them to main process"""
Expand Down
Loading

0 comments on commit 880bc0c

Please sign in to comment.