diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index 0f1788288a..19911640ba 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -126,7 +126,7 @@ class DatabricksAgentV2(DatabricksAgent): """ def __init__(self): - super(AsyncAgentBase, self).__init__(task_type_name="databricks", metadata_type=DatabricksJobMetadata) + super(DatabricksAgent, self).__init__(task_type_name="databricks", metadata_type=DatabricksJobMetadata) def get_header() -> typing.Dict[str, str]: diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 2ed01c2cd7..15e3b48a03 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -145,9 +145,14 @@ def __init__( self._default_applications_path or "local:///usr/local/bin/entrypoint.py" ) + if isinstance(task_config, DatabricksV2): + task_type = "databricks" + else: + task_type = "spark" + super(PysparkFunctionTask, self).__init__( task_config=task_config, - task_type=self._SPARK_TASK_TYPE, + task_type=task_type, task_function=task_function, container_image=container_image, **kwargs,