11import logging
22
33from concurrent .futures import ThreadPoolExecutor , Future
4- from typing import List , Union
4+ from typing import Callable , List , Optional , Union
55
66from databricks .sql .cloudfetch .downloader import (
77 ResultSetDownloadHandler ,
@@ -22,6 +22,7 @@ def __init__(
2222 max_download_threads : int ,
2323 lz4_compressed : bool ,
2424 ssl_options : SSLOptions ,
25+ expiry_callback : Callable [[TSparkArrowResultLink ], None ],
2526 ):
2627 self ._pending_links : List [TSparkArrowResultLink ] = []
2728 for link in links :
@@ -40,6 +41,7 @@ def __init__(
4041
4142 self ._downloadable_result_settings = DownloadableResultSettings (lz4_compressed )
4243 self ._ssl_options = ssl_options
44+ self ._expiry_callback = expiry_callback
4345
4446 def get_next_downloaded_file (
4547 self , next_row_offset : int
@@ -62,7 +64,6 @@ def get_next_downloaded_file(
6264
6365 # No more files to download from this batch of links
6466 if len (self ._download_tasks ) == 0 :
65- self ._shutdown_manager ()
6667 return None
6768
6869 task = self ._download_tasks .pop (0 )
@@ -81,6 +82,34 @@ def get_next_downloaded_file(
8182
8283 return file
8384
85+ def cancel_tasks_from_offset (self , start_row_offset : int ):
86+ """
87+ Cancel all download tasks starting from a specific row offset.
88+ This is used when links expire and we need to restart from a certain point.
89+ """
90+
91+ def to_cancel (link : TSparkArrowResultLink ) -> bool :
92+ return link .startRowOffset < start_row_offset
93+
94+ tasks_to_cancel = [task for task in self ._download_tasks if to_cancel (task .link )]
95+ for task in tasks_to_cancel :
96+ task .cancel ()
97+ logger .info (
98+ f"ResultFileDownloadManager: cancelled { len (tasks_to_cancel )} tasks from offset { start_row_offset } "
99+ )
100+
101+ # Remove cancelled tasks from the download queue
102+ tasks_to_keep = [task for task in self ._download_tasks if not to_cancel (task .link )]
103+ self ._download_tasks = tasks_to_keep
104+
105+ pending_links_to_keep = [
106+ link for link in self ._pending_links if not to_cancel (link )
107+ ]
108+ self ._pending_links = pending_links_to_keep
109+ logger .info (
110+ f"ResultFileDownloadManager: removed { len (self ._pending_links ) - len (pending_links_to_keep )} links from pending links"
111+ )
112+
84113 def _schedule_downloads (self ):
85114 """
86115 While download queue has a capacity, peek pending links and submit them to thread pool.
@@ -97,8 +126,10 @@ def _schedule_downloads(self):
97126 settings = self ._downloadable_result_settings ,
98127 link = link ,
99128 ssl_options = self ._ssl_options ,
129+ expiry_callback = self ._expiry_callback ,
100130 )
101131 task = self ._thread_pool .submit (handler .run )
132+ task .link = link
102133 self ._download_tasks .append (task )
103134
104135 def add_link (self , link : TSparkArrowResultLink ):
0 commit comments