Skip to content

Commit

Permalink
Merge pull request #29 from igorbenav/workflow-fix
Browse files Browse the repository at this point in the history
fix number of passed parameters bug
  • Loading branch information
igorbenav authored Dec 16, 2024
2 parents d52d1d2 + 3bfc69b commit ddc49b7
Show file tree
Hide file tree
Showing 9 changed files with 49 additions and 28 deletions.
6 changes: 4 additions & 2 deletions clientai/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def generate_text(
return_full_response: bool = False,
stream: bool = False,
**kwargs: Any,
) -> GenericResponse[R, T, S]: ...
) -> GenericResponse[R, T, S]:
...

def chat(
self,
Expand All @@ -35,7 +36,8 @@ def chat(
return_full_response: bool = False,
stream: bool = False,
**kwargs: Any,
) -> GenericResponse[R, T, S]: ...
) -> GenericResponse[R, T, S]:
...


P = TypeVar("P", bound=AIProviderProtocol)
Expand Down
32 changes: 20 additions & 12 deletions clientai/agent/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,23 +388,22 @@ def _execute_step(
"""
try:
if param_count == 0:
return engine.execute_step(step, agent, stream=current_stream)
return engine.execute_step(step, stream=current_stream)
elif param_count == 1:
input_data = (
agent.context.original_input
if len(agent.context.last_results) == 0
else last_result
)
return engine.execute_step(
step, agent, input_data, stream=current_stream
step, input_data, stream=current_stream
)
else:
previous_results = self._get_previous_results(
agent, param_count
)
return engine.execute_step(
step, agent, *previous_results, stream=current_stream
agent, param_count - 1
)
args = [last_result] + previous_results
return engine.execute_step(step, *args, stream=current_stream)
except Exception as e:
raise StepError(
f"Failed to execute step '{step.name}': {str(e)}"
Expand Down Expand Up @@ -524,10 +523,13 @@ def execute(
is_intermediate_step=True,
)

result = engine.execute_step(
step,
last_result,
stream=current_stream,
result = self._execute_step(
step=step,
agent=agent,
last_result=last_result,
param_count=param_count,
current_stream=current_stream,
engine=engine,
)

step_result = self._handle_step_result(
Expand All @@ -537,6 +539,7 @@ def execute(
last_result = step_result

logger.debug(f"Step {step.name} completed")

except (StepError, ValueError) as e:
logger.error(f"Error in step '{step.name}': {e}")
if step.config.required:
Expand All @@ -562,8 +565,13 @@ def execute(
is_intermediate_step=False,
)

result = engine.execute_step(
final_step, last_result, stream=current_stream
result = self._execute_step(
step=final_step,
agent=agent,
last_result=last_result,
param_count=param_count,
current_stream=current_stream,
engine=engine,
)

if not current_stream:
Expand Down
6 changes: 4 additions & 2 deletions clientai/agent/steps/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,11 +601,13 @@ def summarize_data(self, data: str) -> str:
@overload
def run(
*, description: Optional[str] = None
) -> Callable[[Callable[..., T]], RunFunction]: ...
) -> Callable[[Callable[..., T]], RunFunction]:
...


@overload
def run(func: Callable[..., T]) -> RunFunction: ...
def run(func: Callable[..., T]) -> RunFunction:
...


def run(
Expand Down
3 changes: 2 additions & 1 deletion clientai/agent/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class ToolProtocol(Protocol):
name: str
description: str

def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
def __call__(self, *args: Any, **kwargs: Any) -> Any:
...


__all__ = [
Expand Down
3 changes: 2 additions & 1 deletion clientai/groq/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def create(
model: str,
stream: bool = False,
**kwargs: Any,
) -> Union[GroqResponse, Iterator[GroqStreamResponse]]: ...
) -> Union[GroqResponse, Iterator[GroqStreamResponse]]:
...


class GroqChatProtocol(Protocol):
Expand Down
6 changes: 4 additions & 2 deletions clientai/ollama/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def generate(
temperature: Optional[float] = None,
top_p: Optional[float] = None,
**kwargs: Any,
) -> Union[OllamaResponse, Iterator[OllamaStreamResponse]]: ...
) -> Union[OllamaResponse, Iterator[OllamaStreamResponse]]:
...

def chat(
self,
Expand All @@ -69,7 +70,8 @@ def chat(
temperature: Optional[float] = None,
top_p: Optional[float] = None,
**kwargs: Any,
) -> Union[OllamaChatResponse, Iterator[OllamaStreamResponse]]: ...
) -> Union[OllamaChatResponse, Iterator[OllamaStreamResponse]]:
...


Client = "ollama.Client"
6 changes: 3 additions & 3 deletions clientai/ollama/manager/platform_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,9 @@ def get_environment(self, config: OllamaServerConfig) -> Dict[str, str]:
)
env["GPU_DEVICE_ORDINAL"] = ",".join(map(str, devices))
if config.gpu_memory_fraction is not None:
env["GPU_MAX_HEAP_SIZE"] = (
f"{int(config.gpu_memory_fraction * 100)}%"
)
env[
"GPU_MAX_HEAP_SIZE"
] = f"{int(config.gpu_memory_fraction * 100)}%"

elif self.gpu_vendor == GPUVendor.APPLE:
if config.gpu_layers is not None:
Expand Down
6 changes: 4 additions & 2 deletions clientai/openai/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ class OpenAIStreamResponse:
class OpenAIChatCompletionProtocol(Protocol):
def create(
self, **kwargs: Any
) -> Union[OpenAIResponse, Iterator[OpenAIStreamResponse]]: ...
) -> Union[OpenAIResponse, Iterator[OpenAIStreamResponse]]:
...


class OpenAIChatProtocol(Protocol):
Expand All @@ -89,7 +90,8 @@ def create(
messages: List[Message],
stream: bool = False,
**kwargs: Any,
) -> Union[OpenAIResponse, OpenAIStreamResponse]: ...
) -> Union[OpenAIResponse, OpenAIStreamResponse]:
...


OpenAIProvider = Any
Expand Down
9 changes: 6 additions & 3 deletions clientai/replicate/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ class ReplicatePredictionProtocol(Protocol):
error: Optional[str]
output: Any

def stream(self) -> Iterator[Any]: ...
def stream(self) -> Iterator[Any]:
...


ReplicatePrediction = ReplicatePredictionProtocol
Expand Down Expand Up @@ -59,10 +60,12 @@ class ReplicateResponse(TypedDict):

class ReplicatePredictionsProtocol(Protocol):
@staticmethod
def create(**kwargs: Any) -> ReplicatePredictionProtocol: ...
def create(**kwargs: Any) -> ReplicatePredictionProtocol:
...

@staticmethod
def get(id: str) -> ReplicatePredictionProtocol: ...
def get(id: str) -> ReplicatePredictionProtocol:
...


class ReplicateClientProtocol(Protocol):
Expand Down

0 comments on commit ddc49b7

Please sign in to comment.