Files
Supertonic/swift/Sources/Helper.swift
2026-01-25 18:58:40 +09:00

836 lines
31 KiB
Swift
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import Foundation
import Accelerate
import OnnxRuntimeBindings
// MARK: - Available Languages
let AVAILABLE_LANGS = ["en", "ko", "es", "pt", "fr"]
func isValidLang(_ lang: String) -> Bool {
return AVAILABLE_LANGS.contains(lang)
}
// MARK: - Configuration Structures
struct Config: Codable {
struct AEConfig: Codable {
let sample_rate: Int
let base_chunk_size: Int
}
struct TTLConfig: Codable {
let chunk_compress_factor: Int
let latent_dim: Int
}
let ae: AEConfig
let ttl: TTLConfig
}
// MARK: - Voice Style Data Structure
struct VoiceStyleData: Codable {
struct StyleComponent: Codable {
let data: [[[Float]]]
let dims: [Int]
let type: String
}
let style_ttl: StyleComponent
let style_dp: StyleComponent
}
// MARK: - Unicode Text Processor
class UnicodeProcessor {
let indexer: [Int64]
init(unicodeIndexerPath: String) throws {
let data = try Data(contentsOf: URL(fileURLWithPath: unicodeIndexerPath))
self.indexer = try JSONDecoder().decode([Int64].self, from: data)
}
func call(_ textList: [String], _ langList: [String]) -> (textIds: [[Int64]], textMask: [[[Float]]]) {
var processedTexts = [String]()
for (i, text) in textList.enumerated() {
processedTexts.append(preprocessText(text, lang: langList[i]))
}
// Use unicodeScalars.count for correct length after NFKD decomposition
var textIdsLengths = [Int]()
for text in processedTexts {
textIdsLengths.append(text.unicodeScalars.count)
}
let maxLen = textIdsLengths.max() ?? 0
var textIds = [[Int64]]()
for text in processedTexts {
var row = Array(repeating: Int64(0), count: maxLen)
let unicodeValues = Array(text.unicodeScalars.map { Int($0.value) })
for (j, val) in unicodeValues.enumerated() {
if val < indexer.count {
row[j] = indexer[val]
} else {
row[j] = -1
}
}
textIds.append(row)
}
let textMask = getTextMask(textIdsLengths)
return (textIds, textMask)
}
}
func preprocessText(_ text: String, lang: String) -> String {
// Use NFKD (decomposed) for proper Hangul Jamo decomposition
var text = text.decomposedStringWithCompatibilityMapping
// Remove emojis (wide Unicode range)
// Swift NSRegularExpression doesn't support Unicode escapes above \uFFFF
// Use character filtering instead
text = text.unicodeScalars.filter { scalar in
let value = scalar.value
return !((value >= 0x1F600 && value <= 0x1F64F) ||
(value >= 0x1F300 && value <= 0x1F5FF) ||
(value >= 0x1F680 && value <= 0x1F6FF) ||
(value >= 0x1F700 && value <= 0x1F77F) ||
(value >= 0x1F780 && value <= 0x1F7FF) ||
(value >= 0x1F800 && value <= 0x1F8FF) ||
(value >= 0x1F900 && value <= 0x1F9FF) ||
(value >= 0x1FA00 && value <= 0x1FA6F) ||
(value >= 0x1FA70 && value <= 0x1FAFF) ||
(value >= 0x2600 && value <= 0x26FF) ||
(value >= 0x2700 && value <= 0x27BF) ||
(value >= 0x1F1E6 && value <= 0x1F1FF))
}.map { String($0) }.joined()
// Replace various dashes and symbols
let replacements: [String: String] = [
"": "-", // en dash
"": "-", // non-breaking hyphen
"": "-", // em dash
"_": " ", // underscore
"\u{201C}": "\"", // left double quote
"\u{201D}": "\"", // right double quote
"\u{2018}": "'", // left single quote
"\u{2019}": "'", // right single quote
"´": "'", // acute accent
"`": "'", // grave accent
"[": " ", // left bracket
"]": " ", // right bracket
"|": " ", // vertical bar
"/": " ", // slash
"#": " ", // hash
"": " ", // right arrow
"": " ", // left arrow
]
for (old, new) in replacements {
text = text.replacingOccurrences(of: old, with: new)
}
// Remove special symbols
let specialSymbols = ["", "", "", "©", "\\"]
for symbol in specialSymbols {
text = text.replacingOccurrences(of: symbol, with: "")
}
// Replace known expressions
let exprReplacements: [String: String] = [
"@": " at ",
"e.g.,": "for example, ",
"i.e.,": "that is, ",
]
for (old, new) in exprReplacements {
text = text.replacingOccurrences(of: old, with: new)
}
// Fix spacing around punctuation
text = text.replacingOccurrences(of: " ,", with: ",")
text = text.replacingOccurrences(of: " .", with: ".")
text = text.replacingOccurrences(of: " !", with: "!")
text = text.replacingOccurrences(of: " ?", with: "?")
text = text.replacingOccurrences(of: " ;", with: ";")
text = text.replacingOccurrences(of: " :", with: ":")
text = text.replacingOccurrences(of: " '", with: "'")
// Remove duplicate quotes
while text.contains("\"\"") {
text = text.replacingOccurrences(of: "\"\"", with: "\"")
}
while text.contains("''") {
text = text.replacingOccurrences(of: "''", with: "'")
}
while text.contains("``") {
text = text.replacingOccurrences(of: "``", with: "`")
}
// Remove extra spaces
let whitespacePattern = try! NSRegularExpression(pattern: "\\s+")
let whitespaceRange = NSRange(text.startIndex..., in: text)
text = whitespacePattern.stringByReplacingMatches(in: text, range: whitespaceRange, withTemplate: " ")
text = text.trimmingCharacters(in: .whitespacesAndNewlines)
// If text doesn't end with punctuation, quotes, or closing brackets, add a period
if !text.isEmpty {
let punctPattern = try! NSRegularExpression(pattern: "[.!?;:,'\"\\u201C\\u201D\\u2018\\u2019)\\]}…。」』】〉》›»]$")
let punctRange = NSRange(text.startIndex..., in: text)
if punctPattern.firstMatch(in: text, range: punctRange) == nil {
text += "."
}
}
// Validate language
guard isValidLang(lang) else {
fatalError("Invalid language: \(lang). Available: \(AVAILABLE_LANGS.joined(separator: ", "))")
}
// Wrap text with language tags
text = "<\(lang)>\(text)</\(lang)>"
return text
}
func lengthToMask(_ lengths: [Int], maxLen: Int? = nil) -> [[[Float]]] {
let actualMaxLen = maxLen ?? (lengths.max() ?? 0)
var mask = [[[Float]]]()
for len in lengths {
var row = Array(repeating: Float(0.0), count: actualMaxLen)
for j in 0..<min(len, actualMaxLen) {
row[j] = 1.0
}
mask.append([row])
}
return mask
}
func getTextMask(_ textIdsLengths: [Int]) -> [[[Float]]] {
let maxLen = textIdsLengths.max() ?? 0
return lengthToMask(textIdsLengths, maxLen: maxLen)
}
func sampleNoisyLatent(duration: [Float], sampleRate: Int, baseChunkSize: Int, chunkCompress: Int, latentDim: Int) -> (noisyLatent: [[[Float]]], latentMask: [[[Float]]]) {
let bsz = duration.count
let maxDur = duration.max() ?? 0.0
let wavLenMax = Int(maxDur * Float(sampleRate))
var wavLengths = [Int]()
for d in duration {
wavLengths.append(Int(d * Float(sampleRate)))
}
let chunkSize = baseChunkSize * chunkCompress
let latentLen = (wavLenMax + chunkSize - 1) / chunkSize
let latentDimVal = latentDim * chunkCompress
var noisyLatent = [[[Float]]]()
for _ in 0..<bsz {
var batch = [[Float]]()
for _ in 0..<latentDimVal {
var row = [Float]()
for _ in 0..<latentLen {
// Box-Muller transform
let u1 = Float.random(in: 0.0001...1.0)
let u2 = Float.random(in: 0.0...1.0)
let val = sqrt(-2.0 * log(u1)) * cos(2.0 * Float.pi * u2)
row.append(val)
}
batch.append(row)
}
noisyLatent.append(batch)
}
var latentLengths = [Int]()
for len in wavLengths {
latentLengths.append((len + chunkSize - 1) / chunkSize)
}
let latentMask = lengthToMask(latentLengths, maxLen: latentLen)
// Apply mask
for b in 0..<bsz {
for d in 0..<latentDimVal {
for t in 0..<latentLen {
noisyLatent[b][d][t] *= latentMask[b][0][t]
}
}
}
return (noisyLatent, latentMask)
}
func getLatentMask(_ wavLengths: [Int64], _ cfgs: Config) -> [[[Float]]] {
let baseChunkSize = cfgs.ae.base_chunk_size
let chunkCompressFactor = cfgs.ttl.chunk_compress_factor
let latentSize = baseChunkSize * chunkCompressFactor
var latentLengths = [Int]()
for len in wavLengths {
latentLengths.append((Int(len) + latentSize - 1) / latentSize)
}
let maxLen = latentLengths.max() ?? 0
return lengthToMask(latentLengths, maxLen: maxLen)
}
// MARK: - WAV File I/O
func writeWavFile(_ filename: String, _ audioData: [Float], _ sampleRate: Int) throws {
let url = URL(fileURLWithPath: filename)
// Convert float to int16
let int16Data = audioData.map { sample -> Int16 in
let clamped = max(-1.0, min(1.0, sample))
return Int16(clamped * 32767.0)
}
// Create WAV header
let numChannels: UInt16 = 1
let bitsPerSample: UInt16 = 16
let byteRate = UInt32(sampleRate) * UInt32(numChannels) * UInt32(bitsPerSample) / 8
let blockAlign = numChannels * bitsPerSample / 8
let dataSize = UInt32(int16Data.count * 2)
var data = Data()
// RIFF chunk
data.append("RIFF".data(using: .ascii)!)
withUnsafeBytes(of: UInt32(36 + dataSize).littleEndian) { data.append(contentsOf: $0) }
data.append("WAVE".data(using: .ascii)!)
// fmt chunk
data.append("fmt ".data(using: .ascii)!)
withUnsafeBytes(of: UInt32(16).littleEndian) { data.append(contentsOf: $0) }
withUnsafeBytes(of: UInt16(1).littleEndian) { data.append(contentsOf: $0) } // PCM
withUnsafeBytes(of: numChannels.littleEndian) { data.append(contentsOf: $0) }
withUnsafeBytes(of: UInt32(sampleRate).littleEndian) { data.append(contentsOf: $0) }
withUnsafeBytes(of: byteRate.littleEndian) { data.append(contentsOf: $0) }
withUnsafeBytes(of: blockAlign.littleEndian) { data.append(contentsOf: $0) }
withUnsafeBytes(of: bitsPerSample.littleEndian) { data.append(contentsOf: $0) }
// data chunk
data.append("data".data(using: .ascii)!)
withUnsafeBytes(of: dataSize.littleEndian) { data.append(contentsOf: $0) }
// audio data
int16Data.withUnsafeBytes { data.append(contentsOf: $0) }
try data.write(to: url)
}
// MARK: - Text Chunking
let MAX_CHUNK_LENGTH = 300
let ABBREVIATIONS = [
"Dr.", "Mr.", "Mrs.", "Ms.", "Prof.", "Sr.", "Jr.",
"St.", "Ave.", "Rd.", "Blvd.", "Dept.", "Inc.", "Ltd.",
"Co.", "Corp.", "etc.", "vs.", "i.e.", "e.g.", "Ph.D."
]
func chunkText(_ text: String, maxLen: Int = 0) -> [String] {
let actualMaxLen = maxLen > 0 ? maxLen : MAX_CHUNK_LENGTH
let trimmedText = text.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines)
if trimmedText.isEmpty {
return [""]
}
// Split by paragraphs using regex
let paraPattern = try! NSRegularExpression(pattern: "\\n\\s*\\n")
let paraRange = NSRange(trimmedText.startIndex..., in: trimmedText)
var paragraphs = [String]()
var lastEnd = trimmedText.startIndex
paraPattern.enumerateMatches(in: trimmedText, range: paraRange) { match, _, _ in
if let match = match, let range = Range(match.range, in: trimmedText) {
paragraphs.append(String(trimmedText[lastEnd..<range.lowerBound]))
lastEnd = range.upperBound
}
}
if lastEnd < trimmedText.endIndex {
paragraphs.append(String(trimmedText[lastEnd...]))
}
if paragraphs.isEmpty {
paragraphs = [trimmedText]
}
var chunks = [String]()
for para in paragraphs {
let trimmedPara = para.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines)
if trimmedPara.isEmpty {
continue
}
if trimmedPara.count <= actualMaxLen {
chunks.append(trimmedPara)
continue
}
// Split by sentences
let sentences = splitSentences(trimmedPara)
var current = ""
var currentLen = 0
for sentence in sentences {
let trimmedSentence = sentence.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines)
if trimmedSentence.isEmpty {
continue
}
let sentenceLen = trimmedSentence.count
if sentenceLen > actualMaxLen {
// If sentence is longer than maxLen, split by comma or space
if !current.isEmpty {
chunks.append(current.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines))
current = ""
currentLen = 0
}
// Try splitting by comma
let parts = trimmedSentence.components(separatedBy: ",")
for part in parts {
let trimmedPart = part.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines)
if trimmedPart.isEmpty {
continue
}
let partLen = trimmedPart.count
if partLen > actualMaxLen {
// Split by space as last resort
let words = trimmedPart.components(separatedBy: CharacterSet.whitespaces).filter { !$0.isEmpty }
var wordChunk = ""
var wordChunkLen = 0
for word in words {
let wordLen = word.count
if wordChunkLen + wordLen + 1 > actualMaxLen && !wordChunk.isEmpty {
chunks.append(wordChunk.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines))
wordChunk = ""
wordChunkLen = 0
}
if !wordChunk.isEmpty {
wordChunk += " "
wordChunkLen += 1
}
wordChunk += word
wordChunkLen += wordLen
}
if !wordChunk.isEmpty {
chunks.append(wordChunk.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines))
}
} else {
if currentLen + partLen + 1 > actualMaxLen && !current.isEmpty {
chunks.append(current.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines))
current = ""
currentLen = 0
}
if !current.isEmpty {
current += ", "
currentLen += 2
}
current += trimmedPart
currentLen += partLen
}
}
continue
}
if currentLen + sentenceLen + 1 > actualMaxLen && !current.isEmpty {
chunks.append(current.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines))
current = ""
currentLen = 0
}
if !current.isEmpty {
current += " "
currentLen += 1
}
current += trimmedSentence
currentLen += sentenceLen
}
if !current.isEmpty {
chunks.append(current.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines))
}
}
return chunks.isEmpty ? [""] : chunks
}
func splitSentences(_ text: String) -> [String] {
// Swift's regex doesn't support lookbehind reliably, so we use a simpler approach
// Split on sentence boundaries and then check if they're abbreviations
let regex = try! NSRegularExpression(pattern: "([.!?])\\s+")
let range = NSRange(text.startIndex..., in: text)
// Find all matches
let matches = regex.matches(in: text, range: range)
if matches.isEmpty {
return [text]
}
var sentences = [String]()
var lastEnd = text.startIndex
for match in matches {
guard let matchRange = Range(match.range, in: text) else { continue }
// Get the text before the punctuation
let beforePunc = String(text[lastEnd..<matchRange.lowerBound])
// Get the punctuation character
let puncRange = Range(NSRange(location: match.range.location, length: 1), in: text)!
let punc = String(text[puncRange])
// Check if this ends with an abbreviation
var isAbbrev = false
let combined = beforePunc.trimmingCharacters(in: CharacterSet.whitespaces) + punc
for abbrev in ABBREVIATIONS {
if combined.hasSuffix(abbrev) {
isAbbrev = true
break
}
}
if !isAbbrev {
// This is a real sentence boundary
sentences.append(String(text[lastEnd..<matchRange.upperBound]))
lastEnd = matchRange.upperBound
}
}
// Add the remaining text
if lastEnd < text.endIndex {
sentences.append(String(text[lastEnd...]))
}
return sentences.isEmpty ? [text] : sentences
}
// MARK: - Utility Functions
func timer<T>(_ name: String, _ f: () throws -> T) rethrows -> T {
let start = Date()
print("\(name)...")
let result = try f()
let elapsed = Date().timeIntervalSince(start)
print(String(format: " -> %@ completed in %.2f sec", name, elapsed))
return result
}
func sanitizeFilename(_ text: String, maxLen: Int) -> String {
let truncated = text.count > maxLen ? String(text.prefix(maxLen)) : text
return truncated.map { char in
if char.isLetter || char.isNumber {
return char
} else {
return Character("_")
}
}.map(String.init).joined()
}
func loadCfgs(_ onnxDir: String) throws -> Config {
let cfgPath = "\(onnxDir)/tts.json"
let data = try Data(contentsOf: URL(fileURLWithPath: cfgPath))
let config = try JSONDecoder().decode(Config.self, from: data)
return config
}
// MARK: - ONNX Runtime Integration
struct Style {
let ttl: ORTValue
let dp: ORTValue
}
class TextToSpeech {
let cfgs: Config
let textProcessor: UnicodeProcessor
let dpOrt: ORTSession
let textEncOrt: ORTSession
let vectorEstOrt: ORTSession
let vocoderOrt: ORTSession
let sampleRate: Int
init(cfgs: Config, textProcessor: UnicodeProcessor,
dpOrt: ORTSession, textEncOrt: ORTSession,
vectorEstOrt: ORTSession, vocoderOrt: ORTSession) {
self.cfgs = cfgs
self.textProcessor = textProcessor
self.dpOrt = dpOrt
self.textEncOrt = textEncOrt
self.vectorEstOrt = vectorEstOrt
self.vocoderOrt = vocoderOrt
self.sampleRate = cfgs.ae.sample_rate
}
private func _infer(_ textList: [String], _ langList: [String], _ style: Style, _ totalStep: Int, speed: Float = 1.05) throws -> (wav: [Float], duration: [Float]) {
let bsz = textList.count
// Process text
let (textIds, textMask) = textProcessor.call(textList, langList)
// Flatten text IDs
let textIdsFlat = textIds.flatMap { $0 }
let textIdsShape: [NSNumber] = [NSNumber(value: bsz), NSNumber(value: textIds[0].count)]
let textIdsValue = try ORTValue(tensorData: NSMutableData(bytes: textIdsFlat, length: textIdsFlat.count * MemoryLayout<Int64>.size),
elementType: .int64,
shape: textIdsShape)
// Flatten text mask
let textMaskFlat = textMask.flatMap { $0.flatMap { $0 } }
let textMaskShape: [NSNumber] = [NSNumber(value: bsz), 1, NSNumber(value: textMask[0][0].count)]
let textMaskValue = try ORTValue(tensorData: NSMutableData(bytes: textMaskFlat, length: textMaskFlat.count * MemoryLayout<Float>.size),
elementType: .float,
shape: textMaskShape)
// Predict duration
let dpOutputs = try dpOrt.run(withInputs: ["text_ids": textIdsValue, "style_dp": style.dp, "text_mask": textMaskValue],
outputNames: ["duration"],
runOptions: nil)
let durationData = try dpOutputs["duration"]!.tensorData() as Data
var duration = durationData.withUnsafeBytes { ptr in
Array(ptr.bindMemory(to: Float.self))
}
// Apply speed factor to duration
for i in 0..<duration.count {
duration[i] /= speed
}
// Encode text
let textEncOutputs = try textEncOrt.run(withInputs: ["text_ids": textIdsValue, "style_ttl": style.ttl, "text_mask": textMaskValue],
outputNames: ["text_emb"],
runOptions: nil)
let textEmbValue = textEncOutputs["text_emb"]!
// Sample noisy latent
var (xt, latentMask) = sampleNoisyLatent(duration: duration, sampleRate: sampleRate,
baseChunkSize: cfgs.ae.base_chunk_size,
chunkCompress: cfgs.ttl.chunk_compress_factor,
latentDim: cfgs.ttl.latent_dim)
// Prepare constant arrays
let totalStepArray = Array(repeating: Float(totalStep), count: bsz)
let totalStepValue = try ORTValue(tensorData: NSMutableData(bytes: totalStepArray, length: totalStepArray.count * MemoryLayout<Float>.size),
elementType: .float,
shape: [NSNumber(value: bsz)])
// Denoising loop
for step in 0..<totalStep {
let currentStepArray = Array(repeating: Float(step), count: bsz)
let currentStepValue = try ORTValue(tensorData: NSMutableData(bytes: currentStepArray, length: currentStepArray.count * MemoryLayout<Float>.size),
elementType: .float,
shape: [NSNumber(value: bsz)])
// Flatten xt
let xtFlat = xt.flatMap { $0.flatMap { $0 } }
let xtShape: [NSNumber] = [NSNumber(value: bsz), NSNumber(value: xt[0].count), NSNumber(value: xt[0][0].count)]
let xtValue = try ORTValue(tensorData: NSMutableData(bytes: xtFlat, length: xtFlat.count * MemoryLayout<Float>.size),
elementType: .float,
shape: xtShape)
// Flatten latent mask
let latentMaskFlat = latentMask.flatMap { $0.flatMap { $0 } }
let latentMaskShape: [NSNumber] = [NSNumber(value: bsz), 1, NSNumber(value: latentMask[0][0].count)]
let latentMaskValue = try ORTValue(tensorData: NSMutableData(bytes: latentMaskFlat, length: latentMaskFlat.count * MemoryLayout<Float>.size),
elementType: .float,
shape: latentMaskShape)
let vectorEstOutputs = try vectorEstOrt.run(withInputs: [
"noisy_latent": xtValue,
"text_emb": textEmbValue,
"style_ttl": style.ttl,
"latent_mask": latentMaskValue,
"text_mask": textMaskValue,
"current_step": currentStepValue,
"total_step": totalStepValue
], outputNames: ["denoised_latent"], runOptions: nil)
let denoisedData = try vectorEstOutputs["denoised_latent"]!.tensorData() as Data
let denoisedFlat = denoisedData.withUnsafeBytes { ptr in
Array(ptr.bindMemory(to: Float.self))
}
// Reshape to 3D
let latentDimVal = xt[0].count
let latentLen = xt[0][0].count
xt = []
var idx = 0
for _ in 0..<bsz {
var batch = [[Float]]()
for _ in 0..<latentDimVal {
var row = [Float]()
for _ in 0..<latentLen {
row.append(denoisedFlat[idx])
idx += 1
}
batch.append(row)
}
xt.append(batch)
}
}
// Generate waveform
let finalXtFlat = xt.flatMap { $0.flatMap { $0 } }
let finalXtShape: [NSNumber] = [NSNumber(value: bsz), NSNumber(value: xt[0].count), NSNumber(value: xt[0][0].count)]
let finalXtValue = try ORTValue(tensorData: NSMutableData(bytes: finalXtFlat, length: finalXtFlat.count * MemoryLayout<Float>.size),
elementType: .float,
shape: finalXtShape)
let vocoderOutputs = try vocoderOrt.run(withInputs: ["latent": finalXtValue],
outputNames: ["wav_tts"],
runOptions: nil)
let wavData = try vocoderOutputs["wav_tts"]!.tensorData() as Data
let wav = wavData.withUnsafeBytes { ptr in
Array(ptr.bindMemory(to: Float.self))
}
return (wav, duration)
}
func call(_ text: String, _ lang: String, _ style: Style, _ totalStep: Int, speed: Float = 1.05, silenceDuration: Float = 0.3) throws -> (wav: [Float], duration: Float) {
let maxLen = lang == "ko" ? 120 : 300
let chunks = chunkText(text, maxLen: maxLen)
let langList = Array(repeating: lang, count: chunks.count)
var wavCat = [Float]()
var durCat: Float = 0.0
for (i, chunk) in chunks.enumerated() {
let result = try _infer([chunk], [langList[i]], style, totalStep, speed: speed)
let dur = result.duration[0]
let wavLen = Int(Float(sampleRate) * dur)
let wavChunk = Array(result.wav.prefix(wavLen))
if i == 0 {
wavCat = wavChunk
durCat = dur
} else {
let silenceLen = Int(silenceDuration * Float(sampleRate))
let silence = [Float](repeating: 0.0, count: silenceLen)
wavCat.append(contentsOf: silence)
wavCat.append(contentsOf: wavChunk)
durCat += silenceDuration + dur
}
}
return (wavCat, durCat)
}
func batch(_ textList: [String], _ langList: [String], _ style: Style, _ totalStep: Int, speed: Float = 1.05) throws -> (wav: [Float], duration: [Float]) {
return try _infer(textList, langList, style, totalStep, speed: speed)
}
}
// MARK: - Component Loading Functions
func loadVoiceStyle(_ voiceStylePaths: [String], verbose: Bool) throws -> Style {
let bsz = voiceStylePaths.count
// Read first file to get dimensions
let firstData = try Data(contentsOf: URL(fileURLWithPath: voiceStylePaths[0]))
let firstStyle = try JSONDecoder().decode(VoiceStyleData.self, from: firstData)
let ttlDims = firstStyle.style_ttl.dims
let dpDims = firstStyle.style_dp.dims
let ttlDim1 = ttlDims[1]
let ttlDim2 = ttlDims[2]
let dpDim1 = dpDims[1]
let dpDim2 = dpDims[2]
// Pre-allocate arrays with full batch size
let ttlSize = bsz * ttlDim1 * ttlDim2
let dpSize = bsz * dpDim1 * dpDim2
var ttlFlat = [Float](repeating: 0.0, count: ttlSize)
var dpFlat = [Float](repeating: 0.0, count: dpSize)
// Fill in the data
for (i, path) in voiceStylePaths.enumerated() {
let data = try Data(contentsOf: URL(fileURLWithPath: path))
let voiceStyle = try JSONDecoder().decode(VoiceStyleData.self, from: data)
// Flatten TTL data
let ttlOffset = i * ttlDim1 * ttlDim2
var idx = 0
for batch in voiceStyle.style_ttl.data {
for row in batch {
for val in row {
ttlFlat[ttlOffset + idx] = val
idx += 1
}
}
}
// Flatten DP data
let dpOffset = i * dpDim1 * dpDim2
idx = 0
for batch in voiceStyle.style_dp.data {
for row in batch {
for val in row {
dpFlat[dpOffset + idx] = val
idx += 1
}
}
}
}
let ttlShape: [NSNumber] = [NSNumber(value: bsz), NSNumber(value: ttlDim1), NSNumber(value: ttlDim2)]
let dpShape: [NSNumber] = [NSNumber(value: bsz), NSNumber(value: dpDim1), NSNumber(value: dpDim2)]
let ttlValue = try ORTValue(tensorData: NSMutableData(bytes: &ttlFlat, length: ttlFlat.count * MemoryLayout<Float>.size),
elementType: .float,
shape: ttlShape)
let dpValue = try ORTValue(tensorData: NSMutableData(bytes: &dpFlat, length: dpFlat.count * MemoryLayout<Float>.size),
elementType: .float,
shape: dpShape)
if verbose {
print("Loaded \(bsz) voice styles\n")
}
return Style(ttl: ttlValue, dp: dpValue)
}
func loadTextToSpeech(_ onnxDir: String, _ useGpu: Bool, _ env: ORTEnv) throws -> TextToSpeech {
if useGpu {
throw NSError(domain: "TTS", code: 1, userInfo: [NSLocalizedDescriptionKey: "GPU mode is not supported yet"])
}
print("Using CPU for inference\n")
let cfgs = try loadCfgs(onnxDir)
let sessionOptions = try ORTSessionOptions()
let dpPath = "\(onnxDir)/duration_predictor.onnx"
let textEncPath = "\(onnxDir)/text_encoder.onnx"
let vectorEstPath = "\(onnxDir)/vector_estimator.onnx"
let vocoderPath = "\(onnxDir)/vocoder.onnx"
let dpOrt = try ORTSession(env: env, modelPath: dpPath, sessionOptions: sessionOptions)
let textEncOrt = try ORTSession(env: env, modelPath: textEncPath, sessionOptions: sessionOptions)
let vectorEstOrt = try ORTSession(env: env, modelPath: vectorEstPath, sessionOptions: sessionOptions)
let vocoderOrt = try ORTSession(env: env, modelPath: vocoderPath, sessionOptions: sessionOptions)
let unicodeIndexerPath = "\(onnxDir)/unicode_indexer.json"
let textProcessor = try UnicodeProcessor(unicodeIndexerPath: unicodeIndexerPath)
return TextToSpeech(cfgs: cfgs, textProcessor: textProcessor,
dpOrt: dpOrt, textEncOrt: textEncOrt,
vectorEstOrt: vectorEstOrt, vocoderOrt: vocoderOrt)
}