194 lines
5.3 KiB
Go
194 lines
5.3 KiB
Go
package main
|
|
|
|
import (
|
|
"flag"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
ort "github.com/yalue/onnxruntime_go"
|
|
)
|
|
|
|
// Args holds command line arguments
|
|
type Args struct {
|
|
useGPU bool
|
|
onnxDir string
|
|
totalStep int
|
|
speed float64
|
|
nTest int
|
|
voiceStyle []string
|
|
text []string
|
|
lang []string
|
|
saveDir string
|
|
batch bool
|
|
}
|
|
|
|
func parseArgs() *Args {
|
|
args := &Args{}
|
|
|
|
flag.BoolVar(&args.useGPU, "use-gpu", false, "Use GPU for inference (default: CPU)")
|
|
flag.StringVar(&args.onnxDir, "onnx-dir", "assets/onnx", "Path to ONNX model directory")
|
|
flag.IntVar(&args.totalStep, "total-step", 5, "Number of denoising steps")
|
|
flag.Float64Var(&args.speed, "speed", 1.05, "Speech speed factor (higher = faster)")
|
|
flag.IntVar(&args.nTest, "n-test", 4, "Number of times to generate")
|
|
flag.StringVar(&args.saveDir, "save-dir", "results", "Output directory")
|
|
flag.BoolVar(&args.batch, "batch", false, "Enable batch mode (multiple text-style pairs)")
|
|
|
|
var voiceStyleStr, textStr, langStr string
|
|
flag.StringVar(&voiceStyleStr, "voice-style", "assets/voice_styles/M1.json", "Voice style file path(s), comma-separated")
|
|
flag.StringVar(&textStr, "text", "This morning, I took a walk in the park, and the sound of the birds and the breeze was so pleasant that I stopped for a long time just to listen.", "Text(s) to synthesize, pipe-separated")
|
|
flag.StringVar(&langStr, "lang", "en", "Language(s) for synthesis, comma-separated (en, ko, es, pt, fr)")
|
|
|
|
flag.Parse()
|
|
|
|
// Parse comma-separated voice-style
|
|
if voiceStyleStr != "" {
|
|
args.voiceStyle = strings.Split(voiceStyleStr, ",")
|
|
for i := range args.voiceStyle {
|
|
args.voiceStyle[i] = strings.TrimSpace(args.voiceStyle[i])
|
|
}
|
|
}
|
|
|
|
// Parse pipe-separated text
|
|
if textStr != "" {
|
|
args.text = strings.Split(textStr, "|")
|
|
for i := range args.text {
|
|
args.text[i] = strings.TrimSpace(args.text[i])
|
|
}
|
|
}
|
|
|
|
// Parse comma-separated lang
|
|
if langStr != "" {
|
|
args.lang = strings.Split(langStr, ",")
|
|
for i := range args.lang {
|
|
args.lang[i] = strings.TrimSpace(args.lang[i])
|
|
}
|
|
}
|
|
|
|
return args
|
|
}
|
|
|
|
func main() {
|
|
fmt.Println("=== TTS Inference with ONNX Runtime (Go) ===\n")
|
|
|
|
// --- 1. Parse arguments --- //
|
|
args := parseArgs()
|
|
totalStep := args.totalStep
|
|
speed := float32(args.speed)
|
|
nTest := args.nTest
|
|
saveDir := args.saveDir
|
|
voiceStylePaths := args.voiceStyle
|
|
textList := args.text
|
|
langList := args.lang
|
|
batch := args.batch
|
|
|
|
if batch {
|
|
if len(voiceStylePaths) != len(textList) {
|
|
fmt.Printf("Error: Number of voice styles (%d) must match number of texts (%d)\n",
|
|
len(voiceStylePaths), len(textList))
|
|
os.Exit(1)
|
|
}
|
|
if len(langList) != len(textList) {
|
|
fmt.Printf("Error: Number of languages (%d) must match number of texts (%d)\n",
|
|
len(langList), len(textList))
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
bsz := len(voiceStylePaths)
|
|
|
|
// Initialize ONNX Runtime
|
|
if err := InitializeONNXRuntime(); err != nil {
|
|
fmt.Printf("Error initializing ONNX Runtime: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
defer ort.DestroyEnvironment()
|
|
|
|
// --- 2. Load config --- //
|
|
cfg, err := LoadCfgs(args.onnxDir)
|
|
if err != nil {
|
|
fmt.Printf("Error loading config: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
// --- 3. Load TTS components --- //
|
|
textToSpeech, err := LoadTextToSpeech(args.onnxDir, args.useGPU, cfg)
|
|
if err != nil {
|
|
fmt.Printf("Error loading TTS components: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
defer textToSpeech.Destroy()
|
|
|
|
// --- 4. Load voice styles --- //
|
|
style, err := LoadVoiceStyle(voiceStylePaths, true)
|
|
if err != nil {
|
|
fmt.Printf("Error loading voice styles: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
defer style.Destroy()
|
|
|
|
// --- 5. Synthesize speech --- //
|
|
if err := os.MkdirAll(saveDir, 0755); err != nil {
|
|
fmt.Printf("Error creating save directory: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
for n := 0; n < nTest; n++ {
|
|
fmt.Printf("\n[%d/%d] Starting synthesis...\n", n+1, nTest)
|
|
|
|
var wav []float32
|
|
var duration []float32
|
|
|
|
if batch {
|
|
Timer("Generating speech from text", func() interface{} {
|
|
w, d, err := textToSpeech.Batch(textList, langList, style, totalStep, speed)
|
|
if err != nil {
|
|
fmt.Printf("Error generating speech: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
wav = w
|
|
duration = d
|
|
return nil
|
|
})
|
|
} else {
|
|
Timer("Generating speech from text", func() interface{} {
|
|
w, d, err := textToSpeech.Call(textList[0], langList[0], style, totalStep, speed, 0.3)
|
|
if err != nil {
|
|
fmt.Printf("Error generating speech: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
wav = w
|
|
duration = []float32{d}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// Save outputs
|
|
for i := 0; i < bsz; i++ {
|
|
fname := fmt.Sprintf("%s_%d.wav", sanitizeFilename(textList[i], 20), n+1)
|
|
var wavOut []float64
|
|
|
|
if batch {
|
|
wavOut = extractWavSegment(wav, duration[i], textToSpeech.SampleRate, i, bsz)
|
|
} else {
|
|
// For non-batch mode, wav is a single concatenated audio
|
|
wavLen := int(float32(textToSpeech.SampleRate) * duration[0])
|
|
wavOut = make([]float64, wavLen)
|
|
for j := 0; j < wavLen && j < len(wav); j++ {
|
|
wavOut[j] = float64(wav[j])
|
|
}
|
|
}
|
|
|
|
outputPath := filepath.Join(saveDir, fname)
|
|
if err := writeWavFile(outputPath, wavOut, textToSpeech.SampleRate); err != nil {
|
|
fmt.Printf("Error writing wav file: %v\n", err)
|
|
continue
|
|
}
|
|
fmt.Printf("Saved: %s\n", outputPath)
|
|
}
|
|
}
|
|
|
|
fmt.Println("\n=== Synthesis completed successfully! ===")
|
|
}
|