Skip to content

Commit abda48f

Browse files
authored
Merge pull request #2408 from dolthub/jennifer/2388
support `int2vector` and `oidvector` types
1 parent d9d9a96 commit abda48f

25 files changed

Lines changed: 1098 additions & 122 deletions

server/ast/resolvable_type_reference.go

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,24 @@ func nodeResolvableTypeReference(ctx *Context, typ tree.ResolvableTypeReference,
5555
case *types.T:
5656
columnTypeName = columnType.SQLStandardName()
5757
if columnType.Family() == types.ArrayFamily {
58-
_, baseResolvedType, err := nodeResolvableTypeReference(ctx, columnType.ArrayContents(), mayBeTrigger)
59-
if err != nil {
60-
return nil, nil, err
61-
}
62-
if baseResolvedType.IsResolvedType() {
63-
// currently the built-in types will be resolved, so it can retrieve its array type
64-
doltgresType = baseResolvedType.ToArrayType()
65-
} else {
66-
// TODO: handle array type of non-built-in types
67-
baseResolvedType.TypCategory = pgtypes.TypeCategory_ArrayTypes
68-
doltgresType = baseResolvedType
58+
switch columnType.Oid() {
59+
case oid.T_int2vector:
60+
doltgresType = pgtypes.Int16vector
61+
case oid.T_oidvector:
62+
doltgresType = pgtypes.Oidvector
63+
default:
64+
_, baseResolvedType, err := nodeResolvableTypeReference(ctx, columnType.ArrayContents(), mayBeTrigger)
65+
if err != nil {
66+
return nil, nil, err
67+
}
68+
if baseResolvedType.IsResolvedType() {
69+
// currently the built-in types will be resolved, so it can retrieve its array type
70+
doltgresType = baseResolvedType.ToArrayType()
71+
} else {
72+
// TODO: handle array type of non-built-in types
73+
baseResolvedType.TypCategory = pgtypes.TypeCategory_ArrayTypes
74+
doltgresType = baseResolvedType
75+
}
6976
}
7077
} else if columnType.Family() == types.GeometryFamily {
7178
return nil, nil, errors.Errorf("geometry types are not yet supported")
@@ -109,6 +116,8 @@ func nodeResolvableTypeReference(ctx *Context, typ tree.ResolvableTypeReference,
109116
doltgresType = pgtypes.Float64
110117
case oid.T_int2:
111118
doltgresType = pgtypes.Int16
119+
case oid.T_int2vector:
120+
doltgresType = pgtypes.Int16vector
112121
case oid.T_int4:
113122
doltgresType = pgtypes.Int32
114123
case oid.T_int8:
@@ -132,6 +141,8 @@ func nodeResolvableTypeReference(ctx *Context, typ tree.ResolvableTypeReference,
132141
}
133142
case oid.T_oid:
134143
doltgresType = pgtypes.Oid
144+
case oid.T_oidvector:
145+
doltgresType = pgtypes.Oidvector
135146
case oid.T_regclass:
136147
doltgresType = pgtypes.Regclass
137148
case oid.T_regproc:

server/functions/array.go

Lines changed: 78 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -169,34 +169,7 @@ var array_recv = framework.Function3{
169169
baseType := pgtypes.IDToBuiltInDoltgresType[id.Type(baseTypeOid)]
170170
typmod := val3.(int32)
171171
baseType = baseType.WithAttTypMod(typmod)
172-
// Check for the nil value, then ensure the minimum length of the slice
173-
if len(data) == 0 {
174-
return nil, nil
175-
}
176-
if len(data) < 4 {
177-
return nil, errors.Errorf("deserializing non-nil array value has invalid length of %d", len(data))
178-
}
179-
// Grab the number of elements and construct an output slice of the appropriate size
180-
elementCount := binary.LittleEndian.Uint32(data)
181-
output := make([]any, elementCount)
182-
// Read all elements
183-
for i := uint32(0); i < elementCount; i++ {
184-
// We read from i+1 to account for the element count at the beginning
185-
offset := binary.LittleEndian.Uint32(data[(i+1)*4:])
186-
// If the value is null, then we can skip it, since the output slice default initializes all values to nil
187-
if data[offset] == 1 {
188-
continue
189-
}
190-
// The element data is everything from the offset to the next offset, excluding the null determinant
191-
nextOffset := binary.LittleEndian.Uint32(data[(i+2)*4:])
192-
o, err := baseType.DeserializeValue(ctx, data[offset+1:nextOffset])
193-
if err != nil {
194-
return nil, err
195-
}
196-
output[i] = o
197-
}
198-
// Returns all read elements
199-
return output, nil
172+
return deserializeArray(ctx, data, baseType)
200173
},
201174
}
202175

@@ -207,49 +180,9 @@ var array_send = framework.Function1{
207180
Parameters: [1]*pgtypes.DoltgresType{pgtypes.AnyArray},
208181
Strict: true,
209182
Callable: func(ctx *sql.Context, t [2]*pgtypes.DoltgresType, val any) (any, error) {
210-
arrType := t[0]
211-
baseType := arrType.ArrayBaseType()
212183
vals := val.([]any)
213-
214-
bb := bytes.Buffer{}
215-
// Write the element count to a buffer. We're using an array since it's stack-allocated, so no need for pooling.
216-
var elementCount [4]byte
217-
binary.LittleEndian.PutUint32(elementCount[:], uint32(len(vals)))
218-
bb.Write(elementCount[:])
219-
// Create an array that contains the offsets for each value. Since we can't update the offset portion of the buffer
220-
// as we determine the offsets, we have to track them outside the buffer. We'll overwrite the buffer later with the
221-
// correct offsets. The last offset represents the end of the slice, which simplifies the logic for reading elements
222-
// using the "current offset to next offset" strategy. We use a byte slice since the buffer only works with byte
223-
// slices.
224-
offsets := make([]byte, (len(vals)+1)*4)
225-
bb.Write(offsets)
226-
// The starting offset for the first element is Count(uint32) + (NumberOfElementOffsets * sizeof(uint32))
227-
currentOffset := uint32(4 + (len(vals)+1)*4)
228-
for i := range vals {
229-
// Write the current offset
230-
binary.LittleEndian.PutUint32(offsets[i*4:], currentOffset)
231-
// Handle serialization of the value
232-
// TODO: ARRAYs may be multidimensional, such as ARRAY[[4,2],[6,3]], which isn't accounted for here
233-
serializedVal, err := baseType.SerializeValue(ctx, vals[i])
234-
if err != nil {
235-
return nil, err
236-
}
237-
// Handle the nil case and non-nil case
238-
if serializedVal == nil {
239-
bb.WriteByte(1)
240-
currentOffset += 1
241-
} else {
242-
bb.WriteByte(0)
243-
bb.Write(serializedVal)
244-
currentOffset += 1 + uint32(len(serializedVal))
245-
}
246-
}
247-
// Write the final offset, which will equal the length of the serialized slice
248-
binary.LittleEndian.PutUint32(offsets[len(offsets)-4:], currentOffset)
249-
// Get the final output, and write the updated offsets to it
250-
outputBytes := bb.Bytes()
251-
copy(outputBytes[4:], offsets)
252-
return outputBytes, nil
184+
arrType := t[0]
185+
return serializeArray(ctx, vals, arrType.ArrayBaseType())
253186
},
254187
}
255188

@@ -301,3 +234,78 @@ var array_subscript_handler = framework.Function1{
301234
return []byte{}, nil
302235
},
303236
}
237+
238+
// deserializeArray serializes an array of given base type.
239+
func serializeArray(ctx *sql.Context, vals []any, baseType *pgtypes.DoltgresType) ([]byte, error) {
240+
bb := bytes.Buffer{}
241+
// Write the element count to a buffer. We're using an array since it's stack-allocated, so no need for pooling.
242+
var elementCount [4]byte
243+
binary.LittleEndian.PutUint32(elementCount[:], uint32(len(vals)))
244+
bb.Write(elementCount[:])
245+
// Create an array that contains the offsets for each value. Since we can't update the offset portion of the buffer
246+
// as we determine the offsets, we have to track them outside the buffer. We'll overwrite the buffer later with the
247+
// correct offsets. The last offset represents the end of the slice, which simplifies the logic for reading elements
248+
// using the "current offset to next offset" strategy. We use a byte slice since the buffer only works with byte
249+
// slices.
250+
offsets := make([]byte, (len(vals)+1)*4)
251+
bb.Write(offsets)
252+
// The starting offset for the first element is Count(uint32) + (NumberOfElementOffsets * sizeof(uint32))
253+
currentOffset := uint32(4 + (len(vals)+1)*4)
254+
for i := range vals {
255+
// Write the current offset
256+
binary.LittleEndian.PutUint32(offsets[i*4:], currentOffset)
257+
// Handle serialization of the value
258+
// TODO: ARRAYs may be multidimensional, such as ARRAY[[4,2],[6,3]], which isn't accounted for here
259+
serializedVal, err := baseType.SerializeValue(ctx, vals[i])
260+
if err != nil {
261+
return nil, err
262+
}
263+
// Handle the nil case and non-nil case
264+
if serializedVal == nil {
265+
bb.WriteByte(1)
266+
currentOffset += 1
267+
} else {
268+
bb.WriteByte(0)
269+
bb.Write(serializedVal)
270+
currentOffset += 1 + uint32(len(serializedVal))
271+
}
272+
}
273+
// Write the final offset, which will equal the length of the serialized slice
274+
binary.LittleEndian.PutUint32(offsets[len(offsets)-4:], currentOffset)
275+
// Get the final output, and write the updated offsets to it
276+
outputBytes := bb.Bytes()
277+
copy(outputBytes[4:], offsets)
278+
return outputBytes, nil
279+
}
280+
281+
// deserializeArray deserializes an array of given base type.
282+
func deserializeArray(ctx *sql.Context, data []byte, baseType *pgtypes.DoltgresType) ([]any, error) {
283+
// Check for the nil value, then ensure the minimum length of the slice
284+
if len(data) == 0 {
285+
return nil, nil
286+
}
287+
if len(data) < 4 {
288+
return nil, errors.Errorf("deserializing non-nil array value has invalid length of %d", len(data))
289+
}
290+
// Grab the number of elements and construct an output slice of the appropriate size
291+
elementCount := binary.LittleEndian.Uint32(data)
292+
output := make([]any, elementCount)
293+
// Read all elements
294+
for i := uint32(0); i < elementCount; i++ {
295+
// We read from i+1 to account for the element count at the beginning
296+
offset := binary.LittleEndian.Uint32(data[(i+1)*4:])
297+
// If the value is null, then we can skip it, since the output slice default initializes all values to nil
298+
if data[offset] == 1 {
299+
continue
300+
}
301+
// The element data is everything from the offset to the next offset, excluding the null determinant
302+
nextOffset := binary.LittleEndian.Uint32(data[(i+2)*4:])
303+
o, err := baseType.DeserializeValue(ctx, data[offset+1:nextOffset])
304+
if err != nil {
305+
return nil, err
306+
}
307+
output[i] = o
308+
}
309+
// Returns all read elements
310+
return output, nil
311+
}

server/functions/binary/equal.go

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ func initBinaryEqual() {
6161
framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, nameeqtext)
6262
framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, numeric_eq)
6363
framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, oideq)
64+
framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, oidvectoreq)
6465
framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, texteqname)
6566
framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, text_eq)
6667
framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, record_eq)
@@ -469,25 +470,30 @@ var numeric_eq = framework.Function2{
469470
Callable: numeric_eq_callable,
470471
}
471472

