Files
Supertonic/go/example_onnx.go
2026-01-25 18:58:40 +09:00

194 lines
5.3 KiB
Go

package main
import (
"flag"
"fmt"
"os"
"path/filepath"
"strings"
ort "github.com/yalue/onnxruntime_go"
)
// Args holds command line arguments
type Args struct {
useGPU bool
onnxDir string
totalStep int
speed float64
nTest int
voiceStyle []string
text []string
lang []string
saveDir string
batch bool
}
func parseArgs() *Args {
args := &Args{}
flag.BoolVar(&args.useGPU, "use-gpu", false, "Use GPU for inference (default: CPU)")
flag.StringVar(&args.onnxDir, "onnx-dir", "assets/onnx", "Path to ONNX model directory")
flag.IntVar(&args.totalStep, "total-step", 5, "Number of denoising steps")
flag.Float64Var(&args.speed, "speed", 1.05, "Speech speed factor (higher = faster)")
flag.IntVar(&args.nTest, "n-test", 4, "Number of times to generate")
flag.StringVar(&args.saveDir, "save-dir", "results", "Output directory")
flag.BoolVar(&args.batch, "batch", false, "Enable batch mode (multiple text-style pairs)")
var voiceStyleStr, textStr, langStr string
flag.StringVar(&voiceStyleStr, "voice-style", "assets/voice_styles/M1.json", "Voice style file path(s), comma-separated")
flag.StringVar(&textStr, "text", "This morning, I took a walk in the park, and the sound of the birds and the breeze was so pleasant that I stopped for a long time just to listen.", "Text(s) to synthesize, pipe-separated")
flag.StringVar(&langStr, "lang", "en", "Language(s) for synthesis, comma-separated (en, ko, es, pt, fr)")
flag.Parse()
// Parse comma-separated voice-style
if voiceStyleStr != "" {
args.voiceStyle = strings.Split(voiceStyleStr, ",")
for i := range args.voiceStyle {
args.voiceStyle[i] = strings.TrimSpace(args.voiceStyle[i])
}
}
// Parse pipe-separated text
if textStr != "" {
args.text = strings.Split(textStr, "|")
for i := range args.text {
args.text[i] = strings.TrimSpace(args.text[i])
}
}
// Parse comma-separated lang
if langStr != "" {
args.lang = strings.Split(langStr, ",")
for i := range args.lang {
args.lang[i] = strings.TrimSpace(args.lang[i])
}
}
return args
}
func main() {
fmt.Println("=== TTS Inference with ONNX Runtime (Go) ===\n")
// --- 1. Parse arguments --- //
args := parseArgs()
totalStep := args.totalStep
speed := float32(args.speed)
nTest := args.nTest
saveDir := args.saveDir
voiceStylePaths := args.voiceStyle
textList := args.text
langList := args.lang
batch := args.batch
if batch {
if len(voiceStylePaths) != len(textList) {
fmt.Printf("Error: Number of voice styles (%d) must match number of texts (%d)\n",
len(voiceStylePaths), len(textList))
os.Exit(1)
}
if len(langList) != len(textList) {
fmt.Printf("Error: Number of languages (%d) must match number of texts (%d)\n",
len(langList), len(textList))
os.Exit(1)
}
}
bsz := len(voiceStylePaths)
// Initialize ONNX Runtime
if err := InitializeONNXRuntime(); err != nil {
fmt.Printf("Error initializing ONNX Runtime: %v\n", err)
os.Exit(1)
}
defer ort.DestroyEnvironment()
// --- 2. Load config --- //
cfg, err := LoadCfgs(args.onnxDir)
if err != nil {
fmt.Printf("Error loading config: %v\n", err)
os.Exit(1)
}
// --- 3. Load TTS components --- //
textToSpeech, err := LoadTextToSpeech(args.onnxDir, args.useGPU, cfg)
if err != nil {
fmt.Printf("Error loading TTS components: %v\n", err)
os.Exit(1)
}
defer textToSpeech.Destroy()
// --- 4. Load voice styles --- //
style, err := LoadVoiceStyle(voiceStylePaths, true)
if err != nil {
fmt.Printf("Error loading voice styles: %v\n", err)
os.Exit(1)
}
defer style.Destroy()
// --- 5. Synthesize speech --- //
if err := os.MkdirAll(saveDir, 0755); err != nil {
fmt.Printf("Error creating save directory: %v\n", err)
os.Exit(1)
}
for n := 0; n < nTest; n++ {
fmt.Printf("\n[%d/%d] Starting synthesis...\n", n+1, nTest)
var wav []float32
var duration []float32
if batch {
Timer("Generating speech from text", func() interface{} {
w, d, err := textToSpeech.Batch(textList, langList, style, totalStep, speed)
if err != nil {
fmt.Printf("Error generating speech: %v\n", err)
os.Exit(1)
}
wav = w
duration = d
return nil
})
} else {
Timer("Generating speech from text", func() interface{} {
w, d, err := textToSpeech.Call(textList[0], langList[0], style, totalStep, speed, 0.3)
if err != nil {
fmt.Printf("Error generating speech: %v\n", err)
os.Exit(1)
}
wav = w
duration = []float32{d}
return nil
})
}
// Save outputs
for i := 0; i < bsz; i++ {
fname := fmt.Sprintf("%s_%d.wav", sanitizeFilename(textList[i], 20), n+1)
var wavOut []float64
if batch {
wavOut = extractWavSegment(wav, duration[i], textToSpeech.SampleRate, i, bsz)
} else {
// For non-batch mode, wav is a single concatenated audio
wavLen := int(float32(textToSpeech.SampleRate) * duration[0])
wavOut = make([]float64, wavLen)
for j := 0; j < wavLen && j < len(wav); j++ {
wavOut[j] = float64(wav[j])
}
}
outputPath := filepath.Join(saveDir, fname)
if err := writeWavFile(outputPath, wavOut, textToSpeech.SampleRate); err != nil {
fmt.Printf("Error writing wav file: %v\n", err)
continue
}
fmt.Printf("Saved: %s\n", outputPath)
}
}
fmt.Println("\n=== Synthesis completed successfully! ===")
}