Skip to content

Commit c7ab22e

Browse files
committed
langchain wrapper for vlm
1 parent 2dec240 commit c7ab22e

File tree

1 file changed

+35
-3
lines changed

1 file changed

+35
-3
lines changed

llama_ros/llama_ros/langchain/llama_ros.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,13 @@
2121
# SOFTWARE.
2222

2323

24+
import cv2
25+
import numpy as np
26+
import urllib.request
2427
from pydantic import root_validator
2528
from typing import Any, Dict, List, Optional, Iterator
2629

30+
from cv_bridge import CvBridge
2731
from action_msgs.msg import GoalStatus
2832
from llama_msgs.msg import LogitBias
2933
from llama_msgs.action import GenerateResponse
@@ -39,6 +43,7 @@ class LlamaROS(LLM):
3943

4044
namespace: str = "llama"
4145
llama_client: LlamaClientNode = None
46+
cv_bridge: CvBridge = CvBridge()
4247

4348
# sampling params
4449
n_prev: int = 64
@@ -94,11 +99,34 @@ def _llm_type(self) -> str:
9499
def cancel(self) -> None:
95100
self.llama_client.cancel_generate_text()
96101

97-
def _create_action_goal(self, prompt: str) -> GenerateResponse.Result:
102+
def _create_action_goal(
103+
self,
104+
prompt: str,
105+
stop: Optional[List[str]] = None,
106+
image_url: Optional[str] = None,
107+
image: Optional[np.ndarray] = None,
108+
) -> GenerateResponse.Result:
109+
98110
goal = GenerateResponse.Goal()
99111
goal.prompt = prompt
100112
goal.reset = True
101113

114+
# load image
115+
if image_url or image:
116+
117+
if image_url and not image:
118+
req = urllib.request.Request(
119+
image_url, headers={"User-Agent": "Mozilla/5.0"})
120+
response = urllib.request.urlopen(req)
121+
arr = np.asarray(bytearray(response.read()), dtype=np.uint8)
122+
image = cv2.imdecode(arr, -1)
123+
124+
goal.image = self.cv_bridge.cv2_to_imgmsg(image)
125+
126+
# add stop
127+
if stop:
128+
goal.stop = stop
129+
102130
# sampling params
103131
goal.sampling_config.n_prev = self.n_prev
104132
goal.sampling_config.n_probs = self.n_probs
@@ -150,7 +178,7 @@ def _call(
150178
**kwargs: Any,
151179
) -> str:
152180

153-
goal = self._create_action_goal(prompt)
181+
goal = self._create_action_goal(prompt, stop, **kwargs)
154182

155183
result, status = LlamaClientNode.get_instance(
156184
self.namespace).generate_response(goal)
@@ -167,10 +195,14 @@ def _stream(
167195
**kwargs: Any,
168196
) -> Iterator[GenerationChunk]:
169197

170-
goal = self._create_action_goal(prompt)
198+
goal = self._create_action_goal(prompt, stop, **kwargs)
171199

172200
for pt in LlamaClientNode.get_instance(
173201
self.namespace).generate_response(goal, stream=True):
202+
203+
if run_manager:
204+
run_manager.on_llm_new_token(pt.text, verbose=self.verbose,)
205+
174206
yield GenerationChunk(text=pt.text)
175207

176208
def get_num_tokens(self, text: str) -> int:

0 commit comments

Comments
 (0)