Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions duckdb/experimental/spark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from duckdb import ColumnExpression, Expression, StarExpression

from ..errors import PySparkIndexError, PySparkTypeError, PySparkValueError
from ..exception import ContributionsAcceptedError
from .column import Column
from .readwriter import DataFrameWriter
from .type_utils import duckdb_to_spark_schema
Expand Down Expand Up @@ -569,6 +568,22 @@ def columns(self) -> list[str]:
"""
return [f.name for f in self.schema.fields]

@property
def dtypes(self) -> list[tuple[str, str]]:
"""Returns all column names and their data types as a list of tuples.

Returns:
-------
list of tuple
List of tuples, each tuple containing a column name and its data type as strings.

Examples:
--------
>>> df.dtypes
[('age', 'bigint'), ('name', 'string')]
"""
return [(f.name, f.dataType.simpleString()) for f in self.schema.fields]

def _ipython_key_completions_(self) -> list[str]:
# Provides tab-completion for column names in PySpark DataFrame
# when accessed in bracket notation, e.g. df['<TAB>]
Expand Down Expand Up @@ -982,8 +997,27 @@ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc]
def write(self) -> DataFrameWriter: # noqa: D102
return DataFrameWriter(self)

def printSchema(self) -> None: # noqa: D102
raise ContributionsAcceptedError
def printSchema(self, level: Optional[int] = None) -> None:
"""Prints out the schema in the tree format.

Parameters
----------
level : int, optional
How many levels to print for nested schemas. Prints all levels by default.

Examples:
--------
>>> df.printSchema()
root
|-- age: bigint (nullable = true)
|-- name: string (nullable = true)
"""
if level is not None and level < 0:
raise PySparkValueError(
error_class="NEGATIVE_VALUE",
message_parameters={"arg_name": "level", "arg_value": str(level)},
)
print(self.schema.treeString(level))

def union(self, other: "DataFrame") -> "DataFrame":
"""Return a new :class:`DataFrame` containing union of rows in this and another
Expand Down
71 changes: 71 additions & 0 deletions duckdb/experimental/spark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,77 @@ def fieldNames(self) -> list[str]:
"""
return list(self.names)

def treeString(self, level: Optional[int] = None) -> str:
"""Returns a string representation of the schema in tree format.

Parameters
----------
level : int, optional
Maximum depth to print. If None, prints all levels.

Returns:
-------
str
Tree-formatted schema string

