Skip to content

Commit 9e22690

Browse files
authored
Revert "Support float8_e3m4 and float8_e4m3 in np_to_memref (#186453)" (#186677)
This reverts commit 57427f8. For some reason mlir-nvidia CI is failing to import `float8_e3m4` from `ml_dtypes`. See https://lab.llvm.org/buildbot/#/builders/138/builds/27095.
1 parent 27dd55b commit 9e22690

2 files changed

Lines changed: 6 additions & 109 deletions

File tree

mlir/python/mlir/runtime/np_to_memref.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,12 @@ class BF16(ctypes.Structure):
3737

3838
_fields_ = [("bf16", ctypes.c_int16)]
3939

40-
4140
class F8E5M2(ctypes.Structure):
4241
"""A ctype representation for MLIR's Float8E5M2."""
4342

4443
_fields_ = [("f8E5M2", ctypes.c_int8)]
4544

4645

47-
class F8E3M4(ctypes.Structure):
48-
"""A ctype representation for MLIR's Float8E3M4."""
49-
50-
_fields_ = [("f8E3M4", ctypes.c_int8)]
51-
52-
53-
class F8E4M3(ctypes.Structure):
54-
"""A ctype representation for MLIR's Float8E4M3."""
55-
56-
_fields_ = [("f8E4M3", ctypes.c_int8)]
57-
58-
5946
# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
6047
def as_ctype(dtp):
6148
"""Converts dtype to ctype."""
@@ -69,10 +56,6 @@ def as_ctype(dtp):
6956
return BF16
7057
if ml_dtypes is not None and dtp == ml_dtypes.float8_e5m2:
7158
return F8E5M2
72-
if ml_dtypes is not None and dtp == ml_dtypes.float8_e3m4:
73-
return F8E3M4
74-
if ml_dtypes is not None and dtp == ml_dtypes.float8_e4m3:
75-
return F8E4M3
7659
return np.ctypeslib.as_ctypes_type(dtp)
7760

7861

@@ -85,17 +68,15 @@ def to_numpy(array):
8568
if array.dtype == F16:
8669
return array.view("float16")
8770
assert not (
88-
array.dtype in (BF16, F8E5M2, F8E3M4, F8E4M3) and ml_dtypes is None
89-
), f"{array.dtype=} requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
71+
array.dtype == BF16 and ml_dtypes is None
72+
), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
9073
if array.dtype == BF16:
9174
return array.view("bfloat16")
75+
assert not (
76+
array.dtype == F8E5M2 and ml_dtypes is None
77+
), f"float8_e5m2 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
9278
if array.dtype == F8E5M2:
9379
return array.view("float8_e5m2")
94-
if array.dtype == F8E3M4:
95-
return array.view("float8_e3m4")
96-
if array.dtype == F8E4M3:
97-
return array.view("float8_e4m3")
98-
9980
return array
10081

10182

mlir/test/python/execution_engine.py

Lines changed: 1 addition & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from mlir.runtime import *
99

1010
try:
11-
from ml_dtypes import bfloat16, float8_e5m2, float8_e3m4, float8_e4m3
11+
from ml_dtypes import bfloat16, float8_e5m2
1212

1313
HAS_ML_DTYPES = True
1414
except ModuleNotFoundError:
@@ -623,90 +623,6 @@ def testF8E5M2Memref():
623623
log("TEST: testF8E5M2Memref")
624624

625625

626-
# Test f8E3M4 memrefs
627-
# CHECK-LABEL: TEST: testF8E3M4Memref
628-
def testF8E3M4Memref():
629-
with Context():
630-
module = Module.parse(
631-
"""
632-
module {
633-
func.func @main(%arg0: memref<1xf8E3M4>,
634-
%arg1: memref<1xf8E3M4>) attributes { llvm.emit_c_interface } {
635-
%0 = arith.constant 0 : index
636-
%1 = memref.load %arg0[%0] : memref<1xf8E3M4>
637-
memref.store %1, %arg1[%0] : memref<1xf8E3M4>
638-
return
639-
}
640-
} """
641-
)
642-
643-
arg1 = np.array([0.5]).astype(float8_e3m4)
644-
arg2 = np.array([0.0]).astype(float8_e3m4)
645-
646-
arg1_memref_ptr = ctypes.pointer(
647-
ctypes.pointer(get_ranked_memref_descriptor(arg1))
648-
)
649-
arg2_memref_ptr = ctypes.pointer(
650-
ctypes.pointer(get_ranked_memref_descriptor(arg2))
651-
)
652-
653-
execution_engine = ExecutionEngine(lowerToLLVM(module))
654-
execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
655-
656-
# test to-numpy utility
657-
x = ranked_memref_to_numpy(arg2_memref_ptr[0])
658-
assert len(x) == 1
659-
assert x[0] == 0.5
660-
661-
662-
if HAS_ML_DTYPES:
663-
run(testF8E3M4Memref)
664-
else:
665-
log("TEST: testF8E3M4Memref")
666-
667-
668-
# Test f8E4M3 memrefs
669-
# CHECK-LABEL: TEST: testF8E4M3Memref
670-
def testF8E4M3Memref():
671-
with Context():
672-
module = Module.parse(
673-
"""
674-
module {
675-
func.func @main(%arg0: memref<1xf8E4M3>,
676-
%arg1: memref<1xf8E4M3>) attributes { llvm.emit_c_interface } {
677-
%0 = arith.constant 0 : index
678-
%1 = memref.load %arg0[%0] : memref<1xf8E4M3>
679-
memref.store %1, %arg1[%0] : memref<1xf8E4M3>
680-
return
681-
}
682-
} """
683-
)
684-
685-
arg1 = np.array([0.5]).astype(float8_e4m3)
686-
arg2 = np.array([0.0]).astype(float8_e4m3)
687-
688-
arg1_memref_ptr = ctypes.pointer(
689-
ctypes.pointer(get_ranked_memref_descriptor(arg1))
690-
)
691-
arg2_memref_ptr = ctypes.pointer(
692-
ctypes.pointer(get_ranked_memref_descriptor(arg2))
693-
)
694-
695-
execution_engine = ExecutionEngine(lowerToLLVM(module))
696-
execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
697-
698-
# test to-numpy utility
699-
x = ranked_memref_to_numpy(arg2_memref_ptr[0])
700-
assert len(x) == 1
701-
assert x[0] == 0.5
702-
703-
704-
if HAS_ML_DTYPES:
705-
run(testF8E4M3Memref)
706-
else:
707-
log("TEST: testF8E4M3Memref")
708-
709-
710626
# Test addition of two 2d_memref
711627
# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
712628
def testDynamicMemrefAdd2D():

0 commit comments

Comments
 (0)