472-
// oideq_callable is the callable logic for the oideq function.
473-
// This method doesn't use DotlgresType.Compare because it's on the critical path for many tooling queries that
474-
// examine the pg_catalog tables.
475-
func oideq_callable(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) {
476-
if val1 == nil || val2 == nil {
477-
return false, nil
478-
}
479-
480-
val1id, val2id := val1.(id.Id), val2.(id.Id)
481-
return val1id == val2id, nil
482-
}
483-
484473
// oideq represents the PostgreSQL function of the same name, taking the same parameters.
485474
var oideq = framework.Function2{
486475
Name: "oideq",
487476
Return: pgtypes.Bool,
488477
Parameters: [2]*pgtypes.DoltgresType{pgtypes.Oid, pgtypes.Oid},
489478
Strict: true,
490-
Callable: oideq_callable,
479+
Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) {
480+
// This method doesn't use DoltgresType.Compare because it's on the critical path for many tooling queries that
481+
// examine the pg_catalog tables.
482+
val1id, val2id := val1.(id.Id), val2.(id.Id)
483+
return val1id == val2id, nil
484+
},
485+
}
486+
487+
// oidvectoreq represents the PostgreSQL function of the same name, taking the same parameters.
488+
var oidvectoreq = framework.Function2{
489+
Name: "oidvectoreq",
490+
Return: pgtypes.Bool,
491+
Parameters: [2]*pgtypes.DoltgresType{pgtypes.Oidvector, pgtypes.Oidvector},
492+
Strict: true,
493+
Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) {
494+
res, err := pgtypes.Oidvector.Compare(ctx, val1.([]any), val2.([]any))
495+
return res == 0, err
496+
},
491497
}
492498

