-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathDiffSingerTensorCache.cs
More file actions
273 lines (250 loc) · 13.4 KB
/
Copy pathDiffSingerTensorCache.cs
File metadata and controls
273 lines (250 loc) · 13.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Hashing;
using System.Linq;
using System.Text;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using TuneLab.Foundation;
using TuneLab.SDK;
namespace DiffSingerForTuneLab;
// DiffSinger 张量缓存:把一个 ONNX 模型调用的输出按「模型文件哈希 + 序列化输入」为键缓存到磁盘,
// 反复合成(撤销重做、重开工程、改动不影响某块、跨块/跨说话人共享 linguistic 等)时直接复用、免重算。
// 忠实移植 OpenUtau DiffSingerCache(序列化/反序列化、文件格式与按 name 排序求键一致),差异仅在:
// · 哈希用 System.IO.Hashing.XxHash64(OpenUtau 用 K4os.Hash.xxHash),数值不同但本插件不需跨工具兼容;
// · 缓存目录取插件独立用户数据根 UserDataRoot/Cache(OpenUtau 用 PathManager.Inst.CachePath)。
// 另附编排封装:Run(建键→Load→未命中则模型 Run + Save,返回脱离原生内存的托管张量)、
// Clone(把模型原生输出深拷成托管,供未命中/禁用时安全返回)、HashFile(模型 identifier)、EnforceSizeLimit(LRU 逐出)。
public sealed class DiffSingerTensorCache
{
const string FormatHeader = "TENSORCACHE";
readonly ulong mHash;
readonly string mFilename;
public ulong Hash => mHash;
public string Filename => mFilename;
// 缓存目录:插件用户数据根下的 Cache(与 Voices/Vocoders 并列)。
public static string CacheDirectory => Path.Combine(DiffSingerDeclarations.UserDataRoot, "Cache");
public DiffSingerTensorCache(ulong identifier, IReadOnlyCollection<NamedOnnxValue> inputs)
{
using var stream = new MemoryStream();
using (var writer = new BinaryWriter(stream, Encoding.UTF8, leaveOpen: true))
{
writer.Write(identifier);
foreach (var onnxValue in inputs.OrderBy(v => v.Name, StringComparer.InvariantCulture))
SerializeNamedOnnxValue(writer, onnxValue);
}
mHash = XxHash64.HashToUInt64(stream.ToArray());
mFilename = $"ds-{mHash:x16}.tensorcache";
}
// —— 编排封装 ——
// 跑一个模型并经缓存:命中直接返回缓存输出;未命中则 Run + Save。返回的张量均为托管 DenseTensor(脱离原生内存,可在调用处延后读取)。
// enabled=false 时跳过磁盘缓存,但仍把原生输出深拷为托管返回(调用方约定可在 Run 作用域外读取输出)。
public static IReadOnlyList<NamedOnnxValue> Run(
InferenceSession model, ulong identifier, IReadOnlyCollection<NamedOnnxValue> inputs, bool enabled)
{
if (!enabled)
{
using var raw = model.Run(inputs);
return Clone(raw);
}
var cache = new DiffSingerTensorCache(identifier, inputs);
var loaded = cache.Load();
if (loaded != null)
return loaded;
using var run = model.Run(inputs);
var result = Clone(run); // 先脱离原生内存,再落盘(Save 读托管张量即可)
cache.Save(result);
return result;
}
// 把(可能由原生 OrtValue 支撑的)输出深拷为托管 DenseTensor,使其在原生集合 Dispose 后仍可安全读取。
// 复用序列化/反序列化的类型分支(往返一次内存流),零重复代码;相对扩散推理开销可忽略。
public static List<NamedOnnxValue> Clone(IEnumerable<NamedOnnxValue> values)
{
var list = new List<NamedOnnxValue>();
foreach (var v in values)
{
using var ms = new MemoryStream();
using (var w = new BinaryWriter(ms, Encoding.UTF8, leaveOpen: true))
SerializeNamedOnnxValue(w, v);
ms.Position = 0;
using var r = new BinaryReader(ms);
list.Add(DeserializeNamedOnnxValue(r));
}
return list;
}
// 模型 identifier:.onnx 文件内容的 XxHash64(流式、不整体载入内存)。加载时算一次缓存进字段,
// 用作缓存键的一部分,区分不同模型(同输入不同权重不撞键),且模型文件更换即自动失效。
public static ulong HashFile(string path)
{
var h = new XxHash64();
using var fs = File.OpenRead(path);
h.Append(fs);
return h.GetCurrentHashAsUInt64();
}
// LRU 体积上限逐出:缓存目录超过上限时,按最近访问时间删最旧的 .tensorcache 直到回落。maxSizeMb<=0 视作不限制。
// 尽力而为,任何 IO 异常吞掉(逐出失败不应影响合成)。
public static void EnforceSizeLimit(long maxSizeMb)
{
if (maxSizeMb <= 0)
return;
try
{
var dir = CacheDirectory;
if (!Directory.Exists(dir))
return;
var files = new DirectoryInfo(dir).GetFiles("*.tensorcache");
long total = files.Sum(f => f.Length);
long limit = maxSizeMb * 1024L * 1024L;
if (total <= limit)
return;
foreach (var f in files.OrderBy(f => f.LastAccessTimeUtc))
{
try { long len = f.Length; f.Delete(); total -= len; } catch { }
if (total <= limit)
break;
}
}
catch { }
}
public IReadOnlyList<NamedOnnxValue>? Load()
{
var cachePath = Path.Join(CacheDirectory, mFilename);
if (!File.Exists(cachePath))
return null;
var result = new List<NamedOnnxValue>();
try
{
using (var stream = new FileStream(cachePath, FileMode.Open, FileAccess.Read))
using (var reader = new BinaryReader(stream))
{
if (reader.ReadString() != FormatHeader)
throw new InvalidDataException($"[TensorCache] 缓存文件头异常:{mFilename}。");
var count = reader.ReadInt32();
for (var i = 0; i < count; ++i)
result.Add(DeserializeNamedOnnxValue(reader));
}
}
catch (Exception e)
{
TuneLabContext.Global.GetLogger().Warning($"DiffSinger:反序列化缓存 {mFilename} 失败、丢弃重算:{e.Message}");
Delete();
return null;
}
// 命中即「访问」:显式刷新访问时间,令 LRU 逐出以真实使用近度排序(不依赖 NTFS 自动 last-access 策略)。
try { File.SetLastAccessTimeUtc(cachePath, DateTime.UtcNow); } catch { }
return result;
}
public void Delete()
{
var cachePath = Path.Join(CacheDirectory, mFilename);
if (File.Exists(cachePath))
{
try { File.Delete(cachePath); } catch { }
}
}
public void Save(IReadOnlyCollection<NamedOnnxValue> outputs)
{
Directory.CreateDirectory(CacheDirectory);
var cachePath = Path.Join(CacheDirectory, mFilename);
using var stream = new FileStream(cachePath, FileMode.Create, FileAccess.Write);
using var writer = new BinaryWriter(stream);
writer.Write(FormatHeader);
writer.Write(outputs.Count);
foreach (var onnxValue in outputs)
SerializeNamedOnnxValue(writer, onnxValue);
}
static void SerializeNamedOnnxValue(BinaryWriter writer, NamedOnnxValue namedOnnxValue)
{
if (namedOnnxValue.ValueType != OnnxValueType.ONNX_TYPE_TENSOR)
throw new NotSupportedException(
$"[TensorCache] 仅支持张量类型 {OnnxValueType.ONNX_TYPE_TENSOR},遇 {namedOnnxValue.ValueType}。");
writer.Write(namedOnnxValue.Name);
var tensorBase = (TensorBase)namedOnnxValue.Value;
var elementType = tensorBase.GetTypeInfo().ElementType;
writer.Write((int)elementType);
switch (elementType)
{
case TensorElementType.Float: SerializeTensor(writer, namedOnnxValue.AsTensor<float>()); break;
case TensorElementType.UInt8: SerializeTensor(writer, namedOnnxValue.AsTensor<byte>()); break;
case TensorElementType.Int8: SerializeTensor(writer, namedOnnxValue.AsTensor<sbyte>()); break;
case TensorElementType.UInt16: SerializeTensor(writer, namedOnnxValue.AsTensor<ushort>()); break;
case TensorElementType.Int16: SerializeTensor(writer, namedOnnxValue.AsTensor<short>()); break;
case TensorElementType.Int32: SerializeTensor(writer, namedOnnxValue.AsTensor<int>()); break;
case TensorElementType.Int64: SerializeTensor(writer, namedOnnxValue.AsTensor<long>()); break;
case TensorElementType.String: SerializeTensor(writer, namedOnnxValue.AsTensor<string>()); break;
case TensorElementType.Bool: SerializeTensor(writer, namedOnnxValue.AsTensor<bool>()); break;
case TensorElementType.Float16: SerializeTensor(writer, namedOnnxValue.AsTensor<Float16>()); break;
case TensorElementType.Double: SerializeTensor(writer, namedOnnxValue.AsTensor<double>()); break;
case TensorElementType.UInt32: SerializeTensor(writer, namedOnnxValue.AsTensor<uint>()); break;
case TensorElementType.UInt64: SerializeTensor(writer, namedOnnxValue.AsTensor<ulong>()); break;
case TensorElementType.BFloat16: SerializeTensor(writer, namedOnnxValue.AsTensor<BFloat16>()); break;
default:
throw new NotSupportedException($"[TensorCache] 不支持的张量元素类型:{elementType}。");
}
}
static void SerializeTensor<T>(BinaryWriter writer, Tensor<T> tensor)
{
if (tensor.IsReversedStride)
throw new NotSupportedException("[TensorCache] 不支持反序步幅张量。");
writer.Write(tensor.Rank);
foreach (var dim in tensor.Dimensions)
writer.Write(dim);
var size = (int)tensor.Length;
writer.Write(size);
if (typeof(T) == typeof(string))
{
foreach (var element in tensor.ToArray())
writer.Write(element?.ToString() ?? string.Empty);
}
else
{
var data = new byte[size * tensor.GetTypeInfo().TypeSize];
Buffer.BlockCopy(tensor.ToArray(), 0, data, 0, data.Length);
writer.Write(data);
}
}
static NamedOnnxValue DeserializeNamedOnnxValue(BinaryReader reader)
{
var name = reader.ReadString();
var dtype = (TensorElementType)reader.ReadInt32();
var rank = reader.ReadInt32();
int[] shape = new int[rank];
for (var i = 0; i < rank; ++i)
shape[i] = reader.ReadInt32();
var size = reader.ReadInt32();
switch (dtype)
{
case TensorElementType.Float: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<float>(reader, size, sizeof(float), shape));
case TensorElementType.UInt8: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<byte>(reader, size, sizeof(byte), shape));
case TensorElementType.Int8: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<sbyte>(reader, size, sizeof(sbyte), shape));
case TensorElementType.UInt16: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<ushort>(reader, size, sizeof(ushort), shape));
case TensorElementType.Int16: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<short>(reader, size, sizeof(short), shape));
case TensorElementType.Int32: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<int>(reader, size, sizeof(int), shape));
case TensorElementType.Int64: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<long>(reader, size, sizeof(long), shape));
case TensorElementType.String:
{
Tensor<string> tensor = new DenseTensor<string>(size);
for (var i = 0; i < size; ++i)
tensor[i] = reader.ReadString();
tensor = tensor.Reshape(shape);
return NamedOnnxValue.CreateFromTensor(name, tensor);
}
case TensorElementType.Bool: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<bool>(reader, size, sizeof(bool), shape));
case TensorElementType.Float16: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<Float16>(reader, size, sizeof(ushort), shape));
case TensorElementType.Double: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<double>(reader, size, sizeof(double), shape));
case TensorElementType.UInt32: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<uint>(reader, size, sizeof(uint), shape));
case TensorElementType.UInt64: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<ulong>(reader, size, sizeof(ulong), shape));
case TensorElementType.BFloat16: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<BFloat16>(reader, size, sizeof(ushort), shape));
default:
throw new NotSupportedException($"[TensorCache] 不支持的张量元素类型:{dtype}。");
}
}
static Tensor<T> DeserializeTensor<T>(BinaryReader reader, int size, int typeSize, ReadOnlySpan<int> shape)
{
var bytes = reader.ReadBytes(size * typeSize);
var data = new T[size];
Buffer.BlockCopy(bytes, 0, data, 0, bytes.Length);
return new DenseTensor<T>(data, shape);
}
}