Skip to content

Commit 79b3e8e

Browse files
committed
allow parsing past the first delimiter for complete statement
1 parent 9ce209f commit 79b3e8e

5 files changed

Lines changed: 105 additions & 30 deletions

File tree

postgres/parser/parser/parse.go

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -129,15 +129,15 @@ func (p *Parser) parseOneWithDepth(depth int, sql string) (Statement, error) {
129129
return stmts[0], nil
130130
}
131131

132-
func (p *Parser) scanOneStmt() (sql string, tokens []sqlSymType, done bool) {
132+
func (p *Parser) scanOneStmt(parse func(sqlStr string, tokens []sqlSymType) error) bool {
133133
var lval sqlSymType
134-
tokens = p.tokBuf[:0]
134+
tokens := p.tokBuf[:0]
135135

136136
// Scan the first token.
137137
for {
138138
p.scanner.scan(&lval)
139139
if lval.id == 0 {
140-
return "", nil, true
140+
return true
141141
}
142142
if lval.id != ';' {
143143
break
@@ -150,12 +150,21 @@ func (p *Parser) scanOneStmt() (sql string, tokens []sqlSymType, done bool) {
150150
tokens = append(tokens, lval)
151151
for {
152152
if lval.id == ERROR {
153-
return p.scanner.in[startPos:], tokens, true
153+
_ = parse(p.scanner.in[startPos:], tokens)
154+
return true
154155
}
155156
posBeforeScan := p.scanner.pos
156157
p.scanner.scan(&lval)
157158
if lval.id == 0 || lval.id == ';' {
158-
return p.scanner.in[startPos:posBeforeScan], tokens, (lval.id == 0)
159+
err := parse(p.scanner.in[startPos:posBeforeScan], tokens)
160+
if lval.id == 0 || (err != nil && !strings.Contains(err.Error(), "EOF")) {
161+
// done scanning all statements OR due to non EOF error
162+
return true
163+
} else if err == nil {
164+
// done scanning single statement
165+
return false
166+
}
167+
// continue scanning if it's EOF error
159168
}
160169
lval.pos -= startPos
161170
tokens = append(tokens, lval)
@@ -166,19 +175,27 @@ func (p *Parser) parseWithDepth(depth int, sql string, nakedIntType *types.T) (S
166175
stmts := Statements(p.stmtBuf[:0])
167176
p.scanner.init(sql)
168177
defer p.scanner.cleanup()
178+
var err error
169179
for {
170-
sql, tokens, done := p.scanOneStmt()
171-
stmt, err := p.parse(depth+1, sql, tokens, nakedIntType)
172-
if err != nil {
173-
return nil, err
174-
}
175-
if stmt.AST != nil {
176-
stmts = append(stmts, stmt)
177-
}
180+
done := p.scanOneStmt(func(sqlStr string, tokens []sqlSymType) error {
181+
var stmt Statement
182+
stmt, err = p.parse(depth+1, sqlStr, tokens, nakedIntType)
183+
if err != nil {
184+
// if it's EOF syntax error, try running again
185+
return err
186+
}
187+
if stmt.AST != nil {
188+
stmts = append(stmts, stmt)
189+
}
190+
return nil
191+
})
178192
if done {
179193
break
180194
}
181195
}
196+
if err != nil {
197+
return nil, err
198+
}
182199
return stmts, nil
183200
}
184201

@@ -259,7 +276,9 @@ func HasMultipleStatements(sql string) bool {
259276
defer p.scanner.cleanup()
260277
count := 0
261278
for {
262-
_, _, done := p.scanOneStmt()
279+
done := p.scanOneStmt(func(sqlStr string, tokens []sqlSymType) error {
280+
return nil
281+
})
263282
if done {
264283
break
265284
}

postgres/parser/sem/tree/create_function.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -280,14 +280,12 @@ type BeginEndBlock struct {
280280
}
281281

282282
func (node *BeginEndBlock) Format(ctx *FmtCtx) {
283-
ctx.WriteString("BEGIN ATOMIC")
284-
for i, s := range node.Statements {
285-
if i != 0 {
286-
ctx.WriteString("; ")
287-
}
283+
ctx.WriteString("BEGIN ATOMIC ")
284+
for _, s := range node.Statements {
288285
ctx.FormatNode(s)
286+
ctx.WriteString("; ")
289287
}
290-
ctx.WriteString(" END")
288+
ctx.WriteString("END")
291289
}
292290

293291
var _ Statement = &Return{}

server/ast/create_function.go

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,33 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme
112112
}
113113
case "sql":
114114
as, ok := options[tree.OptionAs1]
115-
if !ok {
116-
return nil, errors.Errorf("CREATE FUNCTION definition needed for LANGUAGE SQL")
115+
if ok {
116+
sqlDef, sqlDefParsedStmts, err = handleLanguageSQLAs(as.Definition, params)
117+
if err != nil {
118+
return nil, err
119+
}
120+
break
117121
}
118-
sqlDef, sqlDefParsedStmts, err = handleLanguageSQL(as.Definition, params)
119-
if err != nil {
120-
return nil, err
122+
sqlBody, ok := options[tree.OptionSqlBody]
123+
if ok {
124+
beginAtomic, ok := sqlBody.SqlBody.(*tree.BeginEndBlock)
125+
if !ok {
126+
return nil, errors.Errorf("Expected BEGIN ATOMIC in CREATE FUNCTION definition, got %T", sqlBody.SqlBody)
127+
}
128+
stmts := make([]parser.Statement, len(beginAtomic.Statements))
129+
for i, s := range beginAtomic.Statements {
130+
stmts[i] = parser.Statement{
131+
AST: s,
132+
SQL: s.String(),
133+
}
134+
}
135+
sqlDef, sqlDefParsedStmts, err = convertSQLStmts(stmts, params)
136+
if err != nil {
137+
return nil, err
138+
}
139+
break
121140
}
141+
return nil, errors.Errorf("CREATE FUNCTION definition needed for LANGUAGE SQL")
122142
case "c":
123143
symbolOption, ok := options[tree.OptionAs2]
124144
if !ok {
@@ -185,13 +205,17 @@ func createAnonymousCompositeType(fieldTypes []tree.SimpleColumnDef) *pgtypes.Do
185205
return pgtypes.NewCompositeType(context.Background(), id.Null, id.NullType, typeId, attrs)
186206
}
187207

188-
// handleLanguageSQL handles parsing SQL definition strings in both CREATE FUNCTION and CREATE PROCEDURE.
189-
func handleLanguageSQL(definition string, params []pgnodes.RoutineArg) (string, []vitess.Statement, error) {
208+
// handleLanguageSQLAs handles parsing SQL definition strings in both CREATE FUNCTION and CREATE PROCEDURE.
209+
func handleLanguageSQLAs(definition string, params []pgnodes.RoutineArg) (string, []vitess.Statement, error) {
190210
stmts, err := parser.Parse(definition)
191211
if err != nil {
192212
return "", nil, err
193213
}
194214

215+
return convertSQLStmts(stmts, params)
216+
}
217+
218+
func convertSQLStmts(stmts []parser.Statement, params []pgnodes.RoutineArg) (string, []vitess.Statement, error) {
195219
paramMap := make(map[string]*framework.ParamTypAndValue, len(params))
196220
for i, param := range params {
197221
tv := &framework.ParamTypAndValue{
@@ -212,14 +236,17 @@ func handleLanguageSQL(definition string, params []pgnodes.RoutineArg) (string,
212236
var vitessASTs = make([]vitess.Statement, len(stmts))
213237
for i, stmt := range stmts {
214238
sqlDefs[i] = stmt.AST.String()
215-
err = framework.ReplaceFunctionColumn(stmt.AST, paramMap)
239+
err := framework.ReplaceFunctionColumn(stmt.AST, paramMap)
216240
if err != nil {
217241
return "", nil, err
218242
}
219243
// stmt.AST is updated at this point with FunctionColumn
220244
vitessASTs[i], err = Convert(stmt)
245+
if err != nil {
246+
return "", nil, err
247+
}
221248
}
222-
return strings.Join(sqlDefs, ";"), vitessASTs, err
249+
return strings.Join(sqlDefs, ";"), vitessASTs, nil
223250
}
224251

225252
// validateRoutineOptions ensures that each option is defined only once. Returns a map containing all options, or an

server/ast/create_procedure.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ func nodeCreateProcedure(ctx *Context, node *tree.CreateProcedure) (vitess.State
106106
if !ok {
107107
return nil, errors.Errorf("CREATE PROCEDURE definition needed for LANGUAGE SQL")
108108
}
109-
sqlDef, sqlDefParsedStmts, err = handleLanguageSQL(as.Definition, params)
109+
sqlDef, sqlDefParsedStmts, err = handleLanguageSQLAs(as.Definition, params)
110110
if err != nil {
111111
return nil, err
112112
}

testing/go/create_function_sql_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,5 +361,36 @@ func TestCreateFunctionsLanguageSQL(t *testing.T) {
361361
},
362362
},
363363
},
364+
{
365+
Name: "use BEGIN ATOMIC ... END in sql_body",
366+
SetUpScript: []string{},
367+
Assertions: []ScriptTestAssertion{
368+
{
369+
Query: `CREATE FUNCTION match_default() RETURNS jsonb
370+
LANGUAGE sql
371+
BEGIN ATOMIC
372+
SELECT jsonb_build_object('k', 6, 'm', 2048, 'include_original', true, 'tokenizer', json_build_object('kind', 'ngram', 'token_length', 3), 'token_filters', json_build_array(json_build_object('kind', 'downcase'))) AS jsonb_build_object;
373+
END;`,
374+
Expected: []sql.Row{},
375+
},
376+
{
377+
Skip: true, // TODO
378+
Query: `SELECT public.match_default();`,
379+
Expected: []sql.Row{{`{"k": 6, "m": 2048, "tokenizer": {"kind": "ngram", "token_length": 3}, "token_filters": [{"kind": "downcase"}], "include_original": true}`}},
380+
},
381+
{
382+
Query: `CREATE FUNCTION select1() RETURNS int
383+
LANGUAGE sql
384+
BEGIN ATOMIC
385+
SELECT 1;
386+
END;`,
387+
Expected: []sql.Row{},
388+
},
389+
{
390+
Query: `SELECT select1();`,
391+
Expected: []sql.Row{{1}},
392+
},
393+
},
394+
},
364395
})
365396
}

0 commit comments

Comments
 (0)