Skip to content
This repository was archived by the owner on May 18, 2019. It is now read-only.

Commit d090ace

Browse files
perostOpenModelica-Hudson
authored andcommitted
[NF] Improve operator overloading.
- Implement scalar*array, array*scalar and array/scalar for overloaded operators. - Improve TypeCheck.implicitConstructAndMatch so that it checks that the constructed argument actually matches the expected type for the operator, to avoid it matching e.g. scalars with operators that only take arrays. Belonging to [master]: - #2837 - OpenModelica/OpenModelica-testsuite#1095
1 parent b8ddb69 commit d090ace

1 file changed

Lines changed: 271 additions & 20 deletions

File tree

Compiler/NFFrontEnd/NFTypeCheck.mo

Lines changed: 271 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,12 @@ algorithm
250250
if oop == Op.ADD or oop == Op.SUB then
251251
(outExp, outType) :=
252252
checkOverloadedBinaryArrayAddSub(exp1, type1, var1, op, exp2, type2, var2, candidates, info);
253+
elseif oop == Op.MUL then
254+
(outExp, outType) :=
255+
checkOverloadedBinaryArrayMul(exp1, type1, var1, op, exp2, type2, var2, candidates, info);
256+
elseif oop == Op.DIV then
257+
(outExp, outType) :=
258+
checkOverloadedBinaryArrayDiv(exp1, type1, var1, op, exp2, type2, var2, candidates, info);
253259
else
254260
printUnresolvableTypeError(Expression.BINARY(exp1, op, exp2), {type1, type2}, info, showErrors);
255261
end if;
@@ -363,6 +369,210 @@ algorithm
363369
end match;
364370
end checkOverloadedBinaryArrayAddSub2;
365371

