Skip to content

Commit b36a327

Browse files
authored
.Net: Fix input checking in Cosmos NoSQL, Redis and Weaviate providers (#13629)
1 parent e9641a9 commit b36a327

File tree

7 files changed

+259
-48
lines changed

7 files changed

+259
-48
lines changed

dotnet/src/VectorData/CosmosNoSql/CosmosNoSqlCollectionQueryBuilder.cs

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ public static QueryDefinition BuildSearchQuery<TRecord>(
3939
Verify.NotNull(vector);
4040

4141
const string VectorVariableName = "@vector";
42-
// TODO: Use parameterized query for keywords when FullTextScore with parameters is supported.
43-
//const string KeywordsVariableName = "@keywords";
42+
const string KeywordVariablePrefix = "@keyword";
4443

4544
var tableVariableName = CosmosNoSqlConstants.ContainerAlias;
4645

@@ -53,15 +52,6 @@ public static QueryDefinition BuildSearchQuery<TRecord>(
5352
var vectorDistanceArgument = $"VectorDistance({GeneratePropertyAccess(tableVariableName, vectorPropertyName)}, {VectorVariableName})";
5453
var vectorDistanceArgumentWithAlias = $"{vectorDistanceArgument} AS {scorePropertyName}";
5554

56-
// Passing keywords using a parameter is not yet supported for FullTextScore so doing some crude string sanitization in the mean time to frustrate script injection.
57-
var sanitizedKeywords = keywords is not null ? keywords.Select(x => x.Replace("\"", "")) : null;
58-
var formattedKeywords = sanitizedKeywords is not null ? $"\"{string.Join("\", \"", sanitizedKeywords)}\"" : null;
59-
var fullTextScoreArgument = textPropertyName is not null && keywords is not null
60-
? $"FullTextScore({GeneratePropertyAccess(tableVariableName, textPropertyName)}, {formattedKeywords})"
61-
: null;
62-
63-
var rankingArgument = fullTextScoreArgument is null ? vectorDistanceArgument : $"RANK RRF({vectorDistanceArgument}, {fullTextScoreArgument})";
64-
6555
var selectClauseArguments = string.Join(",", [.. fieldsArgument, vectorDistanceArgumentWithAlias]);
6656

6757
#pragma warning disable CS0618 // VectorSearchFilter is obsolete
@@ -86,6 +76,25 @@ public static QueryDefinition BuildSearchQuery<TRecord>(
8676
}
8777
};
8878

79+
string? fullTextScoreArgument = null;
80+
if (textPropertyName is not null && keywords is not null)
81+
{
82+
var fullTextScoreBuilder = new StringBuilder();
83+
fullTextScoreBuilder.Append($"FullTextScore({GeneratePropertyAccess(tableVariableName, textPropertyName)}");
84+
var i = 0;
85+
foreach (var keyword in keywords)
86+
{
87+
var paramName = $"{KeywordVariablePrefix}{i}";
88+
fullTextScoreBuilder.Append(", ").Append(paramName);
89+
queryParameters[paramName] = keyword;
90+
i++;
91+
}
92+
fullTextScoreBuilder.Append(')');
93+
fullTextScoreArgument = fullTextScoreBuilder.ToString();
94+
}
95+
96+
var rankingArgument = fullTextScoreArgument is null ? vectorDistanceArgument : $"RANK RRF({vectorDistanceArgument}, {fullTextScoreArgument})";
97+
8998
// Add score threshold filter if specified.
9099
// For similarity functions (CosineSimilarity, DotProductSimilarity), higher scores are better, so filter with >=.
91100
// For distance functions (EuclideanDistance), lower scores are better, so filter with <=.
@@ -144,12 +153,6 @@ public static QueryDefinition BuildSearchQuery<TRecord>(
144153
builder.AppendLine($"OFFSET {skip} LIMIT {top}");
145154
}
146155

147-
// TODO: Use parameterized query for keywords when FullTextScore with parameters is supported.
148-
//if (fullTextScoreArgument is not null)
149-
//{
150-
// queryParameters.Add(KeywordsVariableName, keywords!.ToArray());
151-
//}
152-
153156
var queryDefinition = new QueryDefinition(builder.ToString());
154157

