Skip to content

Commit e4e711c

Browse files
committed
feat!: avoid string cloning
1 parent 58bc72f commit e4e711c

7 files changed

Lines changed: 155 additions & 167 deletions

File tree

sqlshield-gui/src/main.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@ fn main() -> Result<(), slint::PlatformError> {
1818
// println!("schema -> {}", ui.get_schema());
1919
}
2020

21-
let errors = match sqlshield::validate_query(
22-
ui.get_queries().to_string(),
23-
ui.get_schema().to_string(),
24-
) {
21+
let queries = ui.get_queries();
22+
let schema = ui.get_schema();
23+
let errors = match sqlshield::validate_query(queries.as_str(), schema.as_str()) {
2524
Ok(errors) => errors,
2625
Err(err) => vec![err],
2726
};

sqlshield-py/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ fn validate_files(dir: String, schema_file_path: String) -> PyResult<Vec<PySqlVa
3737
}
3838

3939
#[pyfunction]
40-
fn validate_query(query: String, schema: String) -> PyResult<Vec<String>> {
41-
sqlshield_rs::validate_query(query, schema).map_err(|err| PyValueError::new_err(err))
40+
fn validate_query(query: &str, schema: &str) -> PyResult<Vec<String>> {
41+
sqlshield_rs::validate_query(query, schema).map_err(PyValueError::new_err)
4242
}
4343

4444
/// A Python module implemented in Rust.

sqlshield/src/lib.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ use regex::Regex;
99
use validation::{validate_queries_in_code, validate_statements_with_schema, SqlValidationError};
1010
use walkdir::WalkDir;
1111

12-
pub fn validate_query(query: String, schema: String) -> Result<Vec<String>, String> {
12+
pub fn validate_query(query: &str, schema: &str) -> Result<Vec<String>, String> {
1313
let dialect = sqlparser::dialect::GenericDialect {};
1414

15-
let statements = match sqlparser::parser::Parser::parse_sql(&dialect, &query) {
15+
let statements = match sqlparser::parser::Parser::parse_sql(&dialect, query) {
1616
Ok(statements) => statements,
1717
Err(err) => return Err(err.to_string()),
1818
};
1919

20-
let loaded_schema = match schema::load_schema(&schema.into_bytes(), "sql") {
20+
let loaded_schema = match schema::load_schema(schema.as_bytes(), "sql") {
2121
Ok(loaded_schema) => loaded_schema,
2222
Err(err) => return Err(err),
2323
};
Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
1-
use std::collections::HashSet;
1+
use std::collections::{HashMap, HashSet};
2+
3+
use crate::schema;
24

35
pub fn is_relation_in_schema(
46
relation: &sqlparser::ast::TableFactor,
5-
tables: &HashSet<String>,
7+
schema: &schema::TablesAndColumns,
8+
extras: &HashMap<&str, HashSet<&str>>,
69
) -> Option<String> {
7-
// returns table_name if not in schema
8-
match &relation {
10+
match relation {
911
sqlparser::ast::TableFactor::Table { name, .. } => {
1012
// TODO support table name with schema prefixed instead of using last ident
11-
let table_name = name.0.last().unwrap();
12-
let table_name_str = table_name.value.as_str();
13-
if tables.contains(table_name_str) {
13+
let table_name = name.0.last().unwrap().value.as_str();
14+
if schema.contains_key(table_name) || extras.contains_key(table_name) {
1415
return None;
1516
}
16-
let name_full: String = name
17+
let name_full = name
1718
.0
1819
.iter()
1920
.map(|e| e.value.as_str())
2021
.collect::<Vec<&str>>()
2122
.join(".");
22-
return Some(name_full);
23+
Some(name_full)
2324
}
24-
_ => {}
25+
_ => None,
2526
}
26-
None
2727
}
Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
mod select;
22

3+
use std::collections::{HashMap, HashSet};
4+
35
use crate::schema;
46

57
pub trait ClauseValidation {
6-
fn validate(&self, schema: &schema::TablesAndColumns) -> Vec<String>;
8+
fn validate(
9+
&self,
10+
schema: &schema::TablesAndColumns,
11+
extras: &HashMap<&str, HashSet<&str>>,
12+
) -> Vec<String>;
713
}

sqlshield/src/validation/clauses/select.rs

Lines changed: 71 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,30 @@ use crate::{schema, validation::asserts};
22

33
use super::ClauseValidation;
44

5-
use std::collections::HashSet;
5+
use std::collections::{HashMap, HashSet};
66

77
impl ClauseValidation for sqlparser::ast::Select {
8-
fn validate(&self, schema: &schema::TablesAndColumns) -> Vec<String> {
9-
let tables_in_schema: HashSet<String> = HashSet::from_iter(schema.keys().cloned());
10-
8+
fn validate(
9+
&self,
10+
schema: &schema::TablesAndColumns,
11+
extras: &HashMap<&str, HashSet<&str>>,
12+
) -> Vec<String> {
1113
let select = self;
1214
let mut errors = vec![];
1315

1416
for item in &select.from {
15-
let relation_name = asserts::is_relation_in_schema(&item.relation, &tables_in_schema);
16-
17-
if let Some(relation_name) = relation_name {
17+
if let Some(relation_name) =
18+
asserts::is_relation_in_schema(&item.relation, schema, extras)
19+
{
1820
errors.push(format!(
1921
"Table `{relation_name}` not found in schema nor subqueries"
2022
))
2123
}
2224

2325
for join in &item.joins {
24-
let relation_name =
25-
asserts::is_relation_in_schema(&join.relation, &tables_in_schema);
26-
if let Some(relation_name) = relation_name {
26+
if let Some(relation_name) =
27+
asserts::is_relation_in_schema(&join.relation, schema, extras)
28+
{
2729
errors.push(format!(
2830
"Table `{relation_name}` not found in schema nor subqueries"
2931
))
@@ -32,7 +34,7 @@ impl ClauseValidation for sqlparser::ast::Select {
3234
}
3335

3436
for item in &select.projection {
35-
let result = is_select_item_in_relations(item, &select.from, &schema);
37+
let result = is_select_item_in_relations(item, &select.from, schema, extras);
3638

3739
if let Some((item_name, relations_not_found_in)) = result {
3840
if relations_not_found_in.len() == 1 {
@@ -50,25 +52,28 @@ impl ClauseValidation for sqlparser::ast::Select {
5052
}
5153
}
5254

53-
fn is_select_item_in_relations(
54-
item: &sqlparser::ast::SelectItem,
55-
tables: &Vec<sqlparser::ast::TableWithJoins>,
56-
schema: &schema::TablesAndColumns,
57-
) -> Option<(String, Vec<String>)> {
58-
let mut tables_searched_where_not_found: Vec<String> = vec![];
59-
let mut item_name: Option<String> = None;
55+
fn is_select_item_in_relations<'a>(
56+
item: &'a sqlparser::ast::SelectItem,
57+
tables: &'a [sqlparser::ast::TableWithJoins],
58+
schema: &'a schema::TablesAndColumns,
59+
extras: &HashMap<&'a str, HashSet<&'a str>>,
60+
) -> Option<(&'a str, Vec<&'a str>)> {
61+
let mut tables_searched_where_not_found: Vec<&str> = vec![];
62+
let mut item_name: Option<&str> = None;
6063

6164
for relation in tables {
62-
let result = could_select_item_be_in_relation(&item, &relation.relation, &schema);
63-
if let Some((col_name, table_name)) = result {
65+
if let Some((col_name, table_name)) =
66+
could_select_item_be_in_relation(item, &relation.relation, schema, extras)
67+
{
6468
tables_searched_where_not_found.push(table_name);
6569
if item_name.is_none() {
6670
item_name = Some(col_name);
6771
}
6872
}
6973
for join in &relation.joins {
70-
let result = could_select_item_be_in_relation(&item, &join.relation, &schema);
71-
if let Some((col_name, table_name)) = result {
74+
if let Some((col_name, table_name)) =
75+
could_select_item_be_in_relation(item, &join.relation, schema, extras)
76+
{
7277
tables_searched_where_not_found.push(table_name);
7378
if item_name.is_none() {
7479
item_name = Some(col_name);
@@ -83,68 +88,59 @@ fn is_select_item_in_relations(
8388
Some((item_name?, tables_searched_where_not_found))
8489
}
8590

86-
fn could_select_item_be_in_relation(
87-
item: &sqlparser::ast::SelectItem,
88-
table: &sqlparser::ast::TableFactor,
89-
schema: &schema::TablesAndColumns,
90-
) -> Option<(String, String)> {
91+
fn could_select_item_be_in_relation<'a>(
92+
item: &'a sqlparser::ast::SelectItem,
93+
table: &'a sqlparser::ast::TableFactor,
94+
schema: &'a schema::TablesAndColumns,
95+
extras: &HashMap<&'a str, HashSet<&'a str>>,
96+
) -> Option<(&'a str, &'a str)> {
9197
// returns item_name, table_name if item could be in table but is not
9298

93-
let mut columns: Option<&HashSet<String>> = None;
94-
let mut col_name: Option<String> = None;
95-
let mut col_table_alias: Option<String> = None;
96-
// let mut col_alias: Option<String> = None;
97-
98-
let mut table_name: Option<String> = None;
99-
100-
match &item {
101-
sqlparser::ast::SelectItem::UnnamedExpr(expression) => {
102-
match expression {
103-
sqlparser::ast::Expr::Identifier(identifier) => {
104-
col_name = Some(identifier.value.clone());
105-
}
106-
sqlparser::ast::Expr::CompoundIdentifier(identifier) => {
107-
// for now only supports table alias
108-
if identifier.len() == 2 {
109-
col_table_alias = Some(identifier[0].value.clone());
110-
col_name = Some(identifier[1].value.clone());
111-
}
112-
}
113-
_ => {}
114-
}
115-
}
99+
let (col_name, col_table_alias): (Option<&str>, Option<&str>) = match item {
100+
sqlparser::ast::SelectItem::UnnamedExpr(expression) => match expression {
101+
sqlparser::ast::Expr::Identifier(identifier) => (Some(identifier.value.as_str()), None),
102+
sqlparser::ast::Expr::CompoundIdentifier(identifier) if identifier.len() == 2 => (
103+
Some(identifier[1].value.as_str()),
104+
Some(identifier[0].value.as_str()),
105+
),
106+
_ => (None, None),
107+
},
116108
// TODO: aliased columns
117109
// sqlparser::ast::SelectItem::ExprWithAlias { expr, alias } => {},
118-
_ => {}
119-
}
110+
_ => (None, None),
111+
};
120112

121-
match &table {
113+
let (table_name, alias) = match table {
122114
sqlparser::ast::TableFactor::Table { name, alias, .. } => {
123-
let name = &name.0.last().unwrap().value;
124-
125-
match (alias, col_table_alias) {
126-
(None, None) => {
127-
columns = schema.get(name);
128-
}
129-
(None, Some(_)) => {}
130-
(Some(_), None) => {}
131-
(Some(alias), Some(col_table_alias)) => {
132-
if alias.name.value == col_table_alias {
133-
columns = schema.get(name);
134-
}
135-
}
136-
}
137-
table_name = Some(name.clone());
115+
(name.0.last().unwrap().value.as_str(), alias.as_ref())
138116
}
139117
// TODO Implement for others
140-
_ => (),
141-
}
118+
_ => return None,
119+
};
142120

143-
if let (Some(columns), Some(col_name)) = (columns, col_name) {
144-
if !columns.contains(col_name.as_str()) {
145-
return Some((col_name, table_name?));
146-
}
121+
let should_check = match (alias, col_table_alias) {
122+
(None, None) => true,
123+
(Some(table_alias), Some(col_alias)) => table_alias.name.value == col_alias,
124+
_ => false,
125+
};
126+
127+
if !should_check {
128+
return None;
147129
}
148130

149-
None
131+
let col_name = col_name?;
132+
133+
let column_present = if let Some(cols) = schema.get(table_name) {
134+
cols.contains(col_name)
135+
} else if let Some(cols) = extras.get(table_name) {
136+
cols.contains(col_name)
137+
} else {
138+
return None;
139+
};
140+
141+
if column_present {
142+
None
143+
} else {
144+
Some((col_name, table_name))
145+
}
150146
}

0 commit comments

Comments
 (0)