-
-
Notifications
You must be signed in to change notification settings - Fork 163
Expand file tree
/
Copy pathNullSafeExpressionRewriter.cs
More file actions
315 lines (257 loc) · 11.6 KB
/
NullSafeExpressionRewriter.cs
File metadata and controls
315 lines (257 loc) · 11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
using System.Linq.Expressions;
using System.Reflection;
namespace NoEntityFrameworkExample;
/// <summary>
/// Inserts a null check on member dereference and extension method invocation, to prevent a <see cref="NullReferenceException" /> from being thrown when
/// the expression is compiled and executed.
/// <example>
/// For example, <code><![CDATA[
/// Database.TodoItems.Where(todoItem => todoItem.Assignee.Id == todoItem.Owner.Id)
/// ]]></code> would throw if the database
/// contains a TodoItem that doesn't have an assignee.
/// </example>
/// </summary>
public sealed class NullSafeExpressionRewriter : ExpressionVisitor
{
private const string MinValueName = nameof(long.MinValue);
private static readonly ConstantExpression Int32MinValueConstant = Expression.Constant(int.MinValue, typeof(int));
private static readonly ExpressionType[] ComparisonExpressionTypes =
[
ExpressionType.LessThan,
ExpressionType.LessThanOrEqual,
ExpressionType.GreaterThan,
ExpressionType.GreaterThanOrEqual,
ExpressionType.Equal
// ExpressionType.NotEqual is excluded because WhereClauseBuilder never produces that.
];
private readonly Stack<MethodType> _callStack = new();
public TExpression Rewrite<TExpression>(TExpression expression)
where TExpression : Expression
{
_callStack.Clear();
return (TExpression)Visit(expression);
}
protected override Expression VisitMethodCall(MethodCallExpression node)
{
if (node.Method.Name == "Where")
{
_callStack.Push(MethodType.Where);
Expression expression = base.VisitMethodCall(node);
_callStack.Pop();
return expression;
}
if (node.Method.Name is "OrderBy" or "OrderByDescending" or "ThenBy" or "ThenByDescending")
{
// Ordering can be improved by expanding into multiple OrderBy/ThenBy() calls, as described at
// https://stackoverflow.com/questions/26186527/linq-order-by-descending-with-null-values-on-bottom/26186585#26186585.
// For example:
// .OrderBy(element => element.First.Second.CharValue)
// Could be translated to:
// .OrderBy(element => element.First != null)
// .ThenBy(element => element.First == null ? false : element.First.Second != null)
// .ThenBy(element => element.First == null ? '\0' : element.First.Second == null ? '\0' : element.First.Second.CharValue)
// Which correctly orders 'element.First == null' before 'element.First.Second == null'.
// The current implementation translates to:
// .OrderBy(element => element.First == null ? '\0' : element.First.Second == null ? '\0' : element.First.Second.CharValue)
// in which the order of these two rows is undeterministic.
_callStack.Push(MethodType.Ordering);
Expression expression = base.VisitMethodCall(node);
_callStack.Pop();
return expression;
}
if (_callStack.Count > 0)
{
MethodType outerMethodType = _callStack.Peek();
if (outerMethodType == MethodType.Ordering && node.Method.Name == "Count")
{
return ToNullSafeCountInvocationInOrderBy(node);
}
if (outerMethodType == MethodType.Where && node.Method.Name == "Any")
{
return ToNullSafeAnyInvocationInWhere(node);
}
}
return base.VisitMethodCall(node);
}
private static Expression ToNullSafeCountInvocationInOrderBy(MethodCallExpression countMethodCall)
{
Expression thisArgument = countMethodCall.Arguments.Single();
if (thisArgument is MemberExpression memberArgument)
{
// OrderClauseBuilder never produces nested Count() calls.
// SRC: some.Other.Children.Count()
// DST: some.Other == null ? int.MinValue : some.Other.Children == null ? int.MinValue : some.Other.Children.Count()
return ToConditionalMemberAccessInOrderBy(countMethodCall, memberArgument, Int32MinValueConstant);
}
return countMethodCall;
}
private static Expression ToConditionalMemberAccessInOrderBy(Expression outer, MemberExpression innerMember, ConstantExpression defaultValue)
{
MemberExpression? currentMember = innerMember;
Expression result = outer;
do
{
// Static property/field invocations can never be null (though unlikely we'll ever encounter those).
if (!IsStaticMemberAccess(currentMember))
{
// SRC: first.Second.StringValue
// DST: first.Second == null ? null : first.Second.StringValue
ConstantExpression nullConstant = Expression.Constant(null, currentMember.Type);
BinaryExpression isNull = Expression.Equal(currentMember, nullConstant);
result = Expression.Condition(isNull, defaultValue, result);
}
currentMember = currentMember.Expression as MemberExpression;
}
while (currentMember != null);
return result;
}
private static bool IsStaticMemberAccess(MemberExpression member)
{
if (member.Member is FieldInfo field)
{
return field.IsStatic;
}
if (member.Member is PropertyInfo property)
{
MethodInfo? getter = property.GetGetMethod();
return getter != null && getter.IsStatic;
}
return false;
}
private Expression ToNullSafeAnyInvocationInWhere(MethodCallExpression anyMethodCall)
{
Expression thisArgument = anyMethodCall.Arguments.First();
if (thisArgument is MemberExpression memberArgument)
{
MethodCallExpression newAnyMethodCall = anyMethodCall;
if (anyMethodCall.Arguments.Count > 1)
{
// SRC: .Any(first => first.Second.Value == 1)
// DST: .Any(first => first != null && first.Second != null && first.Second.Value == 1)
List<Expression> newArguments = anyMethodCall.Arguments.Skip(1).Select(Visit).Cast<Expression>().ToList();
newArguments.Insert(0, thisArgument);
newAnyMethodCall = anyMethodCall.Update(anyMethodCall.Object, newArguments);
}
// SRC: some.Other.Any()
// DST: some != null && some.Other != null && some.Other.Any()
return ToConditionalMemberAccessInBooleanExpression(newAnyMethodCall, memberArgument, false);
}
return anyMethodCall;
}
private static Expression ToConditionalMemberAccessInBooleanExpression(Expression outer, MemberExpression innerMember, bool skipNullCheckOnLastAccess)
{
MemberExpression? currentMember = innerMember;
Expression result = outer;
do
{
// Null-check the last member access in the chain on extension method invocation. For example: a.b.c.Count() requires a null-check on 'c'.
// This is unneeded for boolean comparisons. For example: a.b.c == d does not require a null-check on 'c'.
if (!skipNullCheckOnLastAccess || currentMember != innerMember)
{
// Static property/field invocations can never be null (though unlikely we'll ever encounter those).
if (!IsStaticMemberAccess(currentMember))
{
// SRC: first.Second.Value == 1
// DST: first.Second != null && first.Second.Value == 1
ConstantExpression nullConstant = Expression.Constant(null, currentMember.Type);
BinaryExpression isNotNull = Expression.NotEqual(currentMember, nullConstant);
result = Expression.AndAlso(isNotNull, result);
}
}
// Do not null-check the first member access in the chain, because that's the lambda parameter itself.
// For example, in: item => item.First.Second, 'item' does not require a null-check.
currentMember = currentMember.Expression as MemberExpression;
}
while (currentMember != null);
return result;
}
protected override Expression VisitBinary(BinaryExpression node)
{
if (_callStack.Count > 0 && _callStack.Peek() == MethodType.Where)
{
if (ComparisonExpressionTypes.Contains(node.NodeType))
{
Expression result = node;
result = ToNullSafeTermInBinary(node.Right, result);
result = ToNullSafeTermInBinary(node.Left, result);
return result;
}
}
return base.VisitBinary(node);
}
private static Expression ToNullSafeTermInBinary(Expression binaryTerm, Expression result)
{
if (binaryTerm is MemberExpression rightMember)
{
// SRC: some.Other.Value == 1
// DST: some != null && some.Other != null && some.Other.Value == 1
return ToConditionalMemberAccessInBooleanExpression(result, rightMember, true);
}
if (binaryTerm is MethodCallExpression { Method.Name: "Count" } countMethodCall)
{
Expression thisArgument = countMethodCall.Arguments.Single();
if (thisArgument is MemberExpression memberArgument)
{
// SRC: some.Other.Count() == 1
// DST: some != null && some.Other != null && some.Other.Count() == 1
return ToConditionalMemberAccessInBooleanExpression(result, memberArgument, false);
}
}
return result;
}
protected override Expression VisitMember(MemberExpression node)
{
if (_callStack.Count > 0 && _callStack.Peek() == MethodType.Ordering)
{
if (node.Expression is MemberExpression innerMember)
{
ConstantExpression defaultValue = CreateConstantForMemberIsNull(node.Type);
return ToConditionalMemberAccessInOrderBy(node, innerMember, defaultValue);
}
return node;
}
return base.VisitMember(node);
}
private static ConstantExpression CreateConstantForMemberIsNull(Type type)
{
bool canContainNull = !type.IsValueType || Nullable.GetUnderlyingType(type) != null;
if (canContainNull)
{
return Expression.Constant(null, type);
}
Type innerType = Nullable.GetUnderlyingType(type) ?? type;
ConstantExpression? constant = TryCreateConstantForStaticMinValue(innerType);
if (constant != null)
{
return constant;
}
object? defaultValue = Activator.CreateInstance(type);
return Expression.Constant(defaultValue, type);
}
private static ConstantExpression? TryCreateConstantForStaticMinValue(Type type)
{
// Int32.MinValue is a field, while Int128.MinValue is a property.
FieldInfo? field = type.GetField(MinValueName, BindingFlags.Public | BindingFlags.Static);
if (field != null)
{
object? value = field.GetValue(null);
return Expression.Constant(value, type);
}
PropertyInfo? property = type.GetProperty(MinValueName, BindingFlags.Public | BindingFlags.Static);
if (property != null)
{
MethodInfo? getter = property.GetGetMethod();
if (getter != null)
{
object? value = getter.Invoke(null, []);
return Expression.Constant(value, type);
}
}
return null;
}
private enum MethodType
{
Where,
Ordering
}
}