Skip to content

Commit

Permalink
Merge pull request #9 from RyoJerryYu/feat-add-ai-plugin
Browse files Browse the repository at this point in the history
feat: add ai plugin
  • Loading branch information
RyoJerryYu authored Jun 15, 2024
2 parents 62f476b + 612ed06 commit f84c524
Show file tree
Hide file tree
Showing 22 changed files with 359 additions and 82 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Use the official Python runtime as a parent image
FROM python:3.12-slim
FROM python:3.12.3-slim

# Set the working directory to /app
WORKDIR /app
Expand Down
18 changes: 15 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,24 @@ memos:
token: xxxxxxx
plugins:
you_get_plugins:
- name: download
tag: webhook/download
- name: download
tag: webhook/download
you_get_plugin:
patterns:
- https://twitter.com/\w+/status/\d+
- https://x.com/\w+/status/\d+
- name: fix_typos
tag: task/fix_typos
zhipu_plugin:
api_key: xxxxxxx
prompt: |
You are a fix typos plugin.
Please fix the typos in the text.
The text is:
```
{content}
```
```

And config definitionn is in [config.py](src/dependencies/config.py)
32 changes: 29 additions & 3 deletions example/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,35 @@ memos:
token: xxxxxxx

plugins:
you_get_plugins:
- name: download
tag: webhook/download
- name: download
tag: webhook/download
you_get_plugin:
patterns:
- https://twitter.com/\w+/status/\d+
- https://x.com/\w+/status/\d+
- name: fix_typos
tag: task/fix_typos
zhipu_plugin:
api_key: xxxxxxx
prompt: |
You are a fix typos plugin.
Please fix the typos in the text.
The text is:
```
{content}
```
- name: translate
tag: task/translate
zhipu_plugin:
api_key: xxxxxxx
prompt: |
You are a translation plugin.
Please translate the text.
The text is:
```
{content}
```
15 changes: 11 additions & 4 deletions memos_webhook/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import contextlib
from typing import Annotated, Any

from fastapi import BackgroundTasks, Depends, FastAPI, Request
from fastapi import BackgroundTasks, Depends, FastAPI, Request, status
from fastapi.responses import JSONResponse

import memos_webhook.proto_gen.memos.api.v1 as v1
from memos_webhook.dependencies.config import get_config, new_config
Expand Down Expand Up @@ -32,7 +33,7 @@ async def lifespan(app: FastAPI):
logger.info(f"memos info: {cfg.memos_host}:{cfg.memos_port}")
logger.info(f"log level: {cfg.log_level}")
if cfg.plugins:
plugin_names = [plugin.name for plugin in cfg.plugins.you_get_plugins]
plugin_names = [plugin.name for plugin in cfg.plugins]
logger.info(f"plugins: {plugin_names}")
logger.info(f"")
with new_memos_cli(cfg) as memos_cli:
Expand Down Expand Up @@ -77,8 +78,14 @@ async def webhook_handler(
"""The new webhook handler, use protojson."""
dict_json = await req.json()
logger.debug(f"webhook handler received request: {dict_json}")

proto_payload = v1.WebhookRequestPayload().from_dict(dict_json)
try:
proto_payload = v1.WebhookRequestPayload().from_dict(dict_json)
except Exception as e:
logger.error(f"parse payload error: {e}")
return JSONResponse(
content={"code": 1, "message": f"parse payload error: {e}"},
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
)
background_tasks.add_task(webhook_task, proto_payload, executor)
return {
"code": 0,
Expand Down
34 changes: 29 additions & 5 deletions memos_webhook/dependencies/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# mypy: disable-error-code="empty-body"
from typing import Any

from pydantic import BaseModel

from memos_webhook.utils.config_decorators import (ArgsConfigProvider,
Expand All @@ -9,18 +12,39 @@


class YouGetPluginConfig(BaseModel):
name: str
tag: str
patterns: list[str]


class ZhipuPluginConfig(BaseModel):
api_key: str
prompt: str
"""Multiple lines for prompt message.
Just same format as langchain prompt.
Available variables: `{content}`"""


class PluginConfig(BaseModel):
you_get_plugins: list[YouGetPluginConfig]
name: str
tag: str
you_get_plugin: YouGetPluginConfig | None = None
zhipu_plugin: ZhipuPluginConfig | None = None


def parse_config_list(raw: Any) -> list[PluginConfig]:
if raw is None:
return []
assert isinstance(raw, list)
res: list[PluginConfig] = []
for item in raw:
res.append(PluginConfig.model_validate(item))

return res


arg_parser = ArgsConfigProvider()


# type: ignore
class Config(BaseUnmarshalConfig, BaseDotenvConfig, BaseArgsConfig):
@from_env()
@arg_parser.from_flag(
Expand Down Expand Up @@ -81,9 +105,9 @@ def memos_port(self) -> str: ...
@from_unmarshal("memos", "token")
def memos_token(self) -> str: ...

@it_is(PluginConfig)
@it_is(list[PluginConfig], transformer=parse_config_list)
@from_unmarshal()
def plugins(self) -> PluginConfig: ...
def plugins(self) -> list[PluginConfig]: ...


_config: Config
Expand Down
2 changes: 1 addition & 1 deletion memos_webhook/dependencies/memos_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class MemosCli:
resource_service: v1.ResourceServiceStub


_cli: MemosCli = None
_cli: MemosCli | None = None


@contextlib.contextmanager
Expand Down
48 changes: 32 additions & 16 deletions memos_webhook/dependencies/plugin_manager.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,51 @@
from memos_webhook.plugins.base_plugin import PluginExecutor, PluginProtocol
from memos_webhook.plugins.base_plugin import IPlugin, PluginExecutor
from memos_webhook.plugins.you_get_plugin import YouGetPlugin
from memos_webhook.plugins.zhipu_plugin import ZhipuPlugin

from .config import Config, PluginConfig, YouGetPluginConfig
from .memos_cli import MemosCli

_plugin_executor: PluginExecutor = None
_plugin_executor: PluginExecutor | None = None


_DEFAULT_PLUGIN_CFG = PluginConfig(
you_get_plugins=[
YouGetPluginConfig(
name="you-get",
tag="hook/download",
_DEFAULT_PLUGIN_CFG = [
PluginConfig(
name="you-get",
tag="hook/download",
you_get_plugin=YouGetPluginConfig(
patterns=[
"https://twitter.com/\\w+/status/\\d+",
"https://x.com/\\w+/status/\\d+",
],
),
]
)
)
]


def new_plugin_executor(cfg: Config, memos_cli: MemosCli) -> PluginExecutor:
global _plugin_executor
plugins_cfg = cfg.plugins
if plugins_cfg is None:
plugins_cfg = _DEFAULT_PLUGIN_CFG # temp fake cfg

plugins: list[PluginProtocol] = []
for you_get_plugin_cfg in plugins_cfg.you_get_plugins:
plugins.append(YouGetPlugin(you_get_plugin_cfg))
plugins_cfgs = cfg.plugins
if plugins_cfgs is None:
plugins_cfgs = _DEFAULT_PLUGIN_CFG # temp fake cfg

plugins: list[IPlugin] = []
for plugin_cfg in plugins_cfgs:
if plugin_cfg.you_get_plugin is not None:
plugins.append(
YouGetPlugin(
plugin_cfg.name,
plugin_cfg.tag,
plugin_cfg.you_get_plugin,
)
)
if plugin_cfg.zhipu_plugin is not None:
plugins.append(
ZhipuPlugin(
plugin_cfg.name,
plugin_cfg.tag,
plugin_cfg.zhipu_plugin,
)
)

_plugin_executor = PluginExecutor(memos_cli, plugins)
return _plugin_executor
Expand Down
Empty file.
Empty file.
92 changes: 92 additions & 0 deletions memos_webhook/langchains/zhipu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import Any, Dict, List, Optional

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.embeddings.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.llms import LLM
from langchain_core.messages import (AIMessage, BaseMessage, HumanMessage,
SystemMessage)
from langchain_core.outputs import ChatGeneration, ChatResult
from pydantic.v1 import root_validator
from pydantic.v1.main import BaseModel
from zhipuai import ZhipuAI
from zhipuai.api_resource.chat.completions import Completions


class ZhipuAIChatModel(BaseChatModel):
api_key: str = ""
temperature: float = 0.9
top_p: float = 0.7
client: ZhipuAI = None
max_tokens: int = 1024
model_name: str = "chatglm_turbo"

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
from zhipuai import ZhipuAI
from zhipuai.api_resource.chat.completions import Completions

# if values["client"] is None:
values["client"] = ZhipuAI(api_key=values["api_key"])
return values

@property
def _llm_type(self) -> str:
return "zhipuai"

def _generate(self,
messages: List[BaseMessage],
stop: List[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any) -> ChatResult:
completions = Completions(client=self.client)
zhipuMessages = [self._langchainMsgToZhipuMsg(m) for m in messages]
print(f"{zhipuMessages=}")
response = completions.create(
model=self.model_name,
max_tokens=self.max_tokens,
messages=zhipuMessages,
temperature=self.temperature,
top_p=self.top_p,
)
print(f"{response=}")
generation = ChatGeneration(message=AIMessage(
content=response.choices[0].message.content))

return ChatResult(generations=[generation],
llm_output=response.model_dump())

def _langchainMsgToZhipuMsg(self, message: BaseMessage) -> Dict:
if isinstance(message, AIMessage):
return {"role": "assistant", "content": message.content}
if isinstance(message, HumanMessage):
return {"role": "user", "content": message.content}
if isinstance(message, SystemMessage):
return {"role": "system", "content": message.content}

# fallback to human message
return {"role": "user", "content": message.content}


class ZhipuAIEmbedding(BaseModel, Embeddings):
api_key: str = ""
client: Any
model_name: str = "embedding-2"

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
from zhipuai import ZhipuAI
values["client"] = ZhipuAI(api_key=values["api_key"])
return values

def embed_query(self, text: str) -> List[float]:
print(f"[embed_query] {text=}")
embeddings = self.client.embeddings.create(
model=self.model_name,
input=text,
)
return embeddings.data[0].embedding

def embed_documents(self, texts: List[str]) -> List[List[float]]:
print(f"[embed_documents] {texts=}")
return [self.embed_query(text) for text in texts]
Loading

0 comments on commit f84c524

Please sign in to comment.