1067 lines
27 KiB
Go
1067 lines
27 KiB
Go
package main
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"math"
|
||
"math/rand"
|
||
"os"
|
||
"path/filepath"
|
||
"regexp"
|
||
"strings"
|
||
"time"
|
||
"unicode"
|
||
|
||
"github.com/go-audio/audio"
|
||
"github.com/go-audio/wav"
|
||
ort "github.com/yalue/onnxruntime_go"
|
||
"golang.org/x/text/unicode/norm"
|
||
)
|
||
|
||
// Available languages for multilingual TTS
|
||
var AvailableLangs = []string{"en", "ko", "es", "pt", "fr"}
|
||
|
||
// Config structures
|
||
type SpecProcessorConfig struct {
|
||
NFFT int `json:"n_fft"`
|
||
WinLength int `json:"win_length"`
|
||
HopLength int `json:"hop_length"`
|
||
NMels int `json:"n_mels"`
|
||
Eps float64 `json:"eps"`
|
||
NormMean float64 `json:"norm_mean"`
|
||
NormStd float64 `json:"norm_std"`
|
||
}
|
||
|
||
type EncoderConfig struct {
|
||
SpecProcessor SpecProcessorConfig `json:"spec_processor"`
|
||
}
|
||
|
||
type AEConfig struct {
|
||
SampleRate int `json:"sample_rate"`
|
||
BaseChunkSize int `json:"base_chunk_size"`
|
||
Encoder EncoderConfig `json:"encoder"`
|
||
}
|
||
|
||
type StyleTokenLayerConfig struct {
|
||
NStyle int `json:"n_style"`
|
||
StyleValueDim int `json:"style_value_dim"`
|
||
}
|
||
|
||
type StyleEncoderConfig struct {
|
||
StyleTokenLayer StyleTokenLayerConfig `json:"style_token_layer"`
|
||
}
|
||
|
||
type ProjOutConfig struct {
|
||
Idim int `json:"idim"`
|
||
Odim int `json:"odim"`
|
||
}
|
||
|
||
type TextEncoderConfig struct {
|
||
ProjOut ProjOutConfig `json:"proj_out"`
|
||
}
|
||
|
||
type TTLConfig struct {
|
||
ChunkCompressFactor int `json:"chunk_compress_factor"`
|
||
LatentDim int `json:"latent_dim"`
|
||
StyleEncoder StyleEncoderConfig `json:"style_encoder"`
|
||
TextEncoder TextEncoderConfig `json:"text_encoder"`
|
||
}
|
||
|
||
type DPStyleEncoderConfig struct {
|
||
StyleTokenLayer StyleTokenLayerConfig `json:"style_token_layer"`
|
||
}
|
||
|
||
type DPConfig struct {
|
||
LatentDim int `json:"latent_dim"`
|
||
ChunkCompressFactor int `json:"chunk_compress_factor"`
|
||
StyleEncoder DPStyleEncoderConfig `json:"style_encoder"`
|
||
}
|
||
|
||
type Config struct {
|
||
AE AEConfig `json:"ae"`
|
||
TTL TTLConfig `json:"ttl"`
|
||
DP DPConfig `json:"dp"`
|
||
}
|
||
|
||
// VoiceStyleData holds voice style JSON structure
|
||
type VoiceStyleData struct {
|
||
StyleTTL struct {
|
||
Data [][][]float64 `json:"data"`
|
||
Dims []int64 `json:"dims"`
|
||
Type string `json:"type"`
|
||
} `json:"style_ttl"`
|
||
StyleDP struct {
|
||
Data [][][]float64 `json:"data"`
|
||
Dims []int64 `json:"dims"`
|
||
Type string `json:"type"`
|
||
} `json:"style_dp"`
|
||
}
|
||
|
||
// UnicodeProcessor for text processing
|
||
type UnicodeProcessor struct {
|
||
indexer []int64
|
||
}
|
||
|
||
// NewUnicodeProcessor creates a new UnicodeProcessor
|
||
func NewUnicodeProcessor(unicodeIndexerPath string) (*UnicodeProcessor, error) {
|
||
indexer, err := loadJSONInt64(unicodeIndexerPath)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to load unicode indexer: %w", err)
|
||
}
|
||
|
||
return &UnicodeProcessor{indexer: indexer}, nil
|
||
}
|
||
|
||
// Call processes text list to text IDs and mask
|
||
func (up *UnicodeProcessor) Call(textList []string, langList []string) ([][]int64, [][][]float64) {
|
||
// Preprocess texts
|
||
processedTexts := make([]string, len(textList))
|
||
for i, text := range textList {
|
||
processedTexts[i] = preprocessText(text, langList[i])
|
||
}
|
||
|
||
// Get text lengths
|
||
textLengths := make([]int64, len(processedTexts))
|
||
maxLen := 0
|
||
for i, text := range processedTexts {
|
||
textLengths[i] = int64(len([]rune(text)))
|
||
if int(textLengths[i]) > maxLen {
|
||
maxLen = int(textLengths[i])
|
||
}
|
||
}
|
||
|
||
// Create text IDs
|
||
textIDs := make([][]int64, len(processedTexts))
|
||
for i, text := range processedTexts {
|
||
row := make([]int64, maxLen)
|
||
runes := []rune(text)
|
||
for j, r := range runes {
|
||
unicodeVal := int(r)
|
||
if unicodeVal < len(up.indexer) {
|
||
row[j] = up.indexer[unicodeVal]
|
||
} else {
|
||
row[j] = -1
|
||
}
|
||
}
|
||
textIDs[i] = row
|
||
}
|
||
|
||
// Create text mask
|
||
textMask := lengthToMask(textLengths, maxLen)
|
||
|
||
return textIDs, textMask
|
||
}
|
||
|
||
// Text chunking utilities
|
||
const maxChunkLength = 300
|
||
|
||
var abbreviations = []string{
|
||
"Dr.", "Mr.", "Mrs.", "Ms.", "Prof.", "Sr.", "Jr.",
|
||
"St.", "Ave.", "Rd.", "Blvd.", "Dept.", "Inc.", "Ltd.",
|
||
"Co.", "Corp.", "etc.", "vs.", "i.e.", "e.g.", "Ph.D.",
|
||
}
|
||
|
||
func chunkText(text string, maxLen int) []string {
|
||
if maxLen == 0 {
|
||
maxLen = maxChunkLength
|
||
}
|
||
|
||
text = strings.TrimSpace(text)
|
||
if text == "" {
|
||
return []string{""}
|
||
}
|
||
|
||
// Split by paragraphs
|
||
paragraphs := regexp.MustCompile(`\n\s*\n`).Split(text, -1)
|
||
var chunks []string
|
||
|
||
for _, para := range paragraphs {
|
||
para = strings.TrimSpace(para)
|
||
if para == "" {
|
||
continue
|
||
}
|
||
|
||
if len(para) <= maxLen {
|
||
chunks = append(chunks, para)
|
||
continue
|
||
}
|
||
|
||
// Split by sentences
|
||
sentences := splitSentences(para)
|
||
var current strings.Builder
|
||
currentLen := 0
|
||
|
||
for _, sentence := range sentences {
|
||
sentence = strings.TrimSpace(sentence)
|
||
if sentence == "" {
|
||
continue
|
||
}
|
||
|
||
sentenceLen := len(sentence)
|
||
if sentenceLen > maxLen {
|
||
// If sentence is longer than maxLen, split by comma or space
|
||
if current.Len() > 0 {
|
||
chunks = append(chunks, strings.TrimSpace(current.String()))
|
||
current.Reset()
|
||
currentLen = 0
|
||
}
|
||
|
||
// Try splitting by comma
|
||
parts := strings.Split(sentence, ",")
|
||
for _, part := range parts {
|
||
part = strings.TrimSpace(part)
|
||
if part == "" {
|
||
continue
|
||
}
|
||
|
||
partLen := len(part)
|
||
if partLen > maxLen {
|
||
// Split by space as last resort
|
||
words := strings.Fields(part)
|
||
var wordChunk strings.Builder
|
||
wordChunkLen := 0
|
||
|
||
for _, word := range words {
|
||
wordLen := len(word)
|
||
if wordChunkLen+wordLen+1 > maxLen && wordChunk.Len() > 0 {
|
||
chunks = append(chunks, strings.TrimSpace(wordChunk.String()))
|
||
wordChunk.Reset()
|
||
wordChunkLen = 0
|
||
}
|
||
|
||
if wordChunk.Len() > 0 {
|
||
wordChunk.WriteString(" ")
|
||
wordChunkLen++
|
||
}
|
||
wordChunk.WriteString(word)
|
||
wordChunkLen += wordLen
|
||
}
|
||
|
||
if wordChunk.Len() > 0 {
|
||
chunks = append(chunks, strings.TrimSpace(wordChunk.String()))
|
||
}
|
||
} else {
|
||
if currentLen+partLen+1 > maxLen && current.Len() > 0 {
|
||
chunks = append(chunks, strings.TrimSpace(current.String()))
|
||
current.Reset()
|
||
currentLen = 0
|
||
}
|
||
|
||
if current.Len() > 0 {
|
||
current.WriteString(", ")
|
||
currentLen += 2
|
||
}
|
||
current.WriteString(part)
|
||
currentLen += partLen
|
||
}
|
||
}
|
||
continue
|
||
}
|
||
|
||
if currentLen+sentenceLen+1 > maxLen && current.Len() > 0 {
|
||
chunks = append(chunks, strings.TrimSpace(current.String()))
|
||
current.Reset()
|
||
currentLen = 0
|
||
}
|
||
|
||
if current.Len() > 0 {
|
||
current.WriteString(" ")
|
||
currentLen++
|
||
}
|
||
current.WriteString(sentence)
|
||
currentLen += sentenceLen
|
||
}
|
||
|
||
if current.Len() > 0 {
|
||
chunks = append(chunks, strings.TrimSpace(current.String()))
|
||
}
|
||
}
|
||
|
||
if len(chunks) == 0 {
|
||
return []string{""}
|
||
}
|
||
|
||
return chunks
|
||
}
|
||
|
||
func splitSentences(text string) []string {
|
||
// Go's regexp doesn't support lookbehind, so we use a simpler approach
|
||
// Split on sentence boundaries and then check if they're abbreviations
|
||
re := regexp.MustCompile(`([.!?])\s+`)
|
||
|
||
// Find all matches
|
||
matches := re.FindAllStringIndex(text, -1)
|
||
if len(matches) == 0 {
|
||
return []string{text}
|
||
}
|
||
|
||
var sentences []string
|
||
lastEnd := 0
|
||
|
||
for _, match := range matches {
|
||
// Get the text before the punctuation
|
||
beforePunc := text[lastEnd:match[0]]
|
||
|
||
// Check if this ends with an abbreviation
|
||
isAbbrev := false
|
||
for _, abbrev := range abbreviations {
|
||
if strings.HasSuffix(strings.TrimSpace(beforePunc+text[match[0]:match[0]+1]), abbrev) {
|
||
isAbbrev = true
|
||
break
|
||
}
|
||
}
|
||
|
||
if !isAbbrev {
|
||
// This is a real sentence boundary
|
||
sentences = append(sentences, text[lastEnd:match[1]])
|
||
lastEnd = match[1]
|
||
}
|
||
}
|
||
|
||
// Add the remaining text
|
||
if lastEnd < len(text) {
|
||
sentences = append(sentences, text[lastEnd:])
|
||
}
|
||
|
||
if len(sentences) == 0 {
|
||
return []string{text}
|
||
}
|
||
|
||
return sentences
|
||
}
|
||
|
||
// isValidLang checks if a language is in the available languages list
|
||
func isValidLang(lang string) bool {
|
||
for _, l := range AvailableLangs {
|
||
if l == lang {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// Utility functions
|
||
func preprocessText(text string, lang string) string {
|
||
// TODO: Need advanced normalizer for better performance
|
||
// Apply NFKD normalization using golang.org/x/text/unicode/norm
|
||
text = norm.NFKD.String(text)
|
||
|
||
// Remove emojis and various Unicode symbols
|
||
emojiPattern := regexp.MustCompile(`[\x{1F600}-\x{1F64F}\x{1F300}-\x{1F5FF}\x{1F680}-\x{1F6FF}\x{1F700}-\x{1F77F}\x{1F780}-\x{1F7FF}\x{1F800}-\x{1F8FF}\x{1F900}-\x{1F9FF}\x{1FA00}-\x{1FA6F}\x{1FA70}-\x{1FAFF}\x{2600}-\x{26FF}\x{2700}-\x{27BF}\x{1F1E6}-\x{1F1FF}]+`)
|
||
text = emojiPattern.ReplaceAllString(text, "")
|
||
|
||
// Replace various dashes and symbols
|
||
replacements := map[string]string{
|
||
"–": "-", // en dash
|
||
"‑": "-", // non-breaking hyphen
|
||
"—": "-", // em dash
|
||
"_": " ", // underscore
|
||
"\u201C": "\"", // left double quote
|
||
"\u201D": "\"", // right double quote
|
||
"\u2018": "'", // left single quote
|
||
"\u2019": "'", // right single quote
|
||
"´": "'", // acute accent
|
||
"`": "'", // grave accent
|
||
"[": " ", // left bracket
|
||
"]": " ", // right bracket
|
||
"|": " ", // vertical bar
|
||
"/": " ", // slash
|
||
"#": " ", // hash
|
||
"→": " ", // right arrow
|
||
"←": " ", // left arrow
|
||
}
|
||
|
||
for old, new := range replacements {
|
||
text = strings.ReplaceAll(text, old, new)
|
||
}
|
||
|
||
// Remove special symbols
|
||
specialSymbols := []string{"♥", "☆", "♡", "©", "\\"}
|
||
for _, symbol := range specialSymbols {
|
||
text = strings.ReplaceAll(text, symbol, "")
|
||
}
|
||
|
||
// Replace known expressions
|
||
exprReplacements := map[string]string{
|
||
"@": " at ",
|
||
"e.g.,": "for example, ",
|
||
"i.e.,": "that is, ",
|
||
}
|
||
|
||
for old, new := range exprReplacements {
|
||
text = strings.ReplaceAll(text, old, new)
|
||
}
|
||
|
||
// Fix spacing around punctuation
|
||
text = regexp.MustCompile(` ,`).ReplaceAllString(text, ",")
|
||
text = regexp.MustCompile(` \.`).ReplaceAllString(text, ".")
|
||
text = regexp.MustCompile(` !`).ReplaceAllString(text, "!")
|
||
text = regexp.MustCompile(` \?`).ReplaceAllString(text, "?")
|
||
text = regexp.MustCompile(` ;`).ReplaceAllString(text, ";")
|
||
text = regexp.MustCompile(` :`).ReplaceAllString(text, ":")
|
||
text = regexp.MustCompile(` '`).ReplaceAllString(text, "'")
|
||
|
||
// Remove duplicate quotes
|
||
for strings.Contains(text, `""`) {
|
||
text = strings.ReplaceAll(text, `""`, `"`)
|
||
}
|
||
for strings.Contains(text, "''") {
|
||
text = strings.ReplaceAll(text, "''", "'")
|
||
}
|
||
for strings.Contains(text, "``") {
|
||
text = strings.ReplaceAll(text, "``", "`")
|
||
}
|
||
|
||
// Remove extra spaces
|
||
text = regexp.MustCompile(`\s+`).ReplaceAllString(text, " ")
|
||
text = strings.TrimSpace(text)
|
||
|
||
// If text doesn't end with punctuation, quotes, or closing brackets, add a period
|
||
if text != "" {
|
||
endsWithPunct := regexp.MustCompile(`[.!?;:,'"\x{201C}\x{201D}\x{2018}\x{2019})\]}…。」』】〉》›»]$`)
|
||
if !endsWithPunct.MatchString(text) {
|
||
text += "."
|
||
}
|
||
}
|
||
|
||
// Validate language
|
||
if !isValidLang(lang) {
|
||
panic(fmt.Sprintf("Invalid language: %s. Available: %v", lang, AvailableLangs))
|
||
}
|
||
|
||
// Wrap text with language tags
|
||
text = fmt.Sprintf("<%s>%s</%s>", lang, text, lang)
|
||
|
||
return text
|
||
}
|
||
|
||
func lengthToMask(lengths []int64, maxLen int) [][][]float64 {
|
||
bsz := len(lengths)
|
||
mask := make([][][]float64, bsz)
|
||
|
||
for i := 0; i < bsz; i++ {
|
||
row := make([]float64, maxLen)
|
||
for j := 0; j < maxLen; j++ {
|
||
if int64(j) < lengths[i] {
|
||
row[j] = 1.0
|
||
} else {
|
||
row[j] = 0.0
|
||
}
|
||
}
|
||
mask[i] = [][]float64{row}
|
||
}
|
||
|
||
return mask
|
||
}
|
||
|
||
func getTextMask(textLengths []int64, maxLen int) [][][]float64 {
|
||
return lengthToMask(textLengths, maxLen)
|
||
}
|
||
|
||
func getLatentMask(wavLengths []int64, cfg Config) [][][]float64 {
|
||
baseChunkSize := int64(cfg.AE.BaseChunkSize)
|
||
chunkCompressFactor := int64(cfg.TTL.ChunkCompressFactor)
|
||
latentSize := baseChunkSize * chunkCompressFactor
|
||
|
||
latentLengths := make([]int64, len(wavLengths))
|
||
maxLen := int64(0)
|
||
for i, wavLen := range wavLengths {
|
||
latentLengths[i] = (wavLen + latentSize - 1) / latentSize
|
||
if latentLengths[i] > maxLen {
|
||
maxLen = latentLengths[i]
|
||
}
|
||
}
|
||
|
||
return lengthToMask(latentLengths, int(maxLen))
|
||
}
|
||
|
||
func writeWavFile(filename string, audioData []float64, sampleRate int) error {
|
||
file, err := os.Create(filename)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer file.Close()
|
||
|
||
// Convert float64 to int
|
||
intData := make([]int, len(audioData))
|
||
for i, sample := range audioData {
|
||
// Clamp to [-1, 1] and convert to 16-bit int
|
||
clamped := math.Max(-1.0, math.Min(1.0, sample))
|
||
intData[i] = int(clamped * 32767)
|
||
}
|
||
|
||
encoder := wav.NewEncoder(file, sampleRate, 16, 1, 1)
|
||
buf := &audio.IntBuffer{
|
||
Data: intData,
|
||
Format: &audio.Format{SampleRate: sampleRate, NumChannels: 1},
|
||
SourceBitDepth: 16,
|
||
}
|
||
|
||
if err := encoder.Write(buf); err != nil {
|
||
return err
|
||
}
|
||
|
||
return encoder.Close()
|
||
}
|
||
|
||
// Style holds style tensors
|
||
type Style struct {
|
||
TtlTensor *ort.Tensor[float32]
|
||
DpTensor *ort.Tensor[float32]
|
||
}
|
||
|
||
func (s *Style) Destroy() {
|
||
if s.TtlTensor != nil {
|
||
s.TtlTensor.Destroy()
|
||
}
|
||
if s.DpTensor != nil {
|
||
s.DpTensor.Destroy()
|
||
}
|
||
}
|
||
|
||
// LoadVoiceStyle loads voice style from JSON files
|
||
func LoadVoiceStyle(voiceStylePaths []string, verbose bool) (*Style, error) {
|
||
bsz := len(voiceStylePaths)
|
||
|
||
// Read first file to get dimensions
|
||
firstData, err := os.ReadFile(voiceStylePaths[0])
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to read voice style file: %w", err)
|
||
}
|
||
|
||
var firstStyle VoiceStyleData
|
||
if err := json.Unmarshal(firstData, &firstStyle); err != nil {
|
||
return nil, fmt.Errorf("failed to parse voice style JSON: %w", err)
|
||
}
|
||
|
||
ttlDims := firstStyle.StyleTTL.Dims
|
||
dpDims := firstStyle.StyleDP.Dims
|
||
|
||
ttlDim1 := ttlDims[1]
|
||
ttlDim2 := ttlDims[2]
|
||
dpDim1 := dpDims[1]
|
||
dpDim2 := dpDims[2]
|
||
|
||
// Pre-allocate arrays with full batch size
|
||
ttlSize := int(int64(bsz) * ttlDim1 * ttlDim2)
|
||
dpSize := int(int64(bsz) * dpDim1 * dpDim2)
|
||
ttlFlat := make([]float32, ttlSize)
|
||
dpFlat := make([]float32, dpSize)
|
||
|
||
// Fill in the data
|
||
for i := 0; i < bsz; i++ {
|
||
data, err := os.ReadFile(voiceStylePaths[i])
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to read voice style file: %w", err)
|
||
}
|
||
|
||
var voiceStyle VoiceStyleData
|
||
if err := json.Unmarshal(data, &voiceStyle); err != nil {
|
||
return nil, fmt.Errorf("failed to parse voice style JSON: %w", err)
|
||
}
|
||
|
||
// Flatten TTL data
|
||
ttlOffset := int(int64(i) * ttlDim1 * ttlDim2)
|
||
idx := 0
|
||
for _, batch := range voiceStyle.StyleTTL.Data {
|
||
for _, row := range batch {
|
||
for _, val := range row {
|
||
ttlFlat[ttlOffset+idx] = float32(val)
|
||
idx++
|
||
}
|
||
}
|
||
}
|
||
|
||
// Flatten DP data
|
||
dpOffset := int(int64(i) * dpDim1 * dpDim2)
|
||
idx = 0
|
||
for _, batch := range voiceStyle.StyleDP.Data {
|
||
for _, row := range batch {
|
||
for _, val := range row {
|
||
dpFlat[dpOffset+idx] = float32(val)
|
||
idx++
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
ttlShape := []int64{int64(bsz), ttlDim1, ttlDim2}
|
||
dpShape := []int64{int64(bsz), dpDim1, dpDim2}
|
||
|
||
ttlTensor, err := ort.NewTensor(ttlShape, ttlFlat)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create TTL tensor: %w", err)
|
||
}
|
||
|
||
dpTensor, err := ort.NewTensor(dpShape, dpFlat)
|
||
if err != nil {
|
||
ttlTensor.Destroy()
|
||
return nil, fmt.Errorf("failed to create DP tensor: %w", err)
|
||
}
|
||
|
||
if verbose {
|
||
fmt.Printf("Loaded %d voice styles\n\n", bsz)
|
||
}
|
||
|
||
return &Style{
|
||
TtlTensor: ttlTensor,
|
||
DpTensor: dpTensor,
|
||
}, nil
|
||
}
|
||
|
||
// TextToSpeech generates speech from text
|
||
type TextToSpeech struct {
|
||
cfg Config
|
||
textProcessor *UnicodeProcessor
|
||
dpOrt *ort.DynamicAdvancedSession
|
||
textEncOrt *ort.DynamicAdvancedSession
|
||
vectorEstOrt *ort.DynamicAdvancedSession
|
||
vocoderOrt *ort.DynamicAdvancedSession
|
||
SampleRate int
|
||
baseChunkSize int
|
||
chunkCompress int
|
||
ldim int
|
||
}
|
||
|
||
func (tts *TextToSpeech) sampleNoisyLatent(durOnnx []float32) ([][][]float64, [][][]float64) {
|
||
bsz := len(durOnnx)
|
||
maxDur := float64(0)
|
||
for _, d := range durOnnx {
|
||
if float64(d) > maxDur {
|
||
maxDur = float64(d)
|
||
}
|
||
}
|
||
|
||
wavLenMax := maxDur * float64(tts.SampleRate)
|
||
wavLengths := make([]int64, bsz)
|
||
for i, d := range durOnnx {
|
||
wavLengths[i] = int64(float64(d) * float64(tts.SampleRate))
|
||
}
|
||
|
||
chunkSize := tts.baseChunkSize * tts.chunkCompress
|
||
latentLen := int((wavLenMax + float64(chunkSize) - 1) / float64(chunkSize))
|
||
latentDim := tts.ldim * tts.chunkCompress
|
||
|
||
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||
noisyLatent := make([][][]float64, bsz)
|
||
for b := 0; b < bsz; b++ {
|
||
batch := make([][]float64, latentDim)
|
||
for d := 0; d < latentDim; d++ {
|
||
row := make([]float64, latentLen)
|
||
for t := 0; t < latentLen; t++ {
|
||
// Box-Muller transform for normal distribution
|
||
// Add epsilon to avoid log(0)
|
||
const eps = 1e-10
|
||
u1 := math.Max(eps, rng.Float64())
|
||
u2 := rng.Float64()
|
||
row[t] = math.Sqrt(-2.0*math.Log(u1)) * math.Cos(2.0*math.Pi*u2)
|
||
}
|
||
batch[d] = row
|
||
}
|
||
noisyLatent[b] = batch
|
||
}
|
||
|
||
latentMask := getLatentMask(wavLengths, tts.cfg)
|
||
|
||
// Apply mask
|
||
for b := 0; b < bsz; b++ {
|
||
for d := 0; d < latentDim; d++ {
|
||
for t := 0; t < latentLen; t++ {
|
||
noisyLatent[b][d][t] *= latentMask[b][0][t]
|
||
}
|
||
}
|
||
}
|
||
|
||
return noisyLatent, latentMask
|
||
}
|
||
|
||
func (tts *TextToSpeech) _infer(textList []string, langList []string, style *Style, totalStep int, speed float32) ([]float32, []float32, error) {
|
||
bsz := len(textList)
|
||
|
||
// Process text
|
||
textIDs, textMask := tts.textProcessor.Call(textList, langList)
|
||
textIDsShape := []int64{int64(bsz), int64(len(textIDs[0]))}
|
||
textMaskShape := []int64{int64(bsz), 1, int64(len(textMask[0][0]))}
|
||
|
||
textIDsTensor := IntArrayToTensor(textIDs, textIDsShape)
|
||
defer textIDsTensor.Destroy()
|
||
textMaskTensor := ArrayToTensor(textMask, textMaskShape)
|
||
defer textMaskTensor.Destroy()
|
||
|
||
// Predict duration
|
||
dpOutputs := []ort.Value{nil}
|
||
err := tts.dpOrt.Run(
|
||
[]ort.Value{textIDsTensor, style.DpTensor, textMaskTensor},
|
||
dpOutputs,
|
||
)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("failed to run duration predictor: %w", err)
|
||
}
|
||
durTensor := dpOutputs[0].(*ort.Tensor[float32])
|
||
defer durTensor.Destroy()
|
||
durOnnx := durTensor.GetData()
|
||
|
||
// Apply speed factor to duration
|
||
for i := range durOnnx {
|
||
durOnnx[i] /= speed
|
||
}
|
||
|
||
// Encode text
|
||
textIDsTensor2 := IntArrayToTensor(textIDs, textIDsShape)
|
||
defer textIDsTensor2.Destroy()
|
||
textEncOutputs := []ort.Value{nil}
|
||
err = tts.textEncOrt.Run(
|
||
[]ort.Value{textIDsTensor2, style.TtlTensor, textMaskTensor},
|
||
textEncOutputs,
|
||
)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("failed to run text encoder: %w", err)
|
||
}
|
||
textEmbTensor := textEncOutputs[0].(*ort.Tensor[float32])
|
||
defer textEmbTensor.Destroy()
|
||
|
||
// Sample noisy latent
|
||
xt, latentMask := tts.sampleNoisyLatent(durOnnx)
|
||
latentShape := []int64{int64(bsz), int64(len(xt[0])), int64(len(xt[0][0]))}
|
||
latentMaskShape := []int64{int64(bsz), 1, int64(len(latentMask[0][0]))}
|
||
|
||
// Prepare constant arrays
|
||
totalStepArray := make([]float32, bsz)
|
||
for b := 0; b < bsz; b++ {
|
||
totalStepArray[b] = float32(totalStep)
|
||
}
|
||
scalarShape := []int64{int64(bsz)}
|
||
|
||
totalStepTensor, _ := ort.NewTensor(scalarShape, totalStepArray)
|
||
defer totalStepTensor.Destroy()
|
||
|
||
// Denoising loop
|
||
for step := 0; step < totalStep; step++ {
|
||
currentStepArray := make([]float32, bsz)
|
||
for b := 0; b < bsz; b++ {
|
||
currentStepArray[b] = float32(step)
|
||
}
|
||
|
||
currentStepTensor, _ := ort.NewTensor(scalarShape, currentStepArray)
|
||
noisyLatentTensor := ArrayToTensor(xt, latentShape)
|
||
latentMaskTensor := ArrayToTensor(latentMask, latentMaskShape)
|
||
textMaskTensor2 := ArrayToTensor(textMask, textMaskShape)
|
||
|
||
vectorEstOutputs := []ort.Value{nil}
|
||
err = tts.vectorEstOrt.Run(
|
||
[]ort.Value{noisyLatentTensor, textEmbTensor, style.TtlTensor, latentMaskTensor, textMaskTensor2,
|
||
currentStepTensor, totalStepTensor},
|
||
vectorEstOutputs,
|
||
)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("failed to run vector estimator: %w", err)
|
||
}
|
||
|
||
denoisedTensor := vectorEstOutputs[0].(*ort.Tensor[float32])
|
||
denoisedData := denoisedTensor.GetData()
|
||
|
||
// Update latent
|
||
idx := 0
|
||
for b := 0; b < bsz; b++ {
|
||
for d := 0; d < len(xt[b]); d++ {
|
||
for t := 0; t < len(xt[b][d]); t++ {
|
||
xt[b][d][t] = float64(denoisedData[idx])
|
||
idx++
|
||
}
|
||
}
|
||
}
|
||
|
||
noisyLatentTensor.Destroy()
|
||
latentMaskTensor.Destroy()
|
||
textMaskTensor2.Destroy()
|
||
currentStepTensor.Destroy()
|
||
denoisedTensor.Destroy()
|
||
}
|
||
|
||
// Generate waveform
|
||
finalLatentTensor := ArrayToTensor(xt, latentShape)
|
||
defer finalLatentTensor.Destroy()
|
||
|
||
vocoderOutputs := []ort.Value{nil}
|
||
err = tts.vocoderOrt.Run(
|
||
[]ort.Value{finalLatentTensor},
|
||
vocoderOutputs,
|
||
)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("failed to run vocoder: %w", err)
|
||
}
|
||
|
||
wavBatchTensor := vocoderOutputs[0].(*ort.Tensor[float32])
|
||
defer wavBatchTensor.Destroy()
|
||
wav := wavBatchTensor.GetData()
|
||
|
||
return wav, durOnnx, nil
|
||
}
|
||
|
||
// Call synthesizes speech from a single text with automatic chunking
|
||
func (tts *TextToSpeech) Call(text string, lang string, style *Style, totalStep int, speed float32, silenceDuration float32) ([]float32, float32, error) {
|
||
maxLen := 300
|
||
if lang == "ko" {
|
||
maxLen = 120
|
||
}
|
||
chunks := chunkText(text, maxLen)
|
||
|
||
var wavCat []float32
|
||
var durCat float32
|
||
|
||
for i, chunk := range chunks {
|
||
wav, duration, err := tts._infer([]string{chunk}, []string{lang}, style, totalStep, speed)
|
||
if err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
dur := duration[0]
|
||
wavLen := int(float32(tts.SampleRate) * dur)
|
||
wavChunk := wav[:wavLen]
|
||
|
||
if i == 0 {
|
||
wavCat = wavChunk
|
||
durCat = dur
|
||
} else {
|
||
silenceLen := int(silenceDuration * float32(tts.SampleRate))
|
||
silence := make([]float32, silenceLen)
|
||
|
||
wavCat = append(wavCat, silence...)
|
||
wavCat = append(wavCat, wavChunk...)
|
||
durCat += silenceDuration + dur
|
||
}
|
||
}
|
||
|
||
return wavCat, durCat, nil
|
||
}
|
||
|
||
// Batch synthesizes speech from multiple texts
|
||
func (tts *TextToSpeech) Batch(textList []string, langList []string, style *Style, totalStep int, speed float32) ([]float32, []float32, error) {
|
||
return tts._infer(textList, langList, style, totalStep, speed)
|
||
}
|
||
|
||
func (tts *TextToSpeech) Destroy() {
|
||
if tts.dpOrt != nil {
|
||
tts.dpOrt.Destroy()
|
||
}
|
||
if tts.textEncOrt != nil {
|
||
tts.textEncOrt.Destroy()
|
||
}
|
||
if tts.vectorEstOrt != nil {
|
||
tts.vectorEstOrt.Destroy()
|
||
}
|
||
if tts.vocoderOrt != nil {
|
||
tts.vocoderOrt.Destroy()
|
||
}
|
||
}
|
||
|
||
// LoadTextToSpeech loads TTS components
|
||
func LoadTextToSpeech(onnxDir string, useGPU bool, cfg Config) (*TextToSpeech, error) {
|
||
if useGPU {
|
||
return nil, fmt.Errorf("GPU mode is not supported yet")
|
||
}
|
||
fmt.Println("Using CPU for inference\n")
|
||
|
||
// Load models
|
||
dpPath := filepath.Join(onnxDir, "duration_predictor.onnx")
|
||
textEncPath := filepath.Join(onnxDir, "text_encoder.onnx")
|
||
vectorEstPath := filepath.Join(onnxDir, "vector_estimator.onnx")
|
||
vocoderPath := filepath.Join(onnxDir, "vocoder.onnx")
|
||
|
||
dpOrt, err := ort.NewDynamicAdvancedSession(dpPath, []string{"text_ids", "style_dp", "text_mask"},
|
||
[]string{"duration"}, nil)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to load duration predictor: %w", err)
|
||
}
|
||
|
||
textEncOrt, err := ort.NewDynamicAdvancedSession(textEncPath, []string{"text_ids", "style_ttl", "text_mask"},
|
||
[]string{"text_emb"}, nil)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to load text encoder: %w", err)
|
||
}
|
||
|
||
vectorEstOrt, err := ort.NewDynamicAdvancedSession(vectorEstPath,
|
||
[]string{"noisy_latent", "text_emb", "style_ttl", "latent_mask", "text_mask", "current_step", "total_step"},
|
||
[]string{"denoised_latent"}, nil)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to load vector estimator: %w", err)
|
||
}
|
||
|
||
vocoderOrt, err := ort.NewDynamicAdvancedSession(vocoderPath, []string{"latent"},
|
||
[]string{"wav_tts"}, nil)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to load vocoder: %w", err)
|
||
}
|
||
|
||
// Load text processor
|
||
unicodeIndexerPath := filepath.Join(onnxDir, "unicode_indexer.json")
|
||
textProcessor, err := NewUnicodeProcessor(unicodeIndexerPath)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
textToSpeech := &TextToSpeech{
|
||
cfg: cfg,
|
||
textProcessor: textProcessor,
|
||
dpOrt: dpOrt,
|
||
textEncOrt: textEncOrt,
|
||
vectorEstOrt: vectorEstOrt,
|
||
vocoderOrt: vocoderOrt,
|
||
SampleRate: cfg.AE.SampleRate,
|
||
baseChunkSize: cfg.AE.BaseChunkSize,
|
||
chunkCompress: cfg.TTL.ChunkCompressFactor,
|
||
ldim: cfg.TTL.LatentDim,
|
||
}
|
||
|
||
return textToSpeech, nil
|
||
}
|
||
|
||
// InitializeONNXRuntime initializes ONNX Runtime environment
|
||
func InitializeONNXRuntime() error {
|
||
libPath := os.Getenv("ONNXRUNTIME_LIB_PATH")
|
||
if libPath == "" {
|
||
libPath = "/usr/local/lib/libonnxruntime.so"
|
||
if _, err := os.Stat("/usr/local/lib/libonnxruntime.dylib"); err == nil {
|
||
libPath = "/usr/local/lib/libonnxruntime.dylib"
|
||
} else if _, err := os.Stat("/usr/lib/libonnxruntime.so"); err == nil {
|
||
libPath = "/usr/lib/libonnxruntime.so"
|
||
}
|
||
}
|
||
ort.SetSharedLibraryPath(libPath)
|
||
|
||
if err := ort.InitializeEnvironment(); err != nil {
|
||
return fmt.Errorf("failed to initialize ONNX Runtime: %w\nHint: Set ONNXRUNTIME_LIB_PATH environment variable", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// sanitizeFilename creates a safe filename from text (supports Unicode)
|
||
func sanitizeFilename(text string, maxLen int) string {
|
||
runes := []rune(text)
|
||
if len(runes) > maxLen {
|
||
runes = runes[:maxLen]
|
||
}
|
||
|
||
result := make([]rune, 0, len(runes))
|
||
for _, r := range runes {
|
||
// unicode.IsLetter matches any Unicode letter, unicode.IsDigit matches any Unicode digit
|
||
if unicode.IsLetter(r) || unicode.IsDigit(r) {
|
||
result = append(result, r)
|
||
} else {
|
||
result = append(result, '_')
|
||
}
|
||
}
|
||
return string(result)
|
||
}
|
||
|
||
// extractWavSegment extracts a single audio segment from batch output
|
||
func extractWavSegment(wav []float32, duration float32, sampleRate int, index int, batchSize int) []float64 {
|
||
wavLen := int(float64(sampleRate) * float64(duration))
|
||
wavPerBatch := len(wav) / batchSize
|
||
|
||
wavStart := index * wavPerBatch
|
||
wavEnd := wavStart + wavLen
|
||
if wavEnd > len(wav) {
|
||
wavEnd = len(wav)
|
||
}
|
||
|
||
wavOut := make([]float64, wavLen)
|
||
for j := 0; j < wavLen && wavStart+j < len(wav); j++ {
|
||
wavOut[j] = float64(wav[wavStart+j])
|
||
}
|
||
|
||
return wavOut
|
||
}
|
||
|
||
// Timer measures execution time
|
||
func Timer(name string, fn func() interface{}) interface{} {
|
||
start := time.Now()
|
||
fmt.Printf("%s...\n", name)
|
||
result := fn()
|
||
elapsed := time.Since(start).Seconds()
|
||
fmt.Printf(" -> %s completed in %.2f sec\n", name, elapsed)
|
||
return result
|
||
}
|
||
|
||
// LoadCfgs loads configuration from JSON file
|
||
func LoadCfgs(onnxDir string) (Config, error) {
|
||
cfgPath := filepath.Join(onnxDir, "tts.json")
|
||
data, err := os.ReadFile(cfgPath)
|
||
if err != nil {
|
||
return Config{}, err
|
||
}
|
||
|
||
var cfg Config
|
||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||
return Config{}, err
|
||
}
|
||
|
||
return cfg, nil
|
||
}
|
||
|
||
// JSON loading helpers
|
||
func loadJSONInt64(filePath string) ([]int64, error) {
|
||
data, err := os.ReadFile(filePath)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var result []int64
|
||
if err := json.Unmarshal(data, &result); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// Tensor conversion utilities
|
||
func ArrayToTensor(array [][][]float64, shape []int64) *ort.Tensor[float32] {
|
||
// Flatten array
|
||
totalSize := int64(1)
|
||
for _, dim := range shape {
|
||
totalSize *= dim
|
||
}
|
||
|
||
flat := make([]float32, totalSize)
|
||
idx := 0
|
||
for b := 0; b < len(array); b++ {
|
||
for d := 0; d < len(array[b]); d++ {
|
||
for t := 0; t < len(array[b][d]); t++ {
|
||
flat[idx] = float32(array[b][d][t])
|
||
idx++
|
||
}
|
||
}
|
||
}
|
||
|
||
tensor, err := ort.NewTensor(shape, flat)
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
|
||
return tensor
|
||
}
|
||
|
||
func IntArrayToTensor(array [][]int64, shape []int64) *ort.Tensor[int64] {
|
||
// Flatten array
|
||
totalSize := int64(1)
|
||
for _, dim := range shape {
|
||
totalSize *= dim
|
||
}
|
||
|
||
flat := make([]int64, totalSize)
|
||
idx := 0
|
||
for b := 0; b < len(array); b++ {
|
||
for t := 0; t < len(array[b]); t++ {
|
||
flat[idx] = array[b][t]
|
||
idx++
|
||
}
|
||
}
|
||
|
||
tensor, err := ort.NewTensor(shape, flat)
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
|
||
return tensor
|
||
}
|