Files
Supertonic/go/helper.go
2026-01-25 18:58:40 +09:00

1067 lines
27 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"encoding/json"
"fmt"
"math"
"math/rand"
"os"
"path/filepath"
"regexp"
"strings"
"time"
"unicode"
"github.com/go-audio/audio"
"github.com/go-audio/wav"
ort "github.com/yalue/onnxruntime_go"
"golang.org/x/text/unicode/norm"
)
// Available languages for multilingual TTS
var AvailableLangs = []string{"en", "ko", "es", "pt", "fr"}
// Config structures
type SpecProcessorConfig struct {
NFFT int `json:"n_fft"`
WinLength int `json:"win_length"`
HopLength int `json:"hop_length"`
NMels int `json:"n_mels"`
Eps float64 `json:"eps"`
NormMean float64 `json:"norm_mean"`
NormStd float64 `json:"norm_std"`
}
type EncoderConfig struct {
SpecProcessor SpecProcessorConfig `json:"spec_processor"`
}
type AEConfig struct {
SampleRate int `json:"sample_rate"`
BaseChunkSize int `json:"base_chunk_size"`
Encoder EncoderConfig `json:"encoder"`
}
type StyleTokenLayerConfig struct {
NStyle int `json:"n_style"`
StyleValueDim int `json:"style_value_dim"`
}
type StyleEncoderConfig struct {
StyleTokenLayer StyleTokenLayerConfig `json:"style_token_layer"`
}
type ProjOutConfig struct {
Idim int `json:"idim"`
Odim int `json:"odim"`
}
type TextEncoderConfig struct {
ProjOut ProjOutConfig `json:"proj_out"`
}
type TTLConfig struct {
ChunkCompressFactor int `json:"chunk_compress_factor"`
LatentDim int `json:"latent_dim"`
StyleEncoder StyleEncoderConfig `json:"style_encoder"`
TextEncoder TextEncoderConfig `json:"text_encoder"`
}
type DPStyleEncoderConfig struct {
StyleTokenLayer StyleTokenLayerConfig `json:"style_token_layer"`
}
type DPConfig struct {
LatentDim int `json:"latent_dim"`
ChunkCompressFactor int `json:"chunk_compress_factor"`
StyleEncoder DPStyleEncoderConfig `json:"style_encoder"`
}
type Config struct {
AE AEConfig `json:"ae"`
TTL TTLConfig `json:"ttl"`
DP DPConfig `json:"dp"`
}
// VoiceStyleData holds voice style JSON structure
type VoiceStyleData struct {
StyleTTL struct {
Data [][][]float64 `json:"data"`
Dims []int64 `json:"dims"`
Type string `json:"type"`
} `json:"style_ttl"`
StyleDP struct {
Data [][][]float64 `json:"data"`
Dims []int64 `json:"dims"`
Type string `json:"type"`
} `json:"style_dp"`
}
// UnicodeProcessor for text processing
type UnicodeProcessor struct {
indexer []int64
}
// NewUnicodeProcessor creates a new UnicodeProcessor
func NewUnicodeProcessor(unicodeIndexerPath string) (*UnicodeProcessor, error) {
indexer, err := loadJSONInt64(unicodeIndexerPath)
if err != nil {
return nil, fmt.Errorf("failed to load unicode indexer: %w", err)
}
return &UnicodeProcessor{indexer: indexer}, nil
}
// Call processes text list to text IDs and mask
func (up *UnicodeProcessor) Call(textList []string, langList []string) ([][]int64, [][][]float64) {
// Preprocess texts
processedTexts := make([]string, len(textList))
for i, text := range textList {
processedTexts[i] = preprocessText(text, langList[i])
}
// Get text lengths
textLengths := make([]int64, len(processedTexts))
maxLen := 0
for i, text := range processedTexts {
textLengths[i] = int64(len([]rune(text)))
if int(textLengths[i]) > maxLen {
maxLen = int(textLengths[i])
}
}
// Create text IDs
textIDs := make([][]int64, len(processedTexts))
for i, text := range processedTexts {
row := make([]int64, maxLen)
runes := []rune(text)
for j, r := range runes {
unicodeVal := int(r)
if unicodeVal < len(up.indexer) {
row[j] = up.indexer[unicodeVal]
} else {
row[j] = -1
}
}
textIDs[i] = row
}
// Create text mask
textMask := lengthToMask(textLengths, maxLen)
return textIDs, textMask
}
// Text chunking utilities
const maxChunkLength = 300
var abbreviations = []string{
"Dr.", "Mr.", "Mrs.", "Ms.", "Prof.", "Sr.", "Jr.",
"St.", "Ave.", "Rd.", "Blvd.", "Dept.", "Inc.", "Ltd.",
"Co.", "Corp.", "etc.", "vs.", "i.e.", "e.g.", "Ph.D.",
}
func chunkText(text string, maxLen int) []string {
if maxLen == 0 {
maxLen = maxChunkLength
}
text = strings.TrimSpace(text)
if text == "" {
return []string{""}
}
// Split by paragraphs
paragraphs := regexp.MustCompile(`\n\s*\n`).Split(text, -1)
var chunks []string
for _, para := range paragraphs {
para = strings.TrimSpace(para)
if para == "" {
continue
}
if len(para) <= maxLen {
chunks = append(chunks, para)
continue
}
// Split by sentences
sentences := splitSentences(para)
var current strings.Builder
currentLen := 0
for _, sentence := range sentences {
sentence = strings.TrimSpace(sentence)
if sentence == "" {
continue
}
sentenceLen := len(sentence)
if sentenceLen > maxLen {
// If sentence is longer than maxLen, split by comma or space
if current.Len() > 0 {
chunks = append(chunks, strings.TrimSpace(current.String()))
current.Reset()
currentLen = 0
}
// Try splitting by comma
parts := strings.Split(sentence, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
partLen := len(part)
if partLen > maxLen {
// Split by space as last resort
words := strings.Fields(part)
var wordChunk strings.Builder
wordChunkLen := 0
for _, word := range words {
wordLen := len(word)
if wordChunkLen+wordLen+1 > maxLen && wordChunk.Len() > 0 {
chunks = append(chunks, strings.TrimSpace(wordChunk.String()))
wordChunk.Reset()
wordChunkLen = 0
}
if wordChunk.Len() > 0 {
wordChunk.WriteString(" ")
wordChunkLen++
}
wordChunk.WriteString(word)
wordChunkLen += wordLen
}
if wordChunk.Len() > 0 {
chunks = append(chunks, strings.TrimSpace(wordChunk.String()))
}
} else {
if currentLen+partLen+1 > maxLen && current.Len() > 0 {
chunks = append(chunks, strings.TrimSpace(current.String()))
current.Reset()
currentLen = 0
}
if current.Len() > 0 {
current.WriteString(", ")
currentLen += 2
}
current.WriteString(part)
currentLen += partLen
}
}
continue
}
if currentLen+sentenceLen+1 > maxLen && current.Len() > 0 {
chunks = append(chunks, strings.TrimSpace(current.String()))
current.Reset()
currentLen = 0
}
if current.Len() > 0 {
current.WriteString(" ")
currentLen++
}
current.WriteString(sentence)
currentLen += sentenceLen
}
if current.Len() > 0 {
chunks = append(chunks, strings.TrimSpace(current.String()))
}
}
if len(chunks) == 0 {
return []string{""}
}
return chunks
}
func splitSentences(text string) []string {
// Go's regexp doesn't support lookbehind, so we use a simpler approach
// Split on sentence boundaries and then check if they're abbreviations
re := regexp.MustCompile(`([.!?])\s+`)
// Find all matches
matches := re.FindAllStringIndex(text, -1)
if len(matches) == 0 {
return []string{text}
}
var sentences []string
lastEnd := 0
for _, match := range matches {
// Get the text before the punctuation
beforePunc := text[lastEnd:match[0]]
// Check if this ends with an abbreviation
isAbbrev := false
for _, abbrev := range abbreviations {
if strings.HasSuffix(strings.TrimSpace(beforePunc+text[match[0]:match[0]+1]), abbrev) {
isAbbrev = true
break
}
}
if !isAbbrev {
// This is a real sentence boundary
sentences = append(sentences, text[lastEnd:match[1]])
lastEnd = match[1]
}
}
// Add the remaining text
if lastEnd < len(text) {
sentences = append(sentences, text[lastEnd:])
}
if len(sentences) == 0 {
return []string{text}
}
return sentences
}
// isValidLang checks if a language is in the available languages list
func isValidLang(lang string) bool {
for _, l := range AvailableLangs {
if l == lang {
return true
}
}
return false
}
// Utility functions
func preprocessText(text string, lang string) string {
// TODO: Need advanced normalizer for better performance
// Apply NFKD normalization using golang.org/x/text/unicode/norm
text = norm.NFKD.String(text)
// Remove emojis and various Unicode symbols
emojiPattern := regexp.MustCompile(`[\x{1F600}-\x{1F64F}\x{1F300}-\x{1F5FF}\x{1F680}-\x{1F6FF}\x{1F700}-\x{1F77F}\x{1F780}-\x{1F7FF}\x{1F800}-\x{1F8FF}\x{1F900}-\x{1F9FF}\x{1FA00}-\x{1FA6F}\x{1FA70}-\x{1FAFF}\x{2600}-\x{26FF}\x{2700}-\x{27BF}\x{1F1E6}-\x{1F1FF}]+`)
text = emojiPattern.ReplaceAllString(text, "")
// Replace various dashes and symbols
replacements := map[string]string{
"": "-", // en dash
"": "-", // non-breaking hyphen
"—": "-", // em dash
"_": " ", // underscore
"\u201C": "\"", // left double quote
"\u201D": "\"", // right double quote
"\u2018": "'", // left single quote
"\u2019": "'", // right single quote
"´": "'", // acute accent
"`": "'", // grave accent
"[": " ", // left bracket
"]": " ", // right bracket
"|": " ", // vertical bar
"/": " ", // slash
"#": " ", // hash
"→": " ", // right arrow
"←": " ", // left arrow
}
for old, new := range replacements {
text = strings.ReplaceAll(text, old, new)
}
// Remove special symbols
specialSymbols := []string{"♥", "☆", "♡", "©", "\\"}
for _, symbol := range specialSymbols {
text = strings.ReplaceAll(text, symbol, "")
}
// Replace known expressions
exprReplacements := map[string]string{
"@": " at ",
"e.g.,": "for example, ",
"i.e.,": "that is, ",
}
for old, new := range exprReplacements {
text = strings.ReplaceAll(text, old, new)
}
// Fix spacing around punctuation
text = regexp.MustCompile(` ,`).ReplaceAllString(text, ",")
text = regexp.MustCompile(` \.`).ReplaceAllString(text, ".")
text = regexp.MustCompile(` !`).ReplaceAllString(text, "!")
text = regexp.MustCompile(` \?`).ReplaceAllString(text, "?")
text = regexp.MustCompile(` ;`).ReplaceAllString(text, ";")
text = regexp.MustCompile(` :`).ReplaceAllString(text, ":")
text = regexp.MustCompile(` '`).ReplaceAllString(text, "'")
// Remove duplicate quotes
for strings.Contains(text, `""`) {
text = strings.ReplaceAll(text, `""`, `"`)
}
for strings.Contains(text, "''") {
text = strings.ReplaceAll(text, "''", "'")
}
for strings.Contains(text, "``") {
text = strings.ReplaceAll(text, "``", "`")
}
// Remove extra spaces
text = regexp.MustCompile(`\s+`).ReplaceAllString(text, " ")
text = strings.TrimSpace(text)
// If text doesn't end with punctuation, quotes, or closing brackets, add a period
if text != "" {
endsWithPunct := regexp.MustCompile(`[.!?;:,'"\x{201C}\x{201D}\x{2018}\x{2019})\]}…。」』】〉》›»]$`)
if !endsWithPunct.MatchString(text) {
text += "."
}
}
// Validate language
if !isValidLang(lang) {
panic(fmt.Sprintf("Invalid language: %s. Available: %v", lang, AvailableLangs))
}
// Wrap text with language tags
text = fmt.Sprintf("<%s>%s</%s>", lang, text, lang)
return text
}
func lengthToMask(lengths []int64, maxLen int) [][][]float64 {
bsz := len(lengths)
mask := make([][][]float64, bsz)
for i := 0; i < bsz; i++ {
row := make([]float64, maxLen)
for j := 0; j < maxLen; j++ {
if int64(j) < lengths[i] {
row[j] = 1.0
} else {
row[j] = 0.0
}
}
mask[i] = [][]float64{row}
}
return mask
}
func getTextMask(textLengths []int64, maxLen int) [][][]float64 {
return lengthToMask(textLengths, maxLen)
}
func getLatentMask(wavLengths []int64, cfg Config) [][][]float64 {
baseChunkSize := int64(cfg.AE.BaseChunkSize)
chunkCompressFactor := int64(cfg.TTL.ChunkCompressFactor)
latentSize := baseChunkSize * chunkCompressFactor
latentLengths := make([]int64, len(wavLengths))
maxLen := int64(0)
for i, wavLen := range wavLengths {
latentLengths[i] = (wavLen + latentSize - 1) / latentSize
if latentLengths[i] > maxLen {
maxLen = latentLengths[i]
}
}
return lengthToMask(latentLengths, int(maxLen))
}
func writeWavFile(filename string, audioData []float64, sampleRate int) error {
file, err := os.Create(filename)
if err != nil {
return err
}
defer file.Close()
// Convert float64 to int
intData := make([]int, len(audioData))
for i, sample := range audioData {
// Clamp to [-1, 1] and convert to 16-bit int
clamped := math.Max(-1.0, math.Min(1.0, sample))
intData[i] = int(clamped * 32767)
}
encoder := wav.NewEncoder(file, sampleRate, 16, 1, 1)
buf := &audio.IntBuffer{
Data: intData,
Format: &audio.Format{SampleRate: sampleRate, NumChannels: 1},
SourceBitDepth: 16,
}
if err := encoder.Write(buf); err != nil {
return err
}
return encoder.Close()
}
// Style holds style tensors
type Style struct {
TtlTensor *ort.Tensor[float32]
DpTensor *ort.Tensor[float32]
}
func (s *Style) Destroy() {
if s.TtlTensor != nil {
s.TtlTensor.Destroy()
}
if s.DpTensor != nil {
s.DpTensor.Destroy()
}
}
// LoadVoiceStyle loads voice style from JSON files
func LoadVoiceStyle(voiceStylePaths []string, verbose bool) (*Style, error) {
bsz := len(voiceStylePaths)
// Read first file to get dimensions
firstData, err := os.ReadFile(voiceStylePaths[0])
if err != nil {
return nil, fmt.Errorf("failed to read voice style file: %w", err)
}
var firstStyle VoiceStyleData
if err := json.Unmarshal(firstData, &firstStyle); err != nil {
return nil, fmt.Errorf("failed to parse voice style JSON: %w", err)
}
ttlDims := firstStyle.StyleTTL.Dims
dpDims := firstStyle.StyleDP.Dims
ttlDim1 := ttlDims[1]
ttlDim2 := ttlDims[2]
dpDim1 := dpDims[1]
dpDim2 := dpDims[2]
// Pre-allocate arrays with full batch size
ttlSize := int(int64(bsz) * ttlDim1 * ttlDim2)
dpSize := int(int64(bsz) * dpDim1 * dpDim2)
ttlFlat := make([]float32, ttlSize)
dpFlat := make([]float32, dpSize)
// Fill in the data
for i := 0; i < bsz; i++ {
data, err := os.ReadFile(voiceStylePaths[i])
if err != nil {
return nil, fmt.Errorf("failed to read voice style file: %w", err)
}
var voiceStyle VoiceStyleData
if err := json.Unmarshal(data, &voiceStyle); err != nil {
return nil, fmt.Errorf("failed to parse voice style JSON: %w", err)
}
// Flatten TTL data
ttlOffset := int(int64(i) * ttlDim1 * ttlDim2)
idx := 0
for _, batch := range voiceStyle.StyleTTL.Data {
for _, row := range batch {
for _, val := range row {
ttlFlat[ttlOffset+idx] = float32(val)
idx++
}
}
}
// Flatten DP data
dpOffset := int(int64(i) * dpDim1 * dpDim2)
idx = 0
for _, batch := range voiceStyle.StyleDP.Data {
for _, row := range batch {
for _, val := range row {
dpFlat[dpOffset+idx] = float32(val)
idx++
}
}
}
}
ttlShape := []int64{int64(bsz), ttlDim1, ttlDim2}
dpShape := []int64{int64(bsz), dpDim1, dpDim2}
ttlTensor, err := ort.NewTensor(ttlShape, ttlFlat)
if err != nil {
return nil, fmt.Errorf("failed to create TTL tensor: %w", err)
}
dpTensor, err := ort.NewTensor(dpShape, dpFlat)
if err != nil {
ttlTensor.Destroy()
return nil, fmt.Errorf("failed to create DP tensor: %w", err)
}
if verbose {
fmt.Printf("Loaded %d voice styles\n\n", bsz)
}
return &Style{
TtlTensor: ttlTensor,
DpTensor: dpTensor,
}, nil
}
// TextToSpeech generates speech from text
type TextToSpeech struct {
cfg Config
textProcessor *UnicodeProcessor
dpOrt *ort.DynamicAdvancedSession
textEncOrt *ort.DynamicAdvancedSession
vectorEstOrt *ort.DynamicAdvancedSession
vocoderOrt *ort.DynamicAdvancedSession
SampleRate int
baseChunkSize int
chunkCompress int
ldim int
}
func (tts *TextToSpeech) sampleNoisyLatent(durOnnx []float32) ([][][]float64, [][][]float64) {
bsz := len(durOnnx)
maxDur := float64(0)
for _, d := range durOnnx {
if float64(d) > maxDur {
maxDur = float64(d)
}
}
wavLenMax := maxDur * float64(tts.SampleRate)
wavLengths := make([]int64, bsz)
for i, d := range durOnnx {
wavLengths[i] = int64(float64(d) * float64(tts.SampleRate))
}
chunkSize := tts.baseChunkSize * tts.chunkCompress
latentLen := int((wavLenMax + float64(chunkSize) - 1) / float64(chunkSize))
latentDim := tts.ldim * tts.chunkCompress
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
noisyLatent := make([][][]float64, bsz)
for b := 0; b < bsz; b++ {
batch := make([][]float64, latentDim)
for d := 0; d < latentDim; d++ {
row := make([]float64, latentLen)
for t := 0; t < latentLen; t++ {
// Box-Muller transform for normal distribution
// Add epsilon to avoid log(0)
const eps = 1e-10
u1 := math.Max(eps, rng.Float64())
u2 := rng.Float64()
row[t] = math.Sqrt(-2.0*math.Log(u1)) * math.Cos(2.0*math.Pi*u2)
}
batch[d] = row
}
noisyLatent[b] = batch
}
latentMask := getLatentMask(wavLengths, tts.cfg)
// Apply mask
for b := 0; b < bsz; b++ {
for d := 0; d < latentDim; d++ {
for t := 0; t < latentLen; t++ {
noisyLatent[b][d][t] *= latentMask[b][0][t]
}
}
}
return noisyLatent, latentMask
}
func (tts *TextToSpeech) _infer(textList []string, langList []string, style *Style, totalStep int, speed float32) ([]float32, []float32, error) {
bsz := len(textList)
// Process text
textIDs, textMask := tts.textProcessor.Call(textList, langList)
textIDsShape := []int64{int64(bsz), int64(len(textIDs[0]))}
textMaskShape := []int64{int64(bsz), 1, int64(len(textMask[0][0]))}
textIDsTensor := IntArrayToTensor(textIDs, textIDsShape)
defer textIDsTensor.Destroy()
textMaskTensor := ArrayToTensor(textMask, textMaskShape)
defer textMaskTensor.Destroy()
// Predict duration
dpOutputs := []ort.Value{nil}
err := tts.dpOrt.Run(
[]ort.Value{textIDsTensor, style.DpTensor, textMaskTensor},
dpOutputs,
)
if err != nil {
return nil, nil, fmt.Errorf("failed to run duration predictor: %w", err)
}
durTensor := dpOutputs[0].(*ort.Tensor[float32])
defer durTensor.Destroy()
durOnnx := durTensor.GetData()
// Apply speed factor to duration
for i := range durOnnx {
durOnnx[i] /= speed
}
// Encode text
textIDsTensor2 := IntArrayToTensor(textIDs, textIDsShape)
defer textIDsTensor2.Destroy()
textEncOutputs := []ort.Value{nil}
err = tts.textEncOrt.Run(
[]ort.Value{textIDsTensor2, style.TtlTensor, textMaskTensor},
textEncOutputs,
)
if err != nil {
return nil, nil, fmt.Errorf("failed to run text encoder: %w", err)
}
textEmbTensor := textEncOutputs[0].(*ort.Tensor[float32])
defer textEmbTensor.Destroy()
// Sample noisy latent
xt, latentMask := tts.sampleNoisyLatent(durOnnx)
latentShape := []int64{int64(bsz), int64(len(xt[0])), int64(len(xt[0][0]))}
latentMaskShape := []int64{int64(bsz), 1, int64(len(latentMask[0][0]))}
// Prepare constant arrays
totalStepArray := make([]float32, bsz)
for b := 0; b < bsz; b++ {
totalStepArray[b] = float32(totalStep)
}
scalarShape := []int64{int64(bsz)}
totalStepTensor, _ := ort.NewTensor(scalarShape, totalStepArray)
defer totalStepTensor.Destroy()
// Denoising loop
for step := 0; step < totalStep; step++ {
currentStepArray := make([]float32, bsz)
for b := 0; b < bsz; b++ {
currentStepArray[b] = float32(step)
}
currentStepTensor, _ := ort.NewTensor(scalarShape, currentStepArray)
noisyLatentTensor := ArrayToTensor(xt, latentShape)
latentMaskTensor := ArrayToTensor(latentMask, latentMaskShape)
textMaskTensor2 := ArrayToTensor(textMask, textMaskShape)
vectorEstOutputs := []ort.Value{nil}
err = tts.vectorEstOrt.Run(
[]ort.Value{noisyLatentTensor, textEmbTensor, style.TtlTensor, latentMaskTensor, textMaskTensor2,
currentStepTensor, totalStepTensor},
vectorEstOutputs,
)
if err != nil {
return nil, nil, fmt.Errorf("failed to run vector estimator: %w", err)
}
denoisedTensor := vectorEstOutputs[0].(*ort.Tensor[float32])
denoisedData := denoisedTensor.GetData()
// Update latent
idx := 0
for b := 0; b < bsz; b++ {
for d := 0; d < len(xt[b]); d++ {
for t := 0; t < len(xt[b][d]); t++ {
xt[b][d][t] = float64(denoisedData[idx])
idx++
}
}
}
noisyLatentTensor.Destroy()
latentMaskTensor.Destroy()
textMaskTensor2.Destroy()
currentStepTensor.Destroy()
denoisedTensor.Destroy()
}
// Generate waveform
finalLatentTensor := ArrayToTensor(xt, latentShape)
defer finalLatentTensor.Destroy()
vocoderOutputs := []ort.Value{nil}
err = tts.vocoderOrt.Run(
[]ort.Value{finalLatentTensor},
vocoderOutputs,
)
if err != nil {
return nil, nil, fmt.Errorf("failed to run vocoder: %w", err)
}
wavBatchTensor := vocoderOutputs[0].(*ort.Tensor[float32])
defer wavBatchTensor.Destroy()
wav := wavBatchTensor.GetData()
return wav, durOnnx, nil
}
// Call synthesizes speech from a single text with automatic chunking
func (tts *TextToSpeech) Call(text string, lang string, style *Style, totalStep int, speed float32, silenceDuration float32) ([]float32, float32, error) {
maxLen := 300
if lang == "ko" {
maxLen = 120
}
chunks := chunkText(text, maxLen)
var wavCat []float32
var durCat float32
for i, chunk := range chunks {
wav, duration, err := tts._infer([]string{chunk}, []string{lang}, style, totalStep, speed)
if err != nil {
return nil, 0, err
}
dur := duration[0]
wavLen := int(float32(tts.SampleRate) * dur)
wavChunk := wav[:wavLen]
if i == 0 {
wavCat = wavChunk
durCat = dur
} else {
silenceLen := int(silenceDuration * float32(tts.SampleRate))
silence := make([]float32, silenceLen)
wavCat = append(wavCat, silence...)
wavCat = append(wavCat, wavChunk...)
durCat += silenceDuration + dur
}
}
return wavCat, durCat, nil
}
// Batch synthesizes speech from multiple texts
func (tts *TextToSpeech) Batch(textList []string, langList []string, style *Style, totalStep int, speed float32) ([]float32, []float32, error) {
return tts._infer(textList, langList, style, totalStep, speed)
}
func (tts *TextToSpeech) Destroy() {
if tts.dpOrt != nil {
tts.dpOrt.Destroy()
}
if tts.textEncOrt != nil {
tts.textEncOrt.Destroy()
}
if tts.vectorEstOrt != nil {
tts.vectorEstOrt.Destroy()
}
if tts.vocoderOrt != nil {
tts.vocoderOrt.Destroy()
}
}
// LoadTextToSpeech loads TTS components
func LoadTextToSpeech(onnxDir string, useGPU bool, cfg Config) (*TextToSpeech, error) {
if useGPU {
return nil, fmt.Errorf("GPU mode is not supported yet")
}
fmt.Println("Using CPU for inference\n")
// Load models
dpPath := filepath.Join(onnxDir, "duration_predictor.onnx")
textEncPath := filepath.Join(onnxDir, "text_encoder.onnx")
vectorEstPath := filepath.Join(onnxDir, "vector_estimator.onnx")
vocoderPath := filepath.Join(onnxDir, "vocoder.onnx")
dpOrt, err := ort.NewDynamicAdvancedSession(dpPath, []string{"text_ids", "style_dp", "text_mask"},
[]string{"duration"}, nil)
if err != nil {
return nil, fmt.Errorf("failed to load duration predictor: %w", err)
}
textEncOrt, err := ort.NewDynamicAdvancedSession(textEncPath, []string{"text_ids", "style_ttl", "text_mask"},
[]string{"text_emb"}, nil)
if err != nil {
return nil, fmt.Errorf("failed to load text encoder: %w", err)
}
vectorEstOrt, err := ort.NewDynamicAdvancedSession(vectorEstPath,
[]string{"noisy_latent", "text_emb", "style_ttl", "latent_mask", "text_mask", "current_step", "total_step"},
[]string{"denoised_latent"}, nil)
if err != nil {
return nil, fmt.Errorf("failed to load vector estimator: %w", err)
}
vocoderOrt, err := ort.NewDynamicAdvancedSession(vocoderPath, []string{"latent"},
[]string{"wav_tts"}, nil)
if err != nil {
return nil, fmt.Errorf("failed to load vocoder: %w", err)
}
// Load text processor
unicodeIndexerPath := filepath.Join(onnxDir, "unicode_indexer.json")
textProcessor, err := NewUnicodeProcessor(unicodeIndexerPath)
if err != nil {
return nil, err
}
textToSpeech := &TextToSpeech{
cfg: cfg,
textProcessor: textProcessor,
dpOrt: dpOrt,
textEncOrt: textEncOrt,
vectorEstOrt: vectorEstOrt,
vocoderOrt: vocoderOrt,
SampleRate: cfg.AE.SampleRate,
baseChunkSize: cfg.AE.BaseChunkSize,
chunkCompress: cfg.TTL.ChunkCompressFactor,
ldim: cfg.TTL.LatentDim,
}
return textToSpeech, nil
}
// InitializeONNXRuntime initializes ONNX Runtime environment
func InitializeONNXRuntime() error {
libPath := os.Getenv("ONNXRUNTIME_LIB_PATH")
if libPath == "" {
libPath = "/usr/local/lib/libonnxruntime.so"
if _, err := os.Stat("/usr/local/lib/libonnxruntime.dylib"); err == nil {
libPath = "/usr/local/lib/libonnxruntime.dylib"
} else if _, err := os.Stat("/usr/lib/libonnxruntime.so"); err == nil {
libPath = "/usr/lib/libonnxruntime.so"
}
}
ort.SetSharedLibraryPath(libPath)
if err := ort.InitializeEnvironment(); err != nil {
return fmt.Errorf("failed to initialize ONNX Runtime: %w\nHint: Set ONNXRUNTIME_LIB_PATH environment variable", err)
}
return nil
}
// sanitizeFilename creates a safe filename from text (supports Unicode)
func sanitizeFilename(text string, maxLen int) string {
runes := []rune(text)
if len(runes) > maxLen {
runes = runes[:maxLen]
}
result := make([]rune, 0, len(runes))
for _, r := range runes {
// unicode.IsLetter matches any Unicode letter, unicode.IsDigit matches any Unicode digit
if unicode.IsLetter(r) || unicode.IsDigit(r) {
result = append(result, r)
} else {
result = append(result, '_')
}
}
return string(result)
}
// extractWavSegment extracts a single audio segment from batch output
func extractWavSegment(wav []float32, duration float32, sampleRate int, index int, batchSize int) []float64 {
wavLen := int(float64(sampleRate) * float64(duration))
wavPerBatch := len(wav) / batchSize
wavStart := index * wavPerBatch
wavEnd := wavStart + wavLen
if wavEnd > len(wav) {
wavEnd = len(wav)
}
wavOut := make([]float64, wavLen)
for j := 0; j < wavLen && wavStart+j < len(wav); j++ {
wavOut[j] = float64(wav[wavStart+j])
}
return wavOut
}
// Timer measures execution time
func Timer(name string, fn func() interface{}) interface{} {
start := time.Now()
fmt.Printf("%s...\n", name)
result := fn()
elapsed := time.Since(start).Seconds()
fmt.Printf(" -> %s completed in %.2f sec\n", name, elapsed)
return result
}
// LoadCfgs loads configuration from JSON file
func LoadCfgs(onnxDir string) (Config, error) {
cfgPath := filepath.Join(onnxDir, "tts.json")
data, err := os.ReadFile(cfgPath)
if err != nil {
return Config{}, err
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return Config{}, err
}
return cfg, nil
}
// JSON loading helpers
func loadJSONInt64(filePath string) ([]int64, error) {
data, err := os.ReadFile(filePath)
if err != nil {
return nil, err
}
var result []int64
if err := json.Unmarshal(data, &result); err != nil {
return nil, err
}
return result, nil
}
// Tensor conversion utilities
func ArrayToTensor(array [][][]float64, shape []int64) *ort.Tensor[float32] {
// Flatten array
totalSize := int64(1)
for _, dim := range shape {
totalSize *= dim
}
flat := make([]float32, totalSize)
idx := 0
for b := 0; b < len(array); b++ {
for d := 0; d < len(array[b]); d++ {
for t := 0; t < len(array[b][d]); t++ {
flat[idx] = float32(array[b][d][t])
idx++
}
}
}
tensor, err := ort.NewTensor(shape, flat)
if err != nil {
panic(err)
}
return tensor
}
func IntArrayToTensor(array [][]int64, shape []int64) *ort.Tensor[int64] {
// Flatten array
totalSize := int64(1)
for _, dim := range shape {
totalSize *= dim
}
flat := make([]int64, totalSize)
idx := 0
for b := 0; b < len(array); b++ {
for t := 0; t < len(array[b]); t++ {
flat[idx] = array[b][t]
idx++
}
}
tensor, err := ort.NewTensor(shape, flat)
if err != nil {
panic(err)
}
return tensor
}