493499
// texteqname_callable is the callable logic for the texteqname function.

server/functions/binary/greater.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ func initBinaryGreaterThan() {
6161
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterThan, namegttext)
6262
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterThan, numeric_gt)
6363
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterThan, oidgt)
64+
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterThan, oidvectorgt)
6465
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterThan, textgtname)
6566
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterThan, text_gt)
6667
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterThan, time_gt)
@@ -399,6 +400,18 @@ var oidgt = framework.Function2{
399400
},
400401
}
401402

403+
// oidvectorgt represents the PostgreSQL function of the same name, taking the same parameters.
404+
var oidvectorgt = framework.Function2{
405+
Name: "oidvectorgt",
406+
Return: pgtypes.Bool,
407+
Parameters: [2]*pgtypes.DoltgresType{pgtypes.Oidvector, pgtypes.Oidvector},
408+
Strict: true,
409+
Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) {
410+
res, err := pgtypes.Oidvector.Compare(ctx, val1.([]any), val2.([]any))
411+
return res == 1, err
412+
},
413+
}
414+
402415
// textgtname represents the PostgreSQL function of the same name, taking the same parameters.
403416
var textgtname = framework.Function2{
404417
Name: "textgtname",

server/functions/binary/greater_equal.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ func initBinaryGreaterOrEqual() {
6161
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterOrEqual, namegetext)
6262
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterOrEqual, numeric_ge)
6363
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterOrEqual, oidge)
64+
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterOrEqual, oidvectorge)
6465
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterOrEqual, textgename)
6566
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterOrEqual, text_ge)
6667
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterOrEqual, time_ge)
@@ -399,6 +400,18 @@ var oidge = framework.Function2{
399400
},
400401
}
401402