155158
if (filterParameters is { Count: > 0 })

dotnet/src/VectorData/Redis/RedisFilterTranslator.cs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,7 @@ bool TryProcessEqualityComparison(Expression first, Expression second)
115115
{
116116
ExpressionType.Equal when constantValue is byte or short or int or long or float or double => $" == {constantValue}",
117117
ExpressionType.Equal when constantValue is string stringValue
118-
#if NET
119-
=> $$""":{"{{stringValue.Replace("\"", "\\\"", StringComparison.Ordinal)}}"}""",
120-
#else
121-
=> $$""":{"{{stringValue.Replace("\"", "\"\"")}}"}""",
122-
#endif
118+
=> $$""":{"{{SanitizeStringConstant(stringValue)}}"}""",
123119
ExpressionType.Equal when constantValue is null => throw new NotSupportedException("Null value type not supported"), // TODO
124120

125121
ExpressionType.NotEqual when constantValue is int or long or float or double => $" != {constantValue}",
@@ -177,9 +173,9 @@ private void TranslateContains(Expression source, Expression item)
177173
this._filter
178174
.Append('@')
179175
.Append(property.StorageName)
180-
.Append(":{")
181-
.Append(stringConstant)
182-
.Append('}');
176+
.Append(":{\"")
177+
.Append(SanitizeStringConstant(stringConstant))
178+
.Append("\"}");
183179
return;
184180
}
185181

@@ -238,7 +234,7 @@ private void TranslateAny(Expression source, LambdaExpression lambda)
238234
this._filter.Append(" | ");
239235
}
240236

241-
this._filter.Append(stringElement);
237+
this._filter.Append('"').Append(SanitizeStringConstant(stringElement)).Append('"');
242238
}
243239

244240
this._filter.Append('}');
@@ -259,4 +255,11 @@ private void TranslateAny(Expression source, LambdaExpression lambda)
259255
return result;
260256
}
261257
}
258+
259+
private static string SanitizeStringConstant(string value)
260+
#if NET
261+
=> value.Replace("\"", "\\\"", StringComparison.Ordinal);
262+
#else
263+
=> value.Replace("\"", "\\\"");
264+
#endif
262265
}

