-
-
Notifications
You must be signed in to change notification settings - Fork 163
Expand file tree
/
Copy pathWhereClauseBuilder.cs
More file actions
249 lines (195 loc) · 9.27 KB
/
WhereClauseBuilder.cs
File metadata and controls
249 lines (195 loc) · 9.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
using System.Collections;
using System.Linq.Expressions;
using JetBrains.Annotations;
using JsonApiDotNetCore.Configuration;
using JsonApiDotNetCore.Errors;
using JsonApiDotNetCore.Queries.Expressions;
using JsonApiDotNetCore.Resources;
using JsonApiDotNetCore.Resources.Annotations;
namespace JsonApiDotNetCore.Queries.QueryableBuilding;
/// <inheritdoc cref="IWhereClauseBuilder" />
[PublicAPI]
public class WhereClauseBuilder : QueryClauseBuilder, IWhereClauseBuilder
{
private static readonly ConstantExpression NullConstant = Expression.Constant(null);
public virtual Expression ApplyWhere(FilterExpression filter, QueryClauseBuilderContext context)
{
ArgumentNullException.ThrowIfNull(filter);
LambdaExpression lambda = GetPredicateLambda(filter, context);
return WhereExtensionMethodCall(lambda, context);
}
private LambdaExpression GetPredicateLambda(FilterExpression filter, QueryClauseBuilderContext context)
{
Expression body = Visit(filter, context);
return Expression.Lambda(body, context.LambdaScope.Parameter);
}
private static MethodCallExpression WhereExtensionMethodCall(LambdaExpression predicate, QueryClauseBuilderContext context)
{
return Expression.Call(context.ExtensionType, "Where", [context.LambdaScope.Parameter.Type], context.Source, predicate);
}
public override Expression VisitHas(HasExpression expression, QueryClauseBuilderContext context)
{
Expression property = Visit(expression.TargetCollection, context);
Type? elementType = CollectionConverter.Instance.FindCollectionElementType(property.Type);
if (elementType == null)
{
throw new InvalidOperationException("Expression must be a collection.");
}
Expression? predicate = null;
if (expression.Filter != null)
{
ResourceType resourceType = ((HasManyAttribute)expression.TargetCollection.Fields[^1]).RightType;
using LambdaScope lambdaScope = context.LambdaScopeFactory.CreateScope(elementType);
var nestedContext = new QueryClauseBuilderContext(property, resourceType, typeof(Enumerable), context.EntityModel, context.LambdaScopeFactory,
lambdaScope, context.QueryableBuilder, context.State);
predicate = GetPredicateLambda(expression.Filter, nestedContext);
}
return AnyExtensionMethodCall(elementType, property, predicate);
}
private static MethodCallExpression AnyExtensionMethodCall(Type elementType, Expression source, Expression? predicate)
{
return predicate != null
? Expression.Call(typeof(Enumerable), "Any", [elementType], source, predicate)
: Expression.Call(typeof(Enumerable), "Any", [elementType], source);
}
public override Expression VisitIsType(IsTypeExpression expression, QueryClauseBuilderContext context)
{
Expression property = expression.TargetToOneRelationship != null ? Visit(expression.TargetToOneRelationship, context) : context.LambdaScope.Accessor;
TypeBinaryExpression typeCheck = Expression.TypeIs(property, expression.DerivedType.ClrType);
if (expression.Child == null)
{
return typeCheck;
}
UnaryExpression derivedAccessor = Expression.Convert(property, expression.DerivedType.ClrType);
QueryClauseBuilderContext derivedContext = context.WithLambdaScope(context.LambdaScope.WithAccessor(derivedAccessor));
Expression filter = Visit(expression.Child, derivedContext);
return Expression.AndAlso(typeCheck, filter);
}
public override Expression VisitMatchText(MatchTextExpression expression, QueryClauseBuilderContext context)
{
Expression property = Visit(expression.MatchTarget, context);
if (property.Type != typeof(string))
{
throw new InvalidOperationException("Expression must be a string.");
}
Expression text = Visit(expression.TextValue, context);
return expression.MatchKind switch
{
TextMatchKind.StartsWith => Expression.Call(property, "StartsWith", null, text),
TextMatchKind.EndsWith => Expression.Call(property, "EndsWith", null, text),
_ => Expression.Call(property, "Contains", null, text)
};
}
public override Expression VisitAny(AnyExpression expression, QueryClauseBuilderContext context)
{
Expression property = Visit(expression.MatchTarget, context);
var valueList = (IList)Activator.CreateInstance(typeof(List<>).MakeGenericType(property.Type))!;
foreach (LiteralConstantExpression constant in expression.Constants)
{
valueList.Add(constant.TypedValue);
}
ConstantExpression collection = Expression.Constant(valueList);
return ContainsExtensionMethodCall(collection, property);
}
private static MethodCallExpression ContainsExtensionMethodCall(Expression collection, Expression value)
{
return Expression.Call(typeof(Enumerable), "Contains", [value.Type], collection, value);
}
public override Expression VisitLogical(LogicalExpression expression, QueryClauseBuilderContext context)
{
var termQueue = new Queue<Expression>(expression.Terms.Select(filter => Visit(filter, context)));
return expression.Operator switch
{
LogicalOperator.And => Compose(termQueue, Expression.AndAlso),
LogicalOperator.Or => Compose(termQueue, Expression.OrElse),
_ => throw new InvalidOperationException($"Unknown logical operator '{expression.Operator}'.")
};
}
private static BinaryExpression Compose(Queue<Expression> argumentQueue, Func<Expression, Expression, BinaryExpression> applyOperator)
{
Expression left = argumentQueue.Dequeue();
Expression right = argumentQueue.Dequeue();
BinaryExpression tempExpression = applyOperator(left, right);
while (argumentQueue.Count > 0)
{
Expression nextArgument = argumentQueue.Dequeue();
tempExpression = applyOperator(tempExpression, nextArgument);
}
return tempExpression;
}
public override Expression VisitNot(NotExpression expression, QueryClauseBuilderContext context)
{
Expression child = Visit(expression.Child, context);
return Expression.Not(child);
}
public override Expression VisitComparison(ComparisonExpression expression, QueryClauseBuilderContext context)
{
Type commonType = ResolveCommonType(expression.Left, expression.Right, context);
Expression left = WrapInConvert(Visit(expression.Left, context), commonType);
Expression right = WrapInConvert(Visit(expression.Right, context), commonType);
return expression.Operator switch
{
ComparisonOperator.Equals => Expression.Equal(left, right),
ComparisonOperator.LessThan => Expression.LessThan(left, right),
ComparisonOperator.LessOrEqual => Expression.LessThanOrEqual(left, right),
ComparisonOperator.GreaterThan => Expression.GreaterThan(left, right),
ComparisonOperator.GreaterOrEqual => Expression.GreaterThanOrEqual(left, right),
_ => throw new InvalidOperationException($"Unknown comparison operator '{expression.Operator}'.")
};
}
private Type ResolveCommonType(QueryExpression left, QueryExpression right, QueryClauseBuilderContext context)
{
Type leftType = ResolveFixedType(left, context);
if (RuntimeTypeConverter.CanContainNull(leftType))
{
return leftType;
}
if (right is NullConstantExpression)
{
return typeof(Nullable<>).MakeGenericType(leftType);
}
Type? rightType = TryResolveFixedType(right, context);
if (rightType != null && RuntimeTypeConverter.CanContainNull(rightType))
{
return rightType;
}
return leftType;
}
private Type ResolveFixedType(QueryExpression expression, QueryClauseBuilderContext context)
{
Expression result = Visit(expression, context);
return result.Type;
}
private Type? TryResolveFixedType(QueryExpression expression, QueryClauseBuilderContext context)
{
if (expression is CountExpression)
{
return typeof(int);
}
if (expression is ResourceFieldChainExpression chain)
{
Expression child = Visit(chain, context);
return child.Type;
}
return null;
}
private static Expression WrapInConvert(Expression expression, Type targetType)
{
try
{
return expression.Type != targetType ? Expression.Convert(expression, targetType) : expression;
}
catch (InvalidOperationException exception)
{
throw new InvalidQueryException("Query creation failed due to incompatible types.", exception);
}
}
public override Expression VisitNullConstant(NullConstantExpression expression, QueryClauseBuilderContext context)
{
return NullConstant;
}
public override Expression VisitLiteralConstant(LiteralConstantExpression expression, QueryClauseBuilderContext context)
{
return SystemExpressionBuilder.CloseOver(expression.TypedValue);
}
}