diff --git a/alfred/client/client.py b/alfred/client/client.py index e7b173b..2117471 100644 --- a/alfred/client/client.py +++ b/alfred/client/client.py @@ -32,6 +32,7 @@ def __init__( end_point: Optional[str] = None, local_path: Optional[str] = None, ssh_tunnel: bool = False, + ssh_pk: str = '~/.ssh/id_rsa', ssh_node: Optional[str] = None, cache: Optional[Cache] = None, **kwargs: Any, @@ -48,6 +49,8 @@ def __init__( :param local_path: (optional) The local path of the model. (e.g. "/home/user/.cache/model") :param ssh_tunnel: Whether to establish an SSH tunnel to the end point. :type ssh_tunnel: bool + :param ssh_pk: ssh RSA key location + :type ssh_pk: str :param ssh_node: (optional) The final SSH node to establish the SSH tunnel. (e.g. gpu node on a cluster with login node as jump) :type ssh_node: str :param cache: (optional) The cache to use. (e.g. "SQLite", "Dummy") @@ -110,6 +113,7 @@ def __init__( remote_port=self.end_point_port, remote_node_address=ssh_node, username=user_name, + key_file=ssh_pk, ) tunnel.start() diff --git a/alfred/client/ssh/sshtunnel.py b/alfred/client/ssh/sshtunnel.py index d036b6d..279e332 100644 --- a/alfred/client/ssh/sshtunnel.py +++ b/alfred/client/ssh/sshtunnel.py @@ -3,6 +3,7 @@ from typing import Optional, Union, Callable import paramiko +import os from alfred.client.ssh.utils import port_finder, forward_tunnel @@ -39,6 +40,7 @@ def __init__( username: Optional[str] = None, remote_node_address: Optional[str] = None, remote_bind_port: Optional[Union[int, str]] = 443, + key_file: str="~/.ssh/id_rsa", handler: Callable = None, ): """ @@ -56,6 +58,8 @@ def __init__( :type remote_node_address: str :param remote_bind_port: (optional) The remote bind port to connect to, defaults to 443 :type remote_bind_port: Optional[Union[int, str]], optional + :param key_file: (optional) SSH key file + :type key_file: str :param handler: The handler for interactive authentication, defaults to adaptive handler :type handler: Callable, optional """ @@ -69,6 +73,10 @@ def __init__( self.username = username or input("Username: ") self.remote_node_address = remote_node_address self.remote_bind_port = remote_bind_port + + # if key file exist the nuse key_file else None + self.key_file = paramiko.RSAKey.from_private_key_file(os.path.expanduser(key_file)) if os.path.isfile(os.path.expanduser(key_file)) else None + self.handler = handler or self.adaptive_handler def start(self): @@ -89,19 +97,26 @@ def _start(self): self.client = paramiko.SSHClient() self.client.load_system_host_keys() - self.client.set_missing_host_key_policy(paramiko.WarningPolicy()) + self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) try: - self.client.connect(self.remote_host, username=self.username) + self.client.connect( + self.remote_host, + username=self.username, + pkey=self.key_file, + look_for_keys=False + ) except paramiko.ssh_exception.SSHException: pass - try: - self.client.get_transport().auth_interactive( - username=self.username, handler=self.handler) - except paramiko.ssh_exception.AuthenticationException: - logger.error("Wrong Password, Please restart the Tunnel") - raise paramiko.ssh_exception.AuthenticationException + if not self.client.get_transport().is_authenticated(): + try: + self.client.get_transport().auth_interactive( + username=self.username, handler=self.handler) + except paramiko.ssh_exception.AuthenticationException: + logger.error("Wrong Password, Please restart the Tunnel") + raise paramiko.ssh_exception.AuthenticationException + logger.log(logging.INFO, f"Connected to {self.remote_host} @ port 22") port = self.client.get_transport().sock.getsockname()[1] diff --git a/alfred/fm/ai21.py b/alfred/fm/ai21.py index f4e3872..817a61d 100644 --- a/alfred/fm/ai21.py +++ b/alfred/fm/ai21.py @@ -9,9 +9,9 @@ logger = logging.getLogger(__name__) AI21_MODELS = ( - "j1-large", - "j1-grande", - "j1-jumbo", + "j1-light", + "j1-mid", + "j1-ultra", ) @@ -26,7 +26,7 @@ def _ai21_query( query_string: str, temperature: float = 0.0, max_tokens: int = 10, - model: str = "j1-large", + model: str = "j1-mid", ) -> str: """ Run a single query through the foundation model @@ -62,7 +62,7 @@ def _ai21_query( def __init__( self, - model_string: str = "j1-large", + model_string: str = "j1-mid", api_key: Optional[str] = None, ): """ diff --git a/alfred/fm/anthropic.py b/alfred/fm/anthropic.py index 8ed78f2..5c95fb4 100644 --- a/alfred/fm/anthropic.py +++ b/alfred/fm/anthropic.py @@ -13,11 +13,10 @@ logger = logging.getLogger(__name__) ANTHROPIC_MODELS = ( - "claude-v1", - "claude-v1.0", - "claude-v1.2", - "claude-instant-v1", - "claude-instant-v1.0", + "claude-instant-1", + "claude-instant-1.2", + "claude-2", + "claude-2.0", ) try: @@ -39,7 +38,7 @@ def _anthropic_query( query: Union[str, List], temperature: float = 0.0, max_tokens: int = 3, - model: str = "claude-v1", + model: str = "claude-instant-1", **kwargs: Any, ) -> str: """ @@ -82,7 +81,7 @@ def _anthropic_query( return response["completion"] def __init__(self, - model_string: str = "claude-v1", + model_string: str = "claude-instant-1", api_key: Optional[str] = None): """ Initialize the Anthropic API wrapper. diff --git a/docs/alfred/client/client.md b/docs/alfred/client/client.md index aa8f394..05e26db 100644 --- a/docs/alfred/client/client.md +++ b/docs/alfred/client/client.md @@ -38,6 +38,7 @@ class Client: end_point: Optional[str] = None, local_path: Optional[str] = None, ssh_tunnel: bool = False, + ssh_pk: str = "~/.ssh/id_rsa", ssh_node: Optional[str] = None, cache: Optional[Cache] = None, **kwargs: Any @@ -47,7 +48,7 @@ class Client: ### Client().__call__ -[Show source in client.py:270](../../../alfred/client/client.py#L270) +[Show source in client.py:274](../../../alfred/client/client.py#L274) __call__() function to run the model on the queries. Equivalent to run() function. @@ -75,7 +76,7 @@ def __call__( ### Client().calibrate -[Show source in client.py:285](../../../alfred/client/client.py#L285) +[Show source in client.py:289](../../../alfred/client/client.py#L289) calibrate are used to calibrate foundation models contextually given the template. A voter class may be passed to calibrate the model with a specific voter. @@ -120,7 +121,7 @@ def calibrate( ### Client().chat -[Show source in client.py:387](../../../alfred/client/client.py#L387) +[Show source in client.py:391](../../../alfred/client/client.py#L391) Chat with the model APIs. Currently, Alfred supports Chat APIs from Anthropic and OpenAI @@ -139,7 +140,7 @@ def chat(self, log_save_path: Optional[str] = None, **kwargs: Any): ### Client().encode -[Show source in client.py:361](../../../alfred/client/client.py#L361) +[Show source in client.py:365](../../../alfred/client/client.py#L365) embed() function to embed the queries. @@ -162,7 +163,7 @@ def encode( ### Client().generate -[Show source in client.py:229](../../../alfred/client/client.py#L229) +[Show source in client.py:233](../../../alfred/client/client.py#L233) Wrapper function to generate the response(s) from the model. (For completion) @@ -191,7 +192,7 @@ def generate( ### Client().remote_run -[Show source in client.py:207](../../../alfred/client/client.py#L207) +[Show source in client.py:211](../../../alfred/client/client.py#L211) Wrapper function for running the model on the queries thru a gRPC Server. @@ -218,7 +219,7 @@ def remote_run( ### Client().run -[Show source in client.py:187](../../../alfred/client/client.py#L187) +[Show source in client.py:191](../../../alfred/client/client.py#L191) Run the model on the queries. @@ -245,7 +246,7 @@ def run( ### Client().score -[Show source in client.py:246](../../../alfred/client/client.py#L246) +[Show source in client.py:250](../../../alfred/client/client.py#L250) Wrapper function to score the response(s) from the model. (For ranking) diff --git a/docs/alfred/client/ssh/sshtunnel.md b/docs/alfred/client/ssh/sshtunnel.md index 36c88ab..51da227 100644 --- a/docs/alfred/client/ssh/sshtunnel.md +++ b/docs/alfred/client/ssh/sshtunnel.md @@ -16,7 +16,7 @@ SSHTunnel ## SSHTunnel -[Show source in sshtunnel.py:12](../../../../alfred/client/ssh/sshtunnel.py#L12) +[Show source in sshtunnel.py:13](../../../../alfred/client/ssh/sshtunnel.py#L13) SSH Tunnel implemented with paramiko and supports interactive authentication This tunnel would be very useful if you have a alfred.fm model on remote server that you want to access @@ -36,6 +36,7 @@ class SSHTunnel: username: Optional[str] = None, remote_node_address: Optional[str] = None, remote_bind_port: Optional[Union[int, str]] = 443, + key_file: str = "~/.ssh/id_rsa", handler: Callable = None, ): ... @@ -43,7 +44,7 @@ class SSHTunnel: ### SSHTunnel.adaptive_handler -[Show source in sshtunnel.py:20](../../../../alfred/client/ssh/sshtunnel.py#L20) +[Show source in sshtunnel.py:21](../../../../alfred/client/ssh/sshtunnel.py#L21) Authentication handler for paramiko's interactive authentication @@ -57,7 +58,7 @@ def adaptive_handler(title, instructions, prompt_list): ### SSHTunnel().start -[Show source in sshtunnel.py:74](../../../../alfred/client/ssh/sshtunnel.py#L74) +[Show source in sshtunnel.py:82](../../../../alfred/client/ssh/sshtunnel.py#L82) Wrapper for _start() with exception handling @@ -70,7 +71,7 @@ def start(self): ### SSHTunnel().stop -[Show source in sshtunnel.py:127](../../../../alfred/client/ssh/sshtunnel.py#L127) +[Show source in sshtunnel.py:142](../../../../alfred/client/ssh/sshtunnel.py#L142) Stop the tunnel diff --git a/docs/alfred/fm/ai21.md b/docs/alfred/fm/ai21.md index 0d0920a..847c7fa 100644 --- a/docs/alfred/fm/ai21.md +++ b/docs/alfred/fm/ai21.md @@ -22,7 +22,7 @@ This class provides a wrapper for the OpenAI API for generating completions. ```python class AI21Model(APIAccessFoundationModel): - def __init__(self, model_string: str = "j1-large", api_key: Optional[str] = None): + def __init__(self, model_string: str = "j1-mid", api_key: Optional[str] = None): ... ``` diff --git a/docs/alfred/fm/anthropic.md b/docs/alfred/fm/anthropic.md index b64b19f..8441b2d 100644 --- a/docs/alfred/fm/anthropic.md +++ b/docs/alfred/fm/anthropic.md @@ -13,7 +13,7 @@ Anthropic ## AnthropicModel -[Show source in anthropic.py:30](../../../alfred/fm/anthropic.py#L30) +[Show source in anthropic.py:29](../../../alfred/fm/anthropic.py#L29) A wrapper for the anthropic API. @@ -23,13 +23,15 @@ This class provides a wrapper for the anthropic API for generating completions. ```python class AnthropicModel(APIAccessFoundationModel): - def __init__(self, model_string: str = "claude-v1", api_key: Optional[str] = None): + def __init__( + self, model_string: str = "claude-instant-1", api_key: Optional[str] = None + ): ... ``` ### AnthropicModel().chat -[Show source in anthropic.py:145](../../../alfred/fm/anthropic.py#L145) +[Show source in anthropic.py:144](../../../alfred/fm/anthropic.py#L144) Launch an interactive chat session with the Anthropic API.