Skip to content

Commit

Permalink
add Nakdimon.h5 to package data
Browse files Browse the repository at this point in the history
Signed-off-by: Elazar Gershuni <elazarg@gmail.com>
  • Loading branch information
elazarg committed Aug 3, 2024
1 parent fd3f1a3 commit 01c631b
Show file tree
Hide file tree
Showing 10 changed files with 35 additions and 25 deletions.
5 changes: 1 addition & 4 deletions examples/usage.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# pip install git+https://github.com/nakdimon/nakdimon.git
# mkdir models
# wget https://github.com/elazarg/nakdimon/raw/master/models/Nakdimon.h5
# mv Nakdimon.h5 models/Nakdimon.h5

from nakdimon import diacritize

result = diacritize("שלום עולם!", "models/Nakdimon.h5")
result = diacritize("שלום עולם!")
print(result)
File renamed without changes.
8 changes: 4 additions & 4 deletions nakdimon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import os
import logging

from nakdimon.config import MAIN_MODEL

def do_train(**kwargs) -> None:
from nakdimon import train
Expand Down Expand Up @@ -48,7 +48,7 @@ def main() -> None:

parser_train = subparsers.add_parser('train', help='train Nakdimon')
parser_train.add_argument('--wandb', action='store_true', help='use wandb.', default=False)
parser_train.add_argument('--model', help='path to output model (.h5 file)', default='models/Full.h5', dest='model_path')
parser_train.add_argument('--model', help='path to output model (.h5 file)', default=MAIN_MODEL, dest='model_path')
parser_train.add_argument('--ablation', help='ablation test', default=None, dest='ablation_name')
parser_train.set_defaults(func=do_train)

Expand All @@ -60,7 +60,7 @@ def main() -> None:
parser_test = subparsers.add_parser('run_test', help='diacritize a test set')
parser_test.add_argument('--test_set', choices=available_tests, help='choose test set', default='tests/new')
parser_test.add_argument('--system', choices=test_systems, help='diacritization system to use', default='Nakdimon')
parser_test.add_argument('--model', help='path to model (.h5 file)', default='models/Nakdimon.h5', dest='model_path')
parser_test.add_argument('--model', help='path to model (.h5 file)', default=MAIN_MODEL, dest='model_path')
parser_test.add_argument('--skip-existing', action='store_true', help='skip existing files')
parser_test.set_defaults(func=do_run_test)

Expand Down Expand Up @@ -114,6 +114,6 @@ def diacritize_main():
sys.exit(0)


def diacritize(text: str, model_path: str = 'models/Nakdimon.h5') -> str:
def diacritize(text: str, model_path: str = MAIN_MODEL) -> str:
import nakdimon.predict
return nakdimon.predict.predict(text, model_path)
4 changes: 4 additions & 0 deletions nakdimon/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from importlib.resources import files

MODELS_DIR = 'models'
MAIN_MODEL = files('nakdimon').joinpath('Nakdimon.h5')
3 changes: 2 additions & 1 deletion nakdimon/external_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from nakdimon.hebrew import Niqqud
from nakdimon import hebrew
from nakdimon.config import MAIN_MODEL


class DottingError(RuntimeError):
Expand Down Expand Up @@ -184,7 +185,7 @@ def run_nakdimon(text: str) -> str:
'Snopi': fetch_snopi, # Too slow
'Morfix': fetch_morfix, # terms-of-use issue
'Dicta': fetch_dicta,
'Nakdimon': make_nakdimon_no_server('models/Nakdimon.h5'),
'Nakdimon': make_nakdimon_no_server(MAIN_MODEL),
}
all_oov = set()

Expand Down
19 changes: 11 additions & 8 deletions nakdimon/predict.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import logging
import pathlib
from functools import lru_cache

import tensorflow as tf

from nakdimon import utils, dataset, hebrew

from nakdimon.config import MAIN_MODEL

if tf.config.set_visible_devices([], 'GPU'):
logging.warning('No GPU available.')


@lru_cache()
def load_cached_model(m):
def load_cached_model(m: pathlib.Path | str) -> tf.Module:
if isinstance(m, str):
return load_cached_model(pathlib.Path(m))
assert isinstance(m, pathlib.Path)
model = tf.keras.models.load_model(m, custom_objects={'loss': None})
return model

