|
8 | 8 | from mlir.runtime import * |
9 | 9 |
|
10 | 10 | try: |
11 | | - from ml_dtypes import bfloat16, float8_e5m2, float8_e3m4, float8_e4m3 |
| 11 | + from ml_dtypes import bfloat16, float8_e5m2 |
12 | 12 |
|
13 | 13 | HAS_ML_DTYPES = True |
14 | 14 | except ModuleNotFoundError: |
@@ -623,90 +623,6 @@ def testF8E5M2Memref(): |
623 | 623 | log("TEST: testF8E5M2Memref") |
624 | 624 |
|
625 | 625 |
|
626 | | -# Test f8E3M4 memrefs |
627 | | -# CHECK-LABEL: TEST: testF8E3M4Memref |
628 | | -def testF8E3M4Memref(): |
629 | | - with Context(): |
630 | | - module = Module.parse( |
631 | | - """ |
632 | | - module { |
633 | | - func.func @main(%arg0: memref<1xf8E3M4>, |
634 | | - %arg1: memref<1xf8E3M4>) attributes { llvm.emit_c_interface } { |
635 | | - %0 = arith.constant 0 : index |
636 | | - %1 = memref.load %arg0[%0] : memref<1xf8E3M4> |
637 | | - memref.store %1, %arg1[%0] : memref<1xf8E3M4> |
638 | | - return |
639 | | - } |
640 | | - } """ |
641 | | - ) |
642 | | - |
643 | | - arg1 = np.array([0.5]).astype(float8_e3m4) |
644 | | - arg2 = np.array([0.0]).astype(float8_e3m4) |
645 | | - |
646 | | - arg1_memref_ptr = ctypes.pointer( |
647 | | - ctypes.pointer(get_ranked_memref_descriptor(arg1)) |
648 | | - ) |
649 | | - arg2_memref_ptr = ctypes.pointer( |
650 | | - ctypes.pointer(get_ranked_memref_descriptor(arg2)) |
651 | | - ) |
652 | | - |
653 | | - execution_engine = ExecutionEngine(lowerToLLVM(module)) |
654 | | - execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr) |
655 | | - |
656 | | - # test to-numpy utility |
657 | | - x = ranked_memref_to_numpy(arg2_memref_ptr[0]) |
658 | | - assert len(x) == 1 |
659 | | - assert x[0] == 0.5 |
660 | | - |
661 | | - |
662 | | -if HAS_ML_DTYPES: |
663 | | - run(testF8E3M4Memref) |
664 | | -else: |
665 | | - log("TEST: testF8E3M4Memref") |
666 | | - |
667 | | - |
668 | | -# Test f8E4M3 memrefs |
669 | | -# CHECK-LABEL: TEST: testF8E4M3Memref |
670 | | -def testF8E4M3Memref(): |
671 | | - with Context(): |
672 | | - module = Module.parse( |
673 | | - """ |
674 | | - module { |
675 | | - func.func @main(%arg0: memref<1xf8E4M3>, |
676 | | - %arg1: memref<1xf8E4M3>) attributes { llvm.emit_c_interface } { |
677 | | - %0 = arith.constant 0 : index |
678 | | - %1 = memref.load %arg0[%0] : memref<1xf8E4M3> |
679 | | - memref.store %1, %arg1[%0] : memref<1xf8E4M3> |
680 | | - return |
681 | | - } |
682 | | - } """ |
683 | | - ) |
684 | | - |
685 | | - arg1 = np.array([0.5]).astype(float8_e4m3) |
686 | | - arg2 = np.array([0.0]).astype(float8_e4m3) |
687 | | - |
688 | | - arg1_memref_ptr = ctypes.pointer( |
689 | | - ctypes.pointer(get_ranked_memref_descriptor(arg1)) |
690 | | - ) |
691 | | - arg2_memref_ptr = ctypes.pointer( |
692 | | - ctypes.pointer(get_ranked_memref_descriptor(arg2)) |
693 | | - ) |
694 | | - |
695 | | - execution_engine = ExecutionEngine(lowerToLLVM(module)) |
696 | | - execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr) |
697 | | - |
698 | | - # test to-numpy utility |
699 | | - x = ranked_memref_to_numpy(arg2_memref_ptr[0]) |
700 | | - assert len(x) == 1 |
701 | | - assert x[0] == 0.5 |
702 | | - |
703 | | - |
704 | | -if HAS_ML_DTYPES: |
705 | | - run(testF8E4M3Memref) |
706 | | -else: |
707 | | - log("TEST: testF8E4M3Memref") |
708 | | - |
709 | | - |
710 | 626 | # Test addition of two 2d_memref |
711 | 627 | # CHECK-LABEL: TEST: testDynamicMemrefAdd2D |
712 | 628 | def testDynamicMemrefAdd2D(): |
|
0 commit comments