Skip to content

Commit

Permalink
* Updated TMCoalesced Test
Browse files Browse the repository at this point in the history
* Update version for all mnist datsets.
  • Loading branch information
perara committed Mar 25, 2024
1 parent e859102 commit 0645987
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
2 changes: 1 addition & 1 deletion test/test_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class TMCoalescedClassifierTests(unittest.TestCase, ClassifierTests):
def setUp(self) -> None:
from tmu.models.classification.coalesced_classifier import TMCoalescedClassifier
self.model = TMCoalescedClassifier(
number_of_clauses=4,
number_of_clauses=10,
T=10,
s=10.0,
max_included_literals=32,
Expand Down
12 changes: 8 additions & 4 deletions tmu/data/fashion_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ def _transform(self, name, dataset):
class KuzushijiMNIST(TMUDataset):
def _retrieve_dataset(self) -> Dict[str, np.ndarray]:
kwargs = dict()
pyver = tuple([int(x) for x in sklearn.__version__.split(".")])
# Parse the sklearn version string
sklearn_version = parse_version(sklearn.__version__)

if pyver[0] >= 1 and pyver[1] >= 2:
# Check if the major version is >= 1 and the minor version is >= 2
if sklearn_version >= parse_version("1.2"):
kwargs["parser"] = "pandas"

X, y = fetch_openml(
Expand Down Expand Up @@ -94,9 +96,11 @@ def __init__(self):

def _retrieve_dataset(self) -> Dict[str, np.ndarray]:
kwargs = dict()
pyver = tuple([int(x) for x in sklearn.__version__.split(".")])
# Parse the sklearn version string
sklearn_version = parse_version(sklearn.__version__)

if pyver[0] >= 1 and pyver[1] >= 2:
# Check if the major version is >= 1 and the minor version is >= 2
if sklearn_version >= parse_version("1.2"):
kwargs["parser"] = "pandas"

X, y = fetch_openml(
Expand Down

0 comments on commit 0645987

Please sign in to comment.