Examples:
--------
>>> schema = StructType([StructField("age", IntegerType(), True)])
>>> print(schema.treeString())
root
|-- age: integer (nullable = true)
"""

def _tree_string(schema: "StructType", depth: int = 0, max_depth: Optional[int] = None) -> list[str]:
"""Recursively build tree string lines."""
lines = []
if depth == 0:
lines.append("root")

if max_depth is not None and depth >= max_depth:
return lines

for field in schema.fields:
indent = " " * depth
prefix = " |-- "
nullable_str = "true" if field.nullable else "false"

# Handle nested StructType
if isinstance(field.dataType, StructType):
lines.append(f"{indent}{prefix}{field.name}: struct (nullable = {nullable_str})")
# Recursively handle nested struct - don't skip any lines, root only appears at depth 0
nested_lines = _tree_string(field.dataType, depth + 1, max_depth)
lines.extend(nested_lines)
# Handle ArrayType
elif isinstance(field.dataType, ArrayType):
element_type = field.dataType.elementType
if isinstance(element_type, StructType):
lines.append(f"{indent}{prefix}{field.name}: array (nullable = {nullable_str})")
lines.append(
f"{indent} | |-- element: struct (containsNull = {field.dataType.containsNull})"
)
nested_lines = _tree_string(element_type, depth + 2, max_depth)
lines.extend(nested_lines)
else:
type_str = element_type.simpleString()
lines.append(f"{indent}{prefix}{field.name}: array<{type_str}> (nullable = {nullable_str})")
# Handle MapType
elif isinstance(field.dataType, MapType):
key_type = field.dataType.keyType.simpleString()
value_type = field.dataType.valueType.simpleString()
lines.append(
f"{indent}{prefix}{field.name}: map<{key_type},{value_type}> (nullable = {nullable_str})"
)
# Handle simple types
else:
type_str = field.dataType.simpleString()
lines.append(f"{indent}{prefix}{field.name}: {type_str} (nullable = {nullable_str})")

return lines

lines = _tree_string(self, 0, level)
return "\n".join(lines)

def needConversion(self) -> bool: # noqa: D102
# We need convert Row()/namedtuple into tuple()
return True
Expand Down
170 changes: 170 additions & 0 deletions tests/fast/spark/test_spark_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,173 @@ def test_cache(self, spark):
assert df is not cached
assert cached.collect() == df.collect()
assert cached.collect() == [Row(one=1, two=2, three=3, four=4)]

def test_dtypes(self, spark):
data = [("Alice", 25, 5000.0), ("Bob", 30, 6000.0)]
df = spark.createDataFrame(data, ["name", "age", "salary"])
dtypes = df.dtypes

assert isinstance(dtypes, list)
assert len(dtypes) == 3
for col_name, col_type in dtypes:
assert isinstance(col_name, str)
assert isinstance(col_type, str)

col_names = [name for name, _ in dtypes]
assert col_names == ["name", "age", "salary"]
for _, col_type in dtypes:
assert len(col_type) > 0

def test_dtypes_complex_types(self, spark):
from spark_namespace.sql.types import ArrayType, IntegerType, StringType, StructField, StructType

schema = StructType(
[
StructField("name", StringType(), True),
StructField("scores", ArrayType(IntegerType()), True),
StructField(
"address",
StructType([StructField("city", StringType(), True), StructField("zip", StringType(), True)]),
True,
),
]
)
data = [
("Alice", [90, 85, 88], {"city": "NYC", "zip": "10001"}),
("Bob", [75, 80, 82], {"city": "LA", "zip": "90001"}),
]
df = spark.createDataFrame(data, schema)
dtypes = df.dtypes

assert len(dtypes) == 3
col_names = [name for name, _ in dtypes]
assert col_names == ["name", "scores", "address"]

def test_printSchema(self, spark, capsys):
data = [("Alice", 25, 5000), ("Bob", 30, 6000)]
df = spark.createDataFrame(data, ["name", "age", "salary"])
df.printSchema()
captured = capsys.readouterr()
output = captured.out

assert "root" in output
assert "name" in output
assert "age" in output
assert "salary" in output
assert "string" in output or "varchar" in output.lower()
assert "int" in output.lower() or "bigint" in output.lower()

def test_printSchema_nested(self, spark, capsys):
from spark_namespace.sql.types import ArrayType, IntegerType, StringType, StructField, StructType

schema = StructType(
[
StructField("id", IntegerType(), True),
StructField(
"person",
StructType([StructField("name", StringType(), True), StructField("age", IntegerType(), True)]),
True,
),
StructField("hobbies", ArrayType(StringType()), True),
]
)
data = [
(1, {"name": "Alice", "age": 25}, ["reading", "coding"]),
(2, {"name": "Bob", "age": 30}, ["gaming", "music"]),
]
df = spark.createDataFrame(data, schema)
df.printSchema()
captured = capsys.readouterr()
output = captured.out

assert "root" in output
assert "person" in output
assert "hobbies" in output

def test_printSchema_negative_level(self, spark):
data = [("Alice", 25)]
df = spark.createDataFrame(data, ["name", "age"])

with pytest.raises(PySparkValueError):
df.printSchema(level=-1)

def test_treeString_basic(self, spark):
data = [("Alice", 25, 5000)]
df = spark.createDataFrame(data, ["name", "age", "salary"])
tree = df.schema.treeString()

assert tree.startswith("root\n")
assert " |-- name:" in tree
assert " |-- age:" in tree
assert " |-- salary:" in tree
assert "(nullable = true)" in tree
assert tree.count(" |-- ") == 3

def test_treeString_nested_struct(self, spark):
from spark_namespace.sql.types import IntegerType, StringType, StructField, StructType

schema = StructType(
[
StructField("id", IntegerType(), True),
StructField(
"person",
StructType([StructField("name", StringType(), True), StructField("age", IntegerType(), True)]),
True,
),
]
)
data = [(1, {"name": "Alice", "age": 25})]
df = spark.createDataFrame(data, schema)
tree = df.schema.treeString()

assert "root\n" in tree
assert " |-- id:" in tree
assert " |-- person: struct (nullable = true)" in tree
assert "name:" in tree
assert "age:" in tree

def test_treeString_with_level(self, spark):
from spark_namespace.sql.types import IntegerType, StringType, StructField, StructType

schema = StructType(
[
StructField("id", IntegerType(), True),
StructField(
"person",
StructType(
[
StructField("name", StringType(), True),
StructField("details", StructType([StructField("address", StringType(), True)]), True),
]
),
True,
),
]
)

data = [(1, {"name": "Alice", "details": {"address": "123 Main St"}})]
df = spark.createDataFrame(data, schema)

# Level 1 should only show top-level fields
tree_level_1 = df.schema.treeString(level=1)
assert " |-- id:" in tree_level_1
assert " |-- person: struct" in tree_level_1
# Should not show nested field names at level 1
lines = tree_level_1.split("\n")
assert len([line for line in lines if line.strip()]) <= 3

def test_treeString_array_type(self, spark):
from spark_namespace.sql.types import ArrayType, StringType, StructField, StructType

schema = StructType(
[StructField("name", StringType(), True), StructField("hobbies", ArrayType(StringType()), True)]
)

data = [("Alice", ["reading", "coding"])]
df = spark.createDataFrame(data, schema)
tree = df.schema.treeString()

assert "root\n" in tree
assert " |-- name:" in tree
assert " |-- hobbies: array<" in tree
assert "(nullable = true)" in tree
8 changes: 5 additions & 3 deletions tests/fast/test_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def test_insert_with_schema(self, duckdb_cursor):
res = duckdb_cursor.table("not_main.tbl").fetchall()
assert len(res) == 10

# TODO: This is not currently supported # noqa: TD002, TD003
with pytest.raises(duckdb.CatalogException, match="Table with name tbl does not exist"):
duckdb_cursor.table("not_main.tbl").insert([42, 21, 1337])
# Insert into a schema-qualified table should work; table has a single column from range(10)
duckdb_cursor.table("not_main.tbl").insert([42])
res2 = duckdb_cursor.table("not_main.tbl").fetchall()
assert len(res2) == 11
assert (42,) in res2
9 changes: 8 additions & 1 deletion tests/fast/test_relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,14 @@ def test_value_relation(self, duckdb_cursor):
rel = duckdb_cursor.values((const(1), const(2), const(3)), const(4))

# Using Expressions that can't be resolved:
with pytest.raises(duckdb.BinderException, match='Referenced column "a" not found in FROM clause!'):
# Accept both historical and current Binder error message variants
with pytest.raises(
duckdb.BinderException,
match=(
r'Referenced column "a" not found in FROM clause!|'
r'Referenced column "a" was not found because the FROM clause is missing'
),
):
duckdb_cursor.values(duckdb.ColumnExpression("a"))

def test_insert_into_operator(self):
Expand Down
Loading