Skip to content

Commit ee586f9

Browse files
committed
Make async
1 parent cfb27f6 commit ee586f9

File tree

3 files changed

+98
-80
lines changed

3 files changed

+98
-80
lines changed

ollama_manager/app.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
from ollama_manager.commands.delete import delete_model
66
from ollama_manager.commands.pull import pull_model
77
from ollama_manager.commands.run import run_model
8+
from ollama_manager.utils import coro
89

910

1011
@click.group()
11-
def cli():
12+
@coro
13+
async def cli():
1214
pass
1315

1416

ollama_manager/commands/pull.py

Lines changed: 85 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import sys
44

55
import click
6+
import httpx
67
import ollama
7-
import requests
8-
from bs4 import BeautifulSoup
8+
from bs4 import BeautifulSoup, SoupStrainer
99

10-
from ollama_manager.utils import get_session, handle_interaction, make_request
10+
from ollama_manager.utils import coro, handle_interaction
1111

1212

1313
def extract_quantization(text):
@@ -92,11 +92,8 @@ def format_bytes(size_bytes: int) -> str:
9292
return f"{scaled_size:.2f} {_SUFFIXES[magnitude]}"
9393

9494

95-
def list_remote_model_tags(model_name: str, session: requests.Session):
96-
response = make_request(
97-
session=session,
98-
url=f"https://ollama.com/library/{model_name}/tags",
99-
)
95+
async def list_remote_model_tags(model_name: str, client: httpx.AsyncClient):
96+
response = await client.get(f"https://ollama.com/library/{model_name}/tags")
10097
soup = BeautifulSoup(response.text, "html.parser")
10198

10299
# Find the span with the specific attribute
@@ -136,21 +133,21 @@ def list_remote_model_tags(model_name: str, session: requests.Session):
136133
return results
137134

138135

139-
def list_remote_models(session: requests.Session) -> list[str] | None:
140-
response = make_request(session=session, url="https://ollama.com/search")
136+
async def list_remote_models(client: httpx.AsyncClient) -> list[str] | None:
137+
response = await client.get(url="https://ollama.com/search")
141138

142-
soup = BeautifulSoup(response.text, "html.parser")
143-
# Find the span with the specific attribute
144-
# @NOTE: This might change with website updates.
139+
title_strainer = SoupStrainer("span", attrs={"x-test-search-response-title": True})
140+
soup = BeautifulSoup(response.text, "html.parser", parse_only=title_strainer)
145141
elements = soup.find_all("span", attrs={"x-test-search-response-title": True})
142+
146143
if not elements:
147144
return None
148145

149146
return [element.text.strip() for element in elements]
150147

151148

