1616# under the License.
1717
1818import pytest
19- from datafusion import ExecutionPlan , LogicalPlan , SessionContext
19+ from datafusion import (
20+ ExecutionPlan ,
21+ LogicalPlan ,
22+ Metric ,
23+ MetricsSet ,
24+ SessionContext ,
25+ )
2026
2127
2228# Note: We must use CSV because memory tables are currently not supported for
@@ -40,3 +46,147 @@ def test_logical_plan_to_proto(ctx, df) -> None:
4046 execution_plan = ExecutionPlan .from_proto (ctx , execution_plan_bytes )
4147
4248 assert str (original_execution_plan ) == str (execution_plan )
49+
50+
51+ def test_execution_plan_metrics () -> None :
52+ ctx = SessionContext ()
53+ ctx .sql ("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')" )
54+ df = ctx .sql ("SELECT * FROM t WHERE column1 > 1" )
55+
56+ df .collect ()
57+ plan = df .execution_plan ()
58+
59+ found_metrics = False
60+
61+ def _check (node ):
62+ nonlocal found_metrics
63+ ms = node .metrics ()
64+ if ms is not None and ms .output_rows is not None and ms .output_rows > 0 :
65+ found_metrics = True
66+ for child in node .children ():
67+ _check (child )
68+
69+ _check (plan )
70+ assert found_metrics
71+
72+
73+ def test_metric_properties () -> None :
74+ ctx = SessionContext ()
75+ ctx .sql ("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')" )
76+ df = ctx .sql ("SELECT * FROM t WHERE column1 > 1" )
77+
78+ df .collect ()
79+ plan = df .execution_plan ()
80+
81+ for _ , ms in plan .collect_metrics ():
82+ for metric in ms .metrics ():
83+ assert isinstance (metric , Metric )
84+ assert isinstance (metric .name , str )
85+ assert len (metric .name ) > 0
86+ assert metric .partition is None or isinstance (metric .partition , int )
87+ assert isinstance (metric .labels (), dict )
88+ return
89+ pytest .skip ("No metrics found" )
90+
91+
92+ def test_metrics_tree_walk () -> None :
93+ ctx = SessionContext ()
94+ ctx .sql ("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'a'), (4, 'b')" )
95+ df = ctx .sql ("SELECT column2, COUNT(*) FROM t GROUP BY column2" )
96+
97+ df .collect ()
98+ plan = df .execution_plan ()
99+
100+ results = plan .collect_metrics ()
101+ assert len (results ) >= 2
102+ for name , ms in results :
103+ assert isinstance (name , str )
104+ assert isinstance (ms , MetricsSet )
105+
106+
107+ def test_no_metrics_before_execution () -> None :
108+ ctx = SessionContext ()
109+ ctx .sql ("CREATE TABLE t AS VALUES (1), (2), (3)" )
110+ df = ctx .sql ("SELECT * FROM t" )
111+ plan = df .execution_plan ()
112+ ms = plan .metrics ()
113+ assert ms is None or ms .output_rows is None or ms .output_rows == 0
114+
115+
116+ def test_metrics_repr () -> None :
117+ ctx = SessionContext ()
118+ ctx .sql ("CREATE TABLE t AS VALUES (1), (2), (3)" )
119+ df = ctx .sql ("SELECT * FROM t" )
120+
121+ df .collect ()
122+ plan = df .execution_plan ()
123+
124+ for _ , ms in plan .collect_metrics ():
125+ r = repr (ms )
126+ assert isinstance (r , str )
127+ for metric in ms .metrics ():
128+ mr = repr (metric )
129+ assert isinstance (mr , str )
130+ assert len (mr ) > 0
131+ return
132+ pytest .skip ("No metrics found" )
133+
134+
135+ def test_collect_partitioned_metrics () -> None :
136+ ctx = SessionContext ()
137+ ctx .sql ("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')" )
138+ df = ctx .sql ("SELECT * FROM t WHERE column1 > 1" )
139+
140+ partitions = df .collect_partitioned ()
141+ plan = df .execution_plan ()
142+ assert len (partitions ) == plan .partition_count
143+
144+ # Metrics should be populated after collecting
145+ found_metrics = False
146+ for _ , ms in plan .collect_metrics ():
147+ if ms .output_rows is not None and ms .output_rows > 0 :
148+ found_metrics = True
149+ assert found_metrics
150+
151+
152+ def test_execute_stream_metrics () -> None :
153+ ctx = SessionContext ()
154+ ctx .sql ("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')" )
155+ df = ctx .sql ("SELECT * FROM t WHERE column1 > 1" )
156+
157+ stream = df .execute_stream ()
158+
159+ # Consume the stream (iterates over RecordBatches)
160+ batches = list (stream )
161+ assert len (batches ) >= 1
162+
163+ # Metrics should be populated after consuming the stream
164+ plan = df .execution_plan ()
165+ found_metrics = False
166+ for name , ms in plan .collect_metrics ():
167+ assert isinstance (name , str )
168+ assert isinstance (ms , MetricsSet )
169+ if ms .output_rows is not None and ms .output_rows > 0 :
170+ found_metrics = True
171+ assert found_metrics
172+
173+
174+ def test_execute_stream_partitioned_metrics () -> None :
175+ ctx = SessionContext ()
176+ ctx .sql ("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')" )
177+ df = ctx .sql ("SELECT * FROM t WHERE column1 > 1" )
178+
179+ streams = df .execute_stream_partitioned ()
180+
181+ # Consume all partition streams
182+ for stream in streams :
183+ for _ in stream :
184+ pass
185+
186+ # Metrics should be populated (FilterExec reports output_rows)
187+ plan = df .execution_plan ()
188+ found_metrics = False
189+ for _ , ms in plan .collect_metrics ():
190+ if ms .output_rows is not None and ms .output_rows > 0 :
191+ found_metrics = True
192+ assert found_metrics
0 commit comments