Skip to content

Commit

Permalink
bump for 1.0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
franneck94 committed Oct 29, 2023
1 parent 2335309 commit 1694744
Show file tree
Hide file tree
Showing 12 changed files with 40 additions and 47 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# TensorCross

![Python](https://img.shields.io/badge/python-%203.7+-blue)
![Python](https://img.shields.io/badge/python-%203.9+-blue)
[![License](https://camo.githubusercontent.com/890acbdcb87868b382af9a4b1fac507b9659d9bf/68747470733a2f2f696d672e736869656c64732e696f2f62616467652f6c6963656e73652d4d49542d626c75652e737667)](https://github.com/franneck94/TensorCross/blob/main/LICENSE)
[![Build](https://github.com/franneck94/TensorCross/workflows/Test/badge.svg)](https://github.com/franneck94/TensorCross/actions?query=workflow%3A%22Test+and+Coverage%22)
[![codecov](https://codecov.io/gh/franneck94/TensorCross/branch/main/graph/badge.svg)](https://codecov.io/gh/franneck94/TensorCross)
Expand All @@ -10,16 +10,16 @@
pip install tensorcross
```

Cross-Validation, Grid Search, and Random Search for tf.data.Datasets in TensorFlow 2.0+ and Python 3.7+.
Grid Search and Random Search with optionally CrossValidation for tf.data.Datasets in TensorFlow (Keras) 2.8+ and Python 3.9+.

## Motivation

Currently, there is the tf.keras.wrapper.KerasClassifier/KerasRegressor class,
which can be used to transform your tf.keras model into a sklearn estimator.
There was the tf.keras.wrapper.KerasClassifier/KerasRegressor class,
which can be used to transform your tf.keras model into a sklearn estimator.
However, this approach is only applicable if your dataset is a numpy.ndarray
for your x and y data.
for your x and y data and it was also removed from newer versions.
If you want to use the new tf.data.Dataset class, you cannot use the sklearn
wrappers.
wrappers.
This python package aims to help with this use-case.

## API
Expand Down
4 changes: 0 additions & 4 deletions examples/example_grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
from tensorcross.utils import dataset_split


np.random.seed(0)
tf.random.set_seed(0)


def build_model(
optimizer: Optimizer,
learning_rate: float,
Expand Down
4 changes: 0 additions & 4 deletions examples/example_grid_search_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
from tensorcross.model_selection import GridSearchCV


np.random.seed(0)
tf.random.set_seed(0)


def build_model(
optimizer: Optimizer,
learning_rate: float,
Expand Down
4 changes: 0 additions & 4 deletions examples/example_random_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@
from tensorcross.utils import dataset_split


np.random.seed(0)
tf.random.set_seed(0)


def build_model(
optimizer: Optimizer,
learning_rate: float,
Expand Down
4 changes: 0 additions & 4 deletions examples/example_random_search_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
from tensorcross.model_selection import RandomSearchCV


np.random.seed(0)
tf.random.set_seed(0)


def build_model(
optimizer: Optimizer,
learning_rate: float,
Expand Down
15 changes: 7 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tool.black]
target-version = ['py310']
target-version = ['py39']
line-length = 80
skip-string-normalization = false
skip-magic-trailing-comma = false
Expand All @@ -11,7 +11,7 @@ extend-exclude = '''
'''

[tool.isort]
py_version = 310
py_version = 39
sections = [
"FUTURE",
"STDLIB",
Expand Down Expand Up @@ -39,12 +39,11 @@ known_third_party = [
"tqdm",
"cv2",
"skimage",
"tensorcross",
"tensorflow_datasets",
"scikeras"
]
known_first_party = []
known_local_folder = []
known_local_folder = ["tensorcross"]
# style: black
multi_line_output = 3
include_trailing_comma = true
Expand All @@ -62,7 +61,7 @@ skip_glob = [

[tool.mypy]
# Platform configuration
python_version = "3.10"
python_version = "3.8"
# imports related
ignore_missing_imports = true
follow_imports = "silent"
Expand Down Expand Up @@ -102,10 +101,10 @@ show_error_codes = true
exclude = ["examples"]

[tool.ruff]
target-version = "py310"
target-version = "py39"
select = ["F", "E"]
extend-select = ["W", "C90", "I", "N", "UP", "YTT", "ANN", "ASYNC", "B", "A", "C4", "DTZ", "FA", "ISC", "ICN", "PIE", "PYI", "PT", "RET", "SLOT", "SIM", "ARG", "PTH", "PD", "PLC", "PLE", "PLR", "PLW", "FLY", "NPY", "PERF", "RUF"]
ignore = ["I001", "ANN401", "SIM300", "PERF203", "ANN101", "B905"]
ignore = ["I001", "ANN401", "SIM300", "PERF203", "ANN101", "B905", "UP007", "UP006", "PTH", "UP035"]
fixable = ["F", "E", "W", "C90", "I", "N", "UP", "YTT", "ANN", "ASYNC", "B", "A", "C4", "DTZ", "FA", "ISC", "ICN", "PIE", "PYI", "PT", "RET", "SLOT", "SIM", "ARG", "PTH", "PD", "PLC", "PLE", "PLR", "PLW", "FLY", "NPY", "PERF", "RUF"]
unfixable = []
line-length = 80
Expand All @@ -131,7 +130,7 @@ skip-magic-trailing-comma = false
line-ending = "auto"

[tool.pyright]
pythonVersion = "3.10"
pythonVersion = "3.9"
typeCheckingMode = "basic"
# enable subset of "strict"
reportDuplicateImport = true
Expand Down
9 changes: 4 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@

CLASSIFIERS = """\
License :: OSI Approved :: MIT License
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3.12
Topic :: Software Development
Operating System :: Microsoft :: Windows
Operating System :: POSIX
Expand All @@ -34,13 +33,13 @@
VERSION = __version__
ISRELEASED = False

MIN_PYTHON_VERSION = "3.7"
INSTALL_REQUIRES = ["keras", "numpy", "scipy", "scikit-learn"]
MIN_PYTHON_VERSION = "3.9"
INSTALL_REQUIRES = ["keras>=2.8", "numpy", "scipy", "scikit-learn>=1.0"]


PACKAGES = find_packages(include=["tensorcross", "tensorcross.*"])

metadata = dict(
metadata = dict( # noqa: C408
name=DISTNAME,
version=VERSION,
long_description=README,
Expand Down
2 changes: 2 additions & 0 deletions tensorcross/_types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Any
from typing import Dict
from typing import List
Expand Down
12 changes: 8 additions & 4 deletions tensorcross/model_selection/search.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import logging
import os
from abc import ABCMeta
from abc import abstractmethod
from collections.abc import Iterable
from collections.abc import Mapping
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Mapping
from typing import Optional
from typing import Union

Expand Down Expand Up @@ -48,7 +50,7 @@ def __init__(
"params": [],
}

def _run_search(
def _run_search( # noqa: PLR0912
self,
train_dataset: tf.data.Dataset,
val_dataset: tf.data.Dataset,
Expand Down Expand Up @@ -136,7 +138,9 @@ def summary(self) -> str:
f"Best score: {self.results_['best_score']} "
f"using params: {self.results_['best_params']}"
)
dashed_line = "".join(map(lambda x: "-", best_params_str))
dashed_line = "".join(
map(lambda x: "-", best_params_str) # noqa: C417, ARG005
)

current_line = f"\n{dashed_line}\n{best_params_str}\n{dashed_line}"
results_str = current_line
Expand Down
14 changes: 9 additions & 5 deletions tensorcross/model_selection/search_cv.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import logging
import os
from abc import ABCMeta
from abc import abstractmethod
from collections.abc import Iterable
from collections.abc import Mapping
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Mapping
from typing import Optional
from typing import Union

Expand Down Expand Up @@ -56,7 +58,7 @@ def __init__(
"params": [],
}

def _run_search(
def _run_search( # noqa: PLR0912
self,
dataset: tf.data.Dataset,
parameter_obj: Union[ParameterGrid, ParameterSampler],
Expand Down Expand Up @@ -166,7 +168,9 @@ def summary(self) -> str:
f"Best score: {self.results_['best_score']} "
f"using params: {self.results_['best_params']}"
)
dashed_line = "".join(map(lambda x: "-", best_params_str))
dashed_line = "".join(
map(lambda x: "-", best_params_str) # noqa: C417, ARG005
)

current_line = f"\n{dashed_line}\n{best_params_str}\n{dashed_line}"
results_str = current_line
Expand Down Expand Up @@ -253,7 +257,7 @@ def fit(self, dataset: tf.data.Dataset, **kwargs: Any) -> None:


class RandomSearchCV(BaseSearchCV):
def __init__(
def __init__( # noqa: PLR0913
self,
model_fn: Callable[..., Model],
param_distributions: Dict[str, Callable],
Expand Down
5 changes: 3 additions & 2 deletions tensorcross/utils/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Tuple

import tensorflow as tf
Expand Down Expand Up @@ -30,5 +32,4 @@ def dataset_split(
def dataset_join(
dataset_left: tf.data.Dataset, dataset_right: tf.data.Dataset
) -> tf.data.Dataset:
dataset_joined = dataset_left.concatenate(dataset_right)
return dataset_joined
return dataset_left.concatenate(dataset_right)
2 changes: 1 addition & 1 deletion tensorcross/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.4"
__version__ = "1.0.0"

0 comments on commit 1694744

Please sign in to comment.