152-
def list_hugging_face_models(
153-
session: requests.Session, limit: int, query: str
149+
async def list_hugging_face_models(
150+
client: httpx.AsyncClient, limit: int, query: str
154151
) -> list[dict[str, str]]:
155152
BASE_API_ENDPOINT = "https://huggingface.co/api/models"
156153
params = {
@@ -162,7 +159,7 @@ def list_hugging_face_models(
162159
"config": False,
163160
"search": query,
164161
}
165-
res = make_request(session, url=BASE_API_ENDPOINT, params=params)
162+
res = await client.get(url=BASE_API_ENDPOINT, params=params)
166163
hf_response = res.json()
167164
payload = []
168165

@@ -175,9 +172,12 @@ def list_hugging_face_models(
175172
return payload
176173

177174

178-
def list_hugging_face_model_quantization(session: requests.Session, model_name: str):
179-
API_ENDPOINT = f"https://huggingface.co/api/models/{model_name}?blobs=true"
180-
res = make_request(session=session, url=API_ENDPOINT)
175+
async def list_hugging_face_model_quantization(
176+
client: httpx.AsyncClient, model_name: str
177+
):
178+
res = await client.get(
179+
url=f"https://huggingface.co/api/models/{model_name}?blobs=true"
180+
)
181181
hf_response = res.json()
182182
payload = []
183183
files = hf_response.get("siblings")
@@ -217,75 +217,81 @@ def list_hugging_face_model_quantization(session: requests.Session, model_name:
217217
type=int,
218218
default=20,
219219
)
220-
def pull_model(hugging_face: bool, query: str, limit: int):
220+
@coro
221+
async def pull_model(hugging_face: bool, query: str, limit: int):
221222
"""
222223
Pull models from Ollama library:
223224
224225
https://ollama.dev/search
225226
"""
226-
session = get_session()
227-
if hugging_face:
228-
if not query:
229-
query = input("🤗 hf search: ")
230-
models = list_hugging_face_models(session, limit, query)
231-
else:
232-
models = list_remote_models(session)
233-
234-
if not models:
235-
print("❌ No models selected for download")
236-
sys.exit(0)
237-
238-
model_selection = handle_interaction(
239-
models, title="📦 Select remote Ollama model\s:\n", multi_select=False
240-
)
241-
if model_selection:
227+
print("Pulling Model....")
228+
async with httpx.AsyncClient() as client:
242229
if hugging_face:
243-
model_tags = list_hugging_face_model_quantization(
244-
session=session, model_name=model_selection[0]
245-
)
230+
if not query:
231+
query = input("🤗 hf search: ")
232+
models = await list_hugging_face_models(client, limit, query)
246233
else:
247-
model_tags = list_remote_model_tags(
248-
model_name=model_selection[0], session=session
249-
)
250-
if not model_tags:
251-
print(f"❌ Failed fetching tags for: {model_selection}. Please try again.")
252-
sys.exit(1)
234+
models = await list_remote_models(client)
253235

254-
max_length = max(len(f"{model_selection}:{tag['title']}") for tag in model_tags)
236+
if not models:
237+
print("❌ No models selected for download")
238+
sys.exit(0)
255239

256-
if hugging_face:
257-
model_name_with_tags = [
258-
f"{tag['title']:<{max_length}}{tag['size']:<{max_length}}{tag['updated']}"
259-
for tag in model_tags
260-
]
261-
else:
262-
model_name_with_tags = [
263-
f"{model_selection[0]}:{tag['title']:<{max_length + 5}}{tag['size']:<{max_length + 5}}{tag['updated']}"
264-
for tag in model_tags
265-
]
266-
selected_model_with_tag = handle_interaction(
267-
model_name_with_tags, title="🔖 Select tag/quantization:\n"
240+
model_selection = handle_interaction(
241+
models, title="📦 Select remote Ollama model\s:\n", multi_select=False
268242
)
269-
if not selected_model_with_tag:
270-
print("No tag selected for the model")
271-
sys.exit(1)
243+
if model_selection:
244+
if hugging_face:
245+
model_tags = await list_hugging_face_model_quantization(
246+
client=client, model_name=model_selection[0]
247+
)
248+
else:
249+
model_tags = await list_remote_model_tags(
250+
model_name=model_selection[0], client=client
251+
)
252+
if not model_tags:
253+
print(
254+
f"❌ Failed fetching tags for: {model_selection}. Please try again."
255+
)
256+
sys.exit(1)
257+
258+
max_length = max(
259+
len(f"{model_selection}:{tag['title']}") for tag in model_tags
260+
)
272261

273-
if hugging_face:
274-
final_model = (
275-
f"hf.co/{model_selection[0]}:{model_name_with_tags[0]}".split()[0]
262+
if hugging_face:
263+
model_name_with_tags = [
264+
f"{tag['title']:<{max_length}}{tag['size']:<{max_length}}{tag['updated']}"
265+
for tag in model_tags
266+
]
267+
else:
268+
model_name_with_tags = [
269+
f"{model_selection[0]}:{tag['title']:<{max_length + 5}}{tag['size']:<{max_length + 5}}{tag['updated']}"
270+
for tag in model_tags
271+
]
272+
selected_model_with_tag = handle_interaction(
273+
model_name_with_tags, title="🔖 Select tag/quantization:\n"
276274
)
277-
else:
278-
final_model = selected_model_with_tag[0].split()[0]
279-
print(f">>> Pulling model: {final_model}")
280-
try:
281-
response = ollama.pull(final_model, stream=True)
282-
screen_padding = 100
283-
284-
for data in response:
285-
out = f"Status: {data.get('status')} | Completed: {format_bytes(data.get('completed'))}/{format_bytes(data.get('total'))}"
286-
print(f"{out:<{screen_padding}}", end="\r", flush=True)
287-
288-
print(f'\r{" " * screen_padding}\r') # Clear screen
289-
print(f"✅ {final_model} model is ready for use!\n\n>>> olm run\n")
290-
except Exception as e:
291-
print(f"❌ Failed downloading {final_model}\n{str(e)}")
275+
if not selected_model_with_tag:
276+
print("No tag selected for the model")
277+
sys.exit(1)
278+
279+
if hugging_face:
280+
final_model = (
281+
f"hf.co/{model_selection[0]}:{model_name_with_tags[0]}".split()[0]
282+
)
283+
else:
284+
final_model = selected_model_with_tag[0].split()[0]
285+
print(f">>> Pulling model: {final_model}")
286+
try:
287+
response = ollama.pull(final_model, stream=True)
288+
screen_padding = 100
289+
290+
for data in response:
291+
out = f"Status: {data.get('status')} | Completed: {format_bytes(data.get('completed'))}/{format_bytes(data.get('total'))}"
292+
print(f"{out:<{screen_padding}}", end="\r", flush=True)
293+
294+
print(f'\r{" " * screen_padding}\r') # Clear screen
295+
print(f"✅ {final_model} model is ready for use!\n\n>>> olm run\n")
296+
except Exception as e:
297+
print(f"❌ Failed downloading {final_model}\n{str(e)}")

ollama_manager/utils/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
1+
import asyncio
12
import sys
3+
from functools import wraps
24

35
import ollama
46
import requests
57
from simple_term_menu import TerminalMenu
68

79

10+
def coro(f):
11+
@wraps(f)
12+
def wrapper(*args, **kwargs):
13+
return asyncio.run(f(*args, **kwargs))
14+
15+
return wrapper
16+
17+
818
def get_session() -> requests.Session:
919
session = requests.Session()
1020
session.headers = {

0 commit comments

Comments
 (0)