Skip to content

Commit a6cec71

Browse files
authored
Add ability to use image for Reflexion (#17)
* added images for reflexion * formatting fix * added more comments, fixed regex
1 parent 04b6411 commit a6cec71

File tree

5 files changed

+278
-44
lines changed

5 files changed

+278
-44
lines changed

tests/test_lmm.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,45 @@ def test_generate_with_mock(openai_lmm_mock): # noqa: F811
3333
)
3434

3535

36+
@pytest.mark.parametrize(
37+
"openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"]
38+
)
39+
def test_chat_with_mock(openai_lmm_mock): # noqa: F811
40+
lmm = OpenAILMM()
41+
response = lmm.chat([{"role": "user", "content": "test prompt"}])
42+
assert response == "mocked response"
43+
assert (
44+
openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][
45+
"content"
46+
][0]["text"]
47+
== "test prompt"
48+
)
49+
50+
51+
@pytest.mark.parametrize(
52+
"openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"]
53+
)
54+
def test_call_with_mock(openai_lmm_mock): # noqa: F811
55+
lmm = OpenAILMM()
56+
response = lmm("test prompt")
57+
assert response == "mocked response"
58+
assert (
59+
openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][
60+
"content"
61+
][0]["text"]
62+
== "test prompt"
63+
)
64+
65+
response = lmm([{"role": "user", "content": "test prompt"}])
66+
assert response == "mocked response"
67+
assert (
68+
openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][
69+
"content"
70+
][0]["text"]
71+
== "test prompt"
72+
)
73+
74+
3675
@pytest.mark.parametrize(
3776
"openai_lmm_mock",
3877
['{"Parameters": {"prompt": "cat"}}'],

vision_agent/agent/agent.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
from abc import ABC, abstractmethod
2-
from typing import Dict, List, Union
2+
from pathlib import Path
3+
from typing import Dict, List, Optional, Union
34

45

56
class Agent(ABC):
67
@abstractmethod
7-
def __call__(self, input: Union[List[Dict[str, str]], str]) -> str:
8+
def __call__(
9+
self,
10+
input: Union[List[Dict[str, str]], str],
11+
image: Optional[Union[str, Path]] = None,
12+
) -> str:
813
pass

vision_agent/agent/reflexion.py

Lines changed: 145 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import logging
12
import re
3+
import sys
4+
from pathlib import Path
25
from typing import Dict, List, Optional, Tuple, Union
36

4-
from vision_agent import LLM, OpenAILLM
7+
from vision_agent import LLM, LMM, OpenAILLM
58

69
from .agent import Agent
710
from .reflexion_prompts import (
@@ -13,20 +16,27 @@
1316
REFLECTION_HEADER,
1417
)
1518

19+
logging.basicConfig(stream=sys.stdout)
20+
21+
_LOGGER = logging.getLogger(__name__)
22+
1623

1724
def format_step(step: str) -> str:
1825
return step.strip("\n").strip().replace("\n", "")
1926

2027

2128
def parse_action(input: str) -> Tuple[str, str]:
22-
pattern = r"^(\w+)\[(.+)\]$"
23-
match = re.match(pattern, input)
29+
# Make the pattern slightly less strict, the LMMs are not as good at following
30+
# instructions so they often would fail on the original regex.
31+
pattern = r"(\w+)\[(.+)\]"
32+
match = re.search(pattern, input)
2433

2534
if match:
2635
action_type = match.group(1)
2736
argument = match.group(2)
2837
return action_type, argument
2938

39+
_LOGGER.error(f"Invalid action: {input}")
3040
raise ValueError(f"Invalid action: {input}")
3141

3242

@@ -47,41 +57,107 @@ def format_chat(chat: List[Dict[str, str]]) -> str:
4757

4858

4959
class Reflexion(Agent):
60+
r"""This is an implementation of the Reflexion paper https://arxiv.org/abs/2303.11366
61+
based on the original implementation https://github.com/noahshinn/reflexion in the
62+
hotpotqa folder. There are several differences between this implementation and the
63+
original one. Because we do not have instant feedback on whether or not the agent
64+
was correct, we use user feedback to determine if the agent was correct. The user
65+
feedback is evaluated by the self_reflect_model with a new prompt. We also expand
66+
Reflexion to include the ability to use an image as input to the action_agent and the
67+
self_reflect_model. Using Reflexion with LMMs may not work well, if it gets it wrong
68+
the first time, chances are it can't actually see the thing you want it to see.
69+
70+
Examples::
71+
>>> from vision_agent.agent import Reflexion
72+
>>> agent = Reflexion()
73+
>>> question = "How many tires does a truck have?"
74+
>>> resp = agent(question)
75+
>>> print(resp)
76+
>>> "18"
77+
>>> resp = agent([
78+
>>> {"role": "user", "content": question},
79+
>>> {"role": "assistant", "content": resp},
80+
>>> {"role": "user", "content": "No I mean those regular trucks but where the back tires are double."}
81+
>>> ])
82+
>>> print(resp)
83+
>>> "6"
84+
>>> agent = Reflexion(
85+
>>> self_reflect_model=va.lmm.OpenAILMM(),
86+
>>> action_agent=va.lmm.OpenAILMM()
87+
>>> )
88+
>>> quesiton = "How many hearts are in this image?"
89+
>>> resp = agent(question, image="cards.png")
90+
>>> print(resp)
91+
>>> "6"
92+
>>> resp = agent([
93+
>>> {"role": "user", "content": question},
94+
>>> {"role": "assistant", "content": resp},
95+
>>> {"role": "user", "content": "No, please count the hearts on the bottom card."}
96+
>>> ], image="cards.png")
97+
>>> print(resp)
98+
>>> "4"
99+
)
100+
"""
101+
50102
def __init__(
51103
self,
52104
cot_examples: str = COTQA_SIMPLE6,
53105
reflect_examples: str = COT_SIMPLE_REFLECTION,
54106
agent_prompt: str = COT_AGENT_REFLECT_INSTRUCTION,
55107
reflect_prompt: str = COT_REFLECT_INSTRUCTION,
56108
finsh_prompt: str = CHECK_FINSH,
57-
self_reflect_llm: Optional[LLM] = None,
58-
action_agent: Optional[Union[Agent, LLM]] = None,
109+
self_reflect_model: Optional[Union[LLM, LMM]] = None,
110+
action_agent: Optional[Union[Agent, LLM, LMM]] = None,
111+
verbose: bool = False,
59112
):
60113
self.agent_prompt = agent_prompt
61114
self.reflect_prompt = reflect_prompt
62115
self.finsh_prompt = finsh_prompt
63116
self.cot_examples = cot_examples
64117
self.refelct_examples = reflect_examples
65118
self.reflections: List[str] = []
119+
if verbose:
120+
_LOGGER.setLevel(logging.INFO)
121+
122+
if isinstance(self_reflect_model, LLM) and not isinstance(action_agent, LLM):
123+
raise ValueError(
124+
"If self_reflect_model is an LLM, then action_agent must also be an LLM."
125+
)
126+
if isinstance(self_reflect_model, LMM) and isinstance(action_agent, LLM):
127+
raise ValueError(
128+
"If self_reflect_model is an LMM, then action_agent must also be an agent or LMM."
129+
)
66130

67-
if self_reflect_llm is None:
68-
self.self_reflect_llm = OpenAILLM()
69-
if action_agent is None:
70-
self.action_agent = OpenAILLM()
131+
self.self_reflect_model = (
132+
OpenAILLM() if self_reflect_model is None else self_reflect_model
133+
)
134+
self.action_agent = OpenAILLM() if action_agent is None else action_agent
71135

72-
def __call__(self, input: Union[List[Dict[str, str]], str]) -> str:
136+
def __call__(
137+
self,
138+
input: Union[str, List[Dict[str, str]]],
139+
image: Optional[Union[str, Path]] = None,
140+
) -> str:
73141
if isinstance(input, str):
74142
input = [{"role": "user", "content": input}]
75-
return self.chat(input)
143+
return self.chat(input, image)
76144

77-
def chat(self, chat: List[Dict[str, str]]) -> str:
145+
def chat(
146+
self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None
147+
) -> str:
78148
if len(chat) == 0 or chat[0]["role"] != "user":
79149
raise ValueError(
80-
f"Invalid chat. Should start with user and then assistant and contain at least one entry {chat}"
150+
f"Invalid chat. Should start with user and alternate between user"
151+
f"and assistant and contain at least one entry {chat}"
81152
)
153+
if image is not None and isinstance(self.action_agent, LLM):
154+
raise ValueError(
155+
"If image is provided, then action_agent must be an agent or LMM."
156+
)
157+
82158
question = chat[0]["content"]
83159
if len(chat) == 1:
84-
results = self._step(question)
160+
results = self._step(question, image=image)
85161
self.last_scratchpad = results["scratchpad"]
86162
return results["action_arg"]
87163

@@ -91,23 +167,33 @@ def chat(self, chat: List[Dict[str, str]]) -> str:
91167
self.last_scratchpad += "\nObservation: "
92168
if is_correct:
93169
self.last_scratchpad += "Answer is CORRECT"
94-
return self.self_reflect_llm(chat)
170+
return self.self_reflect_model(chat)
95171
else:
96172
self.last_scratchpad += "Answer is INCORRECT"
97173
chat_context = "The previous conversation was:\n" + chat_str
98-
reflections = self.reflect(question, chat_context, self.last_scratchpad)
99-
results = self._step(question, reflections)
174+
reflections = self.reflect(
175+
question, chat_context, self.last_scratchpad, image
176+
)
177+
_LOGGER.info(f" {reflections}")
178+
results = self._step(question, reflections, image=image)
100179
self.last_scratchpad = results["scratchpad"]
101180
return results["action_arg"]
102181

103-
def _step(self, question: str, reflections: str = "") -> Dict[str, str]:
182+
def _step(
183+
self,
184+
question: str,
185+
reflections: str = "",
186+
image: Optional[Union[str, Path]] = None,
187+
) -> Dict[str, str]:
104188
# Think
105189
scratchpad = "\nThought:"
106-
scratchpad += " " + self.prompt_agent(question, reflections, scratchpad)
190+
scratchpad += " " + self.prompt_agent(question, reflections, scratchpad, image)
191+
_LOGGER.info(f" {scratchpad}")
107192

108193
# Act
109194
scratchpad += "\nAction:"
110-
action = self.prompt_agent(question, reflections, scratchpad)
195+
action = self.prompt_agent(question, reflections, scratchpad, image)
196+
_LOGGER.info(f" {action}")
111197
scratchpad += " " + action
112198
action_type, argument = parse_action(action)
113199
return {
@@ -116,23 +202,55 @@ def _step(self, question: str, reflections: str = "") -> Dict[str, str]:
116202
"action_arg": argument,
117203
}
118204

119-
def reflect(self, question: str, context: str, scratchpad: str) -> str:
120-
self.reflections += [self.prompt_reflection(question, context, scratchpad)]
205+
def reflect(
206+
self,
207+
question: str,
208+
context: str,
209+
scratchpad: str,
210+
image: Optional[Union[str, Path]],
211+
) -> str:
212+
self.reflections += [
213+
self.prompt_reflection(question, context, scratchpad, image)
214+
]
121215
return format_reflections(self.reflections)
122216

123-
def prompt_agent(self, question: str, reflections: str, scratchpad: str) -> str:
217+
def prompt_agent(
218+
self,
219+
question: str,
220+
reflections: str,
221+
scratchpad: str,
222+
image: Optional[Union[str, Path]] = None,
223+
) -> str:
224+
if isinstance(self.action_agent, LLM):
225+
return format_step(
226+
self.action_agent(
227+
self._build_agent_prompt(question, reflections, scratchpad)
228+
)
229+
)
124230
return format_step(
125231
self.action_agent(
126-
self._build_agent_prompt(question, reflections, scratchpad)
232+
self._build_agent_prompt(question, reflections, scratchpad),
233+
image=image,
127234
)
128235
)
129236

130237
def prompt_reflection(
131-
self, question: str, context: str = "", scratchpad: str = ""
238+
self,
239+
question: str,
240+
context: str = "",
241+
scratchpad: str = "",
242+
image: Optional[Union[str, Path]] = None,
132243
) -> str:
244+
if isinstance(self.self_reflect_model, LLM):
245+
return format_step(
246+
self.self_reflect_model(
247+
self._build_reflect_prompt(question, context, scratchpad)
248+
)
249+
)
133250
return format_step(
134-
self.self_reflect_llm(
135-
self._build_reflect_prompt(question, context, scratchpad)
251+
self.self_reflect_model(
252+
self._build_reflect_prompt(question, context, scratchpad),
253+
image=image,
136254
)
137255
)
138256

vision_agent/data/data.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@ class DataStore:
2222
r"""A class to store and manage image data along with its generated metadata from an LMM."""
2323

2424
def __init__(self, df: pd.DataFrame):
25-
r"""Initializes the DataStore with a DataFrame containing image paths and image IDs. If the image IDs are not present, they are generated using UUID4. The DataFrame must contain an 'image_paths' column.
25+
r"""Initializes the DataStore with a DataFrame containing image paths and image
26+
IDs. If the image IDs are not present, they are generated using UUID4. The
27+
DataFrame must contain an 'image_paths' column.
2628
2729
Args:
28-
df (pd.DataFrame): The DataFrame containing "image_paths" and "image_id" columns.
30+
df: The DataFrame containing "image_paths" and "image_id" columns.
2931
"""
3032
self.df = df
3133
self.lmm: Optional[LMM] = None
@@ -47,12 +49,14 @@ def add_lmm(self, lmm: LMM) -> Self:
4749
def add_column(
4850
self, name: str, prompt: str, func: Optional[Callable[[str], str]] = None
4951
) -> Self:
50-
r"""Adds a new column to the DataFrame containing the generated metadata from the LMM.
52+
r"""Adds a new column to the DataFrame containing the generated metadata from
53+
the LMM.
5154
5255
Args:
53-
name (str): The name of the column to be added.
54-
prompt (str): The prompt to be used to generate the metadata.
55-
func (Optional[Callable[[Any], Any]]): A Python function to be applied on the output of `lmm.generate`. Defaults to None.
56+
name: The name of the column to be added.
57+
prompt: The prompt to be used to generate the metadata.
58+
func: A Python function to be applied on the output of `lmm.generate`.
59+
Defaults to None.
5660
"""
5761
if self.lmm is None:
5862
raise ValueError("LMM not set yet")
@@ -67,10 +71,11 @@ def add_column(
6771
return self
6872

6973
def build_index(self, target_col: str) -> Self:
70-
r"""This will generate embeddings for the `target_col` and build a searchable index over them, so next time you run search it will search over this index.
74+
r"""This will generate embeddings for the `target_col` and build a searchable
75+
index over them, so next time you run search it will search over this index.
7176
7277
Args:
73-
target_col (str): The column name containing the data to be indexed."""
78+
target_col: The column name containing the data to be indexed."""
7479
if self.emb is None:
7580
raise ValueError("Embedder not set yet")
7681

@@ -92,11 +97,12 @@ def get_embeddings(self) -> npt.NDArray[np.float32]:
9297
)
9398

9499
def search(self, query: str, top_k: int = 10) -> List[Dict]:
95-
r"""Searches the index for the most similar images to the query and returns the top_k results.
100+
r"""Searches the index for the most similar images to the query and returns
101+
the top_k results.
96102
97103
Args:
98-
query (str): The query to search for.
99-
top_k (int, optional): The number of results to return. Defaults to 10."""
104+
query: The query to search for.
105+
top_k: The number of results to return. Defaults to 10."""
100106
if self.index is None:
101107
raise ValueError("Index not built yet")
102108
if self.emb is None:

0 commit comments

Comments
 (0)