372+
function checkOverloadedBinaryArrayMul
373+
input Expression exp1;
374+
input Type type1;
375+
input Variability var1;
376+
input Operator op;
377+
input Expression exp2;
378+
input Type type2;
379+
input Variability var2;
380+
input list<Function> candidates;
381+
input SourceInfo info;
382+
output Expression outExp;
383+
output Type outType;
384+
protected
385+
Boolean valid;
386+
list<Dimension> dims1, dims2;
387+
Dimension dim11, dim12, dim21, dim22;
388+
algorithm
389+
dims1 := Type.arrayDims(type1);
390+
dims2 := Type.arrayDims(type2);
391+
392+
(valid, outExp) := match (dims1, dims2)
393+
// scalar * array = array
394+
case ({}, {_})
395+
algorithm
396+
outExp := checkOverloadedBinaryScalarArray(exp1, type1, var1, op, exp2, type2, var2, candidates, info);
397+
then
398+
(true, outExp);
399+
// array * scalar = array
400+
case ({_}, {})
401+
algorithm
402+
outExp := checkOverloadedBinaryArrayScalar(exp1, type1, var1, op, exp2, type2, var2, candidates, info);
403+
then
404+
(true, outExp);
405+
// matrix[n, m] * vector[m] = vector[n]
406+
case ({dim11, dim12}, {dim21})
407+
algorithm
408+
valid := Dimension.isEqual(dim12, dim21);
409+
// TODO: Implement me!
410+
outExp := Expression.BINARY(exp1, op, exp2);
411+
valid := false;
412+
then
413+
(valid, outExp);
414+
// matrix[n, m] * matrix[m, p] = vector[n, p]
415+
case ({dim11, dim12}, {dim21, dim22})
416+
algorithm
417+
valid := Dimension.isEqual(dim12, dim21);
418+
// TODO: Implement me!
419+
outExp := Expression.BINARY(exp1, op, exp2);
420+
valid := false;
421+
then
422+
(valid, outExp);
423+
// scalar * scalar should never get here.
424+
// vector * vector and vector * matrix are undefined for overloaded operators.
425+
else (false, Expression.BINARY(exp1, op, exp2));
426+
end match;
427+
428+
if not valid then
429+
printUnresolvableTypeError(outExp, {type1, type2}, info);
430+
end if;
431+
432+
outType := Expression.typeOf(outExp);
433+
end checkOverloadedBinaryArrayMul;
434+
435+
function checkOverloadedBinaryScalarArray
436+
input Expression exp1;
437+
input Type type1;
438+
input Variability var1;
439+
input Operator op;
440+
input Expression exp2;
441+
input Type type2;
442+
input Variability var2;
443+
input list<Function> candidates;
444+
input SourceInfo info;
445+
output Expression outExp;
446+
output Type outType;
447+
algorithm
448+
(outExp, outType) := checkOverloadedBinaryScalarArray2(
449+
exp1, type1, var1, op, ExpandExp.expand(exp2), type2, var2, candidates, info);
450+
end checkOverloadedBinaryScalarArray;
451+
452+
function checkOverloadedBinaryScalarArray2
453+
input Expression exp1;
454+
input Type type1;
455+
input Variability var1;
456+
input Operator op;
457+
input Expression exp2;
458+
input Type type2;
459+
input Variability var2;
460+
input list<Function> candidates;
461+
input SourceInfo info;
462+
output Expression outExp;
463+
output Type outType;
464+
protected
465+
list<Expression> expl;
466+
Type ty;
467+
algorithm
468+
(outExp, outType) := match exp2
469+
case Expression.ARRAY(elements = {})
470+
algorithm
471+
try
472+
ty := Type.unliftArray(type2);
473+
(_, outType) := matchOverloadedBinaryOperator(
474+
exp1, type1, var1, op, Expression.EMPTY(type2), ty, var2, candidates, info, showErrors = false);
475+
else
476+
printUnresolvableTypeError(Expression.BINARY(exp1, op, exp2), {type1, exp2.ty}, info);
477+
end try;
478+
479+
outType := Type.setArrayElementType(exp2.ty, outType);
480+
then
481+
(Expression.makeArray(outType, {}), outType);
482+
483+
case Expression.ARRAY(elements = expl)
484+
algorithm
485+
ty := Type.unliftArray(type2);
486+
expl := list(checkOverloadedBinaryScalarArray2(exp1, type1, var1, op, e, ty, var2, candidates, info) for e in expl);
487+
outType := Type.setArrayElementType(exp2.ty, Expression.typeOf(listHead(expl)));
488+
then
489+
(Expression.makeArray(outType, expl), outType);
490+
491+
else matchOverloadedBinaryOperator(exp1, type1, var1, op, exp2, type2, var2, candidates, info);
492+
end match;
493+
end checkOverloadedBinaryScalarArray2;
494+
495+
function checkOverloadedBinaryArrayScalar
496+
input Expression exp1;
497+
input Type type1;
498+
input Variability var1;
499+
input Operator op;
500+
input Expression exp2;
501+
input Type type2;
502+
input Variability var2;
503+
input list<Function> candidates;
504+
input SourceInfo info;
505+
output Expression outExp;
506+
output Type outType;
507+
algorithm
508+
(outExp, outType) := checkOverloadedBinaryArrayScalar2(
509+
ExpandExp.expand(exp1), type1, var1, op, exp2, type2, var2, candidates, info);
510+
end checkOverloadedBinaryArrayScalar;
511+
512+
function checkOverloadedBinaryArrayScalar2
513+
input Expression exp1;
514+
input Type type1;
515+
input Variability var1;
516+
input Operator op;
517+
input Expression exp2;
518+
input Type type2;
519+
input Variability var2;
520+
input list<Function> candidates;
521+
input SourceInfo info;
522+
output Expression outExp;
523+
output Type outType;
524+
protected
525+
Expression e1;
526+
list<Expression> expl;
527+
Type ty;
528+
algorithm
529+
(outExp, outType) := match exp1
530+
case Expression.ARRAY(elements = {})
531+
algorithm
532+
try
533+
ty := Type.unliftArray(type1);
534+
(_, outType) := matchOverloadedBinaryOperator(
535+
Expression.EMPTY(type1), ty, var1, op, exp2, type2, var2, candidates, info, showErrors = false);
536+
else
537+
printUnresolvableTypeError(Expression.BINARY(exp1, op, exp2), {type1, exp1.ty}, info);
538+
end try;
539+
540+
outType := Type.setArrayElementType(exp1.ty, outType);
541+
then
542+
(Expression.makeArray(outType, {}), outType);
543+
544+
case Expression.ARRAY(elements = expl)
545+
algorithm
546+
ty := Type.unliftArray(type1);
547+
expl := list(checkOverloadedBinaryArrayScalar2(e, ty, var1, op, exp2, type2, var2, candidates, info) for e in expl);
548+
outType := Type.setArrayElementType(exp1.ty, Expression.typeOf(listHead(expl)));
549+
then
550+
(Expression.makeArray(outType, expl), outType);
551+
552+
else matchOverloadedBinaryOperator(exp1, type1, var1, op, exp2, type2, var2, candidates, info);
553+
end match;
554+
end checkOverloadedBinaryArrayScalar2;
555+
556+
function checkOverloadedBinaryArrayDiv
557+
input Expression exp1;
558+
input Type type1;
559+
input Variability var1;
560+
input Operator op;
561+
input Expression exp2;
562+
input Type type2;
563+
input Variability var2;
564+
input list<Function> candidates;
565+
input SourceInfo info;
566+
output Expression outExp;
567+
output Type outType;
568+
algorithm
569+
if Type.isArray(type1) and Type.isScalar(type2) then
570+
(outExp, outType) := checkOverloadedBinaryArrayScalar(exp1, type1, var1, op, exp2, type2, var2, candidates, info);
571+
else
572+
printUnresolvableTypeError(Expression.BINARY(exp1, op, exp2), {type1, type2}, info);
573+
end if;
574+
end checkOverloadedBinaryArrayDiv;
575+
366576
function implicitConstructAndMatch
367577
input list<Function> candidates;
368578
input Expression inExp1;
@@ -381,32 +591,32 @@ protected
381591
Function operfn;
382592
list<tuple<Function, list<Expression>, Variability>> matchedfuncs = {};
383593
Expression exp1,exp2;
384-
Type ty;
594+
Type ty, arg1_ty, arg2_ty;
385595
Variability var;
596+
Boolean matched;
597+
SourceInfo arg1_info, arg2_info;
386598
algorithm
387599
exp1 := inExp1; exp2 := inExp2;
388600
for fn in candidates loop
389601
in1 :: in2 :: _ := fn.inputs;
390-
(_, _, mk1) := matchTypes(InstNode.getType(in1),inType1,inExp1,false);
391-
(_, _, mk2) := matchTypes(InstNode.getType(in2),inType2,inExp2,false);
392-
393-
// If the first argument matched the expected one, then we try
394-
// to construct the second argument to the class of the second input.
395-
if mk1 == MatchKind.EXACT then
396-
// We only want overloaded constructors when trying to implicitly construct. Default constructors are not considered.
397-
scope := InstNode.classScope(in2);
398-
fn_ref := Function.instFunction(Absyn.CREF_IDENT("'constructor'",{}),scope,InstNode.info(in2));
399-
exp2 := Expression.CALL(NFCall.UNTYPED_CALL(fn_ref, {inExp2}, {}, scope));
400-
(exp2, ty, var) := Call.typeCall(exp2, 0, InstNode.info(in1));
401-
matchedfuncs := (fn,{inExp1,exp2}, var)::matchedfuncs;
402-
elseif mk2 == MatchKind.EXACT then
403-
// We only want overloaded constructors when trying to implicitly construct. Default constructors are not considered.
404-
scope := InstNode.classScope(in1);
405-
fn_ref := Function.instFunction(Absyn.CREF_IDENT("'constructor'",{}),scope,InstNode.info(in1));
406-
exp1 := Expression.CALL(NFCall.UNTYPED_CALL(fn_ref, {inExp1}, {}, scope));
407-
(exp1, ty, var) := Call.typeCall(exp1, 0, InstNode.info(in2));
408-
matchedfuncs := (fn,{exp1,inExp2},var)::matchedfuncs;
602+
arg1_ty := InstNode.getType(in1);
603+
arg2_ty := InstNode.getType(in2);
604+
arg1_info := InstNode.info(in1);
605+
arg2_info := InstNode.info(in2);
606+
607+
// Try to implicitly construct a matching record from the first argument.
608+
(matchedfuncs, matched) :=
609+
implicitConstructAndMatch2(inExp1, inType1, inExp2, arg1_ty,
610+
arg1_info, arg2_ty, arg2_info, InstNode.classScope(in2), fn, false, matchedfuncs);
611+
612+
if matched then
613+
continue;
409614
end if;
615+
616+
// Try to implicitly construct a matching record from the second argument.
617+
(matchedfuncs, matched) :=
618+
implicitConstructAndMatch2(inExp2, inType2, inExp1, arg2_ty,
619+
arg2_info, arg1_ty, arg1_info, InstNode.classScope(in1), fn, true, matchedfuncs);
410620
end for;
411621

