diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 1f574f1231..412a2197ef 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -393,6 +393,7 @@ pub(crate) fn cast_array( } (Utf8View, Utf8) => Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?), (Struct(_), Utf8) => Ok(casts_struct_to_string(array.as_struct(), cast_options)?), + (Map(_, _), Utf8) => Ok(cast_map_to_string(array.as_map(), cast_options)?), (Struct(_), Struct(_)) => Ok(cast_struct_to_struct( array.as_struct(), &from_type, @@ -729,6 +730,68 @@ fn casts_struct_to_string( Ok(Arc::new(builder.finish())) } +fn cast_map_to_string( + array: &MapArray, + spark_cast_options: &SparkCastOptions, +) -> DataFusionResult { + let mut builder = StringBuilder::with_capacity(array.len(), array.len() * 16); + let mut str = String::with_capacity(array.len() * 16); + + let casted_keys = cast_array( + Arc::clone(array.keys()), + &DataType::Utf8, + spark_cast_options, + )?; + let casted_values = cast_array( + Arc::clone(array.values()), + &DataType::Utf8, + spark_cast_options, + )?; + let key_values = casted_keys + .as_any() + .downcast_ref::() + .expect("Casted keys should be StringArray"); + let value_values = casted_values + .as_any() + .downcast_ref::() + .expect("Casted values should be StringArray"); + + let offsets = array.offsets(); + for row_index in 0..array.len() { + if array.is_null(row_index) { + builder.append_null(); + } else { + str.clear(); + let start = offsets[row_index] as usize; + let end = offsets[row_index + 1] as usize; + + str.push('{'); + let mut first = true; + for idx in start..end { + if !first { + str.push_str(", "); + } + if key_values.is_null(idx) { + str.push_str(&spark_cast_options.null_string); + } else { + str.push_str(key_values.value(idx)); + } + str.push_str(" -> "); + if value_values.is_null(idx) { + str.push_str(&spark_cast_options.null_string); + } else { + str.push_str(value_values.value(idx)); + } + first = false; + } + str.push('}'); + builder.append_value(&str); + } + } + + Ok(Arc::new(builder.finish())) +} + impl Display for Cast { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( @@ -869,7 +932,8 @@ fn cast_binary_formatter(value: &[u8]) -> String { #[cfg(test)] mod tests { use super::*; - use arrow::array::{ListArray, NullArray, StringArray}; + use arrow::array::builder::{Int32Builder, MapBuilder, StringBuilder}; + use arrow::array::{ListArray, MapFieldNames, NullArray, StringArray}; use arrow::buffer::OffsetBuffer; use arrow::datatypes::TimestampMicrosecondType; use arrow::datatypes::{Field, Fields}; @@ -994,6 +1058,41 @@ mod tests { } } + #[test] + fn test_cast_map_to_utf8() { + let mut map_builder = MapBuilder::new( + Some(MapFieldNames { + entry: "entries".into(), + key: "key".into(), + value: "value".into(), + }), + StringBuilder::new(), + Int32Builder::new(), + ); + + map_builder.keys().append_value("a"); + map_builder.values().append_value(1); + map_builder.keys().append_value("b"); + map_builder.values().append_null(); + map_builder.append(true).unwrap(); + + map_builder.append(true).unwrap(); + map_builder.append(false).unwrap(); + + let map_array: ArrayRef = Arc::new(map_builder.finish()); + let string_array = cast_array( + map_array, + &DataType::Utf8, + &SparkCastOptions::new(EvalMode::Legacy, "UTC", false), + ) + .unwrap(); + let string_array = string_array.as_string::(); + assert_eq!(3, string_array.len()); + assert_eq!(r#"{a -> 1, b -> null}"#, string_array.value(0)); + assert_eq!(r#"{}"#, string_array.value(1)); + assert!(string_array.is_null(2)); + } + #[test] fn test_cast_string_array_to_string() { let values_array = diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 42da809206..cea4b53006 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -305,6 +305,13 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { } } Compatible() + case MapType(keyType, valueType, _) => + isSupported(keyType, DataTypes.StringType, timeZoneId, evalMode) match { + case Compatible(_) => + isSupported(valueType, DataTypes.StringType, timeZoneId, evalMode) + case other => + other + } case _ => unsupported(fromType, DataTypes.StringType) } } diff --git a/spark/src/test/resources/sql-tests/expressions/cast/cast_complex_types_to_string.sql b/spark/src/test/resources/sql-tests/expressions/cast/cast_complex_types_to_string.sql index 8b1d989ae7..b8b924e01a 100644 --- a/spark/src/test/resources/sql-tests/expressions/cast/cast_complex_types_to_string.sql +++ b/spark/src/test/resources/sql-tests/expressions/cast/cast_complex_types_to_string.sql @@ -149,8 +149,8 @@ SELECT cast(named_struct('a', named_struct('b', named_struct('c', 1, 'd', 'leaf' query SELECT cast(named_struct('s1', '', 's2', ' ', 's3', cast(null as string)) as string) --- Map-valued field: not supported, falls back to Spark. -query expect_fallback(to StringType is not supported) +-- Map-valued field: supported via recursive map -> string casting. +query SELECT cast(named_struct('m', map('k', 1)) as string) -- ---------------------------------------------------------------------------- @@ -270,69 +270,70 @@ SELECT cast(array(cast(1.5 as double), cast('NaN' as double), cast('-Infinity' a query SELECT cast(array(array(array(1, 2), array(3)), array(array(cast(null as int)))) as string) --- Array of map: not supported, falls back to Spark. -query expect_fallback(to StringType is not supported) +-- Array of map: supported via recursive map -> string casting. +query SELECT cast(array(map('k', 1)) as string) -- ---------------------------------------------------------------------------- -- Map → string -- ---------------------------------------------------------------------------- --- Comet does not implement map-to-string casts, so every map → string falls back to Spark. +-- Comet now implements map-to-string casts, including nested maps. -- Note: maps materialized through parquet have nondeterministic entry order, so map column -- tests use literal maps directly rather than reading from a parquet table. -- Map with string keys, int values. -query expect_fallback(Cast from MapType) +query SELECT cast(map('a', 1, 'b', 2, 'c', 3) as string) -- Map with NULL values rendered as "null". -query expect_fallback(Cast from MapType) +query SELECT cast(map('a', 1, 'b', cast(null as int), 'c', 3) as string) -- Map with int keys, string values. -query expect_fallback(Cast from MapType) +query SELECT cast(map(1, 'one', 2, 'two', 3, 'three') as string) -- Map with boolean values. -query expect_fallback(Cast from MapType) +query SELECT cast(map('t', true, 'f', false, 'n', cast(null as boolean)) as string) -- Map with bigint values at min/max. -query expect_fallback(Cast from MapType) +query SELECT cast(map('max', 9223372036854775807, 'min', -9223372036854775808, 'zero', cast(0 as bigint)) as string) -- Map with decimal values. -query expect_fallback(Cast from MapType) +query SELECT cast(map('pos', cast('1.234567890123456789' as decimal(38, 18)), 'neg', cast('-1.234567890123456789' as decimal(38, 18)), 'null', cast(null as decimal(38, 18))) as string) -- Map with date and timestamp values. -query expect_fallback(Cast from MapType) +query SELECT cast(map('a', date '2024-01-15', 'b', date '1970-01-01', 'c', cast(null as date)) as string) -query expect_fallback(Cast from MapType) +query SELECT cast(map('a', timestamp '2024-01-15 10:30:45', 'b', cast(null as timestamp)) as string) -- Map with binary values. -query expect_fallback(Cast from MapType) +query SELECT cast(map('a', X'616263', 'b', X'', 'c', cast(null as binary)) as string) -- Map with float / double values: NaN / ±0 / ±Infinity / NULL. -query expect_fallback(Cast from MapType) +query SELECT cast(map('nan', cast('NaN' as float), 'neg0', cast(-0.0 as float), 'null', cast(null as float)) as string) -query expect_fallback(Cast from MapType) +query SELECT cast(map('nan', cast('NaN' as double), 'inf', cast('Infinity' as double), 'ninf', cast('-Infinity' as double), 'null', cast(null as double)) as string) -- Map with struct values: each value rendered as `{f1, f2, ...}`. -query expect_fallback(Cast from MapType) +query SELECT cast(map('a', named_struct('x', 1, 'y', 'first'), 'b', cast(null as struct)) as string) -- Map with array values. -query expect_fallback(Cast from MapType) +query SELECT cast(map('a', array(1, 2, 3), 'b', array(cast(null as int)), 'c', cast(null as array)) as string) --- Empty map. -query expect_fallback(Cast from MapType) +-- Empty map: still falls back because planning sees `map()` as `Map`, +-- which reaches the existing NullType -> StringType cast fallback. +query expect_fallback(Cast from NullType to StringType is not supported) SELECT cast(map() as string) -- NULL map: Spark constant-folds this to a literal NULL, so the cast never reaches Comet @@ -341,5 +342,5 @@ query SELECT cast(cast(null as map) as string) -- Map of map. -query expect_fallback(Cast from MapType) +query SELECT cast(map('outer', map('inner', 1)) as string) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index aac1bc0081..6ef5b6b79b 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -1672,19 +1672,24 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { Incompatible(Some("There can be rounding differences"))) } - test("cast MapType propagates Unsupported from nested value cast") { + test("cast MapType to StringType is Compatible") { + val fromType = MapType(IntegerType, IntegerType) + assert( + CometCast.isSupported(fromType, DataTypes.StringType, None, CometEvalMode.LEGACY) == + Compatible()) + } + + test("cast MapType propagates supported nested value cast") { // Map> → Map: the inner Map → String - // cast is Unsupported, and that must propagate through the outer Map - // arm rather than being silently swallowed. + // cast is now supported and must propagate through the outer Map arm. val innerFrom = MapType(IntegerType, IntegerType) - val expectedMessage = s"Cast from $innerFrom to ${DataTypes.StringType} is not supported" assert( CometCast.isSupported( MapType(IntegerType, innerFrom), MapType(IntegerType, StringType), None, CometEvalMode.LEGACY) == - Unsupported(Some(expectedMessage))) + Compatible()) } test("cast ArrayType(DateType) to unsupported ArrayType falls back") {