403+
// oidvectorge represents the PostgreSQL function of the same name, taking the same parameters.
404+
var oidvectorge = framework.Function2{
405+
Name: "oidvectorge",
406+
Return: pgtypes.Bool,
407+
Parameters: [2]*pgtypes.DoltgresType{pgtypes.Oidvector, pgtypes.Oidvector},
408+
Strict: true,
409+
Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) {
410+
res, err := pgtypes.Oidvector.Compare(ctx, val1.([]any), val2.([]any))
411+
return res >= 0, err
412+
},
413+
}
414+
402415
// textgename represents the PostgreSQL function of the same name, taking the same parameters.
403416
var textgename = framework.Function2{
404417
Name: "textgename",

server/functions/binary/less.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ func initBinaryLessThan() {
6161
framework.RegisterBinaryFunction(framework.Operator_BinaryLessThan, namelttext)
6262
framework.RegisterBinaryFunction(framework.Operator_BinaryLessThan, numeric_lt)
6363
framework.RegisterBinaryFunction(framework.Operator_BinaryLessThan, oidlt)
64+
framework.RegisterBinaryFunction(framework.Operator_BinaryLessThan, oidvectorlt)
6465
framework.RegisterBinaryFunction(framework.Operator_BinaryLessThan, textltname)
6566
framework.RegisterBinaryFunction(framework.Operator_BinaryLessThan, text_lt)
6667
framework.RegisterBinaryFunction(framework.Operator_BinaryLessThan, time_lt)
@@ -399,6 +400,18 @@ var oidlt = framework.Function2{
399400
},
400401
}
401402

403+
// oidvectorlt represents the PostgreSQL function of the same name, taking the same parameters.
404+
var oidvectorlt = framework.Function2{
405+
Name: "oidvectorlt",
406+
Return: pgtypes.Bool,
407+
Parameters: [2]*pgtypes.DoltgresType{pgtypes.Oidvector, pgtypes.Oidvector},
408+
Strict: true,
409+
Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) {
410+
res, err := pgtypes.Oidvector.Compare(ctx, val1.([]any), val2.([]any))
411+
return res == -1, err
412+
},
413+
}
414+
402415
// textltname represents the PostgreSQL function of the same name, taking the same parameters.
403416
var textltname = framework.Function2{
404417
Name: "textltname",

server/functions/binary/less_equal.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ func initBinaryLessOrEqual() {
6161
framework.RegisterBinaryFunction(framework.Operator_BinaryLessOrEqual, nameletext)
6262
framework.RegisterBinaryFunction(framework.Operator_BinaryLessOrEqual, numeric_le)
6363
framework.RegisterBinaryFunction(framework.Operator_BinaryLessOrEqual, oidle)
64+
framework.RegisterBinaryFunction(framework.Operator_BinaryLessOrEqual, oidvectorle)
6465
framework.RegisterBinaryFunction(framework.Operator_BinaryLessOrEqual, textlename)
6566
framework.RegisterBinaryFunction(framework.Operator_BinaryLessOrEqual, text_le)
6667
framework.RegisterBinaryFunction(framework.Operator_BinaryLessOrEqual, time_le)
@@ -399,6 +400,18 @@ var oidle = framework.Function2{
399400
},
400401
}
401402

403+
// oidvectorle represents the PostgreSQL function of the same name, taking the same parameters.
404+
var oidvectorle = framework.Function2{
405+
Name: "oidvectorle",
406+
Return: pgtypes.Bool,
407+
Parameters: [2]*pgtypes.DoltgresType{pgtypes.Oidvector, pgtypes.Oidvector},
408+
Strict: true,
409+
Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) {
410+
res, err := pgtypes.Oidvector.Compare(ctx, val1.([]any), val2.([]any))
411+
return res <= 0, err
412+
},
413+
}
414+
402415
// textlename represents the PostgreSQL function of the same name, taking the same parameters.
403416
var textlename = framework.Function2{
404417
Name: "textlename",

0 commit comments

Comments
 (0)