Skip to content

feat: add ai plugin #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Use the official Python runtime as a parent image
FROM python:3.12-slim
FROM python:3.12.3-slim

# Set the working directory to /app
WORKDIR /app
Expand Down
18 changes: 15 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,24 @@ memos:
token: xxxxxxx

plugins:
you_get_plugins:
- name: download
tag: webhook/download
- name: download
tag: webhook/download
you_get_plugin:
patterns:
- https://twitter.com/\w+/status/\d+
- https://x.com/\w+/status/\d+
- name: fix_typos
tag: task/fix_typos
zhipu_plugin:
api_key: xxxxxxx
prompt: |
You are a fix typos plugin.
Please fix the typos in the text.

The text is:
```
{content}
```
```

And config definitionn is in [config.py](src/dependencies/config.py)
32 changes: 29 additions & 3 deletions example/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,35 @@ memos:
token: xxxxxxx

plugins:
you_get_plugins:
- name: download
tag: webhook/download
- name: download
tag: webhook/download
you_get_plugin:
patterns:
- https://twitter.com/\w+/status/\d+
- https://x.com/\w+/status/\d+
- name: fix_typos
tag: task/fix_typos
zhipu_plugin:
api_key: xxxxxxx
prompt: |
You are a fix typos plugin.
Please fix the typos in the text.

The text is:
```
{content}
```
- name: translate
tag: task/translate
zhipu_plugin:
api_key: xxxxxxx
prompt: |
You are a translation plugin.
Please translate the text.

The text is:
```
{content}
```


15 changes: 11 additions & 4 deletions memos_webhook/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import contextlib
from typing import Annotated, Any

from fastapi import BackgroundTasks, Depends, FastAPI, Request
from fastapi import BackgroundTasks, Depends, FastAPI, Request, status
from fastapi.responses import JSONResponse

import memos_webhook.proto_gen.memos.api.v1 as v1
from memos_webhook.dependencies.config import get_config, new_config
Expand Down Expand Up @@ -32,7 +33,7 @@ async def lifespan(app: FastAPI):
logger.info(f"memos info: {cfg.memos_host}:{cfg.memos_port}")
logger.info(f"log level: {cfg.log_level}")
if cfg.plugins:
plugin_names = [plugin.name for plugin in cfg.plugins.you_get_plugins]
plugin_names = [plugin.name for plugin in cfg.plugins]
logger.info(f"plugins: {plugin_names}")
logger.info(f"")
with new_memos_cli(cfg) as memos_cli:
Expand Down Expand Up @@ -77,8 +78,14 @@ async def webhook_handler(
"""The new webhook handler, use protojson."""
dict_json = await req.json()
logger.debug(f"webhook handler received request: {dict_json}")

proto_payload = v1.WebhookRequestPayload().from_dict(dict_json)
try:
proto_payload = v1.WebhookRequestPayload().from_dict(dict_json)
except Exception as e:
logger.error(f"parse payload error: {e}")
return JSONResponse(
content={"code": 1, "message": f"parse payload error: {e}"},
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
)
background_tasks.add_task(webhook_task, proto_payload, executor)
return {
"code": 0,
Expand Down
34 changes: 29 additions & 5 deletions memos_webhook/dependencies/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# mypy: disable-error-code="empty-body"
from typing import Any

from pydantic import BaseModel

from memos_webhook.utils.config_decorators import (ArgsConfigProvider,
Expand All @@ -9,18 +12,39 @@


class YouGetPluginConfig(BaseModel):
name: str
tag: str
patterns: list[str]


class ZhipuPluginConfig(BaseModel):
api_key: str
prompt: str
"""Multiple lines for prompt message.
Just same format as langchain prompt.
Available variables: `{content}`"""


class PluginConfig(BaseModel):
you_get_plugins: list[YouGetPluginConfig]
name: str
tag: str
you_get_plugin: YouGetPluginConfig | None = None
zhipu_plugin: ZhipuPluginConfig | None = None


def parse_config_list(raw: Any) -> list[PluginConfig]:
if raw is None:
return []
assert isinstance(raw, list)
res: list[PluginConfig] = []
for item in raw:
res.append(PluginConfig.model_validate(item))

return res


arg_parser = ArgsConfigProvider()


# type: ignore
class Config(BaseUnmarshalConfig, BaseDotenvConfig, BaseArgsConfig):
@from_env()
@arg_parser.from_flag(
Expand Down Expand Up @@ -81,9 +105,9 @@ def memos_port(self) -> str: ...
@from_unmarshal("memos", "token")
def memos_token(self) -> str: ...

@it_is(PluginConfig)
@it_is(list[PluginConfig], transformer=parse_config_list)
@from_unmarshal()
def plugins(self) -> PluginConfig: ...
def plugins(self) -> list[PluginConfig]: ...


_config: Config
Expand Down
2 changes: 1 addition & 1 deletion memos_webhook/dependencies/memos_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class MemosCli:
resource_service: v1.ResourceServiceStub


_cli: MemosCli = None
_cli: MemosCli | None = None


@contextlib.contextmanager
Expand Down
48 changes: 32 additions & 16 deletions memos_webhook/dependencies/plugin_manager.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,51 @@
from memos_webhook.plugins.base_plugin import PluginExecutor, PluginProtocol
from memos_webhook.plugins.base_plugin import IPlugin, PluginExecutor
from memos_webhook.plugins.you_get_plugin import YouGetPlugin
from memos_webhook.plugins.zhipu_plugin import ZhipuPlugin

from .config import Config, PluginConfig, YouGetPluginConfig
from .memos_cli import MemosCli

_plugin_executor: PluginExecutor = None
_plugin_executor: PluginExecutor | None = None


_DEFAULT_PLUGIN_CFG = PluginConfig(
you_get_plugins=[
YouGetPluginConfig(
name="you-get",
tag="hook/download",
_DEFAULT_PLUGIN_CFG = [
PluginConfig(
name="you-get",
tag="hook/download",
you_get_plugin=YouGetPluginConfig(
patterns=[
"https://twitter.com/\\w+/status/\\d+",
"https://x.com/\\w+/status/\\d+",
],
),
]
)
)
]


def new_plugin_executor(cfg: Config, memos_cli: MemosCli) -> PluginExecutor:
global _plugin_executor
plugins_cfg = cfg.plugins
if plugins_cfg is None:
plugins_cfg = _DEFAULT_PLUGIN_CFG # temp fake cfg

plugins: list[PluginProtocol] = []
for you_get_plugin_cfg in plugins_cfg.you_get_plugins:
plugins.append(YouGetPlugin(you_get_plugin_cfg))
plugins_cfgs = cfg.plugins
if plugins_cfgs is None:
plugins_cfgs = _DEFAULT_PLUGIN_CFG # temp fake cfg

plugins: list[IPlugin] = []
for plugin_cfg in plugins_cfgs:
if plugin_cfg.you_get_plugin is not None:
plugins.append(
YouGetPlugin(
plugin_cfg.name,
plugin_cfg.tag,
plugin_cfg.you_get_plugin,
)
)
if plugin_cfg.zhipu_plugin is not None:
plugins.append(
ZhipuPlugin(
plugin_cfg.name,
plugin_cfg.tag,
plugin_cfg.zhipu_plugin,
)
)

_plugin_executor = PluginExecutor(memos_cli, plugins)
return _plugin_executor
Expand Down
Empty file.
Empty file.
92 changes: 92 additions & 0 deletions memos_webhook/langchains/zhipu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import Any, Dict, List, Optional

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.embeddings.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.llms import LLM
from langchain_core.messages import (AIMessage, BaseMessage, HumanMessage,
SystemMessage)
from langchain_core.outputs import ChatGeneration, ChatResult
from pydantic.v1 import root_validator
from pydantic.v1.main import BaseModel
from zhipuai import ZhipuAI
from zhipuai.api_resource.chat.completions import Completions


class ZhipuAIChatModel(BaseChatModel):
api_key: str = ""
temperature: float = 0.9
top_p: float = 0.7
client: ZhipuAI = None
max_tokens: int = 1024
model_name: str = "chatglm_turbo"

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
from zhipuai import ZhipuAI
from zhipuai.api_resource.chat.completions import Completions

# if values["client"] is None:
values["client"] = ZhipuAI(api_key=values["api_key"])
return values

@property
def _llm_type(self) -> str:
return "zhipuai"

def _generate(self,
messages: List[BaseMessage],
stop: List[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any) -> ChatResult:
completions = Completions(client=self.client)
zhipuMessages = [self._langchainMsgToZhipuMsg(m) for m in messages]
print(f"{zhipuMessages=}")
response = completions.create(
model=self.model_name,
max_tokens=self.max_tokens,
messages=zhipuMessages,
temperature=self.temperature,
top_p=self.top_p,
)
print(f"{response=}")
generation = ChatGeneration(message=AIMessage(
content=response.choices[0].message.content))

return ChatResult(generations=[generation],
llm_output=response.model_dump())

def _langchainMsgToZhipuMsg(self, message: BaseMessage) -> Dict:
if isinstance(message, AIMessage):
return {"role": "assistant", "content": message.content}
if isinstance(message, HumanMessage):
return {"role": "user", "content": message.content}
if isinstance(message, SystemMessage):
return {"role": "system", "content": message.content}

# fallback to human message
return {"role": "user", "content": message.content}


class ZhipuAIEmbedding(BaseModel, Embeddings):
api_key: str = ""
client: Any
model_name: str = "embedding-2"

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
from zhipuai import ZhipuAI
values["client"] = ZhipuAI(api_key=values["api_key"])
return values

def embed_query(self, text: str) -> List[float]:
print(f"[embed_query] {text=}")
embeddings = self.client.embeddings.create(
model=self.model_name,
input=text,
)
return embeddings.data[0].embedding

def embed_documents(self, texts: List[str]) -> List[List[float]]:
print(f"[embed_documents] {texts=}")
return [self.embed_query(text) for text in texts]
Loading
Loading