21
21
# SOFTWARE.
22
22
23
23
24
+ import cv2
25
+ import numpy as np
26
+ import urllib .request
24
27
from pydantic import root_validator
25
28
from typing import Any , Dict , List , Optional , Iterator
26
29
30
+ from cv_bridge import CvBridge
27
31
from action_msgs .msg import GoalStatus
28
32
from llama_msgs .msg import LogitBias
29
33
from llama_msgs .action import GenerateResponse
@@ -39,6 +43,7 @@ class LlamaROS(LLM):
39
43
40
44
namespace : str = "llama"
41
45
llama_client : LlamaClientNode = None
46
+ cv_bridge : CvBridge = CvBridge ()
42
47
43
48
# sampling params
44
49
n_prev : int = 64
@@ -94,11 +99,34 @@ def _llm_type(self) -> str:
94
99
def cancel (self ) -> None :
95
100
self .llama_client .cancel_generate_text ()
96
101
97
- def _create_action_goal (self , prompt : str ) -> GenerateResponse .Result :
102
+ def _create_action_goal (
103
+ self ,
104
+ prompt : str ,
105
+ stop : Optional [List [str ]] = None ,
106
+ image_url : Optional [str ] = None ,
107
+ image : Optional [np .ndarray ] = None ,
108
+ ) -> GenerateResponse .Result :
109
+
98
110
goal = GenerateResponse .Goal ()
99
111
goal .prompt = prompt
100
112
goal .reset = True
101
113
114
+ # load image
115
+ if image_url or image :
116
+
117
+ if image_url and not image :
118
+ req = urllib .request .Request (
119
+ image_url , headers = {"User-Agent" : "Mozilla/5.0" })
120
+ response = urllib .request .urlopen (req )
121
+ arr = np .asarray (bytearray (response .read ()), dtype = np .uint8 )
122
+ image = cv2 .imdecode (arr , - 1 )
123
+
124
+ goal .image = self .cv_bridge .cv2_to_imgmsg (image )
125
+
126
+ # add stop
127
+ if stop :
128
+ goal .stop = stop
129
+
102
130
# sampling params
103
131
goal .sampling_config .n_prev = self .n_prev
104
132
goal .sampling_config .n_probs = self .n_probs
@@ -150,7 +178,7 @@ def _call(
150
178
** kwargs : Any ,
151
179
) -> str :
152
180
153
- goal = self ._create_action_goal (prompt )
181
+ goal = self ._create_action_goal (prompt , stop , ** kwargs )
154
182
155
183
result , status = LlamaClientNode .get_instance (
156
184
self .namespace ).generate_response (goal )
@@ -167,10 +195,14 @@ def _stream(
167
195
** kwargs : Any ,
168
196
) -> Iterator [GenerationChunk ]:
169
197
170
- goal = self ._create_action_goal (prompt )
198
+ goal = self ._create_action_goal (prompt , stop , ** kwargs )
171
199
172
200
for pt in LlamaClientNode .get_instance (
173
201
self .namespace ).generate_response (goal , stream = True ):
202
+
203
+ if run_manager :
204
+ run_manager .on_llm_new_token (pt .text , verbose = self .verbose ,)
205
+
174
206
yield GenerationChunk (text = pt .text )
175
207
176
208
def get_num_tokens (self , text : str ) -> int :
0 commit comments