Skip to content

Commit 5d057aa

Browse files
author
David Freire
committed
fix: find variables recursively in the ast
1 parent ab9da4b commit 5d057aa

3 files changed

Lines changed: 68 additions & 40 deletions

File tree

src/lib.rs

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use regex::Regex;
77
use sqlparser::ast::{SetExpr, Statement};
88
use sqlparser::dialect::GenericDialect;
99
use sqlparser::parser::Parser as SqlParser;
10-
use tree_sitter::Parser as CodeParser;
10+
use tree_sitter::{Node, Parser as CodeParser};
1111
use walkdir::WalkDir;
1212

1313
struct SqlQueryError {
@@ -123,56 +123,61 @@ fn validate_query_with_schema(query: &Vec<Statement>, schema: &TablesAndColumns)
123123
return errors;
124124
}
125125

126-
fn find_queries(code: &[u8]) -> Vec<QueryInCode> {
127-
let mut parser = CodeParser::new();
128-
parser
129-
.set_language(tree_sitter_python::language())
130-
.expect("Error loading Python grammar");
131-
let parsed = parser.parse(code, None);
132-
let mut queries: Vec<QueryInCode> = Vec::new();
126+
fn find_queries_in_tree(node: &Node, code: &[u8], queries: &mut Vec<QueryInCode>) {
127+
let mut cursor = node.walk();
133128
let dialect = GenericDialect {};
134129

135-
if let Some(tree) = parsed {
136-
let mut cursor = tree.walk();
137-
for node in tree.root_node().children(&mut cursor) {
138-
let mut node_cursor = node.walk();
139-
for component in node.children(&mut node_cursor) {
140-
if component.kind() != "assignment" {
141-
continue;
142-
}
130+
for child in node.children(&mut cursor) {
131+
let mut child_cursor = child.walk();
132+
for component in child.children(&mut child_cursor) {
133+
if component.kind() != "assignment" {
134+
continue;
135+
}
143136

144-
if component.child_count() > 3 {
145-
continue;
146-
}
137+
if component.child_count() > 3 {
138+
continue;
139+
}
147140

148-
let identifier = component.child(0).unwrap();
149-
let equal = component.child(1).unwrap();
150-
let var = component.child(2).unwrap();
141+
let identifier = component.child(0).unwrap();
142+
let equal = component.child(1).unwrap();
143+
let var = component.child(2).unwrap();
151144

152-
let is_string_assignment = identifier.kind() == "identifier"
153-
&& equal.kind() == "="
154-
&& var.kind() == "string";
145+
let is_string_assignment =
146+
identifier.kind() == "identifier" && equal.kind() == "=" && var.kind() == "string";
155147

156-
if !is_string_assignment {
157-
continue;
158-
}
148+
if !is_string_assignment {
149+
continue;
150+
}
159151

160-
let content = var.child(1).unwrap();
161-
let content_as_string =
162-
String::from_utf8_lossy(&code[content.start_byte()..content.end_byte()]);
152+
let content = var.child(1).unwrap();
153+
let content_as_string =
154+
String::from_utf8_lossy(&code[content.start_byte()..content.end_byte()]);
163155

164-
let point = component.start_position();
156+
let point = component.start_position();
165157

166-
let statements = SqlParser::parse_sql(&dialect, &content_as_string);
158+
let statements = SqlParser::parse_sql(&dialect, &content_as_string);
167159

168-
if let Ok(statements) = statements {
169-
queries.push(QueryInCode {
170-
line: point.row + 1,
171-
statements,
172-
})
173-
}
160+
if let Ok(statements) = statements {
161+
queries.push(QueryInCode {
162+
line: point.row + 1,
163+
statements,
164+
})
174165
}
175166
}
167+
find_queries_in_tree(&child, code, queries);
168+
}
169+
}
170+
171+
fn find_queries(code: &[u8]) -> Vec<QueryInCode> {
172+
let mut parser = CodeParser::new();
173+
parser
174+
.set_language(tree_sitter_python::language())
175+
.expect("Error loading Python grammar");
176+
let parsed = parser.parse(code, None);
177+
let mut queries: Vec<QueryInCode> = Vec::new();
178+
179+
if let Some(tree) = parsed {
180+
find_queries_in_tree(&tree.root_node(), &code, &mut queries);
176181
}
177182
return queries;
178183
}

tests/acceptance_test.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ fn test_sqlshield_acceptance() {
1818
location: "./tests/main.py:13".to_string(),
1919
description: "Table `admin` not found in schema".to_string(),
2020
},
21+
SqlValidationError {
22+
location: "./tests/main.py:21".to_string(),
23+
description: "Table `admin` not found in schema".to_string(),
24+
},
25+
SqlValidationError {
26+
location: "./tests/main.py:28".to_string(),
27+
description: "Table `admin` not found in schema".to_string(),
28+
},
2129
];
2230
assert_eq!(validation_errors, expected_validation_errors);
2331
}

tests/main.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,23 @@
1010
WHERE id = 1
1111
"""
1212

13-
invalid_query_missing_table = """
13+
INVALID_QUERY_MISSING_TABLE = """
1414
SELECT name
1515
FROM admin
1616
WHERE id = 1
1717
"""
18+
19+
class Repository:
20+
def fn():
21+
invalid_query_missing_table_in_fn = """
22+
SELECT name
23+
FROM admin
24+
WHERE id = 1
25+
"""
26+
27+
def fn():
28+
invalid_query_missing_table_in_fn = """
29+
SELECT name
30+
FROM admin
31+
WHERE id = 1
32+
"""

0 commit comments

Comments
 (0)