Files
Supertonic/java/Helper.java
2026-01-25 18:58:40 +09:00

956 lines
34 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import ai.onnxruntime.*;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import javax.sound.sampled.AudioFileFormat;
import javax.sound.sampled.AudioFormat;
import javax.sound.sampled.AudioInputStream;
import javax.sound.sampled.AudioSystem;
import java.io.*;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.LongBuffer;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.text.Normalizer;
import java.util.*;
import java.util.regex.Pattern;
import java.util.regex.Matcher;
/**
* Available languages for multilingual TTS
*/
class Languages {
public static final List<String> AVAILABLE = Arrays.asList("en", "ko", "es", "pt", "fr");
public static boolean isValid(String lang) {
return AVAILABLE.contains(lang);
}
}
/**
* Configuration classes
*/
class Config {
static class AEConfig {
int sampleRate;
int baseChunkSize;
}
static class TTLConfig {
int chunkCompressFactor;
int latentDim;
}
AEConfig ae;
TTLConfig ttl;
}
/**
* Voice Style Data from JSON
*/
class VoiceStyleData {
static class StyleData {
float[][][] data;
long[] dims;
String type;
}
StyleData styleTtl;
StyleData styleDp;
}
/**
* Unicode text processor
*/
class UnicodeProcessor {
private long[] indexer;
public UnicodeProcessor(String unicodeIndexerJsonPath) throws IOException {
this.indexer = Helper.loadJsonLongArray(unicodeIndexerJsonPath);
}
private static String removeEmojis(String text) {
StringBuilder result = new StringBuilder();
for (int i = 0; i < text.length(); i++) {
int codePoint;
if (Character.isHighSurrogate(text.charAt(i)) && i + 1 < text.length() && Character.isLowSurrogate(text.charAt(i + 1))) {
codePoint = Character.codePointAt(text, i);
i++; // Skip the low surrogate
} else {
codePoint = text.charAt(i);
}
// Check if code point is in emoji ranges
boolean isEmoji = (codePoint >= 0x1F600 && codePoint <= 0x1F64F) ||
(codePoint >= 0x1F300 && codePoint <= 0x1F5FF) ||
(codePoint >= 0x1F680 && codePoint <= 0x1F6FF) ||
(codePoint >= 0x1F700 && codePoint <= 0x1F77F) ||
(codePoint >= 0x1F780 && codePoint <= 0x1F7FF) ||
(codePoint >= 0x1F800 && codePoint <= 0x1F8FF) ||
(codePoint >= 0x1F900 && codePoint <= 0x1F9FF) ||
(codePoint >= 0x1FA00 && codePoint <= 0x1FA6F) ||
(codePoint >= 0x1FA70 && codePoint <= 0x1FAFF) ||
(codePoint >= 0x2600 && codePoint <= 0x26FF) ||
(codePoint >= 0x2700 && codePoint <= 0x27BF) ||
(codePoint >= 0x1F1E6 && codePoint <= 0x1F1FF);
if (!isEmoji) {
if (codePoint > 0xFFFF) {
result.append(Character.toChars(codePoint));
} else {
result.append((char) codePoint);
}
}
}
return result.toString();
}
public TextProcessResult call(List<String> textList, List<String> langList) {
List<String> processedTexts = new ArrayList<>();
for (int i = 0; i < textList.size(); i++) {
processedTexts.add(preprocessText(textList.get(i), langList.get(i)));
}
// Convert texts to unicode values first to get correct character counts
List<int[]> allUnicodeVals = new ArrayList<>();
for (String text : processedTexts) {
allUnicodeVals.add(textToUnicodeValues(text));
}
int[] textIdsLengths = new int[processedTexts.size()];
int maxLen = 0;
for (int i = 0; i < allUnicodeVals.size(); i++) {
textIdsLengths[i] = allUnicodeVals.get(i).length; // Use code point count, not char count
maxLen = Math.max(maxLen, textIdsLengths[i]);
}
long[][] textIds = new long[processedTexts.size()][maxLen];
for (int i = 0; i < allUnicodeVals.size(); i++) {
int[] unicodeVals = allUnicodeVals.get(i);
for (int j = 0; j < unicodeVals.length; j++) {
textIds[i][j] = indexer[unicodeVals[j]];
}
}
float[][][] textMask = getTextMask(textIdsLengths);
return new TextProcessResult(textIds, textMask);
}
private String preprocessText(String text, String lang) {
// TODO: Need advanced normalizer for better performance
text = Normalizer.normalize(text, Normalizer.Form.NFKD);
// Remove emojis (wide Unicode range)
// Java Pattern doesn't support \x{...} syntax for Unicode above \uFFFF
// Use character filtering instead
text = removeEmojis(text);
// Replace various dashes and symbols
Map<String, String> replacements = new HashMap<>();
replacements.put("", "-"); // en dash
replacements.put("", "-"); // non-breaking hyphen
replacements.put("", "-"); // em dash
replacements.put("_", " "); // underscore
replacements.put("\u201C", "\""); // left double quote
replacements.put("\u201D", "\""); // right double quote
replacements.put("\u2018", "'"); // left single quote
replacements.put("\u2019", "'"); // right single quote
replacements.put("´", "'"); // acute accent
replacements.put("`", "'"); // grave accent
replacements.put("[", " "); // left bracket
replacements.put("]", " "); // right bracket
replacements.put("|", " "); // vertical bar
replacements.put("/", " "); // slash
replacements.put("#", " "); // hash
replacements.put("", " "); // right arrow
replacements.put("", " "); // left arrow
for (Map.Entry<String, String> entry : replacements.entrySet()) {
text = text.replace(entry.getKey(), entry.getValue());
}
// Remove special symbols
text = text.replaceAll("[♥☆♡©\\\\]", "");
// Replace known expressions
Map<String, String> exprReplacements = new HashMap<>();
exprReplacements.put("@", " at ");
exprReplacements.put("e.g.,", "for example, ");
exprReplacements.put("i.e.,", "that is, ");
for (Map.Entry<String, String> entry : exprReplacements.entrySet()) {
text = text.replace(entry.getKey(), entry.getValue());
}
// Fix spacing around punctuation
text = text.replaceAll(" ,", ",");
text = text.replaceAll(" \\.", ".");
text = text.replaceAll(" !", "!");
text = text.replaceAll(" \\?", "?");
text = text.replaceAll(" ;", ";");
text = text.replaceAll(" :", ":");
text = text.replaceAll(" '", "'");
// Remove duplicate quotes
while (text.contains("\"\"")) {
text = text.replace("\"\"", "\"");
}
while (text.contains("''")) {
text = text.replace("''", "'");
}
while (text.contains("``")) {
text = text.replace("``", "`");
}
// Remove extra spaces
text = text.replaceAll("\\s+", " ").trim();
// If text doesn't end with punctuation, quotes, or closing brackets, add a period
if (!text.matches(".*[.!?;:,'\"\\u201C\\u201D\\u2018\\u2019)\\]}…。」』】〉》›»]$")) {
text += ".";
}
// Validate language
if (!Languages.isValid(lang)) {
throw new IllegalArgumentException("Invalid language: " + lang + ". Available: " + Languages.AVAILABLE);
}
// Wrap text with language tags
text = "<" + lang + ">" + text + "</" + lang + ">";
return text;
}
private int[] textToUnicodeValues(String text) {
// Use codePoints() stream to correctly handle surrogate pairs
return text.codePoints().toArray();
}
private float[][][] getTextMask(int[] lengths) {
int bsz = lengths.length;
int maxLen = 0;
for (int len : lengths) {
maxLen = Math.max(maxLen, len);
}
float[][][] mask = new float[bsz][1][maxLen];
for (int i = 0; i < bsz; i++) {
for (int j = 0; j < maxLen; j++) {
mask[i][0][j] = j < lengths[i] ? 1.0f : 0.0f;
}
}
return mask;
}
static class TextProcessResult {
long[][] textIds;
float[][][] textMask;
TextProcessResult(long[][] textIds, float[][][] textMask) {
this.textIds = textIds;
this.textMask = textMask;
}
}
}
/**
* Text-to-Speech inference class
*/
class TextToSpeech {
private Config config;
private UnicodeProcessor textProcessor;
private OrtSession dpSession;
private OrtSession textEncSession;
private OrtSession vectorEstSession;
private OrtSession vocoderSession;
public int sampleRate;
private int baseChunkSize;
private int chunkCompress;
private int ldim;
public TextToSpeech(Config config, UnicodeProcessor textProcessor,
OrtSession dpSession, OrtSession textEncSession,
OrtSession vectorEstSession, OrtSession vocoderSession) {
this.config = config;
this.textProcessor = textProcessor;
this.dpSession = dpSession;
this.textEncSession = textEncSession;
this.vectorEstSession = vectorEstSession;
this.vocoderSession = vocoderSession;
this.sampleRate = config.ae.sampleRate;
this.baseChunkSize = config.ae.baseChunkSize;
this.chunkCompress = config.ttl.chunkCompressFactor;
this.ldim = config.ttl.latentDim;
}
private TTSResult _infer(List<String> textList, List<String> langList, Style style, int totalStep, float speed, OrtEnvironment env)
throws OrtException {
int bsz = textList.size();
// Process text
UnicodeProcessor.TextProcessResult textResult = textProcessor.call(textList, langList);
long[][] textIds = textResult.textIds;
float[][][] textMask = textResult.textMask;
// Create tensors
OnnxTensor textIdsTensor = Helper.createLongTensor(textIds, env);
OnnxTensor textMaskTensor = Helper.createFloatTensor(textMask, env);
// Predict duration
Map<String, OnnxTensor> dpInputs = new HashMap<>();
dpInputs.put("text_ids", textIdsTensor);
dpInputs.put("style_dp", style.dpTensor);
dpInputs.put("text_mask", textMaskTensor);
OrtSession.Result dpResult = dpSession.run(dpInputs);
Object dpValue = dpResult.get(0).getValue();
float[] duration;
if (dpValue instanceof float[][]) {
duration = ((float[][]) dpValue)[0];
} else {
duration = (float[]) dpValue;
}
// Apply speed factor to duration
for (int i = 0; i < duration.length; i++) {
duration[i] /= speed;
}
// Encode text
Map<String, OnnxTensor> textEncInputs = new HashMap<>();
textEncInputs.put("text_ids", textIdsTensor);
textEncInputs.put("style_ttl", style.ttlTensor);
textEncInputs.put("text_mask", textMaskTensor);
OrtSession.Result textEncResult = textEncSession.run(textEncInputs);
OnnxTensor textEmbTensor = (OnnxTensor) textEncResult.get(0);
// Sample noisy latent
NoisyLatentResult noisyLatentResult = sampleNoisyLatent(duration);
float[][][] xt = noisyLatentResult.noisyLatent;
float[][][] latentMask = noisyLatentResult.latentMask;
// Prepare constant tensors
float[] totalStepArray = new float[bsz];
Arrays.fill(totalStepArray, (float) totalStep);
OnnxTensor totalStepTensor = OnnxTensor.createTensor(env, totalStepArray);
// Denoising loop
for (int step = 0; step < totalStep; step++) {
float[] currentStepArray = new float[bsz];
Arrays.fill(currentStepArray, (float) step);
OnnxTensor currentStepTensor = OnnxTensor.createTensor(env, currentStepArray);
OnnxTensor noisyLatentTensor = Helper.createFloatTensor(xt, env);
OnnxTensor latentMaskTensor = Helper.createFloatTensor(latentMask, env);
OnnxTensor textMaskTensor2 = Helper.createFloatTensor(textMask, env);
Map<String, OnnxTensor> vectorEstInputs = new HashMap<>();
vectorEstInputs.put("noisy_latent", noisyLatentTensor);
vectorEstInputs.put("text_emb", textEmbTensor);
vectorEstInputs.put("style_ttl", style.ttlTensor);
vectorEstInputs.put("latent_mask", latentMaskTensor);
vectorEstInputs.put("text_mask", textMaskTensor2);
vectorEstInputs.put("current_step", currentStepTensor);
vectorEstInputs.put("total_step", totalStepTensor);
OrtSession.Result vectorEstResult = vectorEstSession.run(vectorEstInputs);
float[][][] denoised = (float[][][]) vectorEstResult.get(0).getValue();
// Update latent
xt = denoised;
// Clean up
currentStepTensor.close();
noisyLatentTensor.close();
latentMaskTensor.close();
textMaskTensor2.close();
vectorEstResult.close();
}
// Generate waveform
OnnxTensor finalLatentTensor = Helper.createFloatTensor(xt, env);
Map<String, OnnxTensor> vocoderInputs = new HashMap<>();
vocoderInputs.put("latent", finalLatentTensor);
OrtSession.Result vocoderResult = vocoderSession.run(vocoderInputs);
float[][] wavBatch = (float[][]) vocoderResult.get(0).getValue();
// Flatten all batch audio into a single array for batch processing
int totalSamples = 0;
for (float[] w : wavBatch) {
totalSamples += w.length;
}
float[] wav = new float[totalSamples];
int offset = 0;
for (float[] w : wavBatch) {
System.arraycopy(w, 0, wav, offset, w.length);
offset += w.length;
}
// Clean up
textIdsTensor.close();
textMaskTensor.close();
dpResult.close();
textEncResult.close();
totalStepTensor.close();
finalLatentTensor.close();
vocoderResult.close();
return new TTSResult(wav, duration);
}
private NoisyLatentResult sampleNoisyLatent(float[] duration) {
int bsz = duration.length;
float maxDur = 0;
for (float d : duration) {
maxDur = Math.max(maxDur, d);
}
long wavLenMax = (long) (maxDur * sampleRate);
long[] wavLengths = new long[bsz];
for (int i = 0; i < bsz; i++) {
wavLengths[i] = (long) (duration[i] * sampleRate);
}
int chunkSize = baseChunkSize * chunkCompress;
int latentLen = (int) ((wavLenMax + chunkSize - 1) / chunkSize);
int latentDim = ldim * chunkCompress;
Random rng = new Random();
float[][][] noisyLatent = new float[bsz][latentDim][latentLen];
for (int b = 0; b < bsz; b++) {
for (int d = 0; d < latentDim; d++) {
for (int t = 0; t < latentLen; t++) {
// Box-Muller transform
double u1 = Math.max(1e-10, rng.nextDouble());
double u2 = rng.nextDouble();
noisyLatent[b][d][t] = (float) (Math.sqrt(-2.0 * Math.log(u1)) * Math.cos(2.0 * Math.PI * u2));
}
}
}
float[][][] latentMask = Helper.getLatentMask(wavLengths, config);
// Apply mask
for (int b = 0; b < bsz; b++) {
for (int d = 0; d < latentDim; d++) {
for (int t = 0; t < latentLen; t++) {
noisyLatent[b][d][t] *= latentMask[b][0][t];
}
}
}
return new NoisyLatentResult(noisyLatent, latentMask);
}
/**
* Synthesize speech from a single text with automatic chunking
*/
public TTSResult call(String text, String lang, Style style, int totalStep, float speed, float silenceDuration, OrtEnvironment env)
throws OrtException {
int maxLen = lang.equals("ko") ? 120 : 300;
List<String> chunks = Helper.chunkText(text, maxLen);
List<Float> wavCat = new ArrayList<>();
float durCat = 0.0f;
for (int i = 0; i < chunks.size(); i++) {
TTSResult result = _infer(Arrays.asList(chunks.get(i)), Arrays.asList(lang), style, totalStep, speed, env);
float dur = result.duration[0];
int wavLen = (int) (sampleRate * dur);
float[] wavChunk = new float[wavLen];
System.arraycopy(result.wav, 0, wavChunk, 0, Math.min(wavLen, result.wav.length));
if (i == 0) {
for (float val : wavChunk) {
wavCat.add(val);
}
durCat = dur;
} else {
int silenceLen = (int) (silenceDuration * sampleRate);
for (int j = 0; j < silenceLen; j++) {
wavCat.add(0.0f);
}
for (float val : wavChunk) {
wavCat.add(val);
}
durCat += silenceDuration + dur;
}
}
float[] wavArray = new float[wavCat.size()];
for (int i = 0; i < wavCat.size(); i++) {
wavArray[i] = wavCat.get(i);
}
return new TTSResult(wavArray, new float[]{durCat});
}
/**
* Batch synthesize speech from multiple texts
*/
public TTSResult batch(List<String> textList, List<String> langList, Style style, int totalStep, float speed, OrtEnvironment env)
throws OrtException {
return _infer(textList, langList, style, totalStep, speed, env);
}
public void close() throws OrtException {
if (dpSession != null) dpSession.close();
if (textEncSession != null) textEncSession.close();
if (vectorEstSession != null) vectorEstSession.close();
if (vocoderSession != null) vocoderSession.close();
}
}
/**
* Style holder class
*/
class Style {
OnnxTensor ttlTensor;
OnnxTensor dpTensor;
Style(OnnxTensor ttlTensor, OnnxTensor dpTensor) {
this.ttlTensor = ttlTensor;
this.dpTensor = dpTensor;
}
public void close() throws OrtException {
if (ttlTensor != null) ttlTensor.close();
if (dpTensor != null) dpTensor.close();
}
}
/**
* TTS result holder
*/
class TTSResult {
float[] wav;
float[] duration;
TTSResult(float[] wav, float[] duration) {
this.wav = wav;
this.duration = duration;
}
}
/**
* Noisy latent result holder
*/
class NoisyLatentResult {
float[][][] noisyLatent;
float[][][] latentMask;
NoisyLatentResult(float[][][] noisyLatent, float[][][] latentMask) {
this.noisyLatent = noisyLatent;
this.latentMask = latentMask;
}
}
/**
* Helper utility class
*/
public class Helper {
private static final int MAX_CHUNK_LENGTH = 300;
private static final String[] ABBREVIATIONS = {
"Dr.", "Mr.", "Mrs.", "Ms.", "Prof.", "Sr.", "Jr.",
"St.", "Ave.", "Rd.", "Blvd.", "Dept.", "Inc.", "Ltd.",
"Co.", "Corp.", "etc.", "vs.", "i.e.", "e.g.", "Ph.D."
};
/**
* Chunk text into smaller segments based on paragraphs and sentences
*/
public static List<String> chunkText(String text, int maxLen) {
if (maxLen == 0) {
maxLen = MAX_CHUNK_LENGTH;
}
text = text.trim();
if (text.isEmpty()) {
return Arrays.asList("");
}
// Split by paragraphs
String[] paragraphs = text.split("\\n\\s*\\n");
List<String> chunks = new ArrayList<>();
for (String para : paragraphs) {
para = para.trim();
if (para.isEmpty()) {
continue;
}
if (para.length() <= maxLen) {
chunks.add(para);
continue;
}
// Split by sentences
List<String> sentences = splitSentences(para);
StringBuilder current = new StringBuilder();
int currentLen = 0;
for (String sentence : sentences) {
sentence = sentence.trim();
if (sentence.isEmpty()) {
continue;
}
int sentenceLen = sentence.length();
if (sentenceLen > maxLen) {
// If sentence is longer than maxLen, split by comma or space
if (current.length() > 0) {
chunks.add(current.toString().trim());
current.setLength(0);
currentLen = 0;
}
// Try splitting by comma
String[] parts = sentence.split(",");
for (String part : parts) {
part = part.trim();
if (part.isEmpty()) {
continue;
}
int partLen = part.length();
if (partLen > maxLen) {
// Split by space as last resort
String[] words = part.split("\\s+");
StringBuilder wordChunk = new StringBuilder();
int wordChunkLen = 0;
for (String word : words) {
int wordLen = word.length();
if (wordChunkLen + wordLen + 1 > maxLen && wordChunk.length() > 0) {
chunks.add(wordChunk.toString().trim());
wordChunk.setLength(0);
wordChunkLen = 0;
}
if (wordChunk.length() > 0) {
wordChunk.append(" ");
wordChunkLen++;
}
wordChunk.append(word);
wordChunkLen += wordLen;
}
if (wordChunk.length() > 0) {
chunks.add(wordChunk.toString().trim());
}
} else {
if (currentLen + partLen + 1 > maxLen && current.length() > 0) {
chunks.add(current.toString().trim());
current.setLength(0);
currentLen = 0;
}
if (current.length() > 0) {
current.append(", ");
currentLen += 2;
}
current.append(part);
currentLen += partLen;
}
}
continue;
}
if (currentLen + sentenceLen + 1 > maxLen && current.length() > 0) {
chunks.add(current.toString().trim());
current.setLength(0);
currentLen = 0;
}
if (current.length() > 0) {
current.append(" ");
currentLen++;
}
current.append(sentence);
currentLen += sentenceLen;
}
if (current.length() > 0) {
chunks.add(current.toString().trim());
}
}
if (chunks.isEmpty()) {
return Arrays.asList("");
}
return chunks;
}
/**
* Split text into sentences, avoiding common abbreviations
*/
private static List<String> splitSentences(String text) {
// Build pattern that avoids abbreviations
StringBuilder abbrevPattern = new StringBuilder();
for (int i = 0; i < ABBREVIATIONS.length; i++) {
if (i > 0) abbrevPattern.append("|");
abbrevPattern.append(Pattern.quote(ABBREVIATIONS[i]));
}
// Match sentence endings, but not abbreviations
String patternStr = "(?<!(?:" + abbrevPattern.toString() + "))(?<=[.!?])\\s+";
Pattern pattern = Pattern.compile(patternStr);
return Arrays.asList(pattern.split(text));
}
/**
* Load voice style from JSON files
*/
public static Style loadVoiceStyle(List<String> voiceStylePaths, boolean verbose, OrtEnvironment env)
throws IOException, OrtException {
int bsz = voiceStylePaths.size();
// Read first file to get dimensions
ObjectMapper mapper = new ObjectMapper();
JsonNode firstRoot = mapper.readTree(new File(voiceStylePaths.get(0)));
long[] ttlDims = new long[3];
for (int i = 0; i < 3; i++) {
ttlDims[i] = firstRoot.get("style_ttl").get("dims").get(i).asLong();
}
long[] dpDims = new long[3];
for (int i = 0; i < 3; i++) {
dpDims[i] = firstRoot.get("style_dp").get("dims").get(i).asLong();
}
long ttlDim1 = ttlDims[1];
long ttlDim2 = ttlDims[2];
long dpDim1 = dpDims[1];
long dpDim2 = dpDims[2];
// Pre-allocate arrays with full batch size
int ttlSize = (int) (bsz * ttlDim1 * ttlDim2);
int dpSize = (int) (bsz * dpDim1 * dpDim2);
float[] ttlFlat = new float[ttlSize];
float[] dpFlat = new float[dpSize];
// Fill in the data
for (int i = 0; i < bsz; i++) {
JsonNode root = mapper.readTree(new File(voiceStylePaths.get(i)));
// Flatten TTL data
int ttlOffset = (int) (i * ttlDim1 * ttlDim2);
int idx = 0;
JsonNode ttlData = root.get("style_ttl").get("data");
for (JsonNode batch : ttlData) {
for (JsonNode row : batch) {
for (JsonNode val : row) {
ttlFlat[ttlOffset + idx++] = (float) val.asDouble();
}
}
}
// Flatten DP data
int dpOffset = (int) (i * dpDim1 * dpDim2);
idx = 0;
JsonNode dpData = root.get("style_dp").get("data");
for (JsonNode batch : dpData) {
for (JsonNode row : batch) {
for (JsonNode val : row) {
dpFlat[dpOffset + idx++] = (float) val.asDouble();
}
}
}
}
long[] ttlShape = {bsz, ttlDim1, ttlDim2};
long[] dpShape = {bsz, dpDim1, dpDim2};
OnnxTensor ttlTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(ttlFlat), ttlShape);
OnnxTensor dpTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(dpFlat), dpShape);
if (verbose) {
System.out.println("Loaded " + bsz + " voice styles\n");
}
return new Style(ttlTensor, dpTensor);
}
/**
* Load TTS components
*/
public static TextToSpeech loadTextToSpeech(String onnxDir, boolean useGpu, OrtEnvironment env)
throws IOException, OrtException {
if (useGpu) {
throw new RuntimeException("GPU mode is not supported yet");
}
System.out.println("Using CPU for inference\n");
// Load config
Config config = loadCfgs(onnxDir);
// Create session options
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
// Load models
OrtSession dpSession = env.createSession(onnxDir + "/duration_predictor.onnx", opts);
OrtSession textEncSession = env.createSession(onnxDir + "/text_encoder.onnx", opts);
OrtSession vectorEstSession = env.createSession(onnxDir + "/vector_estimator.onnx", opts);
OrtSession vocoderSession = env.createSession(onnxDir + "/vocoder.onnx", opts);
// Load text processor
UnicodeProcessor textProcessor = new UnicodeProcessor(onnxDir + "/unicode_indexer.json");
return new TextToSpeech(config, textProcessor, dpSession, textEncSession, vectorEstSession, vocoderSession);
}
/**
* Load configuration from JSON
*/
public static Config loadCfgs(String onnxDir) throws IOException {
ObjectMapper mapper = new ObjectMapper();
JsonNode root = mapper.readTree(new File(onnxDir + "/tts.json"));
Config config = new Config();
config.ae = new Config.AEConfig();
config.ae.sampleRate = root.get("ae").get("sample_rate").asInt();
config.ae.baseChunkSize = root.get("ae").get("base_chunk_size").asInt();
config.ttl = new Config.TTLConfig();
config.ttl.chunkCompressFactor = root.get("ttl").get("chunk_compress_factor").asInt();
config.ttl.latentDim = root.get("ttl").get("latent_dim").asInt();
return config;
}
/**
* Get latent mask from wav lengths
*/
public static float[][][] getLatentMask(long[] wavLengths, Config config) {
long baseChunkSize = config.ae.baseChunkSize;
long chunkCompressFactor = config.ttl.chunkCompressFactor;
long latentSize = baseChunkSize * chunkCompressFactor;
long[] latentLengths = new long[wavLengths.length];
long maxLen = 0;
for (int i = 0; i < wavLengths.length; i++) {
latentLengths[i] = (wavLengths[i] + latentSize - 1) / latentSize;
maxLen = Math.max(maxLen, latentLengths[i]);
}
float[][][] mask = new float[wavLengths.length][1][(int) maxLen];
for (int i = 0; i < wavLengths.length; i++) {
for (int j = 0; j < maxLen; j++) {
mask[i][0][j] = j < latentLengths[i] ? 1.0f : 0.0f;
}
}
return mask;
}
/**
* Write WAV file
*/
public static void writeWavFile(String filename, float[] audioData, int sampleRate) throws IOException {
// Convert float to byte array
byte[] bytes = new byte[audioData.length * 2];
ByteBuffer buffer = ByteBuffer.wrap(bytes);
buffer.order(ByteOrder.LITTLE_ENDIAN);
for (float sample : audioData) {
short val = (short) Math.max(-32768, Math.min(32767, sample * 32767));
buffer.putShort(val);
}
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
AudioFormat format = new AudioFormat(sampleRate, 16, 1, true, false);
AudioInputStream ais = new AudioInputStream(bais, format, audioData.length);
AudioSystem.write(ais, AudioFileFormat.Type.WAVE, new File(filename));
}
/**
* Sanitize filename (supports Unicode characters)
*/
public static String sanitizeFilename(String text, int maxLen) {
// Get first maxLen characters (code points, not chars for surrogate pairs)
int[] codePoints = text.codePoints().limit(maxLen).toArray();
StringBuilder result = new StringBuilder();
for (int codePoint : codePoints) {
if (Character.isLetterOrDigit(codePoint)) {
result.appendCodePoint(codePoint);
} else {
result.append('_');
}
}
return result.toString();
}
/**
* Timer utility
*/
public static <T> T timer(String name, java.util.function.Supplier<T> fn) {
long start = System.currentTimeMillis();
System.out.println(name + "...");
T result = fn.get();
long elapsed = System.currentTimeMillis() - start;
System.out.printf(" -> %s completed in %.2f sec\n", name, elapsed / 1000.0);
return result;
}
/**
* Create float tensor from 3D array
*/
public static OnnxTensor createFloatTensor(float[][][] array, OrtEnvironment env) throws OrtException {
int dim0 = array.length;
int dim1 = array[0].length;
int dim2 = array[0][0].length;
float[] flat = new float[dim0 * dim1 * dim2];
int idx = 0;
for (int i = 0; i < dim0; i++) {
for (int j = 0; j < dim1; j++) {
for (int k = 0; k < dim2; k++) {
flat[idx++] = array[i][j][k];
}
}
}
long[] shape = {dim0, dim1, dim2};
return OnnxTensor.createTensor(env, FloatBuffer.wrap(flat), shape);
}
/**
* Create long tensor from 2D array
*/
public static OnnxTensor createLongTensor(long[][] array, OrtEnvironment env) throws OrtException {
int dim0 = array.length;
int dim1 = array[0].length;
long[] flat = new long[dim0 * dim1];
int idx = 0;
for (int i = 0; i < dim0; i++) {
for (int j = 0; j < dim1; j++) {
flat[idx++] = array[i][j];
}
}
long[] shape = {dim0, dim1};
return OnnxTensor.createTensor(env, LongBuffer.wrap(flat), shape);
}
/**
* Load JSON long array
*/
public static long[] loadJsonLongArray(String filePath) throws IOException {
ObjectMapper mapper = new ObjectMapper();
JsonNode root = mapper.readTree(new File(filePath));
long[] result = new long[root.size()];
for (int i = 0; i < root.size(); i++) {
result[i] = root.get(i).asLong();
}
return result;
}
}