Skip to content

Commit ae493e7

Browse files
authored
avoid splitting script into feature files/scripts (#385)
1 parent a0ce094 commit ae493e7

File tree

3 files changed

+13
-267
lines changed

3 files changed

+13
-267
lines changed

src/datachain/catalog/catalog.py

Lines changed: 13 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import posixpath
1010
import subprocess
1111
import sys
12-
import tempfile
1312
import time
1413
import traceback
1514
from collections.abc import Iterable, Iterator, Mapping, Sequence
@@ -77,7 +76,6 @@
7776
)
7877

7978
from .datasource import DataSource
80-
from .subclass import SubclassFinder
8179

8280
if TYPE_CHECKING:
8381
from datachain.data_storage import (
@@ -92,7 +90,6 @@
9290

9391
DEFAULT_DATASET_DIR = "dataset"
9492
DATASET_FILE_SUFFIX = ".edatachain"
95-
FEATURE_CLASSES = ["DataModel"]
9693

9794
TTL_INT = 4 * 60 * 60
9895

@@ -569,12 +566,6 @@ def find_column_to_str( # noqa: PLR0911
569566
return ""
570567

571568

572-
def form_module_source(source_ast):
573-
module = ast.Module(body=source_ast, type_ignores=[])
574-
module = ast.fix_missing_locations(module)
575-
return ast.unparse(module)
576-
577-
578569
class Catalog:
579570
def __init__(
580571
self,
@@ -660,29 +651,10 @@ def attach_query_wrapper(self, code_ast):
660651
code_ast.body[-1:] = new_expressions
661652
return code_ast
662653

663-
def compile_query_script(
664-
self, script: str, feature_module_name: str
665-
) -> tuple[Union[str, None], str]:
654+
def compile_query_script(self, script: str) -> str:
666655
code_ast = ast.parse(script)
667656
code_ast = self.attach_query_wrapper(code_ast)
668-
finder = SubclassFinder(FEATURE_CLASSES)
669-
finder.visit(code_ast)
670-
671-
if not finder.feature_class:
672-
main_module = form_module_source([*finder.imports, *finder.main_body])
673-
return None, main_module
674-
675-
feature_import = ast.ImportFrom(
676-
module=feature_module_name,
677-
names=[ast.alias(name="*", asname=None)],
678-
level=0,
679-
)
680-
feature_module = form_module_source([*finder.imports, *finder.feature_class])
681-
main_module = form_module_source(
682-
[*finder.imports, feature_import, *finder.main_body]
683-
)
684-
685-
return feature_module, main_module
657+
return ast.unparse(code_ast)
686658

687659
def parse_url(self, uri: str, **config: Any) -> tuple[Client, str]:
688660
config = config or self.client_config
@@ -1863,11 +1835,6 @@ def query(
18631835
C.size > 1000
18641836
)
18651837
"""
1866-
feature_file = tempfile.NamedTemporaryFile( # noqa: SIM115
1867-
dir=os.getcwd(), suffix=".py", delete=False
1868-
)
1869-
_, feature_module = os.path.split(feature_file.name)
1870-
18711838
if not job_id:
18721839
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
18731840
job_id = self.metastore.create_job(
@@ -1877,23 +1844,16 @@ def query(
18771844
python_version=python_version,
18781845
)
18791846

1880-
try:
1881-
lines, proc = self.run_query(
1882-
python_executable or sys.executable,
1883-
query_script,
1884-
envs,
1885-
feature_file,
1886-
capture_output,
1887-
feature_module,
1888-
output_hook,
1889-
params,
1890-
save,
1891-
job_id,
1892-
)
1893-
finally:
1894-
feature_file.close()
1895-
os.unlink(feature_file.name)
1896-
1847+
lines, proc = self.run_query(
1848+
python_executable or sys.executable,
1849+
query_script,
1850+
envs,
1851+
capture_output,
1852+
output_hook,
1853+
params,
1854+
save,
1855+
job_id,
1856+
)
18971857
output = "".join(lines)
18981858

18991859
if proc.returncode:
@@ -1947,31 +1907,19 @@ def run_query(
19471907
python_executable: str,
19481908
query_script: str,
19491909
envs: Optional[Mapping[str, str]],
1950-
feature_file: IO[bytes],
19511910
capture_output: bool,
1952-
feature_module: str,
19531911
output_hook: Callable[[str], None],
19541912
params: Optional[dict[str, str]],
19551913
save: bool,
19561914
job_id: Optional[str],
19571915
) -> tuple[list[str], subprocess.Popen]:
19581916
try:
1959-
feature_code, query_script_compiled = self.compile_query_script(
1960-
query_script, feature_module[:-3]
1961-
)
1962-
if feature_code:
1963-
feature_file.write(feature_code.encode())
1964-
feature_file.flush()
1965-
1917+
query_script_compiled = self.compile_query_script(query_script)
19661918
except Exception as exc:
19671919
raise QueryScriptCompileError(
19681920
f"Query script failed to compile, reason: {exc}"
19691921
) from exc
19701922
envs = dict(envs or os.environ)
1971-
if feature_code:
1972-
envs["DATACHAIN_FEATURE_CLASS_SOURCE"] = json.dumps(
1973-
{feature_module: feature_code}
1974-
)
19751923
envs.update(
19761924
{
19771925
"DATACHAIN_QUERY_PARAMS": json.dumps(params or {}),

src/datachain/catalog/subclass.py

Lines changed: 0 additions & 60 deletions
This file was deleted.

tests/unit/test_catalog.py

Lines changed: 0 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from textwrap import dedent
21
from typing import TYPE_CHECKING
32

43
from datachain.catalog import Catalog
@@ -7,147 +6,6 @@
76
from datachain.data_storage import AbstractWarehouse
87

98

10-
def test_compile_query_script_no_feature_class(catalog):
11-
script = dedent(
12-
"""
13-
from datachain.query import C, DatasetQuery, asUDF
14-
DatasetQuery("s3://bkt/dir1")
15-
"""
16-
).strip()
17-
feature, result = catalog.compile_query_script(script, "tmpfeature")
18-
expected = dedent(
19-
"""
20-
from datachain.query import C, DatasetQuery, asUDF
21-
import datachain.query.dataset
22-
datachain.query.dataset.query_wrapper(
23-
DatasetQuery('s3://bkt/dir1'))
24-
"""
25-
).strip()
26-
assert feature is None
27-
assert result == expected
28-
29-
30-
def test_compile_query_script_with_feature_class(catalog):
31-
script = dedent(
32-
"""
33-
from datachain.query import C, DatasetQuery, asUDF
34-
from datachain.lib.data_model import DataModel as FromAlias
35-
from datachain.lib.data_model import DataModel
36-
import datachain.lib.data_model.DataModel as DirectImportedFeature
37-
import datachain
38-
39-
class NormalClass:
40-
t = 1
41-
42-
class SFClass(FromAlias):
43-
emb: float
44-
45-
class DirectImport(DirectImportedFeature):
46-
emb: float
47-
48-
class FullImport(datachain.lib.data_model.DataModel):
49-
emb: float
50-
51-
class Embedding(DataModel):
52-
emb: float
53-
54-
DatasetQuery("s3://bkt/dir1")
55-
"""
56-
).strip()
57-
feature, result = catalog.compile_query_script(script, "tmpfeature")
58-
expected_feature = dedent(
59-
"""
60-
from datachain.query import C, DatasetQuery, asUDF
61-
from datachain.lib.data_model import DataModel as FromAlias
62-
from datachain.lib.data_model import DataModel
63-
import datachain.lib.data_model.DataModel as DirectImportedFeature
64-
import datachain
65-
import datachain.query.dataset
66-
67-
class SFClass(FromAlias):
68-
emb: float
69-
70-
class DirectImport(DirectImportedFeature):
71-
emb: float
72-
73-
class FullImport(datachain.lib.data_model.DataModel):
74-
emb: float
75-
76-
class Embedding(DataModel):
77-
emb: float
78-
"""
79-
).strip()
80-
expected_result = dedent(
81-
"""
82-
from datachain.query import C, DatasetQuery, asUDF
83-
from datachain.lib.data_model import DataModel as FromAlias
84-
from datachain.lib.data_model import DataModel
85-
import datachain.lib.data_model.DataModel as DirectImportedFeature
86-
import datachain
87-
import datachain.query.dataset
88-
from tmpfeature import *
89-
90-
class NormalClass:
91-
t = 1
92-
datachain.query.dataset.query_wrapper(
93-
DatasetQuery('s3://bkt/dir1'))
94-
"""
95-
).strip()
96-
97-
assert feature == expected_feature
98-
assert result == expected_result
99-
100-
101-
def test_compile_query_script_with_decorator(catalog):
102-
script = dedent(
103-
"""
104-
import os
105-
from datachain.query import C, DatasetQuery, udf
106-
from datachain.sql.types import Float, Float32, Int, String, Binary
107-
108-
@udf(
109-
params=("name", ),
110-
output={"num": Float, "bin": Binary}
111-
)
112-
def my_func1(name):
113-
x = 3.14
114-
int_example = 25
115-
bin = int_example.to_bytes(2, "big")
116-
return (x, bin)
117-
118-
print("Test ENV = ", os.environ['TEST_ENV'])
119-
ds = DatasetQuery("s3://dql-small/*.jpg") \
120-
.add_signals(my_func1)
121-
122-
ds
123-
"""
124-
).strip()
125-
feature, result = catalog.compile_query_script(script, "tmpfeature")
126-
127-
expected_result = dedent(
128-
"""
129-
import os
130-
from datachain.query import C, DatasetQuery, udf
131-
from datachain.sql.types import Float, Float32, Int, String, Binary
132-
import datachain.query.dataset
133-
134-
@udf(params=('name',), output={'num': Float, 'bin': Binary})
135-
def my_func1(name):
136-
x = 3.14
137-
int_example = 25
138-
bin = int_example.to_bytes(2, 'big')
139-
return (x, bin)
140-
print('Test ENV = ', os.environ['TEST_ENV'])
141-
ds = DatasetQuery('s3://dql-small/*.jpg').add_signals(my_func1)
142-
datachain.query.dataset.query_wrapper(
143-
ds)
144-
"""
145-
).strip()
146-
147-
assert feature is None
148-
assert result == expected_result
149-
150-
1519
def test_catalog_warehouse_ready_callback(mocker, warehouse, id_generator, metastore):
15210
spy = mocker.spy(warehouse, "is_ready")
15311

0 commit comments

Comments
 (0)