From f9848f7b72f3996f1b9f3cb2211b227346a1d71b Mon Sep 17 00:00:00 2001 From: Ivan Shcheklein Date: Wed, 10 Jul 2024 14:55:44 -0700 Subject: [PATCH 1/5] use bigger fixture tree for distributed tests --- tests/func/test_dataset_query.py | 13 ++++++++----- tests/utils.py | 12 ++++++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/tests/func/test_dataset_query.py b/tests/func/test_dataset_query.py index 8111e1737..9b44583ed 100644 --- a/tests/func/test_dataset_query.py +++ b/tests/func/test_dataset_query.py @@ -51,6 +51,7 @@ from tests.data import ENTRIES from tests.utils import ( DEFAULT_TREE, + LARGE_TREE, NUM_TREE, SIMPLE_DS_QUERY_RECORDS, TARRED_TREE, @@ -1012,8 +1013,8 @@ def name_len_interrupt(_name): @pytest.mark.parametrize( - "cloud_type,version_aware", - [("s3", True)], + "cloud_type,version_aware,tree", + [("s3", True, LARGE_TREE)], indirect=True, ) @pytest.mark.parametrize("batch", [False, True]) @@ -1023,7 +1024,9 @@ def name_len_interrupt(_name): reason="Set the DATACHAIN_DISTRIBUTED environment variable " "to test distributed UDFs", ) -def test_udf_distributed(cloud_test_catalog_tmpfile, batch, workers, datachain_job_id): +def test_udf_distributed( + cloud_test_catalog_tmpfile, batch, workers, tree, datachain_job_id +): catalog = cloud_test_catalog_tmpfile.catalog sources = [cloud_test_catalog_tmpfile.src_uri] globs = [s.rstrip("/") + "/*" for s in sources] @@ -1048,8 +1051,8 @@ def name_len_batch(names): q = ( DatasetQuery(name="animals", version=1, catalog=catalog) - .filter(C.size < 13) - .filter(C.parent.glob("cats*") | (C.size < 4)) + .filter(C.size < 90) + .filter(C.parent.glob("cats*") | (C.size > 30)) .add_signals(udf_func, parallel=2, workers=workers) .select(C.name, C.name_len, C.blank) ) diff --git a/tests/utils.py b/tests/utils.py index b127af55d..374d1f999 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -43,6 +43,18 @@ def make_index(catalog, src: str, entries, ttl: int = 1234): "others": {"dog4": "ruff"}, }, } + +# Need to run in a distributed mode to at least have a decent amount of tasks +# Has the same structure as the DEFAULT_TREE - cats and dogs +LARGE_TREE: dict[str, Any] = { + "description": "Cats and Dogs", + "cats": {f"cat{i}": "a" * i for i in range(1, 128)}, + "dogs": { + **{f"dogs{i}": "a" * i for i in range(1, 64)}, + "others": {f"dogs{i}": "a" * i for i in range(64, 98)}, + }, +} + NUM_TREE = {f"{i:06d}": f"{i}" for i in range(1024)} From f4b650d594f7742840db847e811d999f4bdb5b9c Mon Sep 17 00:00:00 2001 From: Ivan Shcheklein Date: Wed, 10 Jul 2024 19:03:10 -0700 Subject: [PATCH 2/5] add run datachain fixture, fix other tests --- tests/func/test_dataset_query.py | 80 +++++++++++++++++++++++++++----- 1 file changed, 69 insertions(+), 11 deletions(-) diff --git a/tests/func/test_dataset_query.py b/tests/func/test_dataset_query.py index 9b44583ed..0bbaff16f 100644 --- a/tests/func/test_dataset_query.py +++ b/tests/func/test_dataset_query.py @@ -4,10 +4,13 @@ import os import pickle import random +import signal +import subprocess # nosec B404 import uuid from datetime import datetime, timedelta, timezone from json import dumps from textwrap import dedent +from time import sleep from unittest.mock import ANY, patch import numpy as np @@ -63,6 +66,52 @@ text_embedding, ) +WORKER_COUNT = 1 +WORKER_SHUTDOWN_WAIT_SEC = 30 + + +@pytest.fixture() +def run_datachain_worker(): + if not os.environ.get("DATACHAIN_DISTRIBUTED"): + pytest.skip("Distributed tests are disabled") + workers = [] + worker_cmd = [ + "celery", + "-A", + "clickhousedbadapter.distributed", + "worker", + "--loglevel=INFO", + "-P", + "solo", + "-Q", + "datachain-worker", + "-n", + "datachain-worker-tests", + ] + workers.append(subprocess.Popen(worker_cmd, shell=False)) # noqa: S603 + try: + from clickhousedbadapter.distributed import app + + inspect = app.control.inspect() + attempts = 0 + # Wait 10 seconds for the Celery worker(s) to be up + while not inspect.active() and attempts < 10: + sleep(1) + attempts += 1 + + if attempts == 10: + raise RuntimeError("Celery worker(s) did not start in time") + + yield workers + finally: + for worker in workers: + os.kill(worker.pid, signal.SIGTERM) + for worker in workers: + try: + worker.wait(timeout=WORKER_SHUTDOWN_WAIT_SEC) + except subprocess.TimeoutExpired: + os.kill(worker.pid, signal.SIGKILL) + def from_result_row(col_names, row): return dict(zip(col_names, row)) @@ -1025,7 +1074,12 @@ def name_len_interrupt(_name): "to test distributed UDFs", ) def test_udf_distributed( - cloud_test_catalog_tmpfile, batch, workers, tree, datachain_job_id + cloud_test_catalog_tmpfile, + batch, + workers, + tree, + datachain_job_id, + run_datachain_worker, ): catalog = cloud_test_catalog_tmpfile.catalog sources = [cloud_test_catalog_tmpfile.src_uri] @@ -1058,7 +1112,7 @@ def name_len_batch(names): ) result = q.results() - assert len(result) == 3 + assert len(result) == 148 string_default = String.default_value(catalog.warehouse.db.dialect) for r in result: # Check that the UDF ran successfully @@ -1067,8 +1121,8 @@ def name_len_batch(names): @pytest.mark.parametrize( - "cloud_type,version_aware", - [("s3", True)], + "cloud_type,version_aware,tree", + [("s3", True, LARGE_TREE)], indirect=True, ) @pytest.mark.parametrize("workers", (1, 2)) @@ -1078,7 +1132,7 @@ def name_len_batch(names): "to test distributed UDFs", ) def test_udf_distributed_exec_error( - cloud_test_catalog_tmpfile, workers, datachain_job_id + cloud_test_catalog_tmpfile, workers, datachain_job_id, tree, run_datachain_worker ): catalog = cloud_test_catalog_tmpfile.catalog sources = [cloud_test_catalog_tmpfile.src_uri] @@ -1102,8 +1156,8 @@ def name_len_error(_name): @pytest.mark.parametrize( - "cloud_type,version_aware", - [("s3", True)], + "cloud_type,version_aware,tree", + [("s3", True, LARGE_TREE)], indirect=True, ) @pytest.mark.skipif( @@ -1111,7 +1165,9 @@ def name_len_error(_name): reason="Set the DATACHAIN_DISTRIBUTED environment variable " "to test distributed UDFs", ) -def test_udf_distributed_interrupt(cloud_test_catalog_tmpfile, capfd, datachain_job_id): +def test_udf_distributed_interrupt( + cloud_test_catalog_tmpfile, capfd, datachain_job_id, tree, run_datachain_worker +): catalog = cloud_test_catalog_tmpfile.catalog sources = [cloud_test_catalog_tmpfile.src_uri] globs = [s.rstrip("/") + "/*" for s in sources] @@ -1137,8 +1193,8 @@ def name_len_interrupt(_name): @pytest.mark.parametrize( - "cloud_type,version_aware", - [("s3", True)], + "cloud_type,version_aware, tree", + [("s3", True, LARGE_TREE)], indirect=True, ) @pytest.mark.skipif( @@ -1146,7 +1202,9 @@ def name_len_interrupt(_name): reason="Set the DATACHAIN_DISTRIBUTED environment variable " "to test distributed UDFs", ) -def test_udf_distributed_cancel(cloud_test_catalog_tmpfile, capfd, datachain_job_id): +def test_udf_distributed_cancel( + cloud_test_catalog_tmpfile, capfd, datachain_job_id, tree, run_datachain_worker +): catalog = cloud_test_catalog_tmpfile.catalog metastore = catalog.metastore sources = [cloud_test_catalog_tmpfile.src_uri] From 1fe1612907e3b3668b99518c8cf709eba4b1170b Mon Sep 17 00:00:00 2001 From: Ivan Shcheklein Date: Thu, 11 Jul 2024 14:00:03 -0700 Subject: [PATCH 3/5] add catalog fixture to pass tests on Studio --- tests/unit/lib/test_datachain_bootstrap.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/lib/test_datachain_bootstrap.py b/tests/unit/lib/test_datachain_bootstrap.py index 349f71634..ac309a996 100644 --- a/tests/unit/lib/test_datachain_bootstrap.py +++ b/tests/unit/lib/test_datachain_bootstrap.py @@ -69,7 +69,7 @@ def teardown(self): assert udf._had_teardown is False -def test_bootstrap_in_chain(): +def test_bootstrap_in_chain(catalog): base = 1278 prime = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29] @@ -83,7 +83,7 @@ def test_bootstrap_in_chain(): assert res == [base + val for val in prime] -def test_vars_duplication_error(): +def test_vars_duplication_error(catalog): with pytest.raises(DatasetPrepareError): ( DataChain.from_features(val=[2, 3, 5, 7, 11, 13, 17, 19, 23, 29]) From 56b15b81bb1b1f299d0760cd1b9d7d141050da3d Mon Sep 17 00:00:00 2001 From: Ivan Shcheklein Date: Fri, 12 Jul 2024 20:29:57 -0700 Subject: [PATCH 4/5] remove ImageFile for now --- src/datachain/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/datachain/__init__.py b/src/datachain/__init__.py index 0242cdbdd..22db24eb8 100644 --- a/src/datachain/__init__.py +++ b/src/datachain/__init__.py @@ -2,7 +2,6 @@ from datachain.lib.feature import Feature from datachain.lib.feature_utils import pydantic_to_feature from datachain.lib.file import File, FileError, FileFeature, IndexedFile, TarVFile -from datachain.lib.image import ImageFile, convert_images from datachain.lib.text import convert_text from datachain.lib.udf import Aggregator, Generator, Mapper from datachain.lib.utils import AbstractUDF, DataChainError @@ -23,12 +22,10 @@ "FileError", "FileFeature", "Generator", - "ImageFile", "IndexedFile", "Mapper", "Session", "TarVFile", - "convert_images", "convert_text", "pydantic_to_feature", ] From 773aef6fcab88b7005a6e35e0e2d8514dd477810 Mon Sep 17 00:00:00 2001 From: Ivan Shcheklein Date: Mon, 15 Jul 2024 19:10:25 -0700 Subject: [PATCH 5/5] rename CH package --- tests/func/test_dataset_query.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/func/test_dataset_query.py b/tests/func/test_dataset_query.py index 0bbaff16f..1abc1b5b4 100644 --- a/tests/func/test_dataset_query.py +++ b/tests/func/test_dataset_query.py @@ -78,7 +78,7 @@ def run_datachain_worker(): worker_cmd = [ "celery", "-A", - "clickhousedbadapter.distributed", + "datachain_server.distributed", "worker", "--loglevel=INFO", "-P", @@ -90,7 +90,7 @@ def run_datachain_worker(): ] workers.append(subprocess.Popen(worker_cmd, shell=False)) # noqa: S603 try: - from clickhousedbadapter.distributed import app + from datachain_server.distributed import app inspect = app.control.inspect() attempts = 0