Skip to content

Commit c094006

Browse files
davidsmfreireclaude
andcommitted
fix: walk full projection expressions for column refs
The projection check used a narrow `direct_col_ref` shortcut that only matched bare Identifier and 2-segment CompoundIdentifier — anything wrapped in a function call, CASE, CAST, or arithmetic was silently unvalidated. `SELECT LENGTH(bogus) FROM users` and friends would pass clean while WHERE flagged the same identifier. Replace the projection loop with the same `validate_expr_column_refs` walker that WHERE/HAVING/JOIN ON use. The walker already knows how to descend into BinaryOp / Function args / Case / Cast / Like / Between / IsNull family / etc., so this is a single-line dispatch change once the helpers were unified. Drop the now-unused `direct_col_ref` and merge `resolve_unqualified_for_projection` into `resolve_unqualified` so every clause emits the same "Column X not found in table Y" / "not found in none of: …" message format. Existing tests already match on column-name containment so no breakage. Inlining clippy fix: collapse `SetExpr::Insert(inner) → if let Statement::Insert{…} = inner` into a single nested match in validate_set_expr. 10 new tests cover function calls (valid + invalid), nested function calls, CASE branches and conditions, CAST, arithmetic BinaryOp, qualified identifiers inside function args, and that COUNT(*) — a SelectItem::QualifiedWildcard — doesn't false-positive. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent b6fb4f2 commit c094006

3 files changed

Lines changed: 115 additions & 84 deletions

File tree

sqlshield/src/validation/clauses/select.rs

Lines changed: 11 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -192,37 +192,16 @@ fn column_in_relation(
192192
None
193193
}
194194

