initial commit

This commit is contained in:
2026-01-25 18:58:40 +09:00
commit 77af47274c
101 changed files with 16247 additions and 0 deletions

955
java/Helper.java Normal file
View File

@@ -0,0 +1,955 @@
import ai.onnxruntime.*;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import javax.sound.sampled.AudioFileFormat;
import javax.sound.sampled.AudioFormat;
import javax.sound.sampled.AudioInputStream;
import javax.sound.sampled.AudioSystem;
import java.io.*;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.LongBuffer;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.text.Normalizer;
import java.util.*;
import java.util.regex.Pattern;
import java.util.regex.Matcher;
/**
* Available languages for multilingual TTS
*/
class Languages {
public static final List<String> AVAILABLE = Arrays.asList("en", "ko", "es", "pt", "fr");
public static boolean isValid(String lang) {
return AVAILABLE.contains(lang);
}
}
/**
* Configuration classes
*/
class Config {
static class AEConfig {
int sampleRate;
int baseChunkSize;
}
static class TTLConfig {
int chunkCompressFactor;
int latentDim;
}
AEConfig ae;
TTLConfig ttl;
}
/**
* Voice Style Data from JSON
*/
class VoiceStyleData {
static class StyleData {
float[][][] data;
long[] dims;
String type;
}
StyleData styleTtl;
StyleData styleDp;
}
/**
* Unicode text processor
*/
class UnicodeProcessor {
private long[] indexer;
public UnicodeProcessor(String unicodeIndexerJsonPath) throws IOException {
this.indexer = Helper.loadJsonLongArray(unicodeIndexerJsonPath);
}
private static String removeEmojis(String text) {
StringBuilder result = new StringBuilder();
for (int i = 0; i < text.length(); i++) {
int codePoint;
if (Character.isHighSurrogate(text.charAt(i)) && i + 1 < text.length() && Character.isLowSurrogate(text.charAt(i + 1))) {
codePoint = Character.codePointAt(text, i);
i++; // Skip the low surrogate
} else {
codePoint = text.charAt(i);
}
// Check if code point is in emoji ranges
boolean isEmoji = (codePoint >= 0x1F600 && codePoint <= 0x1F64F) ||
(codePoint >= 0x1F300 && codePoint <= 0x1F5FF) ||
(codePoint >= 0x1F680 && codePoint <= 0x1F6FF) ||
(codePoint >= 0x1F700 && codePoint <= 0x1F77F) ||
(codePoint >= 0x1F780 && codePoint <= 0x1F7FF) ||
(codePoint >= 0x1F800 && codePoint <= 0x1F8FF) ||
(codePoint >= 0x1F900 && codePoint <= 0x1F9FF) ||
(codePoint >= 0x1FA00 && codePoint <= 0x1FA6F) ||
(codePoint >= 0x1FA70 && codePoint <= 0x1FAFF) ||
(codePoint >= 0x2600 && codePoint <= 0x26FF) ||
(codePoint >= 0x2700 && codePoint <= 0x27BF) ||
(codePoint >= 0x1F1E6 && codePoint <= 0x1F1FF);
if (!isEmoji) {
if (codePoint > 0xFFFF) {
result.append(Character.toChars(codePoint));
} else {
result.append((char) codePoint);
}
}
}
return result.toString();
}
public TextProcessResult call(List<String> textList, List<String> langList) {
List<String> processedTexts = new ArrayList<>();
for (int i = 0; i < textList.size(); i++) {
processedTexts.add(preprocessText(textList.get(i), langList.get(i)));
}
// Convert texts to unicode values first to get correct character counts
List<int[]> allUnicodeVals = new ArrayList<>();
for (String text : processedTexts) {
allUnicodeVals.add(textToUnicodeValues(text));
}
int[] textIdsLengths = new int[processedTexts.size()];
int maxLen = 0;
for (int i = 0; i < allUnicodeVals.size(); i++) {
textIdsLengths[i] = allUnicodeVals.get(i).length; // Use code point count, not char count
maxLen = Math.max(maxLen, textIdsLengths[i]);
}
long[][] textIds = new long[processedTexts.size()][maxLen];
for (int i = 0; i < allUnicodeVals.size(); i++) {
int[] unicodeVals = allUnicodeVals.get(i);
for (int j = 0; j < unicodeVals.length; j++) {
textIds[i][j] = indexer[unicodeVals[j]];
}
}
float[][][] textMask = getTextMask(textIdsLengths);
return new TextProcessResult(textIds, textMask);
}
private String preprocessText(String text, String lang) {
// TODO: Need advanced normalizer for better performance
text = Normalizer.normalize(text, Normalizer.Form.NFKD);
// Remove emojis (wide Unicode range)
// Java Pattern doesn't support \x{...} syntax for Unicode above \uFFFF
// Use character filtering instead
text = removeEmojis(text);
// Replace various dashes and symbols
Map<String, String> replacements = new HashMap<>();
replacements.put("", "-"); // en dash
replacements.put("", "-"); // non-breaking hyphen
replacements.put("", "-"); // em dash
replacements.put("_", " "); // underscore
replacements.put("\u201C", "\""); // left double quote
replacements.put("\u201D", "\""); // right double quote
replacements.put("\u2018", "'"); // left single quote
replacements.put("\u2019", "'"); // right single quote
replacements.put("´", "'"); // acute accent
replacements.put("`", "'"); // grave accent
replacements.put("[", " "); // left bracket
replacements.put("]", " "); // right bracket
replacements.put("|", " "); // vertical bar
replacements.put("/", " "); // slash
replacements.put("#", " "); // hash
replacements.put("", " "); // right arrow
replacements.put("", " "); // left arrow
for (Map.Entry<String, String> entry : replacements.entrySet()) {
text = text.replace(entry.getKey(), entry.getValue());
}
// Remove special symbols
text = text.replaceAll("[♥☆♡©\\\\]", "");
// Replace known expressions
Map<String, String> exprReplacements = new HashMap<>();
exprReplacements.put("@", " at ");
exprReplacements.put("e.g.,", "for example, ");
exprReplacements.put("i.e.,", "that is, ");
for (Map.Entry<String, String> entry : exprReplacements.entrySet()) {
text = text.replace(entry.getKey(), entry.getValue());
}
// Fix spacing around punctuation
text = text.replaceAll(" ,", ",");
text = text.replaceAll(" \\.", ".");
text = text.replaceAll(" !", "!");
text = text.replaceAll(" \\?", "?");
text = text.replaceAll(" ;", ";");
text = text.replaceAll(" :", ":");
text = text.replaceAll(" '", "'");
// Remove duplicate quotes
while (text.contains("\"\"")) {
text = text.replace("\"\"", "\"");
}
while (text.contains("''")) {
text = text.replace("''", "'");
}
while (text.contains("``")) {
text = text.replace("``", "`");
}
// Remove extra spaces
text = text.replaceAll("\\s+", " ").trim();
// If text doesn't end with punctuation, quotes, or closing brackets, add a period
if (!text.matches(".*[.!?;:,'\"\\u201C\\u201D\\u2018\\u2019)\\]}…。」』】〉》›»]$")) {
text += ".";
}
// Validate language
if (!Languages.isValid(lang)) {
throw new IllegalArgumentException("Invalid language: " + lang + ". Available: " + Languages.AVAILABLE);
}
// Wrap text with language tags
text = "<" + lang + ">" + text + "</" + lang + ">";
return text;
}
private int[] textToUnicodeValues(String text) {
// Use codePoints() stream to correctly handle surrogate pairs
return text.codePoints().toArray();
}
private float[][][] getTextMask(int[] lengths) {
int bsz = lengths.length;
int maxLen = 0;
for (int len : lengths) {
maxLen = Math.max(maxLen, len);
}
float[][][] mask = new float[bsz][1][maxLen];
for (int i = 0; i < bsz; i++) {
for (int j = 0; j < maxLen; j++) {
mask[i][0][j] = j < lengths[i] ? 1.0f : 0.0f;
}
}
return mask;
}
static class TextProcessResult {
long[][] textIds;
float[][][] textMask;
TextProcessResult(long[][] textIds, float[][][] textMask) {
this.textIds = textIds;
this.textMask = textMask;
}
}
}
/**
* Text-to-Speech inference class
*/
class TextToSpeech {
private Config config;
private UnicodeProcessor textProcessor;
private OrtSession dpSession;
private OrtSession textEncSession;
private OrtSession vectorEstSession;
private OrtSession vocoderSession;
public int sampleRate;
private int baseChunkSize;
private int chunkCompress;
private int ldim;
public TextToSpeech(Config config, UnicodeProcessor textProcessor,
OrtSession dpSession, OrtSession textEncSession,
OrtSession vectorEstSession, OrtSession vocoderSession) {
this.config = config;
this.textProcessor = textProcessor;
this.dpSession = dpSession;
this.textEncSession = textEncSession;
this.vectorEstSession = vectorEstSession;
this.vocoderSession = vocoderSession;
this.sampleRate = config.ae.sampleRate;
this.baseChunkSize = config.ae.baseChunkSize;
this.chunkCompress = config.ttl.chunkCompressFactor;
this.ldim = config.ttl.latentDim;
}
private TTSResult _infer(List<String> textList, List<String> langList, Style style, int totalStep, float speed, OrtEnvironment env)
throws OrtException {
int bsz = textList.size();
// Process text
UnicodeProcessor.TextProcessResult textResult = textProcessor.call(textList, langList);
long[][] textIds = textResult.textIds;
float[][][] textMask = textResult.textMask;
// Create tensors
OnnxTensor textIdsTensor = Helper.createLongTensor(textIds, env);
OnnxTensor textMaskTensor = Helper.createFloatTensor(textMask, env);
// Predict duration
Map<String, OnnxTensor> dpInputs = new HashMap<>();
dpInputs.put("text_ids", textIdsTensor);
dpInputs.put("style_dp", style.dpTensor);
dpInputs.put("text_mask", textMaskTensor);
OrtSession.Result dpResult = dpSession.run(dpInputs);
Object dpValue = dpResult.get(0).getValue();
float[] duration;
if (dpValue instanceof float[][]) {
duration = ((float[][]) dpValue)[0];
} else {
duration = (float[]) dpValue;
}
// Apply speed factor to duration
for (int i = 0; i < duration.length; i++) {
duration[i] /= speed;
}
// Encode text
Map<String, OnnxTensor> textEncInputs = new HashMap<>();
textEncInputs.put("text_ids", textIdsTensor);
textEncInputs.put("style_ttl", style.ttlTensor);
textEncInputs.put("text_mask", textMaskTensor);
OrtSession.Result textEncResult = textEncSession.run(textEncInputs);
OnnxTensor textEmbTensor = (OnnxTensor) textEncResult.get(0);
// Sample noisy latent
NoisyLatentResult noisyLatentResult = sampleNoisyLatent(duration);
float[][][] xt = noisyLatentResult.noisyLatent;
float[][][] latentMask = noisyLatentResult.latentMask;
// Prepare constant tensors
float[] totalStepArray = new float[bsz];
Arrays.fill(totalStepArray, (float) totalStep);
OnnxTensor totalStepTensor = OnnxTensor.createTensor(env, totalStepArray);
// Denoising loop
for (int step = 0; step < totalStep; step++) {
float[] currentStepArray = new float[bsz];
Arrays.fill(currentStepArray, (float) step);
OnnxTensor currentStepTensor = OnnxTensor.createTensor(env, currentStepArray);
OnnxTensor noisyLatentTensor = Helper.createFloatTensor(xt, env);
OnnxTensor latentMaskTensor = Helper.createFloatTensor(latentMask, env);
OnnxTensor textMaskTensor2 = Helper.createFloatTensor(textMask, env);
Map<String, OnnxTensor> vectorEstInputs = new HashMap<>();
vectorEstInputs.put("noisy_latent", noisyLatentTensor);
vectorEstInputs.put("text_emb", textEmbTensor);
vectorEstInputs.put("style_ttl", style.ttlTensor);
vectorEstInputs.put("latent_mask", latentMaskTensor);
vectorEstInputs.put("text_mask", textMaskTensor2);
vectorEstInputs.put("current_step", currentStepTensor);
vectorEstInputs.put("total_step", totalStepTensor);
OrtSession.Result vectorEstResult = vectorEstSession.run(vectorEstInputs);
float[][][] denoised = (float[][][]) vectorEstResult.get(0).getValue();
// Update latent
xt = denoised;
// Clean up
currentStepTensor.close();
noisyLatentTensor.close();
latentMaskTensor.close();
textMaskTensor2.close();
vectorEstResult.close();
}
// Generate waveform
OnnxTensor finalLatentTensor = Helper.createFloatTensor(xt, env);
Map<String, OnnxTensor> vocoderInputs = new HashMap<>();
vocoderInputs.put("latent", finalLatentTensor);
OrtSession.Result vocoderResult = vocoderSession.run(vocoderInputs);
float[][] wavBatch = (float[][]) vocoderResult.get(0).getValue();
// Flatten all batch audio into a single array for batch processing
int totalSamples = 0;
for (float[] w : wavBatch) {
totalSamples += w.length;
}
float[] wav = new float[totalSamples];
int offset = 0;
for (float[] w : wavBatch) {
System.arraycopy(w, 0, wav, offset, w.length);
offset += w.length;
}
// Clean up
textIdsTensor.close();
textMaskTensor.close();
dpResult.close();
textEncResult.close();
totalStepTensor.close();
finalLatentTensor.close();
vocoderResult.close();
return new TTSResult(wav, duration);
}
private NoisyLatentResult sampleNoisyLatent(float[] duration) {
int bsz = duration.length;
float maxDur = 0;
for (float d : duration) {
maxDur = Math.max(maxDur, d);
}
long wavLenMax = (long) (maxDur * sampleRate);
long[] wavLengths = new long[bsz];
for (int i = 0; i < bsz; i++) {
wavLengths[i] = (long) (duration[i] * sampleRate);
}
int chunkSize = baseChunkSize * chunkCompress;
int latentLen = (int) ((wavLenMax + chunkSize - 1) / chunkSize);
int latentDim = ldim * chunkCompress;
Random rng = new Random();
float[][][] noisyLatent = new float[bsz][latentDim][latentLen];
for (int b = 0; b < bsz; b++) {
for (int d = 0; d < latentDim; d++) {
for (int t = 0; t < latentLen; t++) {
// Box-Muller transform
double u1 = Math.max(1e-10, rng.nextDouble());
double u2 = rng.nextDouble();
noisyLatent[b][d][t] = (float) (Math.sqrt(-2.0 * Math.log(u1)) * Math.cos(2.0 * Math.PI * u2));
}
}
}
float[][][] latentMask = Helper.getLatentMask(wavLengths, config);
// Apply mask
for (int b = 0; b < bsz; b++) {
for (int d = 0; d < latentDim; d++) {
for (int t = 0; t < latentLen; t++) {
noisyLatent[b][d][t] *= latentMask[b][0][t];
}
}
}
return new NoisyLatentResult(noisyLatent, latentMask);
}
/**
* Synthesize speech from a single text with automatic chunking
*/
public TTSResult call(String text, String lang, Style style, int totalStep, float speed, float silenceDuration, OrtEnvironment env)
throws OrtException {
int maxLen = lang.equals("ko") ? 120 : 300;
List<String> chunks = Helper.chunkText(text, maxLen);
List<Float> wavCat = new ArrayList<>();
float durCat = 0.0f;
for (int i = 0; i < chunks.size(); i++) {
TTSResult result = _infer(Arrays.asList(chunks.get(i)), Arrays.asList(lang), style, totalStep, speed, env);
float dur = result.duration[0];
int wavLen = (int) (sampleRate * dur);
float[] wavChunk = new float[wavLen];
System.arraycopy(result.wav, 0, wavChunk, 0, Math.min(wavLen, result.wav.length));
if (i == 0) {
for (float val : wavChunk) {
wavCat.add(val);
}
durCat = dur;
} else {
int silenceLen = (int) (silenceDuration * sampleRate);
for (int j = 0; j < silenceLen; j++) {
wavCat.add(0.0f);
}
for (float val : wavChunk) {
wavCat.add(val);
}
durCat += silenceDuration + dur;
}
}
float[] wavArray = new float[wavCat.size()];
for (int i = 0; i < wavCat.size(); i++) {
wavArray[i] = wavCat.get(i);
}
return new TTSResult(wavArray, new float[]{durCat});
}
/**
* Batch synthesize speech from multiple texts
*/
public TTSResult batch(List<String> textList, List<String> langList, Style style, int totalStep, float speed, OrtEnvironment env)
throws OrtException {
return _infer(textList, langList, style, totalStep, speed, env);
}
public void close() throws OrtException {
if (dpSession != null) dpSession.close();
if (textEncSession != null) textEncSession.close();
if (vectorEstSession != null) vectorEstSession.close();
if (vocoderSession != null) vocoderSession.close();
}
}
/**
* Style holder class
*/
class Style {
OnnxTensor ttlTensor;
OnnxTensor dpTensor;
Style(OnnxTensor ttlTensor, OnnxTensor dpTensor) {
this.ttlTensor = ttlTensor;
this.dpTensor = dpTensor;
}
public void close() throws OrtException {
if (ttlTensor != null) ttlTensor.close();
if (dpTensor != null) dpTensor.close();
}
}
/**
* TTS result holder
*/
class TTSResult {
float[] wav;
float[] duration;
TTSResult(float[] wav, float[] duration) {
this.wav = wav;
this.duration = duration;
}
}
/**
* Noisy latent result holder
*/
class NoisyLatentResult {
float[][][] noisyLatent;
float[][][] latentMask;
NoisyLatentResult(float[][][] noisyLatent, float[][][] latentMask) {
this.noisyLatent = noisyLatent;
this.latentMask = latentMask;
}
}
/**
* Helper utility class
*/
public class Helper {
private static final int MAX_CHUNK_LENGTH = 300;
private static final String[] ABBREVIATIONS = {
"Dr.", "Mr.", "Mrs.", "Ms.", "Prof.", "Sr.", "Jr.",
"St.", "Ave.", "Rd.", "Blvd.", "Dept.", "Inc.", "Ltd.",
"Co.", "Corp.", "etc.", "vs.", "i.e.", "e.g.", "Ph.D."
};
/**
* Chunk text into smaller segments based on paragraphs and sentences
*/
public static List<String> chunkText(String text, int maxLen) {
if (maxLen == 0) {
maxLen = MAX_CHUNK_LENGTH;
}
text = text.trim();
if (text.isEmpty()) {
return Arrays.asList("");
}
// Split by paragraphs
String[] paragraphs = text.split("\\n\\s*\\n");
List<String> chunks = new ArrayList<>();
for (String para : paragraphs) {
para = para.trim();
if (para.isEmpty()) {
continue;
}
if (para.length() <= maxLen) {
chunks.add(para);
continue;
}
// Split by sentences
List<String> sentences = splitSentences(para);
StringBuilder current = new StringBuilder();
int currentLen = 0;
for (String sentence : sentences) {
sentence = sentence.trim();
if (sentence.isEmpty()) {
continue;
}
int sentenceLen = sentence.length();
if (sentenceLen > maxLen) {
// If sentence is longer than maxLen, split by comma or space
if (current.length() > 0) {
chunks.add(current.toString().trim());
current.setLength(0);
currentLen = 0;
}
// Try splitting by comma
String[] parts = sentence.split(",");
for (String part : parts) {
part = part.trim();
if (part.isEmpty()) {
continue;
}
int partLen = part.length();
if (partLen > maxLen) {
// Split by space as last resort
String[] words = part.split("\\s+");
StringBuilder wordChunk = new StringBuilder();
int wordChunkLen = 0;
for (String word : words) {
int wordLen = word.length();
if (wordChunkLen + wordLen + 1 > maxLen && wordChunk.length() > 0) {
chunks.add(wordChunk.toString().trim());
wordChunk.setLength(0);
wordChunkLen = 0;
}
if (wordChunk.length() > 0) {
wordChunk.append(" ");
wordChunkLen++;
}
wordChunk.append(word);
wordChunkLen += wordLen;
}
if (wordChunk.length() > 0) {
chunks.add(wordChunk.toString().trim());
}
} else {
if (currentLen + partLen + 1 > maxLen && current.length() > 0) {
chunks.add(current.toString().trim());
current.setLength(0);
currentLen = 0;
}
if (current.length() > 0) {
current.append(", ");
currentLen += 2;
}
current.append(part);
currentLen += partLen;
}
}
continue;
}
if (currentLen + sentenceLen + 1 > maxLen && current.length() > 0) {
chunks.add(current.toString().trim());
current.setLength(0);
currentLen = 0;
}
if (current.length() > 0) {
current.append(" ");
currentLen++;
}
current.append(sentence);
currentLen += sentenceLen;
}
if (current.length() > 0) {
chunks.add(current.toString().trim());
}
}
if (chunks.isEmpty()) {
return Arrays.asList("");
}
return chunks;
}
/**
* Split text into sentences, avoiding common abbreviations
*/
private static List<String> splitSentences(String text) {
// Build pattern that avoids abbreviations
StringBuilder abbrevPattern = new StringBuilder();
for (int i = 0; i < ABBREVIATIONS.length; i++) {
if (i > 0) abbrevPattern.append("|");
abbrevPattern.append(Pattern.quote(ABBREVIATIONS[i]));
}
// Match sentence endings, but not abbreviations
String patternStr = "(?<!(?:" + abbrevPattern.toString() + "))(?<=[.!?])\\s+";
Pattern pattern = Pattern.compile(patternStr);
return Arrays.asList(pattern.split(text));
}
/**
* Load voice style from JSON files
*/
public static Style loadVoiceStyle(List<String> voiceStylePaths, boolean verbose, OrtEnvironment env)
throws IOException, OrtException {
int bsz = voiceStylePaths.size();
// Read first file to get dimensions
ObjectMapper mapper = new ObjectMapper();
JsonNode firstRoot = mapper.readTree(new File(voiceStylePaths.get(0)));
long[] ttlDims = new long[3];
for (int i = 0; i < 3; i++) {
ttlDims[i] = firstRoot.get("style_ttl").get("dims").get(i).asLong();
}
long[] dpDims = new long[3];
for (int i = 0; i < 3; i++) {
dpDims[i] = firstRoot.get("style_dp").get("dims").get(i).asLong();
}
long ttlDim1 = ttlDims[1];
long ttlDim2 = ttlDims[2];
long dpDim1 = dpDims[1];
long dpDim2 = dpDims[2];
// Pre-allocate arrays with full batch size
int ttlSize = (int) (bsz * ttlDim1 * ttlDim2);
int dpSize = (int) (bsz * dpDim1 * dpDim2);
float[] ttlFlat = new float[ttlSize];
float[] dpFlat = new float[dpSize];
// Fill in the data
for (int i = 0; i < bsz; i++) {
JsonNode root = mapper.readTree(new File(voiceStylePaths.get(i)));
// Flatten TTL data
int ttlOffset = (int) (i * ttlDim1 * ttlDim2);
int idx = 0;
JsonNode ttlData = root.get("style_ttl").get("data");
for (JsonNode batch : ttlData) {
for (JsonNode row : batch) {
for (JsonNode val : row) {
ttlFlat[ttlOffset + idx++] = (float) val.asDouble();
}
}
}
// Flatten DP data
int dpOffset = (int) (i * dpDim1 * dpDim2);
idx = 0;
JsonNode dpData = root.get("style_dp").get("data");
for (JsonNode batch : dpData) {
for (JsonNode row : batch) {
for (JsonNode val : row) {
dpFlat[dpOffset + idx++] = (float) val.asDouble();
}
}
}
}
long[] ttlShape = {bsz, ttlDim1, ttlDim2};
long[] dpShape = {bsz, dpDim1, dpDim2};
OnnxTensor ttlTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(ttlFlat), ttlShape);
OnnxTensor dpTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(dpFlat), dpShape);
if (verbose) {
System.out.println("Loaded " + bsz + " voice styles\n");
}
return new Style(ttlTensor, dpTensor);
}
/**
* Load TTS components
*/
public static TextToSpeech loadTextToSpeech(String onnxDir, boolean useGpu, OrtEnvironment env)
throws IOException, OrtException {
if (useGpu) {
throw new RuntimeException("GPU mode is not supported yet");
}
System.out.println("Using CPU for inference\n");
// Load config
Config config = loadCfgs(onnxDir);
// Create session options
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
// Load models
OrtSession dpSession = env.createSession(onnxDir + "/duration_predictor.onnx", opts);
OrtSession textEncSession = env.createSession(onnxDir + "/text_encoder.onnx", opts);
OrtSession vectorEstSession = env.createSession(onnxDir + "/vector_estimator.onnx", opts);
OrtSession vocoderSession = env.createSession(onnxDir + "/vocoder.onnx", opts);
// Load text processor
UnicodeProcessor textProcessor = new UnicodeProcessor(onnxDir + "/unicode_indexer.json");
return new TextToSpeech(config, textProcessor, dpSession, textEncSession, vectorEstSession, vocoderSession);
}
/**
* Load configuration from JSON
*/
public static Config loadCfgs(String onnxDir) throws IOException {
ObjectMapper mapper = new ObjectMapper();
JsonNode root = mapper.readTree(new File(onnxDir + "/tts.json"));
Config config = new Config();
config.ae = new Config.AEConfig();
config.ae.sampleRate = root.get("ae").get("sample_rate").asInt();
config.ae.baseChunkSize = root.get("ae").get("base_chunk_size").asInt();
config.ttl = new Config.TTLConfig();
config.ttl.chunkCompressFactor = root.get("ttl").get("chunk_compress_factor").asInt();
config.ttl.latentDim = root.get("ttl").get("latent_dim").asInt();
return config;
}
/**
* Get latent mask from wav lengths
*/
public static float[][][] getLatentMask(long[] wavLengths, Config config) {
long baseChunkSize = config.ae.baseChunkSize;
long chunkCompressFactor = config.ttl.chunkCompressFactor;
long latentSize = baseChunkSize * chunkCompressFactor;
long[] latentLengths = new long[wavLengths.length];
long maxLen = 0;
for (int i = 0; i < wavLengths.length; i++) {
latentLengths[i] = (wavLengths[i] + latentSize - 1) / latentSize;
maxLen = Math.max(maxLen, latentLengths[i]);
}
float[][][] mask = new float[wavLengths.length][1][(int) maxLen];
for (int i = 0; i < wavLengths.length; i++) {
for (int j = 0; j < maxLen; j++) {
mask[i][0][j] = j < latentLengths[i] ? 1.0f : 0.0f;
}
}
return mask;
}
/**
* Write WAV file
*/
public static void writeWavFile(String filename, float[] audioData, int sampleRate) throws IOException {
// Convert float to byte array
byte[] bytes = new byte[audioData.length * 2];
ByteBuffer buffer = ByteBuffer.wrap(bytes);
buffer.order(ByteOrder.LITTLE_ENDIAN);
for (float sample : audioData) {
short val = (short) Math.max(-32768, Math.min(32767, sample * 32767));
buffer.putShort(val);
}
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
AudioFormat format = new AudioFormat(sampleRate, 16, 1, true, false);
AudioInputStream ais = new AudioInputStream(bais, format, audioData.length);
AudioSystem.write(ais, AudioFileFormat.Type.WAVE, new File(filename));
}
/**
* Sanitize filename (supports Unicode characters)
*/
public static String sanitizeFilename(String text, int maxLen) {
// Get first maxLen characters (code points, not chars for surrogate pairs)
int[] codePoints = text.codePoints().limit(maxLen).toArray();
StringBuilder result = new StringBuilder();
for (int codePoint : codePoints) {
if (Character.isLetterOrDigit(codePoint)) {
result.appendCodePoint(codePoint);
} else {
result.append('_');
}
}
return result.toString();
}
/**
* Timer utility
*/
public static <T> T timer(String name, java.util.function.Supplier<T> fn) {
long start = System.currentTimeMillis();
System.out.println(name + "...");
T result = fn.get();
long elapsed = System.currentTimeMillis() - start;
System.out.printf(" -> %s completed in %.2f sec\n", name, elapsed / 1000.0);
return result;
}
/**
* Create float tensor from 3D array
*/
public static OnnxTensor createFloatTensor(float[][][] array, OrtEnvironment env) throws OrtException {
int dim0 = array.length;
int dim1 = array[0].length;
int dim2 = array[0][0].length;
float[] flat = new float[dim0 * dim1 * dim2];
int idx = 0;
for (int i = 0; i < dim0; i++) {
for (int j = 0; j < dim1; j++) {
for (int k = 0; k < dim2; k++) {
flat[idx++] = array[i][j][k];
}
}
}
long[] shape = {dim0, dim1, dim2};
return OnnxTensor.createTensor(env, FloatBuffer.wrap(flat), shape);
}
/**
* Create long tensor from 2D array
*/
public static OnnxTensor createLongTensor(long[][] array, OrtEnvironment env) throws OrtException {
int dim0 = array.length;
int dim1 = array[0].length;
long[] flat = new long[dim0 * dim1];
int idx = 0;
for (int i = 0; i < dim0; i++) {
for (int j = 0; j < dim1; j++) {
flat[idx++] = array[i][j];
}
}
long[] shape = {dim0, dim1};
return OnnxTensor.createTensor(env, LongBuffer.wrap(flat), shape);
}
/**
* Load JSON long array
*/
public static long[] loadJsonLongArray(String filePath) throws IOException {
ObjectMapper mapper = new ObjectMapper();
JsonNode root = mapper.readTree(new File(filePath));
long[] result = new long[root.size()];
for (int i = 0; i < root.size(); i++) {
result[i] = root.get(i).asLong();
}
return result;
}
}