initial commit
This commit is contained in:
35
java/.gitignore
vendored
Normal file
35
java/.gitignore
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
# Maven
|
||||
target/
|
||||
pom.xml.tag
|
||||
pom.xml.releaseBackup
|
||||
pom.xml.versionsBackup
|
||||
pom.xml.next
|
||||
release.properties
|
||||
dependency-reduced-pom.xml
|
||||
buildNumber.properties
|
||||
.mvn/timing.properties
|
||||
.mvn/wrapper/maven-wrapper.jar
|
||||
|
||||
# Compiled class files
|
||||
*.class
|
||||
|
||||
# IntelliJ IDEA
|
||||
.idea/
|
||||
*.iml
|
||||
*.iws
|
||||
*.ipr
|
||||
|
||||
# Eclipse
|
||||
.classpath
|
||||
.project
|
||||
.settings/
|
||||
|
||||
# VS Code
|
||||
.vscode/
|
||||
|
||||
# Results
|
||||
results/*.wav
|
||||
|
||||
# Mac
|
||||
.DS_Store
|
||||
|
||||
183
java/ExampleONNX.java
Normal file
183
java/ExampleONNX.java
Normal file
@@ -0,0 +1,183 @@
|
||||
import ai.onnxruntime.*;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* TTS Inference Example with ONNX Runtime (Java)
|
||||
*/
|
||||
public class ExampleONNX {
|
||||
|
||||
/**
|
||||
* Command line arguments
|
||||
*/
|
||||
static class Args {
|
||||
boolean useGpu = false;
|
||||
String onnxDir = "assets/onnx";
|
||||
int totalStep = 5;
|
||||
float speed = 1.05f;
|
||||
int nTest = 4;
|
||||
List<String> voiceStyle = Arrays.asList("assets/voice_styles/M1.json");
|
||||
List<String> text = Arrays.asList(
|
||||
"This morning, I took a walk in the park, and the sound of the birds and the breeze was so pleasant that I stopped for a long time just to listen."
|
||||
);
|
||||
List<String> lang = Arrays.asList("en");
|
||||
String saveDir = "results";
|
||||
boolean batch = false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse command line arguments
|
||||
*/
|
||||
private static Args parseArgs(String[] args) {
|
||||
Args result = new Args();
|
||||
|
||||
for (int i = 0; i < args.length; i++) {
|
||||
switch (args[i]) {
|
||||
case "--use-gpu":
|
||||
result.useGpu = true;
|
||||
break;
|
||||
case "--onnx-dir":
|
||||
if (i + 1 < args.length) result.onnxDir = args[++i];
|
||||
break;
|
||||
case "--total-step":
|
||||
if (i + 1 < args.length) result.totalStep = Integer.parseInt(args[++i]);
|
||||
break;
|
||||
case "--speed":
|
||||
if (i + 1 < args.length) result.speed = Float.parseFloat(args[++i]);
|
||||
break;
|
||||
case "--n-test":
|
||||
if (i + 1 < args.length) result.nTest = Integer.parseInt(args[++i]);
|
||||
break;
|
||||
case "--voice-style":
|
||||
if (i + 1 < args.length) {
|
||||
result.voiceStyle = Arrays.asList(args[++i].split(","));
|
||||
}
|
||||
break;
|
||||
case "--text":
|
||||
if (i + 1 < args.length) {
|
||||
result.text = Arrays.asList(args[++i].split("\\|"));
|
||||
}
|
||||
break;
|
||||
case "--lang":
|
||||
if (i + 1 < args.length) {
|
||||
result.lang = Arrays.asList(args[++i].split(","));
|
||||
}
|
||||
break;
|
||||
case "--save-dir":
|
||||
if (i + 1 < args.length) result.saveDir = args[++i];
|
||||
break;
|
||||
case "--batch":
|
||||
result.batch = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Main inference function
|
||||
*/
|
||||
public static void main(String[] args) {
|
||||
try {
|
||||
System.out.println("=== TTS Inference with ONNX Runtime (Java) ===\n");
|
||||
|
||||
// --- 1. Parse arguments --- //
|
||||
Args parsedArgs = parseArgs(args);
|
||||
int totalStep = parsedArgs.totalStep;
|
||||
float speed = parsedArgs.speed;
|
||||
int nTest = parsedArgs.nTest;
|
||||
String saveDir = parsedArgs.saveDir;
|
||||
List<String> voiceStylePaths = parsedArgs.voiceStyle;
|
||||
List<String> textList = parsedArgs.text;
|
||||
List<String> langList = parsedArgs.lang;
|
||||
boolean batch = parsedArgs.batch;
|
||||
|
||||
if (batch) {
|
||||
if (voiceStylePaths.size() != textList.size()) {
|
||||
throw new RuntimeException("Number of voice styles (" + voiceStylePaths.size() +
|
||||
") must match number of texts (" + textList.size() + ")");
|
||||
}
|
||||
if (langList.size() != textList.size()) {
|
||||
throw new RuntimeException("Number of languages (" + langList.size() +
|
||||
") must match number of texts (" + textList.size() + ")");
|
||||
}
|
||||
}
|
||||
|
||||
int bsz = voiceStylePaths.size();
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
|
||||
// --- 2. Load TTS components --- //
|
||||
TextToSpeech textToSpeech = Helper.loadTextToSpeech(parsedArgs.onnxDir, parsedArgs.useGpu, env);
|
||||
|
||||
// --- 3. Load voice styles --- //
|
||||
Style style = Helper.loadVoiceStyle(voiceStylePaths, true, env);
|
||||
|
||||
// --- 4. Synthesize speech --- //
|
||||
File saveDirFile = new File(saveDir);
|
||||
if (!saveDirFile.exists()) {
|
||||
saveDirFile.mkdirs();
|
||||
}
|
||||
|
||||
for (int n = 0; n < nTest; n++) {
|
||||
System.out.println("\n[" + (n + 1) + "/" + nTest + "] Starting synthesis...");
|
||||
|
||||
TTSResult ttsResult;
|
||||
if (batch) {
|
||||
ttsResult = Helper.timer("Generating speech from text", () -> {
|
||||
try {
|
||||
return textToSpeech.batch(textList, langList, style, totalStep, speed, env);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
ttsResult = Helper.timer("Generating speech from text", () -> {
|
||||
try {
|
||||
return textToSpeech.call(textList.get(0), langList.get(0), style, totalStep, speed, 0.3f, env);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
float[] wav = ttsResult.wav;
|
||||
float[] duration = ttsResult.duration;
|
||||
|
||||
// Save outputs
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
String fname = Helper.sanitizeFilename(textList.get(i), 20) + "_" + (n + 1) + ".wav";
|
||||
float[] wavOut;
|
||||
|
||||
if (batch) {
|
||||
int wavLen = wav.length / bsz;
|
||||
int actualLen = (int) (textToSpeech.sampleRate * duration[i]);
|
||||
wavOut = new float[actualLen];
|
||||
System.arraycopy(wav, i * wavLen, wavOut, 0, Math.min(actualLen, wavLen));
|
||||
} else {
|
||||
// For non-batch mode, wav is a single concatenated audio
|
||||
int actualLen = (int) (textToSpeech.sampleRate * duration[0]);
|
||||
wavOut = new float[Math.min(actualLen, wav.length)];
|
||||
System.arraycopy(wav, 0, wavOut, 0, wavOut.length);
|
||||
}
|
||||
|
||||
String outputPath = saveDir + "/" + fname;
|
||||
Helper.writeWavFile(outputPath, wavOut, textToSpeech.sampleRate);
|
||||
System.out.println("Saved: " + outputPath);
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up
|
||||
style.close();
|
||||
textToSpeech.close();
|
||||
|
||||
System.out.println("\n=== Synthesis completed successfully! ===");
|
||||
|
||||
} catch (Exception e) {
|
||||
System.err.println("Error during inference: " + e.getMessage());
|
||||
e.printStackTrace();
|
||||
System.exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
955
java/Helper.java
Normal file
955
java/Helper.java
Normal file
@@ -0,0 +1,955 @@
|
||||
import ai.onnxruntime.*;
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
import javax.sound.sampled.AudioFileFormat;
|
||||
import javax.sound.sampled.AudioFormat;
|
||||
import javax.sound.sampled.AudioInputStream;
|
||||
import javax.sound.sampled.AudioSystem;
|
||||
import java.io.*;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.nio.LongBuffer;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import java.text.Normalizer;
|
||||
import java.util.*;
|
||||
import java.util.regex.Pattern;
|
||||
import java.util.regex.Matcher;
|
||||
|
||||
/**
|
||||
* Available languages for multilingual TTS
|
||||
*/
|
||||
class Languages {
|
||||
public static final List<String> AVAILABLE = Arrays.asList("en", "ko", "es", "pt", "fr");
|
||||
|
||||
public static boolean isValid(String lang) {
|
||||
return AVAILABLE.contains(lang);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration classes
|
||||
*/
|
||||
class Config {
|
||||
static class AEConfig {
|
||||
int sampleRate;
|
||||
int baseChunkSize;
|
||||
}
|
||||
|
||||
static class TTLConfig {
|
||||
int chunkCompressFactor;
|
||||
int latentDim;
|
||||
}
|
||||
|
||||
AEConfig ae;
|
||||
TTLConfig ttl;
|
||||
}
|
||||
|
||||
/**
|
||||
* Voice Style Data from JSON
|
||||
*/
|
||||
class VoiceStyleData {
|
||||
static class StyleData {
|
||||
float[][][] data;
|
||||
long[] dims;
|
||||
String type;
|
||||
}
|
||||
|
||||
StyleData styleTtl;
|
||||
StyleData styleDp;
|
||||
}
|
||||
|
||||
/**
|
||||
* Unicode text processor
|
||||
*/
|
||||
class UnicodeProcessor {
|
||||
private long[] indexer;
|
||||
|
||||
public UnicodeProcessor(String unicodeIndexerJsonPath) throws IOException {
|
||||
this.indexer = Helper.loadJsonLongArray(unicodeIndexerJsonPath);
|
||||
}
|
||||
|
||||
private static String removeEmojis(String text) {
|
||||
StringBuilder result = new StringBuilder();
|
||||
for (int i = 0; i < text.length(); i++) {
|
||||
int codePoint;
|
||||
if (Character.isHighSurrogate(text.charAt(i)) && i + 1 < text.length() && Character.isLowSurrogate(text.charAt(i + 1))) {
|
||||
codePoint = Character.codePointAt(text, i);
|
||||
i++; // Skip the low surrogate
|
||||
} else {
|
||||
codePoint = text.charAt(i);
|
||||
}
|
||||
|
||||
// Check if code point is in emoji ranges
|
||||
boolean isEmoji = (codePoint >= 0x1F600 && codePoint <= 0x1F64F) ||
|
||||
(codePoint >= 0x1F300 && codePoint <= 0x1F5FF) ||
|
||||
(codePoint >= 0x1F680 && codePoint <= 0x1F6FF) ||
|
||||
(codePoint >= 0x1F700 && codePoint <= 0x1F77F) ||
|
||||
(codePoint >= 0x1F780 && codePoint <= 0x1F7FF) ||
|
||||
(codePoint >= 0x1F800 && codePoint <= 0x1F8FF) ||
|
||||
(codePoint >= 0x1F900 && codePoint <= 0x1F9FF) ||
|
||||
(codePoint >= 0x1FA00 && codePoint <= 0x1FA6F) ||
|
||||
(codePoint >= 0x1FA70 && codePoint <= 0x1FAFF) ||
|
||||
(codePoint >= 0x2600 && codePoint <= 0x26FF) ||
|
||||
(codePoint >= 0x2700 && codePoint <= 0x27BF) ||
|
||||
(codePoint >= 0x1F1E6 && codePoint <= 0x1F1FF);
|
||||
|
||||
if (!isEmoji) {
|
||||
if (codePoint > 0xFFFF) {
|
||||
result.append(Character.toChars(codePoint));
|
||||
} else {
|
||||
result.append((char) codePoint);
|
||||
}
|
||||
}
|
||||
}
|
||||
return result.toString();
|
||||
}
|
||||
|
||||
public TextProcessResult call(List<String> textList, List<String> langList) {
|
||||
List<String> processedTexts = new ArrayList<>();
|
||||
for (int i = 0; i < textList.size(); i++) {
|
||||
processedTexts.add(preprocessText(textList.get(i), langList.get(i)));
|
||||
}
|
||||
|
||||
// Convert texts to unicode values first to get correct character counts
|
||||
List<int[]> allUnicodeVals = new ArrayList<>();
|
||||
for (String text : processedTexts) {
|
||||
allUnicodeVals.add(textToUnicodeValues(text));
|
||||
}
|
||||
|
||||
int[] textIdsLengths = new int[processedTexts.size()];
|
||||
int maxLen = 0;
|
||||
for (int i = 0; i < allUnicodeVals.size(); i++) {
|
||||
textIdsLengths[i] = allUnicodeVals.get(i).length; // Use code point count, not char count
|
||||
maxLen = Math.max(maxLen, textIdsLengths[i]);
|
||||
}
|
||||
|
||||
long[][] textIds = new long[processedTexts.size()][maxLen];
|
||||
for (int i = 0; i < allUnicodeVals.size(); i++) {
|
||||
int[] unicodeVals = allUnicodeVals.get(i);
|
||||
for (int j = 0; j < unicodeVals.length; j++) {
|
||||
textIds[i][j] = indexer[unicodeVals[j]];
|
||||
}
|
||||
}
|
||||
|
||||
float[][][] textMask = getTextMask(textIdsLengths);
|
||||
return new TextProcessResult(textIds, textMask);
|
||||
}
|
||||
|
||||
private String preprocessText(String text, String lang) {
|
||||
// TODO: Need advanced normalizer for better performance
|
||||
text = Normalizer.normalize(text, Normalizer.Form.NFKD);
|
||||
|
||||
// Remove emojis (wide Unicode range)
|
||||
// Java Pattern doesn't support \x{...} syntax for Unicode above \uFFFF
|
||||
// Use character filtering instead
|
||||
text = removeEmojis(text);
|
||||
|
||||
// Replace various dashes and symbols
|
||||
Map<String, String> replacements = new HashMap<>();
|
||||
replacements.put("–", "-"); // en dash
|
||||
replacements.put("‑", "-"); // non-breaking hyphen
|
||||
replacements.put("—", "-"); // em dash
|
||||
replacements.put("_", " "); // underscore
|
||||
replacements.put("\u201C", "\""); // left double quote
|
||||
replacements.put("\u201D", "\""); // right double quote
|
||||
replacements.put("\u2018", "'"); // left single quote
|
||||
replacements.put("\u2019", "'"); // right single quote
|
||||
replacements.put("´", "'"); // acute accent
|
||||
replacements.put("`", "'"); // grave accent
|
||||
replacements.put("[", " "); // left bracket
|
||||
replacements.put("]", " "); // right bracket
|
||||
replacements.put("|", " "); // vertical bar
|
||||
replacements.put("/", " "); // slash
|
||||
replacements.put("#", " "); // hash
|
||||
replacements.put("→", " "); // right arrow
|
||||
replacements.put("←", " "); // left arrow
|
||||
|
||||
for (Map.Entry<String, String> entry : replacements.entrySet()) {
|
||||
text = text.replace(entry.getKey(), entry.getValue());
|
||||
}
|
||||
|
||||
// Remove special symbols
|
||||
text = text.replaceAll("[♥☆♡©\\\\]", "");
|
||||
|
||||
// Replace known expressions
|
||||
Map<String, String> exprReplacements = new HashMap<>();
|
||||
exprReplacements.put("@", " at ");
|
||||
exprReplacements.put("e.g.,", "for example, ");
|
||||
exprReplacements.put("i.e.,", "that is, ");
|
||||
|
||||
for (Map.Entry<String, String> entry : exprReplacements.entrySet()) {
|
||||
text = text.replace(entry.getKey(), entry.getValue());
|
||||
}
|
||||
|
||||
// Fix spacing around punctuation
|
||||
text = text.replaceAll(" ,", ",");
|
||||
text = text.replaceAll(" \\.", ".");
|
||||
text = text.replaceAll(" !", "!");
|
||||
text = text.replaceAll(" \\?", "?");
|
||||
text = text.replaceAll(" ;", ";");
|
||||
text = text.replaceAll(" :", ":");
|
||||
text = text.replaceAll(" '", "'");
|
||||
|
||||
// Remove duplicate quotes
|
||||
while (text.contains("\"\"")) {
|
||||
text = text.replace("\"\"", "\"");
|
||||
}
|
||||
while (text.contains("''")) {
|
||||
text = text.replace("''", "'");
|
||||
}
|
||||
while (text.contains("``")) {
|
||||
text = text.replace("``", "`");
|
||||
}
|
||||
|
||||
// Remove extra spaces
|
||||
text = text.replaceAll("\\s+", " ").trim();
|
||||
|
||||
// If text doesn't end with punctuation, quotes, or closing brackets, add a period
|
||||
if (!text.matches(".*[.!?;:,'\"\\u201C\\u201D\\u2018\\u2019)\\]}…。」』】〉》›»]$")) {
|
||||
text += ".";
|
||||
}
|
||||
|
||||
// Validate language
|
||||
if (!Languages.isValid(lang)) {
|
||||
throw new IllegalArgumentException("Invalid language: " + lang + ". Available: " + Languages.AVAILABLE);
|
||||
}
|
||||
|
||||
// Wrap text with language tags
|
||||
text = "<" + lang + ">" + text + "</" + lang + ">";
|
||||
|
||||
return text;
|
||||
}
|
||||
|
||||
private int[] textToUnicodeValues(String text) {
|
||||
// Use codePoints() stream to correctly handle surrogate pairs
|
||||
return text.codePoints().toArray();
|
||||
}
|
||||
|
||||
private float[][][] getTextMask(int[] lengths) {
|
||||
int bsz = lengths.length;
|
||||
int maxLen = 0;
|
||||
for (int len : lengths) {
|
||||
maxLen = Math.max(maxLen, len);
|
||||
}
|
||||
|
||||
float[][][] mask = new float[bsz][1][maxLen];
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
for (int j = 0; j < maxLen; j++) {
|
||||
mask[i][0][j] = j < lengths[i] ? 1.0f : 0.0f;
|
||||
}
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
||||
static class TextProcessResult {
|
||||
long[][] textIds;
|
||||
float[][][] textMask;
|
||||
|
||||
TextProcessResult(long[][] textIds, float[][][] textMask) {
|
||||
this.textIds = textIds;
|
||||
this.textMask = textMask;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Text-to-Speech inference class
|
||||
*/
|
||||
class TextToSpeech {
|
||||
private Config config;
|
||||
private UnicodeProcessor textProcessor;
|
||||
private OrtSession dpSession;
|
||||
private OrtSession textEncSession;
|
||||
private OrtSession vectorEstSession;
|
||||
private OrtSession vocoderSession;
|
||||
public int sampleRate;
|
||||
private int baseChunkSize;
|
||||
private int chunkCompress;
|
||||
private int ldim;
|
||||
|
||||
public TextToSpeech(Config config, UnicodeProcessor textProcessor,
|
||||
OrtSession dpSession, OrtSession textEncSession,
|
||||
OrtSession vectorEstSession, OrtSession vocoderSession) {
|
||||
this.config = config;
|
||||
this.textProcessor = textProcessor;
|
||||
this.dpSession = dpSession;
|
||||
this.textEncSession = textEncSession;
|
||||
this.vectorEstSession = vectorEstSession;
|
||||
this.vocoderSession = vocoderSession;
|
||||
this.sampleRate = config.ae.sampleRate;
|
||||
this.baseChunkSize = config.ae.baseChunkSize;
|
||||
this.chunkCompress = config.ttl.chunkCompressFactor;
|
||||
this.ldim = config.ttl.latentDim;
|
||||
}
|
||||
|
||||
private TTSResult _infer(List<String> textList, List<String> langList, Style style, int totalStep, float speed, OrtEnvironment env)
|
||||
throws OrtException {
|
||||
int bsz = textList.size();
|
||||
|
||||
// Process text
|
||||
UnicodeProcessor.TextProcessResult textResult = textProcessor.call(textList, langList);
|
||||
long[][] textIds = textResult.textIds;
|
||||
float[][][] textMask = textResult.textMask;
|
||||
|
||||
// Create tensors
|
||||
OnnxTensor textIdsTensor = Helper.createLongTensor(textIds, env);
|
||||
OnnxTensor textMaskTensor = Helper.createFloatTensor(textMask, env);
|
||||
|
||||
// Predict duration
|
||||
Map<String, OnnxTensor> dpInputs = new HashMap<>();
|
||||
dpInputs.put("text_ids", textIdsTensor);
|
||||
dpInputs.put("style_dp", style.dpTensor);
|
||||
dpInputs.put("text_mask", textMaskTensor);
|
||||
|
||||
OrtSession.Result dpResult = dpSession.run(dpInputs);
|
||||
Object dpValue = dpResult.get(0).getValue();
|
||||
float[] duration;
|
||||
if (dpValue instanceof float[][]) {
|
||||
duration = ((float[][]) dpValue)[0];
|
||||
} else {
|
||||
duration = (float[]) dpValue;
|
||||
}
|
||||
|
||||
// Apply speed factor to duration
|
||||
for (int i = 0; i < duration.length; i++) {
|
||||
duration[i] /= speed;
|
||||
}
|
||||
|
||||
// Encode text
|
||||
Map<String, OnnxTensor> textEncInputs = new HashMap<>();
|
||||
textEncInputs.put("text_ids", textIdsTensor);
|
||||
textEncInputs.put("style_ttl", style.ttlTensor);
|
||||
textEncInputs.put("text_mask", textMaskTensor);
|
||||
|
||||
OrtSession.Result textEncResult = textEncSession.run(textEncInputs);
|
||||
OnnxTensor textEmbTensor = (OnnxTensor) textEncResult.get(0);
|
||||
|
||||
// Sample noisy latent
|
||||
NoisyLatentResult noisyLatentResult = sampleNoisyLatent(duration);
|
||||
float[][][] xt = noisyLatentResult.noisyLatent;
|
||||
float[][][] latentMask = noisyLatentResult.latentMask;
|
||||
|
||||
// Prepare constant tensors
|
||||
float[] totalStepArray = new float[bsz];
|
||||
Arrays.fill(totalStepArray, (float) totalStep);
|
||||
OnnxTensor totalStepTensor = OnnxTensor.createTensor(env, totalStepArray);
|
||||
|
||||
// Denoising loop
|
||||
for (int step = 0; step < totalStep; step++) {
|
||||
float[] currentStepArray = new float[bsz];
|
||||
Arrays.fill(currentStepArray, (float) step);
|
||||
OnnxTensor currentStepTensor = OnnxTensor.createTensor(env, currentStepArray);
|
||||
OnnxTensor noisyLatentTensor = Helper.createFloatTensor(xt, env);
|
||||
OnnxTensor latentMaskTensor = Helper.createFloatTensor(latentMask, env);
|
||||
OnnxTensor textMaskTensor2 = Helper.createFloatTensor(textMask, env);
|
||||
|
||||
Map<String, OnnxTensor> vectorEstInputs = new HashMap<>();
|
||||
vectorEstInputs.put("noisy_latent", noisyLatentTensor);
|
||||
vectorEstInputs.put("text_emb", textEmbTensor);
|
||||
vectorEstInputs.put("style_ttl", style.ttlTensor);
|
||||
vectorEstInputs.put("latent_mask", latentMaskTensor);
|
||||
vectorEstInputs.put("text_mask", textMaskTensor2);
|
||||
vectorEstInputs.put("current_step", currentStepTensor);
|
||||
vectorEstInputs.put("total_step", totalStepTensor);
|
||||
|
||||
OrtSession.Result vectorEstResult = vectorEstSession.run(vectorEstInputs);
|
||||
float[][][] denoised = (float[][][]) vectorEstResult.get(0).getValue();
|
||||
|
||||
// Update latent
|
||||
xt = denoised;
|
||||
|
||||
// Clean up
|
||||
currentStepTensor.close();
|
||||
noisyLatentTensor.close();
|
||||
latentMaskTensor.close();
|
||||
textMaskTensor2.close();
|
||||
vectorEstResult.close();
|
||||
}
|
||||
|
||||
// Generate waveform
|
||||
OnnxTensor finalLatentTensor = Helper.createFloatTensor(xt, env);
|
||||
Map<String, OnnxTensor> vocoderInputs = new HashMap<>();
|
||||
vocoderInputs.put("latent", finalLatentTensor);
|
||||
|
||||
OrtSession.Result vocoderResult = vocoderSession.run(vocoderInputs);
|
||||
float[][] wavBatch = (float[][]) vocoderResult.get(0).getValue();
|
||||
|
||||
// Flatten all batch audio into a single array for batch processing
|
||||
int totalSamples = 0;
|
||||
for (float[] w : wavBatch) {
|
||||
totalSamples += w.length;
|
||||
}
|
||||
float[] wav = new float[totalSamples];
|
||||
int offset = 0;
|
||||
for (float[] w : wavBatch) {
|
||||
System.arraycopy(w, 0, wav, offset, w.length);
|
||||
offset += w.length;
|
||||
}
|
||||
|
||||
// Clean up
|
||||
textIdsTensor.close();
|
||||
textMaskTensor.close();
|
||||
dpResult.close();
|
||||
textEncResult.close();
|
||||
totalStepTensor.close();
|
||||
finalLatentTensor.close();
|
||||
vocoderResult.close();
|
||||
|
||||
return new TTSResult(wav, duration);
|
||||
}
|
||||
|
||||
private NoisyLatentResult sampleNoisyLatent(float[] duration) {
|
||||
int bsz = duration.length;
|
||||
float maxDur = 0;
|
||||
for (float d : duration) {
|
||||
maxDur = Math.max(maxDur, d);
|
||||
}
|
||||
|
||||
long wavLenMax = (long) (maxDur * sampleRate);
|
||||
long[] wavLengths = new long[bsz];
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
wavLengths[i] = (long) (duration[i] * sampleRate);
|
||||
}
|
||||
|
||||
int chunkSize = baseChunkSize * chunkCompress;
|
||||
int latentLen = (int) ((wavLenMax + chunkSize - 1) / chunkSize);
|
||||
int latentDim = ldim * chunkCompress;
|
||||
|
||||
Random rng = new Random();
|
||||
float[][][] noisyLatent = new float[bsz][latentDim][latentLen];
|
||||
for (int b = 0; b < bsz; b++) {
|
||||
for (int d = 0; d < latentDim; d++) {
|
||||
for (int t = 0; t < latentLen; t++) {
|
||||
// Box-Muller transform
|
||||
double u1 = Math.max(1e-10, rng.nextDouble());
|
||||
double u2 = rng.nextDouble();
|
||||
noisyLatent[b][d][t] = (float) (Math.sqrt(-2.0 * Math.log(u1)) * Math.cos(2.0 * Math.PI * u2));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float[][][] latentMask = Helper.getLatentMask(wavLengths, config);
|
||||
|
||||
// Apply mask
|
||||
for (int b = 0; b < bsz; b++) {
|
||||
for (int d = 0; d < latentDim; d++) {
|
||||
for (int t = 0; t < latentLen; t++) {
|
||||
noisyLatent[b][d][t] *= latentMask[b][0][t];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return new NoisyLatentResult(noisyLatent, latentMask);
|
||||
}
|
||||
|
||||
/**
|
||||
* Synthesize speech from a single text with automatic chunking
|
||||
*/
|
||||
public TTSResult call(String text, String lang, Style style, int totalStep, float speed, float silenceDuration, OrtEnvironment env)
|
||||
throws OrtException {
|
||||
int maxLen = lang.equals("ko") ? 120 : 300;
|
||||
List<String> chunks = Helper.chunkText(text, maxLen);
|
||||
|
||||
List<Float> wavCat = new ArrayList<>();
|
||||
float durCat = 0.0f;
|
||||
|
||||
for (int i = 0; i < chunks.size(); i++) {
|
||||
TTSResult result = _infer(Arrays.asList(chunks.get(i)), Arrays.asList(lang), style, totalStep, speed, env);
|
||||
|
||||
float dur = result.duration[0];
|
||||
int wavLen = (int) (sampleRate * dur);
|
||||
float[] wavChunk = new float[wavLen];
|
||||
System.arraycopy(result.wav, 0, wavChunk, 0, Math.min(wavLen, result.wav.length));
|
||||
|
||||
if (i == 0) {
|
||||
for (float val : wavChunk) {
|
||||
wavCat.add(val);
|
||||
}
|
||||
durCat = dur;
|
||||
} else {
|
||||
int silenceLen = (int) (silenceDuration * sampleRate);
|
||||
for (int j = 0; j < silenceLen; j++) {
|
||||
wavCat.add(0.0f);
|
||||
}
|
||||
for (float val : wavChunk) {
|
||||
wavCat.add(val);
|
||||
}
|
||||
durCat += silenceDuration + dur;
|
||||
}
|
||||
}
|
||||
|
||||
float[] wavArray = new float[wavCat.size()];
|
||||
for (int i = 0; i < wavCat.size(); i++) {
|
||||
wavArray[i] = wavCat.get(i);
|
||||
}
|
||||
|
||||
return new TTSResult(wavArray, new float[]{durCat});
|
||||
}
|
||||
|
||||
/**
|
||||
* Batch synthesize speech from multiple texts
|
||||
*/
|
||||
public TTSResult batch(List<String> textList, List<String> langList, Style style, int totalStep, float speed, OrtEnvironment env)
|
||||
throws OrtException {
|
||||
return _infer(textList, langList, style, totalStep, speed, env);
|
||||
}
|
||||
|
||||
public void close() throws OrtException {
|
||||
if (dpSession != null) dpSession.close();
|
||||
if (textEncSession != null) textEncSession.close();
|
||||
if (vectorEstSession != null) vectorEstSession.close();
|
||||
if (vocoderSession != null) vocoderSession.close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Style holder class
|
||||
*/
|
||||
class Style {
|
||||
OnnxTensor ttlTensor;
|
||||
OnnxTensor dpTensor;
|
||||
|
||||
Style(OnnxTensor ttlTensor, OnnxTensor dpTensor) {
|
||||
this.ttlTensor = ttlTensor;
|
||||
this.dpTensor = dpTensor;
|
||||
}
|
||||
|
||||
public void close() throws OrtException {
|
||||
if (ttlTensor != null) ttlTensor.close();
|
||||
if (dpTensor != null) dpTensor.close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* TTS result holder
|
||||
*/
|
||||
class TTSResult {
|
||||
float[] wav;
|
||||
float[] duration;
|
||||
|
||||
TTSResult(float[] wav, float[] duration) {
|
||||
this.wav = wav;
|
||||
this.duration = duration;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Noisy latent result holder
|
||||
*/
|
||||
class NoisyLatentResult {
|
||||
float[][][] noisyLatent;
|
||||
float[][][] latentMask;
|
||||
|
||||
NoisyLatentResult(float[][][] noisyLatent, float[][][] latentMask) {
|
||||
this.noisyLatent = noisyLatent;
|
||||
this.latentMask = latentMask;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper utility class
|
||||
*/
|
||||
public class Helper {
|
||||
|
||||
private static final int MAX_CHUNK_LENGTH = 300;
|
||||
private static final String[] ABBREVIATIONS = {
|
||||
"Dr.", "Mr.", "Mrs.", "Ms.", "Prof.", "Sr.", "Jr.",
|
||||
"St.", "Ave.", "Rd.", "Blvd.", "Dept.", "Inc.", "Ltd.",
|
||||
"Co.", "Corp.", "etc.", "vs.", "i.e.", "e.g.", "Ph.D."
|
||||
};
|
||||
|
||||
/**
|
||||
* Chunk text into smaller segments based on paragraphs and sentences
|
||||
*/
|
||||
public static List<String> chunkText(String text, int maxLen) {
|
||||
if (maxLen == 0) {
|
||||
maxLen = MAX_CHUNK_LENGTH;
|
||||
}
|
||||
|
||||
text = text.trim();
|
||||
if (text.isEmpty()) {
|
||||
return Arrays.asList("");
|
||||
}
|
||||
|
||||
// Split by paragraphs
|
||||
String[] paragraphs = text.split("\\n\\s*\\n");
|
||||
List<String> chunks = new ArrayList<>();
|
||||
|
||||
for (String para : paragraphs) {
|
||||
para = para.trim();
|
||||
if (para.isEmpty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (para.length() <= maxLen) {
|
||||
chunks.add(para);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Split by sentences
|
||||
List<String> sentences = splitSentences(para);
|
||||
StringBuilder current = new StringBuilder();
|
||||
int currentLen = 0;
|
||||
|
||||
for (String sentence : sentences) {
|
||||
sentence = sentence.trim();
|
||||
if (sentence.isEmpty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int sentenceLen = sentence.length();
|
||||
if (sentenceLen > maxLen) {
|
||||
// If sentence is longer than maxLen, split by comma or space
|
||||
if (current.length() > 0) {
|
||||
chunks.add(current.toString().trim());
|
||||
current.setLength(0);
|
||||
currentLen = 0;
|
||||
}
|
||||
|
||||
// Try splitting by comma
|
||||
String[] parts = sentence.split(",");
|
||||
for (String part : parts) {
|
||||
part = part.trim();
|
||||
if (part.isEmpty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int partLen = part.length();
|
||||
if (partLen > maxLen) {
|
||||
// Split by space as last resort
|
||||
String[] words = part.split("\\s+");
|
||||
StringBuilder wordChunk = new StringBuilder();
|
||||
int wordChunkLen = 0;
|
||||
|
||||
for (String word : words) {
|
||||
int wordLen = word.length();
|
||||
if (wordChunkLen + wordLen + 1 > maxLen && wordChunk.length() > 0) {
|
||||
chunks.add(wordChunk.toString().trim());
|
||||
wordChunk.setLength(0);
|
||||
wordChunkLen = 0;
|
||||
}
|
||||
|
||||
if (wordChunk.length() > 0) {
|
||||
wordChunk.append(" ");
|
||||
wordChunkLen++;
|
||||
}
|
||||
wordChunk.append(word);
|
||||
wordChunkLen += wordLen;
|
||||
}
|
||||
|
||||
if (wordChunk.length() > 0) {
|
||||
chunks.add(wordChunk.toString().trim());
|
||||
}
|
||||
} else {
|
||||
if (currentLen + partLen + 1 > maxLen && current.length() > 0) {
|
||||
chunks.add(current.toString().trim());
|
||||
current.setLength(0);
|
||||
currentLen = 0;
|
||||
}
|
||||
|
||||
if (current.length() > 0) {
|
||||
current.append(", ");
|
||||
currentLen += 2;
|
||||
}
|
||||
current.append(part);
|
||||
currentLen += partLen;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (currentLen + sentenceLen + 1 > maxLen && current.length() > 0) {
|
||||
chunks.add(current.toString().trim());
|
||||
current.setLength(0);
|
||||
currentLen = 0;
|
||||
}
|
||||
|
||||
if (current.length() > 0) {
|
||||
current.append(" ");
|
||||
currentLen++;
|
||||
}
|
||||
current.append(sentence);
|
||||
currentLen += sentenceLen;
|
||||
}
|
||||
|
||||
if (current.length() > 0) {
|
||||
chunks.add(current.toString().trim());
|
||||
}
|
||||
}
|
||||
|
||||
if (chunks.isEmpty()) {
|
||||
return Arrays.asList("");
|
||||
}
|
||||
|
||||
return chunks;
|
||||
}
|
||||
|
||||
/**
|
||||
* Split text into sentences, avoiding common abbreviations
|
||||
*/
|
||||
private static List<String> splitSentences(String text) {
|
||||
// Build pattern that avoids abbreviations
|
||||
StringBuilder abbrevPattern = new StringBuilder();
|
||||
for (int i = 0; i < ABBREVIATIONS.length; i++) {
|
||||
if (i > 0) abbrevPattern.append("|");
|
||||
abbrevPattern.append(Pattern.quote(ABBREVIATIONS[i]));
|
||||
}
|
||||
|
||||
// Match sentence endings, but not abbreviations
|
||||
String patternStr = "(?<!(?:" + abbrevPattern.toString() + "))(?<=[.!?])\\s+";
|
||||
Pattern pattern = Pattern.compile(patternStr);
|
||||
return Arrays.asList(pattern.split(text));
|
||||
}
|
||||
|
||||
/**
|
||||
* Load voice style from JSON files
|
||||
*/
|
||||
public static Style loadVoiceStyle(List<String> voiceStylePaths, boolean verbose, OrtEnvironment env)
|
||||
throws IOException, OrtException {
|
||||
int bsz = voiceStylePaths.size();
|
||||
|
||||
// Read first file to get dimensions
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
JsonNode firstRoot = mapper.readTree(new File(voiceStylePaths.get(0)));
|
||||
|
||||
long[] ttlDims = new long[3];
|
||||
for (int i = 0; i < 3; i++) {
|
||||
ttlDims[i] = firstRoot.get("style_ttl").get("dims").get(i).asLong();
|
||||
}
|
||||
long[] dpDims = new long[3];
|
||||
for (int i = 0; i < 3; i++) {
|
||||
dpDims[i] = firstRoot.get("style_dp").get("dims").get(i).asLong();
|
||||
}
|
||||
|
||||
long ttlDim1 = ttlDims[1];
|
||||
long ttlDim2 = ttlDims[2];
|
||||
long dpDim1 = dpDims[1];
|
||||
long dpDim2 = dpDims[2];
|
||||
|
||||
// Pre-allocate arrays with full batch size
|
||||
int ttlSize = (int) (bsz * ttlDim1 * ttlDim2);
|
||||
int dpSize = (int) (bsz * dpDim1 * dpDim2);
|
||||
float[] ttlFlat = new float[ttlSize];
|
||||
float[] dpFlat = new float[dpSize];
|
||||
|
||||
// Fill in the data
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
JsonNode root = mapper.readTree(new File(voiceStylePaths.get(i)));
|
||||
|
||||
// Flatten TTL data
|
||||
int ttlOffset = (int) (i * ttlDim1 * ttlDim2);
|
||||
int idx = 0;
|
||||
JsonNode ttlData = root.get("style_ttl").get("data");
|
||||
for (JsonNode batch : ttlData) {
|
||||
for (JsonNode row : batch) {
|
||||
for (JsonNode val : row) {
|
||||
ttlFlat[ttlOffset + idx++] = (float) val.asDouble();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Flatten DP data
|
||||
int dpOffset = (int) (i * dpDim1 * dpDim2);
|
||||
idx = 0;
|
||||
JsonNode dpData = root.get("style_dp").get("data");
|
||||
for (JsonNode batch : dpData) {
|
||||
for (JsonNode row : batch) {
|
||||
for (JsonNode val : row) {
|
||||
dpFlat[dpOffset + idx++] = (float) val.asDouble();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
long[] ttlShape = {bsz, ttlDim1, ttlDim2};
|
||||
long[] dpShape = {bsz, dpDim1, dpDim2};
|
||||
|
||||
OnnxTensor ttlTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(ttlFlat), ttlShape);
|
||||
OnnxTensor dpTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(dpFlat), dpShape);
|
||||
|
||||
if (verbose) {
|
||||
System.out.println("Loaded " + bsz + " voice styles\n");
|
||||
}
|
||||
|
||||
return new Style(ttlTensor, dpTensor);
|
||||
}
|
||||
|
||||
/**
|
||||
* Load TTS components
|
||||
*/
|
||||
public static TextToSpeech loadTextToSpeech(String onnxDir, boolean useGpu, OrtEnvironment env)
|
||||
throws IOException, OrtException {
|
||||
if (useGpu) {
|
||||
throw new RuntimeException("GPU mode is not supported yet");
|
||||
}
|
||||
System.out.println("Using CPU for inference\n");
|
||||
|
||||
// Load config
|
||||
Config config = loadCfgs(onnxDir);
|
||||
|
||||
// Create session options
|
||||
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
|
||||
|
||||
// Load models
|
||||
OrtSession dpSession = env.createSession(onnxDir + "/duration_predictor.onnx", opts);
|
||||
OrtSession textEncSession = env.createSession(onnxDir + "/text_encoder.onnx", opts);
|
||||
OrtSession vectorEstSession = env.createSession(onnxDir + "/vector_estimator.onnx", opts);
|
||||
OrtSession vocoderSession = env.createSession(onnxDir + "/vocoder.onnx", opts);
|
||||
|
||||
// Load text processor
|
||||
UnicodeProcessor textProcessor = new UnicodeProcessor(onnxDir + "/unicode_indexer.json");
|
||||
|
||||
return new TextToSpeech(config, textProcessor, dpSession, textEncSession, vectorEstSession, vocoderSession);
|
||||
}
|
||||
|
||||
/**
|
||||
* Load configuration from JSON
|
||||
*/
|
||||
public static Config loadCfgs(String onnxDir) throws IOException {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
JsonNode root = mapper.readTree(new File(onnxDir + "/tts.json"));
|
||||
|
||||
Config config = new Config();
|
||||
config.ae = new Config.AEConfig();
|
||||
config.ae.sampleRate = root.get("ae").get("sample_rate").asInt();
|
||||
config.ae.baseChunkSize = root.get("ae").get("base_chunk_size").asInt();
|
||||
|
||||
config.ttl = new Config.TTLConfig();
|
||||
config.ttl.chunkCompressFactor = root.get("ttl").get("chunk_compress_factor").asInt();
|
||||
config.ttl.latentDim = root.get("ttl").get("latent_dim").asInt();
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get latent mask from wav lengths
|
||||
*/
|
||||
public static float[][][] getLatentMask(long[] wavLengths, Config config) {
|
||||
long baseChunkSize = config.ae.baseChunkSize;
|
||||
long chunkCompressFactor = config.ttl.chunkCompressFactor;
|
||||
long latentSize = baseChunkSize * chunkCompressFactor;
|
||||
|
||||
long[] latentLengths = new long[wavLengths.length];
|
||||
long maxLen = 0;
|
||||
for (int i = 0; i < wavLengths.length; i++) {
|
||||
latentLengths[i] = (wavLengths[i] + latentSize - 1) / latentSize;
|
||||
maxLen = Math.max(maxLen, latentLengths[i]);
|
||||
}
|
||||
|
||||
float[][][] mask = new float[wavLengths.length][1][(int) maxLen];
|
||||
for (int i = 0; i < wavLengths.length; i++) {
|
||||
for (int j = 0; j < maxLen; j++) {
|
||||
mask[i][0][j] = j < latentLengths[i] ? 1.0f : 0.0f;
|
||||
}
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
||||
/**
|
||||
* Write WAV file
|
||||
*/
|
||||
public static void writeWavFile(String filename, float[] audioData, int sampleRate) throws IOException {
|
||||
// Convert float to byte array
|
||||
byte[] bytes = new byte[audioData.length * 2];
|
||||
ByteBuffer buffer = ByteBuffer.wrap(bytes);
|
||||
buffer.order(ByteOrder.LITTLE_ENDIAN);
|
||||
|
||||
for (float sample : audioData) {
|
||||
short val = (short) Math.max(-32768, Math.min(32767, sample * 32767));
|
||||
buffer.putShort(val);
|
||||
}
|
||||
|
||||
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
||||
AudioFormat format = new AudioFormat(sampleRate, 16, 1, true, false);
|
||||
AudioInputStream ais = new AudioInputStream(bais, format, audioData.length);
|
||||
AudioSystem.write(ais, AudioFileFormat.Type.WAVE, new File(filename));
|
||||
}
|
||||
|
||||
/**
|
||||
* Sanitize filename (supports Unicode characters)
|
||||
*/
|
||||
public static String sanitizeFilename(String text, int maxLen) {
|
||||
// Get first maxLen characters (code points, not chars for surrogate pairs)
|
||||
int[] codePoints = text.codePoints().limit(maxLen).toArray();
|
||||
StringBuilder result = new StringBuilder();
|
||||
for (int codePoint : codePoints) {
|
||||
if (Character.isLetterOrDigit(codePoint)) {
|
||||
result.appendCodePoint(codePoint);
|
||||
} else {
|
||||
result.append('_');
|
||||
}
|
||||
}
|
||||
return result.toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* Timer utility
|
||||
*/
|
||||
public static <T> T timer(String name, java.util.function.Supplier<T> fn) {
|
||||
long start = System.currentTimeMillis();
|
||||
System.out.println(name + "...");
|
||||
T result = fn.get();
|
||||
long elapsed = System.currentTimeMillis() - start;
|
||||
System.out.printf(" -> %s completed in %.2f sec\n", name, elapsed / 1000.0);
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create float tensor from 3D array
|
||||
*/
|
||||
public static OnnxTensor createFloatTensor(float[][][] array, OrtEnvironment env) throws OrtException {
|
||||
int dim0 = array.length;
|
||||
int dim1 = array[0].length;
|
||||
int dim2 = array[0][0].length;
|
||||
|
||||
float[] flat = new float[dim0 * dim1 * dim2];
|
||||
int idx = 0;
|
||||
for (int i = 0; i < dim0; i++) {
|
||||
for (int j = 0; j < dim1; j++) {
|
||||
for (int k = 0; k < dim2; k++) {
|
||||
flat[idx++] = array[i][j][k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
long[] shape = {dim0, dim1, dim2};
|
||||
return OnnxTensor.createTensor(env, FloatBuffer.wrap(flat), shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create long tensor from 2D array
|
||||
*/
|
||||
public static OnnxTensor createLongTensor(long[][] array, OrtEnvironment env) throws OrtException {
|
||||
int dim0 = array.length;
|
||||
int dim1 = array[0].length;
|
||||
|
||||
long[] flat = new long[dim0 * dim1];
|
||||
int idx = 0;
|
||||
for (int i = 0; i < dim0; i++) {
|
||||
for (int j = 0; j < dim1; j++) {
|
||||
flat[idx++] = array[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
long[] shape = {dim0, dim1};
|
||||
return OnnxTensor.createTensor(env, LongBuffer.wrap(flat), shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* Load JSON long array
|
||||
*/
|
||||
public static long[] loadJsonLongArray(String filePath) throws IOException {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
JsonNode root = mapper.readTree(new File(filePath));
|
||||
|
||||
long[] result = new long[root.size()];
|
||||
for (int i = 0; i < root.size(); i++) {
|
||||
result[i] = root.get(i).asLong();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
130
java/README.md
Normal file
130
java/README.md
Normal file
@@ -0,0 +1,130 @@
|
||||
# TTS ONNX Inference Examples
|
||||
|
||||
This guide provides examples for running TTS inference using `ExampleONNX.java`.
|
||||
|
||||
## 📰 Update News
|
||||
|
||||
**2026.01.06** - 🎉 **Supertonic 2** released with multilingual support! Now supports English (`en`), Korean (`ko`), Spanish (`es`), Portuguese (`pt`), and French (`fr`). [Demo](https://huggingface.co/spaces/Supertone/supertonic-2) | [Models](https://huggingface.co/Supertone/supertonic-2)
|
||||
|
||||
**2025.12.10** - Added [6 new voice styles](https://huggingface.co/Supertone/supertonic/tree/b10dbaf18b316159be75b34d24f740008fddd381) (M3, M4, M5, F3, F4, F5). See [Voices](https://supertone-inc.github.io/supertonic-py/voices/) for details
|
||||
|
||||
**2025.12.08** - Optimized ONNX models via [OnnxSlim](https://github.com/inisis/OnnxSlim) now available on [Hugging Face Models](https://huggingface.co/Supertone/supertonic)
|
||||
|
||||
**2025.11.23** - Enhanced text preprocessing with comprehensive normalization, emoji removal, symbol replacement, and punctuation handling for improved synthesis quality.
|
||||
|
||||
**2025.11.19** - Added `--speed` parameter to control speech synthesis speed (default: 1.05, recommended range: 0.9-1.5).
|
||||
|
||||
**2025.11.19** - Added automatic text chunking for long-form inference. Long texts are split into chunks and synthesized with natural pauses.
|
||||
|
||||
## Installation
|
||||
|
||||
This project uses [Maven](https://maven.apache.org/) for dependency management.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Java 11 or higher
|
||||
- Maven 3.6 or higher
|
||||
|
||||
### Install dependencies
|
||||
|
||||
```bash
|
||||
mvn clean install
|
||||
```
|
||||
|
||||
## Basic Usage
|
||||
|
||||
### Example 1: Default Inference
|
||||
Run inference with default settings:
|
||||
```bash
|
||||
mvn exec:java
|
||||
```
|
||||
|
||||
This will use:
|
||||
- Voice style: `assets/voice_styles/M1.json`
|
||||
- Text: "This morning, I took a walk in the park, and the sound of the birds and the breeze was so pleasant that I stopped for a long time just to listen."
|
||||
- Output directory: `results/`
|
||||
- Total steps: 5
|
||||
- Number of generations: 4
|
||||
|
||||
### Example 2: Batch Inference
|
||||
Process multiple voice styles and texts at once:
|
||||
```bash
|
||||
mvn exec:java -Dexec.args="--batch --voice-style assets/voice_styles/M1.json,assets/voice_styles/F1.json --text 'The sun sets behind the mountains, painting the sky in shades of pink and orange.|오늘 아침에 공원을 산책했는데, 새소리와 바람 소리가 너무 기분 좋았어요.' --lang en,ko"
|
||||
```
|
||||
|
||||
This will:
|
||||
- Generate speech for 2 different voice-text-language pairs
|
||||
- Use male voice (M1.json) for the first text in English
|
||||
- Use female voice (F1.json) for the second text in Korean
|
||||
- Process both samples in a single batch
|
||||
|
||||
### Example 3: High Quality Inference
|
||||
Increase denoising steps for better quality:
|
||||
```bash
|
||||
mvn exec:java -Dexec.args="--total-step 10 --voice-style assets/voice_styles/M1.json --text 'Increasing the number of denoising steps improves the output fidelity and overall quality.'"
|
||||
```
|
||||
|
||||
This will:
|
||||
- Use 10 denoising steps instead of the default 5
|
||||
- Produce higher quality output at the cost of slower inference
|
||||
|
||||
### Example 4: Long-Form Inference
|
||||
The system automatically chunks long texts into manageable segments, synthesizes each segment separately, and concatenates them with natural pauses (0.3 seconds by default) into a single audio file. This happens by default when you don't use the `--batch` flag:
|
||||
|
||||
```bash
|
||||
mvn exec:java -Dexec.args="--voice-style assets/voice_styles/M1.json --text 'This is a very long text that will be automatically split into multiple chunks. The system will process each chunk separately and then concatenate them together with natural pauses between segments. This ensures that even very long texts can be processed efficiently while maintaining natural speech flow and avoiding memory issues.'"
|
||||
```
|
||||
|
||||
This will:
|
||||
- Automatically split the text into chunks based on paragraph and sentence boundaries
|
||||
- Synthesize each chunk separately
|
||||
- Add 0.3 seconds of silence between chunks for natural pauses
|
||||
- Concatenate all chunks into a single audio file
|
||||
|
||||
**Note**: Automatic text chunking is disabled when using `--batch` mode. In batch mode, each text is processed as-is without chunking.
|
||||
|
||||
**Tip**: If your text contains apostrophes, use escaping or run the JAR directly:
|
||||
```bash
|
||||
java -jar target/tts-example.jar --total-step 10 --text "Text with apostrophe's here"
|
||||
```
|
||||
|
||||
## Building a Fat JAR
|
||||
|
||||
To create a standalone JAR with all dependencies:
|
||||
```bash
|
||||
mvn clean package
|
||||
```
|
||||
|
||||
Then run it directly:
|
||||
```bash
|
||||
java -jar target/tts-example.jar
|
||||
```
|
||||
|
||||
Or with arguments:
|
||||
```bash
|
||||
java -jar target/tts-example.jar --total-step 10 --text "Your custom text here"
|
||||
```
|
||||
|
||||
## Available Arguments
|
||||
|
||||
| Argument | Type | Default | Description |
|
||||
|----------|------|---------|-------------|
|
||||
| `--use-gpu` | flag | False | Use GPU for inference (default: CPU) |
|
||||
| `--onnx-dir` | str | `assets/onnx` | Path to ONNX model directory |
|
||||
| `--total-step` | int | 5 | Number of denoising steps (higher = better quality, slower) |
|
||||
| `--n-test` | int | 4 | Number of times to generate each sample |
|
||||
| `--voice-style` | str+ | `assets/voice_styles/M1.json` | Voice style file path(s), comma-separated |
|
||||
| `--text` | str+ | (long default text) | Text(s) to synthesize, pipe-separated |
|
||||
| `--lang` | str+ | `en` | Language(s) for synthesis, comma-separated (en, ko, es, pt, fr) |
|
||||
| `--save-dir` | str | `results` | Output directory |
|
||||
| `--batch` | flag | False | Enable batch mode (multiple text-style pairs, disables automatic chunking) |
|
||||
|
||||
## Notes
|
||||
|
||||
- **Multilingual Support**: Use `--lang` to specify the language for each text. Available: `en` (English), `ko` (Korean), `es` (Spanish), `pt` (Portuguese), `fr` (French)
|
||||
- **Batch Processing**: When using `--batch`, the number of `--voice-style`, `--text`, and `--lang` entries must match
|
||||
- **Automatic Chunking**: Without `--batch`, long texts are automatically split and concatenated with 0.3s pauses
|
||||
- **Quality vs Speed**: Higher `--total-step` values produce better quality but take longer
|
||||
- **GPU Support**: GPU mode is not supported yet
|
||||
- **Voice Styles**: Uses pre-extracted voice style JSON files for fast inference
|
||||
|
||||
110
java/pom.xml
Normal file
110
java/pom.xml
Normal file
@@ -0,0 +1,110 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
|
||||
http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<groupId>ai.supertonic</groupId>
|
||||
<artifactId>tts-onnx-java</artifactId>
|
||||
<version>1.0.0</version>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<name>TTS ONNX Java Example</name>
|
||||
<description>Text-to-Speech inference using ONNX Runtime in Java</description>
|
||||
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<maven.compiler.source>11</maven.compiler.source>
|
||||
<maven.compiler.target>11</maven.compiler.target>
|
||||
<onnxruntime.version>1.23.1</onnxruntime.version>
|
||||
<jackson.version>2.15.2</jackson.version>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<!-- ONNX Runtime -->
|
||||
<dependency>
|
||||
<groupId>com.microsoft.onnxruntime</groupId>
|
||||
<artifactId>onnxruntime</artifactId>
|
||||
<version>${onnxruntime.version}</version>
|
||||
</dependency>
|
||||
|
||||
<!-- Jackson for JSON parsing -->
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<!-- JTransforms for Fast FFT -->
|
||||
<dependency>
|
||||
<groupId>com.github.wendykierp</groupId>
|
||||
<artifactId>JTransforms</artifactId>
|
||||
<version>3.1</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
<sourceDirectory>.</sourceDirectory>
|
||||
<plugins>
|
||||
<!-- Maven Compiler Plugin -->
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-compiler-plugin</artifactId>
|
||||
<version>3.11.0</version>
|
||||
<configuration>
|
||||
<source>11</source>
|
||||
<target>11</target>
|
||||
</configuration>
|
||||
</plugin>
|
||||
|
||||
<!-- Maven Exec Plugin for running the example -->
|
||||
<plugin>
|
||||
<groupId>org.codehaus.mojo</groupId>
|
||||
<artifactId>exec-maven-plugin</artifactId>
|
||||
<version>3.1.0</version>
|
||||
<configuration>
|
||||
<mainClass>ExampleONNX</mainClass>
|
||||
</configuration>
|
||||
</plugin>
|
||||
|
||||
<!-- Maven Jar Plugin -->
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-jar-plugin</artifactId>
|
||||
<version>3.3.0</version>
|
||||
<configuration>
|
||||
<archive>
|
||||
<manifest>
|
||||
<mainClass>ExampleONNX</mainClass>
|
||||
</manifest>
|
||||
</archive>
|
||||
</configuration>
|
||||
</plugin>
|
||||
|
||||
<!-- Maven Shade Plugin for creating fat JAR -->
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-shade-plugin</artifactId>
|
||||
<version>3.5.0</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<phase>package</phase>
|
||||
<goals>
|
||||
<goal>shade</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<transformers>
|
||||
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
|
||||
<mainClass>ExampleONNX</mainClass>
|
||||
</transformer>
|
||||
</transformers>
|
||||
<finalName>tts-example</finalName>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
</project>
|
||||
|
||||
Reference in New Issue
Block a user