Skip to content

Commit f53056d

Browse files
committed
more compact operator formatting
1 parent 14f2846 commit f53056d

29 files changed

+248
-255
lines changed

.style.yapf

+6-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,9 @@ indent_dictionary_value = True
1111
allow_multiline_dictionary_keys = True
1212
each_dict_entry_on_separate_line = False
1313
allow_multiline_lambdas = True
14-
blank_line_before_nested_class_or_def = False
14+
blank_line_before_nested_class_or_def = False
15+
arithmetic_precedence_indication = True
16+
no_spaces_around_selected_binary_operators = "*,/"
17+
coalesce_brackets = True
18+
space_between_ending_comma_and_closing_bracket = False
19+
split_before_expression_after_opening_paren = False

exo/api/chatgpt_api.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def __init__(self, node: Node, inference_engine_classname: str, response_timeout
158158
self.inference_engine_classname = inference_engine_classname
159159
self.response_timeout_secs = response_timeout_secs
160160
self.on_chat_completion_request = on_chat_completion_request
161-
self.app = web.Application(client_max_size=100 * 1024 * 1024) # 100MB to support image upload
161+
self.app = web.Application(client_max_size=100*1024*1024) # 100MB to support image upload
162162
self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
163163
self.prev_token_lens: Dict[str, int] = {}
164164
self.stream_tasks: Dict[str, asyncio.Task] = {}
@@ -171,7 +171,7 @@ def __init__(self, node: Node, inference_engine_classname: str, response_timeout
171171
)
172172
cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
173173
cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
174-
self.static_dir = Path(__file__).parent.parent.parent / "tinychat/examples/tinychat"
174+
self.static_dir = Path(__file__).parent.parent.parent/"tinychat/examples/tinychat"
175175
self.app.router.add_get("/", self.handle_root)
176176
self.app.router.add_static("/", self.static_dir, name="static")
177177

@@ -186,7 +186,7 @@ async def middleware(request):
186186
return middleware
187187

188188
async def handle_root(self, request):
189-
return web.FileResponse(self.static_dir / "index.html")
189+
return web.FileResponse(self.static_dir/"index.html")
190190

191191
async def handle_post_chat_token_encode(self, request):
192192
data = await request.json()

exo/download/hf/hf_helpers.py

+21-21
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,12 @@ def _add_wildcard_to_directories(pattern: str) -> str:
6262

6363
def get_hf_home() -> Path:
6464
"""Get the Hugging Face home directory."""
65-
return Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface"))
65+
return Path(os.environ.get("HF_HOME", Path.home()/".cache"/"huggingface"))
6666

6767

6868
async def get_hf_token():
6969
"""Retrieve the Hugging Face token from the user's HF_HOME directory."""
70-
token_path = get_hf_home() / "token"
70+
token_path = get_hf_home()/"token"
7171
if await aios.path.exists(token_path):
7272
async with aiofiles.open(token_path, 'r') as f:
7373
return (await f.read()).strip()
@@ -85,7 +85,7 @@ async def get_auth_headers():
8585
def get_repo_root(repo_id: str) -> Path:
8686
"""Get the root directory for a given repo ID in the Hugging Face cache."""
8787
sanitized_repo_id = repo_id.replace("/", "--")
88-
return get_hf_home() / "hub" / f"models--{sanitized_repo_id}"
88+
return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
8989

9090

9191
async def fetch_file_list(session, repo_id, revision, path=""):
@@ -181,9 +181,9 @@ async def download_file(
181181
downloaded_this_session += len(chunk)
182182
if progress_callback and total_size:
183183
elapsed_time = (datetime.now() - start_time).total_seconds()
184-
speed = int(downloaded_this_session / elapsed_time) if elapsed_time > 0 else 0
184+
speed = int(downloaded_this_session/elapsed_time) if elapsed_time > 0 else 0
185185
remaining_size = total_size - downloaded_size
186-
eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
186+
eta = timedelta(seconds=remaining_size/speed) if speed > 0 else timedelta(0)
187187
status = "in_progress" if downloaded_size < total_size else "complete"
188188
if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
189189
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
@@ -199,17 +199,17 @@ async def download_repo_files(
199199
max_parallel_downloads: int = 4
200200
) -> Path:
201201
repo_root = get_repo_root(repo_id)
202-
refs_dir = repo_root / "refs"
203-
snapshots_dir = repo_root / "snapshots"
204-
cachedreqs_dir = repo_root / "cachedreqs"
202+
refs_dir = repo_root/"refs"
203+
snapshots_dir = repo_root/"snapshots"
204+
cachedreqs_dir = repo_root/"cachedreqs"
205205

206206
# Ensure directories exist
207207
await aios.makedirs(refs_dir, exist_ok=True)
208208
await aios.makedirs(snapshots_dir, exist_ok=True)
209209
await aios.makedirs(cachedreqs_dir, exist_ok=True)
210210

211211
# Check if we have a cached commit hash
212-
refs_file = refs_dir / revision
212+
refs_file = refs_dir/revision
213213
if await aios.path.exists(refs_file):
214214
async with aiofiles.open(refs_file, 'r') as f:
215215
commit_hash = (await f.read()).strip()
@@ -230,13 +230,13 @@ async def download_repo_files(
230230
await f.write(commit_hash)
231231

232232
# Set up the snapshot directory
233-
snapshot_dir = snapshots_dir / commit_hash
233+
snapshot_dir = snapshots_dir/commit_hash
234234
await aios.makedirs(snapshot_dir, exist_ok=True)
235235

236236
# Set up the cached file list directory
237-
cached_file_list_dir = cachedreqs_dir / commit_hash
237+
cached_file_list_dir = cachedreqs_dir/commit_hash
238238
await aios.makedirs(cached_file_list_dir, exist_ok=True)
239-
cached_file_list_path = cached_file_list_dir / "fetch_file_list.json"
239+
cached_file_list_path = cached_file_list_dir/"fetch_file_list.json"
240240

241241
async with aiohttp.ClientSession() as session:
242242
# Check if we have a cached file list
@@ -261,17 +261,17 @@ async def download_repo_files(
261261
start_time = datetime.now()
262262

263263
async def download_with_progress(file_info, progress_state):
264-
local_path = snapshot_dir / file_info["path"]
264+
local_path = snapshot_dir/file_info["path"]
265265
if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
266266
if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
267267
progress_state['completed_files'] += 1
268268
progress_state['downloaded_bytes'] += file_info["size"]
269269
file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
270270
if progress_callback:
271271
elapsed_time = (datetime.now() - start_time).total_seconds()
272-
overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
272+
overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
273273
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
274-
overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
274+
overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
275275
status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
276276
await progress_callback(
277277
RepoProgressEvent(
@@ -287,9 +287,9 @@ async def file_progress_callback(event: RepoFileProgressEvent):
287287
file_progress[event.file_path] = event
288288
if progress_callback:
289289
elapsed_time = (datetime.now() - start_time).total_seconds()
290-
overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
290+
overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
291291
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
292-
overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
292+
overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
293293
status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
294294
await progress_callback(
295295
RepoProgressEvent(
@@ -305,9 +305,9 @@ async def file_progress_callback(event: RepoFileProgressEvent):
305305
] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
306306
if progress_callback:
307307
elapsed_time = (datetime.now() - start_time).total_seconds()
308-
overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
308+
overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
309309
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
310-
overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
310+
overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
311311
status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
312312
await progress_callback(
313313
RepoProgressEvent(
@@ -347,11 +347,11 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[
347347

348348
# Check if the file exists
349349
repo_root = get_repo_root(repo_id)
350-
snapshot_dir = repo_root / "snapshots"
350+
snapshot_dir = repo_root/"snapshots"
351351
index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
352352

353353
if index_file:
354-
index_file_path = snapshot_dir / index_file
354+
index_file_path = snapshot_dir/index_file
355355
if await aios.path.exists(index_file_path):
356356
async with aiofiles.open(index_file_path, 'r') as f:
357357
index_data = json.loads(await f.read())

exo/download/hf/hf_shard_download.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ async def ensure_shard(self, shard: Shard) -> Path:
2222
return self.completed_downloads[shard]
2323
if self.quick_check:
2424
repo_root = get_repo_root(shard.model_id)
25-
snapshots_dir = repo_root / "snapshots"
25+
snapshots_dir = repo_root/"snapshots"
2626
if snapshots_dir.exists():
2727
most_recent_dir = max(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime)
2828
return most_recent_dir

exo/helpers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def is_valid_uuid(val):
169169

170170

171171
def get_or_create_node_id():
172-
NODE_ID_FILE = Path(os.path.dirname(os.path.abspath(__file__))) / ".exo_node_id"
172+
NODE_ID_FILE = Path(os.path.dirname(os.path.abspath(__file__)))/".exo_node_id"
173173
try:
174174
if NODE_ID_FILE.is_file():
175175
with open(NODE_ID_FILE, "r") as f:

exo/inference/debug_inference_engine.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
1010
from exo.inference.tinygrad.inference import Tokenizer
1111
from pathlib import Path
1212

13-
_tokenizer = Tokenizer(str(Path(model_id) / "tokenizer.model"))
13+
_tokenizer = Tokenizer(str(Path(model_id)/"tokenizer.model"))
1414

1515
prompt = "In a single word only, what is the last name of the president of the United States? "
1616
resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)

exo/inference/mlx/models/deepseek_v2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __call__(
5959
mask = mask.astype(h.dtype)
6060

6161
if cache is None:
62-
cache = [None] * len(self.layers)
62+
cache = [None]*len(self.layers)
6363

6464
for layer, c in zip(self.layers, cache):
6565
h = layer(h, mask, c)

exo/inference/mlx/models/llama.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __call__(
5858
mask = create_attention_mask(h, cache)
5959

6060
if cache is None:
61-
cache = [None] * len(self.layers)
61+
cache = [None]*len(self.layers)
6262

6363
for layer, c in zip(self.layers, cache):
6464
h = layer(h, mask, cache=c)

exo/inference/mlx/models/llava.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ def __call__(self, queries, keys, values, mask=None):
7474
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
7575
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
7676

77-
scale = math.sqrt(1 / queries.shape[-1])
78-
scores = (queries * scale) @ keys
77+
scale = math.sqrt(1/queries.shape[-1])
78+
scores = (queries*scale) @ keys
7979
if mask is not None:
8080
scores = scores + mask.astype(scores.dtype)
8181
scores = mx.softmax(scores, axis=-1)
@@ -129,7 +129,7 @@ def __init__(self, config: VisionConfig):
129129
self.image_size = config.image_size
130130
self.patch_size = config.patch_size
131131

132-
self.class_embedding = mx.zeros((config.hidden_size, ))
132+
self.class_embedding = mx.zeros((config.hidden_size,))
133133

134134
self.patch_embedding = nn.Conv2d(
135135
in_channels=config.num_channels,
@@ -170,12 +170,12 @@ def __call__(
170170
x = self.embeddings(x)
171171
x = self.pre_layrnorm(x)
172172

173-
encoder_states = (x, ) if output_hidden_states else None
173+
encoder_states = (x,) if output_hidden_states else None
174174

175175
for l in self.encoder.layers:
176176
x = l(x, mask=None)
177177
if output_hidden_states:
178-
encoder_states = encoder_states + (x, )
178+
encoder_states = encoder_states + (x,)
179179

180180
pooler_output = self.post_layernorm(x[:, 0, :])
181181
return pooler_output, x, encoder_states
@@ -263,12 +263,12 @@ def __init__(self, config: TextConfig):
263263
head_dim = config.hidden_size // n_heads
264264
self.scale = head_dim**-0.5
265265

266-
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
267-
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
268-
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
269-
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
266+
self.q_proj = nn.Linear(dim, n_heads*head_dim, bias=False)
267+
self.k_proj = nn.Linear(dim, n_kv_heads*head_dim, bias=False)
268+
self.v_proj = nn.Linear(dim, n_kv_heads*head_dim, bias=False)
269+
self.o_proj = nn.Linear(n_heads*head_dim, dim, bias=False)
270270

271-
rope_scale = (1 / config.rope_scaling["factor"] if config.rope_scaling is not None and config.rope_scaling["type"] == "linear" else 1)
271+
rope_scale = (1/config.rope_scaling["factor"] if config.rope_scaling is not None and config.rope_scaling["type"] == "linear" else 1)
272272
self.rope = nn.RoPE(
273273
head_dim,
274274
traditional=config.rope_traditional,
@@ -312,7 +312,7 @@ def __init__(self, dim, hidden_dim):
312312
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
313313

314314
def __call__(self, x) -> mx.array:
315-
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
315+
return self.down_proj(nn.silu(self.gate_proj(x))*self.up_proj(x))
316316

317317

318318
class TransformerBlock(nn.Module):
@@ -382,7 +382,7 @@ def __call__(
382382
mask = mask.astype(h.dtype)
383383

384384
if cache is None:
385-
cache = [None] * len(self.layers)
385+
cache = [None]*len(self.layers)
386386

387387
for layer, c in zip(self.layers, cache):
388388
h = layer(h, mask, c)

exo/inference/mlx/sharded_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]:
3838
if top_p > 0 and top_p < 1.0:
3939
token = top_p_sampling(logits, top_p, temp)
4040
else:
41-
token = mx.random.categorical(logits * (1 / temp))
41+
token = mx.random.categorical(logits*(1/temp))
4242

4343
return token
4444

@@ -74,7 +74,7 @@ def __call__(
7474
return self.step(request_id, x, temp=temp, top_p=top_p, logit_bias=logit_bias)
7575

7676
def init_cache(self, request_id: str):
77-
kv_heads = ([self.model.n_kv_heads] * len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
77+
kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
7878
if self.max_kv_size is not None:
7979
cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
8080
else:

exo/inference/mlx/sharded_utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def _get_classes(config: dict):
6060

6161
def load_config(model_path: Path) -> dict:
6262
try:
63-
with open(model_path / "config.json", "r") as f:
63+
with open(model_path/"config.json", "r") as f:
6464
config = json.load(f)
6565
except FileNotFoundError:
6666
logging.error(f"Config file not found in {model_path}")
@@ -103,11 +103,11 @@ def load_model_shard(
103103
"n_layers": shard.n_layers,
104104
}
105105

106-
weight_files = glob.glob(str(model_path / "model*.safetensors"))
106+
weight_files = glob.glob(str(model_path/"model*.safetensors"))
107107

108108
if not weight_files:
109109
# Try weight for back-compat
110-
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
110+
weight_files = glob.glob(str(model_path/"weight*.safetensors"))
111111

112112
if not weight_files:
113113
logging.error(f"No safetensors found in {model_path}")

exo/inference/mlx/test_sharded_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __call__(self, x, cache=None):
3838
n_layers = 5
3939
shard1 = Shard("test", 0, n_layers // 2, n_layers)
4040
sharded_model1 = DummyModel(shard1)
41-
shard2 = Shard("test", n_layers // 2 + 1, n_layers - 1, n_layers)
41+
shard2 = Shard("test", n_layers//2 + 1, n_layers - 1, n_layers)
4242
sharded_model2 = DummyModel(shard2)
4343

4444
model.load_weights("./test_weights.npz")

exo/inference/tinygrad/inference.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
3333

3434
# load weights
3535
if model_path.is_dir():
36-
if (model_path / "model.safetensors.index.json").exists(): weights = load(str(model_path / "model.safetensors.index.json"), shard)
37-
elif (model_path / "model.safetensors").exists(): weights = load(str(model_path / "model.safetensors"), shard)
38-
else: weights = concat_weights([load(str(model_path / f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
36+
if (model_path/"model.safetensors.index.json").exists(): weights = load(str(model_path/"model.safetensors.index.json"), shard)
37+
elif (model_path/"model.safetensors").exists(): weights = load(str(model_path/"model.safetensors"), shard)
38+
else: weights = concat_weights([load(str(model_path/f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
3939
else:
4040
weights = load(str(model_path), shard)
4141
weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
@@ -60,7 +60,7 @@ async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_s
6060
toks = self.tokenizer.encode(prompt)
6161
h = self.model(Tensor([toks]), start_pos, TEMPERATURE).realize()
6262

63-
if h.shape == (1, ):
63+
if h.shape == (1,):
6464
start_pos += len(toks)
6565
start_pos += 1
6666
n_captured_toks = 0
@@ -76,7 +76,7 @@ async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarr
7676

7777
h = self.model(Tensor(input_data), start_pos, TEMPERATURE).realize()
7878

79-
if h.shape == (1, ):
79+
if h.shape == (1,):
8080
start_pos += n_captured_toks
8181
start_pos += 1
8282
n_captured_toks = 0

0 commit comments

Comments
 (0)