Skip to content

Commit 22b4e97

Browse files
committedMar 8, 2023
Begin moving files
1 parent 4473842 commit 22b4e97

20 files changed

+945
-2
lines changed
 

‎.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,5 @@ cython_debug/
150150
# and can be added to the global gitignore or merged into this file. For a more nuclear
151151
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
152152
#.idea/
153+
154+
.DS_Store

‎.pre-commit-config.yaml

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
repos:
2+
- repo: https://github.com/pycqa/autoflake
3+
rev: v2.0.0
4+
hooks:
5+
- id: autoflake
6+
language_version: python3
7+
args:
8+
[
9+
"--in-place",
10+
"--recursive",
11+
"--remove-all-unused-imports",
12+
"--remove-unused-variables",
13+
"--exclude",
14+
'**/__init__.py, **/conftest.py, tests/fixtures/**.py',
15+
]
16+
17+
- repo: https://github.com/pycqa/isort
18+
rev: 5.12.0
19+
hooks:
20+
- id: isort
21+
language_version: python3
22+
23+
- repo: https://github.com/psf/black
24+
rev: 22.12.0
25+
hooks:
26+
- id: black
27+
args: ["--preview"]
28+
language_version: python3

‎MANIFEST.in

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Things to always exclude
2+
global-exclude .git*
3+
global-exclude .ipynb_checkpoints
4+
global-exclude *.py[co]
5+
global-exclude __pycache__/**
6+
7+
# Top-level Config
8+
include LICENSE
9+
include MANIFEST.in
10+
include setup.cfg
11+
include requirements.txt
12+
13+
# Prompt templates
14+
graft src/marvin/programs
15+
graft src/marvin/prompts

‎README.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
# marvin
2-
1+
# Marvin
2+
3+
> "'Let’s build robots with Genuine People Personalities,' they said. So they tried it out with me. I’m a personality prototype, you can tell, can’t you?"

‎requirements-dev.txt

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
black[jupyter]>=22.12
2+
pre-commit>=2.21.0
3+
pytest-asyncio>=0.20.3
4+
pytest-sugar>=0.9.6
5+
pytest-env>=0.8.1
6+
pytest>=7.2.0
7+
pdbpp>=0.10.3
8+
pyperclip>=1.8.2
9+
ipython>=8.0

‎requirements.txt

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
aiosqlite==0.18.0
2+
asyncpg==0.27.0
3+
cloudpickle==2.2.1
4+
fastapi==0.89.1
5+
httpx==0.23.3
6+
jinja2==3.1.2
7+
nest_asyncio==1.5.6
8+
openai==0.27.0
9+
pendulum==2.1.2
10+
pydantic[dotenv,email]==1.10.4
11+
rich==13.3.1
12+
sqlalchemy[asyncio]==1.4.41
13+
sqlitedict==2.1.0
14+
sqlmodel==0.0.8
15+
tiktoken==0.3.0
16+
ulid-py==1.1.0
17+
uvicorn==0.20.0
18+
xxhash==3.2.0
19+
yake==0.4.8

‎scripts/init.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import asyncio
2+
3+
import marvin
4+
import marvin.examples.prefect
5+
6+
7+
async def main():
8+
# reset the DB
9+
await marvin.database.ddl.reset_db(confirm=True)
10+
11+
# hydrate with docs
12+
await marvin.examples.prefect.load_prefect()
13+
14+
15+
if __name__ == "__main__":
16+
asyncio.run(main())

‎setup.cfg

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
[tool:pytest]
2+
markers =
3+
ai: mark a test as dependent on external AI APIs.
4+
norecursedirs = *.egg-info .git .mypy_cache node_modules .pytest_cache .vscode
5+
asyncio_mode = auto
6+
filterwarnings =
7+
ignore:'crypt' is deprecated and slated for removal in Python 3.13:DeprecationWarning
8+
9+
env =
10+
MARVIN_TEST_MODE=1
11+
D:MARVIN_DATABASE_CONNECTION_URL=sqlite+aiosqlite:////tmp/marvin-tests/test.sqlite
12+
MARVIN_LOG_CONSOLE_WIDTH=120
13+
MARVIN_LOG_LEVEL=DEBUG
14+
15+
[isort]
16+
skip = __init__.py
17+
profile = black
18+
skip_gitignore = True
19+
multi_line_output = 3
20+
21+
[flake8]
22+
# Match black line-length
23+
max-line-length = 88
24+
extend-ignore =
25+
# See https://github.com/PyCQA/pycodestyle/issues/373
26+
E203,
27+
28+
[pycodestyle]
29+
# Match black line-length
30+
max-line-length = 88
31+
extend-ignore =
32+
# See https://github.com/PyCQA/pycodestyle/issues/373
33+
E203,

‎setup.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from setuptools import find_packages, setup
2+
3+
required_deps = open("requirements.txt").read().strip().split("\n")
4+
dev_deps = open("requirements-dev.txt").read().strip().split("\n")
5+
6+
setup(
7+
# Package metadata
8+
name="marvin",
9+
url="https://github.com/PrefectHQ/marvin",
10+
version="0.3",
11+
long_description=open("README.md").read(),
12+
long_description_content_type="text/markdown",
13+
# Package setup
14+
packages=find_packages(where="src"),
15+
package_dir={"": "src"},
16+
include_package_data=True,
17+
# Requirements
18+
python_requires=">=3.10",
19+
install_requires=required_deps,
20+
extras_require={
21+
"dev": required_deps + dev_deps,
22+
},
23+
)

‎src/marvin/__init__.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# load env vars
2+
from dotenv import load_dotenv
3+
4+
load_dotenv()
5+
6+
# load nest_asyncio
7+
import nest_asyncio
8+
9+
nest_asyncio.apply()
10+
11+
# load marvin root objects
12+
from marvin.config import settings
13+
from marvin.utilities.logging import get_logger
14+
15+
# load marvin

‎src/marvin/config.py

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from contextlib import contextmanager
2+
from pathlib import Path
3+
from typing import Literal
4+
5+
from pydantic import BaseSettings, Field, SecretStr, root_validator, validator
6+
from rich import print
7+
from rich.text import Text
8+
9+
import marvin
10+
11+
12+
class Settings(BaseSettings):
13+
class Config:
14+
env_file = ".env"
15+
env_prefix = "MARVIN_"
16+
validate_assignment = True
17+
18+
home: Path = Path("~/.marvin").expanduser()
19+
test_mode: bool = False
20+
21+
# LOGGING
22+
verbose: bool = False
23+
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "DEBUG"
24+
log_console_width: int | None = Field(
25+
None,
26+
description=(
27+
"Marvin will auto-detect the console width when possible, but in deployed"
28+
" settings logs will assume a console width of 80 characters unless"
29+
" specified here."
30+
),
31+
)
32+
rich_tracebacks: bool = Field(False, description="Enable rich traceback formatting")
33+
34+
# EMBEDDINGS
35+
# specify the path to the embeddings cache, relative to the home dir
36+
embeddings_cache_path: Path = Path("cache/embeddings.sqlite")
37+
embeddings_cache_warn_size: int = 4000000000 # 4GB
38+
39+
# OPENAI
40+
openai_api_key: SecretStr = Field(
41+
"", env=["MARVIN_OPENAI_API_KEY", "OPENAI_API_KEY"]
42+
)
43+
44+
# DATABASE
45+
database_echo: bool = False
46+
database_connection_url: SecretStr = "sqlite+aiosqlite:////$HOME/marvin.db"
47+
48+
# REDIS
49+
redis_connection_url: SecretStr = ""
50+
51+
# BOTS
52+
bot_create_profile_picture: bool = False
53+
54+
@root_validator
55+
def initial_setup(cls, values):
56+
values["home"].mkdir(parents=True, exist_ok=True)
57+
58+
# prefix HOME to embeddings cache path
59+
if not values["embeddings_cache_path"].is_absolute():
60+
values["embeddings_cache_path"] = (
61+
values["home"] / values["embeddings_cache_path"]
62+
)
63+
values["embeddings_cache_path"].parent.mkdir(parents=True, exist_ok=True)
64+
65+
# interpolate HOME into database connection URL
66+
values["database_connection_url"] = SecretStr(
67+
values["database_connection_url"]
68+
.get_secret_value()
69+
.replace("$HOME", str(values["home"]))
70+
)
71+
72+
# print if verbose = True
73+
if values["verbose"]:
74+
print(Text("Verbose mode enabled", style="green"))
75+
76+
return values
77+
78+
@validator("openai_api_key")
79+
def warn_if_missing_api_keys(cls, v, field):
80+
if not v:
81+
print(
82+
Text(
83+
f"WARNING: `{field.name}` is not set. Some features may not work.",
84+
style="red",
85+
)
86+
)
87+
return v
88+
89+
@root_validator
90+
def test_mode_settings(cls, values):
91+
if values["test_mode"]:
92+
print(Text("Marvin is running in test mode!", style="yellow"))
93+
values["log_level"] = "DEBUG"
94+
values["verbose"] = True
95+
return values
96+
97+
def __setattr__(self, name, value):
98+
result = super().__setattr__(name, value)
99+
# update log level on assignment
100+
if name == "log_level":
101+
marvin.utilities.logging.setup_logging()
102+
return result
103+
104+
105+
settings = Settings()
106+
107+
108+
@contextmanager
109+
def temporary_settings(**kwargs):
110+
old_settings = settings.dict()
111+
settings.__dict__.update(kwargs)
112+
try:
113+
yield
114+
finally:
115+
settings.__dict__.clear()
116+
settings.__dict__.update(old_settings)

‎src/marvin/infra/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . import db

‎src/marvin/infra/db.py

+134
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import inspect
2+
from contextlib import asynccontextmanager
3+
from functools import wraps
4+
from typing import AsyncGenerator, Callable, Literal
5+
6+
import sqlmodel
7+
from sqlalchemy.dialects.postgresql import JSONB as postgres_JSONB
8+
from sqlalchemy.dialects.sqlite import JSON as sqlite_JSON
9+
from sqlalchemy.ext.asyncio import create_async_engine
10+
from sqlalchemy.orm import sessionmaker
11+
from sqlmodel.ext.asyncio.session import AsyncSession
12+
13+
import marvin
14+
15+
engine_kwargs = {}
16+
# sqlite doesn't support pool configuration
17+
if marvin.settings.database_connection_url.get_secret_value().startswith("postgresql"):
18+
engine_kwargs.update(
19+
pool_size=50,
20+
max_overflow=20,
21+
)
22+
23+
engine = create_async_engine(
24+
marvin.settings.database_connection_url.get_secret_value(),
25+
echo=marvin.settings.database_echo,
26+
**engine_kwargs,
27+
)
28+
async_session_maker = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
29+
30+
31+
def get_dialect() -> Literal["postgresql", "sqlite"]:
32+
return engine.dialect.name
33+
34+
35+
async def get_session() -> AsyncGenerator[AsyncSession, None]:
36+
return async_session_maker()
37+
38+
39+
@asynccontextmanager
40+
async def session_context(begin_transaction: bool = False):
41+
"""
42+
Provides a SQLAlchemy session and a context manager for opening/closing
43+
the underlying connection.
44+
45+
Args:
46+
begin_transaction: if True, the context manager will begin a SQL transaction.
47+
Exiting the context manager will COMMIT or ROLLBACK any changes.
48+
"""
49+
async with await get_session() as session:
50+
if begin_transaction:
51+
async with session.begin():
52+
yield session
53+
else:
54+
try:
55+
yield session
56+
await session.commit()
57+
except Exception:
58+
await session.rollback()
59+
raise
60+
61+
62+
def provide_session(begin_transaction: bool = False) -> Callable:
63+
"""
64+
Decorator that provides a database interface to a function.
65+
66+
The decorated function _must_ have a kwarg that is annotated as `AsyncSession`.
67+
"""
68+
if isinstance(begin_transaction, Callable):
69+
raise TypeError("provide_session() must be called when decorating a function.")
70+
71+
def wrapper(fn: Callable) -> Callable:
72+
SESSION_KWARG = None
73+
sig = inspect.signature(fn)
74+
for name, param in sig.parameters.items():
75+
if param.annotation is AsyncSession:
76+
SESSION_KWARG = name
77+
break
78+
if SESSION_KWARG is None:
79+
raise TypeError("No `AsyncSession` kwarg found in function signature.")
80+
81+
@wraps(fn)
82+
async def async_wrapper(*args, **kwargs):
83+
try:
84+
arguments = sig.bind_partial(*args, **kwargs).arguments
85+
86+
# typeerror would indicate an illegal argument was passed;
87+
# we'll let the function reraise for clarity
88+
except TypeError:
89+
arguments = {}
90+
91+
if SESSION_KWARG not in arguments or arguments[SESSION_KWARG] is None:
92+
async with session_context(
93+
begin_transaction=begin_transaction
94+
) as session:
95+
kwargs[SESSION_KWARG] = session
96+
return await fn(*args, **kwargs)
97+
return await fn(*args, **kwargs)
98+
99+
return async_wrapper
100+
101+
return wrapper
102+
103+
104+
if get_dialect() == "sqlite":
105+
JSONType = sqlite_JSON
106+
else:
107+
JSONType = postgres_JSONB
108+
109+
110+
async def destroy_db(confirm: bool = False):
111+
if not confirm:
112+
raise ValueError("You must confirm that you want to destroy the database.")
113+
114+
async with session_context(begin_transaction=True) as session:
115+
for table in reversed(sqlmodel.SQLModel.metadata.sorted_tables):
116+
if marvin.database.engine.get_dialect() == "postgresql":
117+
await session.execute(f'DROP TABLE IF EXISTS "{table.name}" CASCADE;')
118+
else:
119+
await session.execute(f'DROP TABLE IF EXISTS "{table.name}";')
120+
marvin.get_logger("db").debug_style(
121+
f"Table {table.name!r} dropped.", "white on red"
122+
)
123+
marvin.get_logger("db").info_style("Database destroyed!", "white on red")
124+
125+
126+
async def create_db():
127+
async with marvin.database.engine.engine.begin() as conn:
128+
await conn.run_sync(sqlmodel.SQLModel.metadata.create_all)
129+
marvin.get_logger("db").info_style("Database created!", "green")
130+
131+
132+
async def reset_db(confirm: bool = False):
133+
await destroy_db(confirm=confirm)
134+
await create_db()

‎src/marvin/utilities/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . import logging, async_utils, types, strings, collections, tests

‎src/marvin/utilities/async_utils.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import asyncio
2+
import concurrent.futures
3+
import functools
4+
import multiprocessing as mp
5+
6+
import cloudpickle
7+
8+
import marvin
9+
10+
process_pool = concurrent.futures.ProcessPoolExecutor(mp_context=mp.get_context("fork"))
11+
12+
13+
async def run_async(func, *args, **kwargs):
14+
loop = asyncio.get_event_loop()
15+
return await loop.run_in_executor(None, lambda: func(*args, **kwargs))
16+
17+
18+
def _cloudpickle_wrapper(pickle):
19+
return cloudpickle.loads(pickle)()
20+
21+
22+
async def run_async_process(func, *args, **kwargs):
23+
# in test mode, don't spawn processes
24+
if marvin.settings.test_mode:
25+
return await run_async(func, *args, **kwargs)
26+
27+
pickled_func = cloudpickle.dumps(functools.partial(func, *args, **kwargs))
28+
loop = asyncio.get_event_loop()
29+
return await loop.run_in_executor(process_pool, _cloudpickle_wrapper, pickled_func)

‎src/marvin/utilities/collections.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import itertools
2+
from typing import Any, Callable, Iterable, TypeVar
3+
4+
T = TypeVar("T")
5+
6+
7+
def batched(
8+
iterable: Iterable[T], size: int, size_fn: Callable[[Any], int] = None
9+
) -> Iterable[T]:
10+
"""
11+
If size_fn is not provided, then the batch size will be determined by the
12+
number of items in the batch.
13+
14+
If size_fn is provided, then it will be used
15+
to compute the batch size. Note that if a single item is larger than the
16+
batch size, it will be returned as a batch of its own.
17+
"""
18+
if size_fn is None:
19+
it = iter(iterable)
20+
while True:
21+
batch = tuple(itertools.islice(it, size))
22+
if not batch:
23+
break
24+
yield batch
25+
else:
26+
batch = []
27+
batch_size = 0
28+
for item in iter(iterable):
29+
batch.append(item)
30+
batch_size += size_fn(item)
31+
if batch_size > size:
32+
yield batch
33+
batch = []
34+
batch_size = 0
35+
if batch:
36+
yield batch

‎src/marvin/utilities/logging.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import logging
2+
from functools import lru_cache, partial
3+
4+
from rich.console import Console
5+
from rich.logging import RichHandler
6+
from rich.markup import escape
7+
from rich.traceback import install as install_rich_tracebacks
8+
9+
import marvin
10+
11+
12+
@lru_cache()
13+
def get_logger(name: str = None) -> logging.Logger:
14+
parent_logger = logging.getLogger("marvin")
15+
16+
if name:
17+
# Append the name if given but allow explicit full names e.g. "marvin.test"
18+
# should not become "marvin.marvin.test"
19+
if not name.startswith(parent_logger.name + "."):
20+
logger = parent_logger.getChild(name)
21+
else:
22+
logger = logging.getLogger(name)
23+
else:
24+
logger = parent_logger
25+
26+
add_logging_methods(logger)
27+
return logger
28+
29+
30+
def setup_logging():
31+
logger = get_logger()
32+
logger.setLevel(marvin.settings.log_level)
33+
34+
if not any(isinstance(h, RichHandler) for h in logger.handlers):
35+
handler = RichHandler(
36+
rich_tracebacks=True,
37+
markup=False,
38+
console=Console(width=marvin.settings.log_console_width),
39+
)
40+
formatter = logging.Formatter("%(name)s: %(message)s")
41+
handler.setFormatter(formatter)
42+
logger.addHandler(handler)
43+
44+
45+
def add_logging_methods(logger):
46+
def log_style(level: int, message: str, style: str = None):
47+
if not style:
48+
style = "default on default"
49+
message = f"[{style}]{escape(str(message))}[/]"
50+
logger.log(level, message, extra={"markup": True})
51+
52+
def log_kv(
53+
level: int,
54+
key: str,
55+
value: str,
56+
key_style: str = "default on default",
57+
value_style: str = "default on default",
58+
delimiter: str = ": ",
59+
):
60+
logger.log(
61+
level,
62+
f"[{key_style}]{escape(str(key))}{delimiter}[/][{value_style}]{escape(str(value))}[/]",
63+
extra={"markup": True},
64+
)
65+
66+
logger.debug_style = partial(log_style, logging.DEBUG)
67+
logger.info_style = partial(log_style, logging.INFO)
68+
logger.warning_style = partial(log_style, logging.WARNING)
69+
logger.error_style = partial(log_style, logging.ERROR)
70+
logger.critical_style = partial(log_style, logging.CRITICAL)
71+
72+
logger.debug_kv = partial(log_kv, logging.DEBUG)
73+
logger.info_kv = partial(log_kv, logging.INFO)
74+
logger.warning_kv = partial(log_kv, logging.WARNING)
75+
logger.error_kv = partial(log_kv, logging.ERROR)
76+
logger.critical_kv = partial(log_kv, logging.CRITICAL)
77+
78+
79+
setup_logging()
80+
if marvin.settings.rich_tracebacks:
81+
install_rich_tracebacks()

‎src/marvin/utilities/strings.py

+197
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import asyncio
2+
import re
3+
from functools import lru_cache
4+
from string import Formatter
5+
from typing import Any, Callable, Mapping, Sequence, Union
6+
7+
import pendulum
8+
import tiktoken
9+
import xxhash
10+
from jinja2 import ChoiceLoader, Environment, StrictUndefined, select_autoescape
11+
12+
import marvin
13+
14+
jinja_env = Environment(
15+
loader=ChoiceLoader(
16+
[
17+
# PackageLoader("marvin", "prompts"),
18+
# PackageLoader("marvin", "programs"),
19+
]
20+
),
21+
autoescape=select_autoescape(default_for_string=False),
22+
trim_blocks=True,
23+
lstrip_blocks=True,
24+
enable_async=True,
25+
auto_reload=True,
26+
undefined=StrictUndefined,
27+
)
28+
jinja_env.globals.update(
29+
zip=zip,
30+
str=str,
31+
len=len,
32+
arun=asyncio.run,
33+
pendulum=pendulum,
34+
dt=lambda: pendulum.now("UTC").to_day_datetime_string(),
35+
)
36+
37+
38+
class StrictFormatter(Formatter):
39+
"""A subclass of formatter that checks for extra keys."""
40+
41+
def check_unused_args(
42+
self,
43+
used_args: Sequence[Union[int, str]],
44+
args: Sequence,
45+
kwargs: Mapping[str, Any],
46+
) -> None:
47+
"""Check to see if extra parameters are passed."""
48+
extra = set(kwargs).difference(used_args)
49+
if extra:
50+
raise KeyError(extra)
51+
52+
53+
@lru_cache(maxsize=2000)
54+
def hash_text(*text: str) -> str:
55+
bs = [t.encode() if not isinstance(t, bytes) else t for t in text]
56+
return xxhash.xxh3_128_hexdigest(b"".join(bs))
57+
58+
59+
VERSION_NUMBERS = re.compile(r"\b\d+\.\d+(?:\.\d+)?\w*\b")
60+
61+
62+
def tokenize(text: str) -> list[int]:
63+
tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
64+
return tokenizer.encode(text)
65+
66+
67+
def detokenize(tokens: list[int]) -> str:
68+
tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
69+
return tokenizer.decode(tokens)
70+
71+
72+
def count_tokens(text: str) -> int:
73+
return len(tokenize(text))
74+
75+
76+
def slice_tokens(text: str, n_tokens: int) -> str:
77+
tokens = tokenize(text)
78+
return detokenize(tokens[:n_tokens])
79+
80+
81+
def split_text(
82+
text: str,
83+
chunk_size: int,
84+
chunk_overlap: float = None,
85+
last_chunk_threshold: float = None,
86+
return_index: bool = False,
87+
) -> str | tuple[str, int]:
88+
"""
89+
Split a text into a list of strings. Chunks are split by tokens.
90+
91+
Args:
92+
text (str): The text to split.
93+
chunk_size (int): The number of tokens in each chunk.
94+
chunk_overlap (float): The fraction of overlap between chunks.
95+
last_chunk_threshold (float): If the last chunk is less than this fraction of
96+
the chunk_size, it will be added to the prior chunk
97+
return_index (bool): If True, return a tuple of (chunk, index) where index is the
98+
character index of the start of the chunk in the original text.
99+
"""
100+
if chunk_overlap is None:
101+
chunk_overlap = 0.1
102+
if chunk_overlap < 0 or chunk_overlap > 1:
103+
raise ValueError("chunk_overlap must be between 0 and 1")
104+
if last_chunk_threshold is None:
105+
last_chunk_threshold = 0.25
106+
107+
tokens = tokenize(text)
108+
109+
chunks = []
110+
for i in range(0, len(tokens), chunk_size - int(chunk_overlap * chunk_size)):
111+
chunks.append((tokens[i : i + chunk_size], len(detokenize(tokens[:i]))))
112+
113+
# if the last chunk is too small, merge it with the previous chunk
114+
if len(chunks) > 1 and len(chunks[-1][0]) < chunk_size * last_chunk_threshold:
115+
chunks[-2][0].extend(chunks.pop(-1)[0])
116+
117+
if return_index:
118+
return [(detokenize(chunk), index) for chunk, index in chunks]
119+
else:
120+
return [detokenize(chunk) for chunk, _ in chunks]
121+
122+
123+
def _extract_keywords(text: str, n_keywords: int = None) -> list[str]:
124+
# deferred import
125+
import yake
126+
127+
kw = yake.KeywordExtractor(
128+
lan="en",
129+
n=1,
130+
dedupLim=0.9,
131+
dedupFunc="seqm",
132+
windowsSize=1,
133+
top=n_keywords or marvin.settings.default_n_keywords,
134+
features=None,
135+
)
136+
keywords = kw.extract_keywords(text)
137+
# return only keyword, not score
138+
return [k[0] for k in keywords]
139+
140+
141+
async def extract_keywords(text: str, n_keywords: int = None) -> list[str]:
142+
# keyword extraction can take a while and is blocking
143+
return await marvin.utilities.async_utils.run_async_process(
144+
_extract_keywords, text=text, n_keywords=n_keywords
145+
)
146+
# return _extract_keywords(text=text, n_keywords=n_keywords)
147+
148+
149+
def create_minimap_fn(content: str) -> Callable[[int], str]:
150+
"""
151+
Given a document with markdown headers, returns a function that outputs the current headers
152+
for any character position in the document.
153+
"""
154+
minimap: dict[int, str] = {}
155+
in_code_block = False
156+
current_stack = {}
157+
characters = 0
158+
for line in content.splitlines():
159+
characters += len(line)
160+
if line.startswith("```"):
161+
in_code_block = not in_code_block
162+
if in_code_block:
163+
continue
164+
165+
if line.startswith("# "):
166+
current_stack = {1: line}
167+
elif line.startswith("## "):
168+
for i in range(2, 6):
169+
current_stack.pop(i, None)
170+
current_stack[2] = line
171+
elif line.startswith("### "):
172+
for i in range(3, 6):
173+
current_stack.pop(i, None)
174+
current_stack[3] = line
175+
elif line.startswith("#### "):
176+
for i in range(4, 6):
177+
current_stack.pop(i, None)
178+
current_stack[4] = line
179+
elif line.startswith("##### "):
180+
for i in range(5, 6):
181+
current_stack.pop(i, None)
182+
current_stack[5] = line
183+
else:
184+
continue
185+
186+
minimap[characters - len(line)] = current_stack
187+
188+
def get_location_fn(n: int) -> str:
189+
if n < 0:
190+
raise ValueError("n must be >= 0")
191+
# get the stack of headers that is closest to - but before - the current position
192+
stack = minimap.get(max((k for k in minimap if k <= n), default=0), {})
193+
194+
ordered_stack = [stack.get(i) for i in range(1, 6)]
195+
return "\n".join([s for s in ordered_stack if s is not None])
196+
197+
return get_location_fn

‎src/marvin/utilities/tests.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import httpx
2+
3+
4+
def assert_status_code(response: httpx.Response, status_code: int):
5+
try:
6+
full_response = response.json()
7+
except:
8+
full_response = response.text
9+
error_message = (
10+
f"assert {response.status_code} == {status_code}"
11+
f"\nFull response: {full_response}"
12+
)
13+
assert response.status_code == status_code, error_message

‎src/marvin/utilities/types.py

+174
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import json
2+
import re
3+
from functools import lru_cache
4+
from typing import Any, Callable, Generic, TypeVar
5+
6+
import pydantic
7+
import ulid
8+
from fastapi import APIRouter, Response, status
9+
from fastapi.encoders import jsonable_encoder
10+
from pydantic import BaseModel, constr
11+
from sqlalchemy import TypeDecorator
12+
13+
from marvin.infra.db import JSONType
14+
15+
T = TypeVar("T")
16+
UUID_REGEX = re.compile(
17+
r"\b[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}\b"
18+
)
19+
# ulid
20+
ULID_REGEX = r"\b[0-9A-HJ-NP-TV-Z]{26}\b"
21+
# specific prefix
22+
PREFIXED_ULID_REGEX = r"\b{prefix}_[0-9A-HJ-NP-TV-Z]{{26}}\b"
23+
# any prefix
24+
ANY_PREFIX_ULID_REGEX = r"\b[^\W0-9_][^\W_]+_[0-9A-HJ-NP-TV-Z]{26}\b"
25+
# optional prefix
26+
ANY_ULID_REGEX = r"\b(?:[^\W0-9_][^\W_]+_)?[0-9A-HJ-NP-TV-Z]{26}\b"
27+
28+
29+
@lru_cache()
30+
def get_id_type(prefix: str = None) -> type:
31+
if prefix is None:
32+
type_ = constr(regex=ULID_REGEX)
33+
type_.new = lambda: str(ulid.new())
34+
else:
35+
if "_" in prefix:
36+
raise ValueError("Prefix must not contain underscores.")
37+
type_ = constr(regex=PREFIXED_ULID_REGEX.format(prefix=prefix))
38+
type_.new = lambda: f"{prefix}_{ulid.new()}"
39+
type_.regex = PREFIXED_ULID_REGEX.format(prefix=prefix)
40+
return type_
41+
42+
43+
class MarvinBaseModel(BaseModel):
44+
class Config:
45+
copy_on_model_validation = "shallow"
46+
validate_assignment = True
47+
extra = "forbid"
48+
json_encoders = {}
49+
50+
def dict(self, *args, json_compatible=False, **kwargs):
51+
if json_compatible:
52+
return json.loads(self.json(*args, **kwargs))
53+
return super().dict(*args, **kwargs)
54+
55+
def copy_with_updates(self, exclude: set[str] = None, **updates):
56+
"""
57+
Copies the current model and updates the copy with the provided updates,
58+
which can be partial nested dictionaries.
59+
60+
Unlike `copy(update=updates)`, this method will properly validate
61+
updates and apply nested updates.
62+
"""
63+
updated = self.dict(exclude=exclude)
64+
65+
stack = [(updated, k, v) for k, v in updates.items()]
66+
while stack:
67+
m, k, v = stack.pop()
68+
mv = m.get(k)
69+
if isinstance(mv, dict) and isinstance(v, dict):
70+
stack.extend([(mv, vk, vv) for vk, vv in v.items()])
71+
else:
72+
m[k] = v
73+
74+
excluded = set(self.__exclude_fields__ or []).union(exclude or [])
75+
excluded_kwargs = {e: getattr(self, e) for e in excluded if e not in updated}
76+
return type(self)(**updated, **excluded_kwargs)
77+
78+
79+
class MarvinRouter(APIRouter):
80+
"""
81+
Utilities to make the router a little more convenient to use.
82+
"""
83+
84+
def add_api_route(
85+
self, path: str, endpoint: Callable[..., Any], **kwargs: Any
86+
) -> None:
87+
"""
88+
Add an API route.
89+
90+
For routes that return content and have not specified a `response_model`,
91+
use return type annotation to infer the response model.
92+
93+
For routes that return No-Content status codes, explicitly set
94+
a `response_class` to ensure nothing is returned in the response body.
95+
"""
96+
if kwargs.get("status_code") == status.HTTP_204_NO_CONTENT:
97+
# any routes that return No-Content status codes must
98+
# explicilty set a response_class that will handle status codes
99+
# and not return anything in the body
100+
kwargs["response_class"] = Response
101+
return super().add_api_route(path, endpoint, **kwargs)
102+
103+
104+
def pydantic_column_type(pydantic_type):
105+
"""
106+
SA Column for converting pydantic models to and from JSON
107+
"""
108+
109+
class PydanticJSONType(TypeDecorator, Generic[T]):
110+
impl = JSONType()
111+
112+
def bind_processor(self, dialect):
113+
impl_processor = self.impl.bind_processor(dialect)
114+
if impl_processor:
115+
116+
def process(value: T):
117+
if value is not None:
118+
if isinstance(pydantic_type, pydantic.main.ModelMetaclass):
119+
# This allows to assign non-InDB models and if they're
120+
# compatible, they're directly parsed into the InDB
121+
# representation, thus hiding the implementation in the
122+
# background. However, the InDB model will still be returned
123+
value_to_dump = pydantic_type.from_orm(value)
124+
else:
125+
value_to_dump = value
126+
value = jsonable_encoder(value_to_dump)
127+
return impl_processor(value)
128+
129+
else:
130+
131+
def process(value):
132+
if isinstance(pydantic_type, pydantic.main.ModelMetaclass):
133+
# This allows to assign non-InDB models and if they're
134+
# compatible, they're directly parsed into the InDB
135+
# representation, thus hiding the implementation in the
136+
# background. However, the InDB model will still be returned
137+
value_to_dump = pydantic_type.from_orm(value)
138+
else:
139+
value_to_dump = value
140+
value = json.dumps(jsonable_encoder(value_to_dump))
141+
return value
142+
143+
return process
144+
145+
def result_processor(self, dialect, coltype) -> T:
146+
impl_processor = self.impl.result_processor(dialect, coltype)
147+
if impl_processor:
148+
149+
def process(value):
150+
value = impl_processor(value)
151+
if value is None:
152+
return None
153+
154+
data = value
155+
# Explicitly use the generic directly, not type(T)
156+
full_obj = pydantic.parse_obj_as(pydantic_type, data)
157+
return full_obj
158+
159+
else:
160+
161+
def process(value):
162+
if value is None:
163+
return None
164+
165+
# Explicitly use the generic directly, not type(T)
166+
full_obj = pydantic.parse_obj_as(pydantic_type, value)
167+
return full_obj
168+
169+
return process
170+
171+
def compare_values(self, x, y):
172+
return x == y
173+
174+
return PydanticJSONType

0 commit comments

Comments
 (0)
Please sign in to comment.