@@ -33,7 +33,6 @@ import scala.util.control.NonFatal
3333import org .apache .spark ._
3434import org .apache .spark .api .python .PythonFunction .PythonAccumulator
3535import org .apache .spark .internal .{Logging , LogKeys , MessageWithContext }
36- import org .apache .spark .internal .LogKeys .TASK_NAME
3736import org .apache .spark .internal .config .{BUFFER_SIZE , EXECUTOR_CORES }
3837import org .apache .spark .internal .config .Python ._
3938import org .apache .spark .rdd .InputFileBlockHolder
@@ -134,6 +133,15 @@ private[spark] object BasePythonRunner extends Logging {
134133 } else None
135134 }
136135
136+ /**
137+ * Creates a task identifier string for logging following Spark's standard format.
138+ * Format: "task <partition>.<attempt> in stage <stageId> (TID <taskAttemptId>)"
139+ */
140+ private [spark] def taskIdentifier (context : TaskContext ): String = {
141+ s " task ${context.partitionId()}. ${context.attemptNumber()} in stage ${context.stageId()} " +
142+ s " (TID ${context.taskAttemptId()}) "
143+ }
144+
137145 private [spark] def pythonWorkerStatusMessageWithContext (
138146 handle : Option [ProcessHandle ],
139147 worker : PythonWorker ,
@@ -259,6 +267,10 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
259267 val startTime = System .currentTimeMillis
260268 val env = SparkEnv .get
261269
270+ // Log task context information at the start of computation
271+ logInfo(s " Starting Python task execution (Stage ${context.stageId()}, " +
272+ s " Attempt ${context.attemptNumber()}) - ${taskIdentifier(context)}" )
273+
262274 // Get the executor cores and pyspark memory, they are passed via the local properties when
263275 // the user specified them in a ResourceProfile.
264276 val execCoresProp = Option (context.getLocalProperty(EXECUTOR_CORES_LOCAL_PROPERTY ))
@@ -334,7 +346,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
334346 // Return an iterator that read lines from the process's stdout
335347 val dataIn = new DataInputStream (new BufferedInputStream (
336348 new ReaderInputStream (worker, writer, handle,
337- faultHandlerEnabled, idleTimeoutSeconds, killOnIdleTimeout),
349+ faultHandlerEnabled, idleTimeoutSeconds, killOnIdleTimeout, context ),
338350 bufferSize))
339351 val stdoutIterator = newReaderIterator(
340352 dataIn, writer, startTime, env, worker, handle.map(_.pid.toInt), releasedOrClosed, context)
@@ -585,6 +597,10 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
585597 private var nextObj : OUT = _
586598 private var eos = false
587599
600+ // Track records and data size for logging
601+ protected var recordsProcessed : Long = 0
602+ protected var totalDataReceived : Long = 0
603+
588604 override def hasNext : Boolean = nextObj != null || {
589605 if (! eos) {
590606 nextObj = read()
@@ -620,10 +636,18 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
620636 val init = initTime - bootTime
621637 val finish = finishTime - initTime
622638 val total = finishTime - startTime
623- logInfo(log " Times: total = ${MDC (LogKeys .TOTAL_TIME , total)}, " +
624- log " boot = ${MDC (LogKeys .BOOT_TIME , boot)}, " +
625- log " init = ${MDC (LogKeys .INIT_TIME , init)}, " +
626- log " finish = ${MDC (LogKeys .FINISH_TIME , finish)}" )
639+
640+ // Format data size for readability
641+ val dataKB = totalDataReceived / 1024.0
642+ val dataMB = dataKB / 1024.0
643+ val dataStr = if (dataMB >= 1.0 ) {
644+ f " $dataMB%.2f MB "
645+ } else {
646+ f " $dataKB%.2f KB "
647+ }
648+
649+ logInfo(s " Times: total = $total, boot = $boot, init = $init, finish = $finish - " +
650+ s " Records: $recordsProcessed, Data: $dataStr - ${taskIdentifier(context)}" )
627651 metrics.get(" pythonBootTime" ).foreach(_.add(boot))
628652 metrics.get(" pythonInitTime" ).foreach(_.add(init))
629653 metrics.get(" pythonTotalTime" ).foreach(_.add(total))
@@ -660,8 +684,9 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
660684 throw new TaskKilledException (context.getKillReason().getOrElse(" unknown reason" ))
661685
662686 case e : Exception if writer.exception.isDefined =>
663- logError(" Python worker exited unexpectedly (crashed)" , e)
664- logError(" This may have been caused by a prior exception:" , writer.exception.get)
687+ logError(s " Python worker exited unexpectedly (crashed) - ${taskIdentifier(context)}" , e)
688+ logError(s " This may have been caused by a prior exception - ${taskIdentifier(context)}" ,
689+ writer.exception.get)
665690 throw writer.exception.get
666691
667692 case e : IOException if ! faultHandlerEnabled =>
@@ -703,16 +728,14 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
703728 Thread .sleep(taskKillTimeout)
704729 if (! context.isCompleted()) {
705730 try {
706- // Mimic the task name used in `Executor` to help the user find out the task to blame.
707- val taskName = s " ${context.partitionId()}. ${context.attemptNumber()} " +
708- s " in stage ${context.stageId()} (TID ${context.taskAttemptId()}) "
709- logWarning(log " Incomplete task ${MDC (TASK_NAME , taskName)} " +
710- log " interrupted: Attempting to kill Python Worker " )
731+ logWarning(s " Incomplete task interrupted: Attempting to kill Python Worker " +
732+ s " (Stage ${context.stageId()}, Attempt ${context.attemptNumber()}) - " +
733+ s " ${taskIdentifier(context)}" )
711734 env.destroyPythonWorker(
712735 pythonExec, workerModule, daemonModule, envVars.asScala.toMap, worker)
713736 } catch {
714737 case e : Exception =>
715- logError(" Exception when trying to kill worker" , e)
738+ logError(s " Exception when trying to kill worker - ${taskIdentifier(context)} " , e)
716739 }
717740 }
718741 }
@@ -736,7 +759,8 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
736759 handle : Option [ProcessHandle ],
737760 faultHandlerEnabled : Boolean ,
738761 idleTimeoutSeconds : Long ,
739- killOnIdleTimeout : Boolean ) extends InputStream {
762+ killOnIdleTimeout : Boolean ,
763+ context : TaskContext ) extends InputStream {
740764 private [this ] var writerIfbhThreadLocalValue : Object = null
741765 private [this ] val temp = new Array [Byte ](1 )
742766 private [this ] val bufferStream = new DirectByteBufferOutputStream ()
@@ -811,16 +835,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
811835 pythonWorkerStatusMessageWithContext(handle, worker, hasInput || buffer.hasRemaining))
812836 } else {
813837 logWarning(
814- log " Idle timeout reached for Python worker (timeout: " +
815- log " ${MDC (LogKeys .PYTHON_WORKER_IDLE_TIMEOUT , idleTimeoutSeconds)} seconds). " +
816- log " No data received from the worker process: " +
817- pythonWorkerStatusMessageWithContext(handle, worker, hasInput || buffer.hasRemaining))
838+ s " Idle timeout reached for Python worker (timeout: $idleTimeoutSeconds seconds). " +
839+ s " No data received from the worker process - ${taskIdentifier(context)}" )
818840 if (killOnIdleTimeout) {
819841 handle.foreach { handle =>
820842 if (handle.isAlive) {
821- logWarning(
822- log " Terminating Python worker process due to idle timeout (timeout: " +
823- log " ${MDC (LogKeys .PYTHON_WORKER_IDLE_TIMEOUT , idleTimeoutSeconds)} seconds) " )
843+ logWarning(s " Terminating Python worker process due to idle timeout " +
844+ s " (timeout: $idleTimeoutSeconds seconds) - ${taskIdentifier(context)}" )
824845 pythonWorkerKilled = handle.destroy()
825846 }
826847 }
@@ -1015,7 +1036,10 @@ private[spark] class PythonRunner(
10151036 try {
10161037 stream.readInt() match {
10171038 case length if length >= 0 =>
1018- PythonWorkerUtils .readBytes(length, stream)
1039+ val data = PythonWorkerUtils .readBytes(length, stream)
1040+ recordsProcessed += 1
1041+ totalDataReceived += length
1042+ data
10191043 case SpecialLengths .TIMING_DATA =>
10201044 handleTimingData()
10211045 read()
0 commit comments