@@ -2,7 +2,9 @@ use sqlparser::ast::{Expr, ObjectName, Query, SetExpr, Statement, TableFactor, V
22use std:: result:: Result ;
33use std:: { collections:: HashSet , ops:: ControlFlow } ;
44
5- use super :: { constants:: ALLOWED_FUNCTIONS , Schema } ;
5+ use crate :: relational:: Layout ;
6+
7+ use super :: constants:: ALLOWED_FUNCTIONS ;
68
79#[ derive( thiserror:: Error , Debug , PartialEq ) ]
810pub enum Error {
@@ -17,14 +19,14 @@ pub enum Error {
1719}
1820
1921pub struct Validator < ' a > {
20- schema : & ' a Schema ,
22+ layout : & ' a Layout ,
2123 ctes : HashSet < String > ,
2224}
2325
2426impl < ' a > Validator < ' a > {
25- pub fn new ( schema : & ' a Schema ) -> Self {
27+ pub fn new ( layout : & ' a Layout ) -> Self {
2628 Self {
27- schema ,
29+ layout ,
2830 ctes : Default :: default ( ) ,
2931 }
3032 }
@@ -54,9 +56,9 @@ impl<'a> Validator<'a> {
5456
5557 fn validate_table_name ( & mut self , name : & ObjectName ) -> ControlFlow < Error > {
5658 if let Some ( table_name) = name. 0 . last ( ) {
57- let table_name = table_name. to_string ( ) . to_lowercase ( ) ;
58- if !self . schema . contains_key ( & table_name ) && !self . ctes . contains ( & table_name ) {
59- return ControlFlow :: Break ( Error :: UnknownTable ( table_name ) ) ;
59+ let name = & table_name. value ;
60+ if !self . layout . table ( name ) . is_some ( ) && !self . ctes . contains ( name ) {
61+ return ControlFlow :: Break ( Error :: UnknownTable ( name . to_string ( ) ) ) ;
6062 }
6163 }
6264 ControlFlow :: Continue ( ( ) )
@@ -114,38 +116,35 @@ impl Visitor for Validator<'_> {
114116#[ cfg( test) ]
115117mod test {
116118 use super :: * ;
117- use crate :: sql:: constants:: SQL_DIALECT ;
118- use std:: collections:: { HashMap , HashSet } ;
119+ use crate :: sql:: { constants:: SQL_DIALECT , test:: make_layout} ;
119120
120121 fn validate ( sql : & str ) -> Result < ( ) , Error > {
121122 let statements = sqlparser:: parser:: Parser :: parse_sql ( & SQL_DIALECT , sql) . unwrap ( ) ;
122123
123- let schema: Schema = HashMap :: from ( [ (
124- "swap" . to_owned ( ) ,
125- HashSet :: from ( [
126- "vid" . to_owned ( ) ,
127- "block$" . to_owned ( ) ,
128- "id" . to_owned ( ) ,
129- "sender" . to_owned ( ) ,
130- "input_amount" . to_owned ( ) ,
131- "input_token" . to_owned ( ) ,
132- "amount_out" . to_owned ( ) ,
133- "output_token" . to_owned ( ) ,
134- "slippage" . to_owned ( ) ,
135- "referral_code" . to_owned ( ) ,
136- "block_number" . to_owned ( ) ,
137- "block_timestamp" . to_owned ( ) ,
138- "transaction_hash" . to_owned ( ) ,
139- ] ) ,
140- ) ] ) ;
141-
142- let mut validator = Validator :: new ( & schema) ;
124+ const GQL : & str = "
125+ type Swap @entity {
126+ id: ID!
127+ sender: Bytes!
128+ inputAmount: BigDecimal!
129+ inputToken: Bytes!
130+ amountOut: BigDecimal!
131+ outputToken: Bytes!
132+ slippage: BigDecimal!
133+ referralCode: String
134+ blockNumber: Int!
135+ blockTimestamp: Timestamp!
136+ transactionHash: Bytes!
137+ }" ;
138+
139+ let layout = make_layout ( GQL ) ;
140+
141+ let mut validator = Validator :: new ( & layout) ;
143142
144143 validator. validate_statements ( & statements)
145144 }
146145
147146 #[ test]
148- fn test_function_blacklisted ( ) {
147+ fn test_function_disallowed ( ) {
149148 let result = validate (
150149 "
151150 SELECT
@@ -161,7 +160,7 @@ mod test {
161160 }
162161
163162 #[ test]
164- fn test_table_function_blacklisted ( ) {
163+ fn test_table_function_disallowed ( ) {
165164 let result = validate (
166165 "
167166 SELECT
@@ -181,7 +180,7 @@ mod test {
181180 }
182181
183182 #[ test]
184- fn test_function_blacklisted_without_paranthesis ( ) {
183+ fn test_function_disallowed_without_paranthesis ( ) {
185184 let result = validate (
186185 "
187186 SELECT
@@ -195,7 +194,7 @@ mod test {
195194 }
196195
197196 #[ test]
198- fn test_function_whitelisted ( ) {
197+ fn test_function_allowed ( ) {
199198 let result = validate (
200199 "
201200 SELECT
0 commit comments