initial commit

This commit is contained in:
2026-01-25 18:58:40 +09:00
commit 77af47274c
101 changed files with 16247 additions and 0 deletions

35
java/.gitignore vendored Normal file
View File

@@ -0,0 +1,35 @@
# Maven
target/
pom.xml.tag
pom.xml.releaseBackup
pom.xml.versionsBackup
pom.xml.next
release.properties
dependency-reduced-pom.xml
buildNumber.properties
.mvn/timing.properties
.mvn/wrapper/maven-wrapper.jar
# Compiled class files
*.class
# IntelliJ IDEA
.idea/
*.iml
*.iws
*.ipr
# Eclipse
.classpath
.project
.settings/
# VS Code
.vscode/
# Results
results/*.wav
# Mac
.DS_Store

183
java/ExampleONNX.java Normal file
View File

@@ -0,0 +1,183 @@
import ai.onnxruntime.*;
import java.io.File;
import java.util.*;
/**
* TTS Inference Example with ONNX Runtime (Java)
*/
public class ExampleONNX {
/**
* Command line arguments
*/
static class Args {
boolean useGpu = false;
String onnxDir = "assets/onnx";
int totalStep = 5;
float speed = 1.05f;
int nTest = 4;
List<String> voiceStyle = Arrays.asList("assets/voice_styles/M1.json");
List<String> text = Arrays.asList(
"This morning, I took a walk in the park, and the sound of the birds and the breeze was so pleasant that I stopped for a long time just to listen."
);
List<String> lang = Arrays.asList("en");
String saveDir = "results";
boolean batch = false;
}
/**
* Parse command line arguments
*/
private static Args parseArgs(String[] args) {
Args result = new Args();
for (int i = 0; i < args.length; i++) {
switch (args[i]) {
case "--use-gpu":
result.useGpu = true;
break;
case "--onnx-dir":
if (i + 1 < args.length) result.onnxDir = args[++i];
break;
case "--total-step":
if (i + 1 < args.length) result.totalStep = Integer.parseInt(args[++i]);
break;
case "--speed":
if (i + 1 < args.length) result.speed = Float.parseFloat(args[++i]);
break;
case "--n-test":
if (i + 1 < args.length) result.nTest = Integer.parseInt(args[++i]);
break;
case "--voice-style":
if (i + 1 < args.length) {
result.voiceStyle = Arrays.asList(args[++i].split(","));
}
break;
case "--text":
if (i + 1 < args.length) {
result.text = Arrays.asList(args[++i].split("\\|"));
}
break;
case "--lang":
if (i + 1 < args.length) {
result.lang = Arrays.asList(args[++i].split(","));
}
break;
case "--save-dir":
if (i + 1 < args.length) result.saveDir = args[++i];
break;
case "--batch":
result.batch = true;
break;
}
}
return result;
}
/**
* Main inference function
*/
public static void main(String[] args) {
try {
System.out.println("=== TTS Inference with ONNX Runtime (Java) ===\n");
// --- 1. Parse arguments --- //
Args parsedArgs = parseArgs(args);
int totalStep = parsedArgs.totalStep;
float speed = parsedArgs.speed;
int nTest = parsedArgs.nTest;
String saveDir = parsedArgs.saveDir;
List<String> voiceStylePaths = parsedArgs.voiceStyle;
List<String> textList = parsedArgs.text;
List<String> langList = parsedArgs.lang;
boolean batch = parsedArgs.batch;
if (batch) {
if (voiceStylePaths.size() != textList.size()) {
throw new RuntimeException("Number of voice styles (" + voiceStylePaths.size() +
") must match number of texts (" + textList.size() + ")");
}
if (langList.size() != textList.size()) {
throw new RuntimeException("Number of languages (" + langList.size() +
") must match number of texts (" + textList.size() + ")");
}
}
int bsz = voiceStylePaths.size();
OrtEnvironment env = OrtEnvironment.getEnvironment();
// --- 2. Load TTS components --- //
TextToSpeech textToSpeech = Helper.loadTextToSpeech(parsedArgs.onnxDir, parsedArgs.useGpu, env);
// --- 3. Load voice styles --- //
Style style = Helper.loadVoiceStyle(voiceStylePaths, true, env);
// --- 4. Synthesize speech --- //
File saveDirFile = new File(saveDir);
if (!saveDirFile.exists()) {
saveDirFile.mkdirs();
}
for (int n = 0; n < nTest; n++) {
System.out.println("\n[" + (n + 1) + "/" + nTest + "] Starting synthesis...");
TTSResult ttsResult;
if (batch) {
ttsResult = Helper.timer("Generating speech from text", () -> {
try {
return textToSpeech.batch(textList, langList, style, totalStep, speed, env);
} catch (Exception e) {
throw new RuntimeException(e);
}
});
} else {
ttsResult = Helper.timer("Generating speech from text", () -> {
try {
return textToSpeech.call(textList.get(0), langList.get(0), style, totalStep, speed, 0.3f, env);
} catch (Exception e) {
throw new RuntimeException(e);
}
});
}
float[] wav = ttsResult.wav;
float[] duration = ttsResult.duration;
// Save outputs
for (int i = 0; i < bsz; i++) {
String fname = Helper.sanitizeFilename(textList.get(i), 20) + "_" + (n + 1) + ".wav";
float[] wavOut;
if (batch) {
int wavLen = wav.length / bsz;
int actualLen = (int) (textToSpeech.sampleRate * duration[i]);
wavOut = new float[actualLen];
System.arraycopy(wav, i * wavLen, wavOut, 0, Math.min(actualLen, wavLen));
} else {
// For non-batch mode, wav is a single concatenated audio
int actualLen = (int) (textToSpeech.sampleRate * duration[0]);
wavOut = new float[Math.min(actualLen, wav.length)];
System.arraycopy(wav, 0, wavOut, 0, wavOut.length);
}
String outputPath = saveDir + "/" + fname;
Helper.writeWavFile(outputPath, wavOut, textToSpeech.sampleRate);
System.out.println("Saved: " + outputPath);
}
}
// Clean up
style.close();
textToSpeech.close();
System.out.println("\n=== Synthesis completed successfully! ===");
} catch (Exception e) {
System.err.println("Error during inference: " + e.getMessage());
e.printStackTrace();
System.exit(1);
}
}
}

