-
-
Notifications
You must be signed in to change notification settings - Fork 163
Expand file tree
/
Copy pathCarExpressionRewriter.cs
More file actions
162 lines (131 loc) · 6.65 KB
/
CarExpressionRewriter.cs
File metadata and controls
162 lines (131 loc) · 6.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
using System.Collections.Immutable;
using System.Reflection;
using JsonApiDotNetCore.Configuration;
using JsonApiDotNetCore.Queries.Expressions;
using JsonApiDotNetCore.Resources;
using JsonApiDotNetCore.Resources.Annotations;
namespace JsonApiDotNetCoreTests.IntegrationTests.CompositeKeys;
/// <summary>
/// Rewrites an expression tree, updating all references to <see cref="Car.Id" /> with the combination of <see cref="Car.RegionId" /> and
/// <see cref="Car.LicensePlate" />.
/// </summary>
/// <remarks>
/// This enables queries to use <see cref="Car.Id" />, which is not mapped in the database.
/// </remarks>
internal sealed class CarExpressionRewriter : QueryExpressionRewriter<object?>
{
private readonly AttrAttribute _regionIdAttribute;
private readonly AttrAttribute _licensePlateAttribute;
public CarExpressionRewriter(IResourceGraph resourceGraph)
{
ResourceType carType = resourceGraph.GetResourceType<Car>();
_regionIdAttribute = carType.GetAttributeByPropertyName(nameof(Car.RegionId));
_licensePlateAttribute = carType.GetAttributeByPropertyName(nameof(Car.LicensePlate));
}
public override QueryExpression? VisitComparison(ComparisonExpression expression, object? argument)
{
if (expression is { Left: ResourceFieldChainExpression leftChain, Right: LiteralConstantExpression rightConstant })
{
PropertyInfo leftProperty = leftChain.Fields[^1].Property;
if (IsCarId(leftProperty))
{
if (expression.Operator != ComparisonOperator.Equals)
{
throw new NotSupportedException("Only equality comparisons are possible on Car IDs.");
}
string carStringId = (string)rightConstant.TypedValue;
return RewriteFilterOnCarStringIds(leftChain, [carStringId]);
}
}
return base.VisitComparison(expression, argument);
}
public override QueryExpression? VisitAny(AnyExpression expression, object? argument)
{
if (expression.MatchTarget is ResourceFieldChainExpression targetAttributeChain)
{
PropertyInfo property = targetAttributeChain.Fields[^1].Property;
if (IsCarId(property))
{
string[] carStringIds = expression.Constants.Select(constant => (string)constant.TypedValue).ToArray();
return RewriteFilterOnCarStringIds(targetAttributeChain, carStringIds);
}
}
return base.VisitAny(expression, argument);
}
public override QueryExpression? VisitMatchText(MatchTextExpression expression, object? argument)
{
if (expression.MatchTarget is ResourceFieldChainExpression targetAttributeChain)
{
PropertyInfo property = targetAttributeChain.Fields[^1].Property;
if (IsCarId(property))
{
throw new NotSupportedException("Partial text matching on Car IDs is not possible.");
}
}
return base.VisitMatchText(expression, argument);
}
private static bool IsCarId(PropertyInfo property)
{
return property.Name == nameof(Identifiable<>.Id) && property.DeclaringType == typeof(Car);
}
private QueryExpression RewriteFilterOnCarStringIds(ResourceFieldChainExpression existingCarIdChain, IEnumerable<string> carStringIds)
{
ImmutableArray<FilterExpression>.Builder outerTermsBuilder = ImmutableArray.CreateBuilder<FilterExpression>();
foreach (string carStringId in carStringIds)
{
var tempCar = new Car
{
StringId = carStringId
};
LogicalExpression keyComparison = CreateEqualityComparisonOnCompositeKey(existingCarIdChain, tempCar.RegionId, tempCar.LicensePlate!);
outerTermsBuilder.Add(keyComparison);
}
return outerTermsBuilder.Count == 1 ? outerTermsBuilder[0] : new LogicalExpression(LogicalOperator.Or, outerTermsBuilder.ToImmutable());
}
private LogicalExpression CreateEqualityComparisonOnCompositeKey(ResourceFieldChainExpression existingCarIdChain, long regionIdValue,
string licensePlateValue)
{
ResourceFieldChainExpression regionIdChain = ReplaceLastAttributeInChain(existingCarIdChain, _regionIdAttribute);
var regionIdComparison = new ComparisonExpression(ComparisonOperator.Equals, regionIdChain, new LiteralConstantExpression(regionIdValue));
ResourceFieldChainExpression licensePlateChain = ReplaceLastAttributeInChain(existingCarIdChain, _licensePlateAttribute);
var licensePlateComparison = new ComparisonExpression(ComparisonOperator.Equals, licensePlateChain, new LiteralConstantExpression(licensePlateValue));
return new LogicalExpression(LogicalOperator.And, regionIdComparison, licensePlateComparison);
}
public override QueryExpression VisitSort(SortExpression expression, object? argument)
{
ImmutableArray<SortElementExpression>.Builder elementsBuilder = ImmutableArray.CreateBuilder<SortElementExpression>(expression.Elements.Count);
foreach (SortElementExpression sortElement in expression.Elements)
{
if (IsSortOnCarId(sortElement))
{
var fieldChain = (ResourceFieldChainExpression)sortElement.Target;
ResourceFieldChainExpression regionIdSort = ReplaceLastAttributeInChain(fieldChain, _regionIdAttribute);
elementsBuilder.Add(new SortElementExpression(regionIdSort, sortElement.IsAscending));
ResourceFieldChainExpression licensePlateSort = ReplaceLastAttributeInChain(fieldChain, _licensePlateAttribute);
elementsBuilder.Add(new SortElementExpression(licensePlateSort, sortElement.IsAscending));
}
else
{
elementsBuilder.Add(sortElement);
}
}
return new SortExpression(elementsBuilder.ToImmutable());
}
private static bool IsSortOnCarId(SortElementExpression sortElement)
{
if (sortElement.Target is ResourceFieldChainExpression fieldChain && fieldChain.Fields[^1] is AttrAttribute attribute)
{
PropertyInfo property = attribute.Property;
if (IsCarId(property))
{
return true;
}
}
return false;
}
private static ResourceFieldChainExpression ReplaceLastAttributeInChain(ResourceFieldChainExpression resourceFieldChain, AttrAttribute attribute)
{
IImmutableList<ResourceFieldAttribute> fields = resourceFieldChain.Fields.SetItem(resourceFieldChain.Fields.Count - 1, attribute);
return new ResourceFieldChainExpression(fields);
}
}