initial commit
This commit is contained in:
835
swift/Sources/Helper.swift
Normal file
835
swift/Sources/Helper.swift
Normal file
@@ -0,0 +1,835 @@
|
||||
import Foundation
|
||||
import Accelerate
|
||||
import OnnxRuntimeBindings
|
||||
|
||||
// MARK: - Available Languages
|
||||
|
||||
let AVAILABLE_LANGS = ["en", "ko", "es", "pt", "fr"]
|
||||
|
||||
func isValidLang(_ lang: String) -> Bool {
|
||||
return AVAILABLE_LANGS.contains(lang)
|
||||
}
|
||||
|
||||
// MARK: - Configuration Structures
|
||||
|
||||
struct Config: Codable {
|
||||
struct AEConfig: Codable {
|
||||
let sample_rate: Int
|
||||
let base_chunk_size: Int
|
||||
}
|
||||
|
||||
struct TTLConfig: Codable {
|
||||
let chunk_compress_factor: Int
|
||||
let latent_dim: Int
|
||||
}
|
||||
|
||||
let ae: AEConfig
|
||||
let ttl: TTLConfig
|
||||
}
|
||||
|
||||
// MARK: - Voice Style Data Structure
|
||||
|
||||
struct VoiceStyleData: Codable {
|
||||
struct StyleComponent: Codable {
|
||||
let data: [[[Float]]]
|
||||
let dims: [Int]
|
||||
let type: String
|
||||
}
|
||||
|
||||
let style_ttl: StyleComponent
|
||||
let style_dp: StyleComponent
|
||||
}
|
||||
|
||||
// MARK: - Unicode Text Processor
|
||||
|
||||
class UnicodeProcessor {
|
||||
let indexer: [Int64]
|
||||
|
||||
init(unicodeIndexerPath: String) throws {
|
||||
let data = try Data(contentsOf: URL(fileURLWithPath: unicodeIndexerPath))
|
||||
self.indexer = try JSONDecoder().decode([Int64].self, from: data)
|
||||
}
|
||||
|
||||
func call(_ textList: [String], _ langList: [String]) -> (textIds: [[Int64]], textMask: [[[Float]]]) {
|
||||
var processedTexts = [String]()
|
||||
for (i, text) in textList.enumerated() {
|
||||
processedTexts.append(preprocessText(text, lang: langList[i]))
|
||||
}
|
||||
|
||||
// Use unicodeScalars.count for correct length after NFKD decomposition
|
||||
var textIdsLengths = [Int]()
|
||||
for text in processedTexts {
|
||||
textIdsLengths.append(text.unicodeScalars.count)
|
||||
}
|
||||
|
||||
let maxLen = textIdsLengths.max() ?? 0
|
||||
|
||||
var textIds = [[Int64]]()
|
||||
for text in processedTexts {
|
||||
var row = Array(repeating: Int64(0), count: maxLen)
|
||||
let unicodeValues = Array(text.unicodeScalars.map { Int($0.value) })
|
||||
for (j, val) in unicodeValues.enumerated() {
|
||||
if val < indexer.count {
|
||||
row[j] = indexer[val]
|
||||
} else {
|
||||
row[j] = -1
|
||||
}
|
||||
}
|
||||
textIds.append(row)
|
||||
}
|
||||
|
||||
let textMask = getTextMask(textIdsLengths)
|
||||
return (textIds, textMask)
|
||||
}
|
||||
}
|
||||
|
||||
func preprocessText(_ text: String, lang: String) -> String {
|
||||
// Use NFKD (decomposed) for proper Hangul Jamo decomposition
|
||||
var text = text.decomposedStringWithCompatibilityMapping
|
||||
|
||||
// Remove emojis (wide Unicode range)
|
||||
// Swift NSRegularExpression doesn't support Unicode escapes above \uFFFF
|
||||
// Use character filtering instead
|
||||
text = text.unicodeScalars.filter { scalar in
|
||||
let value = scalar.value
|
||||
return !((value >= 0x1F600 && value <= 0x1F64F) ||
|
||||
(value >= 0x1F300 && value <= 0x1F5FF) ||
|
||||
(value >= 0x1F680 && value <= 0x1F6FF) ||
|
||||
(value >= 0x1F700 && value <= 0x1F77F) ||
|
||||
(value >= 0x1F780 && value <= 0x1F7FF) ||
|
||||
(value >= 0x1F800 && value <= 0x1F8FF) ||
|
||||
(value >= 0x1F900 && value <= 0x1F9FF) ||
|
||||
(value >= 0x1FA00 && value <= 0x1FA6F) ||
|
||||
(value >= 0x1FA70 && value <= 0x1FAFF) ||
|
||||
(value >= 0x2600 && value <= 0x26FF) ||
|
||||
(value >= 0x2700 && value <= 0x27BF) ||
|
||||
(value >= 0x1F1E6 && value <= 0x1F1FF))
|
||||
}.map { String($0) }.joined()
|
||||
|
||||
// Replace various dashes and symbols
|
||||
let replacements: [String: String] = [
|
||||
"–": "-", // en dash
|
||||
"‑": "-", // non-breaking hyphen
|
||||
"—": "-", // em dash
|
||||
"_": " ", // underscore
|
||||
"\u{201C}": "\"", // left double quote
|
||||
"\u{201D}": "\"", // right double quote
|
||||
"\u{2018}": "'", // left single quote
|
||||
"\u{2019}": "'", // right single quote
|
||||
"´": "'", // acute accent
|
||||
"`": "'", // grave accent
|
||||
"[": " ", // left bracket
|
||||
"]": " ", // right bracket
|
||||
"|": " ", // vertical bar
|
||||
"/": " ", // slash
|
||||
"#": " ", // hash
|
||||
"→": " ", // right arrow
|
||||
"←": " ", // left arrow
|
||||
]
|
||||
|
||||
for (old, new) in replacements {
|
||||
text = text.replacingOccurrences(of: old, with: new)
|
||||
}
|
||||
|
||||
// Remove special symbols
|
||||
let specialSymbols = ["♥", "☆", "♡", "©", "\\"]
|
||||
for symbol in specialSymbols {
|
||||
text = text.replacingOccurrences(of: symbol, with: "")
|
||||
}
|
||||
|
||||
// Replace known expressions
|
||||
let exprReplacements: [String: String] = [
|
||||
"@": " at ",
|
||||
"e.g.,": "for example, ",
|
||||
"i.e.,": "that is, ",
|
||||
]
|
||||
|
||||
for (old, new) in exprReplacements {
|
||||
text = text.replacingOccurrences(of: old, with: new)
|
||||
}
|
||||
|
||||
// Fix spacing around punctuation
|
||||
text = text.replacingOccurrences(of: " ,", with: ",")
|
||||
text = text.replacingOccurrences(of: " .", with: ".")
|
||||
text = text.replacingOccurrences(of: " !", with: "!")
|
||||
text = text.replacingOccurrences(of: " ?", with: "?")
|
||||
text = text.replacingOccurrences(of: " ;", with: ";")
|
||||
text = text.replacingOccurrences(of: " :", with: ":")
|
||||
text = text.replacingOccurrences(of: " '", with: "'")
|
||||
|
||||
// Remove duplicate quotes
|
||||
while text.contains("\"\"") {
|
||||
text = text.replacingOccurrences(of: "\"\"", with: "\"")
|
||||
}
|
||||
while text.contains("''") {
|
||||
text = text.replacingOccurrences(of: "''", with: "'")
|
||||
}
|
||||
while text.contains("``") {
|
||||
text = text.replacingOccurrences(of: "``", with: "`")
|
||||
}
|
||||
|
||||
// Remove extra spaces
|
||||
let whitespacePattern = try! NSRegularExpression(pattern: "\\s+")
|
||||
let whitespaceRange = NSRange(text.startIndex..., in: text)
|
||||
text = whitespacePattern.stringByReplacingMatches(in: text, range: whitespaceRange, withTemplate: " ")
|
||||
text = text.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
|
||||
// If text doesn't end with punctuation, quotes, or closing brackets, add a period
|
||||
if !text.isEmpty {
|
||||
let punctPattern = try! NSRegularExpression(pattern: "[.!?;:,'\"\\u201C\\u201D\\u2018\\u2019)\\]}…。」』】〉》›»]$")
|
||||
let punctRange = NSRange(text.startIndex..., in: text)
|
||||
if punctPattern.firstMatch(in: text, range: punctRange) == nil {
|
||||
text += "."
|
||||
}
|
||||
}
|
||||
|
||||
// Validate language
|
||||
guard isValidLang(lang) else {
|
||||
fatalError("Invalid language: \(lang). Available: \(AVAILABLE_LANGS.joined(separator: ", "))")
|
||||
}
|
||||
|
||||
// Wrap text with language tags
|
||||
text = "<\(lang)>\(text)</\(lang)>"
|
||||
|
||||
return text
|
||||
}
|
||||
|
||||
func lengthToMask(_ lengths: [Int], maxLen: Int? = nil) -> [[[Float]]] {
|
||||
let actualMaxLen = maxLen ?? (lengths.max() ?? 0)
|
||||
|
||||
var mask = [[[Float]]]()
|
||||
for len in lengths {
|
||||
var row = Array(repeating: Float(0.0), count: actualMaxLen)
|
||||
for j in 0..<min(len, actualMaxLen) {
|
||||
row[j] = 1.0
|
||||
}
|
||||
mask.append([row])
|
||||
}
|
||||
return mask
|
||||
}
|
||||
|
||||
func getTextMask(_ textIdsLengths: [Int]) -> [[[Float]]] {
|
||||
let maxLen = textIdsLengths.max() ?? 0
|
||||
return lengthToMask(textIdsLengths, maxLen: maxLen)
|
||||
}
|
||||
|
||||
func sampleNoisyLatent(duration: [Float], sampleRate: Int, baseChunkSize: Int, chunkCompress: Int, latentDim: Int) -> (noisyLatent: [[[Float]]], latentMask: [[[Float]]]) {
|
||||
let bsz = duration.count
|
||||
let maxDur = duration.max() ?? 0.0
|
||||
|
||||
let wavLenMax = Int(maxDur * Float(sampleRate))
|
||||
var wavLengths = [Int]()
|
||||
for d in duration {
|
||||
wavLengths.append(Int(d * Float(sampleRate)))
|
||||
}
|
||||
|
||||
let chunkSize = baseChunkSize * chunkCompress
|
||||
let latentLen = (wavLenMax + chunkSize - 1) / chunkSize
|
||||
let latentDimVal = latentDim * chunkCompress
|
||||
|
||||
var noisyLatent = [[[Float]]]()
|
||||
for _ in 0..<bsz {
|
||||
var batch = [[Float]]()
|
||||
for _ in 0..<latentDimVal {
|
||||
var row = [Float]()
|
||||
for _ in 0..<latentLen {
|
||||
// Box-Muller transform
|
||||
let u1 = Float.random(in: 0.0001...1.0)
|
||||
let u2 = Float.random(in: 0.0...1.0)
|
||||
let val = sqrt(-2.0 * log(u1)) * cos(2.0 * Float.pi * u2)
|
||||
row.append(val)
|
||||
}
|
||||
batch.append(row)
|
||||
}
|
||||
noisyLatent.append(batch)
|
||||
}
|
||||
|
||||
var latentLengths = [Int]()
|
||||
for len in wavLengths {
|
||||
latentLengths.append((len + chunkSize - 1) / chunkSize)
|
||||
}
|
||||
|
||||
let latentMask = lengthToMask(latentLengths, maxLen: latentLen)
|
||||
|
||||
// Apply mask
|
||||
for b in 0..<bsz {
|
||||
for d in 0..<latentDimVal {
|
||||
for t in 0..<latentLen {
|
||||
noisyLatent[b][d][t] *= latentMask[b][0][t]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (noisyLatent, latentMask)
|
||||
}
|
||||
|
||||
func getLatentMask(_ wavLengths: [Int64], _ cfgs: Config) -> [[[Float]]] {
|
||||
let baseChunkSize = cfgs.ae.base_chunk_size
|
||||
let chunkCompressFactor = cfgs.ttl.chunk_compress_factor
|
||||
let latentSize = baseChunkSize * chunkCompressFactor
|
||||
|
||||
var latentLengths = [Int]()
|
||||
for len in wavLengths {
|
||||
latentLengths.append((Int(len) + latentSize - 1) / latentSize)
|
||||
}
|
||||
|
||||
let maxLen = latentLengths.max() ?? 0
|
||||
return lengthToMask(latentLengths, maxLen: maxLen)
|
||||
}
|
||||
|
||||
// MARK: - WAV File I/O
|
||||
|
||||
func writeWavFile(_ filename: String, _ audioData: [Float], _ sampleRate: Int) throws {
|
||||
let url = URL(fileURLWithPath: filename)
|
||||
|
||||
// Convert float to int16
|
||||
let int16Data = audioData.map { sample -> Int16 in
|
||||
let clamped = max(-1.0, min(1.0, sample))
|
||||
return Int16(clamped * 32767.0)
|
||||
}
|
||||
|
||||
// Create WAV header
|
||||
let numChannels: UInt16 = 1
|
||||
let bitsPerSample: UInt16 = 16
|
||||
let byteRate = UInt32(sampleRate) * UInt32(numChannels) * UInt32(bitsPerSample) / 8
|
||||
let blockAlign = numChannels * bitsPerSample / 8
|
||||
let dataSize = UInt32(int16Data.count * 2)
|
||||
|
||||
var data = Data()
|
||||
|
||||
// RIFF chunk
|
||||
data.append("RIFF".data(using: .ascii)!)
|
||||
withUnsafeBytes(of: UInt32(36 + dataSize).littleEndian) { data.append(contentsOf: $0) }
|
||||
data.append("WAVE".data(using: .ascii)!)
|
||||
|
||||
// fmt chunk
|
||||
data.append("fmt ".data(using: .ascii)!)
|
||||
withUnsafeBytes(of: UInt32(16).littleEndian) { data.append(contentsOf: $0) }
|
||||
withUnsafeBytes(of: UInt16(1).littleEndian) { data.append(contentsOf: $0) } // PCM
|
||||
withUnsafeBytes(of: numChannels.littleEndian) { data.append(contentsOf: $0) }
|
||||
withUnsafeBytes(of: UInt32(sampleRate).littleEndian) { data.append(contentsOf: $0) }
|
||||
withUnsafeBytes(of: byteRate.littleEndian) { data.append(contentsOf: $0) }
|
||||
withUnsafeBytes(of: blockAlign.littleEndian) { data.append(contentsOf: $0) }
|
||||
withUnsafeBytes(of: bitsPerSample.littleEndian) { data.append(contentsOf: $0) }
|
||||
|
||||
// data chunk
|
||||
data.append("data".data(using: .ascii)!)
|
||||
withUnsafeBytes(of: dataSize.littleEndian) { data.append(contentsOf: $0) }
|
||||
|
||||
// audio data
|
||||
int16Data.withUnsafeBytes { data.append(contentsOf: $0) }
|
||||
|
||||
try data.write(to: url)
|
||||
}
|
||||
|
||||
// MARK: - Text Chunking
|
||||
|
||||
let MAX_CHUNK_LENGTH = 300
|
||||
let ABBREVIATIONS = [
|
||||
"Dr.", "Mr.", "Mrs.", "Ms.", "Prof.", "Sr.", "Jr.",
|
||||
"St.", "Ave.", "Rd.", "Blvd.", "Dept.", "Inc.", "Ltd.",
|
||||
"Co.", "Corp.", "etc.", "vs.", "i.e.", "e.g.", "Ph.D."
|
||||
]
|
||||
|
||||
func chunkText(_ text: String, maxLen: Int = 0) -> [String] {
|
||||
let actualMaxLen = maxLen > 0 ? maxLen : MAX_CHUNK_LENGTH
|
||||
let trimmedText = text.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines)
|
||||
|
||||
if trimmedText.isEmpty {
|
||||
return [""]
|
||||
}
|
||||
|
||||
// Split by paragraphs using regex
|
||||
let paraPattern = try! NSRegularExpression(pattern: "\\n\\s*\\n")
|
||||
let paraRange = NSRange(trimmedText.startIndex..., in: trimmedText)
|
||||
var paragraphs = [String]()
|
||||
var lastEnd = trimmedText.startIndex
|
||||
|
||||
paraPattern.enumerateMatches(in: trimmedText, range: paraRange) { match, _, _ in
|
||||
if let match = match, let range = Range(match.range, in: trimmedText) {
|
||||
paragraphs.append(String(trimmedText[lastEnd..<range.lowerBound]))
|
||||
lastEnd = range.upperBound
|
||||
}
|
||||
}
|
||||
if lastEnd < trimmedText.endIndex {
|
||||
paragraphs.append(String(trimmedText[lastEnd...]))
|
||||
}
|
||||
if paragraphs.isEmpty {
|
||||
paragraphs = [trimmedText]
|
||||
}
|
||||
|
||||
var chunks = [String]()
|
||||
|
||||
for para in paragraphs {
|
||||
let trimmedPara = para.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines)
|
||||
if trimmedPara.isEmpty {
|
||||
continue
|
||||
}
|
||||
|
||||
if trimmedPara.count <= actualMaxLen {
|
||||
chunks.append(trimmedPara)
|
||||
continue
|
||||
}
|
||||
|
||||
// Split by sentences
|
||||
let sentences = splitSentences(trimmedPara)
|
||||
var current = ""
|
||||
var currentLen = 0
|
||||
|
||||
for sentence in sentences {
|
||||
let trimmedSentence = sentence.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines)
|
||||
if trimmedSentence.isEmpty {
|
||||
continue
|
||||
}
|
||||
|
||||
let sentenceLen = trimmedSentence.count
|
||||
if sentenceLen > actualMaxLen {
|
||||
// If sentence is longer than maxLen, split by comma or space
|
||||
if !current.isEmpty {
|
||||
chunks.append(current.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines))
|
||||
current = ""
|
||||
currentLen = 0
|
||||
}
|
||||
|
||||
// Try splitting by comma
|
||||
let parts = trimmedSentence.components(separatedBy: ",")
|
||||
for part in parts {
|
||||
let trimmedPart = part.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines)
|
||||
if trimmedPart.isEmpty {
|
||||
continue
|
||||
}
|
||||
|
||||
let partLen = trimmedPart.count
|
||||
if partLen > actualMaxLen {
|
||||
// Split by space as last resort
|
||||
let words = trimmedPart.components(separatedBy: CharacterSet.whitespaces).filter { !$0.isEmpty }
|
||||
var wordChunk = ""
|
||||
var wordChunkLen = 0
|
||||
|
||||
for word in words {
|
||||
let wordLen = word.count
|
||||
if wordChunkLen + wordLen + 1 > actualMaxLen && !wordChunk.isEmpty {
|
||||
chunks.append(wordChunk.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines))
|
||||
wordChunk = ""
|
||||
wordChunkLen = 0
|
||||
}
|
||||
|
||||
if !wordChunk.isEmpty {
|
||||
wordChunk += " "
|
||||
wordChunkLen += 1
|
||||
}
|
||||
wordChunk += word
|
||||
wordChunkLen += wordLen
|
||||
}
|
||||
|
||||
if !wordChunk.isEmpty {
|
||||
chunks.append(wordChunk.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines))
|
||||
}
|
||||
} else {
|
||||
if currentLen + partLen + 1 > actualMaxLen && !current.isEmpty {
|
||||
chunks.append(current.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines))
|
||||
current = ""
|
||||
currentLen = 0
|
||||
}
|
||||
|
||||
if !current.isEmpty {
|
||||
current += ", "
|
||||
currentLen += 2
|
||||
}
|
||||
current += trimmedPart
|
||||
currentLen += partLen
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if currentLen + sentenceLen + 1 > actualMaxLen && !current.isEmpty {
|
||||
chunks.append(current.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines))
|
||||
current = ""
|
||||
currentLen = 0
|
||||
}
|
||||
|
||||
if !current.isEmpty {
|
||||
current += " "
|
||||
currentLen += 1
|
||||
}
|
||||
current += trimmedSentence
|
||||
currentLen += sentenceLen
|
||||
}
|
||||
|
||||
if !current.isEmpty {
|
||||
chunks.append(current.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines))
|
||||
}
|
||||
}
|
||||
|
||||
return chunks.isEmpty ? [""] : chunks
|
||||
}
|
||||
|
||||
func splitSentences(_ text: String) -> [String] {
|
||||
// Swift's regex doesn't support lookbehind reliably, so we use a simpler approach
|
||||
// Split on sentence boundaries and then check if they're abbreviations
|
||||
let regex = try! NSRegularExpression(pattern: "([.!?])\\s+")
|
||||
let range = NSRange(text.startIndex..., in: text)
|
||||
|
||||
// Find all matches
|
||||
let matches = regex.matches(in: text, range: range)
|
||||
if matches.isEmpty {
|
||||
return [text]
|
||||
}
|
||||
|
||||
var sentences = [String]()
|
||||
var lastEnd = text.startIndex
|
||||
|
||||
for match in matches {
|
||||
guard let matchRange = Range(match.range, in: text) else { continue }
|
||||
|
||||
// Get the text before the punctuation
|
||||
let beforePunc = String(text[lastEnd..<matchRange.lowerBound])
|
||||
|
||||
// Get the punctuation character
|
||||
let puncRange = Range(NSRange(location: match.range.location, length: 1), in: text)!
|
||||
let punc = String(text[puncRange])
|
||||
|
||||
// Check if this ends with an abbreviation
|
||||
var isAbbrev = false
|
||||
let combined = beforePunc.trimmingCharacters(in: CharacterSet.whitespaces) + punc
|
||||
for abbrev in ABBREVIATIONS {
|
||||
if combined.hasSuffix(abbrev) {
|
||||
isAbbrev = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isAbbrev {
|
||||
// This is a real sentence boundary
|
||||
sentences.append(String(text[lastEnd..<matchRange.upperBound]))
|
||||
lastEnd = matchRange.upperBound
|
||||
}
|
||||
}
|
||||
|
||||
// Add the remaining text
|
||||
if lastEnd < text.endIndex {
|
||||
sentences.append(String(text[lastEnd...]))
|
||||
}
|
||||
|
||||
return sentences.isEmpty ? [text] : sentences
|
||||
}
|
||||
|
||||
// MARK: - Utility Functions
|
||||
|
||||
func timer<T>(_ name: String, _ f: () throws -> T) rethrows -> T {
|
||||
let start = Date()
|
||||
print("\(name)...")
|
||||
let result = try f()
|
||||
let elapsed = Date().timeIntervalSince(start)
|
||||
print(String(format: " -> %@ completed in %.2f sec", name, elapsed))
|
||||
return result
|
||||
}
|
||||
|
||||
func sanitizeFilename(_ text: String, maxLen: Int) -> String {
|
||||
let truncated = text.count > maxLen ? String(text.prefix(maxLen)) : text
|
||||
return truncated.map { char in
|
||||
if char.isLetter || char.isNumber {
|
||||
return char
|
||||
} else {
|
||||
return Character("_")
|
||||
}
|
||||
}.map(String.init).joined()
|
||||
}
|
||||
|
||||
func loadCfgs(_ onnxDir: String) throws -> Config {
|
||||
let cfgPath = "\(onnxDir)/tts.json"
|
||||
let data = try Data(contentsOf: URL(fileURLWithPath: cfgPath))
|
||||
let config = try JSONDecoder().decode(Config.self, from: data)
|
||||
return config
|
||||
}
|
||||
|
||||
// MARK: - ONNX Runtime Integration
|
||||
|
||||
struct Style {
|
||||
let ttl: ORTValue
|
||||
let dp: ORTValue
|
||||
}
|
||||
|
||||
class TextToSpeech {
|
||||
let cfgs: Config
|
||||
let textProcessor: UnicodeProcessor
|
||||
let dpOrt: ORTSession
|
||||
let textEncOrt: ORTSession
|
||||
let vectorEstOrt: ORTSession
|
||||
let vocoderOrt: ORTSession
|
||||
let sampleRate: Int
|
||||
|
||||
init(cfgs: Config, textProcessor: UnicodeProcessor,
|
||||
dpOrt: ORTSession, textEncOrt: ORTSession,
|
||||
vectorEstOrt: ORTSession, vocoderOrt: ORTSession) {
|
||||
self.cfgs = cfgs
|
||||
self.textProcessor = textProcessor
|
||||
self.dpOrt = dpOrt
|
||||
self.textEncOrt = textEncOrt
|
||||
self.vectorEstOrt = vectorEstOrt
|
||||
self.vocoderOrt = vocoderOrt
|
||||
self.sampleRate = cfgs.ae.sample_rate
|
||||
}
|
||||
|
||||
private func _infer(_ textList: [String], _ langList: [String], _ style: Style, _ totalStep: Int, speed: Float = 1.05) throws -> (wav: [Float], duration: [Float]) {
|
||||
let bsz = textList.count
|
||||
|
||||
// Process text
|
||||
let (textIds, textMask) = textProcessor.call(textList, langList)
|
||||
|
||||
// Flatten text IDs
|
||||
let textIdsFlat = textIds.flatMap { $0 }
|
||||
let textIdsShape: [NSNumber] = [NSNumber(value: bsz), NSNumber(value: textIds[0].count)]
|
||||
let textIdsValue = try ORTValue(tensorData: NSMutableData(bytes: textIdsFlat, length: textIdsFlat.count * MemoryLayout<Int64>.size),
|
||||
elementType: .int64,
|
||||
shape: textIdsShape)
|
||||
|
||||
// Flatten text mask
|
||||
let textMaskFlat = textMask.flatMap { $0.flatMap { $0 } }
|
||||
let textMaskShape: [NSNumber] = [NSNumber(value: bsz), 1, NSNumber(value: textMask[0][0].count)]
|
||||
let textMaskValue = try ORTValue(tensorData: NSMutableData(bytes: textMaskFlat, length: textMaskFlat.count * MemoryLayout<Float>.size),
|
||||
elementType: .float,
|
||||
shape: textMaskShape)
|
||||
|
||||
// Predict duration
|
||||
let dpOutputs = try dpOrt.run(withInputs: ["text_ids": textIdsValue, "style_dp": style.dp, "text_mask": textMaskValue],
|
||||
outputNames: ["duration"],
|
||||
runOptions: nil)
|
||||
|
||||
let durationData = try dpOutputs["duration"]!.tensorData() as Data
|
||||
var duration = durationData.withUnsafeBytes { ptr in
|
||||
Array(ptr.bindMemory(to: Float.self))
|
||||
}
|
||||
|
||||
// Apply speed factor to duration
|
||||
for i in 0..<duration.count {
|
||||
duration[i] /= speed
|
||||
}
|
||||
|
||||
// Encode text
|
||||
let textEncOutputs = try textEncOrt.run(withInputs: ["text_ids": textIdsValue, "style_ttl": style.ttl, "text_mask": textMaskValue],
|
||||
outputNames: ["text_emb"],
|
||||
runOptions: nil)
|
||||
|
||||
let textEmbValue = textEncOutputs["text_emb"]!
|
||||
|
||||
// Sample noisy latent
|
||||
var (xt, latentMask) = sampleNoisyLatent(duration: duration, sampleRate: sampleRate,
|
||||
baseChunkSize: cfgs.ae.base_chunk_size,
|
||||
chunkCompress: cfgs.ttl.chunk_compress_factor,
|
||||
latentDim: cfgs.ttl.latent_dim)
|
||||
|
||||
// Prepare constant arrays
|
||||
let totalStepArray = Array(repeating: Float(totalStep), count: bsz)
|
||||
let totalStepValue = try ORTValue(tensorData: NSMutableData(bytes: totalStepArray, length: totalStepArray.count * MemoryLayout<Float>.size),
|
||||
elementType: .float,
|
||||
shape: [NSNumber(value: bsz)])
|
||||
|
||||
// Denoising loop
|
||||
for step in 0..<totalStep {
|
||||
let currentStepArray = Array(repeating: Float(step), count: bsz)
|
||||
let currentStepValue = try ORTValue(tensorData: NSMutableData(bytes: currentStepArray, length: currentStepArray.count * MemoryLayout<Float>.size),
|
||||
elementType: .float,
|
||||
shape: [NSNumber(value: bsz)])
|
||||
|
||||
// Flatten xt
|
||||
let xtFlat = xt.flatMap { $0.flatMap { $0 } }
|
||||
let xtShape: [NSNumber] = [NSNumber(value: bsz), NSNumber(value: xt[0].count), NSNumber(value: xt[0][0].count)]
|
||||
let xtValue = try ORTValue(tensorData: NSMutableData(bytes: xtFlat, length: xtFlat.count * MemoryLayout<Float>.size),
|
||||
elementType: .float,
|
||||
shape: xtShape)
|
||||
|
||||
// Flatten latent mask
|
||||
let latentMaskFlat = latentMask.flatMap { $0.flatMap { $0 } }
|
||||
let latentMaskShape: [NSNumber] = [NSNumber(value: bsz), 1, NSNumber(value: latentMask[0][0].count)]
|
||||
let latentMaskValue = try ORTValue(tensorData: NSMutableData(bytes: latentMaskFlat, length: latentMaskFlat.count * MemoryLayout<Float>.size),
|
||||
elementType: .float,
|
||||
shape: latentMaskShape)
|
||||
|
||||
let vectorEstOutputs = try vectorEstOrt.run(withInputs: [
|
||||
"noisy_latent": xtValue,
|
||||
"text_emb": textEmbValue,
|
||||
"style_ttl": style.ttl,
|
||||
"latent_mask": latentMaskValue,
|
||||
"text_mask": textMaskValue,
|
||||
"current_step": currentStepValue,
|
||||
"total_step": totalStepValue
|
||||
], outputNames: ["denoised_latent"], runOptions: nil)
|
||||
|
||||
let denoisedData = try vectorEstOutputs["denoised_latent"]!.tensorData() as Data
|
||||
let denoisedFlat = denoisedData.withUnsafeBytes { ptr in
|
||||
Array(ptr.bindMemory(to: Float.self))
|
||||
}
|
||||
|
||||
// Reshape to 3D
|
||||
let latentDimVal = xt[0].count
|
||||
let latentLen = xt[0][0].count
|
||||
xt = []
|
||||
var idx = 0
|
||||
for _ in 0..<bsz {
|
||||
var batch = [[Float]]()
|
||||
for _ in 0..<latentDimVal {
|
||||
var row = [Float]()
|
||||
for _ in 0..<latentLen {
|
||||
row.append(denoisedFlat[idx])
|
||||
idx += 1
|
||||
}
|
||||
batch.append(row)
|
||||
}
|
||||
xt.append(batch)
|
||||
}
|
||||
}
|
||||
|
||||
// Generate waveform
|
||||
let finalXtFlat = xt.flatMap { $0.flatMap { $0 } }
|
||||
let finalXtShape: [NSNumber] = [NSNumber(value: bsz), NSNumber(value: xt[0].count), NSNumber(value: xt[0][0].count)]
|
||||
let finalXtValue = try ORTValue(tensorData: NSMutableData(bytes: finalXtFlat, length: finalXtFlat.count * MemoryLayout<Float>.size),
|
||||
elementType: .float,
|
||||
shape: finalXtShape)
|
||||
|
||||
let vocoderOutputs = try vocoderOrt.run(withInputs: ["latent": finalXtValue],
|
||||
outputNames: ["wav_tts"],
|
||||
runOptions: nil)
|
||||
|
||||
let wavData = try vocoderOutputs["wav_tts"]!.tensorData() as Data
|
||||
let wav = wavData.withUnsafeBytes { ptr in
|
||||
Array(ptr.bindMemory(to: Float.self))
|
||||
}
|
||||
|
||||
return (wav, duration)
|
||||
}
|
||||
|
||||
func call(_ text: String, _ lang: String, _ style: Style, _ totalStep: Int, speed: Float = 1.05, silenceDuration: Float = 0.3) throws -> (wav: [Float], duration: Float) {
|
||||
let maxLen = lang == "ko" ? 120 : 300
|
||||
let chunks = chunkText(text, maxLen: maxLen)
|
||||
let langList = Array(repeating: lang, count: chunks.count)
|
||||
|
||||
var wavCat = [Float]()
|
||||
var durCat: Float = 0.0
|
||||
|
||||
for (i, chunk) in chunks.enumerated() {
|
||||
let result = try _infer([chunk], [langList[i]], style, totalStep, speed: speed)
|
||||
|
||||
let dur = result.duration[0]
|
||||
let wavLen = Int(Float(sampleRate) * dur)
|
||||
let wavChunk = Array(result.wav.prefix(wavLen))
|
||||
|
||||
if i == 0 {
|
||||
wavCat = wavChunk
|
||||
durCat = dur
|
||||
} else {
|
||||
let silenceLen = Int(silenceDuration * Float(sampleRate))
|
||||
let silence = [Float](repeating: 0.0, count: silenceLen)
|
||||
|
||||
wavCat.append(contentsOf: silence)
|
||||
wavCat.append(contentsOf: wavChunk)
|
||||
durCat += silenceDuration + dur
|
||||
}
|
||||
}
|
||||
|
||||
return (wavCat, durCat)
|
||||
}
|
||||
|
||||
func batch(_ textList: [String], _ langList: [String], _ style: Style, _ totalStep: Int, speed: Float = 1.05) throws -> (wav: [Float], duration: [Float]) {
|
||||
return try _infer(textList, langList, style, totalStep, speed: speed)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Component Loading Functions
|
||||
|
||||
func loadVoiceStyle(_ voiceStylePaths: [String], verbose: Bool) throws -> Style {
|
||||
let bsz = voiceStylePaths.count
|
||||
|
||||
// Read first file to get dimensions
|
||||
let firstData = try Data(contentsOf: URL(fileURLWithPath: voiceStylePaths[0]))
|
||||
let firstStyle = try JSONDecoder().decode(VoiceStyleData.self, from: firstData)
|
||||
|
||||
let ttlDims = firstStyle.style_ttl.dims
|
||||
let dpDims = firstStyle.style_dp.dims
|
||||
|
||||
let ttlDim1 = ttlDims[1]
|
||||
let ttlDim2 = ttlDims[2]
|
||||
let dpDim1 = dpDims[1]
|
||||
let dpDim2 = dpDims[2]
|
||||
|
||||
// Pre-allocate arrays with full batch size
|
||||
let ttlSize = bsz * ttlDim1 * ttlDim2
|
||||
let dpSize = bsz * dpDim1 * dpDim2
|
||||
var ttlFlat = [Float](repeating: 0.0, count: ttlSize)
|
||||
var dpFlat = [Float](repeating: 0.0, count: dpSize)
|
||||
|
||||
// Fill in the data
|
||||
for (i, path) in voiceStylePaths.enumerated() {
|
||||
let data = try Data(contentsOf: URL(fileURLWithPath: path))
|
||||
let voiceStyle = try JSONDecoder().decode(VoiceStyleData.self, from: data)
|
||||
|
||||
// Flatten TTL data
|
||||
let ttlOffset = i * ttlDim1 * ttlDim2
|
||||
var idx = 0
|
||||
for batch in voiceStyle.style_ttl.data {
|
||||
for row in batch {
|
||||
for val in row {
|
||||
ttlFlat[ttlOffset + idx] = val
|
||||
idx += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Flatten DP data
|
||||
let dpOffset = i * dpDim1 * dpDim2
|
||||
idx = 0
|
||||
for batch in voiceStyle.style_dp.data {
|
||||
for row in batch {
|
||||
for val in row {
|
||||
dpFlat[dpOffset + idx] = val
|
||||
idx += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let ttlShape: [NSNumber] = [NSNumber(value: bsz), NSNumber(value: ttlDim1), NSNumber(value: ttlDim2)]
|
||||
let dpShape: [NSNumber] = [NSNumber(value: bsz), NSNumber(value: dpDim1), NSNumber(value: dpDim2)]
|
||||
|
||||
let ttlValue = try ORTValue(tensorData: NSMutableData(bytes: &ttlFlat, length: ttlFlat.count * MemoryLayout<Float>.size),
|
||||
elementType: .float,
|
||||
shape: ttlShape)
|
||||
let dpValue = try ORTValue(tensorData: NSMutableData(bytes: &dpFlat, length: dpFlat.count * MemoryLayout<Float>.size),
|
||||
elementType: .float,
|
||||
shape: dpShape)
|
||||
|
||||
if verbose {
|
||||
print("Loaded \(bsz) voice styles\n")
|
||||
}
|
||||
|
||||
return Style(ttl: ttlValue, dp: dpValue)
|
||||
}
|
||||
|
||||
func loadTextToSpeech(_ onnxDir: String, _ useGpu: Bool, _ env: ORTEnv) throws -> TextToSpeech {
|
||||
if useGpu {
|
||||
throw NSError(domain: "TTS", code: 1, userInfo: [NSLocalizedDescriptionKey: "GPU mode is not supported yet"])
|
||||
}
|
||||
print("Using CPU for inference\n")
|
||||
|
||||
let cfgs = try loadCfgs(onnxDir)
|
||||
|
||||
let sessionOptions = try ORTSessionOptions()
|
||||
|
||||
let dpPath = "\(onnxDir)/duration_predictor.onnx"
|
||||
let textEncPath = "\(onnxDir)/text_encoder.onnx"
|
||||
let vectorEstPath = "\(onnxDir)/vector_estimator.onnx"
|
||||
let vocoderPath = "\(onnxDir)/vocoder.onnx"
|
||||
|
||||
let dpOrt = try ORTSession(env: env, modelPath: dpPath, sessionOptions: sessionOptions)
|
||||
let textEncOrt = try ORTSession(env: env, modelPath: textEncPath, sessionOptions: sessionOptions)
|
||||
let vectorEstOrt = try ORTSession(env: env, modelPath: vectorEstPath, sessionOptions: sessionOptions)
|
||||
let vocoderOrt = try ORTSession(env: env, modelPath: vocoderPath, sessionOptions: sessionOptions)
|
||||
|
||||
let unicodeIndexerPath = "\(onnxDir)/unicode_indexer.json"
|
||||
let textProcessor = try UnicodeProcessor(unicodeIndexerPath: unicodeIndexerPath)
|
||||
|
||||
return TextToSpeech(cfgs: cfgs, textProcessor: textProcessor,
|
||||
dpOrt: dpOrt, textEncOrt: textEncOrt,
|
||||
vectorEstOrt: vectorEstOrt, vocoderOrt: vocoderOrt)
|
||||
}
|
||||
Reference in New Issue
Block a user