955
java/Helper.java Normal file
View File

@@ -0,0 +1,955 @@
import ai.onnxruntime.*;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import javax.sound.sampled.AudioFileFormat;
import javax.sound.sampled.AudioFormat;
import javax.sound.sampled.AudioInputStream;
import javax.sound.sampled.AudioSystem;
import java.io.*;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.LongBuffer;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.text.Normalizer;
import java.util.*;
import java.util.regex.Pattern;
import java.util.regex.Matcher;
/**
* Available languages for multilingual TTS
*/
class Languages {
public static final List<String> AVAILABLE = Arrays.asList("en", "ko", "es", "pt", "fr");
public static boolean isValid(String lang) {
return AVAILABLE.contains(lang);
}
}
/**
* Configuration classes
*/
class Config {
static class AEConfig {
int sampleRate;
int baseChunkSize;
}
static class TTLConfig {
int chunkCompressFactor;
int latentDim;
}
AEConfig ae;
TTLConfig ttl;
}
/**
* Voice Style Data from JSON
*/
class VoiceStyleData {
static class StyleData {
float[][][] data;
long[] dims;
String type;
}
StyleData styleTtl;
StyleData styleDp;
}
/**
* Unicode text processor
*/
class UnicodeProcessor {
private long[] indexer;
public UnicodeProcessor(String unicodeIndexerJsonPath) throws IOException {
this.indexer = Helper.loadJsonLongArray(unicodeIndexerJsonPath);
}
private static String removeEmojis(String text) {
StringBuilder result = new StringBuilder();
for (int i = 0; i < text.length(); i++) {
int codePoint;
if (Character.isHighSurrogate(text.charAt(i)) && i + 1 < text.length() && Character.isLowSurrogate(text.charAt(i + 1))) {
codePoint = Character.codePointAt(text, i);
i++; // Skip the low surrogate
} else {
codePoint = text.charAt(i);
}
// Check if code point is in emoji ranges
boolean isEmoji = (codePoint >= 0x1F600 && codePoint <= 0x1F64F) ||
(codePoint >= 0x1F300 && codePoint <= 0x1F5FF) ||
(codePoint >= 0x1F680 && codePoint <= 0x1F6FF) ||
(codePoint >= 0x1F700 && codePoint <= 0x1F77F) ||
(codePoint >= 0x1F780 && codePoint <= 0x1F7FF) ||
(codePoint >= 0x1F800 && codePoint <= 0x1F8FF) ||
(codePoint >= 0x1F900 && codePoint <= 0x1F9FF) ||
(codePoint >= 0x1FA00 && codePoint <= 0x1FA6F) ||
(codePoint >= 0x1FA70 && codePoint <= 0x1FAFF) ||
(codePoint >= 0x2600 && codePoint <= 0x26FF) ||
(codePoint >= 0x2700 && codePoint <= 0x27BF) ||
(codePoint >= 0x1F1E6 && codePoint <= 0x1F1FF);
if (!isEmoji) {
if (codePoint > 0xFFFF) {
result.append(Character.toChars(codePoint));
} else {
result.append((char) codePoint);
}
}
}
return result.toString();
}
public TextProcessResult call(List<String> textList, List<String> langList) {
List<String> processedTexts = new ArrayList<>();
for (int i = 0; i < textList.size(); i++) {
processedTexts.add(preprocessText(textList.get(i), langList.get(i)));
}
// Convert texts to unicode values first to get correct character counts
List<int[]> allUnicodeVals = new ArrayList<>();
for (String text : processedTexts) {
allUnicodeVals.add(textToUnicodeValues(text));
}
int[] textIdsLengths = new int[processedTexts.size()];
int maxLen = 0;
for (int i = 0; i < allUnicodeVals.size(); i++) {
textIdsLengths[i] = allUnicodeVals.get(i).length; // Use code point count, not char count
maxLen = Math.max(maxLen, textIdsLengths[i]);
}
long[][] textIds = new long[processedTexts.size()][maxLen];
for (int i = 0; i < allUnicodeVals.size(); i++) {
int[] unicodeVals = allUnicodeVals.get(i);
for (int j = 0; j < unicodeVals.length; j++) {
textIds[i][j] = indexer[unicodeVals[j]];
}
}
float[][][] textMask = getTextMask(textIdsLengths);
return new TextProcessResult(textIds, textMask);
}
private String preprocessText(String text, String lang) {
// TODO: Need advanced normalizer for better performance
text = Normalizer.normalize(text, Normalizer.Form.NFKD);
// Remove emojis (wide Unicode range)
// Java Pattern doesn't support \x{...} syntax for Unicode above \uFFFF
// Use character filtering instead
text = removeEmojis(text);
// Replace various dashes and symbols
Map<String, String> replacements = new HashMap<>();
replacements.put("", "-"); // en dash
replacements.put("", "-"); // non-breaking hyphen
replacements.put("", "-"); // em dash
replacements.put("_", " "); // underscore
replacements.put("\u201C", "\""); // left double quote
replacements.put("\u201D", "\""); // right double quote
replacements.put("\u2018", "'"); // left single quote
replacements.put("\u2019", "'"); // right single quote
replacements.put("´", "'"); // acute accent
replacements.put("`", "'"); // grave accent
replacements.put("[", " "); // left bracket
replacements.put("]", " "); // right bracket
replacements.put("|", " "); // vertical bar
replacements.put("/", " "); // slash
replacements.put("#", " "); // hash
replacements.put("", " "); // right arrow
replacements.put("", " "); // left arrow
for (Map.Entry<String, String> entry : replacements.entrySet()) {
text = text.replace(entry.getKey(), entry.getValue());
}
// Remove special symbols
text = text.replaceAll("[♥☆♡©\\\\]", "");
// Replace known expressions
Map<String, String> exprReplacements = new HashMap<>();
exprReplacements.put("@", " at ");
exprReplacements.put("e.g.,", "for example, ");
exprReplacements.put("i.e.,", "that is, ");
for (Map.Entry<String, String> entry : exprReplacements.entrySet()) {
text = text.replace(entry.getKey(), entry.getValue());
}
// Fix spacing around punctuation
text = text.replaceAll(" ,", ",");
text = text.replaceAll(" \\.", ".");
text = text.replaceAll(" !", "!");
text = text.replaceAll(" \\?", "?");
text = text.replaceAll(" ;", ";");
text = text.replaceAll(" :", ":");
text = text.replaceAll(" '", "'");
// Remove duplicate quotes
while (text.contains("\"\"")) {
text = text.replace("\"\"", "\"");
}
while (text.contains("''")) {
text = text.replace("''", "'");
}
while (text.contains("``")) {
text = text.replace("``", "`");
}
// Remove extra spaces
text = text.replaceAll("\\s+", " ").trim();
// If text doesn't end with punctuation, quotes, or closing brackets, add a period
if (!text.matches(".*[.!?;:,'\"\\u201C\\u201D\\u2018\\u2019)\\]}…。」』】〉》›»]$")) {
text += ".";
}
// Validate language
if (!Languages.isValid(lang)) {
throw new IllegalArgumentException("Invalid language: " + lang + ". Available: " + Languages.AVAILABLE);
}
// Wrap text with language tags
text = "<" + lang + ">" + text + "</" + lang + ">";
return text;
}
private int[] textToUnicodeValues(String text) {
// Use codePoints() stream to correctly handle surrogate pairs
return text.codePoints().toArray();
}
private float[][][] getTextMask(int[] lengths) {
int bsz = lengths.length;
int maxLen = 0;
for (int len : lengths) {
maxLen = Math.max(maxLen, len);
}
float[][][] mask = new float[bsz][1][maxLen];
for (int i = 0; i < bsz; i++) {
for (int j = 0; j < maxLen; j++) {
mask[i][0][j] = j < lengths[i] ? 1.0f : 0.0f;
}
}
return mask;
}
static class TextProcessResult {
long[][] textIds;
float[][][] textMask;
TextProcessResult(long[][] textIds, float[][][] textMask) {
this.textIds = textIds;
this.textMask = textMask;
}
}
}
/**
* Text-to-Speech inference class
*/
class TextToSpeech {
private Config config;
private UnicodeProcessor textProcessor;
private OrtSession dpSession;
private OrtSession textEncSession;
private OrtSession vectorEstSession;
private OrtSession vocoderSession;
public int sampleRate;
private int baseChunkSize;
private int chunkCompress;
private int ldim;
public TextToSpeech(Config config, UnicodeProcessor textProcessor,
OrtSession dpSession, OrtSession textEncSession,
OrtSession vectorEstSession, OrtSession vocoderSession) {
this.config = config;
this.textProcessor = textProcessor;
this.dpSession = dpSession;
this.textEncSession = textEncSession;
this.vectorEstSession = vectorEstSession;
this.vocoderSession = vocoderSession;
this.sampleRate = config.ae.sampleRate;
this.baseChunkSize = config.ae.baseChunkSize;
this.chunkCompress = config.ttl.chunkCompressFactor;
this.ldim = config.ttl.latentDim;
}
private TTSResult _infer(List<String> textList, List<String> langList, Style style, int totalStep, float speed, OrtEnvironment env)
throws OrtException {
int bsz = textList.size();
// Process text
UnicodeProcessor.TextProcessResult textResult = textProcessor.call(textList, langList);
long[][] textIds = textResult.textIds;
float[][][] textMask = textResult.textMask;
// Create tensors
OnnxTensor textIdsTensor = Helper.createLongTensor(textIds, env);
OnnxTensor textMaskTensor = Helper.createFloatTensor(textMask, env);
// Predict duration
Map<String, OnnxTensor> dpInputs = new HashMap<>();
dpInputs.put("text_ids", textIdsTensor);
dpInputs.put("style_dp", style.dpTensor);
dpInputs.put("text_mask", textMaskTensor);
OrtSession.Result dpResult = dpSession.run(dpInputs);
Object dpValue = dpResult.get(0).getValue();
float[] duration;
if (dpValue instanceof float[][]) {
duration = ((float[][]) dpValue)[0];
} else {
duration = (float[]) dpValue;
}
// Apply speed factor to duration
for (int i = 0; i < duration.length; i++) {
duration[i] /= speed;
}
// Encode text
Map<String, OnnxTensor> textEncInputs = new HashMap<>();
textEncInputs.put("text_ids", textIdsTensor);
textEncInputs.put("style_ttl", style.ttlTensor);
textEncInputs.put("text_mask", textMaskTensor);
OrtSession.Result textEncResult = textEncSession.run(textEncInputs);
OnnxTensor textEmbTensor = (OnnxTensor) textEncResult.get(0);
// Sample noisy latent
NoisyLatentResult noisyLatentResult = sampleNoisyLatent(duration);
float[][][] xt = noisyLatentResult.noisyLatent;
float[][][] latentMask = noisyLatentResult.latentMask;
// Prepare constant tensors
float[] totalStepArray = new float[bsz];
Arrays.fill(totalStepArray, (float) totalStep);
OnnxTensor totalStepTensor = OnnxTensor.createTensor(env, totalStepArray);
// Denoising loop
for (int step = 0; step < totalStep; step++) {
float[] currentStepArray = new float[bsz];
Arrays.fill(currentStepArray, (float) step);
OnnxTensor currentStepTensor = OnnxTensor.createTensor(env, currentStepArray);
OnnxTensor noisyLatentTensor = Helper.createFloatTensor(xt, env);
OnnxTensor latentMaskTensor = Helper.createFloatTensor(latentMask, env);
OnnxTensor textMaskTensor2 = Helper.createFloatTensor(textMask, env);
Map<String, OnnxTensor> vectorEstInputs = new HashMap<>();
vectorEstInputs.put("noisy_latent", noisyLatentTensor);
vectorEstInputs.put("text_emb", textEmbTensor);
vectorEstInputs.put("style_ttl", style.ttlTensor);
vectorEstInputs.put("latent_mask", latentMaskTensor);
vectorEstInputs.put("text_mask", textMaskTensor2);
vectorEstInputs.put("current_step", currentStepTensor);
vectorEstInputs.put("total_step", totalStepTensor);
OrtSession.Result vectorEstResult = vectorEstSession.run(vectorEstInputs);
float[][][] denoised = (float[][][]) vectorEstResult.get(0).getValue();
// Update latent
xt = denoised;
// Clean up
currentStepTensor.close();
noisyLatentTensor.close();
latentMaskTensor.close();
textMaskTensor2.close();
vectorEstResult.close();
}
// Generate waveform
OnnxTensor finalLatentTensor = Helper.createFloatTensor(xt, env);
Map<String, OnnxTensor> vocoderInputs = new HashMap<>();
vocoderInputs.put("latent", finalLatentTensor);
OrtSession.Result vocoderResult = vocoderSession.run(vocoderInputs);
float[][] wavBatch = (float[][]) vocoderResult.get(0).getValue();
// Flatten all batch audio into a single array for batch processing
int totalSamples = 0;
for (float[] w : wavBatch) {
totalSamples += w.length;
}
float[] wav = new float[totalSamples];
int offset = 0;
for (float[] w : wavBatch) {
System.arraycopy(w, 0, wav, offset, w.length);
offset += w.length;
}
// Clean up
textIdsTensor.close();
textMaskTensor.close();
dpResult.close();
textEncResult.close();
totalStepTensor.close();
finalLatentTensor.close();
vocoderResult.close();
return new TTSResult(wav, duration);
}
private NoisyLatentResult sampleNoisyLatent(float[] duration) {
int bsz = duration.length;
float maxDur = 0;
for (float d : duration) {
maxDur = Math.max(maxDur, d);
}
long wavLenMax = (long) (maxDur * sampleRate);
long[] wavLengths = new long[bsz];
for (int i = 0; i < bsz; i++) {
wavLengths[i] = (long) (duration[i] * sampleRate);
}
int chunkSize = baseChunkSize * chunkCompress;
int latentLen = (int) ((wavLenMax + chunkSize - 1) / chunkSize);
int latentDim = ldim * chunkCompress;
Random rng = new Random();
float[][][] noisyLatent = new float[bsz][latentDim][latentLen];
for (int b = 0; b < bsz; b++) {
for (int d = 0; d < latentDim; d++) {
for (int t = 0; t < latentLen; t++) {
// Box-Muller transform
double u1 = Math.max(1e-10, rng.nextDouble());
double u2 = rng.nextDouble();
noisyLatent[b][d][t] = (float) (Math.sqrt(-2.0 * Math.log(u1)) * Math.cos(2.0 * Math.PI * u2));
}
}
}
float[][][] latentMask = Helper.getLatentMask(wavLengths, config);
// Apply mask
for (int b = 0; b < bsz; b++) {
for (int d = 0; d < latentDim; d++) {
for (int t = 0; t < latentLen; t++) {
noisyLatent[b][d][t] *= latentMask[b][0][t];
}
}
}
return new NoisyLatentResult(noisyLatent, latentMask);
}
/**
* Synthesize speech from a single text with automatic chunking
*/
public TTSResult call(String text, String lang, Style style, int totalStep, float speed, float silenceDuration, OrtEnvironment env)
throws OrtException {
int maxLen = lang.equals("ko") ? 120 : 300;
List<String> chunks = Helper.chunkText(text, maxLen);
List<Float> wavCat = new ArrayList<>();
float durCat = 0.0f;
for (int i = 0; i < chunks.size(); i++) {
TTSResult result = _infer(Arrays.asList(chunks.get(i)), Arrays.asList(lang), style, totalStep, speed, env);
float dur = result.duration[0];
int wavLen = (int) (sampleRate * dur);
float[] wavChunk = new float[wavLen];
System.arraycopy(result.wav, 0, wavChunk, 0, Math.min(wavLen, result.wav.length));
if (i == 0) {
for (float val : wavChunk) {
wavCat.add(val);
}
durCat = dur;
} else {
int silenceLen = (int) (silenceDuration * sampleRate);
for (int j = 0; j < silenceLen; j++) {
wavCat.add(0.0f);
}
for (float val : wavChunk) {
wavCat.add(val);
}
durCat += silenceDuration + dur;
}
}
float[] wavArray = new float[wavCat.size()];
for (int i = 0; i < wavCat.size(); i++) {
wavArray[i] = wavCat.get(i);
}
return new TTSResult(wavArray, new float[]{durCat});
}
/**
* Batch synthesize speech from multiple texts
*/
public TTSResult batch(List<String> textList, List<String> langList, Style style, int totalStep, float speed, OrtEnvironment env)
throws OrtException {
return _infer(textList, langList, style, totalStep, speed, env);
}
public void close() throws OrtException {
if (dpSession != null) dpSession.close();
if (textEncSession != null) textEncSession.close();
if (vectorEstSession != null) vectorEstSession.close();
if (vocoderSession != null) vocoderSession.close();
}
}
/**
* Style holder class
*/
class Style {
OnnxTensor ttlTensor;
OnnxTensor dpTensor;
Style(OnnxTensor ttlTensor, OnnxTensor dpTensor) {
this.ttlTensor = ttlTensor;
this.dpTensor = dpTensor;
}
public void close() throws OrtException {
if (ttlTensor != null) ttlTensor.close();
if (dpTensor != null) dpTensor.close();
}
}
/**
* TTS result holder
*/
class TTSResult {
float[] wav;
float[] duration;
TTSResult(float[] wav, float[] duration) {
this.wav = wav;
this.duration = duration;
}
}
/**
* Noisy latent result holder
*/
class NoisyLatentResult {
float[][][] noisyLatent;
float[][][] latentMask;
NoisyLatentResult(float[][][] noisyLatent, float[][][] latentMask) {
this.noisyLatent = noisyLatent;
this.latentMask = latentMask;
}
}
/**
* Helper utility class
*/
public class Helper {
private static final int MAX_CHUNK_LENGTH = 300;
private static final String[] ABBREVIATIONS = {
"Dr.", "Mr.", "Mrs.", "Ms.", "Prof.", "Sr.", "Jr.",
"St.", "Ave.", "Rd.", "Blvd.", "Dept.", "Inc.", "Ltd.",
"Co.", "Corp.", "etc.", "vs.", "i.e.", "e.g.", "Ph.D."
};
/**
* Chunk text into smaller segments based on paragraphs and sentences
*/
public static List<String> chunkText(String text, int maxLen) {
if (maxLen == 0) {
maxLen = MAX_CHUNK_LENGTH;
}
text = text.trim();
if (text.isEmpty()) {
return Arrays.asList("");
}
// Split by paragraphs
String[] paragraphs = text.split("\\n\\s*\\n");
List<String> chunks = new ArrayList<>();
for (String para : paragraphs) {
para = para.trim();
if (para.isEmpty()) {
continue;
}
if (para.length() <= maxLen) {
chunks.add(para);
continue;
}
// Split by sentences
List<String> sentences = splitSentences(para);
StringBuilder current = new StringBuilder();
int currentLen = 0;
for (String sentence : sentences) {
sentence = sentence.trim();
if (sentence.isEmpty()) {
continue;
}
int sentenceLen = sentence.length();
if (sentenceLen > maxLen) {
// If sentence is longer than maxLen, split by comma or space
if (current.length() > 0) {
chunks.add(current.toString().trim());
current.setLength(0);
currentLen = 0;
}
// Try splitting by comma
String[] parts = sentence.split(",");
for (String part : parts) {
part = part.trim();
if (part.isEmpty()) {
continue;
}
int partLen = part.length();
if (partLen > maxLen) {
// Split by space as last resort
String[] words = part.split("\\s+");
StringBuilder wordChunk = new StringBuilder();
int wordChunkLen = 0;
for (String word : words) {
int wordLen = word.length();
if (wordChunkLen + wordLen + 1 > maxLen && wordChunk.length() > 0) {
chunks.add(wordChunk.toString().trim());
wordChunk.setLength(0);
wordChunkLen = 0;
}
if (wordChunk.length() > 0) {
wordChunk.append(" ");
wordChunkLen++;
}
wordChunk.append(word);
wordChunkLen += wordLen;
}
if (wordChunk.length() > 0) {
chunks.add(wordChunk.toString().trim());
}
} else {
if (currentLen + partLen + 1 > maxLen && current.length() > 0) {
chunks.add(current.toString().trim());
current.setLength(0);
currentLen = 0;
}
if (current.length() > 0) {
current.append(", ");
currentLen += 2;
}
current.append(part);
currentLen += partLen;
}
}
continue;
}
if (currentLen + sentenceLen + 1 > maxLen && current.length() > 0) {
chunks.add(current.toString().trim());
current.setLength(0);
currentLen = 0;
}
if (current.length() > 0) {
current.append(" ");
currentLen++;
}
current.append(sentence);
currentLen += sentenceLen;
}
if (current.length() > 0) {
chunks.add(current.toString().trim());
}
}
if (chunks.isEmpty()) {
return Arrays.asList("");
}
return chunks;
}
/**
* Split text into sentences, avoiding common abbreviations
*/
private static List<String> splitSentences(String text) {
// Build pattern that avoids abbreviations
StringBuilder abbrevPattern = new StringBuilder();
for (int i = 0; i < ABBREVIATIONS.length; i++) {
if (i > 0) abbrevPattern.append("|");
abbrevPattern.append(Pattern.quote(ABBREVIATIONS[i]));
}
// Match sentence endings, but not abbreviations
String patternStr = "(?<!(?:" + abbrevPattern.toString() + "))(?<=[.!?])\\s+";
Pattern pattern = Pattern.compile(patternStr);
return Arrays.asList(pattern.split(text));
}
/**
* Load voice style from JSON files
*/
public static Style loadVoiceStyle(List<String> voiceStylePaths, boolean verbose, OrtEnvironment env)
throws IOException, OrtException {
int bsz = voiceStylePaths.size();
// Read first file to get dimensions
ObjectMapper mapper = new ObjectMapper();
JsonNode firstRoot = mapper.readTree(new File(voiceStylePaths.get(0)));
long[] ttlDims = new long[3];
for (int i = 0; i < 3; i++) {
ttlDims[i] = firstRoot.get("style_ttl").get("dims").get(i).asLong();
}
long[] dpDims = new long[3];
for (int i = 0; i < 3; i++) {
dpDims[i] = firstRoot.get("style_dp").get("dims").get(i).asLong();
}
long ttlDim1 = ttlDims[1];
long ttlDim2 = ttlDims[2];
long dpDim1 = dpDims[1];
long dpDim2 = dpDims[2];
// Pre-allocate arrays with full batch size
int ttlSize = (int) (bsz * ttlDim1 * ttlDim2);
int dpSize = (int) (bsz * dpDim1 * dpDim2);
float[] ttlFlat = new float[ttlSize];
float[] dpFlat = new float[dpSize];
// Fill in the data
for (int i = 0; i < bsz; i++) {
JsonNode root = mapper.readTree(new File(voiceStylePaths.get(i)));
// Flatten TTL data
int ttlOffset = (int) (i * ttlDim1 * ttlDim2);
int idx = 0;
JsonNode ttlData = root.get("style_ttl").get("data");
for (JsonNode batch : ttlData) {
for (JsonNode row : batch) {
for (JsonNode val : row) {
ttlFlat[ttlOffset + idx++] = (float) val.asDouble();
}
}
}
// Flatten DP data
int dpOffset = (int) (i * dpDim1 * dpDim2);
idx = 0;
JsonNode dpData = root.get("style_dp").get("data");
for (JsonNode batch : dpData) {
for (JsonNode row : batch) {
for (JsonNode val : row) {
dpFlat[dpOffset + idx++] = (float) val.asDouble();
}
}
}
}
long[] ttlShape = {bsz, ttlDim1, ttlDim2};
long[] dpShape = {bsz, dpDim1, dpDim2};
OnnxTensor ttlTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(ttlFlat), ttlShape);
OnnxTensor dpTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(dpFlat), dpShape);
if (verbose) {
System.out.println("Loaded " + bsz + " voice styles\n");
}
return new Style(ttlTensor, dpTensor);
}
/**
* Load TTS components
*/
public static TextToSpeech loadTextToSpeech(String onnxDir, boolean useGpu, OrtEnvironment env)
throws IOException, OrtException {
if (useGpu) {
throw new RuntimeException("GPU mode is not supported yet");
}
System.out.println("Using CPU for inference\n");
// Load config
Config config = loadCfgs(onnxDir);
// Create session options
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
// Load models
OrtSession dpSession = env.createSession(onnxDir + "/duration_predictor.onnx", opts);
OrtSession textEncSession = env.createSession(onnxDir + "/text_encoder.onnx", opts);
OrtSession vectorEstSession = env.createSession(onnxDir + "/vector_estimator.onnx", opts);
OrtSession vocoderSession = env.createSession(onnxDir + "/vocoder.onnx", opts);
// Load text processor
UnicodeProcessor textProcessor = new UnicodeProcessor(onnxDir + "/unicode_indexer.json");
return new TextToSpeech(config, textProcessor, dpSession, textEncSession, vectorEstSession, vocoderSession);
}
/**
* Load configuration from JSON
*/
public static Config loadCfgs(String onnxDir) throws IOException {
ObjectMapper mapper = new ObjectMapper();
JsonNode root = mapper.readTree(new File(onnxDir + "/tts.json"));
Config config = new Config();
config.ae = new Config.AEConfig();
config.ae.sampleRate = root.get("ae").get("sample_rate").asInt();
config.ae.baseChunkSize = root.get("ae").get("base_chunk_size").asInt();
config.ttl = new Config.TTLConfig();
config.ttl.chunkCompressFactor = root.get("ttl").get("chunk_compress_factor").asInt();
config.ttl.latentDim = root.get("ttl").get("latent_dim").asInt();
return config;
}
/**
* Get latent mask from wav lengths
*/
public static float[][][] getLatentMask(long[] wavLengths, Config config) {
long baseChunkSize = config.ae.baseChunkSize;
long chunkCompressFactor = config.ttl.chunkCompressFactor;
long latentSize = baseChunkSize * chunkCompressFactor;
long[] latentLengths = new long[wavLengths.length];
long maxLen = 0;
for (int i = 0; i < wavLengths.length; i++) {
latentLengths[i] = (wavLengths[i] + latentSize - 1) / latentSize;
maxLen = Math.max(maxLen, latentLengths[i]);
}
float[][][] mask = new float[wavLengths.length][1][(int) maxLen];
for (int i = 0; i < wavLengths.length; i++) {
for (int j = 0; j < maxLen; j++) {
mask[i][0][j] = j < latentLengths[i] ? 1.0f : 0.0f;
}
}
return mask;
}
/**
* Write WAV file
*/
public static void writeWavFile(String filename, float[] audioData, int sampleRate) throws IOException {
// Convert float to byte array
byte[] bytes = new byte[audioData.length * 2];
ByteBuffer buffer = ByteBuffer.wrap(bytes);
buffer.order(ByteOrder.LITTLE_ENDIAN);
for (float sample : audioData) {
short val = (short) Math.max(-32768, Math.min(32767, sample * 32767));
buffer.putShort(val);
}
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
AudioFormat format = new AudioFormat(sampleRate, 16, 1, true, false);
AudioInputStream ais = new AudioInputStream(bais, format, audioData.length);
AudioSystem.write(ais, AudioFileFormat.Type.WAVE, new File(filename));
}
/**
* Sanitize filename (supports Unicode characters)
*/
public static String sanitizeFilename(String text, int maxLen) {
// Get first maxLen characters (code points, not chars for surrogate pairs)
int[] codePoints = text.codePoints().limit(maxLen).toArray();
StringBuilder result = new StringBuilder();
for (int codePoint : codePoints) {
if (Character.isLetterOrDigit(codePoint)) {
result.appendCodePoint(codePoint);
} else {
result.append('_');
}
}
return result.toString();
}
/**
* Timer utility
*/
public static <T> T timer(String name, java.util.function.Supplier<T> fn) {
long start = System.currentTimeMillis();
System.out.println(name + "...");
T result = fn.get();
long elapsed = System.currentTimeMillis() - start;
System.out.printf(" -> %s completed in %.2f sec\n", name, elapsed / 1000.0);
return result;
}
/**
* Create float tensor from 3D array
*/
public static OnnxTensor createFloatTensor(float[][][] array, OrtEnvironment env) throws OrtException {
int dim0 = array.length;
int dim1 = array[0].length;
int dim2 = array[0][0].length;
float[] flat = new float[dim0 * dim1 * dim2];
int idx = 0;
for (int i = 0; i < dim0; i++) {
for (int j = 0; j < dim1; j++) {
for (int k = 0; k < dim2; k++) {
flat[idx++] = array[i][j][k];
}
}
}
long[] shape = {dim0, dim1, dim2};
return OnnxTensor.createTensor(env, FloatBuffer.wrap(flat), shape);
}
/**
* Create long tensor from 2D array
*/
public static OnnxTensor createLongTensor(long[][] array, OrtEnvironment env) throws OrtException {
int dim0 = array.length;
int dim1 = array[0].length;
long[] flat = new long[dim0 * dim1];
int idx = 0;
for (int i = 0; i < dim0; i++) {
for (int j = 0; j < dim1; j++) {
flat[idx++] = array[i][j];
}
}
long[] shape = {dim0, dim1};
return OnnxTensor.createTensor(env, LongBuffer.wrap(flat), shape);
}
/**
* Load JSON long array
*/
public static long[] loadJsonLongArray(String filePath) throws IOException {
ObjectMapper mapper = new ObjectMapper();
JsonNode root = mapper.readTree(new File(filePath));
long[] result = new long[root.size()];
for (int i = 0; i < root.size(); i++) {
result[i] = root.get(i).asLong();
}
return result;
}
}

