Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 100 additions & 1 deletion native/spark-expr/src/conversion_funcs/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ pub(crate) fn cast_array(
}
(Utf8View, Utf8) => Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?),
(Struct(_), Utf8) => Ok(casts_struct_to_string(array.as_struct(), cast_options)?),
(Map(_, _), Utf8) => Ok(cast_map_to_string(array.as_map(), cast_options)?),
(Struct(_), Struct(_)) => Ok(cast_struct_to_struct(
array.as_struct(),
&from_type,
Expand Down Expand Up @@ -729,6 +730,68 @@ fn casts_struct_to_string(
Ok(Arc::new(builder.finish()))
}

fn cast_map_to_string(
array: &MapArray,
spark_cast_options: &SparkCastOptions,
) -> DataFusionResult<ArrayRef> {
let mut builder = StringBuilder::with_capacity(array.len(), array.len() * 16);
let mut str = String::with_capacity(array.len() * 16);

let casted_keys = cast_array(
Arc::clone(array.keys()),
&DataType::Utf8,
spark_cast_options,
)?;
let casted_values = cast_array(
Arc::clone(array.values()),
&DataType::Utf8,
spark_cast_options,
)?;
let key_values = casted_keys
.as_any()
.downcast_ref::<StringArray>()
.expect("Casted keys should be StringArray");
let value_values = casted_values
.as_any()
.downcast_ref::<StringArray>()
.expect("Casted values should be StringArray");

let offsets = array.offsets();
for row_index in 0..array.len() {
if array.is_null(row_index) {
builder.append_null();
} else {
str.clear();
let start = offsets[row_index] as usize;
let end = offsets[row_index + 1] as usize;

str.push('{');
let mut first = true;
for idx in start..end {
if !first {
str.push_str(", ");
}
if key_values.is_null(idx) {
str.push_str(&spark_cast_options.null_string);
} else {
str.push_str(key_values.value(idx));
}
str.push_str(" -> ");
if value_values.is_null(idx) {
str.push_str(&spark_cast_options.null_string);
} else {
str.push_str(value_values.value(idx));
}
first = false;
}
str.push('}');
builder.append_value(&str);
}
}

Ok(Arc::new(builder.finish()))
}

impl Display for Cast {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
Expand Down Expand Up @@ -869,7 +932,8 @@ fn cast_binary_formatter(value: &[u8]) -> String {
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{ListArray, NullArray, StringArray};
use arrow::array::builder::{Int32Builder, MapBuilder, StringBuilder};
use arrow::array::{ListArray, MapFieldNames, NullArray, StringArray};
use arrow::buffer::OffsetBuffer;
use arrow::datatypes::TimestampMicrosecondType;
use arrow::datatypes::{Field, Fields};
Expand Down Expand Up @@ -994,6 +1058,41 @@ mod tests {
}
}

#[test]
fn test_cast_map_to_utf8() {
let mut map_builder = MapBuilder::new(
Some(MapFieldNames {
entry: "entries".into(),
key: "key".into(),
value: "value".into(),
}),
StringBuilder::new(),
Int32Builder::new(),
);

map_builder.keys().append_value("a");
map_builder.values().append_value(1);
map_builder.keys().append_value("b");
map_builder.values().append_null();
map_builder.append(true).unwrap();

map_builder.append(true).unwrap();
map_builder.append(false).unwrap();

let map_array: ArrayRef = Arc::new(map_builder.finish());
let string_array = cast_array(
map_array,
&DataType::Utf8,
&SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
)
.unwrap();
let string_array = string_array.as_string::<i32>();
assert_eq!(3, string_array.len());
assert_eq!(r#"{a -> 1, b -> null}"#, string_array.value(0));
assert_eq!(r#"{}"#, string_array.value(1));
assert!(string_array.is_null(2));
}

#[test]
fn test_cast_string_array_to_string() {
let values_array =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,13 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
}
}
Compatible()
case MapType(keyType, valueType, _) =>
isSupported(keyType, DataTypes.StringType, timeZoneId, evalMode) match {
case Compatible(_) =>
isSupported(valueType, DataTypes.StringType, timeZoneId, evalMode)
case other =>
other
}
case _ => unsupported(fromType, DataTypes.StringType)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ SELECT cast(named_struct('a', named_struct('b', named_struct('c', 1, 'd', 'leaf'
query
SELECT cast(named_struct('s1', '', 's2', ' ', 's3', cast(null as string)) as string)

-- Map-valued field: not supported, falls back to Spark.
query expect_fallback(to StringType is not supported)
-- Map-valued field: supported via recursive map -> string casting.
query
SELECT cast(named_struct('m', map('k', 1)) as string)

-- ----------------------------------------------------------------------------
Expand Down Expand Up @@ -270,69 +270,70 @@ SELECT cast(array(cast(1.5 as double), cast('NaN' as double), cast('-Infinity' a
query
SELECT cast(array(array(array(1, 2), array(3)), array(array(cast(null as int)))) as string)

-- Array of map: not supported, falls back to Spark.
query expect_fallback(to StringType is not supported)
-- Array of map: supported via recursive map -> string casting.
query
SELECT cast(array(map('k', 1)) as string)

-- ----------------------------------------------------------------------------
-- Map → string
-- ----------------------------------------------------------------------------
-- Comet does not implement map-to-string casts, so every map → string falls back to Spark.
-- Comet now implements map-to-string casts, including nested maps.
-- Note: maps materialized through parquet have nondeterministic entry order, so map column
-- tests use literal maps directly rather than reading from a parquet table.

-- Map with string keys, int values.
query expect_fallback(Cast from MapType)
query
SELECT cast(map('a', 1, 'b', 2, 'c', 3) as string)

-- Map with NULL values rendered as "null".
query expect_fallback(Cast from MapType)
query
SELECT cast(map('a', 1, 'b', cast(null as int), 'c', 3) as string)

-- Map with int keys, string values.
query expect_fallback(Cast from MapType)
query
SELECT cast(map(1, 'one', 2, 'two', 3, 'three') as string)

-- Map with boolean values.
query expect_fallback(Cast from MapType)
query
SELECT cast(map('t', true, 'f', false, 'n', cast(null as boolean)) as string)

-- Map with bigint values at min/max.
query expect_fallback(Cast from MapType)
query
SELECT cast(map('max', 9223372036854775807, 'min', -9223372036854775808, 'zero', cast(0 as bigint)) as string)

-- Map with decimal values.
query expect_fallback(Cast from MapType)
query
SELECT cast(map('pos', cast('1.234567890123456789' as decimal(38, 18)), 'neg', cast('-1.234567890123456789' as decimal(38, 18)), 'null', cast(null as decimal(38, 18))) as string)

-- Map with date and timestamp values.
query expect_fallback(Cast from MapType)
query
SELECT cast(map('a', date '2024-01-15', 'b', date '1970-01-01', 'c', cast(null as date)) as string)

query expect_fallback(Cast from MapType)
query
SELECT cast(map('a', timestamp '2024-01-15 10:30:45', 'b', cast(null as timestamp)) as string)

-- Map with binary values.
query expect_fallback(Cast from MapType)
query
SELECT cast(map('a', X'616263', 'b', X'', 'c', cast(null as binary)) as string)

-- Map with float / double values: NaN / ±0 / ±Infinity / NULL.
query expect_fallback(Cast from MapType)
query
SELECT cast(map('nan', cast('NaN' as float), 'neg0', cast(-0.0 as float), 'null', cast(null as float)) as string)

query expect_fallback(Cast from MapType)
query
SELECT cast(map('nan', cast('NaN' as double), 'inf', cast('Infinity' as double), 'ninf', cast('-Infinity' as double), 'null', cast(null as double)) as string)

-- Map with struct values: each value rendered as `{f1, f2, ...}`.
query expect_fallback(Cast from MapType)
query
SELECT cast(map('a', named_struct('x', 1, 'y', 'first'), 'b', cast(null as struct<x: int, y: string>)) as string)

-- Map with array values.
query expect_fallback(Cast from MapType)
query
SELECT cast(map('a', array(1, 2, 3), 'b', array(cast(null as int)), 'c', cast(null as array<int>)) as string)

-- Empty map.
query expect_fallback(Cast from MapType)
-- Empty map: still falls back because planning sees `map()` as `Map<NullType, NullType>`,
-- which reaches the existing NullType -> StringType cast fallback.
query expect_fallback(Cast from NullType to StringType is not supported)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if this is the intended behavior. How should we handle the empty map(Map<NullType, NullType) case?

SELECT cast(map() as string)

-- NULL map: Spark constant-folds this to a literal NULL, so the cast never reaches Comet
Expand All @@ -341,5 +342,5 @@ query
SELECT cast(cast(null as map<string, int>) as string)

-- Map of map.
query expect_fallback(Cast from MapType)
query
SELECT cast(map('outer', map('inner', 1)) as string)
15 changes: 10 additions & 5 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1672,19 +1672,24 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
Incompatible(Some("There can be rounding differences")))
}

test("cast MapType propagates Unsupported from nested value cast") {
test("cast MapType to StringType is Compatible") {
val fromType = MapType(IntegerType, IntegerType)
assert(
CometCast.isSupported(fromType, DataTypes.StringType, None, CometEvalMode.LEGACY) ==
Compatible())
}

test("cast MapType propagates supported nested value cast") {
// Map<Int, Map<Int, Int>> → Map<Int, String>: the inner Map → String
// cast is Unsupported, and that must propagate through the outer Map
// arm rather than being silently swallowed.
// cast is now supported and must propagate through the outer Map arm.
val innerFrom = MapType(IntegerType, IntegerType)
val expectedMessage = s"Cast from $innerFrom to ${DataTypes.StringType} is not supported"
assert(
CometCast.isSupported(
MapType(IntegerType, innerFrom),
MapType(IntegerType, StringType),
None,
CometEvalMode.LEGACY) ==
Unsupported(Some(expectedMessage)))
Compatible())
}

test("cast ArrayType(DateType) to unsupported ArrayType falls back") {
Expand Down
Loading