From 77af47274c4d12a821dba58b99b90c59d021387d Mon Sep 17 00:00:00 2001 From: LGram16 Date: Sun, 25 Jan 2026 18:58:40 +0900 Subject: [PATCH] initial commit --- .gitattributes | 3 + .gitignore | 62 + LICENSE | 21 + README.md | 452 +++++++ cpp/CMakeLists.txt | 122 ++ cpp/README.md | 139 ++ cpp/example_onnx.cpp | 121 ++ cpp/helper.cpp | 1186 +++++++++++++++++ cpp/helper.h | 229 ++++ csharp/.gitignore | 41 + csharp/ExampleONNX.cs | 171 +++ csharp/Helper.cs | 861 ++++++++++++ csharp/Properties/launchSettings.json | 8 + csharp/README.md | 137 ++ csharp/Supertonic.csproj | 18 + csharp/csharp.sln | 24 + flutter/.gitignore | 45 + flutter/.metadata | 30 + flutter/README.md | 38 + flutter/analysis_options.yaml | 28 + flutter/lib/helper.dart | 695 ++++++++++ flutter/lib/main.dart | 391 ++++++ flutter/macos/.gitignore | 7 + flutter/macos/Flutter/Flutter-Debug.xcconfig | 2 + .../macos/Flutter/Flutter-Release.xcconfig | 2 + .../Flutter/GeneratedPluginRegistrant.swift | 16 + flutter/macos/Podfile | 45 + flutter/macos/Podfile.lock | 54 + flutter/macos/Runner/AppDelegate.swift | 13 + .../AppIcon.appiconset/Contents.json | 68 + .../AppIcon.appiconset/app_icon_1024.png | Bin 0 -> 102994 bytes .../AppIcon.appiconset/app_icon_128.png | Bin 0 -> 5680 bytes .../AppIcon.appiconset/app_icon_16.png | Bin 0 -> 520 bytes .../AppIcon.appiconset/app_icon_256.png | Bin 0 -> 14142 bytes .../AppIcon.appiconset/app_icon_32.png | Bin 0 -> 1066 bytes .../AppIcon.appiconset/app_icon_512.png | Bin 0 -> 36406 bytes .../AppIcon.appiconset/app_icon_64.png | Bin 0 -> 2218 bytes flutter/macos/Runner/Base.lproj/MainMenu.xib | 343 +++++ flutter/macos/Runner/Configs/AppInfo.xcconfig | 14 + flutter/macos/Runner/Configs/Debug.xcconfig | 2 + flutter/macos/Runner/Configs/Release.xcconfig | 2 + .../macos/Runner/Configs/Warnings.xcconfig | 13 + .../macos/Runner/DebugProfile.entitlements | 12 + flutter/macos/Runner/Info.plist | 32 + flutter/macos/Runner/MainFlutterWindow.swift | 15 + flutter/macos/Runner/Release.entitlements | 8 + flutter/macos/RunnerTests/RunnerTests.swift | 12 + flutter/pubspec.lock | 418 ++++++ flutter/pubspec.yaml | 26 + go/.gitignore | 17 + go/README.md | 165 +++ go/example_onnx.go | 193 +++ go/go.mod | 13 + go/go.sum | 12 + go/helper.go | 1066 +++++++++++++++ img/supertonic_preview_0.1.jpg | Bin 0 -> 784828 bytes img/voicebuilder_img.png | Bin 0 -> 458777 bytes ios/ExampleiOSApp/App.swift | 10 + ios/ExampleiOSApp/AudioPlayer.swift | 30 + ios/ExampleiOSApp/ContentView.swift | 99 ++ ios/ExampleiOSApp/Info.plist | 29 + ios/ExampleiOSApp/TTSService.swift | 114 ++ ios/ExampleiOSApp/TTSViewModel.swift | 82 ++ ios/ExampleiOSApp/project.yml | 29 + ios/README.md | 78 ++ java/.gitignore | 35 + java/ExampleONNX.java | 183 +++ java/Helper.java | 955 +++++++++++++ java/README.md | 130 ++ java/pom.xml | 110 ++ nodejs/README.md | 140 ++ nodejs/example_onnx.js | 119 ++ nodejs/helper.js | 559 ++++++++ nodejs/package.json | 26 + py/README.md | 145 ++ py/example_onnx.py | 116 ++ py/example_pypi.py | 16 + py/helper.py | 429 ++++++ py/pyproject.toml | 20 + py/requirements.txt | 5 + py/uv.lock | 1142 ++++++++++++++++ rust/.gitignore | 21 + rust/Cargo.toml | 44 + rust/README.md | 146 ++ rust/src/example_onnx.rs | 144 ++ rust/src/helper.rs | 838 ++++++++++++ swift/.gitignore | 15 + swift/Package.resolved | 14 + swift/Package.swift | 22 + swift/README.md | 122 ++ swift/Sources/ExampleONNX.swift | 163 +++ swift/Sources/Helper.swift | 835 ++++++++++++ test_all.sh | 330 +++++ web/.gitignore | 4 + web/README.md | 121 ++ web/helper.js | 561 ++++++++ web/index.html | 95 ++ web/main.js | 291 ++++ web/package.json | 21 + web/style.css | 453 +++++++ web/vite.config.js | 14 + 101 files changed, 16247 insertions(+) create mode 100644 .gitattributes create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 cpp/CMakeLists.txt create mode 100644 cpp/README.md create mode 100644 cpp/example_onnx.cpp create mode 100644 cpp/helper.cpp create mode 100644 cpp/helper.h create mode 100644 csharp/.gitignore create mode 100644 csharp/ExampleONNX.cs create mode 100644 csharp/Helper.cs create mode 100644 csharp/Properties/launchSettings.json create mode 100644 csharp/README.md create mode 100644 csharp/Supertonic.csproj create mode 100644 csharp/csharp.sln create mode 100644 flutter/.gitignore create mode 100644 flutter/.metadata create mode 100644 flutter/README.md create mode 100644 flutter/analysis_options.yaml create mode 100644 flutter/lib/helper.dart create mode 100644 flutter/lib/main.dart create mode 100644 flutter/macos/.gitignore create mode 100644 flutter/macos/Flutter/Flutter-Debug.xcconfig create mode 100644 flutter/macos/Flutter/Flutter-Release.xcconfig create mode 100644 flutter/macos/Flutter/GeneratedPluginRegistrant.swift create mode 100644 flutter/macos/Podfile create mode 100644 flutter/macos/Podfile.lock create mode 100644 flutter/macos/Runner/AppDelegate.swift create mode 100644 flutter/macos/Runner/Assets.xcassets/AppIcon.appiconset/Contents.json create mode 100644 flutter/macos/Runner/Assets.xcassets/AppIcon.appiconset/app_icon_1024.png create mode 100644 flutter/macos/Runner/Assets.xcassets/AppIcon.appiconset/app_icon_128.png create mode 100644 flutter/macos/Runner/Assets.xcassets/AppIcon.appiconset/app_icon_16.png create mode 100644 flutter/macos/Runner/Assets.xcassets/AppIcon.appiconset/app_icon_256.png create mode 100644 flutter/macos/Runner/Assets.xcassets/AppIcon.appiconset/app_icon_32.png create mode 100644 flutter/macos/Runner/Assets.xcassets/AppIcon.appiconset/app_icon_512.png create mode 100644 flutter/macos/Runner/Assets.xcassets/AppIcon.appiconset/app_icon_64.png create mode 100644 flutter/macos/Runner/Base.lproj/MainMenu.xib create mode 100644 flutter/macos/Runner/Configs/AppInfo.xcconfig create mode 100644 flutter/macos/Runner/Configs/Debug.xcconfig create mode 100644 flutter/macos/Runner/Configs/Release.xcconfig create mode 100644 flutter/macos/Runner/Configs/Warnings.xcconfig create mode 100644 flutter/macos/Runner/DebugProfile.entitlements create mode 100644 flutter/macos/Runner/Info.plist create mode 100644 flutter/macos/Runner/MainFlutterWindow.swift create mode 100644 flutter/macos/Runner/Release.entitlements create mode 100644 flutter/macos/RunnerTests/RunnerTests.swift create mode 100644 flutter/pubspec.lock create mode 100644 flutter/pubspec.yaml create mode 100644 go/.gitignore create mode 100644 go/README.md create mode 100644 go/example_onnx.go create mode 100644 go/go.mod create mode 100644 go/go.sum create mode 100644 go/helper.go create mode 100644 img/supertonic_preview_0.1.jpg create mode 100644 img/voicebuilder_img.png create mode 100644 ios/ExampleiOSApp/App.swift create mode 100644 ios/ExampleiOSApp/AudioPlayer.swift create mode 100644 ios/ExampleiOSApp/ContentView.swift create mode 100644 ios/ExampleiOSApp/Info.plist create mode 100644 ios/ExampleiOSApp/TTSService.swift create mode 100644 ios/ExampleiOSApp/TTSViewModel.swift create mode 100644 ios/ExampleiOSApp/project.yml create mode 100644 ios/README.md create mode 100644 java/.gitignore create mode 100644 java/ExampleONNX.java create mode 100644 java/Helper.java create mode 100644 java/README.md create mode 100644 java/pom.xml create mode 100644 nodejs/README.md create mode 100644 nodejs/example_onnx.js create mode 100644 nodejs/helper.js create mode 100644 nodejs/package.json create mode 100644 py/README.md create mode 100644 py/example_onnx.py create mode 100644 py/example_pypi.py create mode 100644 py/helper.py create mode 100644 py/pyproject.toml create mode 100644 py/requirements.txt create mode 100644 py/uv.lock create mode 100644 rust/.gitignore create mode 100644 rust/Cargo.toml create mode 100644 rust/README.md create mode 100644 rust/src/example_onnx.rs create mode 100644 rust/src/helper.rs create mode 100644 swift/.gitignore create mode 100644 swift/Package.resolved create mode 100644 swift/Package.swift create mode 100644 swift/README.md create mode 100644 swift/Sources/ExampleONNX.swift create mode 100644 swift/Sources/Helper.swift create mode 100644 test_all.sh create mode 100644 web/.gitignore create mode 100644 web/README.md create mode 100644 web/helper.js create mode 100644 web/index.html create mode 100644 web/main.js create mode 100644 web/package.json create mode 100644 web/style.css create mode 100644 web/vite.config.js diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..61f4aa3 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,3 @@ +assets/onnx/*.onnx filter=lfs diff=lfs merge=lfs -text +ios/** linguist-ignore +web/** linguist-ignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..99f7df7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,62 @@ +assets/* +assets/.git +assets/.gitignore +assets/.gitattributes + +*.onnx +onnx + +# Output files +results + +# Python +__pycache__ +*.py[cod] +*$py.class +*.so +.Python + +# Virtual environments +.venv +venv/ +ENV/ +env/ + +# Node.js +node_modules/ +npm-debug.log* +yarn-debug.log* +yarn-error.log* +package-lock.json + +# Swift +.build/ +.swiftpm/ +*.xcodeproj +*.xcworkspace +xcuserdata/ +DerivedData/ + +# Distribution / packaging +build/ +dist/ +*.egg-info/ +.eggs/ + +# Testing +.pytest_cache/ +.coverage +htmlcov/ +.tox/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# OS +.DS_Store +Thumbs.db +assets diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..943171b --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Supertone Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..7b51db4 --- /dev/null +++ b/README.md @@ -0,0 +1,452 @@ +# Supertonic โ€” Lightning Fast, On-Device TTS + +[![v2 Demo](https://img.shields.io/badge/๐Ÿค—%20v2-Demo-yellow)](https://huggingface.co/spaces/Supertone/supertonic-2) +[![v2 Models](https://img.shields.io/badge/๐Ÿค—%20v2-Models-blue)](https://huggingface.co/Supertone/supertonic-2) +[![v1 Demo](https://img.shields.io/badge/๐Ÿค—%20v1%20(old)-Demo-lightgrey)](https://huggingface.co/spaces/Supertone/supertonic#interactive-demo) +[![v1 Models](https://img.shields.io/badge/๐Ÿค—%20v1%20(old)-Models-lightgrey)](https://huggingface.co/Supertone/supertonic) + +

+ Supertonic Banner +

+ +**Supertonic** is a lightning-fast, on-device text-to-speech system designed for **extreme performance** with minimal computational overhead. Powered by ONNX Runtime, it runs entirely on your deviceโ€”no cloud, no API calls, no privacy concerns. + +### ๐Ÿ“ฐ Update News + +- **2026.01.22** - **[Voice Builder](https://supertonic.supertone.ai/voice_builder)** is now live! Turn your voice into a deployable, edge-native TTS with permanent ownership. + +

+ Voice Builder +

+ +- **2026.01.06** - ๐ŸŽ‰ **Supertonic 2** released with multilingual support! Now supports English (`en`), Korean (`ko`), Spanish (`es`), Portuguese (`pt`), and French (`fr`). [Demo](https://huggingface.co/spaces/Supertone/supertonic-2) | [Models](https://huggingface.co/Supertone/supertonic-2) +- **2025.12.10** - Added `supertonic` PyPI package! Install via `pip install supertonic`. For details, visit [supertonic-py documentation](https://supertone-inc.github.io/supertonic-py) +- **2025.12.10** - Added [6 new voice styles](https://huggingface.co/Supertone/supertonic/tree/b10dbaf18b316159be75b34d24f740008fddd381) (M3, M4, M5, F3, F4, F5). See [Voices](https://supertone-inc.github.io/supertonic-py/voices/) for details +- **2025.12.08** - Optimized ONNX models via [OnnxSlim](https://github.com/inisis/OnnxSlim) now available on [Hugging Face Models](https://huggingface.co/Supertone/supertonic) +- **2025.11.24** - Added Flutter SDK support with macOS compatibility + +### Table of Contents + +- [Demo](#demo) +- [Why Supertonic?](#why-supertonic) +- [Language Support](#language-support) +- [Getting Started](#getting-started) +- [Performance](#performance) +- [Built with Supertonic](#built-with-supertonic) +- [Citation](#citation) +- [License](#license) + +## Demo + +### Raspberry Pi + +Watch Supertonic running on a **Raspberry Pi**, demonstrating on-device, real-time text-to-speech synthesis: + +https://github.com/user-attachments/assets/ea66f6d6-7bc5-4308-8a88-1ce3e07400d2 + +### E-Reader + +Experience Supertonic on an **Onyx Boox Go 6** e-reader in airplane mode, achieving an average RTF of 0.3ร— with zero network dependency: + +https://github.com/user-attachments/assets/64980e58-ad91-423a-9623-78c2ffc13680 + +### Chrome Extension + +Turns any webpage into audio in under one second, delivering lightning-fast, on-device text-to-speech with zero network dependencyโ€”free, private, and effortless: + +https://github.com/user-attachments/assets/cc8a45fc-5c3e-4b2c-8439-a14c3d00d91c + +--- + +> ๐ŸŽง **Try it now**: Experience Supertonic in your browser with our [**Interactive Demo**](https://huggingface.co/spaces/Supertone/supertonic-2), or get started with pre-trained models from [**Hugging Face Hub**](https://huggingface.co/Supertone/supertonic-2) + +## Why Supertonic? + +- **โšก Blazingly Fast**: Generates speech up to **167ร— faster than real-time** on consumer hardware (M4 Pro)โ€”unmatched by any other TTS system +- **๐Ÿชถ Ultra Lightweight**: Only **66M parameters**, optimized for efficient on-device performance with minimal footprint +- **๐Ÿ“ฑ On-Device Capable**: **Complete privacy** and **zero latency**โ€”all processing happens locally on your device +- **๐ŸŽจ Natural Text Handling**: Seamlessly processes numbers, dates, currency, abbreviations, and complex expressions without pre-processing +- **โš™๏ธ Highly Configurable**: Adjust inference steps, batch processing, and other parameters to match your specific needs +- **๐Ÿงฉ Flexible Deployment**: Deploy seamlessly across servers, browsers, and edge devices with multiple runtime backends. + +## Language Support + +We provide ready-to-use TTS inference examples across multiple ecosystems: + +| Language/Platform | Path | Description | +|-------------------|------|-------------| +| [**Python**](py/) | `py/` | ONNX Runtime inference | +| [**Node.js**](nodejs/) | `nodejs/` | Server-side JavaScript | +| [**Browser**](web/) | `web/` | WebGPU/WASM inference | +| [**Java**](java/) | `java/` | Cross-platform JVM | +| [**C++**](cpp/) | `cpp/` | High-performance C++ | +| [**C#**](csharp/) | `csharp/` | .NET ecosystem | +| [**Go**](go/) | `go/` | Go implementation | +| [**Swift**](swift/) | `swift/` | macOS applications | +| [**iOS**](ios/) | `ios/` | Native iOS apps | +| [**Rust**](rust/) | `rust/` | Memory-safe systems | +| [**Flutter**](flutter/) | `flutter/` | Cross-platform apps | + +> For detailed usage instructions, please refer to the README.md in each language directory. + +## Getting Started + +First, clone the repository: + +```bash +git clone https://github.com/supertone-inc/supertonic.git +cd supertonic +``` + +### Prerequisites + +Before running the examples, download the ONNX models and preset voices, and place them in the `assets` directory: + +> **Note:** The Hugging Face repository uses Git LFS. Please ensure Git LFS is installed and initialized before cloning or pulling large model files. +> - macOS: `brew install git-lfs && git lfs install` +> - Generic: see `https://git-lfs.com` for installers + +```bash +git clone https://huggingface.co/Supertone/supertonic-2 assets +``` + +### Quick Start + +**Python Example** ([Details](py/)) +```bash +cd py +uv sync +uv run example_onnx.py +``` + +**Node.js Example** ([Details](nodejs/)) +```bash +cd nodejs +npm install +npm start +``` + +**Browser Example** ([Details](web/)) +```bash +cd web +npm install +npm run dev +``` + +**Java Example** ([Details](java/)) +```bash +cd java +mvn clean install +mvn exec:java +``` + +**C++ Example** ([Details](cpp/)) +```bash +cd cpp +mkdir build && cd build +cmake .. && cmake --build . --config Release +./example_onnx +``` + +**C# Example** ([Details](csharp/)) +```bash +cd csharp +dotnet restore +dotnet run +``` + +**Go Example** ([Details](go/)) +```bash +cd go +go mod download +go run example_onnx.go helper.go +``` + +**Swift Example** ([Details](swift/)) +```bash +cd swift +swift build -c release +.build/release/example_onnx +``` + +**Rust Example** ([Details](rust/)) +```bash +cd rust +cargo build --release +./target/release/example_onnx +``` + +**iOS Example** ([Details](ios/)) +```bash +cd ios/ExampleiOSApp +xcodegen generate +open ExampleiOSApp.xcodeproj +``` +- In Xcode: Targets โ†’ ExampleiOSApp โ†’ Signing: select your Team +- Choose your iPhone as run destination โ†’ Build & Run + + +### Technical Details + +- **Runtime**: ONNX Runtime for cross-platform inference (CPU-optimized; GPU mode is not tested) +- **Browser Support**: onnxruntime-web for client-side inference +- **Batch Processing**: Supports batch inference for improved throughput +- **Audio Output**: Outputs 16-bit WAV files + +## Performance + +We evaluated Supertonic's performance (with 2 inference steps) using two key metrics across input texts of varying lengths: Short (59 chars), Mid (152 chars), and Long (266 chars). + +**Metrics:** +- **Characters per Second**: Measures throughput by dividing the number of input characters by the time required to generate audio. Higher is better. +- **Real-time Factor (RTF)**: Measures the time taken to synthesize audio relative to its duration. Lower is better (e.g., RTF of 0.1 means it takes 0.1 seconds to generate one second of audio). + +### Characters per Second +| System | Short (59 chars) | Mid (152 chars) | Long (266 chars) | +|--------|-----------------|----------------|-----------------| +| **Supertonic** (M4 pro - CPU) | 912 | 1048 | 1263 | +| **Supertonic** (M4 pro - WebGPU) | 996 | 1801 | 2509 | +| **Supertonic** (RTX4090) | 2615 | 6548 | 12164 | +| `API` [ElevenLabs Flash v2.5](https://elevenlabs.io/docs/api-reference/text-to-speech/convert) | 144 | 209 | 287 | +| `API` [OpenAI TTS-1](https://platform.openai.com/docs/guides/text-to-speech) | 37 | 55 | 82 | +| `API` [Gemini 2.5 Flash TTS](https://ai.google.dev/gemini-api/docs/speech-generation) | 12 | 18 | 24 | +| `API` [Supertone Sona speech 1](https://docs.supertoneapi.com/en/api-reference/endpoints/text-to-speech) | 38 | 64 | 92 | +| `Open` [Kokoro](https://github.com/hexgrad/kokoro/) | 104 | 107 | 117 | +| `Open` [NeuTTS Air](https://github.com/neuphonic/neutts-air) | 37 | 42 | 47 | + +> **Notes:** +> `API` = Cloud-based API services (measured from Seoul) +> `Open` = Open-source models +> Supertonic (M4 pro - CPU) and (M4 pro - WebGPU): Tested with ONNX +> Supertonic (RTX4090): Tested with PyTorch model +> Kokoro: Tested on M4 Pro CPU with ONNX +> NeuTTS Air: Tested on M4 Pro CPU with Q8-GGUF + +### Real-time Factor + +| System | Short (59 chars) | Mid (152 chars) | Long (266 chars) | +|--------|-----------------|----------------|-----------------| +| **Supertonic** (M4 pro - CPU) | 0.015 | 0.013 | 0.012 | +| **Supertonic** (M4 pro - WebGPU) | 0.014 | 0.007 | 0.006 | +| **Supertonic** (RTX4090) | 0.005 | 0.002 | 0.001 | +| `API` [ElevenLabs Flash v2.5](https://elevenlabs.io/docs/api-reference/text-to-speech/convert) | 0.133 | 0.077 | 0.057 | +| `API` [OpenAI TTS-1](https://platform.openai.com/docs/guides/text-to-speech) | 0.471 | 0.302 | 0.201 | +| `API` [Gemini 2.5 Flash TTS](https://ai.google.dev/gemini-api/docs/speech-generation) | 1.060 | 0.673 | 0.541 | +| `API` [Supertone Sona speech 1](https://docs.supertoneapi.com/en/api-reference/endpoints/text-to-speech) | 0.372 | 0.206 | 0.163 | +| `Open` [Kokoro](https://github.com/hexgrad/kokoro/) | 0.144 | 0.124 | 0.126 | +| `Open` [NeuTTS Air](https://github.com/neuphonic/neutts-air) | 0.390 | 0.338 | 0.343 | + +
+Additional Performance Data (5-step inference) + +
+ +**Characters per Second (5-step)** + +| System | Short (59 chars) | Mid (152 chars) | Long (266 chars) | +|--------|-----------------|----------------|-----------------| +| **Supertonic** (M4 pro - CPU) | 596 | 691 | 850 | +| **Supertonic** (M4 pro - WebGPU) | 570 | 1118 | 1546 | +| **Supertonic** (RTX4090) | 1286 | 3757 | 6242 | + +**Real-time Factor (5-step)** + +| System | Short (59 chars) | Mid (152 chars) | Long (266 chars) | +|--------|-----------------|----------------|-----------------| +| **Supertonic** (M4 pro - CPU) | 0.023 | 0.019 | 0.018 | +| **Supertonic** (M4 pro - WebGPU) | 0.024 | 0.012 | 0.010 | +| **Supertonic** (RTX4090) | 0.011 | 0.004 | 0.002 | + +
+ +### Natural Text Handling + +Supertonic is designed to handle complex, real-world text inputs that contain numbers, currency symbols, abbreviations, dates, and proper nouns. + +> ๐ŸŽง **View audio samples more easily**: Check out our [**Interactive Demo**](https://huggingface.co/spaces/Supertone/supertonic#text-handling) for a better viewing experience of all audio examples + +**Overview of Test Cases:** + +| Category | Key Challenges | Supertonic | ElevenLabs | OpenAI | Gemini | Microsoft | +|:--------:|:--------------:|:----------:|:----------:|:------:|:------:|:---------:| +| Financial Expression | Decimal currency, abbreviated magnitudes (M, K), currency symbols, currency codes | โœ… | โŒ | โŒ | โŒ | โŒ | +| Time and Date | Time notation, abbreviated weekdays/months, date formats | โœ… | โŒ | โŒ | โŒ | โŒ | +| Phone Number | Area codes, hyphens, extensions (ext.) | โœ… | โŒ | โŒ | โŒ | โŒ | +| Technical Unit | Decimal numbers with units, abbreviated technical notations | โœ… | โŒ | โŒ | โŒ | โŒ | + +
+Example 1: Financial Expression + +
+ +**Text:** +> "The startup secured **$5.2M** in venture capital, a huge leap from their initial **$450K** seed round." + +**Challenges:** +- Decimal point in currency ($5.2M should be read as "five point two million") +- Abbreviated magnitude units (M for million, K for thousand) +- Currency symbol ($) that needs to be properly pronounced as "dollars" + +**Audio Samples:** + +| System | Result | Audio Sample | +|--------|--------|--------------| +| **Supertonic** | โœ… | [๐ŸŽง Play Audio](https://drive.google.com/file/d/1eancUOhiSXCVoTu9ddh4S-OcVQaWrPV-/view?usp=sharing) | +| ElevenLabs Flash v2.5 | โŒ | [๐ŸŽง Play Audio](https://drive.google.com/file/d/1-r2scv7XQ1crIDu6QOh3eqVl445W6ap_/view?usp=sharing) | +| OpenAI TTS-1 | โŒ | [๐ŸŽง Play Audio](https://drive.google.com/file/d/1MFDXMjfmsAVOqwPx7iveS0KUJtZvcwxB/view?usp=sharing) | +| Gemini 2.5 Flash TTS | โŒ | [๐ŸŽง Play Audio](https://drive.google.com/file/d/1dEHpNzfMUucFTJPQK0k4RcFZvPwQTt09/view?usp=sharing) | +| VibeVoice Realtime 0.5B | โŒ | [๐ŸŽง Play Audio](https://drive.google.com/file/d/1b69XWBQnSZZ0WZeR3avv7E8mSdoN6p6P/view?usp=sharing) | + +
+ +
+Example 2: Time and Date + +
+ +**Text:** +> "The train delay was announced at **4:45 PM** on **Wed, Apr 3, 2024** due to track maintenance." + +**Challenges:** +- Time expression with PM notation (4:45 PM) +- Abbreviated weekday (Wed) +- Abbreviated month (Apr) +- Full date format (Apr 3, 2024) + +**Audio Samples:** + +| System | Result | Audio Sample | +|--------|--------|--------------| +| **Supertonic** | โœ… | [๐ŸŽง Play Audio](https://drive.google.com/file/d/1ehkZU8eiizBenG2DgR5tzBGQBvHS0Uaj/view?usp=sharing) | +| ElevenLabs Flash v2.5 | โŒ | [๐ŸŽง Play Audio](https://drive.google.com/file/d/1ta3r6jFyebmA-sT44l8EaEQcMLVmuOEr/view?usp=sharing) | +| OpenAI TTS-1 | โŒ | [๐ŸŽง Play Audio](https://drive.google.com/file/d/1sskmem9AzHAQ3Hv8DRSZoqX_pye-CXuU/view?usp=sharing) | +| Gemini 2.5 Flash TTS | โŒ | [๐ŸŽง Play Audio](https://drive.google.com/file/d/1zx9X8oMsLMXW0Zx_SURoqjju-By2yh_n/view?usp=sharing) | +| VibeVoice Realtime 0.5B | โŒ | [๐ŸŽง Play Audio](https://drive.google.com/file/d/1ZpGEstZr4hA0EdAWBMCUFFWuAkIpYsVh/view?usp=sharing) | + +
+ +
+Example 3: Phone Number + +
+ +**Text:** +> "You can reach the hotel front desk at **(212) 555-0142 ext. 402** anytime." + +**Challenges:** +- Area code in parentheses that should be read as separate digits +- Phone number with hyphen separator (555-0142) +- Abbreviated extension notation (ext.) +- Extension number (402) + +**Audio Samples:** + +| System | Result | Audio Sample | +|--------|--------|--------------| +| **Supertonic** | โœ… | [๐ŸŽง Play Audio](https://drive.google.com/file/d/1z-e5iTsihryMR8ll1-N1YXkB2CIJYJ6F/view?usp=sharing) | +| ElevenLabs Flash v2.5 | โŒ | [๐ŸŽง Play Audio](https://drive.google.com/file/d/1HAzVXFTZfZm0VEK2laSpsMTxzufcuaxA/view?usp=sharing) | +| OpenAI TTS-1 | โŒ | [๐ŸŽง Play Audio](https://drive.google.com/file/d/15tjfAmb3GbjP_kmvD7zSdIWkhtAaCPOg/view?usp=sharing) | +| Gemini 2.5 Flash TTS | โŒ | [๐ŸŽง Play Audio](https://drive.google.com/file/d/1BCL8n7yligUZyso970ud7Gf5NWb1OhKD/view?usp=sharing) | +| VibeVoice Realtime 0.5B | โŒ | [๐ŸŽง Play Audio](https://drive.google.com/file/d/1c0c0YM_Qm7XxSk2uSVYLbITgEDTqaVzL/view?usp=sharing) | + +
+ +
+Example 4: Technical Unit + +
+ +**Text:** +> "Our drone battery lasts **2.3h** when flying at **30kph** with full camera payload." + +**Challenges:** +- Decimal time duration with abbreviation (2.3h = two point three hours) +- Speed unit with abbreviation (30kph = thirty kilometers per hour) +- Technical abbreviations (h for hours, kph for kilometers per hour) +- Technical/engineering context requiring proper pronunciation + +**Audio Samples:** + +| System | Result | Audio Sample | +|--------|--------|--------------| +| **Supertonic** | โœ… | [๐ŸŽง Play Audio](https://drive.google.com/file/d/1kvOBvswFkLfmr8hGplH0V2XiMxy1shYf/view?usp=sharing) | +| ElevenLabs Flash v2.5 | โŒ | [๐ŸŽง Play Audio](https://drive.google.com/file/d/1_SzfjWJe5YEd0t3R7DztkYhHcI_av48p/view?usp=sharing) | +| OpenAI TTS-1 | โŒ | [๐ŸŽง Play Audio](https://drive.google.com/file/d/1P5BSilj5xFPTV2Xz6yW5jitKZohO9o-6/view?usp=sharing) | +| Gemini 2.5 Flash TTS | โŒ | [๐ŸŽง Play Audio](https://drive.google.com/file/d/1GU82SnWC50OvC8CZNjhxvNZFKQb7I9_Y/view?usp=sharing) | +| VibeVoice Realtime 0.5B | โŒ | [๐ŸŽง Play Audio](https://drive.google.com/file/d/1lUTrxrAQy_viEK2Hlu3KLLtTCe8jvbdV/view?usp=sharing) | + +
+ +> **Note:** These samples demonstrate how each system handles text normalization and pronunciation of complex expressions **without requiring pre-processing or phonetic annotations**. + +## Built with Supertonic + +| Project | Description | Links | +|---------|-------------|-------| +| **TLDRL** | Free, on-device TTS extension for reading any webpage | [Chrome](https://chromewebstore.google.com/detail/tldrl-lightning-tts-power/mdbiaajonlkomihpcaffhkagodbcgbme) | +| **Read Aloud** | Open-source TTS browser extension | [Chrome](https://chromewebstore.google.com/detail/read-aloud-a-text-to-spee/hdhinadidafjejdhmfkjgnolgimiaplp) ยท [Edge](https://microsoftedge.microsoft.com/addons/detail/read-aloud-a-text-to-spe/pnfonnnmfjnpfgagnklfaccicnnjcdkm) ยท [GitHub](https://github.com/ken107/read-aloud) | +| **PageEcho** | E-Book reader app for iOS | [App Store](https://apps.apple.com/us/app/pageecho/id6755965837) | +| **VoiceChat** | On-device voice-to-voice LLM chatbot in the browser | [Demo](https://huggingface.co/spaces/RickRossTN/ai-voice-chat) ยท [GitHub](https://github.com/irelate-ai/voice-chat) | +| **OmniAvatar** | Talking avatar video generator from photo + speech | [Demo](https://huggingface.co/spaces/alexnasa/OmniAvatar) | +| **CopiloTTS** | Kotlin Multiplatform TTS SDK via ONNX Runtime | [GitHub](https://github.com/sigmadeltasoftware/CopiloTTS) | +| **Voice Mixer** | PyQt5 tool for mixing and modifying voice styles | [GitHub](https://github.com/Topping1/Supertonic-Voice-Mixer) | +| **Supertonic MNN** | Lightweight library based on MNN (fp32/fp16/int8) | [GitHub](https://github.com/vra/supertonic-mnn) ยท [PyPI](https://pypi.org/project/supertonic-mnn/) | +| **Transformers.js** | Hugging Face's JS library with Supertonic support | [GitHub PR](https://github.com/huggingface/transformers.js/pull/1459) ยท [Demo](https://huggingface.co/spaces/webml-community/Supertonic-TTS-WebGPU) | +| **Pinokio** | 1-click localhost cloud for Mac, Windows, and Linux | [Pinokio](https://pinokio.co/) ยท [GitHub](https://github.com/SUP3RMASS1VE/SuperTonic-TTS) | + +## Citation + +The following papers describe the core technologies used in Supertonic. If you use this system in your research or find these techniques useful, please consider citing the relevant papers: + +### SupertonicTTS: Main Architecture + +This paper introduces the overall architecture of SupertonicTTS, including the speech autoencoder, flow-matching based text-to-latent module, and efficient design choices. + +```bibtex +@article{kim2025supertonic, + title={SupertonicTTS: Towards Highly Efficient and Streamlined Text-to-Speech System}, + author={Kim, Hyeongju and Yang, Jinhyeok and Yu, Yechan and Ji, Seunghun and Morton, Jacob and Bous, Frederik and Byun, Joon and Lee, Juheon}, + journal={arXiv preprint arXiv:2503.23108}, + year={2025}, + url={https://arxiv.org/abs/2503.23108} +} +``` + +### Length-Aware RoPE: Text-Speech Alignment + +This paper presents Length-Aware Rotary Position Embedding (LARoPE), which improves text-speech alignment in cross-attention mechanisms. + +```bibtex +@article{kim2025larope, + title={Length-Aware Rotary Position Embedding for Text-Speech Alignment}, + author={Kim, Hyeongju and Lee, Juheon and Yang, Jinhyeok and Morton, Jacob}, + journal={arXiv preprint arXiv:2509.11084}, + year={2025}, + url={https://arxiv.org/abs/2509.11084} +} +``` + +### Self-Purifying Flow Matching: Training with Noisy Labels + +This paper describes the self-purification technique for training flow matching models robustly with noisy or unreliable labels. + +```bibtex +@article{kim2025spfm, + title={Training Flow Matching Models with Reliable Labels via Self-Purification}, + author={Kim, Hyeongju and Yu, Yechan and Yi, June Young and Lee, Juheon}, + journal={arXiv preprint arXiv:2509.19091}, + year={2025}, + url={https://arxiv.org/abs/2509.19091} +} +``` + +## License + +This project's sample code is released under the MIT License. - see the [LICENSE](https://github.com/supertone-inc/supertonic?tab=MIT-1-ov-file) for details. + +The accompanying model is released under the OpenRAIL-M License. - see the [LICENSE](https://huggingface.co/Supertone/supertonic-2/blob/main/LICENSE) file for details. + +This model was trained using PyTorch, which is licensed under the BSD 3-Clause License but is not redistributed with this project. - see the [LICENSE](https://docs.pytorch.org/FBGEMM/general/License.html) for details. + +Copyright (c) 2026 Supertone Inc. + diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt new file mode 100644 index 0000000..f17563c --- /dev/null +++ b/cpp/CMakeLists.txt @@ -0,0 +1,122 @@ +cmake_minimum_required(VERSION 3.15) +project(Supertonic_CPP) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +# Enable aggressive optimization +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + +# Add optimization flags +set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3 -DNDEBUG -ffast-math") +set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} -O3 -DNDEBUG -ffast-math") + +# Find required packages +find_package(PkgConfig REQUIRED) +find_package(OpenMP) + +# ONNX Runtime - Try multiple methods +# Method 1: Try to find via CMake config +find_package(onnxruntime QUIET CONFIG) + +if(NOT onnxruntime_FOUND) + # Method 2: Try pkg-config + pkg_check_modules(ONNXRUNTIME QUIET libonnxruntime) + + if(ONNXRUNTIME_FOUND) + set(ONNXRUNTIME_INCLUDE_DIR ${ONNXRUNTIME_INCLUDE_DIRS}) + set(ONNXRUNTIME_LIB ${ONNXRUNTIME_LIBRARIES}) + else() + # Method 3: Manual search in common locations + find_path(ONNXRUNTIME_INCLUDE_DIR + NAMES onnxruntime_cxx_api.h + PATHS + /usr/local/include + /opt/homebrew/include + /usr/include + ${CMAKE_PREFIX_PATH}/include + PATH_SUFFIXES onnxruntime + ) + + find_library(ONNXRUNTIME_LIB + NAMES onnxruntime libonnxruntime + PATHS + /usr/local/lib + /opt/homebrew/lib + /usr/lib + ${CMAKE_PREFIX_PATH}/lib + ) + endif() + + if(NOT ONNXRUNTIME_INCLUDE_DIR OR NOT ONNXRUNTIME_LIB) + message(FATAL_ERROR "ONNX Runtime not found. Please install it:\n" + " macOS: brew install onnxruntime\n" + " Ubuntu: See README.md for installation instructions") + endif() + + message(STATUS "Found ONNX Runtime:") + message(STATUS " Include: ${ONNXRUNTIME_INCLUDE_DIR}") + message(STATUS " Library: ${ONNXRUNTIME_LIB}") +endif() + +# nlohmann/json +find_package(nlohmann_json REQUIRED) + +# Include directories +if(NOT onnxruntime_FOUND) + include_directories(${ONNXRUNTIME_INCLUDE_DIR}) +endif() + +# Helper library +add_library(tts_helper STATIC + helper.cpp + helper.h +) + +if(onnxruntime_FOUND) + target_link_libraries(tts_helper + onnxruntime::onnxruntime + nlohmann_json::nlohmann_json + ) +else() + target_include_directories(tts_helper PUBLIC ${ONNXRUNTIME_INCLUDE_DIR}) + target_link_libraries(tts_helper + ${ONNXRUNTIME_LIB} + nlohmann_json::nlohmann_json + ) +endif() + +# Enable OpenMP if available +if(OpenMP_CXX_FOUND) + target_link_libraries(tts_helper OpenMP::OpenMP_CXX) + message(STATUS "OpenMP enabled for parallel processing") +else() + message(WARNING "OpenMP not found - parallel processing will be disabled") +endif() + +# Example executable +add_executable(example_onnx + example_onnx.cpp +) + +if(onnxruntime_FOUND) + target_link_libraries(example_onnx + tts_helper + onnxruntime::onnxruntime + nlohmann_json::nlohmann_json + ) +else() + target_link_libraries(example_onnx + tts_helper + ${ONNXRUNTIME_LIB} + nlohmann_json::nlohmann_json + ) +endif() + +# Installation +install(TARGETS example_onnx DESTINATION bin) +install(TARGETS tts_helper DESTINATION lib) +install(FILES helper.h DESTINATION include) + diff --git a/cpp/README.md b/cpp/README.md new file mode 100644 index 0000000..7ab8d93 --- /dev/null +++ b/cpp/README.md @@ -0,0 +1,139 @@ +# Supertonic C++ Implementation + +High-performance text-to-speech inference using ONNX Runtime. + +## ๐Ÿ“ฐ Update News + +**2026.01.06** - ๐ŸŽ‰ **Supertonic 2** released with multilingual support! Now supports English (`en`), Korean (`ko`), Spanish (`es`), Portuguese (`pt`), and French (`fr`). [Demo](https://huggingface.co/spaces/Supertone/supertonic-2) | [Models](https://huggingface.co/Supertone/supertonic-2) + +**2025.12.10** - Added [6 new voice styles](https://huggingface.co/Supertone/supertonic/tree/b10dbaf18b316159be75b34d24f740008fddd381) (M3, M4, M5, F3, F4, F5). See [Voices](https://supertone-inc.github.io/supertonic-py/voices/) for details + +**2025.12.08** - Optimized ONNX models via [OnnxSlim](https://github.com/inisis/OnnxSlim) now available on [Hugging Face Models](https://huggingface.co/Supertone/supertonic) + +**2025.11.23** - Enhanced text preprocessing with comprehensive normalization, emoji removal, symbol replacement, and punctuation handling for improved synthesis quality. + +**2025.11.19** - Added `--speed` parameter to control speech synthesis speed (default: 1.05, recommended range: 0.9-1.5). + +**2025.11.19** - Added automatic text chunking for long-form inference. Long texts are split into chunks and synthesized with natural pauses. + +## Requirements + +- C++17 compiler, CMake 3.15+ +- Libraries: ONNX Runtime, nlohmann/json + +## Installation + +**Ubuntu/Debian:** +> โš ๏ธ **Note:** Installation instructions not yet verified. + +```bash +sudo apt-get install -y cmake g++ nlohmann-json3-dev +wget https://github.com/microsoft/onnxruntime/releases/download/v1.16.3/onnxruntime-linux-x64-1.16.3.tgz +tar -xzf onnxruntime-linux-x64-1.16.3.tgz +sudo cp -r onnxruntime-linux-x64-1.16.3/include/* /usr/local/include/ +sudo cp -r onnxruntime-linux-x64-1.16.3/lib/* /usr/local/lib/ +sudo ldconfig +``` + +**macOS:** +```bash +brew install cmake nlohmann-json onnxruntime +``` + +**Windows (vcpkg):** +> โš ๏ธ **Note:** Installation instructions not yet verified. + +```powershell +vcpkg install nlohmann-json:x64-windows onnxruntime:x64-windows +vcpkg integrate install +``` + +## Building + +```bash +cd cpp && mkdir build && cd build +cmake .. && cmake --build . --config Release +./example_onnx +``` + +## Basic Usage + +### Example 1: Default Inference +Run inference with default settings: +```bash +./example_onnx +``` + +This will use: +- Voice style: `../assets/voice_styles/M1.json` +- Text: "This morning, I took a walk in the park, and the sound of the birds and the breeze was so pleasant that I stopped for a long time just to listen." +- Output directory: `results/` +- Total steps: 5 +- Number of generations: 4 + +### Example 2: Batch Inference +Process multiple voice styles and texts at once: +```bash +./example_onnx \ + --voice-style ../assets/voice_styles/M1.json,../assets/voice_styles/F1.json \ + --text "The sun sets behind the mountains, painting the sky in shades of pink and orange.|์˜ค๋Š˜ ์•„์นจ์— ๊ณต์›์„ ์‚ฐ์ฑ…ํ–ˆ๋Š”๋ฐ, ์ƒˆ์†Œ๋ฆฌ์™€ ๋ฐ”๋žŒ ์†Œ๋ฆฌ๊ฐ€ ๋„ˆ๋ฌด ์ข‹์•„์„œ ํ•œ์ฐธ์„ ๋ฉˆ์ถฐ ์„œ์„œ ๋“ค์—ˆ์–ด์š”." \ + --lang en,ko \ + --batch +``` + +This will: +- Use `--batch` flag to enable batch processing mode +- Generate speech for 2 different voice-text pairs +- Use male voice style (M1.json) for the first English text +- Use female voice style (F1.json) for the second Korean text +- Process both samples in a single batch (automatic text chunking disabled) + +### Example 3: High Quality Inference +Increase denoising steps for better quality: +```bash +./example_onnx \ + --total-step 10 \ + --voice-style ../assets/voice_styles/M1.json \ + --text "Increasing the number of denoising steps improves the output's fidelity and overall quality." +``` + +This will: +- Use 10 denoising steps instead of the default 5 +- Produce higher quality output at the cost of slower inference + +### Example 4: Long-Form Inference +For long texts, the system automatically chunks the text into manageable segments and generates a single audio file: +```bash +./example_onnx \ + --voice-style ../assets/voice_styles/M1.json \ + --text "Once upon a time, in a small village nestled between rolling hills, there lived a young artist named Clara. Every morning, she would wake up before dawn to capture the first light of day. The golden rays streaming through her window inspired countless paintings. Her work was known throughout the region for its vibrant colors and emotional depth. People from far and wide came to see her gallery, and many said her paintings could tell stories that words never could." +``` + +This will: +- Automatically split the long text into smaller chunks (max 300 characters by default) +- Process each chunk separately while maintaining natural speech flow +- Insert brief silences (0.3 seconds) between chunks for natural pacing +- Combine all chunks into a single output audio file + +**Note**: When using batch mode (`--batch`), automatic text chunking is disabled. Use non-batch mode for long-form text synthesis. + +## Available Arguments + +| Argument | Type | Default | Description | +|----------|------|---------|-------------| +| `--onnx-dir` | str | `../assets/onnx` | Path to ONNX model directory | +| `--total-step` | int | 5 | Number of denoising steps (higher = better quality, slower) | +| `--speed` | float | 1.05 | Speech speed factor (higher = faster, lower = slower) | +| `--n-test` | int | 4 | Number of times to generate each sample | +| `--voice-style` | str | `../assets/voice_styles/M1.json` | Voice style file path(s) (comma-separated for batch) | +| `--text` | str | (long default text) | Text(s) to synthesize (pipe-separated for batch) | +| `--lang` | str | `en` | Language(s) for text(s): `en`, `ko`, `es`, `pt`, `fr` (comma-separated for batch) | +| `--save-dir` | str | `results` | Output directory | +| `--batch` | flag | False | Enable batch mode (disables automatic text chunking) | + +## Notes + +- **Batch Processing**: The number of `--voice-style` files must match the number of `--text` entries +- **Multilingual Support**: Use `--lang` to specify language(s). Available: `en` (English), `ko` (Korean), `es` (Spanish), `pt` (Portuguese), `fr` (French) +- **Long-Form Inference**: Without `--batch` flag, long texts are automatically chunked and combined into a single audio file with natural pauses +- **Quality vs Speed**: Higher `--total-step` values produce better quality but take longer diff --git a/cpp/example_onnx.cpp b/cpp/example_onnx.cpp new file mode 100644 index 0000000..9d7d68b --- /dev/null +++ b/cpp/example_onnx.cpp @@ -0,0 +1,121 @@ +#include "helper.h" +#include +#include +#include +#include +#include + +namespace fs = std::filesystem; + +struct Args { + std::string onnx_dir = "../assets/onnx"; + int total_step = 5; + float speed = 1.05f; + int n_test = 4; + std::vector voice_style = {"../assets/voice_styles/M1.json"}; + std::vector text = { + "This morning, I took a walk in the park, and the sound of the birds and the breeze was so pleasant that I stopped for a long time just to listen." + }; + std::vector lang = {"en"}; + std::string save_dir = "results"; + bool batch = false; +}; + +auto splitString = [](const std::string& str, char delim) { + std::vector result; + size_t start = 0, pos; + while ((pos = str.find(delim, start)) != std::string::npos) { + result.push_back(str.substr(start, pos - start)); + start = pos + 1; + } + result.push_back(str.substr(start)); + return result; +}; + +Args parseArgs(int argc, char* argv[]) { + Args args; + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + if (arg == "--onnx-dir" && i + 1 < argc) args.onnx_dir = argv[++i]; + else if (arg == "--total-step" && i + 1 < argc) args.total_step = std::stoi(argv[++i]); + else if (arg == "--speed" && i + 1 < argc) args.speed = std::stof(argv[++i]); + else if (arg == "--n-test" && i + 1 < argc) args.n_test = std::stoi(argv[++i]); + else if (arg == "--voice-style" && i + 1 < argc) args.voice_style = splitString(argv[++i], ','); + else if (arg == "--text" && i + 1 < argc) args.text = splitString(argv[++i], '|'); + else if (arg == "--lang" && i + 1 < argc) args.lang = splitString(argv[++i], ','); + else if (arg == "--save-dir" && i + 1 < argc) args.save_dir = argv[++i]; + else if (arg == "--batch") args.batch = true; + } + return args; +} + +int main(int argc, char* argv[]) { + std::cout << "=== TTS Inference with ONNX Runtime (C++) ===\n\n"; + + // --- 1. Parse arguments --- // + Args args = parseArgs(argc, argv); + int total_step = args.total_step; + float speed = args.speed; + int n_test = args.n_test; + std::string save_dir = args.save_dir; + std::vector voice_style_paths = args.voice_style; + std::vector text_list = args.text; + std::vector lang_list = args.lang; + bool batch = args.batch; + + if (voice_style_paths.size() != text_list.size()) { + std::cerr << "Error: Number of voice styles (" << voice_style_paths.size() + << ") must match number of texts (" << text_list.size() << ")\n"; + return 1; + } + int bsz = voice_style_paths.size(); + + // --- 2. Load Text to Speech --- // + Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "TTS"); + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu( + OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault + ); + + auto text_to_speech = loadTextToSpeech(env, args.onnx_dir, false); + std::cout << std::endl; + + // --- 3. Load Voice Style --- // + auto style = loadVoiceStyle(voice_style_paths, true); + + // --- 4. Synthesize speech --- // + fs::create_directories(save_dir); + + for (int n = 0; n < n_test; n++) { + std::cout << "\n[" << (n + 1) << "/" << n_test << "] Starting synthesis...\n"; + + auto result = timer("Generating speech from text", [&]() { + if (batch) { + return text_to_speech->batch(memory_info, text_list, lang_list, style, total_step, speed); + } else { + return text_to_speech->call(memory_info, text_list[0], lang_list[0], style, total_step, speed); + } + }); + + int sample_rate = text_to_speech->getSampleRate(); + int wav_shape_1 = result.wav.size() / bsz; + + for (int b = 0; b < bsz; b++) { + std::string fname = sanitizeFilename(text_list[b], 20) + "_" + std::to_string(n + 1) + ".wav"; + int wav_len = static_cast(sample_rate * result.duration[b]); + + std::vector wav_out( + result.wav.begin() + b * wav_shape_1, + result.wav.begin() + b * wav_shape_1 + wav_len + ); + + std::string output_path = save_dir + "/" + fname; + writeWavFile(output_path, wav_out, sample_rate); + std::cout << "Saved: " << output_path << "\n"; + } + + clearTensorBuffers(); + } + + std::cout << "\n=== Synthesis completed successfully! ===\n"; + return 0; +} diff --git a/cpp/helper.cpp b/cpp/helper.cpp new file mode 100644 index 0000000..111f2fb --- /dev/null +++ b/cpp/helper.cpp @@ -0,0 +1,1186 @@ +#include "helper.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using json = nlohmann::json; + +// Available languages for multilingual TTS +const std::vector AVAILABLE_LANGS = {"en", "ko", "es", "pt", "fr"}; + +// Global tensor buffers for memory management +static std::vector> g_tensor_buffers_float; +static std::vector> g_tensor_buffers_int64; + +void clearTensorBuffers() { + g_tensor_buffers_float.clear(); + g_tensor_buffers_int64.clear(); +} + +// ============================================================================ +// Helper function - trim +// ============================================================================ + +static std::string trim(const std::string& str) { + size_t start = 0; + while (start < str.size() && std::isspace(static_cast(str[start]))) { + start++; + } + + size_t end = str.size(); + while (end > start && std::isspace(static_cast(str[end - 1]))) { + end--; + } + + return str.substr(start, end - start); +} + +// ============================================================================ +// UnicodeProcessor implementation +// ============================================================================ + +UnicodeProcessor::UnicodeProcessor(const std::string& unicode_indexer_json_path) { + indexer_ = loadJsonInt64(unicode_indexer_json_path); +} + +std::string UnicodeProcessor::preprocessText(const std::string& text, const std::string& lang) { + // TODO: Need advanced normalizer for better performance + // NOTE: C++ doesn't have built-in Unicode normalization like Python's NFKD + // For full Unicode normalization, consider using ICU library + // This implementation handles basic text preprocessing + + std::string result = text; + + // IMPORTANT: Do symbol replacements FIRST (before emoji removal) + // to preserve curly quotes and other punctuation that might be matched by emoji patterns + + // Replace various dashes and symbols + struct Replacement { + const char* from; + const char* to; + }; + + const Replacement replacements[] = { + {"โ€“", "-"}, // en dash + {"โ€‘", "-"}, // non-breaking hyphen + {"โ€”", "-"}, // em dash + {"_", " "}, // underscore + { u8"\u201C", "\"" }, // left double quote " + { u8"\u201D", "\"" }, // right double quote " + { u8"\u2018", "'" }, // left single quote ' + { u8"\u2019", "'" }, // right single quote ' + {"ยด", "'"}, // acute accent + {"`", "'"}, // grave accent + {"[", " "}, // left bracket + {"]", " "}, // right bracket + {"|", " "}, // vertical bar + {"/", " "}, // slash + {"#", " "}, // hash + {"โ†’", " "}, // right arrow + {"โ†", " "}, // left arrow + }; + + for (const auto& repl : replacements) { + size_t pos = 0; + while ((pos = result.find(repl.from, pos)) != std::string::npos) { + result.replace(pos, strlen(repl.from), repl.to); + pos += strlen(repl.to); + } + } + + // Remove emojis AFTER symbol replacements + // Only target actual emoji ranges (4-byte UTF-8 sequences starting with F0 9F) + std::regex emoji_pattern( + "[\xF0][\x9F][\x80-\xBF][\x80-\xBF]" // 4-byte emoji (U+1F000-U+1FFFF) + ); + result = std::regex_replace(result, emoji_pattern, ""); + + // Remove special symbols + const char* special_symbols[] = {"โ™ฅ", "โ˜†", "โ™ก", "ยฉ", "\\"}; + for (const char* symbol : special_symbols) { + size_t pos = 0; + while ((pos = result.find(symbol, pos)) != std::string::npos) { + result.erase(pos, strlen(symbol)); + } + } + + // Replace known expressions + const Replacement expr_replacements[] = { + {"@", " at "}, + {"e.g.,", "for example, "}, + {"i.e.,", "that is, "}, + }; + + for (const auto& repl : expr_replacements) { + size_t pos = 0; + while ((pos = result.find(repl.from, pos)) != std::string::npos) { + result.replace(pos, strlen(repl.from), repl.to); + pos += strlen(repl.to); + } + } + + // Fix spacing around punctuation + result = std::regex_replace(result, std::regex(" ,"), ","); + result = std::regex_replace(result, std::regex(" \\."), "."); + result = std::regex_replace(result, std::regex(" !"), "!"); + result = std::regex_replace(result, std::regex(" \\?"), "?"); + result = std::regex_replace(result, std::regex(" ;"), ";"); + result = std::regex_replace(result, std::regex(" :"), ":"); + result = std::regex_replace(result, std::regex(" '"), "'"); + + // Remove duplicate quotes + while (result.find("\"\"") != std::string::npos) { + size_t pos = result.find("\"\""); + result.replace(pos, 2, "\""); + } + while (result.find("''") != std::string::npos) { + size_t pos = result.find("''"); + result.replace(pos, 2, "'"); + } + while (result.find("``") != std::string::npos) { + size_t pos = result.find("``"); + result.replace(pos, 2, "`"); + } + + // Remove extra spaces + result = std::regex_replace(result, std::regex("\\s+"), " "); + result = trim(result); + + // If text doesn't end with punctuation, quotes, or closing brackets, add a period + if (!result.empty()) { + char last_char = result.back(); + bool ends_with_punct = ( + last_char == '.' || last_char == '!' || last_char == '?' || + last_char == ';' || last_char == ':' || last_char == ',' || + last_char == '\'' || last_char == '"' || last_char == ')' || + last_char == ']' || last_char == '}' || last_char == '>' + ); + + // Check for UTF-8 multibyte ending punctuation (e.g., โ€ฆ, ใ€‚, curly quotes, etc.) + if (!ends_with_punct && result.size() >= 3) { + std::string last_three = result.substr(result.size() - 3); + if (last_three == "โ€ฆ" || last_three == "ใ€‚" || + last_three == "ใ€" || last_three == "ใ€" || + last_three == "ใ€‘" || last_three == "ใ€‰" || + last_three == "ใ€‹" || last_three == "โ€บ" || + last_three == "ยป" || last_three == u8"\u201C" || + last_three == u8"\u201D" || last_three == u8"\u2018" || + last_three == u8"\u2019") { + ends_with_punct = true; + } + } + + if (!ends_with_punct) { + result += "."; + } + } + + // Validate language + bool valid_lang = false; + for (const auto& available_lang : AVAILABLE_LANGS) { + if (lang == available_lang) { + valid_lang = true; + break; + } + } + if (!valid_lang) { + throw std::runtime_error("Invalid language: " + lang + ". Available: en, ko, es, pt, fr"); + } + + // Wrap text with language tags + result = "<" + lang + ">" + result + ""; + + return result; +} + +// Hangul syllable decomposition constants (Unicode Standard Annex #15) +static const uint32_t HANGUL_SBASE = 0xAC00; // Start of Hangul syllables +static const uint32_t HANGUL_LBASE = 0x1100; // Start of Hangul Jamo (leading consonants) +static const uint32_t HANGUL_VBASE = 0x1161; // Start of Hangul Jamo (vowels) +static const uint32_t HANGUL_TBASE = 0x11A7; // Start of Hangul Jamo (trailing consonants) +static const int HANGUL_LCOUNT = 19; // Number of leading consonants +static const int HANGUL_VCOUNT = 21; // Number of vowels +static const int HANGUL_TCOUNT = 28; // Number of trailing consonants (including none) +static const int HANGUL_NCOUNT = HANGUL_VCOUNT * HANGUL_TCOUNT; // 588 +static const int HANGUL_SCOUNT = HANGUL_LCOUNT * HANGUL_NCOUNT; // 11172 + +// Latin character NFKD decompositions for Spanish, Portuguese, French +static const std::unordered_map> LATIN_DECOMPOSITIONS = { + // Acute accent + {0x00C1, {0x0041, 0x0301}}, // ร โ†’ A + ฬ + {0x00C9, {0x0045, 0x0301}}, // ร‰ โ†’ E + ฬ + {0x00CD, {0x0049, 0x0301}}, // ร โ†’ I + ฬ + {0x00D3, {0x004F, 0x0301}}, // ร“ โ†’ O + ฬ + {0x00DA, {0x0055, 0x0301}}, // รš โ†’ U + ฬ + {0x00E1, {0x0061, 0x0301}}, // รก โ†’ a + ฬ + {0x00E9, {0x0065, 0x0301}}, // รฉ โ†’ e + ฬ + {0x00ED, {0x0069, 0x0301}}, // รญ โ†’ i + ฬ + {0x00F3, {0x006F, 0x0301}}, // รณ โ†’ o + ฬ + {0x00FA, {0x0075, 0x0301}}, // รบ โ†’ u + ฬ + // Grave accent + {0x00C0, {0x0041, 0x0300}}, // ร€ โ†’ A + ฬ€ + {0x00C8, {0x0045, 0x0300}}, // รˆ โ†’ E + ฬ€ + {0x00CC, {0x0049, 0x0300}}, // รŒ โ†’ I + ฬ€ + {0x00D2, {0x004F, 0x0300}}, // ร’ โ†’ O + ฬ€ + {0x00D9, {0x0055, 0x0300}}, // ร™ โ†’ U + ฬ€ + {0x00E0, {0x0061, 0x0300}}, // ร  โ†’ a + ฬ€ + {0x00E8, {0x0065, 0x0300}}, // รจ โ†’ e + ฬ€ + {0x00EC, {0x0069, 0x0300}}, // รฌ โ†’ i + ฬ€ + {0x00F2, {0x006F, 0x0300}}, // รฒ โ†’ o + ฬ€ + {0x00F9, {0x0075, 0x0300}}, // รน โ†’ u + ฬ€ + // Circumflex + {0x00C2, {0x0041, 0x0302}}, // ร‚ โ†’ A + ฬ‚ + {0x00CA, {0x0045, 0x0302}}, // รŠ โ†’ E + ฬ‚ + {0x00CE, {0x0049, 0x0302}}, // รŽ โ†’ I + ฬ‚ + {0x00D4, {0x004F, 0x0302}}, // ร” โ†’ O + ฬ‚ + {0x00DB, {0x0055, 0x0302}}, // ร› โ†’ U + ฬ‚ + {0x00E2, {0x0061, 0x0302}}, // รข โ†’ a + ฬ‚ + {0x00EA, {0x0065, 0x0302}}, // รช โ†’ e + ฬ‚ + {0x00EE, {0x0069, 0x0302}}, // รฎ โ†’ i + ฬ‚ + {0x00F4, {0x006F, 0x0302}}, // รด โ†’ o + ฬ‚ + {0x00FB, {0x0075, 0x0302}}, // รป โ†’ u + ฬ‚ + // Tilde + {0x00C3, {0x0041, 0x0303}}, // รƒ โ†’ A + ฬƒ + {0x00D1, {0x004E, 0x0303}}, // ร‘ โ†’ N + ฬƒ + {0x00D5, {0x004F, 0x0303}}, // ร• โ†’ O + ฬƒ + {0x00E3, {0x0061, 0x0303}}, // รฃ โ†’ a + ฬƒ + {0x00F1, {0x006E, 0x0303}}, // รฑ โ†’ n + ฬƒ + {0x00F5, {0x006F, 0x0303}}, // รต โ†’ o + ฬƒ + // Diaeresis + {0x00C4, {0x0041, 0x0308}}, // ร„ โ†’ A + ฬˆ + {0x00CB, {0x0045, 0x0308}}, // ร‹ โ†’ E + ฬˆ + {0x00CF, {0x0049, 0x0308}}, // ร โ†’ I + ฬˆ + {0x00D6, {0x004F, 0x0308}}, // ร– โ†’ O + ฬˆ + {0x00DC, {0x0055, 0x0308}}, // รœ โ†’ U + ฬˆ + {0x00E4, {0x0061, 0x0308}}, // รค โ†’ a + ฬˆ + {0x00EB, {0x0065, 0x0308}}, // รซ โ†’ e + ฬˆ + {0x00EF, {0x0069, 0x0308}}, // รฏ โ†’ i + ฬˆ + {0x00F6, {0x006F, 0x0308}}, // รถ โ†’ o + ฬˆ + {0x00FC, {0x0075, 0x0308}}, // รผ โ†’ u + ฬˆ + // Cedilla + {0x00C7, {0x0043, 0x0327}}, // ร‡ โ†’ C + ฬง + {0x00E7, {0x0063, 0x0327}}, // รง โ†’ c + ฬง +}; + +// Decompose a character using NFKD (Hangul + Latin accented) +static void decomposeCharacter(uint32_t codepoint, std::vector& output) { + // Check Hangul syllables first + if (codepoint >= HANGUL_SBASE && codepoint < HANGUL_SBASE + HANGUL_SCOUNT) { + // Decompose Hangul syllable into Jamo + uint32_t sIndex = codepoint - HANGUL_SBASE; + uint32_t lIndex = sIndex / HANGUL_NCOUNT; + uint32_t vIndex = (sIndex % HANGUL_NCOUNT) / HANGUL_TCOUNT; + uint32_t tIndex = sIndex % HANGUL_TCOUNT; + + output.push_back(static_cast(HANGUL_LBASE + lIndex)); + output.push_back(static_cast(HANGUL_VBASE + vIndex)); + if (tIndex > 0) { + output.push_back(static_cast(HANGUL_TBASE + tIndex)); + } + return; + } + + // Check Latin decompositions + auto it = LATIN_DECOMPOSITIONS.find(codepoint); + if (it != LATIN_DECOMPOSITIONS.end()) { + for (uint16_t cp : it->second) { + output.push_back(cp); + } + return; + } + + // Keep as-is + output.push_back(static_cast(codepoint & 0xFFFF)); +} + +std::vector UnicodeProcessor::textToUnicodeValues(const std::string& text) { + std::vector unicode_values; + size_t i = 0; + + while (i < text.size()) { + uint32_t codepoint = 0; + unsigned char c = static_cast(text[i]); + + if ((c & 0x80) == 0) { + // 1-byte ASCII (0xxxxxxx) + codepoint = c; + i += 1; + } + else if ((c & 0xE0) == 0xC0 && i + 1 < text.size()) { + // 2-byte UTF-8 (110xxxxx 10xxxxxx) + codepoint = (c & 0x1F) << 6; + codepoint |= (static_cast(text[i + 1]) & 0x3F); + i += 2; + } + else if ((c & 0xF0) == 0xE0 && i + 2 < text.size()) { + // 3-byte UTF-8 (1110xxxx 10xxxxxx 10xxxxxx) - includes Korean + codepoint = (c & 0x0F) << 12; + codepoint |= (static_cast(text[i + 1]) & 0x3F) << 6; + codepoint |= (static_cast(text[i + 2]) & 0x3F); + i += 3; + } + else if ((c & 0xF8) == 0xF0 && i + 3 < text.size()) { + // 4-byte UTF-8 (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) + codepoint = (c & 0x07) << 18; + codepoint |= (static_cast(text[i + 1]) & 0x3F) << 12; + codepoint |= (static_cast(text[i + 2]) & 0x3F) << 6; + codepoint |= (static_cast(text[i + 3]) & 0x3F); + i += 4; + } + else { + // Invalid UTF-8, skip byte + i += 1; + continue; + } + + // Decompose Hangul syllables and Latin accented characters (NFKD) + decomposeCharacter(codepoint, unicode_values); + } + + return unicode_values; +} + +std::vector>> UnicodeProcessor::getTextMask( + const std::vector& text_ids_lengths +) { + return lengthToMask(text_ids_lengths); +} + +void UnicodeProcessor::call( + const std::vector& text_list, + const std::vector& lang_list, + std::vector>& text_ids, + std::vector>>& text_mask +) { + std::vector processed_texts; + for (size_t i = 0; i < text_list.size(); i++) { + processed_texts.push_back(preprocessText(text_list[i], lang_list[i])); + } + + // Convert texts to unicode values first to get correct character counts + std::vector> all_unicode_vals; + std::vector text_ids_lengths; + for (const auto& text : processed_texts) { + auto unicode_vals = textToUnicodeValues(text); + // Use number of Unicode codepoints, not bytes + text_ids_lengths.push_back(static_cast(unicode_vals.size())); + all_unicode_vals.push_back(std::move(unicode_vals)); + } + + int64_t max_len = *std::max_element(text_ids_lengths.begin(), text_ids_lengths.end()); + + text_ids.resize(text_list.size()); + for (size_t i = 0; i < all_unicode_vals.size(); i++) { + text_ids[i].resize(max_len, 0); + const auto& unicode_vals = all_unicode_vals[i]; + for (size_t j = 0; j < unicode_vals.size(); j++) { + if (unicode_vals[j] < indexer_.size()) { + text_ids[i][j] = indexer_[unicode_vals[j]]; + } + } + } + + text_mask = getTextMask(text_ids_lengths); +} + +// ============================================================================ +// Style implementation +// ============================================================================ + +Style::Style(const std::vector& ttl_data, const std::vector& ttl_shape, + const std::vector& dp_data, const std::vector& dp_shape) + : ttl_data_(ttl_data), ttl_shape_(ttl_shape), dp_data_(dp_data), dp_shape_(dp_shape) {} + +// ============================================================================ +// TextToSpeech implementation +// ============================================================================ + +TextToSpeech::TextToSpeech( + const Config& cfgs, + UnicodeProcessor* text_processor, + Ort::Session* dp_ort, + Ort::Session* text_enc_ort, + Ort::Session* vector_est_ort, + Ort::Session* vocoder_ort +) : cfgs_(cfgs), + text_processor_(text_processor), + dp_ort_(dp_ort), + text_enc_ort_(text_enc_ort), + vector_est_ort_(vector_est_ort), + vocoder_ort_(vocoder_ort) { + + sample_rate_ = cfgs.ae.sample_rate; + base_chunk_size_ = cfgs.ae.base_chunk_size; + chunk_compress_factor_ = cfgs.ttl.chunk_compress_factor; + ldim_ = cfgs.ttl.latent_dim; +} + +void TextToSpeech::sampleNoisyLatent( + const std::vector& duration, + std::vector>>& noisy_latent, + std::vector>>& latent_mask +) { + int bsz = duration.size(); + float wav_len_max = *std::max_element(duration.begin(), duration.end()) * sample_rate_; + + std::vector wav_lengths; + for (float d : duration) { + wav_lengths.push_back(static_cast(d * sample_rate_)); + } + + int chunk_size = base_chunk_size_ * chunk_compress_factor_; + int latent_len = static_cast((wav_len_max + chunk_size - 1) / chunk_size); + int latent_dim = ldim_ * chunk_compress_factor_; + + // Generate random noise with normal distribution + std::random_device rd; + std::mt19937 gen(rd()); + std::normal_distribution dist(0.0f, 1.0f); + + noisy_latent.resize(bsz); + for (int b = 0; b < bsz; b++) { + noisy_latent[b].resize(latent_dim); + for (int d = 0; d < latent_dim; d++) { + noisy_latent[b][d].resize(latent_len); + for (int t = 0; t < latent_len; t++) { + noisy_latent[b][d][t] = dist(gen); + } + } + } + + latent_mask = getLatentMask(wav_lengths, base_chunk_size_, chunk_compress_factor_); + + // Apply mask + for (int b = 0; b < bsz; b++) { + for (int d = 0; d < latent_dim; d++) { + for (size_t t = 0; t < noisy_latent[b][d].size(); t++) { + noisy_latent[b][d][t] *= latent_mask[b][0][t]; + } + } + } +} + +TextToSpeech::SynthesisResult TextToSpeech::_infer( + Ort::MemoryInfo& memory_info, + const std::vector& text_list, + const std::vector& lang_list, + const Style& style, + int total_step, + float speed +) { + int bsz = text_list.size(); + + if (bsz != style.getTtlShape()[0]) { + throw std::runtime_error("Number of texts must match number of style vectors"); + } + + // Process text + std::vector> text_ids; + std::vector>> text_mask; + text_processor_->call(text_list, lang_list, text_ids, text_mask); + + std::vector text_ids_shape = {bsz, static_cast(text_ids[0].size())}; + std::vector text_mask_shape = {bsz, 1, static_cast(text_mask[0][0].size())}; + + auto text_ids_tensor = intArrayToTensor(memory_info, text_ids, text_ids_shape); + auto text_mask_tensor = arrayToTensor(memory_info, text_mask, text_mask_shape); + + // Create style tensors + auto style_ttl_tensor = Ort::Value::CreateTensor( + memory_info, + const_cast(style.getTtlData().data()), + style.getTtlData().size(), + style.getTtlShape().data(), + style.getTtlShape().size() + ); + + auto style_dp_tensor = Ort::Value::CreateTensor( + memory_info, + const_cast(style.getDpData().data()), + style.getDpData().size(), + style.getDpShape().data(), + style.getDpShape().size() + ); + + // Run duration predictor + const char* dp_input_names[] = {"text_ids", "style_dp", "text_mask"}; + const char* dp_output_names[] = {"duration"}; + std::vector dp_inputs; + dp_inputs.push_back(std::move(text_ids_tensor)); + dp_inputs.push_back(std::move(style_dp_tensor)); + dp_inputs.push_back(std::move(text_mask_tensor)); + + auto dp_outputs = dp_ort_->Run( + Ort::RunOptions{nullptr}, + dp_input_names, dp_inputs.data(), dp_inputs.size(), + dp_output_names, 1 + ); + + auto* dur_data = dp_outputs[0].GetTensorMutableData(); + std::vector duration(dur_data, dur_data + bsz); + + // Apply speed factor to duration + for (auto& dur : duration) { + dur /= speed; + } + + // Create new tensors for text encoder (previous ones were moved) + text_ids_tensor = intArrayToTensor(memory_info, text_ids, text_ids_shape); + text_mask_tensor = arrayToTensor(memory_info, text_mask, text_mask_shape); + style_ttl_tensor = Ort::Value::CreateTensor( + memory_info, + const_cast(style.getTtlData().data()), + style.getTtlData().size(), + style.getTtlShape().data(), + style.getTtlShape().size() + ); + + // Run text encoder + const char* text_enc_input_names[] = {"text_ids", "style_ttl", "text_mask"}; + const char* text_enc_output_names[] = {"text_emb"}; + std::vector text_enc_inputs; + text_enc_inputs.push_back(std::move(text_ids_tensor)); + text_enc_inputs.push_back(std::move(style_ttl_tensor)); + text_enc_inputs.push_back(std::move(text_mask_tensor)); + + auto text_enc_outputs = text_enc_ort_->Run( + Ort::RunOptions{nullptr}, + text_enc_input_names, text_enc_inputs.data(), text_enc_inputs.size(), + text_enc_output_names, 1 + ); + + // Sample noisy latent + std::vector>> xt, latent_mask; + sampleNoisyLatent(duration, xt, latent_mask); + + std::vector latent_shape = { + bsz, + static_cast(xt[0].size()), + static_cast(xt[0][0].size()) + }; + std::vector latent_mask_shape = { + bsz, 1, + static_cast(latent_mask[0][0].size()) + }; + + // Prepare scalar tensors + std::vector total_step_vec(bsz, static_cast(total_step)); + auto total_step_tensor = Ort::Value::CreateTensor( + memory_info, + total_step_vec.data(), + total_step_vec.size(), + std::vector{bsz}.data(), + 1 + ); + + // Store text_emb data to reuse across iterations + auto text_emb_info = text_enc_outputs[0].GetTensorTypeAndShapeInfo(); + size_t text_emb_size = text_emb_info.GetElementCount(); + auto* text_emb_data = text_enc_outputs[0].GetTensorMutableData(); + std::vector text_emb_vec(text_emb_data, text_emb_data + text_emb_size); + auto text_emb_shape = text_emb_info.GetShape(); + + // Iterative denoising + for (int step = 0; step < total_step; step++) { + std::vector current_step_vec(bsz, static_cast(step)); + + text_mask_tensor = arrayToTensor(memory_info, text_mask, text_mask_shape); + auto latent_mask_tensor = arrayToTensor(memory_info, latent_mask, latent_mask_shape); + auto noisy_latent_tensor = arrayToTensor(memory_info, xt, latent_shape); + style_ttl_tensor = Ort::Value::CreateTensor( + memory_info, + const_cast(style.getTtlData().data()), + style.getTtlData().size(), + style.getTtlShape().data(), + style.getTtlShape().size() + ); + + auto text_emb_tensor = Ort::Value::CreateTensor( + memory_info, + text_emb_vec.data(), + text_emb_vec.size(), + text_emb_shape.data(), + text_emb_shape.size() + ); + + auto current_step_tensor = Ort::Value::CreateTensor( + memory_info, + current_step_vec.data(), + current_step_vec.size(), + std::vector{bsz}.data(), + 1 + ); + + const char* vector_est_input_names[] = { + "noisy_latent", "text_emb", "style_ttl", "text_mask", "latent_mask", "total_step", "current_step" + }; + const char* vector_est_output_names[] = {"denoised_latent"}; + + std::vector vector_est_inputs; + vector_est_inputs.push_back(std::move(noisy_latent_tensor)); + vector_est_inputs.push_back(std::move(text_emb_tensor)); + vector_est_inputs.push_back(std::move(style_ttl_tensor)); + vector_est_inputs.push_back(std::move(text_mask_tensor)); + vector_est_inputs.push_back(std::move(latent_mask_tensor)); + + // Create a new total_step tensor for each iteration + auto total_step_tensor_iter = Ort::Value::CreateTensor( + memory_info, + total_step_vec.data(), + total_step_vec.size(), + std::vector{bsz}.data(), + 1 + ); + vector_est_inputs.push_back(std::move(total_step_tensor_iter)); + vector_est_inputs.push_back(std::move(current_step_tensor)); + + auto vector_est_outputs = vector_est_ort_->Run( + Ort::RunOptions{nullptr}, + vector_est_input_names, vector_est_inputs.data(), vector_est_inputs.size(), + vector_est_output_names, 1 + ); + + // Update xt with denoised output + auto* denoised_data = vector_est_outputs[0].GetTensorMutableData(); + size_t idx = 0; + for (int b = 0; b < bsz; b++) { + for (size_t d = 0; d < xt[b].size(); d++) { + for (size_t t = 0; t < xt[b][d].size(); t++) { + xt[b][d][t] = denoised_data[idx++]; + } + } + } + } + + // Run vocoder + auto latent_tensor = arrayToTensor(memory_info, xt, latent_shape); + const char* vocoder_input_names[] = {"latent"}; + const char* vocoder_output_names[] = {"wav_tts"}; + std::vector vocoder_inputs; + vocoder_inputs.push_back(std::move(latent_tensor)); + + auto vocoder_outputs = vocoder_ort_->Run( + Ort::RunOptions{nullptr}, + vocoder_input_names, vocoder_inputs.data(), vocoder_inputs.size(), + vocoder_output_names, 1 + ); + + auto wav_info = vocoder_outputs[0].GetTensorTypeAndShapeInfo(); + size_t wav_size = wav_info.GetElementCount(); + auto* wav_data = vocoder_outputs[0].GetTensorMutableData(); + + SynthesisResult result; + result.wav.assign(wav_data, wav_data + wav_size); + result.duration = duration; + + return result; +} + +TextToSpeech::SynthesisResult TextToSpeech::call( + Ort::MemoryInfo& memory_info, + const std::string& text, + const std::string& lang, + const Style& style, + int total_step, + float speed, + float silence_duration +) { + if (style.getTtlShape()[0] != 1) { + throw std::runtime_error("Single speaker text to speech only supports single style"); + } + + int max_len = (lang == "ko") ? 120 : 300; + auto text_list = chunkText(text, max_len); + std::vector wav_cat; + float dur_cat = 0.0f; + + for (const auto& chunk : text_list) { + auto result = _infer(memory_info, {chunk}, {lang}, style, total_step, speed); + + if (wav_cat.empty()) { + wav_cat = result.wav; + dur_cat = result.duration[0]; + } else { + int silence_len = static_cast(silence_duration * sample_rate_); + std::vector silence(silence_len, 0.0f); + wav_cat.insert(wav_cat.end(), silence.begin(), silence.end()); + wav_cat.insert(wav_cat.end(), result.wav.begin(), result.wav.end()); + dur_cat += result.duration[0] + silence_duration; + } + } + + SynthesisResult final_result; + final_result.wav = wav_cat; + final_result.duration = {dur_cat}; + + return final_result; +} + +TextToSpeech::SynthesisResult TextToSpeech::batch( + Ort::MemoryInfo& memory_info, + const std::vector& text_list, + const std::vector& lang_list, + const Style& style, + int total_step, + float speed +) { + return _infer(memory_info, text_list, lang_list, style, total_step, speed); +} + +// ============================================================================ +// Utility functions +// ============================================================================ + +std::vector>> lengthToMask( + const std::vector& lengths, int max_len +) { + if (max_len == -1) { + max_len = *std::max_element(lengths.begin(), lengths.end()); + } + + std::vector>> mask; + for (auto len : lengths) { + std::vector> batch_mask(1); + batch_mask[0].resize(max_len); + for (int i = 0; i < max_len; i++) { + batch_mask[0][i] = (i < len) ? 1.0f : 0.0f; + } + mask.push_back(batch_mask); + } + return mask; +} + +std::vector>> getLatentMask( + const std::vector& wav_lengths, + int base_chunk_size, + int chunk_compress_factor +) { + int latent_size = base_chunk_size * chunk_compress_factor; + std::vector latent_lengths; + for (auto len : wav_lengths) { + latent_lengths.push_back((len + latent_size - 1) / latent_size); + } + return lengthToMask(latent_lengths); +} + +// ============================================================================ +// ONNX model loading +// ============================================================================ + +std::unique_ptr loadOnnx( + Ort::Env& env, + const std::string& onnx_path, + const Ort::SessionOptions& opts +) { + return std::make_unique(env, onnx_path.c_str(), opts); +} + +OnnxModels loadOnnxAll( + Ort::Env& env, + const std::string& onnx_dir, + const Ort::SessionOptions& opts +) { + OnnxModels models; + models.dp = loadOnnx(env, onnx_dir + "/duration_predictor.onnx", opts); + models.text_enc = loadOnnx(env, onnx_dir + "/text_encoder.onnx", opts); + models.vector_est = loadOnnx(env, onnx_dir + "/vector_estimator.onnx", opts); + models.vocoder = loadOnnx(env, onnx_dir + "/vocoder.onnx", opts); + return models; +} + +// ============================================================================ +// Configuration and processor loading +// ============================================================================ + +Config loadCfgs(const std::string& onnx_dir) { + std::string cfg_path = onnx_dir + "/tts.json"; + std::ifstream file(cfg_path); + if (!file.is_open()) { + throw std::runtime_error("Failed to open config file: " + cfg_path); + } + + json j; + file >> j; + + Config cfg; + cfg.ae.sample_rate = j["ae"]["sample_rate"]; + cfg.ae.base_chunk_size = j["ae"]["base_chunk_size"]; + cfg.ttl.chunk_compress_factor = j["ttl"]["chunk_compress_factor"]; + cfg.ttl.latent_dim = j["ttl"]["latent_dim"]; + + return cfg; +} + +std::unique_ptr loadTextProcessor(const std::string& onnx_dir) { + std::string unicode_indexer_path = onnx_dir + "/unicode_indexer.json"; + return std::make_unique(unicode_indexer_path); +} + +// ============================================================================ +// Voice style loading +// ============================================================================ + +Style loadVoiceStyle(const std::vector& voice_style_paths, bool verbose) { + int bsz = voice_style_paths.size(); + + // Read first file to get dimensions + std::ifstream first_file(voice_style_paths[0]); + if (!first_file.is_open()) { + throw std::runtime_error("Failed to open voice style file: " + voice_style_paths[0]); + } + json first_json; + first_file >> first_json; + + auto ttl_dims = first_json["style_ttl"]["dims"].get>(); + auto dp_dims = first_json["style_dp"]["dims"].get>(); + + int64_t ttl_dim1 = ttl_dims[1]; + int64_t ttl_dim2 = ttl_dims[2]; + int64_t dp_dim1 = dp_dims[1]; + int64_t dp_dim2 = dp_dims[2]; + + // Pre-allocate arrays with full batch size + size_t ttl_size = bsz * ttl_dim1 * ttl_dim2; + size_t dp_size = bsz * dp_dim1 * dp_dim2; + std::vector ttl_flat(ttl_size); + std::vector dp_flat(dp_size); + + // Fill in the data + for (int i = 0; i < bsz; i++) { + std::ifstream file(voice_style_paths[i]); + if (!file.is_open()) { + throw std::runtime_error("Failed to open voice style file: " + voice_style_paths[i]); + } + + json j; + file >> j; + + // Flatten data + auto ttl_data_nested = j["style_ttl"]["data"].get>>>(); + std::vector ttl_data; + for (const auto& batch : ttl_data_nested) { + for (const auto& row : batch) { + ttl_data.insert(ttl_data.end(), row.begin(), row.end()); + } + } + + auto dp_data_nested = j["style_dp"]["data"].get>>>(); + std::vector dp_data; + for (const auto& batch : dp_data_nested) { + for (const auto& row : batch) { + dp_data.insert(dp_data.end(), row.begin(), row.end()); + } + } + + // Copy to pre-allocated array + size_t ttl_offset = i * ttl_dim1 * ttl_dim2; + std::copy(ttl_data.begin(), ttl_data.end(), ttl_flat.begin() + ttl_offset); + + size_t dp_offset = i * dp_dim1 * dp_dim2; + std::copy(dp_data.begin(), dp_data.end(), dp_flat.begin() + dp_offset); + } + + std::vector ttl_shape = {bsz, ttl_dim1, ttl_dim2}; + std::vector dp_shape = {bsz, dp_dim1, dp_dim2}; + + if (verbose) { + std::cout << "Loaded " << bsz << " voice styles" << std::endl; + } + + return Style(ttl_flat, ttl_shape, dp_flat, dp_shape); +} + +// ============================================================================ +// TextToSpeech loading +// ============================================================================ + +std::unique_ptr loadTextToSpeech( + Ort::Env& env, + const std::string& onnx_dir, + bool use_gpu +) { + Ort::SessionOptions opts; + if (use_gpu) { + throw std::runtime_error("GPU mode is not supported yet"); + } else { + std::cout << "Using CPU for inference" << std::endl; + } + + auto cfgs = loadCfgs(onnx_dir); + auto models = loadOnnxAll(env, onnx_dir, opts); + auto text_processor = loadTextProcessor(onnx_dir); + + // Transfer ownership to TextToSpeech (use raw pointers internally) + auto tts = std::make_unique( + cfgs, + text_processor.get(), + models.dp.get(), + models.text_enc.get(), + models.vector_est.get(), + models.vocoder.get() + ); + + // Keep the models and processor alive by storing them + // (In production, you'd want better lifetime management) + static OnnxModels static_models; + static std::unique_ptr static_text_processor; + static_models = std::move(models); + static_text_processor = std::move(text_processor); + + return tts; +} + +// ============================================================================ +// WAV file writing +// ============================================================================ + +void writeWavFile( + const std::string& filename, + const std::vector& audio_data, + int sample_rate +) { + std::ofstream file(filename, std::ios::binary); + if (!file.is_open()) { + throw std::runtime_error("Failed to open file for writing: " + filename); + } + + int num_channels = 1; + int bits_per_sample = 16; + int byte_rate = sample_rate * num_channels * bits_per_sample / 8; + int block_align = num_channels * bits_per_sample / 8; + int data_size = audio_data.size() * bits_per_sample / 8; + + // RIFF header + file.write("RIFF", 4); + int32_t chunk_size = 36 + data_size; + file.write(reinterpret_cast(&chunk_size), 4); + file.write("WAVE", 4); + + // fmt chunk + file.write("fmt ", 4); + int32_t fmt_chunk_size = 16; + file.write(reinterpret_cast(&fmt_chunk_size), 4); + int16_t audio_format = 1; // PCM + file.write(reinterpret_cast(&audio_format), 2); + int16_t num_channels_16 = num_channels; + file.write(reinterpret_cast(&num_channels_16), 2); + file.write(reinterpret_cast(&sample_rate), 4); + file.write(reinterpret_cast(&byte_rate), 4); + int16_t block_align_16 = block_align; + file.write(reinterpret_cast(&block_align_16), 2); + int16_t bits_per_sample_16 = bits_per_sample; + file.write(reinterpret_cast(&bits_per_sample_16), 2); + + // data chunk + file.write("data", 4); + file.write(reinterpret_cast(&data_size), 4); + + // Write audio data + for (float sample : audio_data) { + float clamped = std::max(-1.0f, std::min(1.0f, sample)); + int16_t int_sample = static_cast(clamped * 32767); + file.write(reinterpret_cast(&int_sample), 2); + } +} + +// ============================================================================ +// Tensor conversion utilities +// ============================================================================ + +Ort::Value arrayToTensor( + Ort::MemoryInfo& memory_info, + const std::vector>>& array, + const std::vector& dims +) { + // Flatten the array + std::vector flat; + for (const auto& batch : array) { + for (const auto& row : batch) { + for (float val : row) { + flat.push_back(val); + } + } + } + + // Store in global buffer to keep data alive + g_tensor_buffers_float.push_back(std::move(flat)); + auto& buffer = g_tensor_buffers_float.back(); + + return Ort::Value::CreateTensor( + memory_info, + buffer.data(), + buffer.size(), + dims.data(), + dims.size() + ); +} + +Ort::Value intArrayToTensor( + Ort::MemoryInfo& memory_info, + const std::vector>& array, + const std::vector& dims +) { + // Flatten the array + std::vector flat; + for (const auto& row : array) { + for (int64_t val : row) { + flat.push_back(val); + } + } + + // Store in global buffer to keep data alive + g_tensor_buffers_int64.push_back(std::move(flat)); + auto& buffer = g_tensor_buffers_int64.back(); + + return Ort::Value::CreateTensor( + memory_info, + buffer.data(), + buffer.size(), + dims.data(), + dims.size() + ); +} + +// ============================================================================ +// JSON loading helpers +// ============================================================================ + +std::vector loadJsonInt64(const std::string& file_path) { + std::ifstream file(file_path); + if (!file.is_open()) { + throw std::runtime_error("Failed to open file: " + file_path); + } + + json j; + file >> j; + + return j.get>(); +} + +// ============================================================================ +// Sanitize filename +// ============================================================================ + +std::string sanitizeFilename(const std::string& text, int max_len) { + std::string result; + int char_count = 0; + size_t i = 0; + + while (i < text.size() && char_count < max_len) { + unsigned char c = static_cast(text[i]); + + // Check if it's ASCII alphanumeric or underscore + if (std::isalnum(c) || c == '_') { + result += text[i]; + i++; + char_count++; + } + // Check for UTF-8 multi-byte sequences (preserve Unicode letters/numbers) + else if ((c & 0xE0) == 0xC0 && i + 1 < text.size()) { + // 2-byte UTF-8 sequence + result += text.substr(i, 2); + i += 2; + char_count++; + } + else if ((c & 0xF0) == 0xE0 && i + 2 < text.size()) { + // 3-byte UTF-8 sequence (includes Korean, Japanese, Chinese) + result += text.substr(i, 3); + i += 3; + char_count++; + } + else if ((c & 0xF8) == 0xF0 && i + 3 < text.size()) { + // 4-byte UTF-8 sequence + result += text.substr(i, 4); + i += 4; + char_count++; + } + else { + // Replace other characters with underscore + result += '_'; + i++; + char_count++; + } + } + return result; +} + +// ============================================================================ +// Chunk text +// ============================================================================ + +std::vector chunkText(const std::string& text, int max_len) { + std::vector chunks; + + // Split by paragraph (two or more newlines) + std::regex paragraph_regex(R"(\n\s*\n+)"); + std::sregex_token_iterator iter(text.begin(), text.end(), paragraph_regex, -1); + std::sregex_token_iterator end; + + std::vector paragraphs; + for (; iter != end; ++iter) { + std::string para = trim(*iter); + if (!para.empty()) { + paragraphs.push_back(para); + } + } + + // Split by sentence boundaries, excluding abbreviations + // This is a simplified version - C++ negative lookbehind is more complex + std::regex sentence_regex(R"([.!?]\s+)"); + + for (const auto& paragraph : paragraphs) { + std::sregex_token_iterator sent_iter(paragraph.begin(), paragraph.end(), sentence_regex, -1); + std::sregex_token_iterator sent_end; + + std::vector sentences; + std::string current = ""; + + for (; sent_iter != sent_end; ++sent_iter) { + std::string sentence = *sent_iter; + if (!sentence.empty()) { + // Add back the punctuation + if (sent_iter != sent_end) { + std::smatch match; + if (std::regex_search(sent_iter->first, paragraph.end(), match, sentence_regex)) { + sentence += match.str(); + } + } + sentences.push_back(sentence); + } + } + + // Combine sentences into chunks + std::string current_chunk = ""; + + for (const auto& sentence : sentences) { + if (static_cast(current_chunk.length() + sentence.length() + 1) <= max_len) { + if (!current_chunk.empty()) { + current_chunk += " "; + } + current_chunk += sentence; + } else { + if (!current_chunk.empty()) { + chunks.push_back(trim(current_chunk)); + } + current_chunk = sentence; + } + } + + if (!current_chunk.empty()) { + chunks.push_back(trim(current_chunk)); + } + } + + // If no chunks were created, return the original text + if (chunks.empty()) { + chunks.push_back(trim(text)); + } + + return chunks; +} diff --git a/cpp/helper.h b/cpp/helper.h new file mode 100644 index 0000000..34f3941 --- /dev/null +++ b/cpp/helper.h @@ -0,0 +1,229 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +// Available languages for multilingual TTS +extern const std::vector AVAILABLE_LANGS; + +/** + * Configuration structure + */ +struct Config { + struct AEConfig { + int sample_rate; + int base_chunk_size; + } ae; + + struct TTLConfig { + int chunk_compress_factor; + int latent_dim; + } ttl; +}; + +/** + * Unicode text processor + */ +class UnicodeProcessor { +public: + explicit UnicodeProcessor(const std::string& unicode_indexer_json_path); + + // Process text list to text IDs and mask + void call( + const std::vector& text_list, + const std::vector& lang_list, + std::vector>& text_ids, + std::vector>>& text_mask + ); + +private: + std::vector indexer_; + + std::string preprocessText(const std::string& text, const std::string& lang); + std::vector textToUnicodeValues(const std::string& text); + std::vector>> getTextMask( + const std::vector& text_ids_lengths + ); +}; + +/** + * Style class + */ +class Style { +public: + Style(const std::vector& ttl_data, const std::vector& ttl_shape, + const std::vector& dp_data, const std::vector& dp_shape); + + const std::vector& getTtlData() const { return ttl_data_; } + const std::vector& getDpData() const { return dp_data_; } + const std::vector& getTtlShape() const { return ttl_shape_; } + const std::vector& getDpShape() const { return dp_shape_; } + +private: + std::vector ttl_data_; + std::vector dp_data_; + std::vector ttl_shape_; + std::vector dp_shape_; +}; + +/** + * TextToSpeech class + */ +class TextToSpeech { +public: + TextToSpeech( + const Config& cfgs, + UnicodeProcessor* text_processor, + Ort::Session* dp_ort, + Ort::Session* text_enc_ort, + Ort::Session* vector_est_ort, + Ort::Session* vocoder_ort + ); + + struct SynthesisResult { + std::vector wav; + std::vector duration; + }; + + SynthesisResult call( + Ort::MemoryInfo& memory_info, + const std::string& text, + const std::string& lang, + const Style& style, + int total_step, + float speed = 1.05f, + float silence_duration = 0.3f + ); + + SynthesisResult batch( + Ort::MemoryInfo& memory_info, + const std::vector& text_list, + const std::vector& lang_list, + const Style& style, + int total_step, + float speed = 1.05f + ); + + int getSampleRate() const { return sample_rate_; } + +private: + SynthesisResult _infer( + Ort::MemoryInfo& memory_info, + const std::vector& text_list, + const std::vector& lang_list, + const Style& style, + int total_step, + float speed = 1.05f + ); + Config cfgs_; + UnicodeProcessor* text_processor_; + Ort::Session* dp_ort_; + Ort::Session* text_enc_ort_; + Ort::Session* vector_est_ort_; + Ort::Session* vocoder_ort_; + int sample_rate_; + int base_chunk_size_; + int chunk_compress_factor_; + int ldim_; + + void sampleNoisyLatent( + const std::vector& duration, + std::vector>>& noisy_latent, + std::vector>>& latent_mask + ); +}; + +// Utility functions +std::vector>> lengthToMask( + const std::vector& lengths, int max_len = -1 +); + +std::vector>> getLatentMask( + const std::vector& wav_lengths, + int base_chunk_size, + int chunk_compress_factor +); + +// ONNX model loading +struct OnnxModels { + std::unique_ptr dp; + std::unique_ptr text_enc; + std::unique_ptr vector_est; + std::unique_ptr vocoder; +}; + +std::unique_ptr loadOnnx( + Ort::Env& env, + const std::string& onnx_path, + const Ort::SessionOptions& opts +); + +OnnxModels loadOnnxAll( + Ort::Env& env, + const std::string& onnx_dir, + const Ort::SessionOptions& opts +); + +// Configuration and processor loading +Config loadCfgs(const std::string& onnx_dir); + +std::unique_ptr loadTextProcessor(const std::string& onnx_dir); + +// Voice style loading +Style loadVoiceStyle(const std::vector& voice_style_paths, bool verbose = false); + +// TextToSpeech loading +std::unique_ptr loadTextToSpeech( + Ort::Env& env, + const std::string& onnx_dir, + bool use_gpu = false +); + +// WAV file writing +void writeWavFile( + const std::string& filename, + const std::vector& audio_data, + int sample_rate +); + +// Tensor conversion utilities +void clearTensorBuffers(); + +Ort::Value arrayToTensor( + Ort::MemoryInfo& memory_info, + const std::vector>>& array, + const std::vector& dims +); + +Ort::Value intArrayToTensor( + Ort::MemoryInfo& memory_info, + const std::vector>& array, + const std::vector& dims +); + +// JSON loading helpers +std::vector loadJsonInt64(const std::string& file_path); + +// Timer utility +template +auto timer(const std::string& name, Func&& func) -> decltype(func()) { + auto start = std::chrono::high_resolution_clock::now(); + std::cout << name << "..." << std::endl; + auto result = func(); + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration elapsed = end - start; + std::cout << " -> " << name << " completed in " + << std::fixed << std::setprecision(2) << elapsed.count() << " sec" << std::endl; + return result; +} + +// Sanitize filename +std::string sanitizeFilename(const std::string& text, int max_len); + +// Chunk text into manageable segments +std::vector chunkText(const std::string& text, int max_len = 300); diff --git a/csharp/.gitignore b/csharp/.gitignore new file mode 100644 index 0000000..2c7f8fe --- /dev/null +++ b/csharp/.gitignore @@ -0,0 +1,41 @@ +# Build results +bin/ +obj/ +[Dd]ebug/ +[Rr]elease/ +x64/ +x86/ +[Aa]rm/ +[Aa]rm64/ +bld/ +[Bb]in/ +[Oo]bj/ +[Ll]og/ + +# Visual Studio files +.vs/ +*.suo +*.user +*.userosscache +*.sln.docstates +*.userprefs + +# Rider +.idea/ +*.sln.iml + +# User-specific files +*.rsuser +*.suo +*.user +*.userosscache +*.sln.docstates + +# Output directory +results/*.wav + +# OS files +.DS_Store +Thumbs.db + + diff --git a/csharp/ExampleONNX.cs b/csharp/ExampleONNX.cs new file mode 100644 index 0000000..143072c --- /dev/null +++ b/csharp/ExampleONNX.cs @@ -0,0 +1,171 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Media; + +namespace Supertonic +{ + class Program + { + class Args + { + public bool UseGpu { get; set; } = false; + public string OnnxDir { get; set; } = "./assets/onnx"; + public int TotalStep { get; set; } = 5; + public float Speed { get; set; } = 1.05f; + public int NTest { get; set; } = 4; + public List VoiceStyle { get; set; } = new List { "assets/voice_styles/F2.json" }; + public List Text { get; set; } = new List + { + "๋™ํ•ด๋ฌผ๊ณผ ๋ฐฑ๋‘์‚ฐ์ด ๋งˆ๋ฅด๊ณ  ๋‹ณ๋„๋ก ํ•˜๋А๋‹˜์ด ๋ณด์šฐํ•˜์‚ฌ. ์šฐ๋ฆฌ ๋‚˜๋ผ ๋งŒ์„ธ~~" + }; + public List Lang { get; set; } = new List { "ko" }; + public string SaveDir { get; set; } = "results"; + public bool Batch { get; set; } = false; + public int? Seed { get; set; } = null; + public float PreSilence { get; set; } = 0.2f; + } + + static Args ParseArgs(string[] args) + { + var result = new Args(); + + for (int i = 0; i < args.Length; i++) + { + switch (args[i]) + { + case "--use-gpu": + result.UseGpu = true; + break; + case "--batch": + result.Batch = true; + break; + case "--onnx-dir" when i + 1 < args.Length: + result.OnnxDir = args[++i]; + break; + case "--total-step" when i + 1 < args.Length: + result.TotalStep = int.Parse(args[++i]); + break; + case "--speed" when i + 1 < args.Length: + result.Speed = float.Parse(args[++i]); + break; + case "--n-test" when i + 1 < args.Length: + result.NTest = int.Parse(args[++i]); + break; + case "--voice-style" when i + 1 < args.Length: + result.VoiceStyle = args[++i].Split(',').ToList(); + break; + case "--text" when i + 1 < args.Length: + result.Text = args[++i].Split('|').ToList(); + break; + case "--lang" when i + 1 < args.Length: + result.Lang = args[++i].Split(',').ToList(); + break; + case "--save-dir" when i + 1 < args.Length: + result.SaveDir = args[++i]; + break; + case "--seed" when i + 1 < args.Length: + result.Seed = int.Parse(args[++i]); + break; + case "--pre-silence" when i + 1 < args.Length: + result.PreSilence = float.Parse(args[++i]); + break; + } + } + + return result; + } + + static void Main(string[] args) + { + Console.WriteLine("=== TTS Inference with ONNX Runtime (C#) ===\n"); + Console.WriteLine("sample seed : 371279630"); + + // --- 1. Parse arguments --- // + var parsedArgs = ParseArgs(args); + int totalStep = parsedArgs.TotalStep; + float speed = parsedArgs.Speed; + int nTest = parsedArgs.NTest; + string saveDir = parsedArgs.SaveDir; + var voiceStylePaths = parsedArgs.VoiceStyle; + var textList = parsedArgs.Text; + var langList = parsedArgs.Lang; + bool batch = parsedArgs.Batch; + + if (voiceStylePaths.Count != textList.Count) + { + throw new ArgumentException( + $"Number of voice styles ({voiceStylePaths.Count}) must match number of texts ({textList.Count})"); + } + int bsz = voiceStylePaths.Count; + + // --- 2. Load Text to Speech --- // + var textToSpeech = Helper.LoadTextToSpeech(parsedArgs.OnnxDir, parsedArgs.UseGpu); + Console.WriteLine(); + + // --- 3. Load Voice Style --- // + var style = Helper.LoadVoiceStyle(voiceStylePaths, verbose: true); + + // --- 4. Synthesize speech --- // + Random seedGenerator = new Random(); + for (int n = 0; n < nTest; n++) + { + int currentSeed = parsedArgs.Seed ?? seedGenerator.Next(); + Console.WriteLine($"\n[{n + 1}/{nTest}] Starting synthesis (Seed: {currentSeed})..."); + + var (wav, duration) = Helper.Timer("Generating speech from text", () => + { + if (batch) + { + return textToSpeech.Batch(textList, langList, style, totalStep, speed, currentSeed); + } + else + { + return textToSpeech.Call(textList[0], langList[0], style, totalStep, speed, seed: currentSeed); + } + }); + + if (!Directory.Exists(saveDir)) + { + Directory.CreateDirectory(saveDir); + } + + for (int b = 0; b < bsz; b++) + { + string fname = $"{Helper.SanitizeFilename(textList[b], 20)}_{n + 1}_s{currentSeed}.wav"; + + int wavLen = (int)(textToSpeech.SampleRate * duration[b]); + + // --- Add Pre-Silence (Delay) --- // + int silenceSamples = (int)(textToSpeech.SampleRate * parsedArgs.PreSilence); + var wavOut = new float[wavLen + silenceSamples]; + + // The array is initialized to 0 by default, so we just copy the audio after the silence + Array.Copy(wav, b * wav.Length / bsz, wavOut, silenceSamples, Math.Min(wavLen, wav.Length / bsz)); + + string outputPath = Path.Combine(saveDir, fname); + Helper.WriteWavFile(outputPath, wavOut, textToSpeech.SampleRate); + Console.WriteLine($"Saved: {outputPath}"); + + // --- Play the generated audio --- // + try + { + using (var player = new SoundPlayer(outputPath)) + { + Console.WriteLine("Playing audio..."); + player.PlaySync(); + } + } + catch (Exception ex) + { + Console.WriteLine($"Warning: Could not play audio. {ex.Message}"); + } + } + } + + Console.WriteLine("\n=== Synthesis completed successfully! ==="); + } + } +} + diff --git a/csharp/Helper.cs b/csharp/Helper.cs new file mode 100644 index 0000000..8f57c59 --- /dev/null +++ b/csharp/Helper.cs @@ -0,0 +1,861 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using System.Text.Json; +using System.Text.RegularExpressions; +using Microsoft.ML.OnnxRuntime; +using Microsoft.ML.OnnxRuntime.Tensors; + +namespace Supertonic +{ + // Available languages for multilingual TTS + public static class Languages + { + public static readonly string[] Available = { "en", "ko", "es", "pt", "fr" }; + } + + // ============================================================================ + // Configuration classes + // ============================================================================ + + public class Config + { + public AEConfig AE { get; set; } = null!; + public TTLConfig TTL { get; set; } = null!; + + public class AEConfig + { + public int SampleRate { get; set; } + public int BaseChunkSize { get; set; } + } + + public class TTLConfig + { + public int ChunkCompressFactor { get; set; } + public int LatentDim { get; set; } + } + } + + // ============================================================================ + // Style class + // ============================================================================ + + public class Style + { + public float[] Ttl { get; set; } + public long[] TtlShape { get; set; } + public float[] Dp { get; set; } + public long[] DpShape { get; set; } + + public Style(float[] ttl, long[] ttlShape, float[] dp, long[] dpShape) + { + Ttl = ttl; + TtlShape = ttlShape; + Dp = dp; + DpShape = dpShape; + } + } + + // ============================================================================ + // Unicode text processor + // ============================================================================ + + public class UnicodeProcessor + { + private readonly Dictionary _indexer; + + public UnicodeProcessor(string unicodeIndexerPath) + { + var json = File.ReadAllText(unicodeIndexerPath); + var indexerArray = JsonSerializer.Deserialize(json) ?? throw new Exception("Failed to load indexer"); + _indexer = new Dictionary(); + for (int i = 0; i < indexerArray.Length; i++) + { + _indexer[i] = indexerArray[i]; + } + } + + private static string RemoveEmojis(string text) + { + var result = new StringBuilder(); + for (int i = 0; i < text.Length; i++) + { + int codePoint; + if (char.IsHighSurrogate(text[i]) && i + 1 < text.Length && char.IsLowSurrogate(text[i + 1])) + { + // Get the full code point from surrogate pair + codePoint = char.ConvertToUtf32(text[i], text[i + 1]); + i++; // Skip the low surrogate + } + else + { + codePoint = text[i]; + } + + // Check if code point is in emoji ranges + bool isEmoji = (codePoint >= 0x1F600 && codePoint <= 0x1F64F) || + (codePoint >= 0x1F300 && codePoint <= 0x1F5FF) || + (codePoint >= 0x1F680 && codePoint <= 0x1F6FF) || + (codePoint >= 0x1F700 && codePoint <= 0x1F77F) || + (codePoint >= 0x1F780 && codePoint <= 0x1F7FF) || + (codePoint >= 0x1F800 && codePoint <= 0x1F8FF) || + (codePoint >= 0x1F900 && codePoint <= 0x1F9FF) || + (codePoint >= 0x1FA00 && codePoint <= 0x1FA6F) || + (codePoint >= 0x1FA70 && codePoint <= 0x1FAFF) || + (codePoint >= 0x2600 && codePoint <= 0x26FF) || + (codePoint >= 0x2700 && codePoint <= 0x27BF) || + (codePoint >= 0x1F1E6 && codePoint <= 0x1F1FF); + + if (!isEmoji) + { + if (codePoint > 0xFFFF) + { + // Add back as surrogate pair + result.Append(char.ConvertFromUtf32(codePoint)); + } + else + { + result.Append((char)codePoint); + } + } + } + return result.ToString(); + } + + private string PreprocessText(string text, string lang) + { + // TODO: Need advanced normalizer for better performance + text = text.Normalize(NormalizationForm.FormKD); + + // Remove emojis (wide Unicode range) + // C# doesn't support \u{...} syntax in regex, so we use character filtering instead + text = RemoveEmojis(text); + + // Replace various dashes and symbols + var replacements = new Dictionary + { + {"โ€“", "-"}, // en dash + {"โ€‘", "-"}, // non-breaking hyphen + {"โ€”", "-"}, // em dash + {"_", " "}, // underscore + {"\u201C", "\""}, // left double quote + {"\u201D", "\""}, // right double quote + {"\u2018", "'"}, // left single quote + {"\u2019", "'"}, // right single quote + {"ยด", "'"}, // acute accent + {"`", "'"}, // grave accent + {"[", " "}, // left bracket + {"]", " "}, // right bracket + {"|", " "}, // vertical bar + {"/", " "}, // slash + {"#", " "}, // hash + {"โ†’", " "}, // right arrow + {"โ†", " "}, // left arrow + }; + + foreach (var kvp in replacements) + { + text = text.Replace(kvp.Key, kvp.Value); + } + + // Remove special symbols + text = Regex.Replace(text, @"[โ™ฅโ˜†โ™กยฉ\\]", ""); + + // Replace known expressions + var exprReplacements = new Dictionary + { + {"@", " at "}, + {"e.g.,", "for example, "}, + {"i.e.,", "that is, "}, + }; + + foreach (var kvp in exprReplacements) + { + text = text.Replace(kvp.Key, kvp.Value); + } + + // Fix spacing around punctuation + text = Regex.Replace(text, @" ,", ","); + text = Regex.Replace(text, @" \.", "."); + text = Regex.Replace(text, @" !", "!"); + text = Regex.Replace(text, @" \?", "?"); + text = Regex.Replace(text, @" ;", ";"); + text = Regex.Replace(text, @" :", ":"); + text = Regex.Replace(text, @" '", "'"); + + // Remove duplicate quotes + while (text.Contains("\"\"")) + { + text = text.Replace("\"\"", "\""); + } + while (text.Contains("''")) + { + text = text.Replace("''", "'"); + } + while (text.Contains("``")) + { + text = text.Replace("``", "`"); + } + + // Remove extra spaces + text = Regex.Replace(text, @"\s+", " ").Trim(); + + // If text doesn't end with punctuation, quotes, or closing brackets, add a period + if (!Regex.IsMatch(text, @"[.!?;:,'\u0022\u201C\u201D\u2018\u2019)\]}โ€ฆใ€‚ใ€ใ€ใ€‘ใ€‰ใ€‹โ€บยป]$")) + { + text += "."; + } + + // Validate language + if (!Languages.Available.Contains(lang)) + { + throw new ArgumentException($"Invalid language: {lang}. Available: {string.Join(", ", Languages.Available)}"); + } + + // Wrap text with language tags + text = $"<{lang}>" + text + $""; + + return text; + } + + private int[] TextToUnicodeValues(string text) + { + return text.Select(c => (int)c).ToArray(); + } + + private float[][][] GetTextMask(long[] textIdsLengths) + { + return Helper.LengthToMask(textIdsLengths); + } + + public (long[][] textIds, float[][][] textMask) Call(List textList, List langList) + { + var processedTexts = textList.Select((t, i) => PreprocessText(t, langList[i])).ToList(); + var textIdsLengths = processedTexts.Select(t => (long)t.Length).ToArray(); + long maxLen = textIdsLengths.Max(); + + var textIds = new long[textList.Count][]; + for (int i = 0; i < processedTexts.Count; i++) + { + textIds[i] = new long[maxLen]; + var unicodeVals = TextToUnicodeValues(processedTexts[i]); + for (int j = 0; j < unicodeVals.Length; j++) + { + if (_indexer.TryGetValue(unicodeVals[j], out long val)) + { + textIds[i][j] = val; + } + } + } + + var textMask = GetTextMask(textIdsLengths); + return (textIds, textMask); + } + } + + // ============================================================================ + // TextToSpeech class + // ============================================================================ + + public class TextToSpeech + { + private readonly Config _cfgs; + private readonly UnicodeProcessor _textProcessor; + private readonly InferenceSession _dpOrt; + private readonly InferenceSession _textEncOrt; + private readonly InferenceSession _vectorEstOrt; + private readonly InferenceSession _vocoderOrt; + public readonly int SampleRate; + private readonly int _baseChunkSize; + private readonly int _chunkCompressFactor; + private readonly int _ldim; + + public TextToSpeech( + Config cfgs, + UnicodeProcessor textProcessor, + InferenceSession dpOrt, + InferenceSession textEncOrt, + InferenceSession vectorEstOrt, + InferenceSession vocoderOrt) + { + _cfgs = cfgs; + _textProcessor = textProcessor; + _dpOrt = dpOrt; + _textEncOrt = textEncOrt; + _vectorEstOrt = vectorEstOrt; + _vocoderOrt = vocoderOrt; + SampleRate = cfgs.AE.SampleRate; + _baseChunkSize = cfgs.AE.BaseChunkSize; + _chunkCompressFactor = cfgs.TTL.ChunkCompressFactor; + _ldim = cfgs.TTL.LatentDim; + } + + private (float[][][] noisyLatent, float[][][] latentMask) SampleNoisyLatent(float[] duration, int seed) + { + int bsz = duration.Length; + float wavLenMax = duration.Max() * SampleRate; + var wavLengths = duration.Select(d => (long)(d * SampleRate)).ToArray(); + int chunkSize = _baseChunkSize * _chunkCompressFactor; + int latentLen = (int)((wavLenMax + chunkSize - 1) / chunkSize); + int latentDim = _ldim * _chunkCompressFactor; + + // Generate random noise with fixed seed + var random = new Random(seed); + var noisyLatent = new float[bsz][][]; + for (int b = 0; b < bsz; b++) + { + noisyLatent[b] = new float[latentDim][]; + for (int d = 0; d < latentDim; d++) + { + noisyLatent[b][d] = new float[latentLen]; + for (int t = 0; t < latentLen; t++) + { + // Box-Muller transform for normal distribution + double u1 = 1.0 - random.NextDouble(); + double u2 = 1.0 - random.NextDouble(); + noisyLatent[b][d][t] = (float)(Math.Sqrt(-2.0 * Math.Log(u1)) * Math.Cos(2.0 * Math.PI * u2)); + } + } + } + + var latentMask = Helper.GetLatentMask(wavLengths, _baseChunkSize, _chunkCompressFactor); + + // Apply mask + for (int b = 0; b < bsz; b++) + { + for (int d = 0; d < latentDim; d++) + { + for (int t = 0; t < latentLen; t++) + { + noisyLatent[b][d][t] *= latentMask[b][0][t]; + } + } + } + + return (noisyLatent, latentMask); + } + + private (float[] wav, float[] duration) _Infer(List textList, List langList, Style style, int totalStep, float speed = 1.05f, int seed = 42) + { + int bsz = textList.Count; + if (bsz != style.TtlShape[0]) + { + throw new ArgumentException("Number of texts must match number of style vectors"); + } + + // Process text + var (textIds, textMask) = _textProcessor.Call(textList, langList); + var textIdsShape = new long[] { bsz, textIds[0].Length }; + var textMaskShape = new long[] { bsz, 1, textMask[0][0].Length }; + + var textIdsTensor = Helper.IntArrayToTensor(textIds, textIdsShape); + var textMaskTensor = Helper.ArrayToTensor(textMask, textMaskShape); + + var styleTtlTensor = new DenseTensor(style.Ttl, style.TtlShape.Select(x => (int)x).ToArray()); + var styleDpTensor = new DenseTensor(style.Dp, style.DpShape.Select(x => (int)x).ToArray()); + + // Run duration predictor + var dpInputs = new List + { + NamedOnnxValue.CreateFromTensor("text_ids", textIdsTensor), + NamedOnnxValue.CreateFromTensor("style_dp", styleDpTensor), + NamedOnnxValue.CreateFromTensor("text_mask", textMaskTensor) + }; + using var dpOutputs = _dpOrt.Run(dpInputs); + var durOnnx = dpOutputs.First(o => o.Name == "duration").AsTensor().ToArray(); + + // Apply speed factor to duration + for (int i = 0; i < durOnnx.Length; i++) + { + durOnnx[i] /= speed; + } + + // Run text encoder + var textEncInputs = new List + { + NamedOnnxValue.CreateFromTensor("text_ids", textIdsTensor), + NamedOnnxValue.CreateFromTensor("style_ttl", styleTtlTensor), + NamedOnnxValue.CreateFromTensor("text_mask", textMaskTensor) + }; + using var textEncOutputs = _textEncOrt.Run(textEncInputs); + var textEmbTensor = textEncOutputs.First(o => o.Name == "text_emb").AsTensor(); + + // Sample noisy latent + var (xt, latentMask) = SampleNoisyLatent(durOnnx, seed); + var latentShape = new long[] { bsz, xt[0].Length, xt[0][0].Length }; + var latentMaskShape = new long[] { bsz, 1, latentMask[0][0].Length }; + + var totalStepArray = Enumerable.Repeat((float)totalStep, bsz).ToArray(); + + // Iterative denoising + for (int step = 0; step < totalStep; step++) + { + var currentStepArray = Enumerable.Repeat((float)step, bsz).ToArray(); + + var vectorEstInputs = new List + { + NamedOnnxValue.CreateFromTensor("noisy_latent", Helper.ArrayToTensor(xt, latentShape)), + NamedOnnxValue.CreateFromTensor("text_emb", textEmbTensor), + NamedOnnxValue.CreateFromTensor("style_ttl", styleTtlTensor), + NamedOnnxValue.CreateFromTensor("text_mask", textMaskTensor), + NamedOnnxValue.CreateFromTensor("latent_mask", Helper.ArrayToTensor(latentMask, latentMaskShape)), + NamedOnnxValue.CreateFromTensor("total_step", new DenseTensor(totalStepArray, new int[] { bsz })), + NamedOnnxValue.CreateFromTensor("current_step", new DenseTensor(currentStepArray, new int[] { bsz })) + }; + + using var vectorEstOutputs = _vectorEstOrt.Run(vectorEstInputs); + var denoisedLatent = vectorEstOutputs.First(o => o.Name == "denoised_latent").AsTensor(); + + // Update xt + int idx = 0; + for (int b = 0; b < bsz; b++) + { + for (int d = 0; d < xt[b].Length; d++) + { + for (int t = 0; t < xt[b][d].Length; t++) + { + xt[b][d][t] = denoisedLatent.GetValue(idx++); + } + } + } + } + + // Run vocoder + var vocoderInputs = new List + { + NamedOnnxValue.CreateFromTensor("latent", Helper.ArrayToTensor(xt, latentShape)) + }; + using var vocoderOutputs = _vocoderOrt.Run(vocoderInputs); + var wavTensor = vocoderOutputs.First(o => o.Name == "wav_tts").AsTensor(); + + return (wavTensor.ToArray(), durOnnx); + } + + public (float[] wav, float[] duration) Call(string text, string lang, Style style, int totalStep, float speed = 1.05f, float silenceDuration = 0.3f, int seed = 42) + { + if (style.TtlShape[0] != 1) + { + throw new ArgumentException("Single speaker text to speech only supports single style"); + } + + int maxLen = lang == "ko" ? 120 : 300; + var textList = Helper.ChunkText(text, maxLen); + var wavCat = new List(); + float durCat = 0.0f; + + foreach (var chunk in textList) + { + var (wav, duration) = _Infer(new List { chunk }, new List { lang }, style, totalStep, speed, seed); + + if (wavCat.Count == 0) + { + wavCat.AddRange(wav); + durCat = duration[0]; + } + else + { + int silenceLen = (int)(silenceDuration * SampleRate); + var silence = new float[silenceLen]; + wavCat.AddRange(silence); + wavCat.AddRange(wav); + durCat += duration[0] + silenceDuration; + } + } + + return (wavCat.ToArray(), new float[] { durCat }); + } + + public (float[] wav, float[] duration) Batch(List textList, List langList, Style style, int totalStep, float speed = 1.05f, int seed = 42) + { + return _Infer(textList, langList, style, totalStep, speed, seed); + } + } + + // ============================================================================ + // Helper class with utility functions + // ============================================================================ + + public static class Helper + { + // ============================================================================ + // Utility functions + // ============================================================================ + + public static float[][][] LengthToMask(long[] lengths, long maxLen = -1) + { + if (maxLen == -1) + { + maxLen = lengths.Max(); + } + + var mask = new float[lengths.Length][][]; + for (int i = 0; i < lengths.Length; i++) + { + mask[i] = new float[1][]; + mask[i][0] = new float[maxLen]; + for (int j = 0; j < maxLen; j++) + { + mask[i][0][j] = j < lengths[i] ? 1.0f : 0.0f; + } + } + return mask; + } + + public static float[][][] GetLatentMask(long[] wavLengths, int baseChunkSize, int chunkCompressFactor) + { + int latentSize = baseChunkSize * chunkCompressFactor; + var latentLengths = wavLengths.Select(len => (len + latentSize - 1) / latentSize).ToArray(); + return LengthToMask(latentLengths); + } + + // ============================================================================ + // ONNX model loading + // ============================================================================ + + public static InferenceSession LoadOnnx(string onnxPath, SessionOptions opts) + { + return new InferenceSession(onnxPath, opts); + } + + public static (InferenceSession dp, InferenceSession textEnc, InferenceSession vectorEst, InferenceSession vocoder) + LoadOnnxAll(string onnxDir, SessionOptions opts) + { + var dpPath = Path.Combine(onnxDir, "duration_predictor.onnx"); + var textEncPath = Path.Combine(onnxDir, "text_encoder.onnx"); + var vectorEstPath = Path.Combine(onnxDir, "vector_estimator.onnx"); + var vocoderPath = Path.Combine(onnxDir, "vocoder.onnx"); + + return ( + LoadOnnx(dpPath, opts), + LoadOnnx(textEncPath, opts), + LoadOnnx(vectorEstPath, opts), + LoadOnnx(vocoderPath, opts) + ); + } + + // ============================================================================ + // Configuration loading + // ============================================================================ + + public static Config LoadCfgs(string onnxDir) + { + var cfgPath = Path.Combine(onnxDir, "tts.json"); + var json = File.ReadAllText(cfgPath); + + using var doc = JsonDocument.Parse(json); + var root = doc.RootElement; + + return new Config + { + AE = new Config.AEConfig + { + SampleRate = root.GetProperty("ae").GetProperty("sample_rate").GetInt32(), + BaseChunkSize = root.GetProperty("ae").GetProperty("base_chunk_size").GetInt32() + }, + TTL = new Config.TTLConfig + { + ChunkCompressFactor = root.GetProperty("ttl").GetProperty("chunk_compress_factor").GetInt32(), + LatentDim = root.GetProperty("ttl").GetProperty("latent_dim").GetInt32() + } + }; + } + + public static UnicodeProcessor LoadTextProcessor(string onnxDir) + { + var unicodeIndexerPath = Path.Combine(onnxDir, "unicode_indexer.json"); + return new UnicodeProcessor(unicodeIndexerPath); + } + + // ============================================================================ + // Voice style loading + // ============================================================================ + + public static Style LoadVoiceStyle(List voiceStylePaths, bool verbose = false) + { + int bsz = voiceStylePaths.Count; + + // Read first file to get dimensions + var firstJson = File.ReadAllText(voiceStylePaths[0]); + using var firstDoc = JsonDocument.Parse(firstJson); + var firstRoot = firstDoc.RootElement; + + var ttlDims = ParseInt64Array(firstRoot.GetProperty("style_ttl").GetProperty("dims")); + var dpDims = ParseInt64Array(firstRoot.GetProperty("style_dp").GetProperty("dims")); + + long ttlDim1 = ttlDims[1]; + long ttlDim2 = ttlDims[2]; + long dpDim1 = dpDims[1]; + long dpDim2 = dpDims[2]; + + // Pre-allocate arrays with full batch size + int ttlSize = (int)(bsz * ttlDim1 * ttlDim2); + int dpSize = (int)(bsz * dpDim1 * dpDim2); + var ttlFlat = new float[ttlSize]; + var dpFlat = new float[dpSize]; + + // Fill in the data + for (int i = 0; i < bsz; i++) + { + var json = File.ReadAllText(voiceStylePaths[i]); + using var doc = JsonDocument.Parse(json); + var root = doc.RootElement; + + // Flatten data + var ttlData3D = ParseFloat3DArray(root.GetProperty("style_ttl").GetProperty("data")); + var ttlDataFlat = new List(); + foreach (var batch in ttlData3D) + { + foreach (var row in batch) + { + ttlDataFlat.AddRange(row); + } + } + + var dpData3D = ParseFloat3DArray(root.GetProperty("style_dp").GetProperty("data")); + var dpDataFlat = new List(); + foreach (var batch in dpData3D) + { + foreach (var row in batch) + { + dpDataFlat.AddRange(row); + } + } + + // Copy to pre-allocated array + int ttlOffset = (int)(i * ttlDim1 * ttlDim2); + ttlDataFlat.CopyTo(ttlFlat, ttlOffset); + + int dpOffset = (int)(i * dpDim1 * dpDim2); + dpDataFlat.CopyTo(dpFlat, dpOffset); + } + + var ttlShape = new long[] { bsz, ttlDim1, ttlDim2 }; + var dpShape = new long[] { bsz, dpDim1, dpDim2 }; + + if (verbose) + { + Console.WriteLine($"Loaded {bsz} voice styles"); + } + + return new Style(ttlFlat, ttlShape, dpFlat, dpShape); + } + + private static float[][][] ParseFloat3DArray(JsonElement element) + { + var result = new List(); + foreach (var batch in element.EnumerateArray()) + { + var batch2D = new List(); + foreach (var row in batch.EnumerateArray()) + { + var rowData = new List(); + foreach (var val in row.EnumerateArray()) + { + rowData.Add(val.GetSingle()); + } + batch2D.Add(rowData.ToArray()); + } + result.Add(batch2D.ToArray()); + } + return result.ToArray(); + } + + private static long[] ParseInt64Array(JsonElement element) + { + var result = new List(); + foreach (var val in element.EnumerateArray()) + { + result.Add(val.GetInt64()); + } + return result.ToArray(); + } + + // ============================================================================ + // TextToSpeech loading + // ============================================================================ + + public static TextToSpeech LoadTextToSpeech(string onnxDir, bool useGpu = false) + { + var opts = new SessionOptions(); + if (useGpu) + { + throw new NotImplementedException("GPU mode is not supported yet"); + } + else + { + Console.WriteLine("Using CPU for inference"); + } + + var cfgs = LoadCfgs(onnxDir); + var (dpOrt, textEncOrt, vectorEstOrt, vocoderOrt) = LoadOnnxAll(onnxDir, opts); + var textProcessor = LoadTextProcessor(onnxDir); + + return new TextToSpeech(cfgs, textProcessor, dpOrt, textEncOrt, vectorEstOrt, vocoderOrt); + } + + // ============================================================================ + // WAV file writing + // ============================================================================ + + public static void WriteWavFile(string filename, float[] audioData, int sampleRate) + { + using var writer = new BinaryWriter(File.Open(filename, FileMode.Create)); + + int numChannels = 1; + int bitsPerSample = 16; + int byteRate = sampleRate * numChannels * bitsPerSample / 8; + short blockAlign = (short)(numChannels * bitsPerSample / 8); + int dataSize = audioData.Length * bitsPerSample / 8; + + // RIFF header + writer.Write(Encoding.ASCII.GetBytes("RIFF")); + writer.Write(36 + dataSize); + writer.Write(Encoding.ASCII.GetBytes("WAVE")); + + // fmt chunk + writer.Write(Encoding.ASCII.GetBytes("fmt ")); + writer.Write(16); // fmt chunk size + writer.Write((short)1); // audio format (PCM) + writer.Write((short)numChannels); + writer.Write(sampleRate); + writer.Write(byteRate); + writer.Write(blockAlign); + writer.Write((short)bitsPerSample); + + // data chunk + writer.Write(Encoding.ASCII.GetBytes("data")); + writer.Write(dataSize); + + // Write audio data + foreach (var sample in audioData) + { + float clamped = Math.Max(-1.0f, Math.Min(1.0f, sample)); + short intSample = (short)(clamped * 32767); + writer.Write(intSample); + } + } + + // ============================================================================ + // Tensor conversion utilities + // ============================================================================ + + public static DenseTensor ArrayToTensor(float[][][] array, long[] dims) + { + var flat = new List(); + foreach (var batch in array) + { + foreach (var row in batch) + { + flat.AddRange(row); + } + } + return new DenseTensor(flat.ToArray(), dims.Select(x => (int)x).ToArray()); + } + + public static DenseTensor IntArrayToTensor(long[][] array, long[] dims) + { + var flat = new List(); + foreach (var row in array) + { + flat.AddRange(row); + } + return new DenseTensor(flat.ToArray(), dims.Select(x => (int)x).ToArray()); + } + + // ============================================================================ + // Timer utility + // ============================================================================ + + public static T Timer(string name, Func func) + { + var start = DateTime.Now; + Console.WriteLine($"{name}..."); + var result = func(); + var elapsed = (DateTime.Now - start).TotalSeconds; + Console.WriteLine($" -> {name} completed in {elapsed:F2} sec"); + return result; + } + + public static string SanitizeFilename(string text, int maxLen) + { + var result = new StringBuilder(); + int count = 0; + foreach (char c in text) + { + if (count >= maxLen) break; + if (char.IsLetterOrDigit(c)) + { + result.Append(c); + } + else + { + result.Append('_'); + } + count++; + } + return result.ToString(); + } + + // ============================================================================ + // Chunk text + // ============================================================================ + + public static List ChunkText(string text, int maxLen = 300) + { + var chunks = new List(); + + // Split by paragraph (two or more newlines) + var paragraphRegex = new Regex(@"\n\s*\n+"); + var paragraphs = paragraphRegex.Split(text.Trim()) + .Select(p => p.Trim()) + .Where(p => !string.IsNullOrEmpty(p)) + .ToList(); + + // Split by sentence boundaries, excluding abbreviations + var sentenceRegex = new Regex(@"(? + + + Exe + net9.0-windows + true + 13.0 + enable + + + + + + + + + + diff --git a/csharp/csharp.sln b/csharp/csharp.sln new file mode 100644 index 0000000..abdd797 --- /dev/null +++ b/csharp/csharp.sln @@ -0,0 +1,24 @@ +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 17 +VisualStudioVersion = 17.5.2.0 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Supertonic", "Supertonic.csproj", "{869BE631-3CAF-8F33-CD9A-3A5788517967}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {869BE631-3CAF-8F33-CD9A-3A5788517967}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {869BE631-3CAF-8F33-CD9A-3A5788517967}.Debug|Any CPU.Build.0 = Debug|Any CPU + {869BE631-3CAF-8F33-CD9A-3A5788517967}.Release|Any CPU.ActiveCfg = Release|Any CPU + {869BE631-3CAF-8F33-CD9A-3A5788517967}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {2726B6AA-94CF-4D70-899D-0356CF025555} + EndGlobalSection +EndGlobal diff --git a/flutter/.gitignore b/flutter/.gitignore new file mode 100644 index 0000000..3820a95 --- /dev/null +++ b/flutter/.gitignore @@ -0,0 +1,45 @@ +# Miscellaneous +*.class +*.log +*.pyc +*.swp +.DS_Store +.atom/ +.build/ +.buildlog/ +.history +.svn/ +.swiftpm/ +migrate_working_dir/ + +# IntelliJ related +*.iml +*.ipr +*.iws +.idea/ + +# The .vscode folder contains launch configuration and tasks you configure in +# VS Code which you may wish to be included in version control, so this line +# is commented out by default. +#.vscode/ + +# Flutter/Dart/Pub related +**/doc/api/ +**/ios/Flutter/.last_build_id +.dart_tool/ +.flutter-plugins-dependencies +.pub-cache/ +.pub/ +/build/ +/coverage/ + +# Symbolication related +app.*.symbols + +# Obfuscation related +app.*.map.json + +# Android Studio will place build artifacts here +/android/app/debug +/android/app/profile +/android/app/release diff --git a/flutter/.metadata b/flutter/.metadata new file mode 100644 index 0000000..79bbc85 --- /dev/null +++ b/flutter/.metadata @@ -0,0 +1,30 @@ +# This file tracks properties of this Flutter project. +# Used by Flutter tool to assess capabilities and perform upgrades etc. +# +# This file should be version controlled and should not be manually edited. + +version: + revision: "19074d12f7eaf6a8180cd4036a430c1d76de904e" + channel: "stable" + +project_type: app + +# Tracks metadata for the flutter migrate command +migration: + platforms: + - platform: root + create_revision: 19074d12f7eaf6a8180cd4036a430c1d76de904e + base_revision: 19074d12f7eaf6a8180cd4036a430c1d76de904e + - platform: macos + create_revision: 19074d12f7eaf6a8180cd4036a430c1d76de904e + base_revision: 19074d12f7eaf6a8180cd4036a430c1d76de904e + + # User provided section + + # List of Local paths (relative to this file) that should be + # ignored by the migrate tool. + # + # Files that are not part of the templates will be ignored by default. + unmanaged_files: + - 'lib/main.dart' + - 'ios/Runner.xcodeproj/project.pbxproj' diff --git a/flutter/README.md b/flutter/README.md new file mode 100644 index 0000000..ab7a7ff --- /dev/null +++ b/flutter/README.md @@ -0,0 +1,38 @@ +# Supertonic Flutter Example + +This example demonstrates how to use Supertonic 2 in a Flutter application using ONNX Runtime. + +> **Note:** This project uses the `flutter_onnxruntime` package ([https://pub.dev/packages/flutter_onnxruntime](https://pub.dev/packages/flutter_onnxruntime)). At the moment, only the macOS platform has been tested. Although the flutter_onnxruntime package supports several other platforms, they have not been tested in this project yet and may require additional verification. + + +## ๐Ÿ“ฐ Update News + +**2026.01.06** - ๐ŸŽ‰ **Supertonic 2** released with multilingual support! Now supports English (`en`), Korean (`ko`), Spanish (`es`), Portuguese (`pt`), and French (`fr`). [Demo](https://huggingface.co/spaces/Supertone/supertonic-2) | [Models](https://huggingface.co/Supertone/supertonic-2) + +**2025.12.10** - Added [6 new voice styles](https://huggingface.co/Supertone/supertonic/tree/b10dbaf18b316159be75b34d24f740008fddd381) (M3, M4, M5, F3, F4, F5). See [Voices](https://supertone-inc.github.io/supertonic-py/voices/) for details + +**2025.12.08** - Optimized ONNX models via [OnnxSlim](https://github.com/inisis/OnnxSlim) now available on [Hugging Face Models](https://huggingface.co/Supertone/supertonic) + +**2025.11.23** - Added and tested macos support. + +## Multilingual Support + +Supertonic 2 supports multiple languages. Select the appropriate language from the dropdown: +- **English (en)**: Default language +- **ํ•œ๊ตญ์–ด (ko)**: Korean +- **Espaรฑol (es)**: Spanish +- **Portuguรชs (pt)**: Portuguese +- **Franรงais (fr)**: French + +## Requirements + +- Flutter SDK version ^3.5.0 + +## Running the Demo + +```bash +flutter clean +flutter pub get +flutter run -d macos +``` + diff --git a/flutter/analysis_options.yaml b/flutter/analysis_options.yaml new file mode 100644 index 0000000..0d29021 --- /dev/null +++ b/flutter/analysis_options.yaml @@ -0,0 +1,28 @@ +# This file configures the analyzer, which statically analyzes Dart code to +# check for errors, warnings, and lints. +# +# The issues identified by the analyzer are surfaced in the UI of Dart-enabled +# IDEs (https://dart.dev/tools#ides-and-editors). The analyzer can also be +# invoked from the command line by running `flutter analyze`. + +# The following line activates a set of recommended lints for Flutter apps, +# packages, and plugins designed to encourage good coding practices. +include: package:flutter_lints/flutter.yaml + +linter: + # The lint rules applied to this project can be customized in the + # section below to disable rules from the `package:flutter_lints/flutter.yaml` + # included above or to enable additional rules. A list of all available lints + # and their documentation is published at https://dart.dev/lints. + # + # Instead of disabling a lint rule for the entire project in the + # section below, it can also be suppressed for a single line of code + # or a specific dart file by using the `// ignore: name_of_lint` and + # `// ignore_for_file: name_of_lint` syntax on the line or in the file + # producing the lint. + rules: + # avoid_print: false # Uncomment to disable the `avoid_print` rule + # prefer_single_quotes: true # Uncomment to enable the `prefer_single_quotes` rule + +# Additional information about this file can be found at +# https://dart.dev/guides/language/analysis-options diff --git a/flutter/lib/helper.dart b/flutter/lib/helper.dart new file mode 100644 index 0000000..b1f45ec --- /dev/null +++ b/flutter/lib/helper.dart @@ -0,0 +1,695 @@ +import 'dart:io'; +import 'dart:convert'; +import 'dart:math' as math; +import 'dart:typed_data'; +import 'package:flutter/services.dart' show rootBundle; +import 'package:flutter_onnxruntime/flutter_onnxruntime.dart'; +import 'package:logger/logger.dart'; +import 'package:path_provider/path_provider.dart'; + +final logger = Logger( + printer: PrettyPrinter(methodCount: 0, errorMethodCount: 5, lineLength: 80), +); + +// Available languages for multilingual TTS +const List availableLangs = ['en', 'ko', 'es', 'pt', 'fr']; + +bool isValidLang(String lang) => availableLangs.contains(lang); + +// Hangul Jamo constants for NFKD decomposition +const int _hangulSyllableBase = 0xAC00; +const int _hangulSyllableEnd = 0xD7A3; +const int _leadingJamoBase = 0x1100; +const int _vowelJamoBase = 0x1161; +const int _trailingJamoBase = 0x11A7; +const int _vowelCount = 21; +const int _trailingCount = 28; + +/// Decompose a Hangul syllable into Jamo (NFKD-like decomposition) +List _decomposeHangulSyllable(int codePoint) { + if (codePoint < _hangulSyllableBase || codePoint > _hangulSyllableEnd) { + return [codePoint]; + } + + final syllableIndex = codePoint - _hangulSyllableBase; + final leadingIndex = syllableIndex ~/ (_vowelCount * _trailingCount); + final vowelIndex = + (syllableIndex % (_vowelCount * _trailingCount)) ~/ _trailingCount; + final trailingIndex = syllableIndex % _trailingCount; + + final result = [ + _leadingJamoBase + leadingIndex, + _vowelJamoBase + vowelIndex, + ]; + + if (trailingIndex > 0) { + result.add(_trailingJamoBase + trailingIndex); + } + + return result; +} + +/// Common Latin character decompositions (NFKD) for es, pt, fr +const Map> _latinDecompositions = { + // Uppercase with acute accent + 0x00C1: [0x0041, 0x0301], // ร โ†’ A + ฬ + 0x00C9: [0x0045, 0x0301], // ร‰ โ†’ E + ฬ + 0x00CD: [0x0049, 0x0301], // ร โ†’ I + ฬ + 0x00D3: [0x004F, 0x0301], // ร“ โ†’ O + ฬ + 0x00DA: [0x0055, 0x0301], // รš โ†’ U + ฬ + // Lowercase with acute accent + 0x00E1: [0x0061, 0x0301], // รก โ†’ a + ฬ + 0x00E9: [0x0065, 0x0301], // รฉ โ†’ e + ฬ + 0x00ED: [0x0069, 0x0301], // รญ โ†’ i + ฬ + 0x00F3: [0x006F, 0x0301], // รณ โ†’ o + ฬ + 0x00FA: [0x0075, 0x0301], // รบ โ†’ u + ฬ + // Grave accent + 0x00C0: [0x0041, 0x0300], // ร€ โ†’ A + ฬ€ + 0x00C8: [0x0045, 0x0300], // รˆ โ†’ E + ฬ€ + 0x00CC: [0x0049, 0x0300], // รŒ โ†’ I + ฬ€ + 0x00D2: [0x004F, 0x0300], // ร’ โ†’ O + ฬ€ + 0x00D9: [0x0055, 0x0300], // ร™ โ†’ U + ฬ€ + 0x00E0: [0x0061, 0x0300], // ร  โ†’ a + ฬ€ + 0x00E8: [0x0065, 0x0300], // รจ โ†’ e + ฬ€ + 0x00EC: [0x0069, 0x0300], // รฌ โ†’ i + ฬ€ + 0x00F2: [0x006F, 0x0300], // รฒ โ†’ o + ฬ€ + 0x00F9: [0x0075, 0x0300], // รน โ†’ u + ฬ€ + // Circumflex + 0x00C2: [0x0041, 0x0302], // ร‚ โ†’ A + ฬ‚ + 0x00CA: [0x0045, 0x0302], // รŠ โ†’ E + ฬ‚ + 0x00CE: [0x0049, 0x0302], // รŽ โ†’ I + ฬ‚ + 0x00D4: [0x004F, 0x0302], // ร” โ†’ O + ฬ‚ + 0x00DB: [0x0055, 0x0302], // ร› โ†’ U + ฬ‚ + 0x00E2: [0x0061, 0x0302], // รข โ†’ a + ฬ‚ + 0x00EA: [0x0065, 0x0302], // รช โ†’ e + ฬ‚ + 0x00EE: [0x0069, 0x0302], // รฎ โ†’ i + ฬ‚ + 0x00F4: [0x006F, 0x0302], // รด โ†’ o + ฬ‚ + 0x00FB: [0x0075, 0x0302], // รป โ†’ u + ฬ‚ + // Tilde + 0x00C3: [0x0041, 0x0303], // รƒ โ†’ A + ฬƒ + 0x00D1: [0x004E, 0x0303], // ร‘ โ†’ N + ฬƒ + 0x00D5: [0x004F, 0x0303], // ร• โ†’ O + ฬƒ + 0x00E3: [0x0061, 0x0303], // รฃ โ†’ a + ฬƒ + 0x00F1: [0x006E, 0x0303], // รฑ โ†’ n + ฬƒ + 0x00F5: [0x006F, 0x0303], // รต โ†’ o + ฬƒ + // Diaeresis/Umlaut + 0x00C4: [0x0041, 0x0308], // ร„ โ†’ A + ฬˆ + 0x00CB: [0x0045, 0x0308], // ร‹ โ†’ E + ฬˆ + 0x00CF: [0x0049, 0x0308], // ร โ†’ I + ฬˆ + 0x00D6: [0x004F, 0x0308], // ร– โ†’ O + ฬˆ + 0x00DC: [0x0055, 0x0308], // รœ โ†’ U + ฬˆ + 0x00E4: [0x0061, 0x0308], // รค โ†’ a + ฬˆ + 0x00EB: [0x0065, 0x0308], // รซ โ†’ e + ฬˆ + 0x00EF: [0x0069, 0x0308], // รฏ โ†’ i + ฬˆ + 0x00F6: [0x006F, 0x0308], // รถ โ†’ o + ฬˆ + 0x00FC: [0x0075, 0x0308], // รผ โ†’ u + ฬˆ + // Cedilla + 0x00C7: [0x0043, 0x0327], // ร‡ โ†’ C + ฬง + 0x00E7: [0x0063, 0x0327], // รง โ†’ c + ฬง +}; + +/// Apply NFKD-like decomposition (Hangul + Latin accented characters) +String _applyNfkdDecomposition(String text) { + final result = []; + for (final codePoint in text.runes) { + // Check Hangul first + if (codePoint >= _hangulSyllableBase && codePoint <= _hangulSyllableEnd) { + result.addAll(_decomposeHangulSyllable(codePoint)); + } + // Check Latin decomposition + else if (_latinDecompositions.containsKey(codePoint)) { + result.addAll(_latinDecompositions[codePoint]!); + } + // Keep as-is + else { + result.add(codePoint); + } + } + return String.fromCharCodes(result); +} + +String preprocessText(String text, String lang) { + // Apply NFKD-like decomposition (especially for Hangul syllables โ†’ Jamo) + text = _applyNfkdDecomposition(text); + + // Remove emojis + text = text.replaceAll( + RegExp( + r'[\u{1F600}-\u{1F64F}]|[\u{1F300}-\u{1F5FF}]|[\u{1F680}-\u{1F6FF}]|' + r'[\u{1F700}-\u{1F77F}]|[\u{1F780}-\u{1F7FF}]|[\u{1F800}-\u{1F8FF}]|' + r'[\u{1F900}-\u{1F9FF}]|[\u{1FA00}-\u{1FA6F}]|[\u{1FA70}-\u{1FAFF}]|' + r'[\u{2600}-\u{26FF}]|[\u{2700}-\u{27BF}]|[\u{1F1E6}-\u{1F1FF}]', + unicode: true, + ), + ''); + + // Replace various dashes and symbols + const replacements = { + 'โ€“': '-', + 'โ€‘': '-', + 'โ€”': '-', + '_': ' ', + '\u201C': '"', + '\u201D': '"', + '\u2018': "'", + '\u2019': "'", + 'ยด': "'", + '`': "'", + '[': ' ', + ']': ' ', + '|': ' ', + '/': ' ', + '#': ' ', + 'โ†’': ' ', + 'โ†': ' ', + }; + for (final entry in replacements.entries) { + text = text.replaceAll(entry.key, entry.value); + } + + // Remove special symbols + text = text.replaceAll(RegExp(r'[โ™ฅโ˜†โ™กยฉ\\]'), ''); + + // Replace known expressions + text = text.replaceAll('@', ' at '); + text = text.replaceAll('e.g.,', 'for example, '); + text = text.replaceAll('i.e.,', 'that is, '); + + // Fix spacing around punctuation + text = text.replaceAll(' ,', ','); + text = text.replaceAll(' .', '.'); + text = text.replaceAll(' !', '!'); + text = text.replaceAll(' ?', '?'); + text = text.replaceAll(' ;', ';'); + text = text.replaceAll(' :', ':'); + text = text.replaceAll(" '", "'"); + + // Remove duplicate quotes + while (text.contains('""')) text = text.replaceAll('""', '"'); + while (text.contains("''")) text = text.replaceAll("''", "'"); + while (text.contains('``')) text = text.replaceAll('``', '`'); + + // Remove extra spaces + text = text.replaceAll(RegExp(r'\s+'), ' ').trim(); + + // Add period if needed + if (text.isNotEmpty && + !RegExp(r'[.!?;:,\x27\x22\u2018\u2019)\]}โ€ฆใ€‚ใ€ใ€ใ€‘ใ€‰ใ€‹โ€บยป]$').hasMatch(text)) { + text += '.'; + } + + // Validate language + if (!isValidLang(lang)) { + throw ArgumentError( + 'Invalid language: $lang. Available: ${availableLangs.join(", ")}'); + } + + // Wrap text with language tags + text = '<$lang>$text'; + + return text; +} + +class UnicodeProcessor { + final Map indexer; + + UnicodeProcessor._(this.indexer); + + static Future load(String path) async { + final json = jsonDecode( + path.startsWith('assets/') + ? await rootBundle.loadString(path) + : File(path).readAsStringSync(), + ); + + final indexer = json is List + ? { + for (var i = 0; i < json.length; i++) + if (json[i] is int && json[i] >= 0) i: json[i] as int + } + : (json as Map) + .map((k, v) => MapEntry(int.parse(k), v as int)); + + return UnicodeProcessor._(indexer); + } + + Map call(List textList, List langList) { + // Preprocess texts with language tags + final processedTexts = []; + for (var i = 0; i < textList.length; i++) { + processedTexts.add(preprocessText(textList[i], langList[i])); + } + + final lengths = processedTexts.map((t) => t.runes.length).toList(); + final maxLen = lengths.reduce(math.max); + + final textIds = processedTexts.map((text) { + final row = List.filled(maxLen, 0); + final runes = text.runes.toList(); + for (var i = 0; i < runes.length; i++) { + row[i] = indexer[runes[i]] ?? 0; + } + return row; + }).toList(); + + return {'textIds': textIds, 'textMask': _lengthToMask(lengths)}; + } + + List>> _lengthToMask(List lengths, [int? maxLen]) { + maxLen ??= lengths.reduce(math.max); + return lengths + .map((len) => [List.generate(maxLen!, (i) => i < len ? 1.0 : 0.0)]) + .toList(); + } +} + +class Style { + final OrtValue ttl, dp; + final List ttlShape, dpShape; + Style(this.ttl, this.dp, this.ttlShape, this.dpShape); +} + +class TextToSpeech { + final Map cfgs; + final UnicodeProcessor textProcessor; + final OrtSession dpOrt, textEncOrt, vectorEstOrt, vocoderOrt; + final int sampleRate, baseChunkSize, chunkCompressFactor, ldim; + + TextToSpeech(this.cfgs, this.textProcessor, this.dpOrt, this.textEncOrt, + this.vectorEstOrt, this.vocoderOrt) + : sampleRate = cfgs['ae']['sample_rate'], + baseChunkSize = cfgs['ae']['base_chunk_size'], + chunkCompressFactor = cfgs['ttl']['chunk_compress_factor'], + ldim = cfgs['ttl']['latent_dim']; + + Future> call( + String text, String lang, Style style, int totalStep, + {double speed = 1.05, double silenceDuration = 0.3}) async { + final maxLen = lang == 'ko' ? 120 : 300; + final chunks = _chunkText(text, maxLen: maxLen); + final langList = List.filled(chunks.length, lang); + List? wavCat; + double durCat = 0; + + for (var i = 0; i < chunks.length; i++) { + final result = await _infer([chunks[i]], [langList[i]], style, totalStep, + speed: speed); + final wav = _safeCast(result['wav']); + final duration = _safeCast(result['duration']); + + if (wavCat == null) { + wavCat = wav; + durCat = duration[0]; + } else { + wavCat = [ + ...wavCat, + ...List.filled((silenceDuration * sampleRate).floor(), 0.0), + ...wav + ]; + durCat += duration[0] + silenceDuration; + } + } + + return { + 'wav': wavCat, + 'duration': [durCat] + }; + } + + Future> _infer( + List textList, List langList, Style style, int totalStep, + {double speed = 1.05}) async { + final bsz = textList.length; + final result = textProcessor.call(textList, langList); + + final textIdsRaw = result['textIds']; + final textIds = textIdsRaw is List> + ? textIdsRaw + : (textIdsRaw as List).map((row) => (row as List).cast()).toList(); + + final textMaskRaw = result['textMask']; + final textMask = textMaskRaw is List>> + ? textMaskRaw + : (textMaskRaw as List) + .map((batch) => (batch as List) + .map((row) => (row as List).cast()) + .toList()) + .toList(); + + final textIdsShape = [bsz, textIds[0].length]; + final textMaskShape = [bsz, 1, textMask[0][0].length]; + final textMaskTensor = await _toTensor(textMask, textMaskShape); + + final dpResult = await dpOrt.run({ + 'text_ids': await _intToTensor(textIds, textIdsShape), + 'style_dp': style.dp, + 'text_mask': textMaskTensor, + }); + final durOnnx = _safeCast(await dpResult.values.first.asList()); + final scaledDur = durOnnx.map((d) => d / speed).toList(); + + final textEncResult = await textEncOrt.run({ + 'text_ids': await _intToTensor(textIds, textIdsShape), + 'style_ttl': style.ttl, + 'text_mask': textMaskTensor, + }); + + final latentData = _sampleNoisyLatent(scaledDur); + final noisyLatentRaw = latentData['noisyLatent']; + var noisyLatent = noisyLatentRaw is List>> + ? noisyLatentRaw + : (noisyLatentRaw as List) + .map((batch) => (batch as List) + .map((row) => (row as List).cast()) + .toList()) + .toList(); + + final latentMaskRaw = latentData['latentMask']; + final latentMask = latentMaskRaw is List>> + ? latentMaskRaw + : (latentMaskRaw as List) + .map((batch) => (batch as List) + .map((row) => (row as List).cast()) + .toList()) + .toList(); + + final latentShape = [bsz, noisyLatent[0].length, noisyLatent[0][0].length]; + final latentMaskTensor = + await _toTensor(latentMask, [bsz, 1, latentMask[0][0].length]); + + final totalStepTensor = + await _scalarToTensor(List.filled(bsz, totalStep.toDouble()), [bsz]); + + // Denoising loop + for (var step = 0; step < totalStep; step++) { + final result = await vectorEstOrt.run({ + 'noisy_latent': await _toTensor(noisyLatent, latentShape), + 'text_emb': textEncResult.values.first, + 'style_ttl': style.ttl, + 'text_mask': textMaskTensor, + 'latent_mask': latentMaskTensor, + 'total_step': totalStepTensor, + 'current_step': + await _scalarToTensor(List.filled(bsz, step.toDouble()), [bsz]), + }); + + final denoisedRaw = await result.values.first.asList(); + final denoised = denoisedRaw is List + ? denoisedRaw + : _safeCast(denoisedRaw); + var idx = 0; + for (var b = 0; b < noisyLatent.length; b++) { + for (var d = 0; d < noisyLatent[b].length; d++) { + for (var t = 0; t < noisyLatent[b][d].length; t++) { + noisyLatent[b][d][t] = denoised[idx++]; + } + } + } + } + + final vocoderResult = await vocoderOrt + .run({'latent': await _toTensor(noisyLatent, latentShape)}); + final wavRaw = await vocoderResult.values.first.asList(); + final wav = wavRaw is List ? wavRaw : _safeCast(wavRaw); + + return {'wav': wav, 'duration': scaledDur}; + } + + Map _sampleNoisyLatent(List duration) { + final wavLenMax = duration.reduce(math.max) * sampleRate; + final wavLengths = duration.map((d) => (d * sampleRate).floor()).toList(); + final chunkSize = baseChunkSize * chunkCompressFactor; + final latentLen = ((wavLenMax + chunkSize - 1) / chunkSize).floor(); + final latentDim = ldim * chunkCompressFactor; + + final random = math.Random(); + final noisyLatent = List.generate( + duration.length, + (_) => List.generate( + latentDim, + (_) => List.generate(latentLen, (_) { + final u1 = math.max(1e-10, random.nextDouble()); + final u2 = random.nextDouble(); + return math.sqrt(-2.0 * math.log(u1)) * math.cos(2.0 * math.pi * u2); + }), + ), + ); + + final latentMask = _getLatentMask(wavLengths); + + for (var b = 0; b < noisyLatent.length; b++) { + for (var d = 0; d < noisyLatent[b].length; d++) { + for (var t = 0; t < noisyLatent[b][d].length; t++) { + noisyLatent[b][d][t] *= latentMask[b][0][t]; + } + } + } + + return {'noisyLatent': noisyLatent, 'latentMask': latentMask}; + } + + List>> _getLatentMask(List wavLengths) { + final latentSize = baseChunkSize * chunkCompressFactor; + final latentLengths = wavLengths + .map((len) => ((len + latentSize - 1) / latentSize).floor()) + .toList(); + final maxLen = latentLengths.reduce(math.max); + return latentLengths + .map((len) => [List.generate(maxLen, (i) => i < len ? 1.0 : 0.0)]) + .toList(); + } + + List _chunkText(String text, {int maxLen = 300}) { + final paragraphs = text + .trim() + .split(RegExp(r'\n\s*\n+')) + .where((p) => p.trim().isNotEmpty) + .toList(); + + final chunks = []; + for (var paragraph in paragraphs) { + paragraph = paragraph.trim(); + if (paragraph.isEmpty) continue; + + final sentences = paragraph.split(RegExp( + r'(? _safeCast(dynamic raw) { + if (raw is List) return raw; + if (raw is List) { + if (raw.isNotEmpty && raw.first is List) { + return _flattenList(raw); + } + if (T == double) { + return raw + .map((e) => e is num ? e.toDouble() : double.parse(e.toString())) + .toList() as List; + } + return raw.cast(); + } + throw Exception('Cannot convert $raw to List<$T>'); + } + + List _flattenList(dynamic list) { + if (list is List) { + return list.expand((e) => _flattenList(e)).toList(); + } + if (T == double && list is num) { + return [list.toDouble()] as List; + } + return [list as T]; + } + + Future _toTensor(dynamic array, List dims) async { + final flat = _flattenList(array); + return await OrtValue.fromList(Float32List.fromList(flat), dims); + } + + Future _scalarToTensor(List array, List dims) async { + return await OrtValue.fromList(Float32List.fromList(array), dims); + } + + Future _intToTensor(List> array, List dims) async { + final flat = array.expand((row) => row).toList(); + return await OrtValue.fromList(Int64List.fromList(flat), dims); + } +} + +Future loadTextToSpeech(String onnxDir, + {bool useGpu = false}) async { + if (useGpu) throw Exception('GPU mode not supported yet'); + + logger.i('Loading TTS models from $onnxDir'); + + final cfgs = await _loadCfgs(onnxDir); + final sessions = await _loadOnnxAll(onnxDir); + final textProcessor = + await UnicodeProcessor.load('$onnxDir/unicode_indexer.json'); + + logger.i('TTS models loaded successfully'); + + return TextToSpeech( + cfgs, + textProcessor, + sessions['dpOrt']!, + sessions['textEncOrt']!, + sessions['vectorEstOrt']!, + sessions['vocoderOrt']!, + ); +} + +Future