diff --git a/spark/src/main/scala/org/apache/comet/serde/statics.scala b/spark/src/main/scala/org/apache/comet/serde/statics.scala index 33817b7383..1012b01c10 100644 --- a/spark/src/main/scala/org/apache/comet/serde/statics.scala +++ b/spark/src/main/scala/org/apache/comet/serde/statics.scala @@ -19,7 +19,7 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.{Attribute, ExpressionImplUtils, Literal, TryEval, UrlCodec} +import org.apache.spark.sql.catalyst.expressions.{Attribute, ExpressionImplUtils, Literal, StringDecode, TryEval, UrlCodec} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils @@ -40,7 +40,11 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] { ("encode", UrlCodec.getClass) -> CometUrlEncodeStaticInvoke, ("decode", UrlCodec.getClass) -> CometUrlDecodeStaticInvoke, ("aesEncrypt", classOf[ExpressionImplUtils]) -> CometStaticInvokeCodegenDispatch, - ("aesDecrypt", classOf[ExpressionImplUtils]) -> CometStaticInvokeCodegenDispatch) + ("aesDecrypt", classOf[ExpressionImplUtils]) -> CometStaticInvokeCodegenDispatch, + // Spark 4.0 lowers `decode(bin, charset)` to `StaticInvoke(StringDecode.decode, ...)` + // carrying the `legacyCharsets` / `legacyErrorAction` flags. Routing through the codegen + // dispatcher runs Spark's own decoder so both flags are honored. See #4465. + ("decode", classOf[StringDecode]) -> CometStaticInvokeCodegenDispatch) override def convert( expr: StaticInvoke, diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index c4abe8ad4e..6a87fa3dd5 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -19,15 +19,12 @@ package org.apache.comet.serde -import java.util.Locale - import org.apache.spark.sql.catalyst.expressions.{Attribute, BitLength, Cast, Concat, ConcatWs, Elt, Expression, FindInSet, FormatNumber, FormatString, GetJsonObject, If, InitCap, IsNull, Left, Length, Levenshtein, Like, Literal, Lower, OctetLength, Overlay, RegExpExtract, RegExpExtractAll, RegExpInStr, RegExpReplace, Right, RLike, SoundEx, StringLocate, StringLPad, StringRepeat, StringReplace, StringRPad, StringSplit, StringTranslate, Substring, SubstringIndex, ToCharacter, ToNumber, UnBase64, Upper} import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withFallbackReason -import org.apache.comet.expressions.{CometCast, CometEvalMode} import org.apache.comet.serde.ExprOuterClass.Expr import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProtoInternal, optExprWithFallbackReason, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType} import org.apache.comet.shims.CometTypeShim @@ -551,34 +548,6 @@ object CometGetJsonObject extends CometCodegenDispatch[GetJsonObject] { } } -trait CommonStringExprs { - - def stringDecode( - expr: Expression, - charset: Expression, - bin: Expression, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - charset match { - case Literal(str, DataTypes.StringType) - if str.toString.toLowerCase(Locale.ROOT) == "utf-8" => - // decode(col, 'utf-8') can be treated as a cast with "try" eval mode that puts nulls - // for invalid strings. - // Left child is the binary expression. - val binExpr = exprToProtoInternal(bin, inputs, binding) - if (binExpr.isDefined) { - CometCast.castToProto(expr, None, DataTypes.StringType, binExpr.get, CometEvalMode.TRY) - } else { - withFallbackReason(expr, bin) - None - } - case _ => - withFallbackReason(expr, "Comet only supports decoding with 'utf-8'.") - None - } - } -} - // Expressions routed through the JVM codegen dispatcher: no native implementation, so Spark's own // doGenCode runs inside the Comet pipeline, matching Spark exactly. object CometLevenshtein extends CometCodegenDispatch[Levenshtein] diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala index d3e678fe54..1ad9ec75bf 100644 --- a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala @@ -23,13 +23,13 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Sum import org.apache.comet.expressions.CometEvalMode -import org.apache.comet.serde.{CometExpressionSerde, CometStringDecode, CommonStringExprs} +import org.apache.comet.serde.{CometExpressionSerde, CometStringDecode} import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr} /** * `CometExprShim` acts as a shim for parsing expressions from different Spark versions. */ -trait CometExprShim extends CommonStringExprs { +trait CometExprShim { protected def evalMode(c: Cast): CometEvalMode.Value = CometEvalModeUtil.fromSparkEvalMode(c.evalMode) diff --git a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala index 464533b191..0be1185f59 100644 --- a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala @@ -23,13 +23,13 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Sum import org.apache.comet.expressions.CometEvalMode -import org.apache.comet.serde.{CometExpressionSerde, CometStringDecode, CometToPrettyString, CometWidthBucket, CommonStringExprs} +import org.apache.comet.serde.{CometExpressionSerde, CometStringDecode, CometToPrettyString, CometWidthBucket} import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr} /** * `CometExprShim` acts as a shim for parsing expressions from different Spark versions. */ -trait CometExprShim extends CommonStringExprs { +trait CometExprShim { protected def evalMode(c: Cast): CometEvalMode.Value = CometEvalModeUtil.fromSparkEvalMode(c.evalMode) diff --git a/spark/src/main/spark-3.x/org/apache/comet/serde/CometStringDecode.scala b/spark/src/main/spark-3.x/org/apache/comet/serde/CometStringDecode.scala index f2c25d75df..049a9fd22c 100644 --- a/spark/src/main/spark-3.x/org/apache/comet/serde/CometStringDecode.scala +++ b/spark/src/main/spark-3.x/org/apache/comet/serde/CometStringDecode.scala @@ -19,17 +19,10 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.{Attribute, StringDecode} +import org.apache.spark.sql.catalyst.expressions.StringDecode -object CometStringDecode extends CometExpressionSerde[StringDecode] with CommonStringExprs { - - override def getUnsupportedReasons(): Seq[String] = - Seq("Only the `'utf-8'` charset is supported. Other charsets fall back to Spark.") - - override def convert( - expr: StringDecode, - inputs: Seq[Attribute], - binding: Boolean): Option[ExprOuterClass.Expr] = { - stringDecode(expr, expr.charset, expr.bin, inputs, binding) - } -} +/** + * Spark 3.x `decode(bin, charset)` runs through the codegen dispatcher so Spark's own decoder + * handles invalid byte sequences (replacement-character substitution). See #4465. + */ +object CometStringDecode extends CometCodegenDispatch[StringDecode] diff --git a/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala b/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala index 4610c3d553..645de7ba75 100644 --- a/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala +++ b/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala @@ -23,12 +23,11 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionUtils, StructsToJsonEvaluator} import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.catalyst.expressions.url.ParseUrlEvaluator -import org.apache.spark.sql.internal.types.StringTypeWithCollation -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, StringType} +import org.apache.spark.sql.types.ArrayType import org.apache.comet.CometExplainInfo import org.apache.comet.expressions.CometEvalMode -import org.apache.comet.serde.{CometExpressionSerde, CometMapSort, CometToPrettyString, CometWidthBucket, CommonStringExprs} +import org.apache.comet.serde.{CometExpressionSerde, CometMapSort, CometToPrettyString, CometWidthBucket} import org.apache.comet.serde.ExprOuterClass.Expr import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithFallbackReason, scalarFunctionExprToProtoWithReturnType} @@ -37,7 +36,7 @@ import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithFa * are identical across minor versions; per-version traits override only `binaryOutputStyle` and * supply the matching `CometEvalModeUtil.sumEvalMode`. */ -trait Spark4xCometExprShim extends CommonStringExprs with CometExprShim4x { +trait Spark4xCometExprShim extends CometExprShim4x { protected def evalMode(c: Cast): CometEvalMode.Value = CometEvalModeUtil.fromSparkEvalMode(c.evalMode) @@ -85,19 +84,6 @@ trait Spark4xCometExprShim extends CommonStringExprs with CometExprShim4x { case _ => exprToProtoInternal(knc.child, inputs, binding) } - case s: StaticInvoke - if s.staticObject == classOf[StringDecode] && - s.dataType.isInstanceOf[StringType] && - s.functionName == "decode" && - s.arguments.size == 4 && - s.inputTypes == Seq( - BinaryType, - StringTypeWithCollation(supportsTrimCollation = true), - BooleanType, - BooleanType) => - val Seq(bin, charset, _, _) = s.arguments - stringDecode(expr, charset, bin, inputs, binding) - // On Spark 4.0+, RuntimeReplaceable expressions (StructsToJson, ParseUrl) become // Invoke(Literal(Evaluator), "evaluate", ...). Reconstruct the original expression and // recurse so support-level checks apply, propagating any explain info back onto the diff --git a/spark/src/test/resources/sql-tests/expressions/string/decode.sql b/spark/src/test/resources/sql-tests/expressions/string/decode.sql index 45aeaacee4..0e3bf0d124 100644 --- a/spark/src/test/resources/sql-tests/expressions/string/decode.sql +++ b/spark/src/test/resources/sql-tests/expressions/string/decode.sql @@ -25,12 +25,15 @@ -- time Comet sees the plan the wrapper has already been replaced with CaseWhen and Comet -- handles it through its existing CaseWhen + EqualNullSafe serde. -- --- The 2-arg charset form lowers to a cast(binary, string) inside Comet's stringDecode --- handler, but only when the charset is 'utf-8' (case-insensitive). Other charsets fall --- back to Spark JVM execution. +-- The 2-arg charset form runs through the codegen dispatcher (Spark's own doGenCode inside the +-- Comet pipeline) so behavior matches Spark exactly across all supported charsets and across +-- the Spark 4.0 `legacyCharsets` / `legacyErrorAction` modes (#4465). Invalid-byte tests live in +-- decode_invalid_utf8.sql and decode_invalid_utf8_strict.sql so each Spark version gets the +-- right expectation. +-- Config: spark.comet.exec.scalaUDF.codegen.enabled=true -- =========================================================================== --- Charset form: decode(bin, charset) for UTF-8 (the supported native path) +-- Charset form: decode(bin, charset) over valid input, multiple charsets -- =========================================================================== statement @@ -57,13 +60,13 @@ CREATE TABLE test_decode_charset_safe(b binary) USING parquet statement INSERT INTO test_decode_charset_safe VALUES (CAST('ab' AS BINARY)), (CAST('abcd' AS BINARY)), (CAST('' AS BINARY)), (NULL) -query expect_fallback(Comet only supports decoding with 'utf-8'.) +query SELECT decode(b, 'UTF-16BE') FROM test_decode_charset_safe -query expect_fallback(Comet only supports decoding with 'utf-8'.) +query SELECT decode(b, 'US-ASCII') FROM test_decode_charset_safe -query expect_fallback(Comet only supports decoding with 'utf-8'.) +query SELECT decode(b, 'ISO-8859-1') FROM test_decode_utf8 diff --git a/spark/src/test/resources/sql-tests/expressions/string/decode_invalid_utf8.sql b/spark/src/test/resources/sql-tests/expressions/string/decode_invalid_utf8.sql new file mode 100644 index 0000000000..5186a171db --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/decode_invalid_utf8.sql @@ -0,0 +1,58 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- decode() over invalid UTF-8 byte sequences with legacy / replacement-character semantics. +-- +-- On Spark 3.4 and 3.5 `decode(bin, charset)` always substitutes the Unicode replacement +-- character for malformed sequences (it lowers to `new String(bytes, charset)`, which uses the +-- JVM's default replace-on-error decoder). +-- On Spark 4.0+ the same substitute behavior is selected by enabling both +-- `spark.sql.legacy.javaCharsets` and `spark.sql.legacy.codingErrorAction`. +-- The 4.0 default (strict) mode is covered separately in decode_invalid_utf8_strict.sql. +-- +-- Regression coverage for #4465: prior to that fix Comet lowered `decode` to a TRY-mode binary→ +-- string cast, which produced wrong output (NULL or raw bytes) on invalid sequences regardless of +-- mode. The codegen dispatcher path delegates to Spark's own decoder so this fixture verifies the +-- replacement-character output matches Spark. +-- Config: spark.comet.exec.scalaUDF.codegen.enabled=true +-- Config: spark.sql.legacy.javaCharsets=true +-- Config: spark.sql.legacy.codingErrorAction=true + +statement +CREATE TABLE test_decode_invalid_utf8(b binary) USING parquet + +-- 0xFF: never valid in any UTF-8 position (neither a lead byte nor a continuation byte). +-- 0xC3 0x28: a 2-byte sequence whose continuation byte (0x28) is invalid. +-- 0xE2 0x82 0x28: a 3-byte sequence with an invalid continuation byte. +-- 'caf' || 0xE9: ISO-8859-1 'café' bytes — 0xE9 is a 3-byte UTF-8 lead byte without the two +-- continuation bytes that would follow it. +statement +INSERT INTO test_decode_invalid_utf8 VALUES + (X'FF'), + (X'C328'), + (X'E28228'), + (CONCAT(CAST('caf' AS BINARY), X'E9')), + (CAST('valid' AS BINARY)), + (NULL) + +query +SELECT decode(b, 'utf-8') FROM test_decode_invalid_utf8 + +query +SELECT decode(X'FF', 'utf-8'), + decode(X'C328', 'utf-8'), + decode(X'E28228', 'utf-8') diff --git a/spark/src/test/resources/sql-tests/expressions/string/decode_invalid_utf8_strict.sql b/spark/src/test/resources/sql-tests/expressions/string/decode_invalid_utf8_strict.sql new file mode 100644 index 0000000000..7c23a61047 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/decode_invalid_utf8_strict.sql @@ -0,0 +1,49 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- decode() over invalid UTF-8 byte sequences in Spark 4.0's default (strict) mode. +-- +-- Spark 4.0 added `spark.sql.legacy.codingErrorAction` (default `false`) which replaces the JVM +-- default substitute-on-error decoder with one that throws `MALFORMED_CHARACTER_CODING`. This +-- fixture asserts both Spark and Comet raise that error, with a sentinel valid-input query so the +-- assertion does not pass vacuously through an operator-level fallback. +-- Regression coverage for #4465. +-- MinSparkVersion: 4.0 +-- Config: spark.comet.exec.scalaUDF.codegen.enabled=true + +-- Sentinel: ensures Comet actually runs `decode` (codegen dispatcher) so the expect_error queries +-- below trip the kernel rather than being satisfied by an operator-level Spark fallback. +statement +CREATE TABLE test_decode_strict_sentinel(b binary) USING parquet + +statement +INSERT INTO test_decode_strict_sentinel VALUES (CAST('hello' AS BINARY)), (NULL) + +query +SELECT decode(b, 'utf-8') FROM test_decode_strict_sentinel + +-- 0xFF is not a valid UTF-8 lead byte; strict mode raises. +query expect_error(MALFORMED_CHARACTER_CODING) +SELECT decode(X'FF', 'utf-8') + +-- 0xC3 0x28: 2-byte sequence with invalid continuation. +query expect_error(MALFORMED_CHARACTER_CODING) +SELECT decode(X'C328', 'utf-8') + +-- 0xE2 0x82 0x28: 3-byte sequence with invalid continuation. +query expect_error(MALFORMED_CHARACTER_CODING) +SELECT decode(X'E28228', 'utf-8')