Skip to content

Commit 1dbfd84

Browse files
committed
Eval after each ODE step for accurate progress.
1 parent 4a7cba3 commit 1dbfd84

File tree

3 files changed

+10
-12
lines changed

3 files changed

+10
-12
lines changed

Sources/F5TTS/DiT.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,7 @@ class InputEmbedding: Module {
7777

7878
let combined = MLX.concatenated([x, cond, textEmbed], axis: -1)
7979
var output = proj(combined)
80-
output.eval()
8180
output = conv_pos_embed(output) + output
82-
output.eval()
8381
return output
8482
}
8583
}

Sources/F5TTS/F5TTS.swift

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,10 @@ public class F5TTS: Module {
141141

142142
progressHandler?(Double(t))
143143

144-
return pred + (pred - nullPred) * cfgStrength
144+
let output = pred + (pred - nullPred) * cfgStrength
145+
output.eval()
146+
147+
return output
145148
}
146149

147150
// noise input
@@ -165,11 +168,10 @@ public class F5TTS: Module {
165168
let trajectory = self.odeint(fun: fn, y0: y0Padded, t: t)
166169
let sampled = trajectory[-1]
167170
var out = MLX.where(condMask, cond, sampled)
168-
171+
169172
if let vocoder = vocoder {
170173
out = vocoder(out)
171174
}
172-
173175
out.eval()
174176

175177
return (out, trajectory)
@@ -239,6 +241,8 @@ public class F5TTS: Module {
239241
}
240242

241243
let generatedAudio = outputAudio[audio.shape[0]...]
244+
245+
print("Got generated audio of shape: \(generatedAudio.shape)")
242246
return generatedAudio
243247
}
244248
}
@@ -331,9 +335,8 @@ public extension F5TTS {
331335

332336
static func estimatedDuration(refAudio: MLXArray, refText: String, text: String, speed: Double = 1.0) -> TimeInterval {
333337
let refDurationInFrames = refAudio.shape[0] / self.hopLength
334-
let pausePunctuation = "。,、;:?!"
335-
let refTextLength = refText.utf8.count + 3 * pausePunctuation.utf8.count
336-
let genTextLength = text.utf8.count + 3 * pausePunctuation.utf8.count
338+
let refTextLength = refText.utf8.count
339+
let genTextLength = text.utf8.count
337340

338341
let refAudioToTextRatio = Double(refDurationInFrames) / Double(refTextLength)
339342
let textLength = Double(genTextLength) / speed

Sources/F5TTS/Modules.swift

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,7 @@ func precomputeFreqsCis(dim: Int, end: Int, theta: Float = 10000.0, thetaRescale
5858
let freqsCos = outerFreqs.cos()
5959
let freqsSin = outerFreqs.sin()
6060

61-
let output = MLX.concatenated([freqsCos, freqsSin], axis: -1)
62-
output.eval()
63-
64-
return output
61+
return MLX.concatenated([freqsCos, freqsSin], axis: -1)
6562
}
6663

6764
func getPosEmbedIndices(start: MLXArray, length: Int, maxPos: Int, scale: Float = 1.0) -> MLXArray {

0 commit comments

Comments
 (0)