initial commit
This commit is contained in:
15
swift/.gitignore
vendored
Normal file
15
swift/.gitignore
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
# Swift Package Manager
|
||||
.build/
|
||||
.swiftpm/
|
||||
*.xcodeproj
|
||||
*.xcworkspace
|
||||
|
||||
# Build artifacts
|
||||
example_onnx
|
||||
|
||||
# Results
|
||||
results/*.wav
|
||||
|
||||
# macOS
|
||||
.DS_Store
|
||||
|
||||
14
swift/Package.resolved
Normal file
14
swift/Package.resolved
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"pins" : [
|
||||
{
|
||||
"identity" : "onnxruntime-swift-package-manager",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/microsoft/onnxruntime-swift-package-manager.git",
|
||||
"state" : {
|
||||
"revision" : "12ce7374c86944e1f68f3a866d10105d8357f074",
|
||||
"version" : "1.20.0"
|
||||
}
|
||||
}
|
||||
],
|
||||
"version" : 2
|
||||
}
|
||||
22
swift/Package.swift
Normal file
22
swift/Package.swift
Normal file
@@ -0,0 +1,22 @@
|
||||
// swift-tools-version: 5.9
|
||||
import PackageDescription
|
||||
|
||||
let package = Package(
|
||||
name: "Supertonic",
|
||||
platforms: [
|
||||
.macOS(.v13)
|
||||
],
|
||||
dependencies: [
|
||||
.package(url: "https://github.com/microsoft/onnxruntime-swift-package-manager.git", from: "1.16.0"),
|
||||
],
|
||||
targets: [
|
||||
.executableTarget(
|
||||
name: "example_onnx",
|
||||
dependencies: [
|
||||
.product(name: "onnxruntime", package: "onnxruntime-swift-package-manager")
|
||||
],
|
||||
path: "Sources"
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
122
swift/README.md
Normal file
122
swift/README.md
Normal file
@@ -0,0 +1,122 @@
|
||||
# TTS ONNX Inference Examples
|
||||
|
||||
This guide provides examples for running TTS inference using `example_onnx`.
|
||||
|
||||
## 📰 Update News
|
||||
|
||||
**2026.01.06** - 🎉 **Supertonic 2** released with multilingual support! Now supports English (`en`), Korean (`ko`), Spanish (`es`), Portuguese (`pt`), and French (`fr`). [Demo](https://huggingface.co/spaces/Supertone/supertonic-2) | [Models](https://huggingface.co/Supertone/supertonic-2)
|
||||
|
||||
**2025.12.10** - Added [6 new voice styles](https://huggingface.co/Supertone/supertonic/tree/b10dbaf18b316159be75b34d24f740008fddd381) (M3, M4, M5, F3, F4, F5). See [Voices](https://supertone-inc.github.io/supertonic-py/voices/) for details
|
||||
|
||||
**2025.12.08** - Optimized ONNX models via [OnnxSlim](https://github.com/inisis/OnnxSlim) now available on [Hugging Face Models](https://huggingface.co/Supertone/supertonic)
|
||||
|
||||
**2025.11.23** - Enhanced text preprocessing with comprehensive normalization, emoji removal, symbol replacement, and punctuation handling for improved synthesis quality.
|
||||
|
||||
**2025.11.19** - Added `--speed` parameter to control speech synthesis speed (default: 1.05, recommended range: 0.9-1.5).
|
||||
|
||||
**2025.11.19** - Added automatic text chunking for long-form inference. Long texts are split into chunks and synthesized with natural pauses.
|
||||
|
||||
## Installation
|
||||
|
||||
This project uses Swift Package Manager (SPM) for dependency management.
|
||||
|
||||
### Prerequisites
|
||||
- Swift 5.9 or later
|
||||
- macOS 13.0 or later
|
||||
|
||||
### Build the project
|
||||
```bash
|
||||
swift build -c release
|
||||
```
|
||||
|
||||
## Basic Usage
|
||||
|
||||
### Example 1: Default Inference
|
||||
Run inference with default settings:
|
||||
```bash
|
||||
.build/release/example_onnx
|
||||
```
|
||||
|
||||
This will use:
|
||||
- Voice style: `assets/voice_styles/M1.json`
|
||||
- Text: "This morning, I took a walk in the park, and the sound of the birds and the breeze was so pleasant that I stopped for a long time just to listen."
|
||||
- Output directory: `results/`
|
||||
- Total steps: 5
|
||||
- Number of generations: 4
|
||||
|
||||
### Example 2: Batch Inference
|
||||
Process multiple voice styles and texts at once:
|
||||
```bash
|
||||
.build/release/example_onnx \
|
||||
--batch \
|
||||
--voice-style assets/voice_styles/M1.json,assets/voice_styles/F1.json \
|
||||
--text "The sun sets behind the mountains, painting the sky in shades of pink and orange.|오늘 아침에 공원을 산책했는데, 새소리와 바람 소리가 너무 기분 좋았어요." \
|
||||
--lang en,ko
|
||||
```
|
||||
|
||||
This will:
|
||||
- Generate speech for 2 different voice-text-language triplets
|
||||
- Use male voice (M1.json) for the first English text
|
||||
- Use female voice (F1.json) for the second Korean text
|
||||
- Process both samples in a single batch
|
||||
|
||||
### Example 3: High Quality Inference
|
||||
Increase denoising steps for better quality:
|
||||
```bash
|
||||
.build/release/example_onnx \
|
||||
--total-step 10 \
|
||||
--voice-style assets/voice_styles/M1.json \
|
||||
--text "Increasing the number of denoising steps improves the output's fidelity and overall quality."
|
||||
```
|
||||
|
||||
This will:
|
||||
- Use 10 denoising steps instead of the default 5
|
||||
- Produce higher quality output at the cost of slower inference
|
||||
|
||||
### Example 4: Long-Form Inference
|
||||
The system automatically chunks long texts into manageable segments, synthesizes each segment separately, and concatenates them with natural pauses (0.3 seconds by default) into a single audio file. This happens by default when you don't use the `--batch` flag:
|
||||
|
||||
```bash
|
||||
.build/release/example_onnx \
|
||||
--voice-style assets/voice_styles/M1.json \
|
||||
--text "This is a very long text that will be automatically split into multiple chunks. The system will process each chunk separately and then concatenate them together with natural pauses between segments. This ensures that even very long texts can be processed efficiently while maintaining natural speech flow and avoiding memory issues."
|
||||
```
|
||||
|
||||
This will:
|
||||
- Automatically split the text into chunks based on paragraph and sentence boundaries
|
||||
- Synthesize each chunk separately
|
||||
- Add 0.3 seconds of silence between chunks for natural pauses
|
||||
- Concatenate all chunks into a single audio file
|
||||
|
||||
**Note**: Automatic text chunking is disabled when using `--batch` mode. In batch mode, each text is processed as-is without chunking.
|
||||
|
||||
## Available Arguments
|
||||
|
||||
| Argument | Type | Default | Description |
|
||||
|----------|------|---------|-------------|
|
||||
| `--use-gpu` | flag | False | Use GPU for inference (default: CPU) |
|
||||
| `--onnx-dir` | str | `assets/onnx` | Path to ONNX model directory |
|
||||
| `--total-step` | int | 5 | Number of denoising steps (higher = better quality, slower) |
|
||||
| `--n-test` | int | 4 | Number of times to generate each sample |
|
||||
| `--voice-style` | str+ | `assets/voice_styles/M1.json` | Voice style file path(s) |
|
||||
| `--text` | str+ | (long default text) | Text(s) to synthesize |
|
||||
| `--lang` | str+ | `en` | Language(s) for synthesis (en, ko, es, pt, fr) |
|
||||
| `--save-dir` | str | `results` | Output directory |
|
||||
| `--batch` | flag | False | Enable batch mode (multiple text-style-lang triplets, disables automatic chunking) |
|
||||
|
||||
## Multilingual Support
|
||||
|
||||
Supertonic 2 supports multiple languages. Use the `--lang` argument to specify the language:
|
||||
|
||||
- `en` - English (default)
|
||||
- `ko` - Korean (한국어)
|
||||
- `es` - Spanish (Español)
|
||||
- `pt` - Portuguese (Português)
|
||||
- `fr` - French (Français)
|
||||
|
||||
## Notes
|
||||
|
||||
- **Batch Processing**: When using `--batch`, the number of `--voice-style`, `--text`, and `--lang` entries must match
|
||||
- **Automatic Chunking**: Without `--batch`, long texts are automatically split and concatenated with 0.3s pauses
|
||||
- **Quality vs Speed**: Higher `--total-step` values produce better quality but take longer
|
||||
- **GPU Support**: GPU mode is not supported yet
|
||||
163
swift/Sources/ExampleONNX.swift
Normal file
163
swift/Sources/ExampleONNX.swift
Normal file
@@ -0,0 +1,163 @@
|
||||
import Foundation
|
||||
import OnnxRuntimeBindings
|
||||
|
||||
struct Args {
|
||||
var useGpu: Bool = false
|
||||
var onnxDir: String = "assets/onnx"
|
||||
var totalStep: Int = 5
|
||||
var speed: Float = 1.05
|
||||
var nTest: Int = 4
|
||||
var voiceStyle: [String] = ["assets/voice_styles/M1.json"]
|
||||
var text: [String] = ["This morning, I took a walk in the park, and the sound of the birds and the breeze was so pleasant that I stopped for a long time just to listen."]
|
||||
var lang: [String] = ["en"]
|
||||
var saveDir: String = "results"
|
||||
var batch: Bool = false
|
||||
}
|
||||
|
||||
func parseArgs() -> Args {
|
||||
var args = Args()
|
||||
let arguments = CommandLine.arguments
|
||||
|
||||
var i = 1
|
||||
while i < arguments.count {
|
||||
let arg = arguments[i]
|
||||
|
||||
switch arg {
|
||||
case "--use-gpu":
|
||||
args.useGpu = true
|
||||
case "--onnx-dir":
|
||||
if i + 1 < arguments.count {
|
||||
args.onnxDir = arguments[i + 1]
|
||||
i += 1
|
||||
}
|
||||
case "--total-step":
|
||||
if i + 1 < arguments.count {
|
||||
args.totalStep = Int(arguments[i + 1]) ?? 5
|
||||
i += 1
|
||||
}
|
||||
case "--speed":
|
||||
if i + 1 < arguments.count {
|
||||
args.speed = Float(arguments[i + 1]) ?? 1.05
|
||||
i += 1
|
||||
}
|
||||
case "--n-test":
|
||||
if i + 1 < arguments.count {
|
||||
args.nTest = Int(arguments[i + 1]) ?? 4
|
||||
i += 1
|
||||
}
|
||||
case "--voice-style":
|
||||
if i + 1 < arguments.count {
|
||||
args.voiceStyle = arguments[i + 1].components(separatedBy: ",")
|
||||
i += 1
|
||||
}
|
||||
case "--text":
|
||||
if i + 1 < arguments.count {
|
||||
args.text = arguments[i + 1].components(separatedBy: "|")
|
||||
i += 1
|
||||
}
|
||||
case "--lang":
|
||||
if i + 1 < arguments.count {
|
||||
args.lang = arguments[i + 1].components(separatedBy: ",")
|
||||
i += 1
|
||||
}
|
||||
case "--save-dir":
|
||||
if i + 1 < arguments.count {
|
||||
args.saveDir = arguments[i + 1]
|
||||
i += 1
|
||||
}
|
||||
case "--batch":
|
||||
args.batch = true
|
||||
default:
|
||||
break
|
||||
}
|
||||
|
||||
i += 1
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
@main
|
||||
struct ExampleONNX {
|
||||
static func main() async {
|
||||
print("=== TTS Inference with ONNX Runtime (Swift) ===\n")
|
||||
|
||||
// --- 1. Parse arguments --- //
|
||||
let args = parseArgs()
|
||||
|
||||
if args.batch {
|
||||
guard args.voiceStyle.count == args.text.count else {
|
||||
print("Error: Number of voice styles (\(args.voiceStyle.count)) must match number of texts (\(args.text.count))")
|
||||
return
|
||||
}
|
||||
guard args.lang.count == args.text.count else {
|
||||
print("Error: Number of languages (\(args.lang.count)) must match number of texts (\(args.text.count))")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
let bsz = args.voiceStyle.count
|
||||
|
||||
do {
|
||||
let env = try ORTEnv(loggingLevel: .warning)
|
||||
|
||||
// --- 2. Load TTS components --- //
|
||||
let textToSpeech = try loadTextToSpeech(args.onnxDir, args.useGpu, env)
|
||||
|
||||
// --- 3. Load voice styles --- //
|
||||
let style = try loadVoiceStyle(args.voiceStyle, verbose: true)
|
||||
|
||||
// --- 4. Synthesize speech --- //
|
||||
try? FileManager.default.createDirectory(atPath: args.saveDir, withIntermediateDirectories: true)
|
||||
|
||||
for n in 0..<args.nTest {
|
||||
print("\n[\(n + 1)/\(args.nTest)] Starting synthesis...")
|
||||
|
||||
let wav: [Float]
|
||||
let duration: [Float]
|
||||
|
||||
if args.batch {
|
||||
let result = try timer("Generating speech from text") {
|
||||
try textToSpeech.batch(args.text, args.lang, style, args.totalStep, speed: args.speed)
|
||||
}
|
||||
wav = result.wav
|
||||
duration = result.duration
|
||||
} else {
|
||||
let result = try timer("Generating speech from text") {
|
||||
try textToSpeech.call(args.text[0], args.lang[0], style, args.totalStep, speed: args.speed, silenceDuration: 0.3)
|
||||
}
|
||||
wav = result.wav
|
||||
duration = [result.duration]
|
||||
}
|
||||
|
||||
// Save outputs
|
||||
for i in 0..<bsz {
|
||||
let fname = "\(sanitizeFilename(args.text[i], maxLen: 20))_\(n + 1).wav"
|
||||
let wavOut: [Float]
|
||||
|
||||
if args.batch {
|
||||
let wavLen = wav.count / bsz
|
||||
let actualLen = Int(Float(textToSpeech.sampleRate) * duration[i])
|
||||
let wavStart = i * wavLen
|
||||
let wavEnd = min(wavStart + actualLen, wavStart + wavLen)
|
||||
wavOut = Array(wav[wavStart..<wavEnd])
|
||||
} else {
|
||||
// For non-batch mode, wav is a single concatenated audio
|
||||
let actualLen = Int(Float(textToSpeech.sampleRate) * duration[0])
|
||||
wavOut = Array(wav.prefix(actualLen))
|
||||
}
|
||||
|
||||
let outputPath = "\(args.saveDir)/\(fname)"
|
||||
try writeWavFile(outputPath, wavOut, textToSpeech.sampleRate)
|
||||
print("Saved: \(outputPath)")
|
||||
}
|
||||
}
|
||||
|
||||
print("\n=== Synthesis completed successfully! ===")
|
||||
|
||||
} catch {
|
||||
print("Error during inference: \(error)")
|
||||
exit(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
835
swift/Sources/Helper.swift
Normal file
835
swift/Sources/Helper.swift
Normal file
@@ -0,0 +1,835 @@
|
||||
import Foundation
|
||||
import Accelerate
|
||||
import OnnxRuntimeBindings
|
||||
|
||||
// MARK: - Available Languages
|
||||
|
||||
let AVAILABLE_LANGS = ["en", "ko", "es", "pt", "fr"]
|
||||
|
||||
func isValidLang(_ lang: String) -> Bool {
|
||||
return AVAILABLE_LANGS.contains(lang)
|
||||
}
|
||||
|
||||
// MARK: - Configuration Structures
|
||||
|
||||
struct Config: Codable {
|
||||
struct AEConfig: Codable {
|
||||
let sample_rate: Int
|
||||
let base_chunk_size: Int
|
||||
}
|
||||
|
||||
struct TTLConfig: Codable {
|
||||
let chunk_compress_factor: Int
|
||||
let latent_dim: Int
|
||||
}
|
||||
|
||||
let ae: AEConfig
|
||||
let ttl: TTLConfig
|
||||
}
|
||||
|
||||
// MARK: - Voice Style Data Structure
|
||||
|
||||
struct VoiceStyleData: Codable {
|
||||
struct StyleComponent: Codable {
|
||||
let data: [[[Float]]]
|
||||
let dims: [Int]
|
||||
let type: String
|
||||
}
|
||||
|
||||
let style_ttl: StyleComponent
|
||||
let style_dp: StyleComponent
|
||||
}
|
||||
|
||||
// MARK: - Unicode Text Processor
|
||||
|
||||
class UnicodeProcessor {
|
||||
let indexer: [Int64]
|
||||
|
||||
init(unicodeIndexerPath: String) throws {
|
||||
let data = try Data(contentsOf: URL(fileURLWithPath: unicodeIndexerPath))
|
||||
self.indexer = try JSONDecoder().decode([Int64].self, from: data)
|
||||
}
|
||||
|
||||
func call(_ textList: [String], _ langList: [String]) -> (textIds: [[Int64]], textMask: [[[Float]]]) {
|
||||
var processedTexts = [String]()
|
||||
for (i, text) in textList.enumerated() {
|
||||
processedTexts.append(preprocessText(text, lang: langList[i]))
|
||||
}
|
||||
|
||||
// Use unicodeScalars.count for correct length after NFKD decomposition
|
||||
var textIdsLengths = [Int]()
|
||||
for text in processedTexts {
|
||||
textIdsLengths.append(text.unicodeScalars.count)
|
||||
}
|
||||
|
||||
let maxLen = textIdsLengths.max() ?? 0
|
||||
|
||||
var textIds = [[Int64]]()
|
||||
for text in processedTexts {
|
||||
var row = Array(repeating: Int64(0), count: maxLen)
|
||||
let unicodeValues = Array(text.unicodeScalars.map { Int($0.value) })
|
||||
for (j, val) in unicodeValues.enumerated() {
|
||||
if val < indexer.count {
|
||||
row[j] = indexer[val]
|
||||
} else {
|
||||
row[j] = -1
|
||||
}
|
||||
}
|
||||
textIds.append(row)
|
||||
}
|
||||
|
||||
let textMask = getTextMask(textIdsLengths)
|
||||
return (textIds, textMask)
|
||||
}
|
||||
}
|
||||
|
||||
func preprocessText(_ text: String, lang: String) -> String {
|
||||
// Use NFKD (decomposed) for proper Hangul Jamo decomposition
|
||||
var text = text.decomposedStringWithCompatibilityMapping
|
||||
|
||||
// Remove emojis (wide Unicode range)
|
||||
// Swift NSRegularExpression doesn't support Unicode escapes above \uFFFF
|
||||
// Use character filtering instead
|
||||
text = text.unicodeScalars.filter { scalar in
|
||||
let value = scalar.value
|
||||
return !((value >= 0x1F600 && value <= 0x1F64F) ||
|
||||
(value >= 0x1F300 && value <= 0x1F5FF) ||
|
||||
(value >= 0x1F680 && value <= 0x1F6FF) ||
|
||||
(value >= 0x1F700 && value <= 0x1F77F) ||
|
||||
(value >= 0x1F780 && value <= 0x1F7FF) ||
|
||||
(value >= 0x1F800 && value <= 0x1F8FF) ||
|
||||
(value >= 0x1F900 && value <= 0x1F9FF) ||
|
||||
(value >= 0x1FA00 && value <= 0x1FA6F) ||
|
||||
(value >= 0x1FA70 && value <= 0x1FAFF) ||
|
||||
(value >= 0x2600 && value <= 0x26FF) ||
|
||||
(value >= 0x2700 && value <= 0x27BF) ||
|
||||
(value >= 0x1F1E6 && value <= 0x1F1FF))
|
||||
}.map { String($0) }.joined()
|
||||
|
||||
// Replace various dashes and symbols
|
||||
let replacements: [String: String] = [
|
||||
"–": "-", // en dash
|
||||
"‑": "-", // non-breaking hyphen
|
||||
"—": "-", // em dash
|
||||
"_": " ", // underscore
|
||||
"\u{201C}": "\"", // left double quote
|
||||
"\u{201D}": "\"", // right double quote
|
||||
"\u{2018}": "'", // left single quote
|
||||
"\u{2019}": "'", // right single quote
|
||||
"´": "'", // acute accent
|
||||
"`": "'", // grave accent
|
||||
"[": " ", // left bracket
|
||||
"]": " ", // right bracket
|
||||
"|": " ", // vertical bar
|
||||
"/": " ", // slash
|
||||
"#": " ", // hash
|
||||
"→": " ", // right arrow
|
||||
"←": " ", // left arrow
|
||||
]
|
||||
|
||||
for (old, new) in replacements {
|
||||
text = text.replacingOccurrences(of: old, with: new)
|
||||
}
|
||||
|
||||
// Remove special symbols
|
||||
let specialSymbols = ["♥", "☆", "♡", "©", "\\"]
|
||||
for symbol in specialSymbols {
|
||||
text = text.replacingOccurrences(of: symbol, with: "")
|
||||
}
|
||||
|
||||
// Replace known expressions
|
||||
let exprReplacements: [String: String] = [
|
||||
"@": " at ",
|
||||
"e.g.,": "for example, ",
|
||||
"i.e.,": "that is, ",
|
||||
]
|
||||
|
||||
for (old, new) in exprReplacements {
|
||||
text = text.replacingOccurrences(of: old, with: new)
|
||||
}
|
||||
|
||||
// Fix spacing around punctuation
|
||||
text = text.replacingOccurrences(of: " ,", with: ",")
|
||||
text = text.replacingOccurrences(of: " .", with: ".")
|
||||
text = text.replacingOccurrences(of: " !", with: "!")
|
||||
text = text.replacingOccurrences(of: " ?", with: "?")
|
||||
text = text.replacingOccurrences(of: " ;", with: ";")
|
||||
text = text.replacingOccurrences(of: " :", with: ":")
|
||||
text = text.replacingOccurrences(of: " '", with: "'")
|
||||
|
||||
// Remove duplicate quotes
|
||||
while text.contains("\"\"") {
|
||||
text = text.replacingOccurrences(of: "\"\"", with: "\"")
|
||||
}
|
||||
while text.contains("''") {
|
||||
text = text.replacingOccurrences(of: "''", with: "'")
|
||||
}
|
||||
while text.contains("``") {
|
||||
text = text.replacingOccurrences(of: "``", with: "`")
|
||||
}
|
||||
|
||||
// Remove extra spaces
|
||||
let whitespacePattern = try! NSRegularExpression(pattern: "\\s+")
|
||||
let whitespaceRange = NSRange(text.startIndex..., in: text)
|
||||
text = whitespacePattern.stringByReplacingMatches(in: text, range: whitespaceRange, withTemplate: " ")
|
||||
text = text.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
|
||||
// If text doesn't end with punctuation, quotes, or closing brackets, add a period
|
||||
if !text.isEmpty {
|
||||
let punctPattern = try! NSRegularExpression(pattern: "[.!?;:,'\"\\u201C\\u201D\\u2018\\u2019)\\]}…。」』】〉》›»]$")
|
||||
let punctRange = NSRange(text.startIndex..., in: text)
|
||||
if punctPattern.firstMatch(in: text, range: punctRange) == nil {
|
||||
text += "."
|
||||
}
|
||||
}
|
||||
|
||||
// Validate language
|
||||
guard isValidLang(lang) else {
|
||||
fatalError("Invalid language: \(lang). Available: \(AVAILABLE_LANGS.joined(separator: ", "))")
|
||||
}
|
||||
|
||||
// Wrap text with language tags
|
||||
text = "<\(lang)>\(text)</\(lang)>"
|
||||
|
||||
return text
|
||||
}
|
||||
|
||||
func lengthToMask(_ lengths: [Int], maxLen: Int? = nil) -> [[[Float]]] {
|
||||
let actualMaxLen = maxLen ?? (lengths.max() ?? 0)
|
||||
|
||||
var mask = [[[Float]]]()
|
||||
for len in lengths {
|
||||
var row = Array(repeating: Float(0.0), count: actualMaxLen)
|
||||
for j in 0..<min(len, actualMaxLen) {
|
||||
row[j] = 1.0
|
||||
}
|
||||
mask.append([row])
|
||||
}
|
||||
return mask
|
||||
}
|
||||
|
||||
func getTextMask(_ textIdsLengths: [Int]) -> [[[Float]]] {
|
||||
let maxLen = textIdsLengths.max() ?? 0
|
||||
return lengthToMask(textIdsLengths, maxLen: maxLen)
|
||||
}
|
||||
|
||||
func sampleNoisyLatent(duration: [Float], sampleRate: Int, baseChunkSize: Int, chunkCompress: Int, latentDim: Int) -> (noisyLatent: [[[Float]]], latentMask: [[[Float]]]) {
|
||||
let bsz = duration.count
|
||||
let maxDur = duration.max() ?? 0.0
|
||||
|
||||
let wavLenMax = Int(maxDur * Float(sampleRate))
|
||||
var wavLengths = [Int]()
|
||||
for d in duration {
|
||||
wavLengths.append(Int(d * Float(sampleRate)))
|
||||
}
|
||||
|
||||
let chunkSize = baseChunkSize * chunkCompress
|
||||
let latentLen = (wavLenMax + chunkSize - 1) / chunkSize
|
||||
let latentDimVal = latentDim * chunkCompress
|
||||
|
||||
var noisyLatent = [[[Float]]]()
|
||||
for _ in 0..<bsz {
|
||||
var batch = [[Float]]()
|
||||
for _ in 0..<latentDimVal {
|
||||
var row = [Float]()
|
||||
for _ in 0..<latentLen {
|
||||
// Box-Muller transform
|
||||
let u1 = Float.random(in: 0.0001...1.0)
|
||||
let u2 = Float.random(in: 0.0...1.0)
|
||||
let val = sqrt(-2.0 * log(u1)) * cos(2.0 * Float.pi * u2)
|
||||
row.append(val)
|
||||
}
|
||||
batch.append(row)
|
||||
}
|
||||
noisyLatent.append(batch)
|
||||
}
|
||||
|
||||
var latentLengths = [Int]()
|
||||
for len in wavLengths {
|
||||
latentLengths.append((len + chunkSize - 1) / chunkSize)
|
||||
}
|
||||
|
||||
let latentMask = lengthToMask(latentLengths, maxLen: latentLen)
|
||||
|
||||
// Apply mask
|
||||
for b in 0..<bsz {
|
||||
for d in 0..<latentDimVal {
|
||||
for t in 0..<latentLen {
|
||||
noisyLatent[b][d][t] *= latentMask[b][0][t]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (noisyLatent, latentMask)
|
||||
}
|
||||
|
||||
func getLatentMask(_ wavLengths: [Int64], _ cfgs: Config) -> [[[Float]]] {
|
||||
let baseChunkSize = cfgs.ae.base_chunk_size
|
||||
let chunkCompressFactor = cfgs.ttl.chunk_compress_factor
|
||||
let latentSize = baseChunkSize * chunkCompressFactor
|
||||
|
||||
var latentLengths = [Int]()
|
||||
for len in wavLengths {
|
||||
latentLengths.append((Int(len) + latentSize - 1) / latentSize)
|
||||
}
|
||||
|
||||
let maxLen = latentLengths.max() ?? 0
|
||||
return lengthToMask(latentLengths, maxLen: maxLen)
|
||||
}
|
||||
|
||||
// MARK: - WAV File I/O
|
||||
|
||||
func writeWavFile(_ filename: String, _ audioData: [Float], _ sampleRate: Int) throws {
|
||||
let url = URL(fileURLWithPath: filename)
|
||||
|
||||
// Convert float to int16
|
||||
let int16Data = audioData.map { sample -> Int16 in
|
||||
let clamped = max(-1.0, min(1.0, sample))
|
||||
return Int16(clamped * 32767.0)
|
||||
}
|
||||
|
||||
// Create WAV header
|
||||
let numChannels: UInt16 = 1
|
||||
let bitsPerSample: UInt16 = 16
|
||||
let byteRate = UInt32(sampleRate) * UInt32(numChannels) * UInt32(bitsPerSample) / 8
|
||||
let blockAlign = numChannels * bitsPerSample / 8
|
||||
let dataSize = UInt32(int16Data.count * 2)
|
||||
|
||||
var data = Data()
|
||||
|
||||
// RIFF chunk
|
||||
data.append("RIFF".data(using: .ascii)!)
|
||||
withUnsafeBytes(of: UInt32(36 + dataSize).littleEndian) { data.append(contentsOf: $0) }
|
||||
data.append("WAVE".data(using: .ascii)!)
|
||||
|
||||
// fmt chunk
|
||||
data.append("fmt ".data(using: .ascii)!)
|
||||
withUnsafeBytes(of: UInt32(16).littleEndian) { data.append(contentsOf: $0) }
|
||||
withUnsafeBytes(of: UInt16(1).littleEndian) { data.append(contentsOf: $0) } // PCM
|
||||
withUnsafeBytes(of: numChannels.littleEndian) { data.append(contentsOf: $0) }
|
||||
withUnsafeBytes(of: UInt32(sampleRate).littleEndian) { data.append(contentsOf: $0) }
|
||||
withUnsafeBytes(of: byteRate.littleEndian) { data.append(contentsOf: $0) }
|
||||
withUnsafeBytes(of: blockAlign.littleEndian) { data.append(contentsOf: $0) }
|
||||
withUnsafeBytes(of: bitsPerSample.littleEndian) { data.append(contentsOf: $0) }
|
||||
|
||||
// data chunk
|
||||
data.append("data".data(using: .ascii)!)
|
||||
withUnsafeBytes(of: dataSize.littleEndian) { data.append(contentsOf: $0) }
|
||||
|
||||
// audio data
|
||||
int16Data.withUnsafeBytes { data.append(contentsOf: $0) }
|
||||
|
||||
try data.write(to: url)
|
||||
}
|
||||
|
||||
// MARK: - Text Chunking
|
||||
|
||||
let MAX_CHUNK_LENGTH = 300
|
||||
let ABBREVIATIONS = [
|
||||
"Dr.", "Mr.", "Mrs.", "Ms.", "Prof.", "Sr.", "Jr.",
|
||||
"St.", "Ave.", "Rd.", "Blvd.", "Dept.", "Inc.", "Ltd.",
|
||||
"Co.", "Corp.", "etc.", "vs.", "i.e.", "e.g.", "Ph.D."
|
||||
]
|
||||
|
||||
func chunkText(_ text: String, maxLen: Int = 0) -> [String] {
|
||||
let actualMaxLen = maxLen > 0 ? maxLen : MAX_CHUNK_LENGTH
|
||||
let trimmedText = text.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines)
|
||||
|
||||
if trimmedText.isEmpty {
|
||||
return [""]
|
||||
}
|
||||
|
||||
// Split by paragraphs using regex
|
||||
let paraPattern = try! NSRegularExpression(pattern: "\\n\\s*\\n")
|
||||
let paraRange = NSRange(trimmedText.startIndex..., in: trimmedText)
|
||||
var paragraphs = [String]()
|
||||
var lastEnd = trimmedText.startIndex
|
||||
|
||||
paraPattern.enumerateMatches(in: trimmedText, range: paraRange) { match, _, _ in
|
||||
if let match = match, let range = Range(match.range, in: trimmedText) {
|
||||
paragraphs.append(String(trimmedText[lastEnd..<range.lowerBound]))
|
||||
lastEnd = range.upperBound
|
||||
}
|
||||
}
|
||||
if lastEnd < trimmedText.endIndex {
|
||||
paragraphs.append(String(trimmedText[lastEnd...]))
|
||||
}
|
||||
if paragraphs.isEmpty {
|
||||
paragraphs = [trimmedText]
|
||||
}
|
||||
|
||||
var chunks = [String]()
|
||||
|
||||
for para in paragraphs {
|
||||
let trimmedPara = para.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines)
|
||||
if trimmedPara.isEmpty {
|
||||
continue
|
||||
}
|
||||
|
||||
if trimmedPara.count <= actualMaxLen {
|
||||
chunks.append(trimmedPara)
|
||||
continue
|
||||
}
|
||||
|
||||
// Split by sentences
|
||||
let sentences = splitSentences(trimmedPara)
|
||||
var current = ""
|
||||
var currentLen = 0
|
||||
|
||||
for sentence in sentences {
|
||||
let trimmedSentence = sentence.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines)
|
||||
if trimmedSentence.isEmpty {
|
||||
continue
|
||||
}
|
||||
|
||||
let sentenceLen = trimmedSentence.count
|
||||
if sentenceLen > actualMaxLen {
|
||||
// If sentence is longer than maxLen, split by comma or space
|
||||
if !current.isEmpty {
|
||||
chunks.append(current.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines))
|
||||
current = ""
|
||||
currentLen = 0
|
||||
}
|
||||
|
||||
// Try splitting by comma
|
||||
let parts = trimmedSentence.components(separatedBy: ",")
|
||||
for part in parts {
|
||||
let trimmedPart = part.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines)
|
||||
if trimmedPart.isEmpty {
|
||||
continue
|
||||
}
|
||||
|
||||
let partLen = trimmedPart.count
|
||||
if partLen > actualMaxLen {
|
||||
// Split by space as last resort
|
||||
let words = trimmedPart.components(separatedBy: CharacterSet.whitespaces).filter { !$0.isEmpty }
|
||||
var wordChunk = ""
|
||||
var wordChunkLen = 0
|
||||
|
||||
for word in words {
|
||||
let wordLen = word.count
|
||||
if wordChunkLen + wordLen + 1 > actualMaxLen && !wordChunk.isEmpty {
|
||||
chunks.append(wordChunk.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines))
|
||||
wordChunk = ""
|
||||
wordChunkLen = 0
|
||||
}
|
||||
|
||||
if !wordChunk.isEmpty {
|
||||
wordChunk += " "
|
||||
wordChunkLen += 1
|
||||
}
|
||||
wordChunk += word
|
||||
wordChunkLen += wordLen
|
||||
}
|
||||
|
||||
if !wordChunk.isEmpty {
|
||||
chunks.append(wordChunk.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines))
|
||||
}
|
||||
} else {
|
||||
if currentLen + partLen + 1 > actualMaxLen && !current.isEmpty {
|
||||
chunks.append(current.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines))
|
||||
current = ""
|
||||
currentLen = 0
|
||||
}
|
||||
|
||||
if !current.isEmpty {
|
||||
current += ", "
|
||||
currentLen += 2
|
||||
}
|
||||
current += trimmedPart
|
||||
currentLen += partLen
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if currentLen + sentenceLen + 1 > actualMaxLen && !current.isEmpty {
|
||||
chunks.append(current.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines))
|
||||
current = ""
|
||||
currentLen = 0
|
||||
}
|
||||
|
||||
if !current.isEmpty {
|
||||
current += " "
|
||||
currentLen += 1
|
||||
}
|
||||
current += trimmedSentence
|
||||
currentLen += sentenceLen
|
||||
}
|
||||
|
||||
if !current.isEmpty {
|
||||
chunks.append(current.trimmingCharacters(in: CharacterSet.whitespacesAndNewlines))
|
||||
}
|
||||
}
|
||||
|
||||
return chunks.isEmpty ? [""] : chunks
|
||||
}
|
||||
|
||||
func splitSentences(_ text: String) -> [String] {
|
||||
// Swift's regex doesn't support lookbehind reliably, so we use a simpler approach
|
||||
// Split on sentence boundaries and then check if they're abbreviations
|
||||
let regex = try! NSRegularExpression(pattern: "([.!?])\\s+")
|
||||
let range = NSRange(text.startIndex..., in: text)
|
||||
|
||||
// Find all matches
|
||||
let matches = regex.matches(in: text, range: range)
|
||||
if matches.isEmpty {
|
||||
return [text]
|
||||
}
|
||||
|
||||
var sentences = [String]()
|
||||
var lastEnd = text.startIndex
|
||||
|
||||
for match in matches {
|
||||
guard let matchRange = Range(match.range, in: text) else { continue }
|
||||
|
||||
// Get the text before the punctuation
|
||||
let beforePunc = String(text[lastEnd..<matchRange.lowerBound])
|
||||
|
||||
// Get the punctuation character
|
||||
let puncRange = Range(NSRange(location: match.range.location, length: 1), in: text)!
|
||||
let punc = String(text[puncRange])
|
||||
|
||||
// Check if this ends with an abbreviation
|
||||
var isAbbrev = false
|
||||
let combined = beforePunc.trimmingCharacters(in: CharacterSet.whitespaces) + punc
|
||||
for abbrev in ABBREVIATIONS {
|
||||
if combined.hasSuffix(abbrev) {
|
||||
isAbbrev = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isAbbrev {
|
||||
// This is a real sentence boundary
|
||||
sentences.append(String(text[lastEnd..<matchRange.upperBound]))
|
||||
lastEnd = matchRange.upperBound
|
||||
}
|
||||
}
|
||||
|
||||
// Add the remaining text
|
||||
if lastEnd < text.endIndex {
|
||||
sentences.append(String(text[lastEnd...]))
|
||||
}
|
||||
|
||||
return sentences.isEmpty ? [text] : sentences
|
||||
}
|
||||
|
||||
// MARK: - Utility Functions
|
||||
|
||||
func timer<T>(_ name: String, _ f: () throws -> T) rethrows -> T {
|
||||
let start = Date()
|
||||
print("\(name)...")
|
||||
let result = try f()
|
||||
let elapsed = Date().timeIntervalSince(start)
|
||||
print(String(format: " -> %@ completed in %.2f sec", name, elapsed))
|
||||
return result
|
||||
}
|
||||
|
||||
func sanitizeFilename(_ text: String, maxLen: Int) -> String {
|
||||
let truncated = text.count > maxLen ? String(text.prefix(maxLen)) : text
|
||||
return truncated.map { char in
|
||||
if char.isLetter || char.isNumber {
|
||||
return char
|
||||
} else {
|
||||
return Character("_")
|
||||
}
|
||||
}.map(String.init).joined()
|
||||
}
|
||||
|
||||
func loadCfgs(_ onnxDir: String) throws -> Config {
|
||||
let cfgPath = "\(onnxDir)/tts.json"
|
||||
let data = try Data(contentsOf: URL(fileURLWithPath: cfgPath))
|
||||
let config = try JSONDecoder().decode(Config.self, from: data)
|
||||
return config
|
||||
}
|
||||
|
||||
// MARK: - ONNX Runtime Integration
|
||||
|
||||
struct Style {
|
||||
let ttl: ORTValue
|
||||
let dp: ORTValue
|
||||
}
|
||||
|
||||
class TextToSpeech {
|
||||
let cfgs: Config
|
||||
let textProcessor: UnicodeProcessor
|
||||
let dpOrt: ORTSession
|
||||
let textEncOrt: ORTSession
|
||||
let vectorEstOrt: ORTSession
|
||||
let vocoderOrt: ORTSession
|
||||
let sampleRate: Int
|
||||
|
||||
init(cfgs: Config, textProcessor: UnicodeProcessor,
|
||||
dpOrt: ORTSession, textEncOrt: ORTSession,
|
||||
vectorEstOrt: ORTSession, vocoderOrt: ORTSession) {
|
||||
self.cfgs = cfgs
|
||||
self.textProcessor = textProcessor
|
||||
self.dpOrt = dpOrt
|
||||
self.textEncOrt = textEncOrt
|
||||
self.vectorEstOrt = vectorEstOrt
|
||||
self.vocoderOrt = vocoderOrt
|
||||
self.sampleRate = cfgs.ae.sample_rate
|
||||
}
|
||||
|
||||
private func _infer(_ textList: [String], _ langList: [String], _ style: Style, _ totalStep: Int, speed: Float = 1.05) throws -> (wav: [Float], duration: [Float]) {
|
||||
let bsz = textList.count
|
||||
|
||||
// Process text
|
||||
let (textIds, textMask) = textProcessor.call(textList, langList)
|
||||
|
||||
// Flatten text IDs
|
||||
let textIdsFlat = textIds.flatMap { $0 }
|
||||
let textIdsShape: [NSNumber] = [NSNumber(value: bsz), NSNumber(value: textIds[0].count)]
|
||||
let textIdsValue = try ORTValue(tensorData: NSMutableData(bytes: textIdsFlat, length: textIdsFlat.count * MemoryLayout<Int64>.size),
|
||||
elementType: .int64,
|
||||
shape: textIdsShape)
|
||||
|
||||
// Flatten text mask
|
||||
let textMaskFlat = textMask.flatMap { $0.flatMap { $0 } }
|
||||
let textMaskShape: [NSNumber] = [NSNumber(value: bsz), 1, NSNumber(value: textMask[0][0].count)]
|
||||
let textMaskValue = try ORTValue(tensorData: NSMutableData(bytes: textMaskFlat, length: textMaskFlat.count * MemoryLayout<Float>.size),
|
||||
elementType: .float,
|
||||
shape: textMaskShape)
|
||||
|
||||
// Predict duration
|
||||
let dpOutputs = try dpOrt.run(withInputs: ["text_ids": textIdsValue, "style_dp": style.dp, "text_mask": textMaskValue],
|
||||
outputNames: ["duration"],
|
||||
runOptions: nil)
|
||||
|
||||
let durationData = try dpOutputs["duration"]!.tensorData() as Data
|
||||
var duration = durationData.withUnsafeBytes { ptr in
|
||||
Array(ptr.bindMemory(to: Float.self))
|
||||
}
|
||||
|
||||
// Apply speed factor to duration
|
||||
for i in 0..<duration.count {
|
||||
duration[i] /= speed
|
||||
}
|
||||
|
||||
// Encode text
|
||||
let textEncOutputs = try textEncOrt.run(withInputs: ["text_ids": textIdsValue, "style_ttl": style.ttl, "text_mask": textMaskValue],
|
||||
outputNames: ["text_emb"],
|
||||
runOptions: nil)
|
||||
|
||||
let textEmbValue = textEncOutputs["text_emb"]!
|
||||
|
||||
// Sample noisy latent
|
||||
var (xt, latentMask) = sampleNoisyLatent(duration: duration, sampleRate: sampleRate,
|
||||
baseChunkSize: cfgs.ae.base_chunk_size,
|
||||
chunkCompress: cfgs.ttl.chunk_compress_factor,
|
||||
latentDim: cfgs.ttl.latent_dim)
|
||||
|
||||
// Prepare constant arrays
|
||||
let totalStepArray = Array(repeating: Float(totalStep), count: bsz)
|
||||
let totalStepValue = try ORTValue(tensorData: NSMutableData(bytes: totalStepArray, length: totalStepArray.count * MemoryLayout<Float>.size),
|
||||
elementType: .float,
|
||||
shape: [NSNumber(value: bsz)])
|
||||
|
||||
// Denoising loop
|
||||
for step in 0..<totalStep {
|
||||
let currentStepArray = Array(repeating: Float(step), count: bsz)
|
||||
let currentStepValue = try ORTValue(tensorData: NSMutableData(bytes: currentStepArray, length: currentStepArray.count * MemoryLayout<Float>.size),
|
||||
elementType: .float,
|
||||
shape: [NSNumber(value: bsz)])
|
||||
|
||||
// Flatten xt
|
||||
let xtFlat = xt.flatMap { $0.flatMap { $0 } }
|
||||
let xtShape: [NSNumber] = [NSNumber(value: bsz), NSNumber(value: xt[0].count), NSNumber(value: xt[0][0].count)]
|
||||
let xtValue = try ORTValue(tensorData: NSMutableData(bytes: xtFlat, length: xtFlat.count * MemoryLayout<Float>.size),
|
||||
elementType: .float,
|
||||
shape: xtShape)
|
||||
|
||||
// Flatten latent mask
|
||||
let latentMaskFlat = latentMask.flatMap { $0.flatMap { $0 } }
|
||||
let latentMaskShape: [NSNumber] = [NSNumber(value: bsz), 1, NSNumber(value: latentMask[0][0].count)]
|
||||
let latentMaskValue = try ORTValue(tensorData: NSMutableData(bytes: latentMaskFlat, length: latentMaskFlat.count * MemoryLayout<Float>.size),
|
||||
elementType: .float,
|
||||
shape: latentMaskShape)
|
||||
|
||||
let vectorEstOutputs = try vectorEstOrt.run(withInputs: [
|
||||
"noisy_latent": xtValue,
|
||||
"text_emb": textEmbValue,
|
||||
"style_ttl": style.ttl,
|
||||
"latent_mask": latentMaskValue,
|
||||
"text_mask": textMaskValue,
|
||||
"current_step": currentStepValue,
|
||||
"total_step": totalStepValue
|
||||
], outputNames: ["denoised_latent"], runOptions: nil)
|
||||
|
||||
let denoisedData = try vectorEstOutputs["denoised_latent"]!.tensorData() as Data
|
||||
let denoisedFlat = denoisedData.withUnsafeBytes { ptr in
|
||||
Array(ptr.bindMemory(to: Float.self))
|
||||
}
|
||||
|
||||
// Reshape to 3D
|
||||
let latentDimVal = xt[0].count
|
||||
let latentLen = xt[0][0].count
|
||||
xt = []
|
||||
var idx = 0
|
||||
for _ in 0..<bsz {
|
||||
var batch = [[Float]]()
|
||||
for _ in 0..<latentDimVal {
|
||||
var row = [Float]()
|
||||
for _ in 0..<latentLen {
|
||||
row.append(denoisedFlat[idx])
|
||||
idx += 1
|
||||
}
|
||||
batch.append(row)
|
||||
}
|
||||
xt.append(batch)
|
||||
}
|
||||
}
|
||||
|
||||
// Generate waveform
|
||||
let finalXtFlat = xt.flatMap { $0.flatMap { $0 } }
|
||||
let finalXtShape: [NSNumber] = [NSNumber(value: bsz), NSNumber(value: xt[0].count), NSNumber(value: xt[0][0].count)]
|
||||
let finalXtValue = try ORTValue(tensorData: NSMutableData(bytes: finalXtFlat, length: finalXtFlat.count * MemoryLayout<Float>.size),
|
||||
elementType: .float,
|
||||
shape: finalXtShape)
|
||||
|
||||
let vocoderOutputs = try vocoderOrt.run(withInputs: ["latent": finalXtValue],
|
||||
outputNames: ["wav_tts"],
|
||||
runOptions: nil)
|
||||
|
||||
let wavData = try vocoderOutputs["wav_tts"]!.tensorData() as Data
|
||||
let wav = wavData.withUnsafeBytes { ptr in
|
||||
Array(ptr.bindMemory(to: Float.self))
|
||||
}
|
||||
|
||||
return (wav, duration)
|
||||
}
|
||||
|
||||
func call(_ text: String, _ lang: String, _ style: Style, _ totalStep: Int, speed: Float = 1.05, silenceDuration: Float = 0.3) throws -> (wav: [Float], duration: Float) {
|
||||
let maxLen = lang == "ko" ? 120 : 300
|
||||
let chunks = chunkText(text, maxLen: maxLen)
|
||||
let langList = Array(repeating: lang, count: chunks.count)
|
||||
|
||||
var wavCat = [Float]()
|
||||
var durCat: Float = 0.0
|
||||
|
||||
for (i, chunk) in chunks.enumerated() {
|
||||
let result = try _infer([chunk], [langList[i]], style, totalStep, speed: speed)
|
||||
|
||||
let dur = result.duration[0]
|
||||
let wavLen = Int(Float(sampleRate) * dur)
|
||||
let wavChunk = Array(result.wav.prefix(wavLen))
|
||||
|
||||
if i == 0 {
|
||||
wavCat = wavChunk
|
||||
durCat = dur
|
||||
} else {
|
||||
let silenceLen = Int(silenceDuration * Float(sampleRate))
|
||||
let silence = [Float](repeating: 0.0, count: silenceLen)
|
||||
|
||||
wavCat.append(contentsOf: silence)
|
||||
wavCat.append(contentsOf: wavChunk)
|
||||
durCat += silenceDuration + dur
|
||||
}
|
||||
}
|
||||
|
||||
return (wavCat, durCat)
|
||||
}
|
||||
|
||||
func batch(_ textList: [String], _ langList: [String], _ style: Style, _ totalStep: Int, speed: Float = 1.05) throws -> (wav: [Float], duration: [Float]) {
|
||||
return try _infer(textList, langList, style, totalStep, speed: speed)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Component Loading Functions
|
||||
|
||||
func loadVoiceStyle(_ voiceStylePaths: [String], verbose: Bool) throws -> Style {
|
||||
let bsz = voiceStylePaths.count
|
||||
|
||||
// Read first file to get dimensions
|
||||
let firstData = try Data(contentsOf: URL(fileURLWithPath: voiceStylePaths[0]))
|
||||
let firstStyle = try JSONDecoder().decode(VoiceStyleData.self, from: firstData)
|
||||
|
||||
let ttlDims = firstStyle.style_ttl.dims
|
||||
let dpDims = firstStyle.style_dp.dims
|
||||
|
||||
let ttlDim1 = ttlDims[1]
|
||||
let ttlDim2 = ttlDims[2]
|
||||
let dpDim1 = dpDims[1]
|
||||
let dpDim2 = dpDims[2]
|
||||
|
||||
// Pre-allocate arrays with full batch size
|
||||
let ttlSize = bsz * ttlDim1 * ttlDim2
|
||||
let dpSize = bsz * dpDim1 * dpDim2
|
||||
var ttlFlat = [Float](repeating: 0.0, count: ttlSize)
|
||||
var dpFlat = [Float](repeating: 0.0, count: dpSize)
|
||||
|
||||
// Fill in the data
|
||||
for (i, path) in voiceStylePaths.enumerated() {
|
||||
let data = try Data(contentsOf: URL(fileURLWithPath: path))
|
||||
let voiceStyle = try JSONDecoder().decode(VoiceStyleData.self, from: data)
|
||||
|
||||
// Flatten TTL data
|
||||
let ttlOffset = i * ttlDim1 * ttlDim2
|
||||
var idx = 0
|
||||
for batch in voiceStyle.style_ttl.data {
|
||||
for row in batch {
|
||||
for val in row {
|
||||
ttlFlat[ttlOffset + idx] = val
|
||||
idx += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Flatten DP data
|
||||
let dpOffset = i * dpDim1 * dpDim2
|
||||
idx = 0
|
||||
for batch in voiceStyle.style_dp.data {
|
||||
for row in batch {
|
||||
for val in row {
|
||||
dpFlat[dpOffset + idx] = val
|
||||
idx += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let ttlShape: [NSNumber] = [NSNumber(value: bsz), NSNumber(value: ttlDim1), NSNumber(value: ttlDim2)]
|
||||
let dpShape: [NSNumber] = [NSNumber(value: bsz), NSNumber(value: dpDim1), NSNumber(value: dpDim2)]
|
||||
|
||||
let ttlValue = try ORTValue(tensorData: NSMutableData(bytes: &ttlFlat, length: ttlFlat.count * MemoryLayout<Float>.size),
|
||||
elementType: .float,
|
||||
shape: ttlShape)
|
||||
let dpValue = try ORTValue(tensorData: NSMutableData(bytes: &dpFlat, length: dpFlat.count * MemoryLayout<Float>.size),
|
||||
elementType: .float,
|
||||
shape: dpShape)
|
||||
|
||||
if verbose {
|
||||
print("Loaded \(bsz) voice styles\n")
|
||||
}
|
||||
|
||||
return Style(ttl: ttlValue, dp: dpValue)
|
||||
}
|
||||
|
||||
func loadTextToSpeech(_ onnxDir: String, _ useGpu: Bool, _ env: ORTEnv) throws -> TextToSpeech {
|
||||
if useGpu {
|
||||
throw NSError(domain: "TTS", code: 1, userInfo: [NSLocalizedDescriptionKey: "GPU mode is not supported yet"])
|
||||
}
|
||||
print("Using CPU for inference\n")
|
||||
|
||||
let cfgs = try loadCfgs(onnxDir)
|
||||
|
||||
let sessionOptions = try ORTSessionOptions()
|
||||
|
||||
let dpPath = "\(onnxDir)/duration_predictor.onnx"
|
||||
let textEncPath = "\(onnxDir)/text_encoder.onnx"
|
||||
let vectorEstPath = "\(onnxDir)/vector_estimator.onnx"
|
||||
let vocoderPath = "\(onnxDir)/vocoder.onnx"
|
||||
|
||||
let dpOrt = try ORTSession(env: env, modelPath: dpPath, sessionOptions: sessionOptions)
|
||||
let textEncOrt = try ORTSession(env: env, modelPath: textEncPath, sessionOptions: sessionOptions)
|
||||
let vectorEstOrt = try ORTSession(env: env, modelPath: vectorEstPath, sessionOptions: sessionOptions)
|
||||
let vocoderOrt = try ORTSession(env: env, modelPath: vocoderPath, sessionOptions: sessionOptions)
|
||||
|
||||
let unicodeIndexerPath = "\(onnxDir)/unicode_indexer.json"
|
||||
let textProcessor = try UnicodeProcessor(unicodeIndexerPath: unicodeIndexerPath)
|
||||
|
||||
return TextToSpeech(cfgs: cfgs, textProcessor: textProcessor,
|
||||
dpOrt: dpOrt, textEncOrt: textEncOrt,
|
||||
vectorEstOrt: vectorEstOrt, vocoderOrt: vocoderOrt)
|
||||
}
|
||||
Reference in New Issue
Block a user