From 36899100db60159022fa20620facfa0404871c64 Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Wed, 11 Dec 2024 15:47:46 -0800 Subject: [PATCH] Add rk4 sampling and use it by default. --- Package.resolved | 12 +- Sources/F5TTS/F5TTS.swift | 119 ++++++++++++------ Sources/f5-tts-generate/GenerateCommand.swift | 8 ++ 3 files changed, 97 insertions(+), 42 deletions(-) diff --git a/Package.resolved b/Package.resolved index 04b6b46..8f05b2d 100644 --- a/Package.resolved +++ b/Package.resolved @@ -5,8 +5,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/maiqingqiang/Jinja", "state" : { - "revision" : "b435eb62b0d3d5f34167ec70a128355486981712", - "version" : "1.0.5" + "revision" : "6dbe4c449469fb586d0f7339f900f0dd4d78b167", + "version" : "1.0.6" } }, { @@ -14,8 +14,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/ml-explore/mlx-swift", "state" : { - "revision" : "78a7cfe6701d6e9c88e9d4a0d1f7990af84b2146", - "version" : "0.18.0" + "revision" : "70dbb62128a5a1471a5ab80363430adb33470cab", + "version" : "0.21.2" } }, { @@ -41,8 +41,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/huggingface/swift-transformers", "state" : { - "revision" : "4d25d20e49d2269aec1556231f8e278db7b2a4f0", - "version" : "0.1.13" + "revision" : "d42fdae473c49ea216671da8caae58e102d28709", + "version" : "0.1.14" } }, { diff --git a/Sources/F5TTS/F5TTS.swift b/Sources/F5TTS/F5TTS.swift index a5fa6e3..d8345bf 100644 --- a/Sources/F5TTS/F5TTS.swift +++ b/Sources/F5TTS/F5TTS.swift @@ -7,7 +7,74 @@ import Vocos // MARK: - F5TTS +func odeint_euler(fun: (Float, MLXArray) -> MLXArray, y0: MLXArray, t: MLXArray) -> MLXArray { + var ys = [y0] + var yCurrent = y0 + + for i in 0..<(t.shape[0] - 1) { + let tCurrent = t[i].item(Float.self) + let dt = t[i + 1].item(Float.self) - tCurrent + + let k = fun(tCurrent, yCurrent) + let yNext = yCurrent + dt * k + + ys.append(yNext) + yCurrent = yNext + } + + return MLX.stacked(ys, axis: 0) +} + +func odeint_midpoint(fun: (Float, MLXArray) -> MLXArray, y0: MLXArray, t: MLXArray) -> MLXArray { + var ys = [y0] + var yCurrent = y0 + + for i in 0..<(t.shape[0] - 1) { + let tCurrent = t[i].item(Float.self) + let dt = t[i + 1].item(Float.self) - tCurrent + + let k1 = fun(tCurrent, yCurrent) + let mid = yCurrent + 0.5 * dt * k1 + + let k2 = fun(tCurrent + 0.5 * dt, mid) + let yNext = yCurrent + dt * k2 + + ys.append(yNext) + yCurrent = yNext + } + + return MLX.stacked(ys, axis: 0) +} + +func odeint_rk4(fun: (Float, MLXArray) -> MLXArray, y0: MLXArray, t: MLXArray) -> MLXArray { + var ys = [y0] + var yCurrent = y0 + + for i in 0..<(t.shape[0] - 1) { + let tCurrent = t[i].item(Float.self) + let dt = t[i + 1].item(Float.self) - tCurrent + + let k1 = fun(tCurrent, yCurrent) + let k2 = fun(tCurrent + 0.5 * dt, yCurrent + 0.5 * dt * k1) + let k3 = fun(tCurrent + 0.5 * dt, yCurrent + 0.5 * dt * k2) + let k4 = fun(tCurrent + dt, yCurrent + dt * k3) + + let yNext = yCurrent + (dt / 6) * (k1 + 2 * k2 + 2 * k3 + k4) + + ys.append(yNext) + yCurrent = yNext + } + + return MLX.stacked(ys) +} + public class F5TTS: Module { + public enum ODEMethod: String { + case euler + case midpoint + case rk4 + } + enum F5TTSError: Error { case unableToLoadModel case unableToLoadReferenceAudio @@ -38,40 +105,18 @@ public class F5TTS: Module { super.init() } - private func odeint(fun: (Float, MLXArray) -> MLXArray, y0: MLXArray, t: MLXArray) -> MLXArray { - var ys = [y0] - var yCurrent = y0 - - for i in 0..<(t.shape[0] - 1) { - let tCurrent = t[i].item(Float.self) - let dt = t[i + 1].item(Float.self) - tCurrent - - let k1 = fun(tCurrent, yCurrent) - let mid = yCurrent + 0.5 * dt * k1 - - let k2 = fun(tCurrent + 0.5 * dt, mid) - let yNext = yCurrent + dt * k2 - - ys.append(yNext) - yCurrent = yNext - } - - return MLX.stacked(ys, axis: 0) - } - private func sample( cond: MLXArray, text: [String], duration: Int? = nil, lens: MLXArray? = nil, - steps: Int = 32, + steps: Int = 8, + method: ODEMethod = .rk4, cfgStrength: Double = 2.0, swayCoef: Double? = -1.0, seed: Int? = nil, maxDuration: Int = 4096, vocoder: ((MLXArray) -> MLXArray)? = nil, - noRefAudio: Bool = false, - editMask: MLXArray? = nil, progressHandler: ((Double) -> Void)? = nil ) throws -> (MLXArray, MLXArray) { MLX.eval(self.parameters()) @@ -96,9 +141,6 @@ public class F5TTS: Module { lens = MLX.maximum(textLens, lens) var condMask = lensToMask(t: lens) - if let editMask = editMask { - condMask = condMask & editMask - } // duration var resolvedDuration: MLXArray? = (duration != nil) ? MLXArray(duration!) : nil @@ -125,10 +167,6 @@ public class F5TTS: Module { let mask: MLXArray? = (batch > 1) ? lensToMask(t: duration) : nil - if noRefAudio { - cond = MLX.zeros(like: cond) - } - // neural ode let fn: (Float, MLXArray) -> MLXArray = { t, x in @@ -169,7 +207,7 @@ public class F5TTS: Module { var y0: [MLXArray] = [] for dur in duration { - if let seed = seed { + if let seed { MLXRandom.seed(UInt64(seed)) } let noise = MLXRandom.normal([dur.item(Int.self), self.numChannels]) @@ -183,11 +221,17 @@ public class F5TTS: Module { t = t + coef * (MLX.cos(MLXArray(.pi) / 2 * t) - 1 + t) } - let trajectory = self.odeint(fun: fn, y0: y0Padded, t: t) + let odeintFn = switch method { + case .euler: odeint_euler + case .midpoint: odeint_midpoint + case .rk4: odeint_rk4 + } + + let trajectory = odeintFn(fn, y0Padded, t) let sampled = trajectory[-1] var out = MLX.where(condMask, cond, sampled) - if let vocoder = vocoder { + if let vocoder { out = vocoder(out) } out.eval() @@ -200,6 +244,8 @@ public class F5TTS: Module { referenceAudioURL: URL? = nil, referenceAudioText: String? = nil, duration: TimeInterval? = nil, + steps: Int = 8, + method: ODEMethod = .rk4, cfg: Double = 2.0, sway: Double = -1.0, speed: Double = 1.0, @@ -234,7 +280,8 @@ public class F5TTS: Module { cond: normalizedAudio.expandedDimensions(axis: 0), text: [processedText], duration: nil, - steps: 32, + steps: steps, + method: method, cfgStrength: cfg, swayCoef: sway, seed: seed, @@ -339,7 +386,7 @@ public extension F5TTS { static var framesPerSecond: Double = .init(sampleRate) / Double(hopLength) static func loadAudioArray(url: URL) throws -> MLXArray { - return try AudioUtilities.loadAudioFile(url: url) + try AudioUtilities.loadAudioFile(url: url) } static func referenceAudio() throws -> (MLXArray, String) { diff --git a/Sources/f5-tts-generate/GenerateCommand.swift b/Sources/f5-tts-generate/GenerateCommand.swift index 59f583c..28aff21 100644 --- a/Sources/f5-tts-generate/GenerateCommand.swift +++ b/Sources/f5-tts-generate/GenerateCommand.swift @@ -24,6 +24,12 @@ struct GenerateAudio: AsyncParsableCommand { @Option(name: .long, help: "Output path for the generated audio") var outputPath: String = "output.wav" + @Option(name: .long, help: "The number of steps to use for ODE sampling") + var steps: Int = 8 + + @Option(name: .long, help: "Method to use for ODE sampling. Options are 'euler', 'midpoint', and 'rk4'.") + var method: String = "rk4" + @Option(name: .long, help: "Strength of classifier free guidance") var cfg: Double = 2.0 @@ -49,6 +55,8 @@ struct GenerateAudio: AsyncParsableCommand { referenceAudioURL: refAudioPath != nil ? URL(filePath: refAudioPath!) : nil, referenceAudioText: refAudioText, duration: duration, + steps: steps, + method: F5TTS.ODEMethod(rawValue: method)!, cfg: cfg, sway: sway, speed: speed,