@@ -1007,13 +1007,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
10071007 return b.create <arith::SelectOp>(loc, pred, lhs, rhs);
10081008 }
10091009 if (auto clamp = dyn_cast<AtenClampOp>(op)) {
1010- Type dtype = converter->convertType (clamp.getType ())
1011- .cast <RankedTensorType>()
1012- .getElementType ();
1013- if (!dtype.isa <mlir::FloatType>()) {
1014- clamp.emitError (" unimplemented: non-floating point dtype" );
1015- return nullptr ;
1016- }
10171010 AtenClampOp::Adaptor adaptor (operands);
10181011 auto min = adaptor.getMin ();
10191012 auto max = adaptor.getMax ();
@@ -1022,19 +1015,45 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
10221015 clamp.emitError (" unimplemented: runtime optional type" );
10231016 return nullptr ;
10241017 }
1025- auto result = payloadArgs[0 ];
1026- if (!min.getType ().isa <Torch::NoneType>()) {
1027- auto minPromoted = convertScalarToDtype (b, loc, min, dtype);
1028- auto pred = b.create <arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
1029- result, minPromoted);
1030- result = b.create <arith::SelectOp>(loc, pred, minPromoted, result);
1018+
1019+ Type dtype = converter->convertType (clamp.getType ())
1020+ .cast <RankedTensorType>()
1021+ .getElementType ();
1022+ if (!dtype.isa <mlir::FloatType, mlir::IntegerType>()) {
1023+ clamp.emitError (" unimplement type for clamp" );
1024+ return nullptr ;
10311025 }
1032- if (!max. getType (). isa <Torch::NoneType>()) {
1033- auto maxPromoted = convertScalarToDtype (b, loc, max, dtype );
1034- auto pred = b. create <arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
1035- result, maxPromoted);
1036- result = b. create <arith::SelectOp>(loc, pred, maxPromoted, result );
1026+
1027+ Type dstOriginalDtype = clamp. getType (). cast <BaseTensorType>(). getDtype ( );
1028+ bool isUnsigned = isa<QUInt8Type>(dstOriginalDtype);
1029+ if ( auto intTy = dstOriginalDtype. dyn_cast <IntegerType>()) {
1030+ isUnsigned = intTy. isUnsigned ( );
10371031 }
1032+ auto cmpSelect = [&](Value input, Value clamp, bool getMax) -> Value {
1033+ clamp = convertScalarToDtype (b, loc, clamp, dtype,
1034+ /* srcOriginalDtype=*/ std::nullopt ,
1035+ /* dstOriginalDtype=*/ dstOriginalDtype);
1036+
1037+ Value pred;
1038+ if (dtype.isa <mlir::FloatType>()) {
1039+ auto cmp =
1040+ getMax ? arith::CmpFPredicate::UGT : arith::CmpFPredicate::ULT;
1041+ pred = b.create <arith::CmpFOp>(loc, cmp, input, clamp);
1042+ } else if (dtype.isa <mlir::IntegerType>()) {
1043+ auto cmp =
1044+ isUnsigned ? arith::CmpIPredicate::ult : arith::CmpIPredicate::slt;
1045+ if (getMax)
1046+ cmp = arith::invertPredicate (cmp);
1047+ pred = b.create <arith::CmpIOp>(loc, cmp, input, clamp);
1048+ }
1049+ return b.create <arith::SelectOp>(loc, pred, clamp, input);
1050+ };
1051+
1052+ auto result = payloadArgs[0 ];
1053+ if (!min.getType ().isa <Torch::NoneType>())
1054+ result = cmpSelect (result, min, /* getMax=*/ false );
1055+ if (!max.getType ().isa <Torch::NoneType>())
1056+ result = cmpSelect (result, max, /* getMax=*/ true );
10381057 return result;
10391058 }
10401059 if (auto clampTensor = dyn_cast<AtenClampTensorOp>(op)) {
0 commit comments