From c7fc11057332297016fc65410cf190c5f255bded Mon Sep 17 00:00:00 2001 From: Bryce Dubayah Date: Wed, 11 Sep 2024 20:56:37 +0000 Subject: [PATCH] fix lock for concurrent writes --- pyproject.toml | 2 +- truss/templates/trtllm-briton/src/engine.py | 19 ++++++++++++++----- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f07866187..5349b6a7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.35rc1" +version = "0.9.35rc1dev1" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" diff --git a/truss/templates/trtllm-briton/src/engine.py b/truss/templates/trtllm-briton/src/engine.py index 8020bacce..235479e49 100644 --- a/truss/templates/trtllm-briton/src/engine.py +++ b/truss/templates/trtllm-briton/src/engine.py @@ -3,6 +3,7 @@ import fcntl import hashlib import json +import math import multiprocessing import os import signal @@ -122,8 +123,11 @@ def __init__(self, **kwargs): predict_concurrency = runtime.get("predict_concurrency", 1) cpu_count = os.cpu_count() self._max_fsm_workers = ( - min(predict_concurrency, cpu_count) if cpu_count else predict_concurrency + min(predict_concurrency, math.ceil(cpu_count / 2)) + if cpu_count + else predict_concurrency ) + print(f"Using {self._max_fsm_workers} workers for FSM schema generation") def load(self): if self._loaded: @@ -424,10 +428,15 @@ def worker(vocab_size: int, end_id: int, schema: Dict[str, Any], output_path: Pa vocab_size=vocab_size, eos_token_id=end_id, ) - with open(output_path, "wb") as f: - fcntl.flock(f, fcntl.LOCK_EX) - f.write(states_to_tokens_pb.SerializeToString()) - fcntl.flock(f, fcntl.LOCK_UN) + if not output_path.exists(): + try: + fd = os.open(output_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY) + with os.fdopen(fd, "wb") as f: + fcntl.flock(f, fcntl.LOCK_EX) + f.write(states_to_tokens_pb.SerializeToString()) + fcntl.flock(f, fcntl.LOCK_UN) + except FileExistsError: + pass def dummy_task():