diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index ac1cde10bdf5..a849f1813674 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5526,6 +5526,17 @@ "message" : [ " is not supported" ] + }, + "UNSUPPORTED_STATEFUL_OPERATOR" : { + "message" : [ + "Unsupported stateful operator . Please use a checkpoint with supported stateful operators." + ] + }, + "UNSUPPORTED_TRANSFORM_WITH_STATE_VARIABLE_TYPE" : { + "message" : [ + "Unsupported transformWithState variable type (TTL_Enabled: , ColFamilyName: ).", + "Please use a checkpoint with supported transform with state variable types." + ] } }, "sqlState" : "55019" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/FlatMapGroupsInPandasWithStateExec.scala index 5b9b95ef413a..e9430ed9f9b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/FlatMapGroupsInPandasWithStateExec.scala @@ -120,7 +120,8 @@ case class FlatMapGroupsInPandasWithStateExec( override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq( groupingAttributes.map(SortOrder(_, Ascending))) - override def shortName: String = "applyInPandasWithState" + override def shortName: String = + StatefulOperatorsUtils.FLAT_MAP_GROUPS_IN_PANDAS_WITH_STATE_EXEC_OP_NAME override protected def withNewChildInternal( newChild: SparkPlan): FlatMapGroupsInPandasWithStateExec = copy(child = newChild) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/StatePartitionKeyExtractorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/StatePartitionKeyExtractorFactory.scala new file mode 100644 index 000000000000..ece76d22fda4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/StatePartitionKeyExtractorFactory.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.operators.stateful + +import org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorsUtils +import org.apache.spark.sql.execution.streaming.operators.stateful.flatmapgroupswithstate.FlatMapGroupsWithStatePartitionKeyExtractor +import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager +import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{TransformWithStatePartitionKeyExtractorFactory, TransformWithStateVariableInfo} +import org.apache.spark.sql.execution.streaming.state.{OfflineStateRepartitionErrors, StatePartitionKeyExtractor, StateStore, StateStoreId} +import org.apache.spark.sql.types.StructType + +/** + * Factory for creating state partition key extractor for various streaming stateful operators. + * This is used for offline state repartitioning, when we need to repartition + * the state for a given operator. If an operator isn't included in this factory, + * then offline repartitioning will not be supported for it. + * + * To support offline repartitioning for a new stateful operator, you need to: + * 1. Create a state partition key extractor for the operator state. + * 2. Register the state partition key extractor in this factory. + */ +object StatePartitionKeyExtractorFactory { + import StatefulOperatorsUtils._ + + /** + * Creates a state partition key extractor for the given operator. + * An operator may have different extractor for different stores/column families. + * + * @param operatorName The name of the operator. + * @param stateKeySchema The schema of the state key. + * @param storeName The name of the store. + * @param colFamilyName The name of the column family. + * @param stateFormatVersion Optional, the version of the state format. Used by operators + * that have different extractors for different state formats. + * @param stateVariableInfo Optional, the state variable info for TransformWithState. + * @return The state partition key extractor. + */ + def create( + operatorName: String, + stateKeySchema: StructType, + storeName: String = StateStoreId.DEFAULT_STORE_NAME, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME, + stateFormatVersion: Option[Int] = None, + stateVariableInfo: Option[TransformWithStateVariableInfo] = None + ): StatePartitionKeyExtractor = { + operatorName match { + case STATE_STORE_SAVE_EXEC_OP_NAME => + new StreamingAggregationStatePartitionKeyExtractor(stateKeySchema) + case DEDUPLICATE_EXEC_OP_NAME => + new StreamingDeduplicateStatePartitionKeyExtractor(stateKeySchema) + case DEDUPLICATE_WITHIN_WATERMARK_EXEC_OP_NAME => + new StreamingDedupWithinWatermarkStatePartitionKeyExtractor(stateKeySchema) + case SESSION_WINDOW_STATE_STORE_SAVE_EXEC_OP_NAME => + new StreamingSessionWindowStatePartitionKeyExtractor(stateKeySchema) + case SYMMETRIC_HASH_JOIN_EXEC_OP_NAME => + SymmetricHashJoinStateManager.createPartitionKeyExtractor( + storeName, colFamilyName, stateKeySchema, stateFormatVersion.get) + case fmg if FLAT_MAP_GROUPS_OP_NAMES.contains(fmg) => + new FlatMapGroupsWithStatePartitionKeyExtractor(stateKeySchema) + case tws if TRANSFORM_WITH_STATE_OP_NAMES.contains(tws) => + require(stateVariableInfo.isDefined, + "stateVariableInfo is required for TransformWithState") + TransformWithStatePartitionKeyExtractorFactory.create( + storeName, colFamilyName, stateKeySchema, stateVariableInfo.get) + case _ => throw OfflineStateRepartitionErrors + .unsupportedStatefulOperatorError(checkpointLocation = "", operatorName) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/StreamingAggregationStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/StreamingAggregationStateManager.scala index c7f7f388010d..8357053cdc46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/StreamingAggregationStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/StreamingAggregationStateManager.scala @@ -21,7 +21,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner} import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.execution.streaming.state.{ReadStateStore, StateStore, UnsafeRowPair} +import org.apache.spark.sql.execution.streaming.state.{NoopStatePartitionKeyExtractor, ReadStateStore, StateStore, UnsafeRowPair} import org.apache.spark.sql.types.StructType /** @@ -205,3 +205,9 @@ class StreamingAggregationStateManagerImplV2( } } } + +/** + * For aggregation state v1 and v2, the state key is the partition key i.e. the aggregation key + */ +class StreamingAggregationStatePartitionKeyExtractor(stateKeySchema: StructType) + extends NoopStatePartitionKeyExtractor(stateKeySchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/StreamingSessionWindowStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/StreamingSessionWindowStateManager.scala index a74b4aaf0da1..b517eec7bddc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/StreamingSessionWindowStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/StreamingSessionWindowStateManager.scala @@ -22,7 +22,7 @@ import org.apache.spark.internal.LogKeys._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection -import org.apache.spark.sql.execution.streaming.state.{ReadStateStore, StateStore, UnsafeRowPair} +import org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor, ReadStateStore, StateStore, UnsafeRowPair} import org.apache.spark.sql.types.{StructType, TimestampType} import org.apache.spark.util.NextIterator @@ -280,3 +280,10 @@ class StreamingSessionWindowHelper(sessionExpression: Attribute, inputSchema: Se (window.getLong(0), window.getLong(1)) } } + +/** + * The State key is the session key (i.e. partition key) and the sessionStartTime. + * Drop the last field (sessionStartTime) to get the partition key. + */ +class StreamingSessionWindowStatePartitionKeyExtractor(stateKeySchema: StructType) + extends DropLastNFieldsStatePartitionKeyExtractor(stateKeySchema, numLastColsToDrop = 1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/flatmapgroupswithstate/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/flatmapgroupswithstate/FlatMapGroupsWithStateExec.scala index a0778fbfb614..6b9f90a9ab5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/flatmapgroupswithstate/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/flatmapgroupswithstate/FlatMapGroupsWithStateExec.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorPartitioning, StatefulOperatorStateInfo, StateStoreWriter, WatermarkSupport} +import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorPartitioning, StatefulOperatorStateInfo, StatefulOperatorsUtils, StateStoreWriter, WatermarkSupport} import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._ import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} @@ -95,7 +95,7 @@ trait FlatMapGroupsWithStateExecBase override def keyExpressions: Seq[Attribute] = groupingAttributes - override def shortName: String = "flatMapGroupsWithState" + override def shortName: String = StatefulOperatorsUtils.FLAT_MAP_GROUPS_WITH_STATE_EXEC_OP_NAME override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = { timeoutConf match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/flatmapgroupswithstate/FlatMapGroupsWithStateExecHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/flatmapgroupswithstate/FlatMapGroupsWithStateExecHelper.scala index 280fcfc0ca1c..42b9412846d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/flatmapgroupswithstate/FlatMapGroupsWithStateExecHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/flatmapgroupswithstate/FlatMapGroupsWithStateExecHelper.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.execution.ObjectOperator import org.apache.spark.sql.execution.streaming.operators.stateful.flatmapgroupswithstate.GroupStateImpl.NO_TIMESTAMP -import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.execution.streaming.state.{NoopStatePartitionKeyExtractor, StateStore} import org.apache.spark.sql.types._ @@ -246,3 +246,10 @@ object FlatMapGroupsWithStateExecHelper { } } } + +/** + * For FlatMapGroupsWithStateExec and FlatMapGroupsInPandasWithStateExec (v1 & v2), + * the state key is the partition key i.e. the grouping key + */ +class FlatMapGroupsWithStatePartitionKeyExtractor(stateKeySchema: StructType) + extends NoopStatePartitionKeyExtractor(stateKeySchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala index 579267892abd..c27f2d116bf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorStateInfo import org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOpStateStoreCheckpointInfo import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._ -import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, StateSchemaBroadcast, StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema, StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics, StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay} +import org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor, KeyStateEncoderSpec, NoopStatePartitionKeyExtractor, NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast, StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema, StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics, StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay} import org.apache.spark.sql.types.{BooleanType, LongType, StructField, StructType} import org.apache.spark.util.NextIterator @@ -101,14 +101,14 @@ abstract class SymmetricHashJoinStateManager( joinStoreGenerator: JoinStateManagerStoreGenerator) extends Logging { import SymmetricHashJoinStateManager._ - protected val keySchema = StructType( + private[streaming] val keySchema = StructType( joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i", k.dataType, k.nullable) }) protected val keyAttributes = toAttributes(keySchema) - protected val keyToNumValues = new KeyToNumValuesStore( + private[streaming] val keyToNumValues = new KeyToNumValuesStore( stateFormatVersion, snapshotOptions.map(_.getKeyToNumValuesHandlerOpts())) - protected val keyWithIndexToValue = new KeyWithIndexToValueStore( + private[streaming] val keyWithIndexToValue = new KeyWithIndexToValueStore( stateFormatVersion, snapshotOptions.map(_.getKeyWithIndexToValueHandlerOpts())) @@ -1254,21 +1254,53 @@ object SymmetricHashJoinStateManager { } } - private[join] sealed trait StateStoreType + private[streaming] sealed trait StateStoreType - private[join] case object KeyToNumValuesType extends StateStoreType { + private[streaming] case object KeyToNumValuesType extends StateStoreType { override def toString(): String = "keyToNumValues" } - private[join] case object KeyWithIndexToValueType extends StateStoreType { + private[streaming] case object KeyWithIndexToValueType extends StateStoreType { override def toString(): String = "keyWithIndexToValue" } - private[join] def getStateStoreName( + private[streaming] def getStateStoreName( joinSide: JoinSide, storeType: StateStoreType): String = { s"$joinSide-$storeType" } + private[join] def getStoreType(storeName: String): StateStoreType = { + if (storeName == getStateStoreName(LeftSide, KeyToNumValuesType) || + storeName == getStateStoreName(RightSide, KeyToNumValuesType)) { + KeyToNumValuesType + } else if (storeName == getStateStoreName(LeftSide, KeyWithIndexToValueType) || + storeName == getStateStoreName(RightSide, KeyWithIndexToValueType)) { + KeyWithIndexToValueType + } else { + throw new IllegalArgumentException(s"Unknown join store name: $storeName") + } + } + + /** + * Returns the partition key extractor for the given join store and column family name. + */ + def createPartitionKeyExtractor( + storeName: String, + colFamilyName: String, + stateKeySchema: StructType, + stateFormatVersion: Int): StatePartitionKeyExtractor = { + assert(stateFormatVersion <= 3, "State format version must be less than or equal to 3") + val name = if (stateFormatVersion == 3) colFamilyName else storeName + if (getStoreType(name) == KeyWithIndexToValueType) { + // For KeyWithIndex, the index is added to the join (i.e. partition) key. + // Drop the last field (index) to get the partition key + new DropLastNFieldsStatePartitionKeyExtractor(stateKeySchema, numLastColsToDrop = 1) + } else { + // State key is the partition key + new NoopStatePartitionKeyExtractor(stateKeySchema) + } + } + /** Helper class for representing data (value, matched). */ case class ValueAndMatchPair(value: UnsafeRow, matched: Boolean) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala index 808ac8e6226b..6206e6832618 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala @@ -946,7 +946,7 @@ case class StateStoreSaveExec( } } - override def shortName: String = "stateStoreSave" + override def shortName: String = StatefulOperatorsUtils.STATE_STORE_SAVE_EXEC_OP_NAME override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = { (outputMode.contains(Append) || outputMode.contains(Update)) && @@ -1074,7 +1074,8 @@ case class SessionWindowStateStoreSaveExec( override def keyExpressions: Seq[Attribute] = keyWithoutSessionExpressions - override def shortName: String = "sessionWindowStateStoreSaveExec" + override def shortName: String = + StatefulOperatorsUtils.SESSION_WINDOW_STATE_STORE_SAVE_EXEC_OP_NAME private val stateManager = StreamingSessionWindowStateManager.createStateManager( keyWithoutSessionExpressions, sessionExpression, child.output, stateFormatVersion) @@ -1395,7 +1396,7 @@ case class StreamingDeduplicateExec( removeKeysOlderThanWatermark(store) } - override def shortName: String = "dedupe" + override def shortName: String = StatefulOperatorsUtils.DEDUPLICATE_EXEC_OP_NAME override protected def withNewChildInternal(newChild: SparkPlan): StreamingDeduplicateExec = copy(child = newChild) @@ -1416,6 +1417,12 @@ object StreamingDeduplicateExec { UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) } +/** + * For Deduplicate, the state key is the partition key i.e. the dedup key + */ +class StreamingDeduplicateStatePartitionKeyExtractor(stateKeySchema: StructType) + extends NoopStatePartitionKeyExtractor(stateKeySchema) + case class StreamingDeduplicateWithinWatermarkExec( keyExpressions: Seq[Attribute], child: SparkPlan, @@ -1478,7 +1485,7 @@ case class StreamingDeduplicateWithinWatermarkExec( } } - override def shortName: String = "dedupeWithinWatermark" + override def shortName: String = StatefulOperatorsUtils.DEDUPLICATE_WITHIN_WATERMARK_EXEC_OP_NAME override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): @@ -1494,6 +1501,12 @@ case class StreamingDeduplicateWithinWatermarkExec( newChild: SparkPlan): StreamingDeduplicateWithinWatermarkExec = copy(child = newChild) } +/** + * For DeduplicateWithinWatermark, the state key is the partition key i.e. the dedup key + */ +class StreamingDedupWithinWatermarkStatePartitionKeyExtractor(stateKeySchema: StructType) + extends NoopStatePartitionKeyExtractor(stateKeySchema) + trait SchemaValidationUtils extends Logging { // Determines whether the operator should be able to evolve their schema @@ -1561,4 +1574,14 @@ object StatefulOperatorsUtils { TRANSFORM_WITH_STATE_IN_PYSPARK_EXEC_OP_NAME ) val SYMMETRIC_HASH_JOIN_EXEC_OP_NAME = "symmetricHashJoin" + val STATE_STORE_SAVE_EXEC_OP_NAME = "stateStoreSave" + val DEDUPLICATE_EXEC_OP_NAME = "dedupe" + val DEDUPLICATE_WITHIN_WATERMARK_EXEC_OP_NAME = "dedupeWithinWatermark" + val SESSION_WINDOW_STATE_STORE_SAVE_EXEC_OP_NAME = "sessionWindowStateStoreSaveExec" + val FLAT_MAP_GROUPS_WITH_STATE_EXEC_OP_NAME = "flatMapGroupsWithState" + val FLAT_MAP_GROUPS_IN_PANDAS_WITH_STATE_EXEC_OP_NAME = "applyInPandasWithState" + val FLAT_MAP_GROUPS_OP_NAMES: Seq[String] = Seq( + FLAT_MAP_GROUPS_WITH_STATE_EXEC_OP_NAME, + FLAT_MAP_GROUPS_IN_PANDAS_WITH_STATE_EXEC_OP_NAME + ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/StateStoreColumnFamilySchemaUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/StateStoreColumnFamilySchemaUtils.scala index 7e25960daf33..3b82447fb8a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/StateStoreColumnFamilySchemaUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/StateStoreColumnFamilySchemaUtils.scala @@ -62,10 +62,43 @@ object StateStoreColumnFamilySchemaUtils { case _ => false } - def getTtlColFamilyName(stateName: String): String = { - "$ttl_" + stateName + private def makeCFName(prefix: String, stateName: String): String = prefix + stateName + + private def isCFPrefix(prefix: String, colFamilyName: String): Boolean = + colFamilyName.startsWith(prefix) + + private def getStateName(prefix: String, colFamilyName: String): String = { + require(isCFPrefix(prefix, colFamilyName), s"Column family name must have prefix $prefix") + colFamilyName.substring(prefix.length) } + private val TTL_COL_FAMILY_PREFIX = "$ttl_" + + def getTtlColFamilyName(stateName: String): String = + makeCFName(TTL_COL_FAMILY_PREFIX, stateName) + def isTtlColFamilyName(colFamilyName: String): Boolean = + isCFPrefix(TTL_COL_FAMILY_PREFIX, colFamilyName) + def getStateNameFromTtlColFamily(colFamilyName: String): String = + getStateName(TTL_COL_FAMILY_PREFIX, colFamilyName) + + private val MIN_EXPIRY_INDEX_PREFIX = "$min_" + + def getMinExpiryIndexCFName(stateName: String): String = + makeCFName(MIN_EXPIRY_INDEX_PREFIX, stateName) + def isMinExpiryIndexCFName(colFamilyName: String): Boolean = + isCFPrefix(MIN_EXPIRY_INDEX_PREFIX, colFamilyName) + def getStateNameFromMinExpiryIndexCFName(colFamilyName: String): String = + getStateName(MIN_EXPIRY_INDEX_PREFIX, colFamilyName) + + private val COUNT_INDEX_PREFIX = "$count_" + + def getCountIndexCFName(stateName: String): String = + makeCFName(COUNT_INDEX_PREFIX, stateName) + def isCountIndexCFName(colFamilyName: String): Boolean = + isCFPrefix(COUNT_INDEX_PREFIX, colFamilyName) + def getStateNameFromCountIndexCFName(colFamilyName: String): String = + getStateName(COUNT_INDEX_PREFIX, colFamilyName) + def getValueStateSchema[T]( stateName: String, keyEncoder: ExpressionEncoder[Any], @@ -135,7 +168,7 @@ object StateStoreColumnFamilySchemaUtils { // Min expiry index val minIndexSchema = StateStoreColFamilySchema( - s"$$min_$stateName", + getMinExpiryIndexCFName(stateName), keySchemaId = 0, keyEncoder.schema, valueSchemaId = 0, @@ -145,7 +178,7 @@ object StateStoreColumnFamilySchemaUtils { // Count index val countSchema = StateStoreColFamilySchema( - s"$$count_$stateName", + getCountIndexCFName(stateName), keySchemaId = 0, keyEncoder.schema, valueSchemaId = 0, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/StateTypesEncoderUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/StateTypesEncoderUtils.scala index 8ce300a40b43..d147ad66c246 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/StateTypesEncoderUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/StateTypesEncoderUtils.scala @@ -23,8 +23,9 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.TransformWithStateKeyValueRowSchemaUtils._ import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.statefulprocessor.ImplicitGroupingKeyTracker +import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.timers.TimerStateUtils import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.ttl.StateTTL -import org.apache.spark.sql.execution.streaming.state.StateStoreErrors +import org.apache.spark.sql.execution.streaming.state.{IndexBasedStatePartitionKeyExtractor, NoopStatePartitionKeyExtractor, OfflineStateRepartitionErrors, StatePartitionKeyExtractor, StateStore, StateStoreErrors, StateStoreId} import org.apache.spark.sql.types._ /** @@ -311,3 +312,127 @@ class TimerKeyEncoder(keyExprEnc: ExpressionEncoder[Any]) { keyDeserializer.apply(retUnsafeRow) } } + +/** + * For MapState main CF, the state key is the composite key (grouping key + user key) i.e. + * StructType("key": StructType, "userKey": StructType). The partition key + * is the grouping key i.e. first field. + */ +class MapStatePartitionKeyExtractor(stateKeySchema: StructType) + extends IndexBasedStatePartitionKeyExtractor(stateKeySchema, partitionKeyIndex = 0) + +/** + * TTL main CF have state keys with schema (expirationMs, elementKey). The partition key + * is the elementKey part. This is used by Value and List TTL main CF. + */ +class TTLStatePartitionKeyExtractor(stateKeySchema: StructType) + extends IndexBasedStatePartitionKeyExtractor(stateKeySchema, partitionKeyIndex = 1) + +/** + * For MapTTL CF, TTL keys have schema (expirationMs, elementKey), + * but for map, the elementKey is the composite key (grouping key, user key). + * Hence we need to extract the composite key from TTL key, + * then extract the grouping key from the composite key. + */ +class MapTTLStatePartitionKeyExtractor(stateKeySchema: StructType) + extends StatePartitionKeyExtractor { + // This will extract the compositeKey from the TTL key + private lazy val compositeKeyExtractor = new TTLStatePartitionKeyExtractor(stateKeySchema) + // This will extract the grouping key from the compositeKey + private lazy val partitionKeyExtractor = + new MapStatePartitionKeyExtractor(compositeKeyExtractor.partitionKeySchema) + + override lazy val partitionKeySchema: StructType = partitionKeyExtractor.partitionKeySchema + + override def partitionKey(stateKeyRow: UnsafeRow): UnsafeRow = { + partitionKeyExtractor.partitionKey(compositeKeyExtractor.partitionKey(stateKeyRow)) + } +} + +/** + * For extracting partition keys from Timer state keys (both event & processing time) + * Timer state has two key schemas: + * - Primary index CF: (key, expiryTimestampMs) + * - Secondary index CF: (expiryTimestampMs, key) + * The partition key for both is just the key field. + */ +class TimerStatePartitionKeyExtractor( + stateKeySchema: StructType, isSecondaryIndex: Boolean = false) + extends IndexBasedStatePartitionKeyExtractor( + stateKeySchema, partitionKeyIndex = if (isSecondaryIndex) 1 else 0) + +object TransformWithStatePartitionKeyExtractorFactory { + def create( + storeName: String, + colFamilyName: String, + stateKeySchema: StructType, + stateVariableInfo: TransformWithStateVariableInfo): StatePartitionKeyExtractor = { + require(storeName == StateStoreId.DEFAULT_STORE_NAME, "Store name must be default") + require(colFamilyName != StateStore.DEFAULT_COL_FAMILY_NAME, "Use non-default CF") + + if (stateVariableInfo.ttlEnabled) { + createForTTL(colFamilyName, stateKeySchema, stateVariableInfo) + } else { + createForStateVarType(colFamilyName, stateKeySchema, stateVariableInfo) + } + } + + private def createForTTL( + colFamilyName: String, + stateKeySchema: StructType, + stateVariableInfo: TransformWithStateVariableInfo): StatePartitionKeyExtractor = { + // TTL main CF + if (StateStoreColumnFamilySchemaUtils.isTtlColFamilyName(colFamilyName)) { + val stateName = StateStoreColumnFamilySchemaUtils + .getStateNameFromTtlColFamily(colFamilyName) + require(stateName == stateVariableInfo.stateName, "State name must match") + + stateVariableInfo.stateVariableType match { + case StateVariableType.MapState => new MapTTLStatePartitionKeyExtractor(stateKeySchema) + case StateVariableType.ListState | StateVariableType.ValueState => + new TTLStatePartitionKeyExtractor(stateKeySchema) + case _ => throw OfflineStateRepartitionErrors.unsupportedTransformWithStateVarTypeError( + checkpointLocation = "", + stateVariableInfo.stateVariableType.toString, + stateVariableInfo.ttlEnabled, + colFamilyName) + } + } else if (StateStoreColumnFamilySchemaUtils.isMinExpiryIndexCFName(colFamilyName)) { + val stateName = StateStoreColumnFamilySchemaUtils + .getStateNameFromMinExpiryIndexCFName(colFamilyName) + require(stateName == stateVariableInfo.stateName, "State name must match") + + new NoopStatePartitionKeyExtractor(stateKeySchema) + } else if (StateStoreColumnFamilySchemaUtils.isCountIndexCFName(colFamilyName)) { + val stateName = StateStoreColumnFamilySchemaUtils + .getStateNameFromCountIndexCFName(colFamilyName) + require(stateName == stateVariableInfo.stateName, "State name must match") + + new NoopStatePartitionKeyExtractor(stateKeySchema) + } else { + // TTL is enabled but this is the main CF for the state variable data + createForStateVarType(colFamilyName, stateKeySchema, stateVariableInfo) + } + } + + private def createForStateVarType( + colFamilyName: String, + stateKeySchema: StructType, + stateVariableInfo: TransformWithStateVariableInfo): StatePartitionKeyExtractor = { + stateVariableInfo.stateVariableType match { + case StateVariableType.ListState | StateVariableType.ValueState => + new NoopStatePartitionKeyExtractor(stateKeySchema) + case StateVariableType.MapState => new MapStatePartitionKeyExtractor(stateKeySchema) + case StateVariableType.TimerState => + require(TimerStateUtils.isTimerCFName(colFamilyName), + s"Column family name must be for a timer: $colFamilyName") + new TimerStatePartitionKeyExtractor( + stateKeySchema, TimerStateUtils.isTimerSecondaryIndexCF(colFamilyName)) + case _ => throw OfflineStateRepartitionErrors.unsupportedTransformWithStateVarTypeError( + checkpointLocation = "", + stateVariableInfo.stateVariableType.toString, + stateVariableInfo.ttlEnabled, + colFamilyName) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateVariableUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateVariableUtils.scala index 068303e25e2e..5b4100e9c256 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateVariableUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateVariableUtils.scala @@ -56,7 +56,13 @@ object TransformWithStateVariableUtils { } } - def getRowCounterCFName(stateName: String): String = "$rowCounter_" + stateName + private val ROW_COUNTER_CF_PREFIX = "$rowCounter_" + + def getRowCounterCFName(stateName: String): String = ROW_COUNTER_CF_PREFIX + stateName + + def isRowCounterCFName(colFamilyName: String): Boolean = { + colFamilyName.startsWith(ROW_COUNTER_CF_PREFIX) + } } // Enum of possible State Variable types diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala index 647c0b3036a2..75624c9af9ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala @@ -54,6 +54,16 @@ object TimerStateUtils { buildTimerStateNames(baseStateName) } + + def isTimerCFName(colFamilyName: String): Boolean = { + colFamilyName.startsWith(PROC_TIMERS_STATE_NAME) || + colFamilyName.startsWith(EVENT_TIMERS_STATE_NAME) + } + + def isTimerSecondaryIndexCF(colFamilyName: String): Boolean = { + assert(isTimerCFName(colFamilyName), s"Column family name must be for a timer: $colFamilyName") + colFamilyName.endsWith(TIMESTAMP_TO_KEY_CF) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala index 0e9b8ad8a63b..46df73fdb380 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala @@ -91,6 +91,21 @@ object OfflineStateRepartitionErrors { providerClass: String): StateRepartitionInvalidCheckpointError = { new StateRepartitionUnsupportedProviderError(checkpointLocation, providerClass) } + + def unsupportedStatefulOperatorError( + checkpointLocation: String, + operatorName: String): StateRepartitionInvalidCheckpointError = { + new StateRepartitionUnsupportedStatefulOperatorError(checkpointLocation, operatorName) + } + + def unsupportedTransformWithStateVarTypeError( + checkpointLocation: String, + variableType: String, + ttlEnabled: Boolean, + colFamilyName: String): StateRepartitionInvalidCheckpointError = { + new StateRepartitionUnsupportedTransformWithStateVarTypeError( + checkpointLocation, variableType, ttlEnabled, colFamilyName) + } } /** @@ -215,3 +230,24 @@ class StateRepartitionUnsupportedProviderError( checkpointLocation, subClass = "UNSUPPORTED_PROVIDER", messageParameters = Map("provider" -> provider)) + +class StateRepartitionUnsupportedStatefulOperatorError( + checkpointLocation: String, + operatorName: String) + extends StateRepartitionInvalidCheckpointError( + checkpointLocation, + subClass = "UNSUPPORTED_STATEFUL_OPERATOR", + messageParameters = Map("operatorName" -> operatorName)) + +class StateRepartitionUnsupportedTransformWithStateVarTypeError( + checkpointLocation: String, + variableType: String, + ttlEnabled: Boolean, + colFamilyName: String) + extends StateRepartitionInvalidCheckpointError( + checkpointLocation, + subClass = "UNSUPPORTED_TRANSFORM_WITH_STATE_VARIABLE_TYPE", + messageParameters = Map( + "variableType" -> variableType, + "ttlEnabled" -> ttlEnabled.toString, + "colFamilyName" -> colFamilyName)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index b4410362c4d3..b6d33ad9f57f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -65,7 +65,7 @@ trait StateSchemaProvider extends Serializable { // Test implementation that can be dynamically updated class TestStateSchemaProvider extends StateSchemaProvider { - private val schemas = mutable.Map.empty[StateSchemaMetadataKey, StateSchemaMetadataValue] + private[state] val schemas = mutable.Map.empty[StateSchemaMetadataKey, StateSchemaMetadataValue] /** * Captures a new schema pair (key schema and value schema) for a column family. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionKeyExtractor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionKeyExtractor.scala new file mode 100644 index 000000000000..cff20c511e97 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionKeyExtractor.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import scala.collection.immutable.ArraySeq + +import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes +import org.apache.spark.sql.types.StructType + +/** + * Trait for extracting partition keys from state key rows. + * The partition key is the key used by shuffle. + * This is used for offline state repartitioning. + */ +trait StatePartitionKeyExtractor { + /** + * Returns the schema of the partition key. + */ + def partitionKeySchema: StructType + + /** + * Extracts the partition key row from the given state key row. + * + * @note Depending on the implementation, it might not be safe to buffer the + * returned UnsafeRow across multiple calls of this method, due to UnsafeRow re-use. + * If you are holding on to the row between multiple calls, you should copy the row. + */ + def partitionKey(stateKeyRow: UnsafeRow): UnsafeRow +} + +/** + * No-op state partition key extractor that returns the state key row as the partition key row. + * This is used by operators that use the partition key as the state key. + * + * @param stateKeySchema The schema of the state key row + */ +class NoopStatePartitionKeyExtractor(stateKeySchema: StructType) + extends StatePartitionKeyExtractor { + override lazy val partitionKeySchema: StructType = stateKeySchema + + override def partitionKey(stateKeyRow: UnsafeRow): UnsafeRow = stateKeyRow +} + +/** + * State partition key extractor that returns the field at the specified index + * of the state key row as the partition key row. + * + * @param stateKeySchema The schema of the state key row + * @param partitionKeyIndex The index of the field to extract as the partition key + */ +class IndexBasedStatePartitionKeyExtractor(stateKeySchema: StructType, partitionKeyIndex: Int) + extends StatePartitionKeyExtractor { + override lazy val partitionKeySchema: StructType = + stateKeySchema.fields(partitionKeyIndex).dataType.asInstanceOf[StructType] + + override def partitionKey(stateKeyRow: UnsafeRow): UnsafeRow = { + stateKeyRow.getStruct(partitionKeyIndex, partitionKeySchema.length) + } +} + +/** + * State partition key extractor that drops the last N fields of the state key row + * and returns the remaining fields as the partition key row. + * + * @param stateKeySchema The schema of the state key row + * @param numLastColsToDrop The number of last columns to drop in the state key + */ +class DropLastNFieldsStatePartitionKeyExtractor(stateKeySchema: StructType, numLastColsToDrop: Int) + extends StatePartitionKeyExtractor { + override lazy val partitionKeySchema: StructType = { + require(numLastColsToDrop < stateKeySchema.length, + s"numLastColsToDrop: $numLastColsToDrop must be less than the number of fields in the " + + s"state key schema: ${stateKeySchema.length}, to avoid empty partition key schema") + StructType(stateKeySchema.dropRight(numLastColsToDrop)) + } + + private lazy val partitionKeyExpr: Array[BoundReference] = + partitionKeySchema.fields.zipWithIndex.map { case (field, index) => + BoundReference(index, field.dataType, field.nullable) + } + + private lazy val partitionKeyProjection: UnsafeProjection = UnsafeProjection.create( + ArraySeq.unsafeWrapArray(partitionKeyExpr), + toAttributes(stateKeySchema)) + + override def partitionKey(stateKeyRow: UnsafeRow): UnsafeRow = { + partitionKeyProjection(stateKeyRow) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala index ee5374c02435..1a1bed47552c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala @@ -23,6 +23,7 @@ import java.util.UUID import org.apache.spark.{SparkIllegalArgumentException, SparkUnsupportedOperationException} import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.TransformWithStateVariableUtils import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.statefulprocessor.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl} import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.ttl.ListStateImplWithTTL import org.apache.spark.sql.streaming.{ListState, TimeMode, TTLConfig, ValueState} @@ -424,6 +425,31 @@ class ListStateSuite extends StateVariableSuiteBase { assert(ttlStateValue.isDefined) } } + + test("Partition key extraction - ListState without TTL") { + testListStatePartitionKeyExtraction(ttlEnabled = false) + } + + test("Partition key extraction - ListState with TTL") { + testListStatePartitionKeyExtraction(ttlEnabled = true) + } + + private def testListStatePartitionKeyExtraction(ttlEnabled: Boolean): Unit = { + testPartitionKeyExtraction( + addStateFunc = { (handle, ttlConfig, _) => + val testState: ListState[Long] = handle.getListState[Long]("testState", ttlConfig) + ImplicitGroupingKeyTracker.setImplicitKey("key1") + testState.appendValue(100L) + testState.appendValue(101L) + ImplicitGroupingKeyTracker.setImplicitKey("key2") + testState.appendValue(200L) + }, + stateVariableInfo = TransformWithStateVariableUtils.getListState("testState", ttlEnabled), + ttlEnabled = ttlEnabled, + expectedNumColFamilies = if (ttlEnabled) 4 else 2, + groupingKeyToExpectedCount = Map("key1" -> 1, "key2" -> 1) + ) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala index dbbd0ce8388a..154344f05311 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala @@ -22,6 +22,7 @@ import java.util.UUID import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.Encoders +import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.TransformWithStateVariableUtils import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.statefulprocessor.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl} import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.ttl.MapStateImplWithTTL import org.apache.spark.sql.streaming.{ListState, MapState, TimeMode, TTLConfig, ValueState} @@ -289,6 +290,32 @@ class MapStateSuite extends StateVariableSuiteBase { assert(ttlStateValueIterator.hasNext) } } + + test("Partition key extraction - MapState without TTL") { + testMapStatePartitionKeyExtraction(ttlEnabled = false) + } + + test("Partition key extraction - MapState with TTL") { + testMapStatePartitionKeyExtraction(ttlEnabled = true) + } + + private def testMapStatePartitionKeyExtraction(ttlEnabled: Boolean): Unit = { + testPartitionKeyExtraction( + addStateFunc = { (handle, ttlConfig, _) => + val testState: MapState[String, Long] = + handle.getMapState[String, Long]("testState", ttlConfig) + ImplicitGroupingKeyTracker.setImplicitKey("key1") + testState.updateValue("userKey1", 100L) + testState.updateValue("userKey2", 101L) + ImplicitGroupingKeyTracker.setImplicitKey("key2") + testState.updateValue("userKey3", 200L) + }, + stateVariableInfo = TransformWithStateVariableUtils.getMapState("testState", ttlEnabled), + ttlEnabled = ttlEnabled, + expectedNumColFamilies = if (ttlEnabled) 2 else 1, + groupingKeyToExpectedCount = Map("key1" -> 2, "key2" -> 1) + ) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala index 30ae505477ef..8e30f7760781 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes -import org.apache.spark.sql.execution.streaming.operators.stateful.StreamingAggregationStateManager +import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorsUtils, StatePartitionKeyExtractorFactory, StreamingAggregationStateManager} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -125,4 +125,34 @@ class StreamingAggregationStateManagerSuite extends StreamTest { // state manager should return row which is same as input row regardless of format version assert(inputRow === stateManager.get(memoryStateStore, keyRow)) } + + // ============================ Partition Key Extraction Tests ============================ + + Seq(1, 2).foreach { version => + test(s"Partition key extraction - StateManager v$version") { + val stateManager = StreamingAggregationStateManager.createStateManager(testKeyAttributes, + testOutputAttributes, stateFormatVersion = version) + + val keySchema = testKeyAttributes.toStructType + + // Create extractor for aggregation operation + val extractor = StatePartitionKeyExtractorFactory.create( + StatefulOperatorsUtils.STATE_STORE_SAVE_EXEC_OP_NAME, + keySchema + ) + assert(extractor.partitionKeySchema === keySchema, + "Partition key schema should match the aggregation key schema") + + // Write input aggregation row via state manager + val memoryStateStore = new MemoryStateStore() + stateManager.put(memoryStateStore, testRow) + assert(stateManager.getKey(testRow) === expectedTestKeyRow) + + // Verify the state key and partition key by reading via store + val pair = memoryStateStore.iterator().next() + assert(pair.key === expectedTestKeyRow) + assert(extractor.partitionKey(pair.key) === pair.key, + "Partition key should be the same as the state key") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala index f3b76a8df047..849c2ba56f79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes -import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo, StreamingSessionWindowStateManager} +import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo, StatefulOperatorsUtils, StatePartitionKeyExtractorFactory, StreamingSessionWindowStateManager} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} @@ -74,6 +74,11 @@ class StreamingSessionWindowStateManagerSuite extends StreamTest with BeforeAndA s"provider ${providerOpt._2} state version v${version} - CRUD operations") { testAllOperations(version) } + + test("Partition key extraction - StreamingSessionWindowStateManager - " + + s"provider ${providerOpt._2} state version v$version") { + testPartitionKeyExtraction(version) + } } } @@ -168,6 +173,41 @@ class StreamingSessionWindowStateManagerSuite extends StreamTest with BeforeAndA } } + private def testPartitionKeyExtraction(stateFormatVersion: Int): Unit = { + withStateManager(stateFormatVersion) { case (stateManager, store) => + // Test data + val testRow1 = createRow("a", 1, 100, 150, 1) + val testRow2 = createRow("a", 1, 200, 250, 2) + val testRow3 = createRow("b", 2, 100, 150, 3) + val expectedKeyRow1 = createKeyRow("a", 1) + val expectedKeyRow2 = createKeyRow("b", 2) + + // Create extractor for session window operation + val extractor = StatePartitionKeyExtractorFactory.create( + StatefulOperatorsUtils.SESSION_WINDOW_STATE_STORE_SAVE_EXEC_OP_NAME, + stateManager.getStateKeySchema + ) + + // Verify partition key schema excludes sessionStartTime + assert(extractor.partitionKeySchema === keysWithoutSessionAttributes.toStructType, + "Partition key schema should exclude sessionStartTime") + + // Update sessions and verify partition key extraction + stateManager.updateSessions(store, expectedKeyRow1, Seq(testRow1, testRow2)) + stateManager.updateSessions(store, expectedKeyRow2, Seq(testRow3)) + + // Verify partition keys for stored state keys + val stateKeys = store.iterator().map(_.key).toList + assert(stateKeys.length === 3, "Should have 3 state keys stored") + + val partitionKeys = stateKeys.map(extractor.partitionKey(_).copy()) + assert(partitionKeys.count(_ === expectedKeyRow1) === 2, + "Should have 2 partition keys matching (a, 1)") + assert(partitionKeys.count(_ === expectedKeyRow2) === 1, + "Should have 1 partition key matching (b, 2)") + } + } + private def withStateManager( stateFormatVersion: Int)( f: (StreamingSessionWindowStateManager, StateStore) => Unit): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala index 44364626c20d..2e46202b3b0b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala @@ -22,6 +22,7 @@ import java.util.UUID import org.apache.hadoop.conf.Configuration import org.scalatest.BeforeAndAfter +import org.scalatest.PrivateMethodTester import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, GenericInternalRow, LessThanOrEqual, Literal, UnsafeProjection, UnsafeRow} @@ -29,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorStateInfo +import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo, StatefulOperatorsUtils, StatePartitionKeyExtractorFactory} import org.apache.spark.sql.execution.streaming.operators.stateful.join.{JoinStateManagerStoreGenerator, SymmetricHashJoinStateManager} import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper.LeftSide import org.apache.spark.sql.internal.SQLConf @@ -37,7 +38,8 @@ import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter { +class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter + with PrivateMethodTester { before { SparkSession.setActiveSession(spark) // set this before force initializing 'joinExec' @@ -84,6 +86,12 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter } } + SymmetricHashJoinStateManager.supportedVersions.foreach { version => + test(s"Partition key extraction - SymmetricHashJoinStateManager v$version") { + testPartitionKeyExtraction(version) + } + } + private def testAllOperations(stateFormatVersion: Int): Unit = { withJoinStateManager(inputValueAttribs, joinKeyExprs, stateFormatVersion) { manager => implicit val mgr = manager @@ -350,4 +358,87 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter } StateStore.stop() } + + private def testPartitionKeyExtraction(stateFormatVersion: Int): Unit = { + withJoinStateManager(inputValueAttribs, joinKeyExprs, stateFormatVersion) { manager => + implicit val mgr = manager + + val joinKeySchema = StructType( + joinKeyExprs.zipWithIndex.map { case (expr, i) => + StructField(s"field$i", expr.dataType, expr.nullable) + }) + + // Add some test data + append(key = 20, value = 100) + append(key = 20, value = 200) + append(key = 30, value = 150) + + Seq( + (getKeyToNumValuesStoreAndKeySchema(), SymmetricHashJoinStateManager + .getStateStoreName(LeftSide, SymmetricHashJoinStateManager.KeyToNumValuesType), + // expect 1 for both key 20 & 30 + 1, 1), + (getKeyWithIndexToValueStoreAndKeySchema(), SymmetricHashJoinStateManager + .getStateStoreName(LeftSide, SymmetricHashJoinStateManager.KeyWithIndexToValueType), + // expect 2 for key 20 & 1 for key 30 + 2, 1) + ).foreach { case ((store, keySchema), name, expectedNumKey20, expectedNumKey30) => + val storeName = if (stateFormatVersion == 3) { + StateStoreId.DEFAULT_STORE_NAME + } else { + name + } + + val colFamilyName = if (stateFormatVersion == 3) { + name + } else { + StateStore.DEFAULT_COL_FAMILY_NAME + } + + val extractor = StatePartitionKeyExtractorFactory.create( + StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME, + keySchema, + storeName, + colFamilyName, + stateFormatVersion = Some(stateFormatVersion) + ) + + assert(extractor.partitionKeySchema === joinKeySchema, + "Partition key schema should match the join key schema") + + // Copy both the state key and partition key to avoid UnsafeRow reuse issues + val stateKeys = store.iterator(colFamilyName).map(_.key.copy()).toList + val partitionKeys = stateKeys.map(extractor.partitionKey(_).copy()) + + assert(partitionKeys.length === expectedNumKey20 + expectedNumKey30, + "Should have same num partition keys as num state store keys") + assert(partitionKeys.count(_ === toJoinKeyRow(20)) === expectedNumKey20, + "Should have the expected num partition keys for join key 20") + assert(partitionKeys.count(_ === toJoinKeyRow(30)) === expectedNumKey30, + "Should have the expected num partition keys for join key 30") + } + } + } + + def getKeyToNumValuesStoreAndKeySchema() + (implicit manager: SymmetricHashJoinStateManager): (StateStore, StructType) = { + val keyToNumValuesHandler = manager.keyToNumValues + val keyToNumValuesStoreMethod = PrivateMethod[StateStore](Symbol("stateStore")) + val keyToNumValuesStore = keyToNumValuesHandler.invokePrivate(keyToNumValuesStoreMethod()) + + (keyToNumValuesStore, manager.keySchema) + } + + def getKeyWithIndexToValueStoreAndKeySchema() + (implicit manager: SymmetricHashJoinStateManager): (StateStore, StructType) = { + val keyWithIndexToValueHandler = manager.keyWithIndexToValue + + val keyWithIndexToValueStoreMethod = PrivateMethod[StateStore](Symbol("stateStore")) + val keyWithIndexToValueStore = + keyWithIndexToValueHandler.invokePrivate(keyWithIndexToValueStoreMethod()) + + val keySchemaMethod = PrivateMethod[StructType](Symbol("keyWithIndexSchema")) + val keyWithIndexToValueKeySchema = keyWithIndexToValueHandler.invokePrivate(keySchemaMethod()) + (keyWithIndexToValueStore, keyWithIndexToValueKeySchema) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala index b8ad09cb0d95..2b482f708a72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.TransformWithStateVariableUtils import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.statefulprocessor.ImplicitGroupingKeyTracker import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.timers.TimerStateImpl import org.apache.spark.sql.streaming.TimeMode @@ -186,6 +187,26 @@ class TimerSuite extends StateVariableSuiteBase { assert(timerState.listTimers().toSet === Set(1000L)) } } + + testWithTimeMode("Partition key extraction - TimerState") { timeMode => + testPartitionKeyExtraction( + addStateFunc = { (_, _, store) => + ImplicitGroupingKeyTracker.setImplicitKey("key1") + val timerState1 = new TimerStateImpl(store, timeMode, stringEncoder) + timerState1.registerTimer(1000L) + timerState1.registerTimer(2000L) + + ImplicitGroupingKeyTracker.setImplicitKey("key2") + val timerState2 = new TimerStateImpl(store, timeMode, stringEncoder) + timerState2.registerTimer(1500L) + }, + stateVariableInfo = TransformWithStateVariableUtils.getTimerState("testState"), + ttlEnabled = false, + expectedNumColFamilies = 2, + groupingKeyToExpectedCount = Map("key1" -> 2, "key2" -> 1), + timeMode = timeMode + ) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala index 328bc38bf7d5..456e9a7c7337 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala @@ -29,6 +29,9 @@ import org.apache.spark.{SparkException, SparkUnsupportedOperationException, Tas import org.apache.spark.TaskContext.withTaskContext import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorsUtils +import org.apache.spark.sql.execution.streaming.operators.stateful.StatePartitionKeyExtractorFactory +import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{TransformWithStateVariableInfo, TransformWithStateVariableUtils} import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.statefulprocessor.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl} import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.ttl.ValueStateImplWithTTL import org.apache.spark.sql.execution.streaming.runtime.StreamExecution @@ -421,6 +424,30 @@ class ValueStateSuite extends StateVariableSuiteBase { assert(ttlStateValueIterator.isDefined) } } + + test("Partition key extraction - ValueState without TTL") { + testValueStatePartitionKeyExtraction(ttlEnabled = false) + } + + test("Partition key extraction - ValueState with TTL") { + testValueStatePartitionKeyExtraction(ttlEnabled = true) + } + + private def testValueStatePartitionKeyExtraction(ttlEnabled: Boolean): Unit = { + testPartitionKeyExtraction( + addStateFunc = { (handle, ttlConfig, _) => + val testState: ValueState[Long] = handle.getValueState[Long]("testState", ttlConfig) + ImplicitGroupingKeyTracker.setImplicitKey("key1") + testState.update(100L) + ImplicitGroupingKeyTracker.setImplicitKey("key2") + testState.update(200L) + }, + stateVariableInfo = TransformWithStateVariableUtils.getValueState("testState", ttlEnabled), + ttlEnabled = ttlEnabled, + expectedNumColFamilies = if (ttlEnabled) 2 else 1, + groupingKeyToExpectedCount = Map("key1" -> 1, "key2" -> 1) + ) + } } /** @@ -454,10 +481,17 @@ abstract class StateVariableSuiteBase extends SharedSparkSession protected def useMultipleValuesPerKey = false protected def newStoreProviderWithStateVariable( - useColumnFamilies: Boolean): RocksDBStateStoreProvider = { + useColumnFamilies: Boolean, + schemaProvider: Option[StateSchemaProvider]): RocksDBStateStoreProvider = { newStoreProviderWithStateVariable(StateStoreId(newDir(), Random.nextInt(), 0), NoPrefixKeyStateEncoderSpec(schemaForKeyRow), - useColumnFamilies = useColumnFamilies) + useColumnFamilies = useColumnFamilies, + schemaProvider = schemaProvider) + } + + protected def newStoreProviderWithStateVariable( + useColumnFamilies: Boolean): RocksDBStateStoreProvider = { + newStoreProviderWithStateVariable(useColumnFamilies, schemaProvider = None) } protected def newStoreProviderWithStateVariable( @@ -465,14 +499,15 @@ abstract class StateVariableSuiteBase extends SharedSparkSession keyStateEncoderSpec: KeyStateEncoderSpec, sqlConf: SQLConf = SQLConf.get, conf: Configuration = new Configuration, - useColumnFamilies: Boolean = false): RocksDBStateStoreProvider = { + useColumnFamilies: Boolean = false, + schemaProvider: Option[StateSchemaProvider] = None): RocksDBStateStoreProvider = { val provider = new RocksDBStateStoreProvider() conf.set(StreamExecution.RUN_ID_KEY, UUID.randomUUID().toString) provider.init( storeId, schemaForKeyRow, schemaForValueRow, keyStateEncoderSpec, useColumnFamilies, new StateStoreConf(sqlConf), conf, useMultipleValuesPerKey, - Some(new TestStateSchemaProvider)) + schemaProvider.orElse(Some(new TestStateSchemaProvider))) provider } @@ -491,6 +526,70 @@ abstract class StateVariableSuiteBase extends SharedSparkSession provider.close() } } + + protected def testPartitionKeyExtraction( + addStateFunc: (StatefulProcessorHandleImpl, TTLConfig, StateStore) => Unit, + stateVariableInfo: TransformWithStateVariableInfo, + ttlEnabled: Boolean, + expectedNumColFamilies: Int, + groupingKeyToExpectedCount: Map[String, Int], + timeMode: TimeMode = TimeMode.ProcessingTime()): Unit = { + val schemaProvider = new TestStateSchemaProvider + tryWithProviderResource( + newStoreProviderWithStateVariable(true, Some(schemaProvider))) { provider => + val store = provider.getStore(0) + val timestampMs = 10 + val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + stringEncoder, timeMode, batchTimestampMs = Some(timestampMs)) + + val ttlConfig = if (ttlEnabled) { + TTLConfig(ttlDuration = Duration.ofMinutes(1)) + } else { + TTLConfig.NONE + } + + // call the passed in func to create the state variable and add state + addStateFunc(handle, ttlConfig, store) + + // Get all the column families and their key schemas + val colFamilyNameAndKeySchema = schemaProvider.schemas.filter(_._1.isKey) + .map(kv => (kv._1.colFamilyName, kv._2.sqlSchema)) + // don't include default CF + .filterNot(_._1 == StateStore.DEFAULT_COL_FAMILY_NAME) + + assert(colFamilyNameAndKeySchema.size === expectedNumColFamilies, + s"Should have $expectedNumColFamilies column families, " + + s"found ${colFamilyNameAndKeySchema.size}") + + // Verify partition key extraction for each column family + val expectedStateKeyCount = groupingKeyToExpectedCount.values.sum + colFamilyNameAndKeySchema.foreach { case (colFamilyName, keySchema) => + val extractor = StatePartitionKeyExtractorFactory.create( + StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME, + keySchema, + storeName = StateStoreId.DEFAULT_STORE_NAME, + colFamilyName = colFamilyName, + stateVariableInfo = Some(stateVariableInfo) + ) + + assert(extractor.partitionKeySchema === stringEncoder.schema, + "Partition key schema should match the grouping key schema") + + // Get all state keys and extract partition keys + val stateKeys = store.iterator(colFamilyName).map(_.key.copy()).toList + assert(stateKeys.length == expectedStateKeyCount, + s"Should have $expectedStateKeyCount state keys, found ${stateKeys.length}") + + val partitionKeys = stateKeys.map(extractor.partitionKey(_).copy()) + + groupingKeyToExpectedCount.foreach { case (keyStr, expectedCount) => + val keyRow = stringEncoder.createSerializer().apply(keyStr).copy() + assert(partitionKeys.count(_ === keyRow) == expectedCount, + s"Should have $expectedCount partition keys for $keyStr") + } + } + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 93efbe3b3cf5..5269edc68221 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.streaming import java.io.File import java.nio.ByteOrder import java.sql.Timestamp +import java.util.UUID +import org.apache.hadoop.conf.Configuration import org.scalatest.exceptions.TestFailedException import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction @@ -32,14 +34,17 @@ import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorsUtils +import org.apache.spark.sql.execution.streaming.operators.stateful.StatePartitionKeyExtractorFactory import org.apache.spark.sql.execution.streaming.operators.stateful.flatmapgroupswithstate.{FlatMapGroupsWithStateExec, FlatMapGroupsWithStateExecHelper, FlatMapGroupsWithStateUserFuncException} import org.apache.spark.sql.execution.streaming.runtime._ -import org.apache.spark.sql.execution.streaming.state.{MemoryStateStore, RocksDBStateStoreProvider, StateStore} +import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.functions.timestamp_seconds import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock -import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructType} import org.apache.spark.tags.SlowSQLTest +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** Class to check custom state types */ @@ -1215,6 +1220,142 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { } } } + + testWithAllStateVersions("Partition key extraction - FlatMapGroupsWithState without timeout") { + testPartitionKeyExtraction(timeoutEnabled = false) + } + + testWithAllStateVersions("Partition key extraction - FlatMapGroupsWithState with timeout") { + testPartitionKeyExtraction(timeoutEnabled = true) + } + + private def testPartitionKeyExtraction(timeoutEnabled: Boolean): Unit = { + withTempDir { checkpointDir => + // 1 partition to make verification easier + val conf = Map(SQLConf.SHUFFLE_PARTITIONS.key -> "1") + + val timeoutConf = if (timeoutEnabled) ProcessingTimeTimeout else NoTimeout + + // Function to maintain running count + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + state.update(RunningCount(count)) + if (timeoutEnabled) { + state.setTimeoutDuration("3 seconds") + } + Iterator((key, count.toString)) + } + + val inputStream = MemoryStream[String] + val result = inputStream.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, timeoutConf)(stateFunc) + + val inputData = Seq("a", "b", "c") + + // Run streaming query to populate state + // ProcessingTimeTimeout requires a manual clock and trigger to work properly + if (timeoutEnabled) { + val clock = new StreamManualClock + testStream(result, Update)( + StartStream( + Trigger.ProcessingTime("1 second"), + triggerClock = clock, + checkpointLocation = checkpointDir.getAbsolutePath, + additionalConfs = conf), + AddData(inputStream, inputData: _*), + AdvanceManualClock(1 * 1000), + // CheckNewAnswer waits for the batch to complete and commit state + CheckNewAnswer(("a", "1"), ("b", "1"), ("c", "1")), + StopStream + ) + } else { + testStream(result, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath, additionalConfs = conf), + AddData(inputStream, inputData: _*), + ProcessAllAvailable(), + StopStream + ) + } + + // Now access the state store to verify partition key extraction + val storeConf = new StateStoreConf(spark.sessionState.conf) + val storeId = StateStoreId(checkpointDir.getAbsolutePath + "/state", 0, 0) + val storeProviderId = StateStoreProviderId(storeId, UUID.randomUUID()) + + // The key schema for flatMapGroupsWithState is the grouping key (String) + val keySchema = new StructType().add("value", StringType) + + // Value schema differs between state format versions + // V1: flat structure with state fields + timestamp (IntegerType) + // V2: nested struct for state + timestamp (LongType) + val stateFormatVersion = sqlConf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) + val valueSchema = stateFormatVersion match { + case 1 => + // V1: UnsafeRow[ col1 | col2 | ... | timestamp (IntegerType) ] + var schema = new StructType().add("count", LongType) + if (timeoutEnabled) schema = schema.add("timeoutTimestamp", IntegerType) + schema + case 2 => + // V2: UnsafeRow[ groupState (nested struct) | timestamp (LongType) ] + var schema = new StructType() + .add("groupState", new StructType().add("count", LongType), nullable = true) + if (timeoutEnabled) schema = schema.add("timeoutTimestamp", LongType) + schema + case _ => + throw new IllegalArgumentException(s"Unknown state format version: $stateFormatVersion") + } + + val keyProjection = UnsafeProjection.create(keySchema) + def createExpectedKeyRow(key: String) = { + val row = new GenericInternalRow(Array[Any](UTF8String.fromString(key))) + keyProjection.apply(row).copy() + } + + val store = StateStore.getReadOnly( + storeProviderId, + keySchema, + valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + version = 1, + stateStoreCkptId = None, + stateSchemaBroadcast = None, + useColumnFamilies = false, + storeConf, + new Configuration + ) + + try { + val extractor = StatePartitionKeyExtractorFactory.create( + StatefulOperatorsUtils.FLAT_MAP_GROUPS_WITH_STATE_EXEC_OP_NAME, + keySchema + ) + + // Verify partition key schema matches the key schema + assert(extractor.partitionKeySchema === keySchema, + "Partition key schema should match the grouping key schema") + + // Get all state keys written by the query + val stateKeys = store.iterator().map(_.key.copy()).toList + assert(stateKeys.length === inputData.length, + s"Should have ${inputData.length} unique keys, found ${stateKeys.length}") + + // Extract partition keys + val partitionKeys = stateKeys.map(extractor.partitionKey(_).copy()) + // Verify each partition key equals its corresponding state key + assert(partitionKeys === stateKeys, + "Partition keys should match state keys") + + // Expected keys + inputData.foreach { key => + val keyRow = createExpectedKeyRow(key) + assert(partitionKeys.count(_ === keyRow) == 1, s"Should have 1 partition key for $key") + } + } finally { + store.abort() + } + } + } } object FlatMapGroupsWithStateSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala index 832b22d6304f..dbc2b767b0f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala @@ -18,19 +18,28 @@ package org.apache.spark.sql.streaming import java.io.File +import java.util.UUID + +import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkUnsupportedOperationException -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorsUtils +import org.apache.spark.sql.execution.streaming.operators.stateful.StatePartitionKeyExtractorFactory import org.apache.spark.sql.execution.streaming.runtime.MemoryStream +import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.tags.SlowSQLTest +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @SlowSQLTest -class StreamingDeduplicationSuite extends StateStoreMetricsTest { +class StreamingDeduplicationSuite extends StateStoreMetricsTest + with StreamingDeduplicationSuiteBase { import testImplicits._ @@ -587,6 +596,102 @@ class StreamingDeduplicationSuite extends StateStoreMetricsTest { } assert(ex.getMessage.contains("State store encoding format as avro is not supported")) } + + test("Partition key extraction - Dedupe") { + val df = (input: Dataset[(String, Int)]) => input.dropDuplicates("_1").toDF() + + testPartitionKeyExtraction( + StatefulOperatorsUtils.DEDUPLICATE_EXEC_OP_NAME, + inputData = Seq(("a", 1), ("b", 2), ("c", 3)), + dedupeDF = df, + // The key schema for dedup is just the _1 column (String) + keySchema = new StructType().add("_1", StringType), + // Empty value schema for dedup + valueSchema = new StructType(), + sqlConf = spark.sessionState.conf + ) + } +} + +trait StreamingDeduplicationSuiteBase { self: StreamTest => + import testImplicits._ + + protected def testPartitionKeyExtraction( + operatorName: String, + inputData: Seq[(String, Int)], + dedupeDF: Dataset[(String, Int)] => DataFrame, + keySchema: StructType, + valueSchema: StructType, + sqlConf: SQLConf): Unit = { + withTempDir { checkpointDir => + // 1 partition to make verification easier + val conf = Map(SQLConf.SHUFFLE_PARTITIONS.key -> "1") + + val inputStream = MemoryStream[(String, Int)] + + // Run streaming query to populate state + testStream(dedupeDF(inputStream.toDS()), Append)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath, additionalConfs = conf), + AddData(inputStream, inputData: _*), + ProcessAllAvailable(), + StopStream + ) + + // Now access the state store to verify partition key extraction + val storeConf = new StateStoreConf(sqlConf) + val storeId = StateStoreId(checkpointDir.getAbsolutePath + "/state", 0, 0) + val storeProviderId = StateStoreProviderId(storeId, UUID.randomUUID()) + + val keyProjection = UnsafeProjection.create(keySchema) + def createExpectedKeyRow(inputKey: String) = { + val row = new GenericInternalRow(Array[Any](UTF8String.fromString(inputKey))) + keyProjection.apply(row).copy() + } + + val store = StateStore.getReadOnly( + storeProviderId, + keySchema, + valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + version = 1, + stateStoreCkptId = None, + stateSchemaBroadcast = None, + useColumnFamilies = false, + storeConf, + new Configuration + ) + + try { + val extractor = StatePartitionKeyExtractorFactory.create( + operatorName, + keySchema + ) + + // Verify partition key schema matches the key schema + assert(extractor.partitionKeySchema === keySchema, + "Partition key schema should match the dedup key schema") + + // Get all state keys written by the dedup query + val stateKeys = store.iterator().map(_.key.copy()).toList + assert(stateKeys.length === inputData.length, + s"Should have ${inputData.length} unique keys, found ${stateKeys.length}") + + // Extract partition keys + val partitionKeys = stateKeys.map(extractor.partitionKey(_).copy()) + // Verify each partition key equals its corresponding state key (for dedup) + assert(partitionKeys === stateKeys, + "Partition keys should match state keys") + + // Expected keys + inputData.foreach { case (key, _) => + val keyRow = createExpectedKeyRow(key) + assert(partitionKeys.count(_ === keyRow) == 1, s"Should have 1 partition key for $key") + } + } finally { + store.abort() + } + } + } } @SlowSQLTest diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationWithinWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationWithinWatermarkSuite.scala index a6223cef32da..9645f82ac241 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationWithinWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationWithinWatermarkSuite.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql.streaming import org.apache.spark.sql.{AnalysisException, Dataset, SaveMode} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Append +import org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorsUtils import org.apache.spark.sql.execution.streaming.runtime.MemoryStream import org.apache.spark.sql.functions.timestamp_seconds +import org.apache.spark.sql.types.{LongType, StringType, StructType} import org.apache.spark.tags.SlowSQLTest @SlowSQLTest -class StreamingDeduplicationWithinWatermarkSuite extends StateStoreMetricsTest { +class StreamingDeduplicationWithinWatermarkSuite extends StateStoreMetricsTest + with StreamingDeduplicationSuiteBase { import testImplicits._ @@ -234,4 +237,24 @@ class StreamingDeduplicationWithinWatermarkSuite extends StateStoreMetricsTest { CheckAnswer(2) ) } + + test("Partition key extraction - DedupeWithinWatermark") { + val df = (input: Dataset[(String, Int)]) => { + input.withColumn("eventTime", timestamp_seconds($"_2")) + .withWatermark("eventTime", "10 seconds") + .dropDuplicatesWithinWatermark("_1") + .select($"_1", $"eventTime".cast("long").as[Long]) + } + + testPartitionKeyExtraction( + StatefulOperatorsUtils.DEDUPLICATE_WITHIN_WATERMARK_EXEC_OP_NAME, + inputData = Seq(("a", 17), ("b", 22), ("c", 21)), + dedupeDF = df, + // The key schema for dedup within watermark is just the _1 column (String) + keySchema = new StructType().add("_1", StringType), + // Value schema includes the expiration time + valueSchema = new StructType().add("expiresAtMicros", LongType), + sqlConf = spark.sessionState.conf + ) + } }