1
1
import base64
2
+ import requests
2
3
from abc import ABC , abstractmethod
3
4
from pathlib import Path
4
5
from typing import Any , Dict , List , Optional , Union , cast
6
+ from lmm_tools .config import BASETEN_API_KEY , BASETEN_URL
5
7
6
8
7
9
def encode_image (image : Union [str , Path ]) -> str :
@@ -12,7 +14,7 @@ def encode_image(image: Union[str, Path]) -> str:
12
14
13
15
class LMM (ABC ):
14
16
@abstractmethod
15
- def generate (self , prompt : str , image : Optional [Union [str , Path ]]) -> str :
17
+ def generate (self , prompt : str , image : Optional [Union [str , Path ]] = None ) -> str :
16
18
pass
17
19
18
20
@@ -22,8 +24,16 @@ class LLaVALMM(LMM):
22
24
def __init__ (self , name : str ):
23
25
self .name = name
24
26
25
- def generate (self , prompt : str , image : Optional [Union [str , Path ]]) -> str :
26
- raise NotImplementedError ("LLaVA LMM not implemented yet" )
27
+ def generate (self , prompt : str , image : Optional [Union [str , Path ]] = None ) -> str :
28
+ data = {"prompt" : prompt }
29
+ if image :
30
+ data ["image" ] = encode_image (image )
31
+ res = requests .post (
32
+ BASETEN_URL ,
33
+ headers = {"Authorization" : f"Api-Key { BASETEN_API_KEY } " },
34
+ json = data ,
35
+ )
36
+ return res .text
27
37
28
38
29
39
class OpenAILMM (LMM ):
@@ -35,7 +45,7 @@ def __init__(self, name: str):
35
45
self .name = name
36
46
self .client = OpenAI ()
37
47
38
- def generate (self , prompt : str , image : Optional [Union [str , Path ]]) -> str :
48
+ def generate (self , prompt : str , image : Optional [Union [str , Path ]] = None ) -> str :
39
49
message : List [Dict [str , Any ]] = [
40
50
{
41
51
"role" : "user" ,
0 commit comments