Skip to content

Commit

Permalink
Merge branch 'main' into fix-harmonize-torch
Browse files Browse the repository at this point in the history
  • Loading branch information
michdr authored Jan 27, 2025
2 parents 5313e22 + c9225cc commit 1a96db5
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 4,255 deletions.
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ from pathlib import Path
import pandas as pd
from mostlyai import engine

# set up workspace
# set up workspace and default logging
ws = Path("ws-tabular-flat")
engine.init_logging()

# load original data
url = "https://github.com/mostly-ai/public-demo-data/raw/refs/heads/dev/census"
Expand Down Expand Up @@ -73,8 +74,11 @@ from pathlib import Path
import pandas as pd
from mostlyai import engine

# set up workspace
engine.init_logging()

# set up workspace and default logging
ws = Path("ws-tabular-sequential")
engine.init_logging()

# load original data
url = "https://github.com/mostly-ai/public-demo-data/raw/refs/heads/dev/baseball"
Expand Down Expand Up @@ -107,8 +111,9 @@ from pathlib import Path
import pandas as pd
from mostlyai import engine

# set up workspace
# init workspace and logging
ws = Path("ws-language-flat")
engine.init_logging()

# load original data
trn_df = pd.read_parquet("https://github.com/mostly-ai/public-demo-data/raw/refs/heads/dev/headlines/headlines.parquet")
Expand Down
23 changes: 3 additions & 20 deletions examples/flat.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,6 @@
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mostly-ai/mostlyai-engine/blob/main/examples/flat.ipynb)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import logging\n",
"import sys\n",
"import numpy as np\n",
"\n",
"logging.basicConfig(\n",
" level=logging.INFO,\n",
" stream=sys.stdout,\n",
" format=\"[%(asctime)s] %(levelname)-7s: %(message)s\",\n",
" datefmt=\"%Y-%m-%d %H:%M:%S\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -44,8 +26,9 @@
"import pandas as pd\n",
"from mostlyai import engine\n",
"\n",
"# set up workspace\n",
"# init workspace and logging\n",
"ws = Path(\"ws-tabular-flat\")\n",
"engine.init_logging()\n",
"\n",
"# load original data\n",
"url = \"https://github.com/mostly-ai/public-demo-data/raw/refs/heads/dev/census\"\n",
Expand Down Expand Up @@ -152,7 +135,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
"version": "3.12.3"
},
"toc": {
"base_numbering": 1,
Expand Down
132 changes: 3 additions & 129 deletions examples/language.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,126 +16,6 @@
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mostly-ai/mostlyai-engine/blob/main/examples/language.ipynb)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"import logging\n",
"import sys\n",
"import numpy as np\n",
"\n",
"logging.basicConfig(\n",
" level=logging.INFO,\n",
" stream=sys.stdout,\n",
" format=\"[%(asctime)s] %(levelname)-7s: %(message)s\",\n",
" datefmt=\"%Y-%m-%d %H:%M:%S\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"dataset_name = \"sacred\"\n",
"ctx_encoding_types = {\"book\": \"TABULAR_CATEGORICAL\"}\n",
"tgt_encoding_types = {\"text\": \"LANGUAGE_TEXT\"}\n",
"fn = \"https://github.com/mostly-ai/public-demo-data/raw/dev/sacred_verses/sacred.csv.gz\"\n",
"df = pd.read_csv(fn)[list(ctx_encoding_types.keys()) + list(tgt_encoding_types.keys())]\n",
"df.text = df[\"text\"].str[:30] # trim to 30 chars max to speed up demo\n",
"print(df.shape)\n",
"df.iloc[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"ws = Path(\"language-ws\")\n",
"\n",
"pk = \"pk\"\n",
"df[pk] = list(range(df.shape[0]))\n",
"ctx_columns = [pk, *[key for key in ctx_encoding_types.keys()]] if ctx_encoding_types else [pk]\n",
"tgt_columns = [pk, *[key for key in tgt_encoding_types.keys()]]\n",
"ctx_df = df[ctx_columns]\n",
"tgt_df = df[tgt_columns]\n",
"\n",
"split(\n",
" tgt_data=tgt_df,\n",
" tgt_context_key=pk,\n",
" tgt_encoding_types=tgt_encoding_types,\n",
" ctx_data=ctx_df,\n",
" ctx_primary_key=pk,\n",
" ctx_encoding_types=ctx_encoding_types,\n",
" workspace_dir=ws,\n",
")\n",
"analyze(workspace_dir=ws)\n",
"encode(workspace_dir=ws)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"encoded_data = pd.read_parquet(ws / \"OriginalData\" / \"encoded-data\")\n",
"encoded_data.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"# train(workspace_dir=workspace_dir, max_training_time=2, model=\"Locutusque/TinyMistral-248M\")\n",
"# train(workspace_dir=workspace_dir, max_training_time=2, model=\"EleutherAI/pythia-160m\")\n",
"# train(workspace_dir=workspace_dir, max_training_time=2, model=\"EleutherAI/pythia-410m\")\n",
"train(workspace_dir=ws, max_training_time=2, model=\"MOSTLY_AI/LSTMFromScratch-3m\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"generate(sample_size=100, workspace_dir=ws)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -146,8 +26,9 @@
"import pandas as pd\n",
"from mostlyai import engine\n",
"\n",
"# set up workspace\n",
"# init workspace and logging\n",
"ws = Path(\"ws-language-flat\")\n",
"engine.init_logging()\n",
"\n",
"# load original data\n",
"url = \"https://github.com/mostly-ai/public-demo-data/raw/refs/heads/dev/arxiv\"\n",
Expand Down Expand Up @@ -182,13 +63,6 @@
"syn_tgt_df = pd.read_parquet(ws / \"SyntheticData\") # load synthetic data\n",
"syn_tgt_df.head(5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -207,7 +81,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
"version": "3.12.3"
},
"toc": {
"base_numbering": 1,
Expand Down
38 changes: 4 additions & 34 deletions examples/sequential.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,41 +20,18 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2024-12-02T22:08:30.505289Z",
"start_time": "2024-12-02T22:08:30.501492Z"
},
"tags": []
},
"outputs": [],
"source": [
"import logging\n",
"import sys\n",
"import numpy as np\n",
"\n",
"logging.basicConfig(\n",
" level=logging.INFO,\n",
" stream=sys.stdout,\n",
" format=\"[%(asctime)s] %(levelname)-7s: %(message)s\",\n",
" datefmt=\"%Y-%m-%d %H:%M:%S\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import pandas as pd\n",
"import numpy as np\n",
"from mostlyai import engine\n",
"\n",
"# set up workspace\n",
"# init workspace and logging\n",
"ws = Path(\"ws-tabular-sequential\")\n",
"engine.init_logging()\n",
"\n",
"# load original data\n",
"url = \"https://github.com/mostly-ai/public-demo-data/raw/refs/heads/dev/baseball\"\n",
Expand Down Expand Up @@ -145,13 +122,6 @@
"trn_avg_teams_per_player = trn_tgt_df.groupby(\"players_id\")[\"team\"].nunique().mean().round(1)\n",
"syn_avg_teams_per_player, trn_avg_teams_per_player"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -170,7 +140,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
"version": "3.12.3"
},
"toc": {
"base_numbering": 1,
Expand Down
3 changes: 2 additions & 1 deletion mostlyai/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
# limitations under the License.
import warnings

from mostlyai.engine.logging import init_logging
from mostlyai.engine.splitting import split
from mostlyai.engine.analysis import analyze
from mostlyai.engine.training import train
from mostlyai.engine.encoding import encode
from mostlyai.engine.generation import generate


__all__ = ["split", "analyze", "encode", "train", "generate"]
__all__ = ["split", "analyze", "encode", "train", "generate", "init_logging"]
__version__ = "1.0.1"

# suppress specific warning related to os.fork() in multi-threaded processes
Expand Down
15 changes: 0 additions & 15 deletions mostlyai/engine/_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,21 +168,6 @@ def make_stats_json_path_desc(**kwargs):
self.generated_data_path = Path(self._ws_path) / Path(*generated_data)
self.generated_data = make_path_desc(parts=generated_data, fetch_multiple=fetch_part_parquets)

# Report-related
self.report_ctx_data_path = self._ws_path / "report-ctx-data"
self.report_ctx_data = make_path_desc(
parts=["report-ctx-data"],
fetch_multiple=fetch_part_parquets,
)
self.report_trn_ctx_data = make_path_desc(
parts=["report-ctx-data"],
fetch_multiple=fetch_part_trn_parquets,
)
self.report_val_ctx_data = make_path_desc(
parts=["report-ctx-data"],
fetch_multiple=fetch_part_val_parquets,
)


def ensure_workspace_dir(workspace_dir: str | Path) -> Path:
workspace_dir = Path(workspace_dir)
Expand Down
34 changes: 34 additions & 0 deletions mostlyai/engine/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2025 MOSTLY AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import logging

_LOG = logging.getLogger(__name__.rsplit(".", 1)[0]) # get the logger with the root module name (mostlyai.engine)


def init_logging() -> None:
"""
Initialize the logging configuration. Either log to stdout or to a file.
"""

# log to stdout
handler = logging.StreamHandler(stream=sys.stdout)
handler.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)-7s: %(message)s"))
handler.setLevel(logging.INFO)

if not _LOG.hasHandlers():
_LOG.addHandler(handler)
_LOG.setLevel(logging.INFO)
_LOG.propagate = False
Loading

0 comments on commit 1a96db5

Please sign in to comment.