195+
/// Resolve an unqualified column reference against the visible relations.
196+
/// The error message names the specific table(s) the column was missing
197+
/// from when at least one visible relation is known to the schema; if no
198+
/// known relation contains this column, table-not-found errors emitted by
199+
/// the FROM walk already covered the situation, so we stay quiet.
195200
fn resolve_unqualified(
196201
col: &str,
197202
relations: &[VisibleRelation<'_>],
198203
schema: &schema::TablesAndColumns,
199204
extras: &HashMap<&str, HashSet<&str>>,
200-
) -> Option<String> {
201-
let mut any_known = false;
202-
for rel in relations {
203-
match column_in_relation(col, rel, schema, extras) {
204-
Some(true) => return None,
205-
Some(false) => any_known = true,
206-
None => {}
207-
}
208-
}
209-
if !any_known {
210-
// None of the visible relations are in the schema: table-not-found
211-
// errors from the FROM check already covered this; don't pile on.
212-
return None;
213-
}
214-
Some(format!("Column `{col}` not found in any visible table"))
215-
}
216-
217-
/// Like [`resolve_unqualified`] but carries a richer error message that
218-
/// names the specific table(s) the column was searched in. Used by the
219-
/// projection check, which historically reported "not found in table X"
220-
/// rather than the more generic "not found in any visible table".
221-
fn resolve_unqualified_for_projection(
222-
col: &str,
223-
relations: &[VisibleRelation<'_>],
224-
schema: &schema::TablesAndColumns,
225-
extras: &HashMap<&str, HashSet<&str>>,
226205
) -> Option<String> {
227206
let mut not_found_in: Vec<&str> = Vec::new();
228207
for rel in relations {
@@ -432,21 +411,17 @@ impl ClauseValidation for Select {
432411
}
433412
}
434413

414+
// Walk every projection expression with the same scope-aware visitor
415+
// used by WHERE/HAVING/etc. — catches column refs inside function
416+
// calls, CASE branches, CAST, arithmetic, and nested expressions
417+
// that the old `direct_col_ref` shortcut silently skipped.
418+
// Wildcard / QualifiedWildcard items have no column to validate.
435419
for item in &select.projection {
436420
let expr = match item {
437421
SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } => e,
438422
_ => continue,
439423
};
440-
let (col_name, col_qualifier) = direct_col_ref(expr);
441-
let Some(col_name) = col_name else { continue };
442-
443-
let err = match col_qualifier {
444-
Some(qual) => resolve_qualified(qual, col_name, &visible, schema, extras),
445-
None => resolve_unqualified_for_projection(col_name, &visible, schema, extras),
446-
};
447-
if let Some(err) = err {
448-
errors.push(err);
449-
}
424+
validate_expr_column_refs(expr, &visible, schema, extras, &no_aliases, &mut errors);
450425
}
451426

452427
// WHERE / HAVING / GROUP BY column references. `visible` and
@@ -475,18 +450,3 @@ impl ClauseValidation for Select {
475450
errors
476451
}
477452
}
478-
479-
/// Pull a direct column reference out of an expression, ignoring wrappers
480-
/// we don't yet drill into (function calls, CASE, casts, etc.). Returns
481-
/// `(column, qualifier)` — the qualifier is the table-or-alias prefix in a
482-
/// 2-segment compound identifier.
483-
fn direct_col_ref(expr: &Expr) -> (Option<&str>, Option<&str>) {
484-
match expr {
485-
Expr::Identifier(identifier) => (Some(identifier.value.as_str()), None),
486-
Expr::CompoundIdentifier(identifier) if identifier.len() == 2 => (
487-
Some(identifier[1].value.as_str()),
488-
Some(identifier[0].value.as_str()),
489-
),
490-
_ => (None, None),
491-
}
492-
}

sqlshield/src/validation/mod.rs

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -199,44 +199,38 @@ fn validate_set_expr<'a>(
199199
// body. Dispatch back to the DML validators so the inner statement
200200
// is checked and the surrounding CTEs (carried in `extras`) stay in
201201
// scope.
202-
SetExpr::Insert(inner) => {
203-
if let Statement::Insert {
204-
table_name,
205-
columns,
206-
source,
207-
..
208-
} = inner
209-
{
210-
errors.extend(clauses::insert::validate_insert(
211-
table_name, columns, schema,
202+
SetExpr::Insert(Statement::Insert {
203+
table_name,
204+
columns,
205+
source,
206+
..
207+
}) => {
208+
errors.extend(clauses::insert::validate_insert(
209+
table_name, columns, schema,
210+
));
211+
if let Some(source_query) = source {
212+
errors.extend(validate_query_with_scope(
213+
source_query.as_ref(),
214+
schema,
215+
extras,
212216
));
213-
if let Some(source_query) = source {
214-
errors.extend(validate_query_with_scope(
215-
source_query.as_ref(),
216-
schema,
217-
extras,
218-
));
219-
}
220217
}
221218
}
222-
SetExpr::Update(inner) => {
223-
if let Statement::Update {
219+
SetExpr::Update(Statement::Update {
220+
table,
221+
assignments,
222+
from,
223+
selection,
224+
..
225+
}) => {
226+
errors.extend(clauses::update::validate_update(
224227
table,
225228
assignments,
226-
from,
227-
selection,
228-
..
229-
} = inner
230-
{
231-
errors.extend(clauses::update::validate_update(
232-
table,
233-
assignments,
234-
from.as_ref(),
235-
selection.as_ref(),
236-
schema,
237-
extras,
238-
));
239-
}
229+
from.as_ref(),
230+
selection.as_ref(),
231+
schema,
232+
extras,
233+
));
240234
}
241235
_ => {}
242236
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
//! Projection items beyond bare identifiers — function calls, CASE, CAST,
2+
//! arithmetic. The WHERE-side walker already covers these; previously the
3+
//! projection check used a `direct_col_ref` shortcut that only matched
4+
//! Identifier / 2-segment CompoundIdentifier and silently passed everything
5+
//! else.
6+
7+
use sqlshield::validate_query;
8+
9+
const SCHEMA: &str = "
10+
CREATE TABLE users (id INT, name VARCHAR(255), age INT);
11+
";
12+
13+
fn run(sql: &str) -> Vec<String> {
14+
validate_query(sql, SCHEMA).expect("SQL/schema should parse")
15+
}
16+
17+
#[test]
18+
fn function_call_on_valid_column() {
19+
assert!(run("SELECT LENGTH(name) FROM users").is_empty());
20+
}
21+
22+
#[test]
23+
fn function_call_on_unknown_column_is_reported() {
24+
let errs = run("SELECT LENGTH(bogus) FROM users");
25+
assert!(errs.iter().any(|e| e.contains("`bogus`")), "got: {errs:?}");
26+
}
27+
28+
#[test]
29+
fn case_when_branch_unknown_column_is_reported() {
30+
let errs = run("SELECT CASE WHEN id > 0 THEN bogus ELSE name END FROM users");
31+
assert!(errs.iter().any(|e| e.contains("`bogus`")), "got: {errs:?}");
32+
}
33+
34+
#[test]
35+
fn case_condition_unknown_column_is_reported() {
36+
let errs = run("SELECT CASE WHEN typo > 0 THEN name END FROM users");
37+
assert!(errs.iter().any(|e| e.contains("`typo`")), "got: {errs:?}");
38+
}
39+
40+
#[test]
41+
fn cast_unknown_column_is_reported() {
42+
let errs = run("SELECT CAST(bogus AS TEXT) FROM users");
43+
assert!(errs.iter().any(|e| e.contains("`bogus`")), "got: {errs:?}");
44+
}
45+
46+
#[test]
47+
fn arithmetic_unknown_column_is_reported() {
48+
let errs = run("SELECT id + bogus FROM users");
49+
assert!(errs.iter().any(|e| e.contains("`bogus`")), "got: {errs:?}");
50+
}
51+
52+
#[test]
53+
fn nested_function_call() {
54+
let errs = run("SELECT LENGTH(UPPER(bogus)) FROM users");
55+
assert!(errs.iter().any(|e| e.contains("`bogus`")), "got: {errs:?}");
56+
}
57+
58+
#[test]
59+
fn function_with_aliased_qualifier_resolves() {
60+
// `u.name` inside a function — qualified identifier in nested context.
61+
assert!(run("SELECT LENGTH(u.name) FROM users u").is_empty());
62+
}
63+
64+
#[test]
65+
fn function_with_aliased_qualifier_unknown_column() {
66+
let errs = run("SELECT LENGTH(u.bogus) FROM users u");
67+
assert!(
68+
errs.iter()
69+
.any(|e| e.contains("`bogus`") && e.contains("`users`")),
70+
"got: {errs:?}"
71+
);
72+
}
73+
74+
#[test]
75+
fn count_star_does_not_error() {
76+
assert!(run("SELECT COUNT(*) FROM users").is_empty());
77+
}

0 commit comments

Comments
 (0)