add supertonic

This commit is contained in:
backuppc
2025-12-10 11:39:15 +09:00
parent 3695ab0044
commit 868fa2deec
22 changed files with 54223 additions and 1 deletions

View File

@@ -0,0 +1,881 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Text.RegularExpressions;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
namespace Supertonic
{
// ============================================================================
// Configuration classes
// ============================================================================
public class Config
{
public AEConfig AE { get; set; } = null;
public TTLConfig TTL { get; set; } = null;
public class AEConfig
{
public int SampleRate { get; set; }
public int BaseChunkSize { get; set; }
}
public class TTLConfig
{
public int ChunkCompressFactor { get; set; }
public int LatentDim { get; set; }
}
}
// ============================================================================
// Style class
// ============================================================================
public class Style
{
public float[] Ttl { get; set; }
public long[] TtlShape { get; set; }
public float[] Dp { get; set; }
public long[] DpShape { get; set; }
public Style(float[] ttl, long[] ttlShape, float[] dp, long[] dpShape)
{
Ttl = ttl;
TtlShape = ttlShape;
Dp = dp;
DpShape = dpShape;
}
}
// ============================================================================
// Unicode text processor
// ============================================================================
public class UnicodeProcessor
{
private readonly Dictionary<int, long> _indexer;
public UnicodeProcessor(string unicodeIndexerPath)
{
var json = File.ReadAllText(unicodeIndexerPath);
var indexerArray = JsonSerializer.Deserialize<long[]>(json) ?? throw new Exception("Failed to load indexer");
_indexer = new Dictionary<int, long>();
for (int i = 0; i < indexerArray.Length; i++)
{
_indexer[i] = indexerArray[i];
}
}
private static string RemoveEmojis(string text)
{
var result = new StringBuilder();
for (int i = 0; i < text.Length; i++)
{
int codePoint;
if (char.IsHighSurrogate(text[i]) && i + 1 < text.Length && char.IsLowSurrogate(text[i + 1]))
{
// Get the full code point from surrogate pair
codePoint = char.ConvertToUtf32(text[i], text[i + 1]);
i++; // Skip the low surrogate
}
else
{
codePoint = text[i];
}
// Check if code point is in emoji ranges
bool isEmoji = (codePoint >= 0x1F600 && codePoint <= 0x1F64F) ||
(codePoint >= 0x1F300 && codePoint <= 0x1F5FF) ||
(codePoint >= 0x1F680 && codePoint <= 0x1F6FF) ||
(codePoint >= 0x1F700 && codePoint <= 0x1F77F) ||
(codePoint >= 0x1F780 && codePoint <= 0x1F7FF) ||
(codePoint >= 0x1F800 && codePoint <= 0x1F8FF) ||
(codePoint >= 0x1F900 && codePoint <= 0x1F9FF) ||
(codePoint >= 0x1FA00 && codePoint <= 0x1FA6F) ||
(codePoint >= 0x1FA70 && codePoint <= 0x1FAFF) ||
(codePoint >= 0x2600 && codePoint <= 0x26FF) ||
(codePoint >= 0x2700 && codePoint <= 0x27BF) ||
(codePoint >= 0x1F1E6 && codePoint <= 0x1F1FF);
if (!isEmoji)
{
if (codePoint > 0xFFFF)
{
// Add back as surrogate pair
result.Append(char.ConvertFromUtf32(codePoint));
}
else
{
result.Append((char)codePoint);
}
}
}
return result.ToString();
}
private string PreprocessText(string text)
{
// TODO: Need advanced normalizer for better performance
text = text.Normalize(NormalizationForm.FormKD);
// FIXME: this should be fixed for non-English languages
// Remove emojis (wide Unicode range)
// C# doesn't support \u{...} syntax in regex, so we use character filtering instead
text = RemoveEmojis(text);
// Replace various dashes and symbols
var replacements = new Dictionary<string, string>
{
{"", "-"}, // en dash
{"", "-"}, // non-breaking hyphen
{"—", "-"}, // em dash
{"¯", " "}, // macron
{"_", " "}, // underscore
{"\u201C", "\""}, // left double quote
{"\u201D", "\""}, // right double quote
{"\u2018", "'"}, // left single quote
{"\u2019", "'"}, // right single quote
{"´", "'"}, // acute accent
{"`", "'"}, // grave accent
{"[", " "}, // left bracket
{"]", " "}, // right bracket
{"|", " "}, // vertical bar
{"/", " "}, // slash
{"#", " "}, // hash
{"→", " "}, // right arrow
{"←", " "}, // left arrow
};
foreach (var kvp in replacements)
{
text = text.Replace(kvp.Key, kvp.Value);
}
// Remove combining diacritics // FIXME: this should be fixed for non-English languages
text = Regex.Replace(text, @"[\u0302\u0303\u0304\u0305\u0306\u0307\u0308\u030A\u030B\u030C\u0327\u0328\u0329\u032A\u032B\u032C\u032D\u032E\u032F]", "");
// Remove special symbols
text = Regex.Replace(text, @"[♥☆♡©\\]", "");
// Replace known expressions
var exprReplacements = new Dictionary<string, string>
{
{"@", " at "},
{"e.g.,", "for example, "},
{"i.e.,", "that is, "},
};
foreach (var kvp in exprReplacements)
{
text = text.Replace(kvp.Key, kvp.Value);
}
// Fix spacing around punctuation
text = Regex.Replace(text, @" ,", ",");
text = Regex.Replace(text, @" \.", ".");
text = Regex.Replace(text, @" !", "!");
text = Regex.Replace(text, @" \?", "?");
text = Regex.Replace(text, @" ;", ";");
text = Regex.Replace(text, @" :", ":");
text = Regex.Replace(text, @" '", "'");
// Remove duplicate quotes
while (text.Contains("\"\""))
{
text = text.Replace("\"\"", "\"");
}
while (text.Contains("''"))
{
text = text.Replace("''", "'");
}
while (text.Contains("``"))
{
text = text.Replace("``", "`");
}
// Remove extra spaces
text = Regex.Replace(text, @"\s+", " ").Trim();
// If text doesn't end with punctuation, quotes, or closing brackets, add a period
if (!Regex.IsMatch(text, @"[.!?;:,'\u0022\u201C\u201D\u2018\u2019)\]}…。」』】〉》›»]$"))
{
text += ".";
}
return text;
}
private int[] TextToUnicodeValues(string text)
{
return text.Select(c => (int)c).ToArray();
}
private float[][][] GetTextMask(long[] textIdsLengths)
{
return Helper.LengthToMask(textIdsLengths);
}
public (long[][] textIds, float[][][] textMask) Call(List<string> textList)
{
var processedTexts = textList.Select(t => PreprocessText(t)).ToList();
var textIdsLengths = processedTexts.Select(t => (long)t.Length).ToArray();
long maxLen = textIdsLengths.Max();
var textIds = new long[textList.Count][];
for (int i = 0; i < processedTexts.Count; i++)
{
textIds[i] = new long[maxLen];
var unicodeVals = TextToUnicodeValues(processedTexts[i]);
for (int j = 0; j < unicodeVals.Length; j++)
{
if (_indexer.TryGetValue(unicodeVals[j], out long val))
{
textIds[i][j] = val;
}
}
}
var textMask = GetTextMask(textIdsLengths);
return (textIds, textMask);
}
}
// ============================================================================
// TextToSpeech class
// ============================================================================
public class TextToSpeech
{
private readonly Config _cfgs;
private readonly UnicodeProcessor _textProcessor;
private readonly InferenceSession _dpOrt;
private readonly InferenceSession _textEncOrt;
private readonly InferenceSession _vectorEstOrt;
private readonly InferenceSession _vocoderOrt;
public readonly int SampleRate;
private readonly int _baseChunkSize;
private readonly int _chunkCompressFactor;
private readonly int _ldim;
public TextToSpeech(
Config cfgs,
UnicodeProcessor textProcessor,
InferenceSession dpOrt,
InferenceSession textEncOrt,
InferenceSession vectorEstOrt,
InferenceSession vocoderOrt)
{
_cfgs = cfgs;
_textProcessor = textProcessor;
_dpOrt = dpOrt;
_textEncOrt = textEncOrt;
_vectorEstOrt = vectorEstOrt;
_vocoderOrt = vocoderOrt;
SampleRate = cfgs.AE.SampleRate;
_baseChunkSize = cfgs.AE.BaseChunkSize;
_chunkCompressFactor = cfgs.TTL.ChunkCompressFactor;
_ldim = cfgs.TTL.LatentDim;
}
private (float[][][] noisyLatent, float[][][] latentMask) SampleNoisyLatent(float[] duration)
{
int bsz = duration.Length;
float wavLenMax = duration.Max() * SampleRate;
var wavLengths = duration.Select(d => (long)(d * SampleRate)).ToArray();
int chunkSize = _baseChunkSize * _chunkCompressFactor;
int latentLen = (int)((wavLenMax + chunkSize - 1) / chunkSize);
int latentDim = _ldim * _chunkCompressFactor;
// Generate random noise
var random = new Random();
var noisyLatent = new float[bsz][][];
for (int b = 0; b < bsz; b++)
{
noisyLatent[b] = new float[latentDim][];
for (int d = 0; d < latentDim; d++)
{
noisyLatent[b][d] = new float[latentLen];
for (int t = 0; t < latentLen; t++)
{
// Box-Muller transform for normal distribution
double u1 = 1.0 - random.NextDouble();
double u2 = 1.0 - random.NextDouble();
noisyLatent[b][d][t] = (float)(Math.Sqrt(-2.0 * Math.Log(u1)) * Math.Cos(2.0 * Math.PI * u2));
}
}
}
var latentMask = Helper.GetLatentMask(wavLengths, _baseChunkSize, _chunkCompressFactor);
// Apply mask
for (int b = 0; b < bsz; b++)
{
for (int d = 0; d < latentDim; d++)
{
for (int t = 0; t < latentLen; t++)
{
noisyLatent[b][d][t] *= latentMask[b][0][t];
}
}
}
return (noisyLatent, latentMask);
}
private (float[] wav, float[] duration) _Infer(List<string> textList, Style style, int totalStep, float speed = 1.05f)
{
int bsz = textList.Count;
if (bsz != style.TtlShape[0])
{
throw new ArgumentException("Number of texts must match number of style vectors");
}
// Process text
var (textIds, textMask) = _textProcessor.Call(textList);
var textIdsShape = new long[] { bsz, textIds[0].Length };
var textMaskShape = new long[] { bsz, 1, textMask[0][0].Length };
var textIdsTensor = Helper.IntArrayToTensor(textIds, textIdsShape);
var textMaskTensor = Helper.ArrayToTensor(textMask, textMaskShape);
var styleTtlTensor = new DenseTensor<float>(style.Ttl, style.TtlShape.Select(x => (int)x).ToArray());
var styleDpTensor = new DenseTensor<float>(style.Dp, style.DpShape.Select(x => (int)x).ToArray());
// Run duration predictor
var dpInputs = new List<NamedOnnxValue>
{
NamedOnnxValue.CreateFromTensor("text_ids", textIdsTensor),
NamedOnnxValue.CreateFromTensor("style_dp", styleDpTensor),
NamedOnnxValue.CreateFromTensor("text_mask", textMaskTensor)
};
using (var dpOutputs = _dpOrt.Run(dpInputs))
{
var durOnnx = dpOutputs.First(o => o.Name == "duration").AsTensor<float>().ToArray();
// Apply speed factor to duration
for (int i = 0; i < durOnnx.Length; i++)
{
durOnnx[i] /= speed;
}
// Run text encoder
var textEncInputs = new List<NamedOnnxValue>
{
NamedOnnxValue.CreateFromTensor("text_ids", textIdsTensor),
NamedOnnxValue.CreateFromTensor("style_ttl", styleTtlTensor),
NamedOnnxValue.CreateFromTensor("text_mask", textMaskTensor)
};
using (var textEncOutputs = _textEncOrt.Run(textEncInputs))
{
var textEmbTensor = textEncOutputs.First(o => o.Name == "text_emb").AsTensor<float>();
// Sample noisy latent
var (xt, latentMask) = SampleNoisyLatent(durOnnx);
var latentShape = new long[] { bsz, xt[0].Length, xt[0][0].Length };
var latentMaskShape = new long[] { bsz, 1, latentMask[0][0].Length };
var totalStepArray = Enumerable.Repeat((float)totalStep, bsz).ToArray();
// Iterative denoising
for (int step = 0; step < totalStep; step++)
{
var currentStepArray = Enumerable.Repeat((float)step, bsz).ToArray();
var vectorEstInputs = new List<NamedOnnxValue>
{
NamedOnnxValue.CreateFromTensor("noisy_latent", Helper.ArrayToTensor(xt, latentShape)),
NamedOnnxValue.CreateFromTensor("text_emb", textEmbTensor),
NamedOnnxValue.CreateFromTensor("style_ttl", styleTtlTensor),
NamedOnnxValue.CreateFromTensor("text_mask", textMaskTensor),
NamedOnnxValue.CreateFromTensor("latent_mask", Helper.ArrayToTensor(latentMask, latentMaskShape)),
NamedOnnxValue.CreateFromTensor("total_step", new DenseTensor<float>(totalStepArray, new int[] { bsz })),
NamedOnnxValue.CreateFromTensor("current_step", new DenseTensor<float>(currentStepArray, new int[] { bsz }))
};
using (var vectorEstOutputs = _vectorEstOrt.Run(vectorEstInputs))
{
var denoisedLatent = vectorEstOutputs.First(o => o.Name == "denoised_latent").AsTensor<float>();
// Update xt
int idx = 0;
for (int b = 0; b < bsz; b++)
{
for (int d = 0; d < xt[b].Length; d++)
{
for (int t = 0; t < xt[b][d].Length; t++)
{
xt[b][d][t] = denoisedLatent.GetValue(idx++);
}
}
}
}
}
// Run vocoder
var vocoderInputs = new List<NamedOnnxValue>
{
NamedOnnxValue.CreateFromTensor("latent", Helper.ArrayToTensor(xt, latentShape))
};
using (var vocoderOutputs = _vocoderOrt.Run(vocoderInputs))
{
var wavTensor = vocoderOutputs.First(o => o.Name == "wav_tts").AsTensor<float>();
return (wavTensor.ToArray(), durOnnx);
}
}
}
}
public (float[] wav, float[] duration) Call(string text, Style style, int totalStep, float speed = 1.05f, float silenceDuration = 0.3f)
{
if (style.TtlShape[0] != 1)
{
throw new ArgumentException("Single speaker text to speech only supports single style");
}
var textList = Helper.ChunkText(text);
var wavCat = new List<float>();
float durCat = 0.0f;
foreach (var chunk in textList)
{
var (wav, duration) = _Infer(new List<string> { chunk }, style, totalStep, speed);
if (wavCat.Count == 0)
{
wavCat.AddRange(wav);
durCat = duration[0];
}
else
{
int silenceLen = (int)(silenceDuration * SampleRate);
var silence = new float[silenceLen];
wavCat.AddRange(silence);
wavCat.AddRange(wav);
durCat += duration[0] + silenceDuration;
}
}
return (wavCat.ToArray(), new float[] { durCat });
}
public (float[] wav, float[] duration) Batch(List<string> textList, Style style, int totalStep, float speed = 1.05f)
{
return _Infer(textList, style, totalStep, speed);
}
}
// ============================================================================
// Helper class with utility functions
// ============================================================================
public static class Helper
{
// ============================================================================
// Utility functions
// ============================================================================
public static float[][][] LengthToMask(long[] lengths, long maxLen = -1)
{
if (maxLen == -1)
{
maxLen = lengths.Max();
}
var mask = new float[lengths.Length][][];
for (int i = 0; i < lengths.Length; i++)
{
mask[i] = new float[1][];
mask[i][0] = new float[maxLen];
for (int j = 0; j < maxLen; j++)
{
mask[i][0][j] = j < lengths[i] ? 1.0f : 0.0f;
}
}
return mask;
}
public static float[][][] GetLatentMask(long[] wavLengths, int baseChunkSize, int chunkCompressFactor)
{
int latentSize = baseChunkSize * chunkCompressFactor;
var latentLengths = wavLengths.Select(len => (len + latentSize - 1) / latentSize).ToArray();
return LengthToMask(latentLengths);
}
// ============================================================================
// ONNX model loading
// ============================================================================
public static InferenceSession LoadOnnx(string onnxPath, SessionOptions opts)
{
return new InferenceSession(onnxPath, opts);
}
public static (InferenceSession dp, InferenceSession textEnc, InferenceSession vectorEst, InferenceSession vocoder)
LoadOnnxAll(string onnxDir, SessionOptions opts)
{
var dpPath = Path.Combine(onnxDir, "duration_predictor.onnx");
var textEncPath = Path.Combine(onnxDir, "text_encoder.onnx");
var vectorEstPath = Path.Combine(onnxDir, "vector_estimator.onnx");
var vocoderPath = Path.Combine(onnxDir, "vocoder.onnx");
return (
LoadOnnx(dpPath, opts),
LoadOnnx(textEncPath, opts),
LoadOnnx(vectorEstPath, opts),
LoadOnnx(vocoderPath, opts)
);
}
// ============================================================================
// Configuration loading
// ============================================================================
public static Config LoadCfgs(string onnxDir)
{
var cfgPath = Path.Combine(onnxDir, "tts.json");
var json = File.ReadAllText(cfgPath);
using (var doc = JsonDocument.Parse(json))
{
var root = doc.RootElement;
return new Config
{
AE = new Config.AEConfig
{
SampleRate = root.GetProperty("ae").GetProperty("sample_rate").GetInt32(),
BaseChunkSize = root.GetProperty("ae").GetProperty("base_chunk_size").GetInt32()
},
TTL = new Config.TTLConfig
{
ChunkCompressFactor = root.GetProperty("ttl").GetProperty("chunk_compress_factor").GetInt32(),
LatentDim = root.GetProperty("ttl").GetProperty("latent_dim").GetInt32()
}
};
}
}
public static UnicodeProcessor LoadTextProcessor(string onnxDir)
{
var unicodeIndexerPath = Path.Combine(onnxDir, "unicode_indexer.json");
return new UnicodeProcessor(unicodeIndexerPath);
}
// ============================================================================
// Voice style loading
// ============================================================================
public static Style LoadVoiceStyle(List<string> voiceStylePaths, bool verbose = false)
{
int bsz = voiceStylePaths.Count;
// Read first file to get dimensions
var firstJson = File.ReadAllText(voiceStylePaths[0]);
using (var firstDoc = JsonDocument.Parse(firstJson))
{
var firstRoot = firstDoc.RootElement;
var ttlDims = ParseInt64Array(firstRoot.GetProperty("style_ttl").GetProperty("dims"));
var dpDims = ParseInt64Array(firstRoot.GetProperty("style_dp").GetProperty("dims"));
long ttlDim1 = ttlDims[1];
long ttlDim2 = ttlDims[2];
long dpDim1 = dpDims[1];
long dpDim2 = dpDims[2];
// Pre-allocate arrays with full batch size
int ttlSize = (int)(bsz * ttlDim1 * ttlDim2);
int dpSize = (int)(bsz * dpDim1 * dpDim2);
var ttlFlat = new float[ttlSize];
var dpFlat = new float[dpSize];
// Fill in the data
for (int i = 0; i < bsz; i++)
{
var json = File.ReadAllText(voiceStylePaths[i]);
using (var doc = JsonDocument.Parse(json))
{
var root = doc.RootElement;
// Flatten data
var ttlData3D = ParseFloat3DArray(root.GetProperty("style_ttl").GetProperty("data"));
var ttlDataFlat = new List<float>();
foreach (var batch in ttlData3D)
{
foreach (var row in batch)
{
ttlDataFlat.AddRange(row);
}
}
var dpData3D = ParseFloat3DArray(root.GetProperty("style_dp").GetProperty("data"));
var dpDataFlat = new List<float>();
foreach (var batch in dpData3D)
{
foreach (var row in batch)
{
dpDataFlat.AddRange(row);
}
}
// Copy to pre-allocated array
int ttlOffset = (int)(i * ttlDim1 * ttlDim2);
ttlDataFlat.CopyTo(ttlFlat, ttlOffset);
int dpOffset = (int)(i * dpDim1 * dpDim2);
dpDataFlat.CopyTo(dpFlat, dpOffset);
}
}
var ttlShape = new long[] { bsz, ttlDim1, ttlDim2 };
var dpShape = new long[] { bsz, dpDim1, dpDim2 };
if (verbose)
{
Console.WriteLine($"Loaded {bsz} voice styles");
}
return new Style(ttlFlat, ttlShape, dpFlat, dpShape);
}
}
private static float[][][] ParseFloat3DArray(JsonElement element)
{
var result = new List<float[][]>();
foreach (var batch in element.EnumerateArray())
{
var batch2D = new List<float[]>();
foreach (var row in batch.EnumerateArray())
{
var rowData = new List<float>();
foreach (var val in row.EnumerateArray())
{
rowData.Add(val.GetSingle());
}
batch2D.Add(rowData.ToArray());
}
result.Add(batch2D.ToArray());
}
return result.ToArray();
}
private static long[] ParseInt64Array(JsonElement element)
{
var result = new List<long>();
foreach (var val in element.EnumerateArray())
{
result.Add(val.GetInt64());
}
return result.ToArray();
}
// ============================================================================
// TextToSpeech loading
// ============================================================================
public static TextToSpeech LoadTextToSpeech(string onnxDir, bool useGpu = false)
{
var opts = new SessionOptions();
if (useGpu)
{
throw new NotImplementedException("GPU mode is not supported yet");
}
else
{
Console.WriteLine("Using CPU for inference");
}
var cfgs = LoadCfgs(onnxDir);
var (dpOrt, textEncOrt, vectorEstOrt, vocoderOrt) = LoadOnnxAll(onnxDir, opts);
var textProcessor = LoadTextProcessor(onnxDir);
return new TextToSpeech(cfgs, textProcessor, dpOrt, textEncOrt, vectorEstOrt, vocoderOrt);
}
// ============================================================================
// WAV file writing
// ============================================================================
public static void WriteWavFile(string filename, float[] audioData, int sampleRate)
{
using (var writer = new BinaryWriter(File.Open(filename, FileMode.Create)))
{
int numChannels = 1;
int bitsPerSample = 16;
int byteRate = sampleRate * numChannels * bitsPerSample / 8;
short blockAlign = (short)(numChannels * bitsPerSample / 8);
int dataSize = audioData.Length * bitsPerSample / 8;
// RIFF header
writer.Write(Encoding.ASCII.GetBytes("RIFF"));
writer.Write(36 + dataSize);
writer.Write(Encoding.ASCII.GetBytes("WAVE"));
// fmt chunk
writer.Write(Encoding.ASCII.GetBytes("fmt "));
writer.Write(16); // fmt chunk size
writer.Write((short)1); // audio format (PCM)
writer.Write((short)numChannels);
writer.Write(sampleRate);
writer.Write(byteRate);
writer.Write(blockAlign);
writer.Write((short)bitsPerSample);
// data chunk
writer.Write(Encoding.ASCII.GetBytes("data"));
writer.Write(dataSize);
// Write audio data
foreach (var sample in audioData)
{
float clamped = Math.Max(-1.0f, Math.Min(1.0f, sample));
short intSample = (short)(clamped * 32767);
writer.Write(intSample);
}
}
}
// ============================================================================
// Tensor conversion utilities
// ============================================================================
public static DenseTensor<float> ArrayToTensor(float[][][] array, long[] dims)
{
var flat = new List<float>();
foreach (var batch in array)
{
foreach (var row in batch)
{
flat.AddRange(row);
}
}
return new DenseTensor<float>(flat.ToArray(), dims.Select(x => (int)x).ToArray());
}
public static DenseTensor<long> IntArrayToTensor(long[][] array, long[] dims)
{
var flat = new List<long>();
foreach (var row in array)
{
flat.AddRange(row);
}
return new DenseTensor<long>(flat.ToArray(), dims.Select(x => (int)x).ToArray());
}
// ============================================================================
// Timer utility
// ============================================================================
public static T Timer<T>(string name, Func<T> func)
{
var start = DateTime.Now;
Console.WriteLine($"{name}...");
var result = func();
var elapsed = (DateTime.Now - start).TotalSeconds;
Console.WriteLine($" -> {name} completed in {elapsed:F2} sec");
return result;
}
public static string SanitizeFilename(string text, int maxLen)
{
var result = new StringBuilder();
int count = 0;
foreach (char c in text)
{
if (count >= maxLen) break;
if (char.IsLetterOrDigit(c))
{
result.Append(c);
}
else
{
result.Append('_');
}
count++;
}
return result.ToString();
}
// ============================================================================
// Chunk text
// ============================================================================
public static List<string> ChunkText(string text, int maxLen = 300)
{
var chunks = new List<string>();
// Split by paragraph (two or more newlines)
var paragraphRegex = new Regex(@"\n\s*\n+");
var paragraphs = paragraphRegex.Split(text.Trim())
.Select(p => p.Trim())
.Where(p => !string.IsNullOrEmpty(p))
.ToList();
// Split by sentence boundaries, excluding abbreviations
var sentenceRegex = new Regex(@"(?<!Mr\.|Mrs\.|Ms\.|Dr\.|Prof\.|Sr\.|Jr\.|Ph\.D\.|etc\.|e\.g\.|i\.e\.|vs\.|Inc\.|Ltd\.|Co\.|Corp\.|St\.|Ave\.|Blvd\.)(?<!\b[A-Z]\.)(?<=[.!?])\s+");
foreach (var paragraph in paragraphs)
{
var sentences = sentenceRegex.Split(paragraph);
string currentChunk = "";
foreach (var sentence in sentences)
{
if (string.IsNullOrEmpty(sentence)) continue;
if (currentChunk.Length + sentence.Length + 1 <= maxLen)
{
if (!string.IsNullOrEmpty(currentChunk))
{
currentChunk += " ";
}
currentChunk += sentence;
}
else
{
if (!string.IsNullOrEmpty(currentChunk))
{
chunks.Add(currentChunk.Trim());
}
currentChunk = sentence;
}
}
if (!string.IsNullOrEmpty(currentChunk))
{
chunks.Add(currentChunk.Trim());
}
}
// If no chunks were created, return the original text
if (chunks.Count == 0)
{
chunks.Add(text.Trim());
}
return chunks;
}
}
}