Skip to content

Commit 57097b0

Browse files
authored
Add llava generate (#4)
added llava generate
1 parent 84b94c6 commit 57097b0

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

lmm_tools/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
BASETEN_API_KEY = "PRxjuebe.VQJQ7rCvswimP5y8GeSmZA03I4zw6dgB"
2+
BASETEN_URL = "https://model-232pg41q.api.baseten.co/production/predict"

lmm_tools/lmm/lmm.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import base64
2+
import requests
23
from abc import ABC, abstractmethod
34
from pathlib import Path
45
from typing import Any, Dict, List, Optional, Union, cast
6+
from lmm_tools.config import BASETEN_API_KEY, BASETEN_URL
57

68

79
def encode_image(image: Union[str, Path]) -> str:
@@ -12,7 +14,7 @@ def encode_image(image: Union[str, Path]) -> str:
1214

1315
class LMM(ABC):
1416
@abstractmethod
15-
def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str:
17+
def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str:
1618
pass
1719

1820

@@ -22,8 +24,16 @@ class LLaVALMM(LMM):
2224
def __init__(self, name: str):
2325
self.name = name
2426

25-
def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str:
26-
raise NotImplementedError("LLaVA LMM not implemented yet")
27+
def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str:
28+
data = {"prompt": prompt}
29+
if image:
30+
data["image"] = encode_image(image)
31+
res = requests.post(
32+
BASETEN_URL,
33+
headers={"Authorization": f"Api-Key {BASETEN_API_KEY}"},
34+
json=data,
35+
)
36+
return res.text
2737

2838

2939
class OpenAILMM(LMM):
@@ -35,7 +45,7 @@ def __init__(self, name: str):
3545
self.name = name
3646
self.client = OpenAI()
3747

38-
def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str:
48+
def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str:
3949
message: List[Dict[str, Any]] = [
4050
{
4151
"role": "user",

0 commit comments

Comments
 (0)