Expand All @@ -31,13 +35,12 @@ def merge_unconditional(texts, tnss, nss, dss, sss):
return res


def predict(text: str, model_or_model_path: tf.Module|str = 'models/Nakdimon.h5', maxlen=10000) -> str:
if isinstance(model_or_model_path, str):
model = load_cached_model(model_or_model_path)
elif isinstance(model_or_model_path, tf.Module):
model = model_or_model_path
else:
def predict(text: str, model_or_model_path: tf.Module | str = MAIN_MODEL, maxlen=10000) -> str:
if isinstance(model_or_model_path, (pathlib.Path, str)):
model_or_model_path = load_cached_model(model_or_model_path)
if not isinstance(model_or_model_path, tf.Module):
raise TypeError(f'Expected str or tf.Module, got {type(model_or_model_path)}')
model = model_or_model_path
data = dataset.Data.from_text(hebrew.iterate_dotted_text(text), maxlen)
prediction = model.predict(data.normalized)
[actual_niqqud, actual_dagesh, actual_sin] = [dataset.from_categorical(prediction[0]), dataset.from_categorical(prediction[1]), dataset.from_categorical(prediction[2])]
Expand Down
8 changes: 4 additions & 4 deletions nakdimon/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from nakdimon import hebrew
from nakdimon.train import TrainingParams
from nakdimon import metrics
from nakdimon.config import MODELS_DIR


pretrain_path = f'./models/wiki'
pretrain_path = f'{MODELS_DIR}/wiki'
model_name = pretrain_path + 'pretrain.h5'


Expand Down Expand Up @@ -138,7 +138,7 @@ def pretrain():
def train_ablation(params):
from train import train
model = train(params)
model.save(f'./models/ablations/{params.name}.h5')
model.save(f'./{MODELS_DIR}/ablations/{params.name}.h5')


if __name__ == '__main__':
Expand All @@ -152,7 +152,7 @@ def train_ablation(params):
import ablations
tf.config.set_visible_devices([], 'GPU')
model_name = 'PretrainedModernOnly'
model = tf.keras.models.load_model(f'models/ablations/{model_name}.h5',
model = tf.keras.models.load_model(f'{MODELS_DIR}/ablations/{model_name}.h5',
custom_objects={'loss': TrainingParams().loss})
print(model_name, *metrics.metricwise_mean(ablations.calculate_metrics(model)).values(), sep=', ')

5 changes: 3 additions & 2 deletions nakdimon/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging

from nakdimon import predict
from nakdimon.config import MAIN_MODEL

app = flask.Flask(__name__)

Expand All @@ -26,9 +27,9 @@ def diacritize():


def main():
logging.info("Loading models/Nakdimon.h5")
logging.info(f"Loading {MAIN_MODEL}")
try:
predict.predict("שלום", 'models/Nakdimon.h5')
predict.predict("שלום")
logging.info("Done loading.")
except OSError:
logging.warning("Could not load default model")
Expand Down
3 changes: 2 additions & 1 deletion nakdimon/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from nakdimon.dataset import NIQQUD_SIZE, DAGESH_SIZE, SIN_SIZE, LETTERS_SIZE
from nakdimon import schedulers
from nakdimon import transformer
from nakdimon.config import MODELS_DIR

# assert tf.config.list_physical_devices('GPU')

Expand Down Expand Up @@ -270,7 +271,7 @@ def train(params: NakdimonParams, group, ablation=False, wandb_enabled=False):

def train_ablation(params, group):
model = train(params, group, ablation=True)
model.save(f'./models/ablations/{params.name}.h5')
model.save(f'./{MODELS_DIR}/ablations/{params.name}.h5')


class Full(NakdimonParams):
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "nakdimon"
version = "0.1.1"
version = "0.1.2"
authors = [
{ name="Elazar Gershuni", email="elazarg@gmail.com" },
]
Expand All @@ -28,6 +28,9 @@ include = ["nakdimon"]
exclude = [] # exclude packages matching these glob patterns (empty by default)
namespaces = false

[tool.setuptools.package-data]
nakdimon = ["Nakdimon.h5"]

[project.urls]
"Homepage" = "https://github.com/elazarg/nakdimon"
"Bug Tracker" = "https://github.com/elazarg/nakdimon/issues"
Expand Down

0 comments on commit 01c631b

Please sign in to comment.