130
java/README.md Normal file
View File

@@ -0,0 +1,130 @@
# TTS ONNX Inference Examples
This guide provides examples for running TTS inference using `ExampleONNX.java`.
## 📰 Update News
**2026.01.06** - 🎉 **Supertonic 2** released with multilingual support! Now supports English (`en`), Korean (`ko`), Spanish (`es`), Portuguese (`pt`), and French (`fr`). [Demo](https://huggingface.co/spaces/Supertone/supertonic-2) | [Models](https://huggingface.co/Supertone/supertonic-2)
**2025.12.10** - Added [6 new voice styles](https://huggingface.co/Supertone/supertonic/tree/b10dbaf18b316159be75b34d24f740008fddd381) (M3, M4, M5, F3, F4, F5). See [Voices](https://supertone-inc.github.io/supertonic-py/voices/) for details
**2025.12.08** - Optimized ONNX models via [OnnxSlim](https://github.com/inisis/OnnxSlim) now available on [Hugging Face Models](https://huggingface.co/Supertone/supertonic)
**2025.11.23** - Enhanced text preprocessing with comprehensive normalization, emoji removal, symbol replacement, and punctuation handling for improved synthesis quality.
**2025.11.19** - Added `--speed` parameter to control speech synthesis speed (default: 1.05, recommended range: 0.9-1.5).
**2025.11.19** - Added automatic text chunking for long-form inference. Long texts are split into chunks and synthesized with natural pauses.
## Installation
This project uses [Maven](https://maven.apache.org/) for dependency management.
### Prerequisites
- Java 11 or higher
- Maven 3.6 or higher
### Install dependencies
```bash
mvn clean install
```
## Basic Usage
### Example 1: Default Inference
Run inference with default settings:
```bash
mvn exec:java
```
This will use:
- Voice style: `assets/voice_styles/M1.json`
- Text: "This morning, I took a walk in the park, and the sound of the birds and the breeze was so pleasant that I stopped for a long time just to listen."
- Output directory: `results/`
- Total steps: 5
- Number of generations: 4
### Example 2: Batch Inference
Process multiple voice styles and texts at once:
```bash
mvn exec:java -Dexec.args="--batch --voice-style assets/voice_styles/M1.json,assets/voice_styles/F1.json --text 'The sun sets behind the mountains, painting the sky in shades of pink and orange.|오늘 아침에 공원을 산책했는데, 새소리와 바람 소리가 너무 기분 좋았어요.' --lang en,ko"
```
This will:
- Generate speech for 2 different voice-text-language pairs
- Use male voice (M1.json) for the first text in English
- Use female voice (F1.json) for the second text in Korean
- Process both samples in a single batch
### Example 3: High Quality Inference
Increase denoising steps for better quality:
```bash
mvn exec:java -Dexec.args="--total-step 10 --voice-style assets/voice_styles/M1.json --text 'Increasing the number of denoising steps improves the output fidelity and overall quality.'"
```
This will:
- Use 10 denoising steps instead of the default 5
- Produce higher quality output at the cost of slower inference
### Example 4: Long-Form Inference
The system automatically chunks long texts into manageable segments, synthesizes each segment separately, and concatenates them with natural pauses (0.3 seconds by default) into a single audio file. This happens by default when you don't use the `--batch` flag:
```bash
mvn exec:java -Dexec.args="--voice-style assets/voice_styles/M1.json --text 'This is a very long text that will be automatically split into multiple chunks. The system will process each chunk separately and then concatenate them together with natural pauses between segments. This ensures that even very long texts can be processed efficiently while maintaining natural speech flow and avoiding memory issues.'"
```
This will:
- Automatically split the text into chunks based on paragraph and sentence boundaries
- Synthesize each chunk separately
- Add 0.3 seconds of silence between chunks for natural pauses
- Concatenate all chunks into a single audio file
**Note**: Automatic text chunking is disabled when using `--batch` mode. In batch mode, each text is processed as-is without chunking.
**Tip**: If your text contains apostrophes, use escaping or run the JAR directly:
```bash
java -jar target/tts-example.jar --total-step 10 --text "Text with apostrophe's here"
```
## Building a Fat JAR
To create a standalone JAR with all dependencies:
```bash
mvn clean package
```
Then run it directly:
```bash
java -jar target/tts-example.jar
```
Or with arguments:
```bash
java -jar target/tts-example.jar --total-step 10 --text "Your custom text here"
```
## Available Arguments
| Argument | Type | Default | Description |
|----------|------|---------|-------------|
| `--use-gpu` | flag | False | Use GPU for inference (default: CPU) |
| `--onnx-dir` | str | `assets/onnx` | Path to ONNX model directory |
| `--total-step` | int | 5 | Number of denoising steps (higher = better quality, slower) |
| `--n-test` | int | 4 | Number of times to generate each sample |
| `--voice-style` | str+ | `assets/voice_styles/M1.json` | Voice style file path(s), comma-separated |
| `--text` | str+ | (long default text) | Text(s) to synthesize, pipe-separated |
| `--lang` | str+ | `en` | Language(s) for synthesis, comma-separated (en, ko, es, pt, fr) |
| `--save-dir` | str | `results` | Output directory |
| `--batch` | flag | False | Enable batch mode (multiple text-style pairs, disables automatic chunking) |
## Notes
- **Multilingual Support**: Use `--lang` to specify the language for each text. Available: `en` (English), `ko` (Korean), `es` (Spanish), `pt` (Portuguese), `fr` (French)
- **Batch Processing**: When using `--batch`, the number of `--voice-style`, `--text`, and `--lang` entries must match
- **Automatic Chunking**: Without `--batch`, long texts are automatically split and concatenated with 0.3s pauses
- **Quality vs Speed**: Higher `--total-step` values produce better quality but take longer
- **GPU Support**: GPU mode is not supported yet
- **Voice Styles**: Uses pre-extracted voice style JSON files for fast inference

110
java/pom.xml Normal file
View File

@@ -0,0 +1,110 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>ai.supertonic</groupId>
<artifactId>tts-onnx-java</artifactId>
<version>1.0.0</version>
<packaging>jar</packaging>
<name>TTS ONNX Java Example</name>
<description>Text-to-Speech inference using ONNX Runtime in Java</description>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
<onnxruntime.version>1.23.1</onnxruntime.version>
<jackson.version>2.15.2</jackson.version>
</properties>
<dependencies>
<!-- ONNX Runtime -->
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>${onnxruntime.version}</version>
</dependency>
<!-- Jackson for JSON parsing -->
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${jackson.version}</version>
</dependency>
<!-- JTransforms for Fast FFT -->
<dependency>
<groupId>com.github.wendykierp</groupId>
<artifactId>JTransforms</artifactId>
<version>3.1</version>
</dependency>
</dependencies>
<build>
<sourceDirectory>.</sourceDirectory>
<plugins>
<!-- Maven Compiler Plugin -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.11.0</version>
<configuration>
<source>11</source>
<target>11</target>
</configuration>
</plugin>
<!-- Maven Exec Plugin for running the example -->
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>exec-maven-plugin</artifactId>
<version>3.1.0</version>
<configuration>
<mainClass>ExampleONNX</mainClass>
</configuration>
</plugin>
<!-- Maven Jar Plugin -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<version>3.3.0</version>
<configuration>
<archive>
<manifest>
<mainClass>ExampleONNX</mainClass>
</manifest>
</archive>
</configuration>
</plugin>
<!-- Maven Shade Plugin for creating fat JAR -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>3.5.0</version>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<transformers>
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
<mainClass>ExampleONNX</mainClass>
</transformer>
</transformers>
<finalName>tts-example</finalName>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>