412622
if listLength(matchedfuncs) == 1 then
@@ -421,6 +631,47 @@ algorithm
421631
end if;
422632
end implicitConstructAndMatch;
423633

634+
function implicitConstructAndMatch2
635+
input Expression exp1;
636+
input Type type1;
637+
input Expression exp2;
638+
input Type paramType1;
639+
input SourceInfo paramInfo1;
640+
input Type paramType2;
641+
input SourceInfo paramInfo2;
642+
input InstNode scope;
643+
input Function fn;
644+
input Boolean reverseArgs;
645+
input output list<tuple<Function, list<Expression>, Variability>> matchedFns;
646+
output Boolean matched;
647+
protected
648+
ComponentRef fn_ref;
649+
Expression e1, e2;
650+
MatchKind mk;
651+
Variability var;
652+
Type ty;
653+
algorithm
654+
(e1, _, mk) := matchTypes(paramType1, type1, exp1, false);
655+
656+
// We only want overloaded constructors when trying to implicitly construct.
657+
// Default constructors are not considered.
658+
if mk == MatchKind.EXACT then
659+
fn_ref := Function.instFunction(Absyn.CREF_IDENT("'constructor'", {}), scope, paramInfo2);
660+
e2 := Expression.CALL(NFCall.UNTYPED_CALL(fn_ref, {exp2}, {}, scope));
661+
(e2, ty, var) := Call.typeCall(e2, 0, paramInfo1);
662+
(_, _, mk) := matchTypes(paramType2, ty, e2, false);
663+
664+
if mk == MatchKind.EXACT then
665+
matchedFns := (fn, if reverseArgs then {e2, e1} else {e1, e2}, var) :: matchedFns;
666+
matched := true;
667+
else
668+
matched := false;
669+
end if;
670+
else
671+
matched := false;
672+
end if;
673+
end implicitConstructAndMatch2;
674+
424675
//function checkValidBinaryOperatorOverload
425676
// input String oper_name;
426677
// input Function oper_func;

0 commit comments

Comments
 (0)