Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding PK support for SSH Tunnel, Updating AI21, Anthropic Models #49

Merged
merged 1 commit into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions alfred/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
end_point: Optional[str] = None,
local_path: Optional[str] = None,
ssh_tunnel: bool = False,
ssh_pk: str = '~/.ssh/id_rsa',
ssh_node: Optional[str] = None,
cache: Optional[Cache] = None,
**kwargs: Any,
Expand All @@ -48,6 +49,8 @@ def __init__(
:param local_path: (optional) The local path of the model. (e.g. "/home/user/.cache/model")
:param ssh_tunnel: Whether to establish an SSH tunnel to the end point.
:type ssh_tunnel: bool
:param ssh_pk: ssh RSA key location
:type ssh_pk: str
:param ssh_node: (optional) The final SSH node to establish the SSH tunnel. (e.g. gpu node on a cluster with login node as jump)
:type ssh_node: str
:param cache: (optional) The cache to use. (e.g. "SQLite", "Dummy")
Expand Down Expand Up @@ -110,6 +113,7 @@ def __init__(
remote_port=self.end_point_port,
remote_node_address=ssh_node,
username=user_name,
key_file=ssh_pk,
)

tunnel.start()
Expand Down
31 changes: 23 additions & 8 deletions alfred/client/ssh/sshtunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional, Union, Callable

import paramiko
import os

from alfred.client.ssh.utils import port_finder, forward_tunnel

Expand Down Expand Up @@ -39,6 +40,7 @@ def __init__(
username: Optional[str] = None,
remote_node_address: Optional[str] = None,
remote_bind_port: Optional[Union[int, str]] = 443,
key_file: str="~/.ssh/id_rsa",
handler: Callable = None,
):
"""
Expand All @@ -56,6 +58,8 @@ def __init__(
:type remote_node_address: str
:param remote_bind_port: (optional) The remote bind port to connect to, defaults to 443
:type remote_bind_port: Optional[Union[int, str]], optional
:param key_file: (optional) SSH key file
:type key_file: str
:param handler: The handler for interactive authentication, defaults to adaptive handler
:type handler: Callable, optional
"""
Expand All @@ -69,6 +73,10 @@ def __init__(
self.username = username or input("Username: ")
self.remote_node_address = remote_node_address
self.remote_bind_port = remote_bind_port

# if key file exist the nuse key_file else None
self.key_file = paramiko.RSAKey.from_private_key_file(os.path.expanduser(key_file)) if os.path.isfile(os.path.expanduser(key_file)) else None

self.handler = handler or self.adaptive_handler

def start(self):
Expand All @@ -89,19 +97,26 @@ def _start(self):

self.client = paramiko.SSHClient()
self.client.load_system_host_keys()
self.client.set_missing_host_key_policy(paramiko.WarningPolicy())
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())

try:
self.client.connect(self.remote_host, username=self.username)
self.client.connect(
self.remote_host,
username=self.username,
pkey=self.key_file,
look_for_keys=False
)
except paramiko.ssh_exception.SSHException:
pass

try:
self.client.get_transport().auth_interactive(
username=self.username, handler=self.handler)
except paramiko.ssh_exception.AuthenticationException:
logger.error("Wrong Password, Please restart the Tunnel")
raise paramiko.ssh_exception.AuthenticationException
if not self.client.get_transport().is_authenticated():
try:
self.client.get_transport().auth_interactive(
username=self.username, handler=self.handler)
except paramiko.ssh_exception.AuthenticationException:
logger.error("Wrong Password, Please restart the Tunnel")
raise paramiko.ssh_exception.AuthenticationException

logger.log(logging.INFO, f"Connected to {self.remote_host} @ port 22")

port = self.client.get_transport().sock.getsockname()[1]
Expand Down
10 changes: 5 additions & 5 deletions alfred/fm/ai21.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
logger = logging.getLogger(__name__)

AI21_MODELS = (
"j1-large",
"j1-grande",
"j1-jumbo",
"j1-light",
"j1-mid",
"j1-ultra",
)


Expand All @@ -26,7 +26,7 @@ def _ai21_query(
query_string: str,
temperature: float = 0.0,
max_tokens: int = 10,
model: str = "j1-large",
model: str = "j1-mid",
) -> str:
"""
Run a single query through the foundation model
Expand Down Expand Up @@ -62,7 +62,7 @@ def _ai21_query(

def __init__(
self,
model_string: str = "j1-large",
model_string: str = "j1-mid",
api_key: Optional[str] = None,
):
"""
Expand Down
13 changes: 6 additions & 7 deletions alfred/fm/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
logger = logging.getLogger(__name__)

ANTHROPIC_MODELS = (
"claude-v1",
"claude-v1.0",
"claude-v1.2",
"claude-instant-v1",
"claude-instant-v1.0",
"claude-instant-1",
"claude-instant-1.2",
"claude-2",
"claude-2.0",
)

try:
Expand All @@ -39,7 +38,7 @@ def _anthropic_query(
query: Union[str, List],
temperature: float = 0.0,
max_tokens: int = 3,
model: str = "claude-v1",
model: str = "claude-instant-1",
**kwargs: Any,
) -> str:
"""
Expand Down Expand Up @@ -82,7 +81,7 @@ def _anthropic_query(
return response["completion"]

def __init__(self,
model_string: str = "claude-v1",
model_string: str = "claude-instant-1",
api_key: Optional[str] = None):
"""
Initialize the Anthropic API wrapper.
Expand Down
17 changes: 9 additions & 8 deletions docs/alfred/client/client.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class Client:
end_point: Optional[str] = None,
local_path: Optional[str] = None,
ssh_tunnel: bool = False,
ssh_pk: str = "~/.ssh/id_rsa",
ssh_node: Optional[str] = None,
cache: Optional[Cache] = None,
**kwargs: Any
Expand All @@ -47,7 +48,7 @@ class Client:

### Client().__call__

[Show source in client.py:270](../../../alfred/client/client.py#L270)
[Show source in client.py:274](../../../alfred/client/client.py#L274)

__call__() function to run the model on the queries.
Equivalent to run() function.
Expand Down Expand Up @@ -75,7 +76,7 @@ def __call__(

### Client().calibrate

[Show source in client.py:285](../../../alfred/client/client.py#L285)
[Show source in client.py:289](../../../alfred/client/client.py#L289)

calibrate are used to calibrate foundation models contextually given the template.
A voter class may be passed to calibrate the model with a specific voter.
Expand Down Expand Up @@ -120,7 +121,7 @@ def calibrate(

### Client().chat

[Show source in client.py:387](../../../alfred/client/client.py#L387)
[Show source in client.py:391](../../../alfred/client/client.py#L391)

Chat with the model APIs.
Currently, Alfred supports Chat APIs from Anthropic and OpenAI
Expand All @@ -139,7 +140,7 @@ def chat(self, log_save_path: Optional[str] = None, **kwargs: Any):

### Client().encode

[Show source in client.py:361](../../../alfred/client/client.py#L361)
[Show source in client.py:365](../../../alfred/client/client.py#L365)

embed() function to embed the queries.

Expand All @@ -162,7 +163,7 @@ def encode(

### Client().generate

[Show source in client.py:229](../../../alfred/client/client.py#L229)
[Show source in client.py:233](../../../alfred/client/client.py#L233)

Wrapper function to generate the response(s) from the model. (For completion)

Expand Down Expand Up @@ -191,7 +192,7 @@ def generate(

### Client().remote_run

[Show source in client.py:207](../../../alfred/client/client.py#L207)
[Show source in client.py:211](../../../alfred/client/client.py#L211)

Wrapper function for running the model on the queries thru a gRPC Server.

Expand All @@ -218,7 +219,7 @@ def remote_run(

### Client().run

[Show source in client.py:187](../../../alfred/client/client.py#L187)
[Show source in client.py:191](../../../alfred/client/client.py#L191)

Run the model on the queries.

Expand All @@ -245,7 +246,7 @@ def run(

### Client().score

[Show source in client.py:246](../../../alfred/client/client.py#L246)
[Show source in client.py:250](../../../alfred/client/client.py#L250)

Wrapper function to score the response(s) from the model. (For ranking)

Expand Down
9 changes: 5 additions & 4 deletions docs/alfred/client/ssh/sshtunnel.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ SSHTunnel

## SSHTunnel

[Show source in sshtunnel.py:12](../../../../alfred/client/ssh/sshtunnel.py#L12)
[Show source in sshtunnel.py:13](../../../../alfred/client/ssh/sshtunnel.py#L13)

SSH Tunnel implemented with paramiko and supports interactive authentication
This tunnel would be very useful if you have a alfred.fm model on remote server that you want to access
Expand All @@ -36,14 +36,15 @@ class SSHTunnel:
username: Optional[str] = None,
remote_node_address: Optional[str] = None,
remote_bind_port: Optional[Union[int, str]] = 443,
key_file: str = "~/.ssh/id_rsa",
handler: Callable = None,
):
...
```

### SSHTunnel.adaptive_handler

[Show source in sshtunnel.py:20](../../../../alfred/client/ssh/sshtunnel.py#L20)
[Show source in sshtunnel.py:21](../../../../alfred/client/ssh/sshtunnel.py#L21)

Authentication handler for paramiko's interactive authentication

Expand All @@ -57,7 +58,7 @@ def adaptive_handler(title, instructions, prompt_list):

### SSHTunnel().start

[Show source in sshtunnel.py:74](../../../../alfred/client/ssh/sshtunnel.py#L74)
[Show source in sshtunnel.py:82](../../../../alfred/client/ssh/sshtunnel.py#L82)

Wrapper for _start() with exception handling

Expand All @@ -70,7 +71,7 @@ def start(self):

### SSHTunnel().stop

[Show source in sshtunnel.py:127](../../../../alfred/client/ssh/sshtunnel.py#L127)
[Show source in sshtunnel.py:142](../../../../alfred/client/ssh/sshtunnel.py#L142)

Stop the tunnel

Expand Down
2 changes: 1 addition & 1 deletion docs/alfred/fm/ai21.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ This class provides a wrapper for the OpenAI API for generating completions.

```python
class AI21Model(APIAccessFoundationModel):
def __init__(self, model_string: str = "j1-large", api_key: Optional[str] = None):
def __init__(self, model_string: str = "j1-mid", api_key: Optional[str] = None):
...
```

Expand Down
8 changes: 5 additions & 3 deletions docs/alfred/fm/anthropic.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Anthropic

## AnthropicModel

[Show source in anthropic.py:30](../../../alfred/fm/anthropic.py#L30)
[Show source in anthropic.py:29](../../../alfred/fm/anthropic.py#L29)

A wrapper for the anthropic API.

Expand All @@ -23,13 +23,15 @@ This class provides a wrapper for the anthropic API for generating completions.

```python
class AnthropicModel(APIAccessFoundationModel):
def __init__(self, model_string: str = "claude-v1", api_key: Optional[str] = None):
def __init__(
self, model_string: str = "claude-instant-1", api_key: Optional[str] = None
):
...
```

### AnthropicModel().chat

[Show source in anthropic.py:145](../../../alfred/fm/anthropic.py#L145)
[Show source in anthropic.py:144](../../../alfred/fm/anthropic.py#L144)

Launch an interactive chat session with the Anthropic API.

Expand Down