@@ -697,79 +697,102 @@ def test_sample_profiler_sample_accepts_async_aware(self):
697697 sig = inspect .signature (SampleProfiler .sample )
698698 self .assertIn ("async_aware" , sig .parameters )
699699
700- def test_async_aware_all_uses_get_all_awaited_by (self ):
701- """Test that async_aware='all' calls get_all_awaited_by on unwinder."""
702- from unittest .mock import Mock , patch
703- from profiling .sampling .sample import SampleProfiler
704-
705- with patch ('profiling.sampling.sample._remote_debugging' ) as mock_rd :
706- mock_unwinder = Mock ()
707- mock_unwinder .get_all_awaited_by .return_value = []
708- mock_rd .RemoteUnwinder .return_value = mock_unwinder
709-
710- profiler = SampleProfiler (
711- pid = 12345 ,
712- sample_interval_usec = 1000 ,
713- all_threads = False
714- )
715- profiler .unwinder = mock_unwinder
716-
717- mock_collector = Mock ()
718- mock_collector .running = False # Stop immediately
700+ def test_async_aware_all_sees_sleeping_and_running_tasks (self ):
701+ """Test async_aware='all' captures both sleeping and CPU-running tasks."""
702+ # Sleeping task (awaiting)
703+ sleeping_task = MockTaskInfo (
704+ task_id = 1 ,
705+ task_name = "SleepingTask" ,
706+ coroutine_stack = [
707+ MockCoroInfo (
708+ task_name = "SleepingTask" ,
709+ call_stack = [MockFrameInfo ("sleeper.py" , 10 , "sleep_work" )]
710+ )
711+ ],
712+ awaited_by = []
713+ )
719714
720- # Sample with async_aware="all"
721- profiler .sample (mock_collector , duration_sec = 0.001 , async_aware = "all" )
715+ # CPU-running task (active)
716+ running_task = MockTaskInfo (
717+ task_id = 2 ,
718+ task_name = "RunningTask" ,
719+ coroutine_stack = [
720+ MockCoroInfo (
721+ task_name = "RunningTask" ,
722+ call_stack = [MockFrameInfo ("runner.py" , 20 , "cpu_work" )]
723+ )
724+ ],
725+ awaited_by = []
726+ )
722727
723- # Should have called get_all_awaited_by
724- mock_unwinder . get_all_awaited_by . assert_called ()
728+ # Both tasks returned by get_all_awaited_by
729+ awaited_info_list = [ MockAwaitedInfo ( thread_id = 100 , awaited_by = [ sleeping_task , running_task ])]
725730
726- def test_async_aware_running_uses_get_async_stack_trace (self ):
727- """Test that async_aware='running' calls get_async_stack_trace on unwinder."""
728- from unittest .mock import Mock , patch
729- from profiling .sampling .sample import SampleProfiler
731+ collector = PstatsCollector (sample_interval_usec = 1000 )
732+ collector .collect (awaited_info_list )
733+ collector .create_stats ()
730734
731- with patch ('profiling.sampling.sample._remote_debugging' ) as mock_rd :
732- mock_unwinder = Mock ()
733- mock_unwinder .get_async_stack_trace .return_value = []
734- mock_rd .RemoteUnwinder .return_value = mock_unwinder
735+ # Both tasks should be visible
736+ sleeping_key = ("sleeper.py" , 10 , "sleep_work" )
737+ running_key = ("runner.py" , 20 , "cpu_work" )
735738
736- profiler = SampleProfiler (
737- pid = 12345 ,
738- sample_interval_usec = 1000 ,
739- all_threads = False
740- )
741- profiler .unwinder = mock_unwinder
739+ self .assertIn (sleeping_key , collector .stats )
740+ self .assertIn (running_key , collector .stats )
742741
743- mock_collector = Mock ()
744- mock_collector .running = False
742+ # Task markers should also be present
743+ task_keys = [k for k in collector .stats if k [0 ] == "<task>" ]
744+ self .assertGreater (len (task_keys ), 0 , "Should have <task> markers in stats" )
745745
746- profiler .sample (mock_collector , duration_sec = 0.001 , async_aware = "running" )
746+ # Verify task names are in the markers
747+ task_names = [k [2 ] for k in task_keys ]
748+ self .assertTrue (
749+ any ("SleepingTask" in name for name in task_names ),
750+ "SleepingTask should be in task markers"
751+ )
752+ self .assertTrue (
753+ any ("RunningTask" in name for name in task_names ),
754+ "RunningTask should be in task markers"
755+ )
747756
748- mock_unwinder .get_async_stack_trace .assert_called ()
757+ def test_async_aware_running_sees_only_running_task (self ):
758+ """Test async_aware='running' only shows the currently running task stack."""
759+ # Only the running task's stack is returned by get_async_stack_trace
760+ running_task = MockTaskInfo (
761+ task_id = 2 ,
762+ task_name = "RunningTask" ,
763+ coroutine_stack = [
764+ MockCoroInfo (
765+ task_name = "RunningTask" ,
766+ call_stack = [MockFrameInfo ("runner.py" , 20 , "cpu_work" )]
767+ )
768+ ],
769+ awaited_by = []
770+ )
749771
750- def test_async_aware_none_uses_get_stack_trace (self ):
751- """Test that async_aware=None uses regular get_stack_trace."""
752- from unittest .mock import Mock , patch
753- from profiling .sampling .sample import SampleProfiler
772+ # get_async_stack_trace only returns the running task
773+ awaited_info_list = [MockAwaitedInfo (thread_id = 100 , awaited_by = [running_task ])]
754774
755- with patch ('profiling.sampling.sample._remote_debugging' ) as mock_rd :
756- mock_unwinder = Mock ()
757- mock_unwinder .get_stack_trace .return_value = []
758- mock_rd .RemoteUnwinder .return_value = mock_unwinder
775+ collector = PstatsCollector (sample_interval_usec = 1000 )
776+ collector .collect (awaited_info_list )
777+ collector .create_stats ()
759778
760- profiler = SampleProfiler (
761- pid = 12345 ,
762- sample_interval_usec = 1000 ,
763- all_threads = False
764- )
765- profiler .unwinder = mock_unwinder
779+ # Only running task should be visible
780+ running_key = ("runner.py" , 20 , "cpu_work" )
781+ self .assertIn (running_key , collector .stats )
766782
767- mock_collector = Mock ()
768- mock_collector .running = False
783+ # Verify we don't see the sleeping task (it wasn't in the input)
784+ sleeping_key = ("sleeper.py" , 10 , "sleep_work" )
785+ self .assertNotIn (sleeping_key , collector .stats )
769786
770- profiler .sample (mock_collector , duration_sec = 0.001 , async_aware = None )
787+ # Task marker for running task should be present
788+ task_keys = [k for k in collector .stats if k [0 ] == "<task>" ]
789+ self .assertGreater (len (task_keys ), 0 , "Should have <task> markers in stats" )
771790
772- mock_unwinder .get_stack_trace .assert_called ()
791+ task_names = [k [2 ] for k in task_keys ]
792+ self .assertTrue (
793+ any ("RunningTask" in name for name in task_names ),
794+ "RunningTask should be in task markers"
795+ )
773796
774797
775798if __name__ == "__main__" :
0 commit comments