Skip to content

Commit

Permalink
Remove dependency to 3rd party lib
Browse files Browse the repository at this point in the history
  • Loading branch information
pbernet committed Feb 20, 2024
1 parent 1e477da commit 494387b
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 46 deletions.
2 changes: 0 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ libraryDependencies ++= Seq(

"org.apache.opennlp" % "opennlp-tools" % "2.3.1",

"com.crowdscriber.captions" %% "caption-parser" % "0.1.6",

"org.apache.httpcomponents.client5" % "httpclient5" % "5.3.1",
"org.apache.httpcomponents.core5" % "httpcore5" % "5.2.4",
"commons-io" % "commons-io" % "2.11.0",
Expand Down
91 changes: 91 additions & 0 deletions src/main/scala/tools/SrtParser.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package tools

import org.apache.commons.lang3.StringUtils
import org.apache.pekko.actor.ActorSystem
import org.apache.pekko.stream.scaladsl.{Framing, Sink, Source, StreamConverters}
import org.apache.pekko.util.ByteString
import org.slf4j.{Logger, LoggerFactory}

import java.io.FileInputStream
import java.time.format.DateTimeFormatter
import java.time.{Duration, LocalTime}
import scala.concurrent.duration.DurationInt
import scala.concurrent.{Await, ExecutionContextExecutor}

/**
* Run with a reasonably well formatted .srt file in UTF-8 encoding
*
* @param sourceFilePath
*/
class SrtParser(sourceFilePath: String) {
val logger: Logger = LoggerFactory.getLogger(this.getClass)
implicit val system: ActorSystem = ActorSystem()
implicit val executionContext: ExecutionContextExecutor = system.dispatcher

val ls = sys.props("line.separator")

private val frameByEmptyLine = Framing.delimiter(
ByteString(ls + ls),
maximumFrameLength = 2048,
allowTruncation = true)

private def toMillisOfDay(value: String) = {
val formatter = DateTimeFormatter.ofPattern("HH:mm:ss,SSS")
val localTime = LocalTime.parse(value, formatter)
localTime.toNanoOfDay / 1_000_000
}

private def convertTo(raw: String) = {
val parts = raw.split(ls)
val times = parts(1).split("""\s*-->\s*""")
val lines = parts.drop(2).toList
SubtitleBlock(toMillisOfDay(times.head), toMillisOfDay(times.tail.head), lines)
}

val source: Source[SubtitleBlock, Any] = {
StreamConverters.fromInputStream(() => new FileInputStream(sourceFilePath))
.via(frameByEmptyLine)
.map(each => each.utf8String)
.map(each => convertTo(each))
}

def runSync(): Seq[SubtitleBlock] = {
val resultFut = source.runWith(Sink.seq)
Await.result(resultFut, 10.seconds)
}
}

object SrtParser extends App {
val logger: Logger = LoggerFactory.getLogger(this.getClass)
val parser = new SrtParser("src/main/resources/EN_challenges.srt")
val result = parser.runSync()
logger.info(s"File contains: ${result.size} SubtitleBlock(s)")
logger.info(s"Blocks: $result")

def apply(sourceFilePath: String): SrtParser = new SrtParser(sourceFilePath)
}

case class SubtitleBlock(start: Long, end: Long, lines: Seq[String]) {
val logger: Logger = LoggerFactory.getLogger(this.getClass)
val ls = sys.props("line.separator")

def allLines: String = lines.mkString(" ")

def allLinesEnd: String = allLines + ls + ls

def formatOutBlock(blockCounter: Long): String = {
// Spec: https://wiki.videolan.org/SubRip
val outputFormatted = s"$blockCounter$ls${toTime(start)} --> ${toTime(end)}$ls${lines.mkString("\n")}$ls$ls"
logger.info(s"Writing block:$ls {}", outputFormatted)
outputFormatted
}

private def toTime(ms: Long) = {
val d = Duration.ofMillis(ms)
val hours = StringUtils.leftPad(d.toHoursPart.toString, 2, "0")
val minutes = StringUtils.leftPad(d.toMinutesPart.toString, 2, "0")
val seconds = StringUtils.leftPad(d.toSecondsPart.toString, 2, "0")
val milliSeconds = StringUtils.leftPad(d.toMillisPart.toString, 3, "0")
s"$hours:$minutes:$seconds,$milliSeconds"
}
}
60 changes: 16 additions & 44 deletions src/main/scala/tools/SubtitleTranslator.scala
Original file line number Diff line number Diff line change
@@ -1,27 +1,22 @@
package tools

import com.crowdscriber.caption.common.Vocabulary.{Srt, SubtitleBlock}
import com.crowdscriber.caption.srtdissector.SrtDissector
import org.apache.commons.lang3.StringUtils
import org.apache.pekko.actor.ActorSystem
import org.apache.pekko.stream.scaladsl.{FileIO, Flow, Source}
import org.apache.pekko.stream.{IOResult, ThrottleMode}
import org.apache.pekko.util.ByteString
import org.slf4j.{Logger, LoggerFactory}

import java.io.FileInputStream
import java.nio.file.Paths
import java.time.Duration
import scala.concurrent.duration.DurationInt
import scala.concurrent.{ExecutionContextExecutor, Future}
import scala.util.{Failure, Success, Try}
import scala.util.{Failure, Success}

/**
* Translate all blocks of an English .srt file to a target lang using OpenAI API
*
* Workflow:
* - Load all blocks from the .srt source file
* - Group blocks to scenes (= all blocks within a session window), depending on maxGap
* - Load all blocks from the .srt source file with [[SrtParser]]
* - Group blocks to scenes (= all blocks within a session window), depending on `maxGapSeconds`
* - Translate all blocks of a scene in one prompt (one line per block) via the openAI API
* - Continuously write translated blocks to target file
*
Expand All @@ -42,31 +37,25 @@ object SubtitleTranslator extends App {
implicit val system: ActorSystem = ActorSystem()
implicit val executionContext: ExecutionContextExecutor = system.dispatcher

val sourceFilePath = "The.Quiet.Girl.2022.BluRay.1080p.DTS-HD.MA.5.1.AVC.REMUX-FraMeSToR.srt"
private val targetFilePath = "DE_The.Quiet.Girl.2022.BluRay.1080p.DTS-HD.MA.5.1.AVC.REMUX-FraMeSToR.srt"
val sourceFilePath = "src/main/resources/EN_challenges.srt"
private val targetFilePath = "DE_challenges.srt"
private val targetLanguage = "German"

private val defaultModel = "gpt-3.5-turbo"
private val fallbackModel = "gpt-3.5-turbo-instruct"

private val maxGapSeconds = 1 // idle time between scenes (= session windows)
private val endBlockTag = "\n" // one block per line
private val maxGapSeconds = 1 // gap time between two scenes (= session windows)
private val endBlockTag = sys.props("line.separator") // one block per line
private val maxCharPerTranslatedLine = 40 // recommendation
private val conversationPrefix = "-"

private var totalTokensUsed = 0

implicit class SubtitleBlockExtensions(block: SubtitleBlock) {
def allLines: String = block.lines.mkString(" ")
// Sync to ensure that all blocks are readable before translation starts
val parseResult = SrtParser(sourceFilePath).runSync()
logger.info("Number of subtitleBlocks to translate: {}", parseResult.length)

def allLinesEbNewLine: String = allLines + endBlockTag + endBlockTag
}

// Source file must be in utf-8
private val srt: Try[Srt] = SrtDissector(new FileInputStream(sourceFilePath))
logger.info("Number of subtitleBlocks to translate: {}", srt.get.length)

val source = Source.fromIterator(() => srt.get.iterator)
val source = Source(parseResult)

val workflow = Flow[SubtitleBlock]
.via(groupByScene(maxGapSeconds))
Expand All @@ -77,7 +66,7 @@ object SubtitleTranslator extends App {
val processingSink = Flow[SubtitleBlock]
.zipWithIndex
.map { case (block: SubtitleBlock, blockCounter: Long) =>
ByteString(formatOutBlock(block, blockCounter + 1))
ByteString(block.formatOutBlock(blockCounter + 1))
}
.toMat(fileSink)((_, bytesWritten) => bytesWritten)

Expand Down Expand Up @@ -115,7 +104,7 @@ object SubtitleTranslator extends App {
private def translateScene(sceneOrig: List[SubtitleBlock]) = {
logger.info(s"About to translate scene with: ${sceneOrig.size} original blocks")

val allLines = sceneOrig.foldLeft("")((acc, block) => acc + block.allLinesEbNewLine)
val allLines = sceneOrig.foldLeft("")((acc, block) => acc + block.allLinesEnd)
val toTranslate = generateTranslationPrompt(allLines)
logger.info(s"Translation prompt: $toTranslate")

Expand Down Expand Up @@ -174,7 +163,7 @@ object SubtitleTranslator extends App {
|Translate the text lines below from English to $targetLanguage.
|
|Desired format:
|<line separated list of translated lines, honor line breaks>
|<line separated list of translated text lines, honor all line breaks>
|
|Text lines:
|$text
Expand Down Expand Up @@ -235,30 +224,13 @@ object SubtitleTranslator extends App {
List(firstHalf, secondHalf)
}

private def toTime(ms: Int) = {
val d = Duration.ofMillis(ms)
val hours = StringUtils.leftPad(d.toHoursPart.toString, 2, "0")
val minutes = StringUtils.leftPad(d.toMinutesPart.toString, 2, "0")
val seconds = StringUtils.leftPad(d.toSecondsPart.toString, 2, "0")
val milliSeconds = StringUtils.leftPad(d.toMillisPart.toString, 3, "0")
s"$hours:$minutes:$seconds,$milliSeconds"
}

private def formatOutBlock(block: SubtitleBlock, blockCounter: Long) = {
val ls = sys.props("line.separator")
// Spec: https://wiki.videolan.org/SubRip
val outputFormatted = s"$blockCounter$ls${toTime(block.start)} --> ${toTime(block.end)}$ls${block.lines.mkString("\n")}$ls$ls"
logger.info(s"Writing block:$ls {}", outputFormatted)
outputFormatted
}

def terminateWhen(done: Future[IOResult]) = {
done.onComplete {
case Success(_) =>
println(s"Flow Success. Finished writing to target file: $targetFilePath. Around $totalTokensUsed tokens used. About to terminate...")
logger.info(s"Flow Success. Finished writing to target file: $targetFilePath. Around $totalTokensUsed tokens used. About to terminate...")
system.terminate()
case Failure(e) =>
println(s"Flow Failure: $e. Partial translations are in target file: $targetFilePath About to terminate...")
logger.info(s"Flow Failure: $e. Partial translations are in target file: $targetFilePath About to terminate...")
system.terminate()
}
}
Expand Down

0 comments on commit 494387b

Please sign in to comment.