dotnet/src/VectorData/Weaviate/WeaviateQueryBuilder.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ public static string BuildHybridSearchQuery<TRecord, TVector>(
149149
#pragma warning restore CS0618
150150

151151
var vectorArray = JsonSerializer.Serialize(vector, jsonSerializerOptions);
152+
var sanitizedKeywords = keywords.Replace("\\", "\\\\").Replace("\"", "\\\"");
152153

153154
return $$"""
154155
{
@@ -158,7 +159,7 @@ public static string BuildHybridSearchQuery<TRecord, TVector>(
158159
offset: {{searchOptions.Skip}}
159160
{{(filter is null ? "" : "where: " + filter)}}
160161
hybrid: {
161-
query: "{{keywords}}"
162+
query: "{{sanitizedKeywords}}"
162163
properties: ["{{textProperty.StorageName}}"]
163164
{{GetTargetVectorsQuery(hasNamedVectors, vectorProperty.StorageName)}}
164165
vector: {{vectorArray}}

dotnet/test/VectorData/CosmosNoSql.UnitTests/CosmosNoSqlCollectionQueryBuilderTests.cs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -221,17 +221,20 @@ public void BuildSearchQueryWithHybridFieldsReturnsValidHybridQueryDefinition()
221221
Assert.Contains("SELECT x[\"id\"],x[\"TestProperty1\"],x[\"TestProperty2\"],x[\"TestProperty3\"],VectorDistance(x[\"TestProperty1\"], @vector) AS TestScore", queryText);
222222
Assert.Contains("FROM x", queryText);
223223
Assert.Contains("WHERE x[\"TestProperty2\"] = @cv0 AND ARRAY_CONTAINS(x[\"TestProperty3\"], @cv1)", queryText);
224-
Assert.Contains("ORDER BY RANK RRF(VectorDistance(x[\"TestProperty1\"], @vector), FullTextScore(x[\"TestProperty2\"], \"hybrid\"))", queryText);
224+
Assert.Contains("ORDER BY RANK RRF(VectorDistance(x[\"TestProperty1\"], @vector), FullTextScore(x[\"TestProperty2\"], @keyword0))", queryText);
225225
Assert.Contains("OFFSET 5 LIMIT 10", queryText);
226226

227227
Assert.Equal("@vector", queryParameters[0].Name);
228228
Assert.Equal(vector, queryParameters[0].Value);
229229

230-
Assert.Equal("@cv0", queryParameters[1].Name);
231-
Assert.Equal("test-value-2", queryParameters[1].Value);
230+
Assert.Equal("@keyword0", queryParameters[1].Name);
231+
Assert.Equal("hybrid", queryParameters[1].Value);
232232

233-
Assert.Equal("@cv1", queryParameters[2].Name);
234-
Assert.Equal("test-value-3", queryParameters[2].Value);
233+
Assert.Equal("@cv0", queryParameters[2].Name);
234+
Assert.Equal("test-value-2", queryParameters[2].Value);
235+
236+
Assert.Equal("@cv1", queryParameters[3].Name);
237+
Assert.Equal("test-value-3", queryParameters[3].Value);
235238
}
236239

237240
#pragma warning disable CA1812 // An internal class that is apparently never instantiated. If so, remove the code from the assembly.

dotnet/test/VectorData/Redis.ConformanceTests/RedisFilterTests.cs

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,6 @@ public class RedisJsonFilterTests(RedisJsonFilterTests.Fixture fixture)
7777

7878
protected override string CollectionNameBase => "JsonFilterTests";
7979

80-
public override string SpecialCharactersText
81-
#if NET
82-
=> base.SpecialCharactersText;
83-
#else
84-
// Redis client doesn't properly escape '"' on Full Framework.
85-
=> base.SpecialCharactersText.Replace("\"", "");
86-
#endif
87-
8880
// Override to remove the bool property, which isn't (currently) supported on Redis/JSON
8981
public override VectorStoreCollectionDefinition CreateRecordDefinition()
9082
=> new()
@@ -163,14 +155,6 @@ public override Task Any_over_List_with_Contains_over_captured_string_array()
163155

164156
protected override string CollectionNameBase => "HashSetCollectionFilterTests";
165157

166-
public override string SpecialCharactersText
167-
#if NET
168-
=> base.SpecialCharactersText;
169-
#else
170-
// Redis client doesn't properly escape '"' on Full Framework.
171-
=> base.SpecialCharactersText.Replace("\"", "");
172-
#endif
173-
174158
// Override to remove the bool property, which isn't (currently) supported on Redis
175159
public override VectorStoreCollectionDefinition CreateRecordDefinition()
176160
=> new()
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
3+
using System;
4+
using System.Linq;
5+
using System.Linq.Expressions;
6+
using Microsoft.Extensions.VectorData;
7+
using Microsoft.Extensions.VectorData.ProviderServices;
8+
using Xunit;
9+
10+
namespace Microsoft.SemanticKernel.Connectors.Redis.UnitTests;
11+
12+
#pragma warning disable MEVD9001 // Experimental
13+
14+
public sealed class RedisFilterTranslatorTests
15+
{
16+
[Fact]
17+
public void Contains_with_simple_string()
18+
{
19+
var result = Translate<TestRecord>(r => r.Tags.Contains("foo"));
20+
Assert.Equal("""@Tags:{"foo"}""", result);
21+
}
22+
23+
[Fact]
24+
public void Contains_with_curly_brace_in_value()
25+
{
26+
var result = Translate<TestRecord>(r => r.Tags.Contains("foo}bar"));
27+
Assert.Equal("""@Tags:{"foo}bar"}""", result);
28+
}
29+
30+
[Fact]
31+
public void Contains_with_pipe_in_value()
32+
{
33+
var result = Translate<TestRecord>(r => r.Tags.Contains("foo|bar"));
34+
Assert.Equal("""@Tags:{"foo|bar"}""", result);
35+
}
36+
37+
[Fact]
38+
public void Contains_with_double_quote_in_value()
39+
{
40+
var result = Translate<TestRecord>(r => r.Tags.Contains("foo\"bar"));
41+
Assert.Equal("""@Tags:{"foo\"bar"}""", result);
42+
}
43+
44+
[Fact]
45+
public void Contains_with_asterisk_in_value()
46+
{
47+
var result = Translate<TestRecord>(r => r.Tags.Contains("foo*bar"));
48+
Assert.Equal("""@Tags:{"foo*bar"}""", result);
49+
}
50+
51+
[Fact]
52+
public void Contains_with_at_sign_in_value()
53+
{
54+
var result = Translate<TestRecord>(r => r.Tags.Contains("foo@bar"));
55+
Assert.Equal("""@Tags:{"foo@bar"}""", result);
56+
}
57+
58+
[Fact]
59+
public void Contains_with_injection_attempt()
60+
{
61+
var result = Translate<TestRecord>(r => r.Tags.Contains("evil} | @secret:{*"));
62+
Assert.Equal("""@Tags:{"evil} | @secret:{*"}""", result);
63+
}
64+
65+
[Fact]
66+
public void Any_with_simple_strings()
67+
{
68+
var values = new[] { "a", "b" };
69+
var result = Translate<TestRecord>(r => r.Tags.Any(t => values.Contains(t)));
70+
Assert.Equal("""@Tags:{"a" | "b"}""", result);
71+
}
72+
73+
[Fact]
74+
public void Any_with_metacharacters_in_values()
75+
{
76+
var values = new[] { "a|b", "c}d" };
77+
var result = Translate<TestRecord>(r => r.Tags.Any(t => values.Contains(t)));
78+
Assert.Equal("""@Tags:{"a|b" | "c}d"}""", result);
79+
}
80+
81+
[Fact]
82+
public void Any_with_double_quotes_in_values()
83+
{
84+
var values = new[] { "a\"b", "c\"d" };
85+
var result = Translate<TestRecord>(r => r.Tags.Any(t => values.Contains(t)));
86+
Assert.Equal("""@Tags:{"a\"b" | "c\"d"}""", result);
87+
}
88+
89+
[Fact]
90+
public void Any_with_injection_attempt()
91+
{
92+
var values = new[] { "safe", "x} | @admin:{true" };
93+
var result = Translate<TestRecord>(r => r.Tags.Any(t => values.Contains(t)));
94+
Assert.Equal("""@Tags:{"safe" | "x} | @admin:{true"}""", result);
95+
}
96+
97+
[Fact]
98+
public void Equal_with_simple_string()
99+
{
100+
var result = Translate<TestRecord>(r => r.Name == "foo");
101+
Assert.Equal("""@Name:{"foo"}""", result);
102+
}
103+
104+
[Fact]
105+
public void Equal_with_double_quote_in_value()
106+
{
107+
var result = Translate<TestRecord>(r => r.Name == "foo\"bar");
108+
Assert.Equal("""@Name:{"foo\"bar"}""", result);
109+
}
110+
111+
private static string Translate<TRecord>(Expression<Func<TRecord, bool>> filter)
112+
{
113+
var model = BuildModel();
114+
return new RedisFilterTranslator().Translate(filter, model);
115+
}
116+
117+
private static CollectionModel BuildModel()
118+
{
119+
var definition = new VectorStoreCollectionDefinition
120+
{
121+
Properties =
122+
[
123+
new VectorStoreKeyProperty("Id", typeof(string)),
124+
new VectorStoreDataProperty("Name", typeof(string)),
125+
new VectorStoreDataProperty("Tags", typeof(string[])),
126+
new VectorStoreVectorProperty("Embedding", typeof(ReadOnlyMemory<float>), 10)
127+
]
128+
};
129+
130+
return new RedisJsonModelBuilder(RedisJsonCollection<string, object>.ModelBuildingOptions).BuildDynamic(definition, defaultEmbeddingGenerator: null);
131+
}
132+
133+
#pragma warning disable CA1812
134+
private sealed record TestRecord
135+
{
136+
public string Id { get; set; } = string.Empty;
137+
public string Name { get; set; } = string.Empty;
138+
public string[] Tags { get; set; } = [];
139+
public ReadOnlyMemory<float> Embedding { get; set; }
140+
}
141+
#pragma warning restore CA1812
142+
}

0 commit comments

Comments
 (0)