diff --git a/Cargo.lock b/Cargo.lock index fe1ee087e3f1..1b18989ae19c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3158,6 +3158,7 @@ dependencies = [ "serde_json", "simd-json", "simdutf8", + "strum", "strum_macros", "tempfile", "tokio", diff --git a/Cargo.toml b/Cargo.toml index 543641b0536a..4a2248bec2f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -78,6 +78,7 @@ sqlparser = "0.53" stacker = "0.1" streaming-iterator = "0.1.9" strength_reduce = "0.2" +strum = "0.26" strum_macros = "0.26" thiserror = "2" tokio = "1.26" diff --git a/crates/polars-io/Cargo.toml b/crates/polars-io/Cargo.toml index 0ea9402542a9..796b7642819b 100644 --- a/crates/polars-io/Cargo.toml +++ b/crates/polars-io/Cargo.toml @@ -46,6 +46,7 @@ serde = { workspace = true, features = ["rc"], optional = true } serde_json = { version = "1", optional = true } simd-json = { workspace = true, optional = true } simdutf8 = { workspace = true, optional = true } +strum = { workspace = true, optional = true } strum_macros = { workspace = true, optional = true } tokio = { workspace = true, features = ["fs", "net", "rt-multi-thread", "time", "sync"], optional = true } tokio-util = { workspace = true, features = ["io", "io-util"], optional = true } @@ -60,7 +61,7 @@ home = "0.5.4" tempfile = "3" [features] -catalog = ["cloud", "serde", "reqwest", "futures", "strum_macros"] +catalog = ["cloud", "serde", "reqwest", "futures", "strum", "strum_macros", "chrono"] default = ["decompress"] # support for arrows json parsing json = [ diff --git a/crates/polars-io/src/catalog/schema.rs b/crates/polars-io/src/catalog/schema.rs index 33f54e056f03..07ed53139bcb 100644 --- a/crates/polars-io/src/catalog/schema.rs +++ b/crates/polars-io/src/catalog/schema.rs @@ -1,9 +1,13 @@ use polars_core::prelude::{DataType, Field}; use polars_core::schema::{Schema, SchemaRef}; -use polars_error::{polars_bail, polars_err, PolarsResult}; +use polars_error::{polars_bail, polars_err, to_compute_err, PolarsResult}; +use polars_utils::error::TruncateErrorDetail; +use polars_utils::format_pl_smallstr; use polars_utils::pl_str::PlSmallStr; -use super::unity::models::TableInfo; +use super::unity::models::{ColumnInfo, ColumnTypeJson, TableInfo}; +use crate::catalog::unity::models::ColumnTypeJsonType; +use crate::utils::decode_json_response; /// Returns `(schema, hive_schema)` pub fn table_info_to_schemas( @@ -17,8 +21,6 @@ pub fn table_info_to_schemas( let mut hive_schema = Schema::default(); for (i, col) in columns.iter().enumerate() { - let dtype = parse_type_str(&col.type_text)?; - if let Some(position) = col.position { if usize::try_from(position).unwrap() != i { polars_bail!( @@ -28,6 +30,8 @@ pub fn table_info_to_schemas( } } + let field = column_info_to_field(col)?; + if let Some(i) = col.partition_index { if usize::try_from(i).unwrap() != hive_schema.len() { polars_bail!( @@ -36,9 +40,9 @@ pub fn table_info_to_schemas( ) } - hive_schema.extend([Field::new(col.name.as_str().into(), dtype)]); + hive_schema.extend([field]); } else { - schema.extend([Field::new(col.name.as_str().into(), dtype)]) + schema.extend([field]) } } @@ -50,34 +54,154 @@ pub fn table_info_to_schemas( )) } -/// Parse a type string from a catalog API response. +pub fn column_info_to_field(column_info: &ColumnInfo) -> PolarsResult { + Ok(Field::new( + column_info.name.clone(), + parse_type_json_str(&column_info.type_json)?, + )) +} + +/// e.g. +/// ```json +/// {"name":"Int64","type":"long","nullable":true} +/// {"name":"List","type":{"type":"array","elementType":"long","containsNull":true},"nullable":true} +/// ``` +pub fn parse_type_json_str(type_json: &str) -> PolarsResult { + let decoded: ColumnTypeJson = decode_json_response(type_json.as_bytes())?; + + parse_type_json(&decoded).map_err(|e| { + e.wrap_msg(|e| { + format!( + "error parsing type response: {}, type_json: {}", + e, + TruncateErrorDetail(type_json) + ) + }) + }) +} + +/// We prefer this as `type_text` cannot be trusted for consistency (e.g. we may expect `decimal(int,int)` +/// but instead get `decimal`, or `struct<...>` and instead get `struct`). +pub fn parse_type_json(type_json: &ColumnTypeJson) -> PolarsResult { + use ColumnTypeJsonType::*; + + let out = match &type_json.type_ { + TypeName(name) => match name.as_str() { + "array" => { + let inner_json: &ColumnTypeJsonType = + type_json.element_type.as_ref().ok_or_else(|| { + polars_err!( + ComputeError: + "missing elementType in response for array type" + ) + })?; + + let inner_dtype = parse_type_json_type(inner_json)?; + + DataType::List(Box::new(inner_dtype)) + }, + + "struct" => { + let fields_json: &[ColumnTypeJson] = + type_json.fields.as_deref().ok_or_else(|| { + polars_err!( + ComputeError: + "missing elementType in response for array type" + ) + })?; + + let fields = fields_json + .iter() + .map(|x| { + let name = x.name.clone().ok_or_else(|| { + polars_err!( + ComputeError: + "missing name in fields response for struct type" + ) + })?; + let dtype = parse_type_json(x)?; + + Ok(Field::new(name, dtype)) + }) + .collect::>>()?; + + DataType::Struct(fields) + }, + + "map" => { + let key_type = type_json.key_type.as_ref().ok_or_else(|| { + polars_err!( + ComputeError: + "missing keyType in response for map type" + ) + })?; + + let value_type = type_json.value_type.as_ref().ok_or_else(|| { + polars_err!( + ComputeError: + "missing valueType in response for map type" + ) + })?; + + DataType::List(Box::new(DataType::Struct(vec![ + Field::new( + PlSmallStr::from_static("key"), + parse_type_json_type(key_type)?, + ), + Field::new( + PlSmallStr::from_static("value"), + parse_type_json_type(value_type)?, + ), + ]))) + }, + + name => parse_type_text(name)?, + }, + + TypeJson(type_json) => parse_type_json(type_json.as_ref())?, + }; + + Ok(out) +} + +fn parse_type_json_type(type_json_type: &ColumnTypeJsonType) -> PolarsResult { + use ColumnTypeJsonType::*; + + match type_json_type { + TypeName(name) => parse_type_text(name), + TypeJson(type_json) => parse_type_json(type_json.as_ref()), + } +} + +/// Parses the string variant of the `type` field within a `type_json`. This can be understood as +/// the leaf / non-nested datatypes of the field. /// /// References: /// * https://spark.apache.org/docs/latest/sql-ref-datatypes.html /// * https://docs.databricks.com/api/workspace/tables/get /// * https://docs.databricks.com/en/sql/language-manual/sql-ref-datatypes.html /// -/// Note: `type_precision` and `type_scale` in the API response are defined as supplementary data to -/// the `type_text`, but from testing they aren't actually used - e.g. a decimal type would have a -/// `type_text` of `decimal(18, 2)` -fn parse_type_str(type_text: &str) -> PolarsResult { +/// Notes: +/// * `type_precision` and `type_scale` in the API response are defined as supplementary data to +/// the `type_text`, but from testing they aren't actually used - e.g. a decimal type would have a +/// `type_text` of `decimal(18, 2)` +fn parse_type_text(type_text: &str) -> PolarsResult { + use polars_core::prelude::TimeUnit; use DataType::*; let dtype = match type_text { "boolean" => Boolean, - "byte" | "tinyint" => Int8, - "short" | "smallint" => Int16, + "tinyint" | "byte" => Int8, + "smallint" | "short" => Int16, "int" | "integer" => Int32, - "long" | "bigint" => Int64, + "bigint" | "long" => Int64, "float" | "real" => Float32, "double" => Float64, "date" => Date, - "timestamp" | "timestamp_ltz" | "timestamp_ntz" => { - Datetime(polars_core::prelude::TimeUnit::Nanoseconds, None) - }, + "timestamp" | "timestamp_ntz" | "timestamp_ltz" => Datetime(TimeUnit::Nanoseconds, None), "string" => String, "binary" => Binary, @@ -105,31 +229,10 @@ fn parse_type_str(type_text: &str) -> PolarsResult { v ) })? - } else if v.starts_with("array") { - // e.g. array - DataType::List(Box::new(parse_type_str(extract_angle_brackets_inner( - v, "array", - )?)?)) - } else if v.starts_with("struct") { - parse_struct_type_str(v)? - } else if v.starts_with("map") { - // e.g. map - let inner = extract_angle_brackets_inner(v, "map")?; - let (key_type_str, value_type_str) = split_comma_nesting_aware(inner); - DataType::List(Box::new(DataType::Struct(vec![ - Field::new( - PlSmallStr::from_static("key"), - parse_type_str(key_type_str)?, - ), - Field::new( - PlSmallStr::from_static("value"), - parse_type_str(value_type_str)?, - ), - ]))) } else { polars_bail!( ComputeError: - "parse_type_str unknown type name: {}", + "parse_type_text unknown type name: {}", v ) } @@ -139,127 +242,272 @@ fn parse_type_str(type_text: &str) -> PolarsResult { Ok(dtype) } -/// `array -> inner` -fn extract_angle_brackets_inner<'a>(value: &'a str, name: &'static str) -> PolarsResult<&'a str> { - let i = value.find('<'); - let j = value.rfind('>'); +// Conversion functions to API format. Mainly used for constructing the request to create tables. + +pub fn schema_to_column_info_list(schema: &Schema) -> PolarsResult> { + schema + .iter() + .enumerate() + .map(|(i, (name, dtype))| { + let name = name.clone(); + let type_text = dtype_to_type_text(dtype)?; + let type_name = dtype_to_type_name(dtype)?; + let type_json = serde_json::to_string(&field_to_type_json(name.clone(), dtype)?) + .map_err(to_compute_err)?; + + Ok(ColumnInfo { + name, + type_name, + type_text, + type_json, + position: Some(i.try_into().unwrap()), + comment: None, + partition_index: None, + }) + }) + .collect::>() +} - if i.is_none() || j.is_none() || i.unwrap().saturating_add(1) >= j.unwrap() { - polars_bail!( - ComputeError: - "type format did not match {}<...>: {}", - name, value - ) +/// Creates the `type_text` field of the API. Opposite of [`parse_type_text`] +fn dtype_to_type_text(dtype: &DataType) -> PolarsResult { + use polars_core::prelude::TimeUnit; + use DataType::*; + + macro_rules! S { + ($e:expr) => { + PlSmallStr::from_static($e) + }; } - let i = i.unwrap(); - let j = j.unwrap(); + let out = match dtype { + Boolean => S!("boolean"), - let inner = value[i + 1..j].trim(); + Int8 => S!("tinyint"), + Int16 => S!("smallint"), + Int32 => S!("int"), + Int64 => S!("bigint"), - Ok(inner) -} + Float32 => S!("float"), + Float64 => S!("double"), -/// `struct,effective_list:struct>` -fn parse_struct_type_str(value: &str) -> PolarsResult { - let orig_value = value; - let mut value = extract_angle_brackets_inner(value, "struct")?; + Date => S!("date"), + Datetime(TimeUnit::Nanoseconds, None) => S!("timestamp_ntz"), - let mut fields = vec![]; + String => S!("string"), + Binary => S!("binary"), - while !value.is_empty() { - let (field_str, new_value) = split_comma_nesting_aware(value); - value = new_value; + Null => S!("null"), - let (name, dtype_str) = field_str.split_once(':').ok_or_else(|| { - polars_err!( - ComputeError: - "type format did not match struct: {}", - orig_value - ) - })?; + Decimal(precision, scale) => { + let precision = precision.unwrap_or(38); + let scale = scale.unwrap_or(0); - let dtype = parse_type_str(dtype_str)?; + format_pl_smallstr!("decimal({},{})", precision, scale) + }, - fields.push(Field::new(name.into(), dtype)); - } + List(inner) => { + if let Some((key_type, value_type)) = get_list_map_type(inner) { + format_pl_smallstr!( + "map<{},{}>", + dtype_to_type_text(key_type)?, + dtype_to_type_text(value_type)? + ) + } else { + format_pl_smallstr!("array<{}>", dtype_to_type_text(inner)?) + } + }, + + Struct(fields) => { + // Yes, it's possible to construct column names containing the brackets. This won't + // affect us as we parse using `type_json` rather than this field. + let mut out = std::string::String::from("struct<"); - Ok(DataType::Struct(fields)) + for Field { name, dtype } in fields { + out.push_str(name); + out.push(':'); + out.push_str(&dtype_to_type_text(dtype)?); + out.push(','); + } + + if out.ends_with(',') { + out.truncate(out.len() - 1); + } + + out.push('>'); + + out.into() + }, + + v => polars_bail!( + ComputeError: + "dtype_to_type_text unsupported type: {}", + v + ), + }; + + Ok(out) } -/// `default:decimal(38,18),promotional:struct` -> -/// * 1: `default:decimal(38,18)` -/// * 2: `struct` -/// -/// If there are no splits, returns the full string and an empty string. -fn split_comma_nesting_aware(value: &str) -> (&str, &str) { - let mut bracket_level = 0usize; - let mut angle_bracket_level = 0usize; - - for (i, b) in value.as_bytes().iter().enumerate() { - match b { - b'(' => bracket_level += 1, - b')' => bracket_level = bracket_level.saturating_sub(1), - b'<' => angle_bracket_level += 1, - b'>' => angle_bracket_level = angle_bracket_level.saturating_sub(1), - b',' if bracket_level == 0 && angle_bracket_level == 0 => { - return (&value[..i], &value[1 + i..]) - }, - _ => {}, - } +/// Creates the `type_name` field, from testing this wasn't exactly the same as the `type_text` field. +fn dtype_to_type_name(dtype: &DataType) -> PolarsResult { + use polars_core::prelude::TimeUnit; + use DataType::*; + + macro_rules! S { + ($e:expr) => { + PlSmallStr::from_static($e) + }; } - (value, &value[value.len()..]) + let out = match dtype { + Boolean => S!("BOOLEAN"), + + Int8 => S!("BYTE"), + Int16 => S!("SHORT"), + Int32 => S!("INT"), + Int64 => S!("LONG"), + + Float32 => S!("FLOAT"), + Float64 => S!("DOUBLE"), + + Date => S!("DATE"), + Datetime(TimeUnit::Nanoseconds, None) => S!("TIMESTAMP_NTZ"), + String => S!("STRING"), + Binary => S!("BINARY"), + + Null => S!("NULL"), + + Decimal(..) => S!("DECIMAL"), + + List(inner) => { + if get_list_map_type(inner).is_some() { + S!("MAP") + } else { + S!("ARRAY") + } + }, + + Struct(..) => S!("STRUCT"), + + v => polars_bail!( + ComputeError: + "dtype_to_type_text unsupported type: {}", + v + ), + }; + + Ok(out) } -#[cfg(test)] -mod tests { - #[test] - fn test_parse_type_str_nested_struct() { - use super::{parse_type_str, DataType, Field}; - - let type_str = "struct,effective_list:struct>"; - let dtype = parse_type_str(type_str).unwrap(); - - use DataType::*; - - assert_eq!( - dtype, - Struct(vec![ - Field::new("default".into(), Decimal(Some(38), Some(18))), - Field::new( - "promotional".into(), - Struct(vec![Field::new( - "default".into(), - Decimal(Some(38), Some(18)) - )]) - ), - Field::new( - "effective_list".into(), - Struct(vec![Field::new( - "default".into(), - Decimal(Some(38), Some(18)) - )]) - ) - ]) - ); +/// Creates the `type_json` field. +fn field_to_type_json(name: PlSmallStr, dtype: &DataType) -> PolarsResult { + Ok(ColumnTypeJson { + name: Some(name), + type_: dtype_to_type_json(dtype)?, + nullable: Some(true), + // We set this to Some(_) so that the output matches the one generated by Databricks. + metadata: Some(Default::default()), + + ..Default::default() + }) +} + +fn dtype_to_type_json(dtype: &DataType) -> PolarsResult { + use polars_core::prelude::TimeUnit; + use DataType::*; + + macro_rules! S { + ($e:expr) => { + ColumnTypeJsonType::from_static_type_name($e) + }; } - #[test] - fn test_parse_type_str_map() { - use super::{parse_type_str, DataType, Field}; + let out = match dtype { + Boolean => S!("boolean"), + + Int8 => S!("byte"), + Int16 => S!("short"), + Int32 => S!("integer"), + Int64 => S!("long"), + + Float32 => S!("float"), + Float64 => S!("double"), + + Date => S!("date"), + Datetime(TimeUnit::Nanoseconds, None) => S!("timestamp_ntz"), + + String => S!("string"), + Binary => S!("binary"), + + Null => S!("null"), + + Decimal(..) => ColumnTypeJsonType::TypeName(dtype_to_type_text(dtype)?), + + List(inner) => { + let out = if let Some((key_type, value_type)) = get_list_map_type(inner) { + ColumnTypeJson { + type_: ColumnTypeJsonType::from_static_type_name("map"), + key_type: Some(dtype_to_type_json(key_type)?), + value_type: Some(dtype_to_type_json(value_type)?), + value_contains_null: Some(true), + + ..Default::default() + } + } else { + ColumnTypeJson { + type_: ColumnTypeJsonType::from_static_type_name("array"), + element_type: Some(dtype_to_type_json(inner)?), + contains_null: Some(true), + + ..Default::default() + } + }; + + ColumnTypeJsonType::TypeJson(Box::new(out)) + }, + + Struct(fields) => { + let out = ColumnTypeJson { + type_: ColumnTypeJsonType::from_static_type_name("struct"), + fields: Some( + fields + .iter() + .map(|Field { name, dtype }| field_to_type_json(name.clone(), dtype)) + .collect::>()?, + ), + + ..Default::default() + }; - let type_str = "map,array>"; - let dtype = parse_type_str(type_str).unwrap(); + ColumnTypeJsonType::TypeJson(Box::new(out)) + }, - use DataType::*; + v => polars_bail!( + ComputeError: + "dtype_to_type_text unsupported type: {}", + v + ), + }; - assert_eq!( - dtype, - List(Box::new(Struct(vec![ - Field::new("key".into(), List(Box::new(Int32))), - Field::new("value".into(), List(Box::new(Decimal(Some(18), Some(2))))) - ]))) - ); + Ok(out) +} + +/// Tries to interpret the List type as a `map` field, which is essentially +/// List(Struct(("key", ), ("value", ))). +/// +/// Returns `Option<(key_type, value_type)>` +fn get_list_map_type(list_inner_dtype: &DataType) -> Option<(&DataType, &DataType)> { + let DataType::Struct(fields) = list_inner_dtype else { + return None; + }; + + let [fld1, fld2] = fields.as_slice() else { + return None; + }; + + if !(fld1.name == "key" && fld2.name == "value") { + return None; } + + Some((fld1.dtype(), fld2.dtype())) } diff --git a/crates/polars-io/src/catalog/unity/client.rs b/crates/polars-io/src/catalog/unity/client.rs index 53f0814a01c5..9f8d2e1cc3ad 100644 --- a/crates/polars-io/src/catalog/unity/client.rs +++ b/crates/polars-io/src/catalog/unity/client.rs @@ -1,7 +1,11 @@ +use polars_core::prelude::PlHashMap; +use polars_core::schema::Schema; use polars_error::{polars_bail, to_compute_err, PolarsResult}; use super::models::{CatalogInfo, SchemaInfo, TableInfo}; -use super::utils::PageWalker; +use super::utils::{do_request, PageWalker}; +use crate::catalog::schema::schema_to_column_info_list; +use crate::catalog::unity::models::{ColumnInfo, DataSourceFormat, TableType}; use crate::impl_page_walk; use crate::utils::decode_json_response; @@ -64,25 +68,200 @@ impl CatalogClient { table_name.replace('/', "%2F") ); - let bytes = async { + let bytes = do_request( self.http_client .get(format!( "{}{}{}", &self.workspace_url, "/api/2.1/unity-catalog/tables/", full_table_name )) - .query(&[("full_name", full_table_name)]) - .send() - .await? - .bytes() - .await - } - .await - .map_err(to_compute_err)?; + .query(&[("full_name", full_table_name)]), + ) + .await?; let out: TableInfo = decode_json_response(&bytes)?; Ok(out) } + + pub async fn create_catalog( + &self, + catalog_name: &str, + comment: Option<&str>, + storage_root: Option<&str>, + ) -> PolarsResult { + let resp = do_request( + self.http_client + .post(format!( + "{}{}", + &self.workspace_url, "/api/2.1/unity-catalog/catalogs" + )) + .json(&Body { + name: catalog_name, + comment, + storage_root, + }), + ) + .await?; + + return decode_json_response(&resp); + + #[derive(serde::Serialize)] + struct Body<'a> { + name: &'a str, + comment: Option<&'a str>, + storage_root: Option<&'a str>, + } + } + + pub async fn delete_catalog(&self, catalog_name: &str, force: bool) -> PolarsResult<()> { + let catalog_name = catalog_name.replace('/', "%2F"); + + do_request( + self.http_client + .delete(format!( + "{}{}{}", + &self.workspace_url, "/api/2.1/unity-catalog/catalogs/", catalog_name + )) + .query(&[("force", force)]), + ) + .await?; + + Ok(()) + } + + pub async fn create_schema( + &self, + catalog_name: &str, + schema_name: &str, + comment: Option<&str>, + storage_root: Option<&str>, + ) -> PolarsResult { + let resp = do_request( + self.http_client + .post(format!( + "{}{}", + &self.workspace_url, "/api/2.1/unity-catalog/schemas" + )) + .json(&Body { + name: schema_name, + catalog_name, + comment, + storage_root, + }), + ) + .await?; + + return decode_json_response(&resp); + + #[derive(serde::Serialize)] + struct Body<'a> { + name: &'a str, + catalog_name: &'a str, + comment: Option<&'a str>, + storage_root: Option<&'a str>, + } + } + + pub async fn delete_schema( + &self, + catalog_name: &str, + schema_name: &str, + force: bool, + ) -> PolarsResult<()> { + let full_name = format!( + "{}.{}", + catalog_name.replace('/', "%2F"), + schema_name.replace('/', "%2F"), + ); + + do_request( + self.http_client + .delete(format!( + "{}{}{}", + &self.workspace_url, "/api/2.1/unity-catalog/schemas/", full_name + )) + .query(&[("force", force)]), + ) + .await?; + + Ok(()) + } + + /// Note, `data_source_format` can be None for some `table_type`s. + #[allow(clippy::too_many_arguments)] + pub async fn create_table( + &self, + catalog_name: &str, + schema_name: &str, + table_name: &str, + schema: Option<&Schema>, + table_type: &TableType, + data_source_format: Option<&DataSourceFormat>, + comment: Option<&str>, + storage_location: Option<&str>, + properties: &mut (dyn Iterator + Send + Sync), + ) -> PolarsResult { + let columns = schema.map(schema_to_column_info_list).transpose()?; + let columns = columns.as_deref(); + + let resp = do_request( + self.http_client + .post(format!( + "{}{}", + &self.workspace_url, "/api/2.1/unity-catalog/tables" + )) + .json(&Body { + name: table_name, + catalog_name, + schema_name, + table_type, + data_source_format, + comment, + columns, + storage_location, + properties: properties.collect(), + }), + ) + .await?; + + return decode_json_response(&resp); + + #[derive(serde::Serialize)] + struct Body<'a> { + name: &'a str, + catalog_name: &'a str, + schema_name: &'a str, + comment: Option<&'a str>, + table_type: &'a TableType, + #[serde(skip_serializing_if = "Option::is_none")] + data_source_format: Option<&'a DataSourceFormat>, + columns: Option<&'a [ColumnInfo]>, + storage_location: Option<&'a str>, + properties: PlHashMap<&'a str, &'a str>, + } + } + + pub async fn delete_table( + &self, + catalog_name: &str, + schema_name: &str, + table_name: &str, + ) -> PolarsResult<()> { + let full_name = format!( + "{}.{}.{}", + catalog_name.replace('/', "%2F"), + schema_name.replace('/', "%2F"), + table_name.replace('/', "%2F"), + ); + + do_request(self.http_client.delete(format!( + "{}{}{}", + &self.workspace_url, "/api/2.1/unity-catalog/tables/", full_name + ))) + .await?; + + Ok(()) + } } pub struct CatalogClientBuilder { diff --git a/crates/polars-io/src/catalog/unity/models.rs b/crates/polars-io/src/catalog/unity/models.rs index da9f604e27ee..f66701b8255b 100644 --- a/crates/polars-io/src/catalog/unity/models.rs +++ b/crates/polars-io/src/catalog/unity/models.rs @@ -1,13 +1,52 @@ +use polars_core::prelude::PlHashMap; +use polars_utils::pl_str::PlSmallStr; + #[derive(Debug, serde::Deserialize)] pub struct CatalogInfo { pub name: String, + pub comment: Option, + + #[serde(default)] + pub storage_location: Option, + + #[serde(default, deserialize_with = "null_to_default")] + pub properties: PlHashMap, + + #[serde(default, deserialize_with = "null_to_default")] + pub options: PlHashMap, + + #[serde(with = "chrono::serde::ts_milliseconds_option")] + pub created_at: Option>, + + pub created_by: Option, + + #[serde(with = "chrono::serde::ts_milliseconds_option")] + pub updated_at: Option>, + + pub updated_by: Option, } #[derive(Debug, serde::Deserialize)] pub struct SchemaInfo { pub name: String, pub comment: Option, + + #[serde(default, deserialize_with = "null_to_default")] + pub properties: PlHashMap, + + #[serde(default)] + pub storage_location: Option, + + #[serde(with = "chrono::serde::ts_milliseconds_option")] + pub created_at: Option>, + + pub created_by: Option, + + #[serde(with = "chrono::serde::ts_milliseconds_option")] + pub updated_at: Option>, + + pub updated_by: Option, } #[derive(Debug, serde::Deserialize)] @@ -15,17 +54,36 @@ pub struct TableInfo { pub name: String, pub table_id: String, pub table_type: TableType, + #[serde(default)] pub comment: Option, + #[serde(default)] pub storage_location: Option, + #[serde(default)] pub data_source_format: Option, + #[serde(default)] pub columns: Option>, + + #[serde(default, deserialize_with = "null_to_default")] + pub properties: PlHashMap, + + #[serde(with = "chrono::serde::ts_milliseconds_option")] + pub created_at: Option>, + + pub created_by: Option, + + #[serde(with = "chrono::serde::ts_milliseconds_option")] + pub updated_at: Option>, + + pub updated_by: Option, } -#[derive(Debug, strum_macros::Display, serde::Deserialize)] +#[derive( + Debug, strum_macros::Display, strum_macros::EnumString, serde::Serialize, serde::Deserialize, +)] #[strum(serialize_all = "SCREAMING_SNAKE_CASE")] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum TableType { @@ -39,7 +97,9 @@ pub enum TableType { ExternalShallowClone, } -#[derive(Debug, strum_macros::Display, serde::Deserialize)] +#[derive( + Debug, strum_macros::Display, strum_macros::EnumString, serde::Serialize, serde::Deserialize, +)] #[strum(serialize_all = "SCREAMING_SNAKE_CASE")] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum DataSourceFormat { @@ -70,12 +130,139 @@ pub enum DataSourceFormat { VectorIndexFormat, } -#[derive(Debug, serde::Deserialize)] +#[derive(Debug, serde::Serialize, serde::Deserialize)] pub struct ColumnInfo { - pub name: String, - pub type_text: String, - pub type_interval_type: Option, + pub name: PlSmallStr, + pub type_name: PlSmallStr, + pub type_text: PlSmallStr, + pub type_json: String, pub position: Option, pub comment: Option, pub partition_index: Option, } + +/// Note: This struct contains all the field names for a few different possible type / field presence +/// combinations. We use serde(default) and skip_serializing_if to get the desired serialization +/// output. +/// +/// E.g.: +/// +/// ```text +/// { +/// "name": "List", +/// "type": {"type": "array", "elementType": "long", "containsNull": True}, +/// "nullable": True, +/// "metadata": {}, +/// } +/// { +/// "name": "Struct", +/// "type": { +/// "type": "struct", +/// "fields": [{"name": "x", "type": "long", "nullable": True, "metadata": {}}], +/// }, +/// "nullable": True, +/// "metadata": {}, +/// } +/// { +/// "name": "ListStruct", +/// "type": { +/// "type": "array", +/// "elementType": { +/// "type": "struct", +/// "fields": [{"name": "x", "type": "long", "nullable": True, "metadata": {}}], +/// }, +/// "containsNull": True, +/// }, +/// "nullable": True, +/// "metadata": {}, +/// } +/// { +/// "name": "Map", +/// "type": { +/// "type": "map", +/// "keyType": "string", +/// "valueType": "string", +/// "valueContainsNull": True, +/// }, +/// "nullable": True, +/// "metadata": {}, +/// } +/// ``` +#[derive(Debug, Default, serde::Serialize, serde::Deserialize)] +pub struct ColumnTypeJson { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub name: Option, + + #[serde(rename = "type")] + pub type_: ColumnTypeJsonType, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub nullable: Option, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub metadata: Option>, + + // Used for List types + #[serde( + default, + rename = "elementType", + skip_serializing_if = "Option::is_none" + )] + pub element_type: Option, + + #[serde( + default, + rename = "containsNull", + skip_serializing_if = "Option::is_none" + )] + pub contains_null: Option, + + // Used for Struct types + #[serde(default, skip_serializing_if = "Option::is_none")] + pub fields: Option>, + + // Used for Map types + #[serde(default, rename = "keyType", skip_serializing_if = "Option::is_none")] + pub key_type: Option, + + #[serde(default, rename = "valueType", skip_serializing_if = "Option::is_none")] + pub value_type: Option, + + #[serde( + default, + rename = "valueContainsNull", + skip_serializing_if = "Option::is_none" + )] + pub value_contains_null: Option, +} + +#[derive(Debug, serde::Serialize, serde::Deserialize)] +#[serde(untagged)] +pub enum ColumnTypeJsonType { + /// * `{"type": "name", ..}`` + TypeName(PlSmallStr), + /// * `{"type": {"type": "name", ..}}` + TypeJson(Box), +} + +impl Default for ColumnTypeJsonType { + fn default() -> Self { + Self::TypeName(PlSmallStr::EMPTY) + } +} + +impl ColumnTypeJsonType { + pub const fn from_static_type_name(type_name: &'static str) -> Self { + Self::TypeName(PlSmallStr::from_static(type_name)) + } +} + +fn null_to_default<'de, T, D>(d: D) -> Result +where + T: Default + serde::de::Deserialize<'de>, + D: serde::de::Deserializer<'de>, +{ + use serde::Deserialize; + let opt_val = Option::::deserialize(d)?; + Ok(opt_val.unwrap_or_default()) +} diff --git a/crates/polars-io/src/catalog/unity/utils.rs b/crates/polars-io/src/catalog/unity/utils.rs index 56308f0e4466..fee60f869b02 100644 --- a/crates/polars-io/src/catalog/unity/utils.rs +++ b/crates/polars-io/src/catalog/unity/utils.rs @@ -1,7 +1,29 @@ use bytes::Bytes; use polars_error::{to_compute_err, PolarsResult}; +use polars_utils::error::TruncateErrorDetail; use reqwest::RequestBuilder; +/// Performs the request and attaches the response body to any error messages. +pub(super) async fn do_request(request: reqwest::RequestBuilder) -> PolarsResult { + let resp = request.send().await.map_err(to_compute_err)?; + let opt_err = resp.error_for_status_ref().map(|_| ()); + let resp_bytes = resp.bytes().await.map_err(to_compute_err)?; + + opt_err.map_err(|e| { + to_compute_err(e).wrap_msg(|e| { + let body = String::from_utf8_lossy(&resp_bytes); + + format!( + "error: {}, response body: {}", + e, + TruncateErrorDetail(&body) + ) + }) + })?; + + Ok(resp_bytes) +} + /// Support for traversing paginated response values that look like: /// ```text /// { @@ -94,9 +116,6 @@ impl PageWalker { request }; - async { request.send().await?.bytes().await } - .await - .map(Some) - .map_err(to_compute_err) + do_request(request).await.map(Some) } } diff --git a/crates/polars-io/src/utils/other.rs b/crates/polars-io/src/utils/other.rs index 6496b0d54d68..a3529b9d7937 100644 --- a/crates/polars-io/src/utils/other.rs +++ b/crates/polars-io/src/utils/other.rs @@ -209,29 +209,17 @@ pub fn decode_json_response(bytes: &[u8]) -> PolarsResult where T: for<'de> serde::de::Deserialize<'de>, { - use polars_core::config; use polars_error::to_compute_err; + use polars_utils::error::TruncateErrorDetail; serde_json::from_slice(bytes) .map_err(to_compute_err) .map_err(|e| { e.wrap_msg(|e| { - let maybe_truncated = if config::verbose() { - bytes - } else { - // Clamp the output on non-verbose - &bytes[..bytes.len().min(4096)] - }; - format!( - "error decoding response: {}, response value: {}{}", + "error decoding response: {}, response value: {}", e, - String::from_utf8_lossy(maybe_truncated), - if maybe_truncated.len() != bytes.len() { - " ...(set POLARS_VERBOSE=1 to see full response)" - } else { - "" - } + TruncateErrorDetail(&String::from_utf8_lossy(bytes)) ) }) }) diff --git a/crates/polars-python/src/catalog/mod.rs b/crates/polars-python/src/catalog/mod.rs index 9fafe6cbeb17..7cb5111b5403 100644 --- a/crates/polars-python/src/catalog/mod.rs +++ b/crates/polars-python/src/catalog/mod.rs @@ -1,14 +1,20 @@ -use polars::prelude::LazyFrame; +use std::str::FromStr; + +use polars::prelude::{LazyFrame, PlHashMap, PlSmallStr, Schema}; +use polars_io::catalog::schema::parse_type_json_str; use polars_io::catalog::unity::client::{CatalogClient, CatalogClientBuilder}; -use polars_io::catalog::unity::models::{CatalogInfo, ColumnInfo, SchemaInfo, TableInfo}; +use polars_io::catalog::unity::models::{ + CatalogInfo, ColumnInfo, DataSourceFormat, SchemaInfo, TableInfo, TableType, +}; use polars_io::cloud::credential_provider::PlCredentialProvider; use polars_io::pl_async; use pyo3::exceptions::PyValueError; +use pyo3::sync::GILOnceCell; use pyo3::types::{PyAnyMethods, PyDict, PyList}; -use pyo3::{pyclass, pymethods, Bound, PyObject, PyResult, Python}; +use pyo3::{pyclass, pymethods, Bound, IntoPyObject, Py, PyAny, PyObject, PyResult, Python}; use crate::lazyframe::PyLazyFrame; -use crate::prelude::parse_cloud_options; +use crate::prelude::{parse_cloud_options, Wrap}; use crate::utils::{to_py_err, EnterPolarsExt}; macro_rules! pydict_insert_keys { @@ -26,6 +32,13 @@ macro_rules! pydict_insert_keys { }; } +// Result dataclasses. These are initialized from Python by calling [`PyCatalogClient::init_classes`]. + +static CATALOG_INFO_CLS: GILOnceCell> = GILOnceCell::new(); +static SCHEMA_INFO_CLS: GILOnceCell> = GILOnceCell::new(); +static TABLE_INFO_CLS: GILOnceCell> = GILOnceCell::new(); +static COLUMN_INFO_CLS: GILOnceCell> = GILOnceCell::new(); + #[pyclass] pub struct PyCatalogClient(CatalogClient); @@ -50,20 +63,24 @@ impl PyCatalogClient { pl_async::get_runtime().block_on_potential_spawn(self.client().list_catalogs()) })?; - PyList::new( + let mut opt_err = None; + + let out = PyList::new( py, - v.into_iter().map(|CatalogInfo { name, comment }| { - let dict = PyDict::new(py); + v.into_iter().map(|x| { + let v = catalog_info_to_pyobject(py, x); + if let Ok(v) = v { + Some(v) + } else { + opt_err.replace(v); + None + } + }), + )?; - pydict_insert_keys!(dict, { - name, - comment, - }); + opt_err.transpose()?; - dict - }), - ) - .map(|x| x.into()) + Ok(out.into()) } #[pyo3(signature = (catalog_name))] @@ -73,20 +90,25 @@ impl PyCatalogClient { .block_on_potential_spawn(self.client().list_schemas(catalog_name)) })?; - PyList::new( + let mut opt_err = None; + + let out = PyList::new( py, - v.into_iter().map(|SchemaInfo { name, comment }| { - let dict = PyDict::new(py); + v.into_iter().map(|x| { + let v = schema_info_to_pyobject(py, x); + match v { + Ok(v) => Some(v), + Err(_) => { + opt_err.replace(v); + None + }, + } + }), + )?; - pydict_insert_keys!(dict, { - name, - comment, - }); + opt_err.transpose()?; - dict - }), - ) - .map(|x| x.into()) + Ok(out.into()) } #[pyo3(signature = (catalog_name, schema_name))] @@ -101,31 +123,47 @@ impl PyCatalogClient { .block_on_potential_spawn(self.client().list_tables(catalog_name, schema_name)) })?; - PyList::new( + let mut opt_err = None; + + let out = PyList::new( py, - v.into_iter() - .map(|table_entry| table_entry_to_pydict(py, table_entry)), - ) - .map(|x| x.into()) + v.into_iter().map(|table_info| { + let v = table_info_to_pyobject(py, table_info); + + if let Ok(v) = v { + Some(v) + } else { + opt_err.replace(v); + None + } + }), + )? + .into(); + + opt_err.transpose()?; + + Ok(out) } - #[pyo3(signature = (catalog_name, schema_name, table_name))] + #[pyo3(signature = (table_name, catalog_name, schema_name))] pub fn get_table_info( &self, py: Python, + table_name: &str, catalog_name: &str, schema_name: &str, - table_name: &str, ) -> PyResult { - let table_entry = py.enter_polars(|| { - pl_async::get_runtime().block_on_potential_spawn(self.client().get_table_info( - catalog_name, - schema_name, - table_name, - )) - })?; - - Ok(table_entry_to_pydict(py, table_entry).into()) + let table_info = py + .enter_polars(|| { + pl_async::get_runtime().block_on_potential_spawn(self.client().get_table_info( + table_name, + catalog_name, + schema_name, + )) + }) + .map_err(to_py_err)?; + + table_info_to_pyobject(py, table_info).map(|x| x.into()) } #[pyo3(signature = (catalog_name, schema_name, table_name, cloud_options, credential_provider, retries))] @@ -166,6 +204,161 @@ impl PyCatalogClient { .into(), ) } + + #[pyo3(signature = (catalog_name, comment, storage_root))] + pub fn create_catalog( + &self, + py: Python, + catalog_name: &str, + comment: Option<&str>, + storage_root: Option<&str>, + ) -> PyResult { + let catalog_info = py + .allow_threads(|| { + pl_async::get_runtime().block_on_potential_spawn(self.client().create_catalog( + catalog_name, + comment, + storage_root, + )) + }) + .map_err(to_py_err)?; + + catalog_info_to_pyobject(py, catalog_info).map(|x| x.into()) + } + + #[pyo3(signature = (catalog_name, force))] + pub fn delete_catalog(&self, py: Python, catalog_name: &str, force: bool) -> PyResult<()> { + py.allow_threads(|| { + pl_async::get_runtime() + .block_on_potential_spawn(self.client().delete_catalog(catalog_name, force)) + }) + .map_err(to_py_err) + } + + #[pyo3(signature = (catalog_name, schema_name, comment, storage_root))] + pub fn create_schema( + &self, + py: Python, + catalog_name: &str, + schema_name: &str, + comment: Option<&str>, + storage_root: Option<&str>, + ) -> PyResult { + let schema_info = py + .allow_threads(|| { + pl_async::get_runtime().block_on_potential_spawn(self.client().create_schema( + catalog_name, + schema_name, + comment, + storage_root, + )) + }) + .map_err(to_py_err)?; + + schema_info_to_pyobject(py, schema_info).map(|x| x.into()) + } + + #[pyo3(signature = (catalog_name, schema_name, force))] + pub fn delete_schema( + &self, + py: Python, + catalog_name: &str, + schema_name: &str, + force: bool, + ) -> PyResult<()> { + py.allow_threads(|| { + pl_async::get_runtime().block_on_potential_spawn(self.client().delete_schema( + catalog_name, + schema_name, + force, + )) + }) + .map_err(to_py_err) + } + + #[pyo3(signature = ( + catalog_name, schema_name, table_name, schema, table_type, data_source_format, comment, + storage_root, properties + ))] + pub fn create_table( + &self, + py: Python, + catalog_name: &str, + schema_name: &str, + table_name: &str, + schema: Option>, + table_type: &str, + data_source_format: Option<&str>, + comment: Option<&str>, + storage_root: Option<&str>, + properties: Vec<(String, String)>, + ) -> PyResult { + let table_info = py.allow_threads(|| { + pl_async::get_runtime() + .block_on_potential_spawn( + self.client().create_table( + catalog_name, + schema_name, + table_name, + schema.as_ref().map(|x| &x.0), + &TableType::from_str(table_type) + .map_err(|e| PyValueError::new_err(e.to_string()))?, + data_source_format + .map(DataSourceFormat::from_str) + .transpose() + .map_err(|e| PyValueError::new_err(e.to_string()))? + .as_ref(), + comment, + storage_root, + &mut properties.iter().map(|(a, b)| (a.as_str(), b.as_str())), + ), + ) + .map_err(to_py_err) + })?; + + table_info_to_pyobject(py, table_info).map(|x| x.into()) + } + + #[pyo3(signature = (catalog_name, schema_name, table_name))] + pub fn delete_table( + &self, + py: Python, + catalog_name: &str, + schema_name: &str, + table_name: &str, + ) -> PyResult<()> { + py.allow_threads(|| { + pl_async::get_runtime().block_on_potential_spawn(self.client().delete_table( + catalog_name, + schema_name, + table_name, + )) + }) + .map_err(to_py_err) + } + + #[pyo3(signature = (type_json))] + #[staticmethod] + pub fn type_json_to_polars_type(py: Python, type_json: &str) -> PyResult { + Ok(Wrap(parse_type_json_str(type_json).map_err(to_py_err)?) + .into_pyobject(py)? + .unbind()) + } + + #[pyo3(signature = (catalog_info_cls, schema_info_cls, table_info_cls, column_info_cls))] + #[staticmethod] + pub fn init_classes( + py: Python, + catalog_info_cls: Py, + schema_info_cls: Py, + table_info_cls: Py, + column_info_cls: Py, + ) { + CATALOG_INFO_CLS.get_or_init(py, || catalog_info_cls); + SCHEMA_INFO_CLS.get_or_init(py, || schema_info_cls); + TABLE_INFO_CLS.get_or_init(py, || table_info_cls); + COLUMN_INFO_CLS.get_or_init(py, || column_info_cls); + } } impl PyCatalogClient { @@ -174,50 +367,141 @@ impl PyCatalogClient { } } -fn table_entry_to_pydict(py: Python, table_entry: TableInfo) -> Bound<'_, PyDict> { - let TableInfo { +fn catalog_info_to_pyobject( + py: Python, + CatalogInfo { + name, + comment, + storage_location, + properties, + options, + created_at, + created_by, + updated_at, + updated_by, + }: CatalogInfo, +) -> PyResult> { + let dict = PyDict::new(py); + + let properties = properties_to_pyobject(py, properties); + let options = properties_to_pyobject(py, options); + + pydict_insert_keys!(dict, { name, comment, + storage_location, + properties, + options, + created_at, + created_by, + updated_at, + updated_by + }); + + CATALOG_INFO_CLS + .get(py) + .unwrap() + .bind(py) + .call((), Some(&dict)) +} + +fn schema_info_to_pyobject( + py: Python, + SchemaInfo { + name, + comment, + properties, + storage_location, + created_at, + created_by, + updated_at, + updated_by, + }: SchemaInfo, +) -> PyResult> { + let dict = PyDict::new(py); + + let properties = properties_to_pyobject(py, properties); + + pydict_insert_keys!(dict, { + name, + comment, + properties, + storage_location, + created_at, + created_by, + updated_at, + updated_by + }); + + SCHEMA_INFO_CLS + .get(py) + .unwrap() + .bind(py) + .call((), Some(&dict)) +} + +fn table_info_to_pyobject(py: Python, table_info: TableInfo) -> PyResult> { + let TableInfo { + name, table_id, table_type, + comment, storage_location, data_source_format, columns, - } = table_entry; + properties, + created_at, + created_by, + updated_at, + updated_by, + } = table_info; + + let columns = columns + .map(|columns| { + columns + .into_iter() + .map( + |ColumnInfo { + name, + type_name, + type_text, + type_json, + position, + comment, + partition_index, + }| { + let dict = PyDict::new(py); + + let name = name.as_str(); + let type_name = type_name.as_str(); + let type_text = type_text.as_str(); + + pydict_insert_keys!(dict, { + name, + type_name, + type_text, + type_json, + position, + comment, + partition_index, + }); + + COLUMN_INFO_CLS + .get(py) + .unwrap() + .bind(py) + .call((), Some(&dict)) + }, + ) + .collect::>>() + }) + .transpose()?; let dict = PyDict::new(py); - let columns = columns.map(|columns| { - columns - .into_iter() - .map( - |ColumnInfo { - name, - type_text, - type_interval_type, - position, - comment, - partition_index, - }| { - let dict = PyDict::new(py); - - pydict_insert_keys!(dict, { - name, - type_text, - type_interval_type, - position, - comment, - partition_index, - }); - - dict - }, - ) - .collect::>() - }); - let data_source_format = data_source_format.map(|x| x.to_string()); let table_type = table_type.to_string(); + let properties = properties_to_pyobject(py, properties); pydict_insert_keys!(dict, { name, @@ -227,7 +511,29 @@ fn table_entry_to_pydict(py: Python, table_entry: TableInfo) -> Bound<'_, PyDict storage_location, data_source_format, columns, + properties, + created_at, + created_by, + updated_at, + updated_by, }); + TABLE_INFO_CLS + .get(py) + .unwrap() + .bind(py) + .call((), Some(&dict)) +} + +fn properties_to_pyobject( + py: Python, + properties: PlHashMap, +) -> Bound<'_, PyDict> { + let dict = PyDict::new(py); + + for (key, value) in properties.into_iter() { + dict.set_item(key.as_str(), value).unwrap(); + } + dict } diff --git a/crates/polars-utils/src/error.rs b/crates/polars-utils/src/error.rs index 39d5b5bf4daf..861731ece753 100644 --- a/crates/polars-utils/src/error.rs +++ b/crates/polars-utils/src/error.rs @@ -1,6 +1,9 @@ use std::borrow::Cow; use std::fmt::{Display, Formatter}; +use crate::config::verbose; +use crate::format_pl_smallstr; + type ErrString = Cow<'static, str>; #[derive(Debug)] @@ -20,3 +23,28 @@ impl Display for PolarsUtilsError { } pub type Result = std::result::Result; + +/// Utility whose Display impl truncates the string unless POLARS_VERBOSE is set. +pub struct TruncateErrorDetail<'a>(pub &'a str); + +impl std::fmt::Display for TruncateErrorDetail<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let maybe_truncated = if verbose() { + self.0 + } else { + // Clamp the output on non-verbose + &self.0[..self.0.len().min(4096)] + }; + + f.write_str(maybe_truncated)?; + + if maybe_truncated.len() != self.0.len() { + let n_more = self.0.len() - maybe_truncated.len(); + f.write_str(" ...(set POLARS_VERBOSE=1 to see full response (")?; + f.write_str(&format_pl_smallstr!("{}", n_more))?; + f.write_str(" more characters))")?; + }; + + Ok(()) + } +} diff --git a/crates/polars-utils/src/lib.rs b/crates/polars-utils/src/lib.rs index ba6da732d76c..960587bf79a5 100644 --- a/crates/polars-utils/src/lib.rs +++ b/crates/polars-utils/src/lib.rs @@ -16,7 +16,7 @@ pub mod chunks; pub mod clmul; mod config; pub mod cpuid; -mod error; +pub mod error; pub mod floor_divmod; pub mod functions; pub mod hashing; diff --git a/crates/polars-utils/src/pl_str.rs b/crates/polars-utils/src/pl_str.rs index 64d1bde78e80..162b460c8a61 100644 --- a/crates/polars-utils/src/pl_str.rs +++ b/crates/polars-utils/src/pl_str.rs @@ -45,6 +45,11 @@ impl PlSmallStr { self.0.as_str() } + #[inline(always)] + pub fn as_mut_str(&mut self) -> &mut str { + self.0.as_mut_str() + } + #[inline(always)] pub fn into_string(self) -> String { self.0.into_string() @@ -76,6 +81,13 @@ impl core::ops::Deref for PlSmallStr { } } +impl core::ops::DerefMut for PlSmallStr { + #[inline(always)] + fn deref_mut(&mut self) -> &mut Self::Target { + self.as_mut_str() + } +} + impl core::borrow::Borrow for PlSmallStr { #[inline(always)] fn borrow(&self) -> &str { diff --git a/py-polars/polars/catalog.py b/py-polars/polars/catalog.py index d57d01b783f8..20d0d688193d 100644 --- a/py-polars/polars/catalog.py +++ b/py-polars/polars/catalog.py @@ -1,14 +1,21 @@ from __future__ import annotations +import contextlib import importlib import os -from typing import TYPE_CHECKING, Any, Literal, TypedDict +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal +from polars._utils.unstable import issue_unstable_warning from polars._utils.wrap import wrap_ldf +from polars.exceptions import DuplicateError +from polars.schema import Schema if TYPE_CHECKING: from datetime import datetime + from polars._typing import SchemaDict + from polars.datatypes.classes import DataType from polars.io.cloud import CredentialProviderFunction from polars.lazyframe import LazyFrame @@ -46,7 +53,7 @@ def __init__( * "databricks-sdk": Use the Databricks SDK to retrieve and use the bearer token from the environment. """ - from polars.polars import PyCatalogClient + issue_unstable_warning("`Catalog` functionality is considered unstable.") if bearer_token == "databricks-sdk" or ( bearer_token == "auto" @@ -190,16 +197,11 @@ def scan_table( """ table_info = self.get_table_info(catalog_name, schema_name, table_name) + storage_location, data_source_format = _extract_location_and_data_format( + table_info, "scan table" + ) - if (source := table_info.get("storage_location")) is None: - msg = "cannot scan catalog table: no storage_location found" - raise ValueError(msg) - - if (data_source_format := table_info.get("data_source_format")) is None: - msg = "cannot scan catalog table: no data_source_format found" - raise ValueError(msg) - - if data_source_format in ["DELTA", "DELTA_SHARING"]: + if data_source_format in ["DELTA", "DELTASHARING"]: from polars.io.delta import scan_delta if credential_provider is not None and credential_provider != "auto": @@ -207,7 +209,7 @@ def scan_table( raise NotImplementedError(msg) return scan_delta( - source, + storage_location, version=delta_table_version, delta_table_options=delta_table_options, storage_options=storage_options, @@ -230,7 +232,10 @@ def scan_table( from polars.io.cloud.credential_provider import _maybe_init_credential_provider credential_provider = _maybe_init_credential_provider( - credential_provider, source, storage_options, "Catalog.scan_table" + credential_provider, + storage_location, + storage_options, + "Catalog.scan_table", ) if storage_options: @@ -250,37 +255,258 @@ def scan_table( ) ) + def create_catalog( + self, + catalog_name: str, + *, + comment: str | None = None, + storage_root: str | None = None, + ) -> CatalogInfo: + """ + Create a catalog. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + + Parameters + ---------- + catalog_name + Name of the catalog. + comment + Leaves a comment about the catalog. + storage_root + Base location at which to store the catalog. + """ + return self._client.create_catalog( + catalog_name=catalog_name, comment=comment, storage_root=storage_root + ) + + def delete_catalog( + self, + catalog_name: str, + *, + force: bool = False, + ) -> None: + """ + Delete a catalog. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + + Parameters + ---------- + catalog_name + Name of the catalog. + force + Forcibly delete the catalog even if it is not empty. + """ + self._client.delete_catalog(catalog_name=catalog_name, force=force) + + def create_schema( + self, + catalog_name: str, + schema_name: str, + *, + comment: str | None = None, + storage_root: str | None = None, + ) -> SchemaInfo: + """ + Create a schema in the catalog. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + + Parameters + ---------- + catalog_name + Name of the catalog. + schema_name + Name of the schema. + comment + Leaves a comment about the table. + storage_root + Base location at which to store the schema. + """ + return self._client.create_schema( + catalog_name=catalog_name, + schema_name=schema_name, + comment=comment, + storage_root=storage_root, + ) + + def delete_schema( + self, + catalog_name: str, + schema_name: str, + *, + force: bool = False, + ) -> None: + """ + Delete a schema in the catalog. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + + Parameters + ---------- + catalog_name + Name of the catalog. + schema_name + Name of the schema. + force + Forcibly delete the schema even if it is not empty. + """ + self._client.delete_schema( + catalog_name=catalog_name, schema_name=schema_name, force=force + ) + + def create_table( + self, + catalog_name: str, + schema_name: str, + table_name: str, + *, + schema: SchemaDict | None, + table_type: TableType, + data_source_format: DataSourceFormat | None = None, + comment: str | None = None, + storage_root: str | None = None, + properties: dict[str, str] | None = None, + ) -> TableInfo: + """ + Create a table in the catalog. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + + Parameters + ---------- + catalog_name + Name of the catalog. + schema_name + Name of the schema. + table_name + Name of the table. + schema + Schema of the table. + table_type + Type of the table + data_source_format + Storage format of the table. + comment + Leaves a comment about the table. + storage_root + Base location at which to store the table. + properties + Extra key-value metadata to store. + """ + return self._client.create_table( + catalog_name=catalog_name, + schema_name=schema_name, + table_name=table_name, + schema=schema, + table_type=table_type, + data_source_format=data_source_format, + comment=comment, + storage_root=storage_root, + properties=list((properties or {}).items()), + ) + + def delete_table( + self, + catalog_name: str, + schema_name: str, + table_name: str, + ) -> None: + """ + Delete the table stored at this location. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + + Parameters + ---------- + catalog_name + Name of the catalog. + schema_name + Name of the schema. + table_name + Name of the table. + """ + self._client.delete_table( + catalog_name=catalog_name, + schema_name=schema_name, + table_name=table_name, + ) + @classmethod def _get_databricks_token(cls) -> str: - cls._ensure_databricks_sdk_available() + if importlib.util.find_spec("databricks.sdk") is None: + msg = "could not get Databricks token: databricks-sdk is not installed" + raise ImportError(msg) # We code like this to bypass linting m = importlib.import_module("databricks.sdk.core").__dict__ return m["DefaultCredentials"]()(m["Config"]())()["Authorization"][7:] - @staticmethod - def _ensure_databricks_sdk_available() -> None: - if importlib.util.find_spec("databricks.sdk") is None: - msg = "could not get Databricks token: databricks-sdk is not installed" - raise ImportError(msg) +def _extract_location_and_data_format( + table_info: TableInfo, operation: str +) -> tuple[str, DataSourceFormat]: + if table_info.storage_location is None: + msg = f"cannot {operation}: no storage_location found" + raise ValueError(msg) -class CatalogInfo(TypedDict): + if table_info.data_source_format is None: + msg = f"cannot {operation}: no data_source_format found" + raise ValueError(msg) + + return table_info.storage_location, table_info.data_source_format + + +@dataclass +class CatalogInfo: """Information for a catalog within a metastore.""" name: str comment: str | None + properties: dict[str, str] + options: dict[str, str] + storage_location: str | None + created_at: datetime | None + created_by: str | None + updated_at: datetime | None + updated_by: str | None + +@dataclass +class SchemaInfo: + """ + Information for a schema within a catalog. -class SchemaInfo(TypedDict): - """Information for a schema within a catalog.""" + Note: This does not refer to a table schema. It can instead be understood + as a subdirectory within a catalog. + """ name: str comment: str | None + properties: dict[str, str] + storage_location: str | None + created_at: datetime | None + created_by: str | None + updated_at: datetime | None + updated_by: str | None -class TableInfo(TypedDict): +@dataclass +class TableInfo: """Information for a catalog table.""" name: str @@ -290,19 +516,64 @@ class TableInfo(TypedDict): storage_location: str | None data_source_format: DataSourceFormat | None columns: list[ColumnInfo] | None + properties: dict[str, str] + created_at: datetime | None + created_by: str | None + updated_at: datetime | None + updated_by: str | None + def get_polars_schema(self) -> Schema | None: + """ + Get the native polars schema of this table. -class ColumnInfo(TypedDict): + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + """ + issue_unstable_warning( + "`get_polars_schema` functionality is considered unstable." + ) + if self.columns is None: + return None + + schema = Schema() + + for column_info in self.columns: + if column_info.name in schema: + msg = f"duplicate column name: {column_info.name}" + raise DuplicateError(msg) + schema[column_info.name] = column_info.get_polars_dtype() + + return schema + + +@dataclass +class ColumnInfo: """Information for a column within a catalog table.""" name: str + type_name: str type_text: str - type_interval_type: str | None + type_json: str position: int | None comment: str | None partition_index: int | None + def get_polars_dtype(self) -> DataType: + """ + Get the native polars datatype of this column. + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + """ + issue_unstable_warning( + "`get_polars_dtype` functionality is considered unstable." + ) + return PyCatalogClient.type_json_to_polars_type(self.type_json) + + +# TODO: Expose these type aliases to reference guide TableType = Literal[ "MANAGED", "EXTERNAL", @@ -323,7 +594,7 @@ class ColumnInfo(TypedDict): "ORC", "TEXT", "UNITY_CATALOG", - "DELTA_SHARING", + "DELTASHARING", "DATABRICKS_FORMAT", "REDSHIFT_FORMAT", "SNOWFLAKE_FORMAT", @@ -336,3 +607,14 @@ class ColumnInfo(TypedDict): "HIVE_CUSTOM", "VECTOR_INDEX_FORMAT", ] + +# TODO: Move this back up after moving the data models to a separate file +with contextlib.suppress(ImportError): + from polars.polars import PyCatalogClient + + PyCatalogClient.init_classes( + catalog_info_cls=CatalogInfo, + schema_info_cls=SchemaInfo, + table_info_cls=TableInfo, + column_info_cls=ColumnInfo, + ) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 9530b7be9ea4..b719e80473a1 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -4489,8 +4489,6 @@ def write_delta( _check_if_delta_available() - credential_provider_creds = {} - from deltalake import DeltaTable, write_deltalake from deltalake import __version__ as delta_version from packaging.version import Version @@ -4520,6 +4518,8 @@ def write_delta( else: credential_provider = None + credential_provider_creds = {} + if credential_provider is not None: credential_provider_creds = _get_credentials_from_provider_expiry_aware( credential_provider