-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathGenericSymbolReferenceTree.cs
More file actions
288 lines (281 loc) · 12.7 KB
/
GenericSymbolReferenceTree.cs
File metadata and controls
288 lines (281 loc) · 12.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
namespace Monkeymoto.GeneratorUtils
{
/// <summary>
/// Represents a collection of closed generic symbols for use with an incremental source generator.
/// </summary>
/// <remarks>
/// <para>
/// This class discovers generic types and generic methods in your compilation and keeps
/// <see cref="ISymbol">symbolic references</see> to them and the <see cref="SyntaxNode">syntax</see> that produced
/// those references. This may create pressure on your compilation in terms of memory or time spent discovering
/// those symbols.
/// </para><para>
/// Open generic symbols are resolved to closed generic symbols by calling
/// <see cref="GetBranchesBySymbol">GetBranchesBySymbol</see>.
/// </para>
/// </remarks>
public sealed class GenericSymbolReferenceTree : IDisposable
{
private readonly Dictionary<ISymbol, ImmutableArray<GenericSymbolReference>> closedBranches =
new(SymbolEqualityComparer.Default);
private readonly Dictionary<ISymbol, HashSet<GenericSymbolReference>> openBranches =
new(SymbolEqualityComparer.Default);
/// <summary>
/// Creates a new tree from an incremental generator initialization context.
/// </summary>
/// <remarks>
/// <para>
/// The returned tree should <b>not</b> be a long-living object. You should extract the symbol references you
/// need from the tree and then call <see cref="Dispose">Dispose</see> to free the memory used by the tree.
/// </para>
/// </remarks>
/// <param name="context">The context used to create the new tree.</param>
/// <returns>An <see cref="IncrementalValueProvider{TValue}"/> which provides the newly created tree.</returns>
public static IncrementalValueProvider<GenericSymbolReferenceTree>
FromIncrementalGeneratorInitializationContext
(
IncrementalGeneratorInitializationContext context
)
{
return FromIncrementalGeneratorInitializationContext(context, static x => false);
}
/// <inheritdoc cref="FromIncrementalGeneratorInitializationContext(IncrementalGeneratorInitializationContext)"/>
/// <param name="excludePathPredicate">
/// A predicate used to selectively exclude certain file paths from the tree. For example, you may choose to
/// exclude file paths your generator added to the compilation that include only definitions. The predicate
/// receives the full file path of each <see cref="SyntaxNode">SyntaxNode</see>'s
/// <see cref="SyntaxTree">SyntaxTree</see> considered for inclusion in the
/// <see cref="GenericSymbolReferenceTree">GenericSymbolReferenceTree</see>.
/// </param>
public static IncrementalValueProvider<GenericSymbolReferenceTree>
FromIncrementalGeneratorInitializationContext
(
IncrementalGeneratorInitializationContext context,
Func<string, bool> excludePathPredicate
)
{
var symbolsProvider = context.SyntaxProvider.CreateSyntaxProvider
(
(node, cancellationToken) =>
{
cancellationToken.ThrowIfCancellationRequested();
return node switch
{
GenericNameSyntax => true,
IdentifierNameSyntax identifierName =>
identifierName.Parent switch
{
ArgumentSyntax or EqualsValueClauseSyntax or InvocationExpressionSyntax => true,
MemberAccessExpressionSyntax memberAccessExpression =>
memberAccessExpression.Parent is InvocationExpressionSyntax,
_ => false
},
_ => false
} && !excludePathPredicate(node.SyntaxTree.FilePath);
},
static (context, cancellationToken) =>
{
cancellationToken.ThrowIfCancellationRequested();
return GenericSymbolReference.FromSyntaxNodeInternal
(
context.Node,
context.SemanticModel,
cancellationToken
);
}
);
return symbolsProvider.Collect()
.Select
(
static (references, cancellationToken) =>
new GenericSymbolReferenceTree(references, cancellationToken)
);
}
private GenericSymbolReferenceTree
(
ImmutableArray<GenericSymbolReference?> references,
CancellationToken cancellationToken
)
{
foreach (var reference in references)
{
if (reference is not null)
{
if (openBranches.TryGetValue(reference.Symbol, out var set))
{
_ = set.Add(reference);
}
else
{
openBranches[reference.Symbol] = [reference];
}
}
}
}
/// <summary>
/// Removes all references in the tree, releasing its memory.
/// </summary>
public void Dispose()
{
closedBranches.Clear();
openBranches.Clear();
}
/// <summary>
/// Returns a collection of all branches in the tree that match the given symbol.
/// </summary>
/// <remarks>
/// <para>
/// If <paramref name="symbol"/> is an open generic symbol, this method will discover all branches that match
/// <paramref name="symbol"/> after type substitutions. If <paramref name="symbol"/> is the
/// <see cref="ISymbol.OriginalDefinition">original symbol definition</see>, this method will discover all
/// branches that share the same original symbol.
/// </para><para>
/// If <paramref name="symbol"/> is a closed generic symbol, then the returned collection will only represent
/// those syntax nodes which reference this closed symbol.
/// </para>
/// </remarks>
/// <param name="symbol">The generic symbol to find in the tree.</param>
/// <param name="cancellationToken">
/// The <see cref="CancellationToken"/> that will be observed while searching the tree.
/// </param>
/// <returns>
/// A flattened collection of all branches in the tree that match <paramref name="symbol"/>, regardless of the
/// syntax node. The returned collection will only contain closed generic symbols.
/// </returns>
public IEnumerable<GenericSymbolReference> GetBranchesBySymbol
(
ISymbol symbol,
CancellationToken cancellationToken
)
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
bool SymbolEquals(ISymbol? other)
{
cancellationToken.ThrowIfCancellationRequested();
return SymbolEqualityComparer.Default.Equals(symbol, other);
}
if (closedBranches.TryGetValue(symbol, out var branches))
{
return branches;
}
int typeArgumentCount;
switch (symbol)
{
case null:
case IMethodSymbol { IsGenericMethod: false }:
case INamedTypeSymbol { IsGenericType: false }:
return [];
case ISymbol { IsDefinition: true }:
var newBranches = closedBranches
.Where(x => SymbolEquals(x.Key.OriginalDefinition))
.ToImmutableArray()
.SelectMany(static x => x.Value)
.Concat
(
openBranches
.Where(x => SymbolEquals(x.Key.OriginalDefinition))
.ToImmutableArray()
.SelectMany(x => GetBranchesBySymbol(x.Key, cancellationToken))
);
branches = [.. newBranches];
closedBranches[symbol] = branches;
return branches;
case IMethodSymbol methodSymbol:
typeArgumentCount = methodSymbol.TypeArguments.Length;
break;
case INamedTypeSymbol namedTypeSymbol:
typeArgumentCount = namedTypeSymbol.TypeArguments.Length;
break;
default:
return [];
}
if (!openBranches.TryGetValue(symbol, out var openBranch))
{
return [];
}
_ = openBranches.Remove(symbol);
var typeArgumentSetList = new List<HashSet<INamedTypeSymbol>>(typeArgumentCount);
for (int i = 0; i < typeArgumentCount; ++i)
{
typeArgumentSetList.Add(new HashSet<INamedTypeSymbol>(SymbolEqualityComparer.Default));
}
foreach (var reference in openBranch)
{
var typeArguments = reference.TypeArguments;
for (int i = 0; i < typeArgumentCount; ++i)
{
cancellationToken.ThrowIfCancellationRequested();
var typeArgument = typeArguments[i];
var typeArgumentSet = typeArgumentSetList[i];
if (GenericSymbolReference.IsOpenTypeOrMethodSymbol(typeArgument))
{
typeArgumentSet.UnionWith
(
typeArgument switch
{
ITypeParameterSymbol typeParameter =>
GetBranchesBySymbol(typeParameter.ContainingSymbol, cancellationToken)
.Select(x => (INamedTypeSymbol)x.TypeArguments[typeParameter.Ordinal]),
_ => GetBranchesBySymbol(typeArgument, cancellationToken)
.Select(static x => (INamedTypeSymbol)x.Symbol)
}
);
}
else
{
_ = typeArgumentSet.Add((INamedTypeSymbol)typeArgument);
}
}
}
Func<ITypeSymbol[], ISymbol> construct = symbol switch
{
IMethodSymbol methodSymbol => methodSymbol.OriginalDefinition.Construct,
INamedTypeSymbol namedTypeSymbol => namedTypeSymbol.OriginalDefinition.Construct,
_ => throw new UnreachableException()
};
var constructedSymbols = new List<ISymbol>();
foreach (var typeArgumentList in typeArgumentSetList.CartesianProduct())
{
cancellationToken.ThrowIfCancellationRequested();
constructedSymbols.Add(construct([.. typeArgumentList]));
}
var newReferences = new HashSet<GenericSymbolReference>();
foreach (var reference in openBranch)
{
cancellationToken.ThrowIfCancellationRequested();
foreach (var constructedSymbol in constructedSymbols)
{
var typeArguments = constructedSymbol switch
{
IMethodSymbol methodSymbol => methodSymbol.TypeArguments,
INamedTypeSymbol namedTypeSymbol => namedTypeSymbol.TypeArguments,
_ => throw new UnreachableException()
};
_ = newReferences.Add
(
new GenericSymbolReference
(
reference.Node,
reference.SemanticModel,
constructedSymbol,
typeArguments,
cancellationToken
)
);
}
}
branches = [.. newReferences];
closedBranches[symbol] = branches;
return branches;
}
}
}