Files
Supertonic/cpp/helper.cpp
2026-01-25 18:58:40 +09:00

1187 lines
42 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#include "helper.h"
#include <fstream>
#include <iostream>
#include <cmath>
#include <algorithm>
#include <random>
#include <sstream>
#include <regex>
#include <unordered_map>
#include <nlohmann/json.hpp>
using json = nlohmann::json;
// Available languages for multilingual TTS
const std::vector<std::string> AVAILABLE_LANGS = {"en", "ko", "es", "pt", "fr"};
// Global tensor buffers for memory management
static std::vector<std::vector<float>> g_tensor_buffers_float;
static std::vector<std::vector<int64_t>> g_tensor_buffers_int64;
void clearTensorBuffers() {
g_tensor_buffers_float.clear();
g_tensor_buffers_int64.clear();
}
// ============================================================================
// Helper function - trim
// ============================================================================
static std::string trim(const std::string& str) {
size_t start = 0;
while (start < str.size() && std::isspace(static_cast<unsigned char>(str[start]))) {
start++;
}
size_t end = str.size();
while (end > start && std::isspace(static_cast<unsigned char>(str[end - 1]))) {
end--;
}
return str.substr(start, end - start);
}
// ============================================================================
// UnicodeProcessor implementation
// ============================================================================
UnicodeProcessor::UnicodeProcessor(const std::string& unicode_indexer_json_path) {
indexer_ = loadJsonInt64(unicode_indexer_json_path);
}
std::string UnicodeProcessor::preprocessText(const std::string& text, const std::string& lang) {
// TODO: Need advanced normalizer for better performance
// NOTE: C++ doesn't have built-in Unicode normalization like Python's NFKD
// For full Unicode normalization, consider using ICU library
// This implementation handles basic text preprocessing
std::string result = text;
// IMPORTANT: Do symbol replacements FIRST (before emoji removal)
// to preserve curly quotes and other punctuation that might be matched by emoji patterns
// Replace various dashes and symbols
struct Replacement {
const char* from;
const char* to;
};
const Replacement replacements[] = {
{"", "-"}, // en dash
{"", "-"}, // non-breaking hyphen
{"", "-"}, // em dash
{"_", " "}, // underscore
{ u8"\u201C", "\"" }, // left double quote "
{ u8"\u201D", "\"" }, // right double quote "
{ u8"\u2018", "'" }, // left single quote '
{ u8"\u2019", "'" }, // right single quote '
{"´", "'"}, // acute accent
{"`", "'"}, // grave accent
{"[", " "}, // left bracket
{"]", " "}, // right bracket
{"|", " "}, // vertical bar
{"/", " "}, // slash
{"#", " "}, // hash
{"", " "}, // right arrow
{"", " "}, // left arrow
};
for (const auto& repl : replacements) {
size_t pos = 0;
while ((pos = result.find(repl.from, pos)) != std::string::npos) {
result.replace(pos, strlen(repl.from), repl.to);
pos += strlen(repl.to);
}
}
// Remove emojis AFTER symbol replacements
// Only target actual emoji ranges (4-byte UTF-8 sequences starting with F0 9F)
std::regex emoji_pattern(
"[\xF0][\x9F][\x80-\xBF][\x80-\xBF]" // 4-byte emoji (U+1F000-U+1FFFF)
);
result = std::regex_replace(result, emoji_pattern, "");
// Remove special symbols
const char* special_symbols[] = {"", "", "", "©", "\\"};
for (const char* symbol : special_symbols) {
size_t pos = 0;
while ((pos = result.find(symbol, pos)) != std::string::npos) {
result.erase(pos, strlen(symbol));
}
}
// Replace known expressions
const Replacement expr_replacements[] = {
{"@", " at "},
{"e.g.,", "for example, "},
{"i.e.,", "that is, "},
};
for (const auto& repl : expr_replacements) {
size_t pos = 0;
while ((pos = result.find(repl.from, pos)) != std::string::npos) {
result.replace(pos, strlen(repl.from), repl.to);
pos += strlen(repl.to);
}
}
// Fix spacing around punctuation
result = std::regex_replace(result, std::regex(" ,"), ",");
result = std::regex_replace(result, std::regex(" \\."), ".");
result = std::regex_replace(result, std::regex(" !"), "!");
result = std::regex_replace(result, std::regex(" \\?"), "?");
result = std::regex_replace(result, std::regex(" ;"), ";");
result = std::regex_replace(result, std::regex(" :"), ":");
result = std::regex_replace(result, std::regex(" '"), "'");
// Remove duplicate quotes
while (result.find("\"\"") != std::string::npos) {
size_t pos = result.find("\"\"");
result.replace(pos, 2, "\"");
}
while (result.find("''") != std::string::npos) {
size_t pos = result.find("''");
result.replace(pos, 2, "'");
}
while (result.find("``") != std::string::npos) {
size_t pos = result.find("``");
result.replace(pos, 2, "`");
}
// Remove extra spaces
result = std::regex_replace(result, std::regex("\\s+"), " ");
result = trim(result);
// If text doesn't end with punctuation, quotes, or closing brackets, add a period
if (!result.empty()) {
char last_char = result.back();
bool ends_with_punct = (
last_char == '.' || last_char == '!' || last_char == '?' ||
last_char == ';' || last_char == ':' || last_char == ',' ||
last_char == '\'' || last_char == '"' || last_char == ')' ||
last_char == ']' || last_char == '}' || last_char == '>'
);
// Check for UTF-8 multibyte ending punctuation (e.g., …, 。, curly quotes, etc.)
if (!ends_with_punct && result.size() >= 3) {
std::string last_three = result.substr(result.size() - 3);
if (last_three == "" || last_three == "" ||
last_three == "" || last_three == "" ||
last_three == "" || last_three == "" ||
last_three == "" || last_three == "" ||
last_three == "»" || last_three == u8"\u201C" ||
last_three == u8"\u201D" || last_three == u8"\u2018" ||
last_three == u8"\u2019") {
ends_with_punct = true;
}
}
if (!ends_with_punct) {
result += ".";
}
}
// Validate language
bool valid_lang = false;
for (const auto& available_lang : AVAILABLE_LANGS) {
if (lang == available_lang) {
valid_lang = true;
break;
}
}
if (!valid_lang) {
throw std::runtime_error("Invalid language: " + lang + ". Available: en, ko, es, pt, fr");
}
// Wrap text with language tags
result = "<" + lang + ">" + result + "</" + lang + ">";
return result;
}
// Hangul syllable decomposition constants (Unicode Standard Annex #15)
static const uint32_t HANGUL_SBASE = 0xAC00; // Start of Hangul syllables
static const uint32_t HANGUL_LBASE = 0x1100; // Start of Hangul Jamo (leading consonants)
static const uint32_t HANGUL_VBASE = 0x1161; // Start of Hangul Jamo (vowels)
static const uint32_t HANGUL_TBASE = 0x11A7; // Start of Hangul Jamo (trailing consonants)
static const int HANGUL_LCOUNT = 19; // Number of leading consonants
static const int HANGUL_VCOUNT = 21; // Number of vowels
static const int HANGUL_TCOUNT = 28; // Number of trailing consonants (including none)
static const int HANGUL_NCOUNT = HANGUL_VCOUNT * HANGUL_TCOUNT; // 588
static const int HANGUL_SCOUNT = HANGUL_LCOUNT * HANGUL_NCOUNT; // 11172
// Latin character NFKD decompositions for Spanish, Portuguese, French
static const std::unordered_map<uint32_t, std::vector<uint16_t>> LATIN_DECOMPOSITIONS = {
// Acute accent
{0x00C1, {0x0041, 0x0301}}, // Á → A + ́
{0x00C9, {0x0045, 0x0301}}, // É → E + ́
{0x00CD, {0x0049, 0x0301}}, // Í → I + ́
{0x00D3, {0x004F, 0x0301}}, // Ó → O + ́
{0x00DA, {0x0055, 0x0301}}, // Ú → U + ́
{0x00E1, {0x0061, 0x0301}}, // á → a + ́
{0x00E9, {0x0065, 0x0301}}, // é → e + ́
{0x00ED, {0x0069, 0x0301}}, // í → i + ́
{0x00F3, {0x006F, 0x0301}}, // ó → o + ́
{0x00FA, {0x0075, 0x0301}}, // ú → u + ́
// Grave accent
{0x00C0, {0x0041, 0x0300}}, // À → A + ̀
{0x00C8, {0x0045, 0x0300}}, // È → E + ̀
{0x00CC, {0x0049, 0x0300}}, // Ì → I + ̀
{0x00D2, {0x004F, 0x0300}}, // Ò → O + ̀
{0x00D9, {0x0055, 0x0300}}, // Ù → U + ̀
{0x00E0, {0x0061, 0x0300}}, // à → a + ̀
{0x00E8, {0x0065, 0x0300}}, // è → e + ̀
{0x00EC, {0x0069, 0x0300}}, // ì → i + ̀
{0x00F2, {0x006F, 0x0300}}, // ò → o + ̀
{0x00F9, {0x0075, 0x0300}}, // ù → u + ̀
// Circumflex
{0x00C2, {0x0041, 0x0302}}, // Â → A + ̂
{0x00CA, {0x0045, 0x0302}}, // Ê → E + ̂
{0x00CE, {0x0049, 0x0302}}, // Î → I + ̂
{0x00D4, {0x004F, 0x0302}}, // Ô → O + ̂
{0x00DB, {0x0055, 0x0302}}, // Û → U + ̂
{0x00E2, {0x0061, 0x0302}}, // â → a + ̂
{0x00EA, {0x0065, 0x0302}}, // ê → e + ̂
{0x00EE, {0x0069, 0x0302}}, // î → i + ̂
{0x00F4, {0x006F, 0x0302}}, // ô → o + ̂
{0x00FB, {0x0075, 0x0302}}, // û → u + ̂
// Tilde
{0x00C3, {0x0041, 0x0303}}, // Ã → A + ̃
{0x00D1, {0x004E, 0x0303}}, // Ñ → N + ̃
{0x00D5, {0x004F, 0x0303}}, // Õ → O + ̃
{0x00E3, {0x0061, 0x0303}}, // ã → a + ̃
{0x00F1, {0x006E, 0x0303}}, // ñ → n + ̃
{0x00F5, {0x006F, 0x0303}}, // õ → o + ̃
// Diaeresis
{0x00C4, {0x0041, 0x0308}}, // Ä → A + ̈
{0x00CB, {0x0045, 0x0308}}, // Ë → E + ̈
{0x00CF, {0x0049, 0x0308}}, // Ï → I + ̈
{0x00D6, {0x004F, 0x0308}}, // Ö → O + ̈
{0x00DC, {0x0055, 0x0308}}, // Ü → U + ̈
{0x00E4, {0x0061, 0x0308}}, // ä → a + ̈
{0x00EB, {0x0065, 0x0308}}, // ë → e + ̈
{0x00EF, {0x0069, 0x0308}}, // ï → i + ̈
{0x00F6, {0x006F, 0x0308}}, // ö → o + ̈
{0x00FC, {0x0075, 0x0308}}, // ü → u + ̈
// Cedilla
{0x00C7, {0x0043, 0x0327}}, // Ç → C + ̧
{0x00E7, {0x0063, 0x0327}}, // ç → c + ̧
};
// Decompose a character using NFKD (Hangul + Latin accented)
static void decomposeCharacter(uint32_t codepoint, std::vector<uint16_t>& output) {
// Check Hangul syllables first
if (codepoint >= HANGUL_SBASE && codepoint < HANGUL_SBASE + HANGUL_SCOUNT) {
// Decompose Hangul syllable into Jamo
uint32_t sIndex = codepoint - HANGUL_SBASE;
uint32_t lIndex = sIndex / HANGUL_NCOUNT;
uint32_t vIndex = (sIndex % HANGUL_NCOUNT) / HANGUL_TCOUNT;
uint32_t tIndex = sIndex % HANGUL_TCOUNT;
output.push_back(static_cast<uint16_t>(HANGUL_LBASE + lIndex));
output.push_back(static_cast<uint16_t>(HANGUL_VBASE + vIndex));
if (tIndex > 0) {
output.push_back(static_cast<uint16_t>(HANGUL_TBASE + tIndex));
}
return;
}
// Check Latin decompositions
auto it = LATIN_DECOMPOSITIONS.find(codepoint);
if (it != LATIN_DECOMPOSITIONS.end()) {
for (uint16_t cp : it->second) {
output.push_back(cp);
}
return;
}
// Keep as-is
output.push_back(static_cast<uint16_t>(codepoint & 0xFFFF));
}
std::vector<uint16_t> UnicodeProcessor::textToUnicodeValues(const std::string& text) {
std::vector<uint16_t> unicode_values;
size_t i = 0;
while (i < text.size()) {
uint32_t codepoint = 0;
unsigned char c = static_cast<unsigned char>(text[i]);
if ((c & 0x80) == 0) {
// 1-byte ASCII (0xxxxxxx)
codepoint = c;
i += 1;
}
else if ((c & 0xE0) == 0xC0 && i + 1 < text.size()) {
// 2-byte UTF-8 (110xxxxx 10xxxxxx)
codepoint = (c & 0x1F) << 6;
codepoint |= (static_cast<unsigned char>(text[i + 1]) & 0x3F);
i += 2;
}
else if ((c & 0xF0) == 0xE0 && i + 2 < text.size()) {
// 3-byte UTF-8 (1110xxxx 10xxxxxx 10xxxxxx) - includes Korean
codepoint = (c & 0x0F) << 12;
codepoint |= (static_cast<unsigned char>(text[i + 1]) & 0x3F) << 6;
codepoint |= (static_cast<unsigned char>(text[i + 2]) & 0x3F);
i += 3;
}
else if ((c & 0xF8) == 0xF0 && i + 3 < text.size()) {
// 4-byte UTF-8 (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx)
codepoint = (c & 0x07) << 18;
codepoint |= (static_cast<unsigned char>(text[i + 1]) & 0x3F) << 12;
codepoint |= (static_cast<unsigned char>(text[i + 2]) & 0x3F) << 6;
codepoint |= (static_cast<unsigned char>(text[i + 3]) & 0x3F);
i += 4;
}
else {
// Invalid UTF-8, skip byte
i += 1;
continue;
}
// Decompose Hangul syllables and Latin accented characters (NFKD)
decomposeCharacter(codepoint, unicode_values);
}
return unicode_values;
}
std::vector<std::vector<std::vector<float>>> UnicodeProcessor::getTextMask(
const std::vector<int64_t>& text_ids_lengths
) {
return lengthToMask(text_ids_lengths);
}
void UnicodeProcessor::call(
const std::vector<std::string>& text_list,
const std::vector<std::string>& lang_list,
std::vector<std::vector<int64_t>>& text_ids,
std::vector<std::vector<std::vector<float>>>& text_mask
) {
std::vector<std::string> processed_texts;
for (size_t i = 0; i < text_list.size(); i++) {
processed_texts.push_back(preprocessText(text_list[i], lang_list[i]));
}
// Convert texts to unicode values first to get correct character counts
std::vector<std::vector<uint16_t>> all_unicode_vals;
std::vector<int64_t> text_ids_lengths;
for (const auto& text : processed_texts) {
auto unicode_vals = textToUnicodeValues(text);
// Use number of Unicode codepoints, not bytes
text_ids_lengths.push_back(static_cast<int64_t>(unicode_vals.size()));
all_unicode_vals.push_back(std::move(unicode_vals));
}
int64_t max_len = *std::max_element(text_ids_lengths.begin(), text_ids_lengths.end());
text_ids.resize(text_list.size());
for (size_t i = 0; i < all_unicode_vals.size(); i++) {
text_ids[i].resize(max_len, 0);
const auto& unicode_vals = all_unicode_vals[i];
for (size_t j = 0; j < unicode_vals.size(); j++) {
if (unicode_vals[j] < indexer_.size()) {
text_ids[i][j] = indexer_[unicode_vals[j]];
}
}
}
text_mask = getTextMask(text_ids_lengths);
}
// ============================================================================
// Style implementation
// ============================================================================
Style::Style(const std::vector<float>& ttl_data, const std::vector<int64_t>& ttl_shape,
const std::vector<float>& dp_data, const std::vector<int64_t>& dp_shape)
: ttl_data_(ttl_data), ttl_shape_(ttl_shape), dp_data_(dp_data), dp_shape_(dp_shape) {}
// ============================================================================
// TextToSpeech implementation
// ============================================================================
TextToSpeech::TextToSpeech(
const Config& cfgs,
UnicodeProcessor* text_processor,
Ort::Session* dp_ort,
Ort::Session* text_enc_ort,
Ort::Session* vector_est_ort,
Ort::Session* vocoder_ort
) : cfgs_(cfgs),
text_processor_(text_processor),
dp_ort_(dp_ort),
text_enc_ort_(text_enc_ort),
vector_est_ort_(vector_est_ort),
vocoder_ort_(vocoder_ort) {
sample_rate_ = cfgs.ae.sample_rate;
base_chunk_size_ = cfgs.ae.base_chunk_size;
chunk_compress_factor_ = cfgs.ttl.chunk_compress_factor;
ldim_ = cfgs.ttl.latent_dim;
}
void TextToSpeech::sampleNoisyLatent(
const std::vector<float>& duration,
std::vector<std::vector<std::vector<float>>>& noisy_latent,
std::vector<std::vector<std::vector<float>>>& latent_mask
) {
int bsz = duration.size();
float wav_len_max = *std::max_element(duration.begin(), duration.end()) * sample_rate_;
std::vector<int64_t> wav_lengths;
for (float d : duration) {
wav_lengths.push_back(static_cast<int64_t>(d * sample_rate_));
}
int chunk_size = base_chunk_size_ * chunk_compress_factor_;
int latent_len = static_cast<int>((wav_len_max + chunk_size - 1) / chunk_size);
int latent_dim = ldim_ * chunk_compress_factor_;
// Generate random noise with normal distribution
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<float> dist(0.0f, 1.0f);
noisy_latent.resize(bsz);
for (int b = 0; b < bsz; b++) {
noisy_latent[b].resize(latent_dim);
for (int d = 0; d < latent_dim; d++) {
noisy_latent[b][d].resize(latent_len);
for (int t = 0; t < latent_len; t++) {
noisy_latent[b][d][t] = dist(gen);
}
}
}
latent_mask = getLatentMask(wav_lengths, base_chunk_size_, chunk_compress_factor_);
// Apply mask
for (int b = 0; b < bsz; b++) {
for (int d = 0; d < latent_dim; d++) {
for (size_t t = 0; t < noisy_latent[b][d].size(); t++) {
noisy_latent[b][d][t] *= latent_mask[b][0][t];
}
}
}
}
TextToSpeech::SynthesisResult TextToSpeech::_infer(
Ort::MemoryInfo& memory_info,
const std::vector<std::string>& text_list,
const std::vector<std::string>& lang_list,
const Style& style,
int total_step,
float speed
) {
int bsz = text_list.size();
if (bsz != style.getTtlShape()[0]) {
throw std::runtime_error("Number of texts must match number of style vectors");
}
// Process text
std::vector<std::vector<int64_t>> text_ids;
std::vector<std::vector<std::vector<float>>> text_mask;
text_processor_->call(text_list, lang_list, text_ids, text_mask);
std::vector<int64_t> text_ids_shape = {bsz, static_cast<int64_t>(text_ids[0].size())};
std::vector<int64_t> text_mask_shape = {bsz, 1, static_cast<int64_t>(text_mask[0][0].size())};
auto text_ids_tensor = intArrayToTensor(memory_info, text_ids, text_ids_shape);
auto text_mask_tensor = arrayToTensor(memory_info, text_mask, text_mask_shape);
// Create style tensors
auto style_ttl_tensor = Ort::Value::CreateTensor<float>(
memory_info,
const_cast<float*>(style.getTtlData().data()),
style.getTtlData().size(),
style.getTtlShape().data(),
style.getTtlShape().size()
);
auto style_dp_tensor = Ort::Value::CreateTensor<float>(
memory_info,
const_cast<float*>(style.getDpData().data()),
style.getDpData().size(),
style.getDpShape().data(),
style.getDpShape().size()
);
// Run duration predictor
const char* dp_input_names[] = {"text_ids", "style_dp", "text_mask"};
const char* dp_output_names[] = {"duration"};
std::vector<Ort::Value> dp_inputs;
dp_inputs.push_back(std::move(text_ids_tensor));
dp_inputs.push_back(std::move(style_dp_tensor));
dp_inputs.push_back(std::move(text_mask_tensor));
auto dp_outputs = dp_ort_->Run(
Ort::RunOptions{nullptr},
dp_input_names, dp_inputs.data(), dp_inputs.size(),
dp_output_names, 1
);
auto* dur_data = dp_outputs[0].GetTensorMutableData<float>();
std::vector<float> duration(dur_data, dur_data + bsz);
// Apply speed factor to duration
for (auto& dur : duration) {
dur /= speed;
}
// Create new tensors for text encoder (previous ones were moved)
text_ids_tensor = intArrayToTensor(memory_info, text_ids, text_ids_shape);
text_mask_tensor = arrayToTensor(memory_info, text_mask, text_mask_shape);
style_ttl_tensor = Ort::Value::CreateTensor<float>(
memory_info,
const_cast<float*>(style.getTtlData().data()),
style.getTtlData().size(),
style.getTtlShape().data(),
style.getTtlShape().size()
);
// Run text encoder
const char* text_enc_input_names[] = {"text_ids", "style_ttl", "text_mask"};
const char* text_enc_output_names[] = {"text_emb"};
std::vector<Ort::Value> text_enc_inputs;
text_enc_inputs.push_back(std::move(text_ids_tensor));
text_enc_inputs.push_back(std::move(style_ttl_tensor));
text_enc_inputs.push_back(std::move(text_mask_tensor));
auto text_enc_outputs = text_enc_ort_->Run(
Ort::RunOptions{nullptr},
text_enc_input_names, text_enc_inputs.data(), text_enc_inputs.size(),
text_enc_output_names, 1
);
// Sample noisy latent
std::vector<std::vector<std::vector<float>>> xt, latent_mask;
sampleNoisyLatent(duration, xt, latent_mask);
std::vector<int64_t> latent_shape = {
bsz,
static_cast<int64_t>(xt[0].size()),
static_cast<int64_t>(xt[0][0].size())
};
std::vector<int64_t> latent_mask_shape = {
bsz, 1,
static_cast<int64_t>(latent_mask[0][0].size())
};
// Prepare scalar tensors
std::vector<float> total_step_vec(bsz, static_cast<float>(total_step));
auto total_step_tensor = Ort::Value::CreateTensor<float>(
memory_info,
total_step_vec.data(),
total_step_vec.size(),
std::vector<int64_t>{bsz}.data(),
1
);
// Store text_emb data to reuse across iterations
auto text_emb_info = text_enc_outputs[0].GetTensorTypeAndShapeInfo();
size_t text_emb_size = text_emb_info.GetElementCount();
auto* text_emb_data = text_enc_outputs[0].GetTensorMutableData<float>();
std::vector<float> text_emb_vec(text_emb_data, text_emb_data + text_emb_size);
auto text_emb_shape = text_emb_info.GetShape();
// Iterative denoising
for (int step = 0; step < total_step; step++) {
std::vector<float> current_step_vec(bsz, static_cast<float>(step));
text_mask_tensor = arrayToTensor(memory_info, text_mask, text_mask_shape);
auto latent_mask_tensor = arrayToTensor(memory_info, latent_mask, latent_mask_shape);
auto noisy_latent_tensor = arrayToTensor(memory_info, xt, latent_shape);
style_ttl_tensor = Ort::Value::CreateTensor<float>(
memory_info,
const_cast<float*>(style.getTtlData().data()),
style.getTtlData().size(),
style.getTtlShape().data(),
style.getTtlShape().size()
);
auto text_emb_tensor = Ort::Value::CreateTensor<float>(
memory_info,
text_emb_vec.data(),
text_emb_vec.size(),
text_emb_shape.data(),
text_emb_shape.size()
);
auto current_step_tensor = Ort::Value::CreateTensor<float>(
memory_info,
current_step_vec.data(),
current_step_vec.size(),
std::vector<int64_t>{bsz}.data(),
1
);
const char* vector_est_input_names[] = {
"noisy_latent", "text_emb", "style_ttl", "text_mask", "latent_mask", "total_step", "current_step"
};
const char* vector_est_output_names[] = {"denoised_latent"};
std::vector<Ort::Value> vector_est_inputs;
vector_est_inputs.push_back(std::move(noisy_latent_tensor));
vector_est_inputs.push_back(std::move(text_emb_tensor));
vector_est_inputs.push_back(std::move(style_ttl_tensor));
vector_est_inputs.push_back(std::move(text_mask_tensor));
vector_est_inputs.push_back(std::move(latent_mask_tensor));
// Create a new total_step tensor for each iteration
auto total_step_tensor_iter = Ort::Value::CreateTensor<float>(
memory_info,
total_step_vec.data(),
total_step_vec.size(),
std::vector<int64_t>{bsz}.data(),
1
);
vector_est_inputs.push_back(std::move(total_step_tensor_iter));
vector_est_inputs.push_back(std::move(current_step_tensor));
auto vector_est_outputs = vector_est_ort_->Run(
Ort::RunOptions{nullptr},
vector_est_input_names, vector_est_inputs.data(), vector_est_inputs.size(),
vector_est_output_names, 1
);
// Update xt with denoised output
auto* denoised_data = vector_est_outputs[0].GetTensorMutableData<float>();
size_t idx = 0;
for (int b = 0; b < bsz; b++) {
for (size_t d = 0; d < xt[b].size(); d++) {
for (size_t t = 0; t < xt[b][d].size(); t++) {
xt[b][d][t] = denoised_data[idx++];
}
}
}
}
// Run vocoder
auto latent_tensor = arrayToTensor(memory_info, xt, latent_shape);
const char* vocoder_input_names[] = {"latent"};
const char* vocoder_output_names[] = {"wav_tts"};
std::vector<Ort::Value> vocoder_inputs;
vocoder_inputs.push_back(std::move(latent_tensor));
auto vocoder_outputs = vocoder_ort_->Run(
Ort::RunOptions{nullptr},
vocoder_input_names, vocoder_inputs.data(), vocoder_inputs.size(),
vocoder_output_names, 1
);
auto wav_info = vocoder_outputs[0].GetTensorTypeAndShapeInfo();
size_t wav_size = wav_info.GetElementCount();
auto* wav_data = vocoder_outputs[0].GetTensorMutableData<float>();
SynthesisResult result;
result.wav.assign(wav_data, wav_data + wav_size);
result.duration = duration;
return result;
}
TextToSpeech::SynthesisResult TextToSpeech::call(
Ort::MemoryInfo& memory_info,
const std::string& text,
const std::string& lang,
const Style& style,
int total_step,
float speed,
float silence_duration
) {
if (style.getTtlShape()[0] != 1) {
throw std::runtime_error("Single speaker text to speech only supports single style");
}
int max_len = (lang == "ko") ? 120 : 300;
auto text_list = chunkText(text, max_len);
std::vector<float> wav_cat;
float dur_cat = 0.0f;
for (const auto& chunk : text_list) {
auto result = _infer(memory_info, {chunk}, {lang}, style, total_step, speed);
if (wav_cat.empty()) {
wav_cat = result.wav;
dur_cat = result.duration[0];
} else {
int silence_len = static_cast<int>(silence_duration * sample_rate_);
std::vector<float> silence(silence_len, 0.0f);
wav_cat.insert(wav_cat.end(), silence.begin(), silence.end());
wav_cat.insert(wav_cat.end(), result.wav.begin(), result.wav.end());
dur_cat += result.duration[0] + silence_duration;
}
}
SynthesisResult final_result;
final_result.wav = wav_cat;
final_result.duration = {dur_cat};
return final_result;
}
TextToSpeech::SynthesisResult TextToSpeech::batch(
Ort::MemoryInfo& memory_info,
const std::vector<std::string>& text_list,
const std::vector<std::string>& lang_list,
const Style& style,
int total_step,
float speed
) {
return _infer(memory_info, text_list, lang_list, style, total_step, speed);
}
// ============================================================================
// Utility functions
// ============================================================================
std::vector<std::vector<std::vector<float>>> lengthToMask(
const std::vector<int64_t>& lengths, int max_len
) {
if (max_len == -1) {
max_len = *std::max_element(lengths.begin(), lengths.end());
}
std::vector<std::vector<std::vector<float>>> mask;
for (auto len : lengths) {
std::vector<std::vector<float>> batch_mask(1);
batch_mask[0].resize(max_len);
for (int i = 0; i < max_len; i++) {
batch_mask[0][i] = (i < len) ? 1.0f : 0.0f;
}
mask.push_back(batch_mask);
}
return mask;
}
std::vector<std::vector<std::vector<float>>> getLatentMask(
const std::vector<int64_t>& wav_lengths,
int base_chunk_size,
int chunk_compress_factor
) {
int latent_size = base_chunk_size * chunk_compress_factor;
std::vector<int64_t> latent_lengths;
for (auto len : wav_lengths) {
latent_lengths.push_back((len + latent_size - 1) / latent_size);
}
return lengthToMask(latent_lengths);
}
// ============================================================================
// ONNX model loading
// ============================================================================
std::unique_ptr<Ort::Session> loadOnnx(
Ort::Env& env,
const std::string& onnx_path,
const Ort::SessionOptions& opts
) {
return std::make_unique<Ort::Session>(env, onnx_path.c_str(), opts);
}
OnnxModels loadOnnxAll(
Ort::Env& env,
const std::string& onnx_dir,
const Ort::SessionOptions& opts
) {
OnnxModels models;
models.dp = loadOnnx(env, onnx_dir + "/duration_predictor.onnx", opts);
models.text_enc = loadOnnx(env, onnx_dir + "/text_encoder.onnx", opts);
models.vector_est = loadOnnx(env, onnx_dir + "/vector_estimator.onnx", opts);
models.vocoder = loadOnnx(env, onnx_dir + "/vocoder.onnx", opts);
return models;
}
// ============================================================================
// Configuration and processor loading
// ============================================================================
Config loadCfgs(const std::string& onnx_dir) {
std::string cfg_path = onnx_dir + "/tts.json";
std::ifstream file(cfg_path);
if (!file.is_open()) {
throw std::runtime_error("Failed to open config file: " + cfg_path);
}
json j;
file >> j;
Config cfg;
cfg.ae.sample_rate = j["ae"]["sample_rate"];
cfg.ae.base_chunk_size = j["ae"]["base_chunk_size"];
cfg.ttl.chunk_compress_factor = j["ttl"]["chunk_compress_factor"];
cfg.ttl.latent_dim = j["ttl"]["latent_dim"];
return cfg;
}
std::unique_ptr<UnicodeProcessor> loadTextProcessor(const std::string& onnx_dir) {
std::string unicode_indexer_path = onnx_dir + "/unicode_indexer.json";
return std::make_unique<UnicodeProcessor>(unicode_indexer_path);
}
// ============================================================================
// Voice style loading
// ============================================================================
Style loadVoiceStyle(const std::vector<std::string>& voice_style_paths, bool verbose) {
int bsz = voice_style_paths.size();
// Read first file to get dimensions
std::ifstream first_file(voice_style_paths[0]);
if (!first_file.is_open()) {
throw std::runtime_error("Failed to open voice style file: " + voice_style_paths[0]);
}
json first_json;
first_file >> first_json;
auto ttl_dims = first_json["style_ttl"]["dims"].get<std::vector<int64_t>>();
auto dp_dims = first_json["style_dp"]["dims"].get<std::vector<int64_t>>();
int64_t ttl_dim1 = ttl_dims[1];
int64_t ttl_dim2 = ttl_dims[2];
int64_t dp_dim1 = dp_dims[1];
int64_t dp_dim2 = dp_dims[2];
// Pre-allocate arrays with full batch size
size_t ttl_size = bsz * ttl_dim1 * ttl_dim2;
size_t dp_size = bsz * dp_dim1 * dp_dim2;
std::vector<float> ttl_flat(ttl_size);
std::vector<float> dp_flat(dp_size);
// Fill in the data
for (int i = 0; i < bsz; i++) {
std::ifstream file(voice_style_paths[i]);
if (!file.is_open()) {
throw std::runtime_error("Failed to open voice style file: " + voice_style_paths[i]);
}
json j;
file >> j;
// Flatten data
auto ttl_data_nested = j["style_ttl"]["data"].get<std::vector<std::vector<std::vector<float>>>>();
std::vector<float> ttl_data;
for (const auto& batch : ttl_data_nested) {
for (const auto& row : batch) {
ttl_data.insert(ttl_data.end(), row.begin(), row.end());
}
}
auto dp_data_nested = j["style_dp"]["data"].get<std::vector<std::vector<std::vector<float>>>>();
std::vector<float> dp_data;
for (const auto& batch : dp_data_nested) {
for (const auto& row : batch) {
dp_data.insert(dp_data.end(), row.begin(), row.end());
}
}
// Copy to pre-allocated array
size_t ttl_offset = i * ttl_dim1 * ttl_dim2;
std::copy(ttl_data.begin(), ttl_data.end(), ttl_flat.begin() + ttl_offset);
size_t dp_offset = i * dp_dim1 * dp_dim2;
std::copy(dp_data.begin(), dp_data.end(), dp_flat.begin() + dp_offset);
}
std::vector<int64_t> ttl_shape = {bsz, ttl_dim1, ttl_dim2};
std::vector<int64_t> dp_shape = {bsz, dp_dim1, dp_dim2};
if (verbose) {
std::cout << "Loaded " << bsz << " voice styles" << std::endl;
}
return Style(ttl_flat, ttl_shape, dp_flat, dp_shape);
}
// ============================================================================
// TextToSpeech loading
// ============================================================================
std::unique_ptr<TextToSpeech> loadTextToSpeech(
Ort::Env& env,
const std::string& onnx_dir,
bool use_gpu
) {
Ort::SessionOptions opts;
if (use_gpu) {
throw std::runtime_error("GPU mode is not supported yet");
} else {
std::cout << "Using CPU for inference" << std::endl;
}
auto cfgs = loadCfgs(onnx_dir);
auto models = loadOnnxAll(env, onnx_dir, opts);
auto text_processor = loadTextProcessor(onnx_dir);
// Transfer ownership to TextToSpeech (use raw pointers internally)
auto tts = std::make_unique<TextToSpeech>(
cfgs,
text_processor.get(),
models.dp.get(),
models.text_enc.get(),
models.vector_est.get(),
models.vocoder.get()
);
// Keep the models and processor alive by storing them
// (In production, you'd want better lifetime management)
static OnnxModels static_models;
static std::unique_ptr<UnicodeProcessor> static_text_processor;
static_models = std::move(models);
static_text_processor = std::move(text_processor);
return tts;
}
// ============================================================================
// WAV file writing
// ============================================================================
void writeWavFile(
const std::string& filename,
const std::vector<float>& audio_data,
int sample_rate
) {
std::ofstream file(filename, std::ios::binary);
if (!file.is_open()) {
throw std::runtime_error("Failed to open file for writing: " + filename);
}
int num_channels = 1;
int bits_per_sample = 16;
int byte_rate = sample_rate * num_channels * bits_per_sample / 8;
int block_align = num_channels * bits_per_sample / 8;
int data_size = audio_data.size() * bits_per_sample / 8;
// RIFF header
file.write("RIFF", 4);
int32_t chunk_size = 36 + data_size;
file.write(reinterpret_cast<char*>(&chunk_size), 4);
file.write("WAVE", 4);
// fmt chunk
file.write("fmt ", 4);
int32_t fmt_chunk_size = 16;
file.write(reinterpret_cast<char*>(&fmt_chunk_size), 4);
int16_t audio_format = 1; // PCM
file.write(reinterpret_cast<char*>(&audio_format), 2);
int16_t num_channels_16 = num_channels;
file.write(reinterpret_cast<char*>(&num_channels_16), 2);
file.write(reinterpret_cast<char*>(&sample_rate), 4);
file.write(reinterpret_cast<char*>(&byte_rate), 4);
int16_t block_align_16 = block_align;
file.write(reinterpret_cast<char*>(&block_align_16), 2);
int16_t bits_per_sample_16 = bits_per_sample;
file.write(reinterpret_cast<char*>(&bits_per_sample_16), 2);
// data chunk
file.write("data", 4);
file.write(reinterpret_cast<char*>(&data_size), 4);
// Write audio data
for (float sample : audio_data) {
float clamped = std::max(-1.0f, std::min(1.0f, sample));
int16_t int_sample = static_cast<int16_t>(clamped * 32767);
file.write(reinterpret_cast<char*>(&int_sample), 2);
}
}
// ============================================================================
// Tensor conversion utilities
// ============================================================================
Ort::Value arrayToTensor(
Ort::MemoryInfo& memory_info,
const std::vector<std::vector<std::vector<float>>>& array,
const std::vector<int64_t>& dims
) {
// Flatten the array
std::vector<float> flat;
for (const auto& batch : array) {
for (const auto& row : batch) {
for (float val : row) {
flat.push_back(val);
}
}
}
// Store in global buffer to keep data alive
g_tensor_buffers_float.push_back(std::move(flat));
auto& buffer = g_tensor_buffers_float.back();
return Ort::Value::CreateTensor<float>(
memory_info,
buffer.data(),
buffer.size(),
dims.data(),
dims.size()
);
}
Ort::Value intArrayToTensor(
Ort::MemoryInfo& memory_info,
const std::vector<std::vector<int64_t>>& array,
const std::vector<int64_t>& dims
) {
// Flatten the array
std::vector<int64_t> flat;
for (const auto& row : array) {
for (int64_t val : row) {
flat.push_back(val);
}
}
// Store in global buffer to keep data alive
g_tensor_buffers_int64.push_back(std::move(flat));
auto& buffer = g_tensor_buffers_int64.back();
return Ort::Value::CreateTensor<int64_t>(
memory_info,
buffer.data(),
buffer.size(),
dims.data(),
dims.size()
);
}
// ============================================================================
// JSON loading helpers
// ============================================================================
std::vector<int64_t> loadJsonInt64(const std::string& file_path) {
std::ifstream file(file_path);
if (!file.is_open()) {
throw std::runtime_error("Failed to open file: " + file_path);
}
json j;
file >> j;
return j.get<std::vector<int64_t>>();
}
// ============================================================================
// Sanitize filename
// ============================================================================
std::string sanitizeFilename(const std::string& text, int max_len) {
std::string result;
int char_count = 0;
size_t i = 0;
while (i < text.size() && char_count < max_len) {
unsigned char c = static_cast<unsigned char>(text[i]);
// Check if it's ASCII alphanumeric or underscore
if (std::isalnum(c) || c == '_') {
result += text[i];
i++;
char_count++;
}
// Check for UTF-8 multi-byte sequences (preserve Unicode letters/numbers)
else if ((c & 0xE0) == 0xC0 && i + 1 < text.size()) {
// 2-byte UTF-8 sequence
result += text.substr(i, 2);
i += 2;
char_count++;
}
else if ((c & 0xF0) == 0xE0 && i + 2 < text.size()) {
// 3-byte UTF-8 sequence (includes Korean, Japanese, Chinese)
result += text.substr(i, 3);
i += 3;
char_count++;
}
else if ((c & 0xF8) == 0xF0 && i + 3 < text.size()) {
// 4-byte UTF-8 sequence
result += text.substr(i, 4);
i += 4;
char_count++;
}
else {
// Replace other characters with underscore
result += '_';
i++;
char_count++;
}
}
return result;
}
// ============================================================================
// Chunk text
// ============================================================================
std::vector<std::string> chunkText(const std::string& text, int max_len) {
std::vector<std::string> chunks;
// Split by paragraph (two or more newlines)
std::regex paragraph_regex(R"(\n\s*\n+)");
std::sregex_token_iterator iter(text.begin(), text.end(), paragraph_regex, -1);
std::sregex_token_iterator end;
std::vector<std::string> paragraphs;
for (; iter != end; ++iter) {
std::string para = trim(*iter);
if (!para.empty()) {
paragraphs.push_back(para);
}
}
// Split by sentence boundaries, excluding abbreviations
// This is a simplified version - C++ negative lookbehind is more complex
std::regex sentence_regex(R"([.!?]\s+)");
for (const auto& paragraph : paragraphs) {
std::sregex_token_iterator sent_iter(paragraph.begin(), paragraph.end(), sentence_regex, -1);
std::sregex_token_iterator sent_end;
std::vector<std::string> sentences;
std::string current = "";
for (; sent_iter != sent_end; ++sent_iter) {
std::string sentence = *sent_iter;
if (!sentence.empty()) {
// Add back the punctuation
if (sent_iter != sent_end) {
std::smatch match;
if (std::regex_search(sent_iter->first, paragraph.end(), match, sentence_regex)) {
sentence += match.str();
}
}
sentences.push_back(sentence);
}
}
// Combine sentences into chunks
std::string current_chunk = "";
for (const auto& sentence : sentences) {
if (static_cast<int>(current_chunk.length() + sentence.length() + 1) <= max_len) {
if (!current_chunk.empty()) {
current_chunk += " ";
}
current_chunk += sentence;
} else {
if (!current_chunk.empty()) {
chunks.push_back(trim(current_chunk));
}
current_chunk = sentence;
}
}
if (!current_chunk.empty()) {
chunks.push_back(trim(current_chunk));
}
}
// If no chunks were created, return the original text
if (chunks.empty()) {
chunks.push_back(trim(text));
}
return chunks;
}