initial commit
This commit is contained in:
193
go/example_onnx.go
Normal file
193
go/example_onnx.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
ort "github.com/yalue/onnxruntime_go"
|
||||
)
|
||||
|
||||
// Args holds command line arguments
|
||||
type Args struct {
|
||||
useGPU bool
|
||||
onnxDir string
|
||||
totalStep int
|
||||
speed float64
|
||||
nTest int
|
||||
voiceStyle []string
|
||||
text []string
|
||||
lang []string
|
||||
saveDir string
|
||||
batch bool
|
||||
}
|
||||
|
||||
func parseArgs() *Args {
|
||||
args := &Args{}
|
||||
|
||||
flag.BoolVar(&args.useGPU, "use-gpu", false, "Use GPU for inference (default: CPU)")
|
||||
flag.StringVar(&args.onnxDir, "onnx-dir", "assets/onnx", "Path to ONNX model directory")
|
||||
flag.IntVar(&args.totalStep, "total-step", 5, "Number of denoising steps")
|
||||
flag.Float64Var(&args.speed, "speed", 1.05, "Speech speed factor (higher = faster)")
|
||||
flag.IntVar(&args.nTest, "n-test", 4, "Number of times to generate")
|
||||
flag.StringVar(&args.saveDir, "save-dir", "results", "Output directory")
|
||||
flag.BoolVar(&args.batch, "batch", false, "Enable batch mode (multiple text-style pairs)")
|
||||
|
||||
var voiceStyleStr, textStr, langStr string
|
||||
flag.StringVar(&voiceStyleStr, "voice-style", "assets/voice_styles/M1.json", "Voice style file path(s), comma-separated")
|
||||
flag.StringVar(&textStr, "text", "This morning, I took a walk in the park, and the sound of the birds and the breeze was so pleasant that I stopped for a long time just to listen.", "Text(s) to synthesize, pipe-separated")
|
||||
flag.StringVar(&langStr, "lang", "en", "Language(s) for synthesis, comma-separated (en, ko, es, pt, fr)")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
// Parse comma-separated voice-style
|
||||
if voiceStyleStr != "" {
|
||||
args.voiceStyle = strings.Split(voiceStyleStr, ",")
|
||||
for i := range args.voiceStyle {
|
||||
args.voiceStyle[i] = strings.TrimSpace(args.voiceStyle[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Parse pipe-separated text
|
||||
if textStr != "" {
|
||||
args.text = strings.Split(textStr, "|")
|
||||
for i := range args.text {
|
||||
args.text[i] = strings.TrimSpace(args.text[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Parse comma-separated lang
|
||||
if langStr != "" {
|
||||
args.lang = strings.Split(langStr, ",")
|
||||
for i := range args.lang {
|
||||
args.lang[i] = strings.TrimSpace(args.lang[i])
|
||||
}
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
func main() {
|
||||
fmt.Println("=== TTS Inference with ONNX Runtime (Go) ===\n")
|
||||
|
||||
// --- 1. Parse arguments --- //
|
||||
args := parseArgs()
|
||||
totalStep := args.totalStep
|
||||
speed := float32(args.speed)
|
||||
nTest := args.nTest
|
||||
saveDir := args.saveDir
|
||||
voiceStylePaths := args.voiceStyle
|
||||
textList := args.text
|
||||
langList := args.lang
|
||||
batch := args.batch
|
||||
|
||||
if batch {
|
||||
if len(voiceStylePaths) != len(textList) {
|
||||
fmt.Printf("Error: Number of voice styles (%d) must match number of texts (%d)\n",
|
||||
len(voiceStylePaths), len(textList))
|
||||
os.Exit(1)
|
||||
}
|
||||
if len(langList) != len(textList) {
|
||||
fmt.Printf("Error: Number of languages (%d) must match number of texts (%d)\n",
|
||||
len(langList), len(textList))
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
bsz := len(voiceStylePaths)
|
||||
|
||||
// Initialize ONNX Runtime
|
||||
if err := InitializeONNXRuntime(); err != nil {
|
||||
fmt.Printf("Error initializing ONNX Runtime: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer ort.DestroyEnvironment()
|
||||
|
||||
// --- 2. Load config --- //
|
||||
cfg, err := LoadCfgs(args.onnxDir)
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading config: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// --- 3. Load TTS components --- //
|
||||
textToSpeech, err := LoadTextToSpeech(args.onnxDir, args.useGPU, cfg)
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading TTS components: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer textToSpeech.Destroy()
|
||||
|
||||
// --- 4. Load voice styles --- //
|
||||
style, err := LoadVoiceStyle(voiceStylePaths, true)
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading voice styles: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer style.Destroy()
|
||||
|
||||
// --- 5. Synthesize speech --- //
|
||||
if err := os.MkdirAll(saveDir, 0755); err != nil {
|
||||
fmt.Printf("Error creating save directory: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
for n := 0; n < nTest; n++ {
|
||||
fmt.Printf("\n[%d/%d] Starting synthesis...\n", n+1, nTest)
|
||||
|
||||
var wav []float32
|
||||
var duration []float32
|
||||
|
||||
if batch {
|
||||
Timer("Generating speech from text", func() interface{} {
|
||||
w, d, err := textToSpeech.Batch(textList, langList, style, totalStep, speed)
|
||||
if err != nil {
|
||||
fmt.Printf("Error generating speech: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
wav = w
|
||||
duration = d
|
||||
return nil
|
||||
})
|
||||
} else {
|
||||
Timer("Generating speech from text", func() interface{} {
|
||||
w, d, err := textToSpeech.Call(textList[0], langList[0], style, totalStep, speed, 0.3)
|
||||
if err != nil {
|
||||
fmt.Printf("Error generating speech: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
wav = w
|
||||
duration = []float32{d}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Save outputs
|
||||
for i := 0; i < bsz; i++ {
|
||||
fname := fmt.Sprintf("%s_%d.wav", sanitizeFilename(textList[i], 20), n+1)
|
||||
var wavOut []float64
|
||||
|
||||
if batch {
|
||||
wavOut = extractWavSegment(wav, duration[i], textToSpeech.SampleRate, i, bsz)
|
||||
} else {
|
||||
// For non-batch mode, wav is a single concatenated audio
|
||||
wavLen := int(float32(textToSpeech.SampleRate) * duration[0])
|
||||
wavOut = make([]float64, wavLen)
|
||||
for j := 0; j < wavLen && j < len(wav); j++ {
|
||||
wavOut[j] = float64(wav[j])
|
||||
}
|
||||
}
|
||||
|
||||
outputPath := filepath.Join(saveDir, fname)
|
||||
if err := writeWavFile(outputPath, wavOut, textToSpeech.SampleRate); err != nil {
|
||||
fmt.Printf("Error writing wav file: %v\n", err)
|
||||
continue
|
||||
}
|
||||
fmt.Printf("Saved: %s\n", outputPath)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("\n=== Synthesis completed successfully! ===")
|
||||
}
|
||||
Reference in New Issue
Block a user