From bf802eb20d7a4f04bf43c3f2ba85910ca773977f Mon Sep 17 00:00:00 2001 From: Andrew Huang Date: Tue, 22 Oct 2024 13:26:48 -0700 Subject: [PATCH 01/10] move it back down --- lumen/ai/translate.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/lumen/ai/translate.py b/lumen/ai/translate.py index 38a70411..ab7122cc 100644 --- a/lumen/ai/translate.py +++ b/lumen/ai/translate.py @@ -9,9 +9,9 @@ import param -from pydantic import BaseConfig, BaseModel, create_model -from pydantic.color import Color +from pydantic import BaseModel, ConfigDict, create_model from pydantic.fields import FieldInfo, PydanticUndefined +from pydantic_extra_types.color import Color DATE_TYPE = datetime.datetime | datetime.date PARAM_TYPE_MAPPING: dict[param.Parameter, type] = { @@ -36,8 +36,7 @@ class ArbitraryTypesModel(BaseModel): A Pydantic model that allows arbitrary types. """ - class Config(BaseConfig): - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) def _create_literal(obj: list[str | type]) -> type: From de756322a9970829ef9b3bfe8cd54cb4cdee65d2 Mon Sep 17 00:00:00 2001 From: Andrew Huang Date: Tue, 22 Oct 2024 13:28:25 -0700 Subject: [PATCH 02/10] Add test --- lumen/tests/ai/test_utils.py | 304 +++++++++++++++++++++++++++++++++++ 1 file changed, 304 insertions(+) create mode 100644 lumen/tests/ai/test_utils.py diff --git a/lumen/tests/ai/test_utils.py b/lumen/tests/ai/test_utils.py new file mode 100644 index 00000000..4fa44786 --- /dev/null +++ b/lumen/tests/ai/test_utils.py @@ -0,0 +1,304 @@ +from pathlib import Path +from unittest.mock import MagicMock, patch + +import jinja2 +import numpy as np +import pandas as pd +import pytest + +from panel.chat import ChatStep + +from lumen.ai.utils import ( + UNRECOVERABLE_ERRORS, clean_sql, describe_data, format_schema, get_schema, + render_template, report_error, retry_llm_output, +) + + +def test_render_template_with_valid_template(): + template_content = "Hello {{ name }}!" + with patch.object(Path, "read_text", return_value=template_content): + result = render_template("test_template.txt", name="World") + assert result == "Hello World!" + + +def test_render_template_missing_key(): + template_content = "Hello {{ name }}!" + with patch.object(Path, "read_text", return_value=template_content): + with pytest.raises(jinja2.exceptions.UndefinedError): + render_template("test_template.txt") + + +class TestRetryLLMOutput: + + @patch("time.sleep", return_value=None) + def test_success(self, mock_sleep): + @retry_llm_output(retries=2) + def mock_func(errors=None): + return "Success" + + result = mock_func() + assert result == "Success" + assert mock_sleep.call_count == 0 + + @patch("time.sleep", return_value=None) + def test_failure(self, mock_sleep): + @retry_llm_output(retries=2) + def mock_func(errors=None): + if errors is not None: + assert errors == ["Failed"] + raise Exception("Failed") + + with pytest.raises(Exception, match="Failed"): + mock_func() + assert mock_sleep.call_count == 1 + + @patch("time.sleep", return_value=None) + def test_failure_unrecoverable(self, mock_sleep): + @retry_llm_output(retries=2) + def mock_func(errors=None): + if errors is not None: + assert errors == ["Failed"] + raise unrecoverable_error("Failed") + + unrecoverable_error = UNRECOVERABLE_ERRORS[0] + with pytest.raises(unrecoverable_error, match="Failed"): + mock_func(errors=["Failed"]) + assert mock_sleep.call_count == 0 + + @pytest.mark.asyncio + @patch("asyncio.sleep", return_value=None) + async def test_async_success(self, mock_sleep): + @retry_llm_output(retries=2) + async def mock_func(errors=None): + return "Success" + + result = await mock_func() + assert result == "Success" + assert mock_sleep.call_count == 0 + + @pytest.mark.asyncio + @patch("asyncio.sleep", return_value=None) + async def test_async_failure(self, mock_sleep): + @retry_llm_output(retries=2) + async def mock_func(errors=None): + if errors is not None: + assert errors == ["Failed"] + raise Exception("Failed") + + with pytest.raises(Exception, match="Failed"): + await mock_func() + assert mock_sleep.call_count == 1 + + @pytest.mark.asyncio + @patch("asyncio.sleep", return_value=None) + async def test_async_failure_unrecoverable(self, mock_sleep): + @retry_llm_output(retries=2) + async def mock_func(errors=None): + if errors is not None: + assert errors == ["Failed"] + raise unrecoverable_error("Failed") + + unrecoverable_error = UNRECOVERABLE_ERRORS[0] + with pytest.raises(unrecoverable_error, match="Failed"): + await mock_func(errors=["Failed"]) + assert mock_sleep.call_count == 0 + + +def test_format_schema_with_enum(): + schema = { + "field1": {"type": "string", "enum": ["a", "b", "c", "d", "e", "f"]}, + "field2": {"type": "integer"}, + } + expected = { + "field1": {"type": "str", "enum": ["a", "b", "c", "d", "e", "..."]}, + "field2": {"type": "int"}, + } + assert format_schema(schema) == expected + + +def test_format_schema_no_enum(): + schema = { + "field1": {"type": "boolean"}, + "field2": {"type": "integer"}, + } + expected = { + "field1": {"type": "bool"}, + "field2": {"type": "int"}, + } + assert format_schema(schema) == expected + + +class TestGetSchema: + + @pytest.mark.asyncio + async def test_get_schema_from_source(self): + mock_source = MagicMock() + mock_source.get_schema.return_value = {"field1": {"type": "integer"}} + schema = await get_schema(mock_source) + assert "field1" in schema + assert schema["field1"]["type"] == "int" + + @pytest.mark.asyncio + async def test_min_max(self): + mock_source = MagicMock() + mock_source.get_schema.return_value = { + "field1": { + "type": "integer", + "inclusiveMinimum": 0, + "inclusiveMaximum": 100, + } + } + schema = await get_schema(mock_source, include_min_max=True) + assert "min" in schema["field1"] + assert "max" in schema["field1"] + assert schema["field1"]["min"] == 0 + assert schema["field1"]["max"] == 100 + + @pytest.mark.asyncio + async def test_no_min_max(self): + mock_source = MagicMock() + mock_source.get_schema.return_value = { + "field1": { + "type": "integer", + "inclusiveMinimum": 0, + "inclusiveMaximum": 100, + } + } + schema = await get_schema(mock_source, include_min_max=False) + assert "min" not in schema["field1"] + assert "max" not in schema["field1"] + assert "inclusiveMinimum" not in schema["field1"] + assert "inclusiveMaximum" not in schema["field1"] + + @pytest.mark.asyncio + async def test_enum(self): + mock_source = MagicMock() + mock_source.get_schema.return_value = { + "field1": {"type": "string", "enum": ["value1", "value2"]} + } + schema = await get_schema(mock_source, include_enum=False) + assert "enum" not in schema["field1"] + assert schema["field1"]["enum"] == ["value1", "value2"] + + @pytest.mark.asyncio + async def test_count(self): + mock_source = MagicMock() + mock_source.get_schema.return_value = { + "field1": {"type": "integer"}, + "count": 1000, + } + schema = await get_schema(mock_source, include_count=True) + assert "count" in schema["field1"] + assert schema["field1"]["count"] == 1000 + + @pytest.mark.asyncio + async def test_no_count(self): + mock_source = MagicMock() + mock_source.get_schema.return_value = { + "field1": {"type": "integer"}, + "count": 1000, + } + schema = await get_schema(mock_source, include_count=False) + assert "count" not in schema + + @pytest.mark.asyncio + async def test_table(self): + mock_source = MagicMock() + mock_source.get_schema.return_value = {"field1": {"type": "integer"}} + schema = await get_schema(mock_source, table="test_table") + mock_source.get_schema.assert_called_with("test_table", limit=100) + assert "field1" in schema + + @pytest.mark.asyncio + async def test_custom_limit(self): + mock_source = MagicMock() + mock_source.get_schema.return_value = {"field1": {"type": "integer"}} + schema = await get_schema(mock_source, table="test_table", limit=50) + mock_source.get_schema.assert_called_with("test_table", limit=50) + assert "field1" in schema + +class TestDescribeData: + + @pytest.mark.asyncio + async def test_describe_numeric_data(self): + df = pd.DataFrame({ + "col1": np.arange(0, 100000), + "col2": np.arange(0, 100000) + }) + result = await describe_data(df) + assert "col1" in result["stats"] + assert "col2" in result["stats"] + assert result["stats"]["col1"]["nulls"] == '0' + assert result["stats"]["col2"]["nulls"] == '0' + + @pytest.mark.asyncio + async def test_describe_with_nulls(self): + df = pd.DataFrame({ + "col1": np.arange(0, 100000), + "col2": np.arange(0, 100000) + }) + df.loc[:5000, "col1"] = np.nan + df.loc[:5000, "col2"] = np.nan + result = await describe_data(df) + assert result["stats"]["col1"]["nulls"] != '0' + assert result["stats"]["col2"]["nulls"] != '0' + + @pytest.mark.asyncio + async def test_describe_string_data(self): + df = pd.DataFrame({ + "col1": ["apple", "banana", "cherry", "date", "elderberry"] * 2000, + "col2": ["a", "b", "c", "d", "e"] * 2000 + }) + result = await describe_data(df) + assert result["stats"]["col1"]["nunique"] == 5 + assert result["stats"]["col2"]["lengths"]["max"] == 1 + assert result["stats"]["col1"]["lengths"]["max"] == 10 + + @pytest.mark.asyncio + async def test_describe_datetime_data(self): + df = pd.DataFrame({ + "col1": pd.date_range("2018-08-18", periods=10000), + "col2": pd.date_range("2018-08-18", periods=10000), + }) + result = await describe_data(df) + assert "col1" in result["stats"] + assert "col2" in result["stats"] + + @pytest.mark.asyncio + async def test_describe_large_data(self): + df = pd.DataFrame({ + "col1": range(6000), + "col2": range(6000, 12000) + }) + result = await describe_data(df) + assert result["summary"]["is_summarized"] is True + assert len(df.sample(5000)) == 5000 # Should summarize to 5000 rows + + @pytest.mark.asyncio + async def test_describe_small_data(self): + df = pd.DataFrame({ + "col1": [1, 2], + "col2": [3, 4] + }) + result = await describe_data(df) + assert result.equals(df) + + +def test_clean_sql_removes_backticks(): + sql_expr = "```sql SELECT * FROM `table`; ```" + cleaned_sql = clean_sql(sql_expr) + assert cleaned_sql == 'SELECT * FROM "table"' + + +def test_clean_sql_strips_whitespace_and_semicolons(): + sql_expr = "SELECT * FROM table; " + cleaned_sql = clean_sql(sql_expr) + assert cleaned_sql == "SELECT * FROM table" + + +def test_report_error(): + step = ChatStep() + report_error(Exception("Test error"), step) + assert step.failed_title == "Test error" + assert step.status == "failed" + assert step.objects[0].object == "\n```python\nTest error\n```" From 941e41ba1fcb0b156a1508ee263e1c3a1384010b Mon Sep 17 00:00:00 2001 From: Andrew Huang Date: Tue, 22 Oct 2024 13:46:22 -0700 Subject: [PATCH 03/10] add deps and flags --- pixi.toml | 8 ++++++++ pyproject.toml | 2 ++ 2 files changed, 10 insertions(+) diff --git a/pixi.toml b/pixi.toml index c1249fad..2b1b3df0 100644 --- a/pixi.toml +++ b/pixi.toml @@ -64,6 +64,14 @@ toolz = "*" intake-sql = "*" python-duckdb = "*" sqlalchemy = "*" +# ai +datashader = "*" +duckdb = "*" +instructor = ">=1.4.3" +nbformat = "*" +openai = "*" +pyarrow = "*" +pydantic = ">=2.8.0" # [feature.ai.dependencies] # datashader = "*" diff --git a/pyproject.toml b/pyproject.toml index 2b8b718d..c720c542 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,8 @@ addopts = [ "--doctest-modules", "--doctest-ignore-import-errors", ] +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" minversion = "7" xfail_strict = true log_cli_level = "INFO" From fb71ebc405bab414c4a4c44e52de3350165d47ed Mon Sep 17 00:00:00 2001 From: Andrew Huang Date: Tue, 22 Oct 2024 13:54:40 -0700 Subject: [PATCH 04/10] missing dep --- pixi.toml | 1 + pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pixi.toml b/pixi.toml index 2b1b3df0..1ad8b4ed 100644 --- a/pixi.toml +++ b/pixi.toml @@ -50,6 +50,7 @@ pytest-cov = "*" pytest-github-actions-annotate-failures = "*" pytest-rerunfailures = "*" pytest-xdist = "*" +pytest-asyncio = "*" [feature.test-core.tasks] test-unit = 'pytest lumen/tests -n logical --dist loadgroup' diff --git a/pyproject.toml b/pyproject.toml index c720c542..0bc5dac3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ Source = "https://github.com/holoviz/lumen" HoloViz = "https://holoviz.org/" [project.optional-dependencies] -tests = ['pytest', 'pytest-rerunfailures'] +tests = ['pytest', 'pytest-rerunfailures', 'pytest-asyncio'] sql = ['duckdb', 'intake-sql', 'sqlalchemy'] ai = ['nbformat', 'duckdb', 'pyarrow', 'openai', 'instructor >=1.4.3', 'pydantic >=2.8.0', 'datashader'] ai-local = ['lumen[ai]', 'huggingface_hub'] From e1bcbee07f8d0cafff59e2f7545b842067348759 Mon Sep 17 00:00:00 2001 From: Andrew Huang Date: Tue, 22 Oct 2024 13:59:48 -0700 Subject: [PATCH 05/10] fix test --- lumen/tests/ai/test_utils.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/lumen/tests/ai/test_utils.py b/lumen/tests/ai/test_utils.py index 4fa44786..7f94aac3 100644 --- a/lumen/tests/ai/test_utils.py +++ b/lumen/tests/ai/test_utils.py @@ -172,13 +172,22 @@ async def test_no_min_max(self): @pytest.mark.asyncio async def test_enum(self): + mock_source = MagicMock() + mock_source.get_schema.return_value = { + "field1": {"type": "string", "enum": ["value1", "value2"]} + } + schema = await get_schema(mock_source, include_enum=True) + assert "enum" in schema["field1"] + assert schema["field1"]["enum"] == ["value1", "value2"] + + @pytest.mark.asyncio + async def test_no_enum(self): mock_source = MagicMock() mock_source.get_schema.return_value = { "field1": {"type": "string", "enum": ["value1", "value2"]} } schema = await get_schema(mock_source, include_enum=False) assert "enum" not in schema["field1"] - assert schema["field1"]["enum"] == ["value1", "value2"] @pytest.mark.asyncio async def test_count(self): From d1a5a7743a629d9769bda2b7e183db33c38aedfe Mon Sep 17 00:00:00 2001 From: Andrew Huang Date: Tue, 22 Oct 2024 14:21:34 -0700 Subject: [PATCH 06/10] fix test --- pixi.toml | 1 + pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pixi.toml b/pixi.toml index 1ad8b4ed..8d3f8f83 100644 --- a/pixi.toml +++ b/pixi.toml @@ -73,6 +73,7 @@ nbformat = "*" openai = "*" pyarrow = "*" pydantic = ">=2.8.0" +pydantic_extra_types = "*" # [feature.ai.dependencies] # datashader = "*" diff --git a/pyproject.toml b/pyproject.toml index 0bc5dac3..17c0b6c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ HoloViz = "https://holoviz.org/" [project.optional-dependencies] tests = ['pytest', 'pytest-rerunfailures', 'pytest-asyncio'] sql = ['duckdb', 'intake-sql', 'sqlalchemy'] -ai = ['nbformat', 'duckdb', 'pyarrow', 'openai', 'instructor >=1.4.3', 'pydantic >=2.8.0', 'datashader'] +ai = ['nbformat', 'duckdb', 'pyarrow', 'openai', 'instructor >=1.4.3', 'pydantic >=2.8.0', 'datashader', 'pydantic-extra-types'] ai-local = ['lumen[ai]', 'huggingface_hub'] ai-llama = ['lumen[ai-local]', 'llama-cpp-python'] From 2819471c442efa095fb37624089c1f2c9a3b6700 Mon Sep 17 00:00:00 2001 From: Andrew Huang Date: Tue, 22 Oct 2024 14:31:50 -0700 Subject: [PATCH 07/10] push to pypi --- pixi.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pixi.toml b/pixi.toml index 8d3f8f83..a3c48d6d 100644 --- a/pixi.toml +++ b/pixi.toml @@ -73,6 +73,8 @@ nbformat = "*" openai = "*" pyarrow = "*" pydantic = ">=2.8.0" + +[feature.test.pypi-dependencies] pydantic_extra_types = "*" # [feature.ai.dependencies] @@ -84,6 +86,7 @@ pydantic_extra_types = "*" # pyarrow = "*" # pydantic = ">=2.8.0" # + # [feature.ai-local.dependencies] # huggingface_hub = "*" # From 61e55a10537c62cfa59f9954b95a29678f8dff2a Mon Sep 17 00:00:00 2001 From: Andrew Huang Date: Tue, 22 Oct 2024 14:37:06 -0700 Subject: [PATCH 08/10] skip entire module --- lumen/tests/ai/test_utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/lumen/tests/ai/test_utils.py b/lumen/tests/ai/test_utils.py index 7f94aac3..84a4f779 100644 --- a/lumen/tests/ai/test_utils.py +++ b/lumen/tests/ai/test_utils.py @@ -8,10 +8,13 @@ from panel.chat import ChatStep -from lumen.ai.utils import ( - UNRECOVERABLE_ERRORS, clean_sql, describe_data, format_schema, get_schema, - render_template, report_error, retry_llm_output, -) +try: + from lumen.ai.utils import ( + UNRECOVERABLE_ERRORS, clean_sql, describe_data, format_schema, + get_schema, render_template, report_error, retry_llm_output, + ) +except ImportError: + pytest.skip("Skipping tests that require lumen.ai", allow_module_level=True) def test_render_template_with_valid_template(): From 0aa145ecf5f75c0f68b30acb81026fd3256556b8 Mon Sep 17 00:00:00 2001 From: Andrew Huang Date: Mon, 28 Oct 2024 10:57:06 -0700 Subject: [PATCH 09/10] rm pytest asyncio --- lumen/tests/ai/test_utils.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/lumen/tests/ai/test_utils.py b/lumen/tests/ai/test_utils.py index 84a4f779..a3b5133b 100644 --- a/lumen/tests/ai/test_utils.py +++ b/lumen/tests/ai/test_utils.py @@ -68,7 +68,6 @@ def mock_func(errors=None): mock_func(errors=["Failed"]) assert mock_sleep.call_count == 0 - @pytest.mark.asyncio @patch("asyncio.sleep", return_value=None) async def test_async_success(self, mock_sleep): @retry_llm_output(retries=2) @@ -79,7 +78,6 @@ async def mock_func(errors=None): assert result == "Success" assert mock_sleep.call_count == 0 - @pytest.mark.asyncio @patch("asyncio.sleep", return_value=None) async def test_async_failure(self, mock_sleep): @retry_llm_output(retries=2) @@ -92,7 +90,6 @@ async def mock_func(errors=None): await mock_func() assert mock_sleep.call_count == 1 - @pytest.mark.asyncio @patch("asyncio.sleep", return_value=None) async def test_async_failure_unrecoverable(self, mock_sleep): @retry_llm_output(retries=2) @@ -133,7 +130,6 @@ def test_format_schema_no_enum(): class TestGetSchema: - @pytest.mark.asyncio async def test_get_schema_from_source(self): mock_source = MagicMock() mock_source.get_schema.return_value = {"field1": {"type": "integer"}} @@ -141,7 +137,6 @@ async def test_get_schema_from_source(self): assert "field1" in schema assert schema["field1"]["type"] == "int" - @pytest.mark.asyncio async def test_min_max(self): mock_source = MagicMock() mock_source.get_schema.return_value = { @@ -157,7 +152,6 @@ async def test_min_max(self): assert schema["field1"]["min"] == 0 assert schema["field1"]["max"] == 100 - @pytest.mark.asyncio async def test_no_min_max(self): mock_source = MagicMock() mock_source.get_schema.return_value = { @@ -173,7 +167,6 @@ async def test_no_min_max(self): assert "inclusiveMinimum" not in schema["field1"] assert "inclusiveMaximum" not in schema["field1"] - @pytest.mark.asyncio async def test_enum(self): mock_source = MagicMock() mock_source.get_schema.return_value = { @@ -183,7 +176,6 @@ async def test_enum(self): assert "enum" in schema["field1"] assert schema["field1"]["enum"] == ["value1", "value2"] - @pytest.mark.asyncio async def test_no_enum(self): mock_source = MagicMock() mock_source.get_schema.return_value = { @@ -192,7 +184,6 @@ async def test_no_enum(self): schema = await get_schema(mock_source, include_enum=False) assert "enum" not in schema["field1"] - @pytest.mark.asyncio async def test_count(self): mock_source = MagicMock() mock_source.get_schema.return_value = { @@ -203,7 +194,6 @@ async def test_count(self): assert "count" in schema["field1"] assert schema["field1"]["count"] == 1000 - @pytest.mark.asyncio async def test_no_count(self): mock_source = MagicMock() mock_source.get_schema.return_value = { @@ -213,7 +203,6 @@ async def test_no_count(self): schema = await get_schema(mock_source, include_count=False) assert "count" not in schema - @pytest.mark.asyncio async def test_table(self): mock_source = MagicMock() mock_source.get_schema.return_value = {"field1": {"type": "integer"}} @@ -221,7 +210,6 @@ async def test_table(self): mock_source.get_schema.assert_called_with("test_table", limit=100) assert "field1" in schema - @pytest.mark.asyncio async def test_custom_limit(self): mock_source = MagicMock() mock_source.get_schema.return_value = {"field1": {"type": "integer"}} @@ -231,7 +219,6 @@ async def test_custom_limit(self): class TestDescribeData: - @pytest.mark.asyncio async def test_describe_numeric_data(self): df = pd.DataFrame({ "col1": np.arange(0, 100000), @@ -243,7 +230,6 @@ async def test_describe_numeric_data(self): assert result["stats"]["col1"]["nulls"] == '0' assert result["stats"]["col2"]["nulls"] == '0' - @pytest.mark.asyncio async def test_describe_with_nulls(self): df = pd.DataFrame({ "col1": np.arange(0, 100000), @@ -255,7 +241,6 @@ async def test_describe_with_nulls(self): assert result["stats"]["col1"]["nulls"] != '0' assert result["stats"]["col2"]["nulls"] != '0' - @pytest.mark.asyncio async def test_describe_string_data(self): df = pd.DataFrame({ "col1": ["apple", "banana", "cherry", "date", "elderberry"] * 2000, @@ -266,7 +251,6 @@ async def test_describe_string_data(self): assert result["stats"]["col2"]["lengths"]["max"] == 1 assert result["stats"]["col1"]["lengths"]["max"] == 10 - @pytest.mark.asyncio async def test_describe_datetime_data(self): df = pd.DataFrame({ "col1": pd.date_range("2018-08-18", periods=10000), @@ -276,7 +260,6 @@ async def test_describe_datetime_data(self): assert "col1" in result["stats"] assert "col2" in result["stats"] - @pytest.mark.asyncio async def test_describe_large_data(self): df = pd.DataFrame({ "col1": range(6000), @@ -286,7 +269,6 @@ async def test_describe_large_data(self): assert result["summary"]["is_summarized"] is True assert len(df.sample(5000)) == 5000 # Should summarize to 5000 rows - @pytest.mark.asyncio async def test_describe_small_data(self): df = pd.DataFrame({ "col1": [1, 2], From 5fd9c0ebb4ae39c5012e5a33254cf82d6cb3b0ae Mon Sep 17 00:00:00 2001 From: Andrew Huang Date: Mon, 28 Oct 2024 11:08:14 -0700 Subject: [PATCH 10/10] fix tests --- lumen/ai/utils.py | 2 +- lumen/tests/ai/test_utils.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/lumen/ai/utils.py b/lumen/ai/utils.py index 5c9c53c5..49c78a8b 100644 --- a/lumen/ai/utils.py +++ b/lumen/ai/utils.py @@ -146,7 +146,7 @@ async def get_schema( continue elif not include_enum: spec.pop("enum") - elif "limit" in get_kwargs: + elif "limit" in get_kwargs and len(spec["enum"]) > get_kwargs["limit"]: spec["enum"].append("...") if count and include_count: diff --git a/lumen/tests/ai/test_utils.py b/lumen/tests/ai/test_utils.py index a3b5133b..459aaa69 100644 --- a/lumen/tests/ai/test_utils.py +++ b/lumen/tests/ai/test_utils.py @@ -184,6 +184,15 @@ async def test_no_enum(self): schema = await get_schema(mock_source, include_enum=False) assert "enum" not in schema["field1"] + async def test_enum_limit(self): + mock_source = MagicMock() + mock_source.get_schema.return_value = { + "field1": {"type": "string", "enum": ["value1", "value2", "value3"]} + } + schema = await get_schema(mock_source, include_enum=True, limit=2) + assert "enum" in schema["field1"] + assert "..." in schema["field1"]["enum"] + async def test_count(self): mock_source = MagicMock() mock_source.get_schema.return_value = {