-
Notifications
You must be signed in to change notification settings - Fork 141
Expand file tree
/
Copy pathsession.py
More file actions
267 lines (226 loc) · 9.65 KB
/
session.py
File metadata and controls
267 lines (226 loc) · 9.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
import logging
from typing import Dict, Tuple, List, Optional, Any, Type
from databricks.sql.thrift_api.TCLIService import ttypes
from databricks.sql.types import SSLOptions
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
from databricks.sql.auth.common import ClientContext
from databricks.sql.exc import SessionAlreadyClosedError, DatabaseError, RequestError
from databricks.sql import __version__
from databricks.sql import USER_AGENT_NAME
from databricks.sql.backend.thrift_backend import ThriftDatabricksClient
from databricks.sql.backend.sea.backend import SeaDatabricksClient
from databricks.sql.backend.databricks_client import DatabricksClient
from databricks.sql.backend.types import SessionId, BackendType
from databricks.sql.common.unified_http_client import UnifiedHttpClient
from databricks.sql.common.agent import detect as detect_agent
logger = logging.getLogger(__name__)
class Session:
def __init__(
self,
server_hostname: str,
http_path: str,
http_client: UnifiedHttpClient,
http_headers: Optional[List[Tuple[str, str]]] = None,
session_configuration: Optional[Dict[str, Any]] = None,
catalog: Optional[str] = None,
schema: Optional[str] = None,
_use_arrow_native_complex_types: Optional[bool] = True,
**kwargs,
) -> None:
"""
Create a session to a Databricks SQL endpoint or a Databricks cluster.
This class handles all session-related behavior and communication with the backend.
"""
self.is_open = False
self.host = server_hostname
self.port = kwargs.get("_port", 443)
self.session_configuration = session_configuration
self.catalog = catalog
self.schema = schema
self.http_path = http_path
# Initialize autocommit state (JDBC default is True)
self._autocommit = True
user_agent_entry = kwargs.get("user_agent_entry")
if user_agent_entry is None:
user_agent_entry = kwargs.get("_user_agent_entry")
if user_agent_entry is not None:
logger.warning(
"[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. "
"This parameter will be removed in the upcoming releases."
)
if user_agent_entry:
self.useragent_header = "{}/{} ({})".format(
USER_AGENT_NAME, __version__, user_agent_entry
)
else:
self.useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__)
agent_product = detect_agent()
if agent_product:
self.useragent_header += " agent/{}".format(agent_product)
base_headers = [("User-Agent", self.useragent_header)]
all_headers = (http_headers or []) + base_headers
# Extract ?o=<workspaceId> from http_path for SPOG routing.
# On SPOG hosts, the httpPath contains ?o=<workspaceId> which routes Thrift
# requests via the URL. For SEA, telemetry, and feature flags (which use
# separate endpoints), we inject x-databricks-org-id as an HTTP header.
self._spog_headers = self._extract_spog_headers(http_path, all_headers)
if self._spog_headers:
all_headers = all_headers + list(self._spog_headers.items())
self.ssl_options = SSLOptions(
# Double negation is generally a bad thing, but we have to keep backward compatibility
tls_verify=not kwargs.get(
"_tls_no_verify", False
), # by default - verify cert and host
tls_verify_hostname=kwargs.get("_tls_verify_hostname", True),
tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"),
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
)
# Use the provided HTTP client (created in Connection)
self.http_client = http_client
# Create auth provider with HTTP client context
self.auth_provider = get_python_sql_connector_auth_provider(
server_hostname, http_client=self.http_client, **kwargs
)
self.backend = self._create_backend(
server_hostname,
http_path,
all_headers,
self.auth_provider,
_use_arrow_native_complex_types,
kwargs,
)
self.protocol_version = None
def _create_backend(
self,
server_hostname: str,
http_path: str,
all_headers: List[Tuple[str, str]],
auth_provider,
_use_arrow_native_complex_types: Optional[bool],
kwargs: dict,
) -> DatabricksClient:
"""Create and return the appropriate backend client."""
self.use_sea = kwargs.get("use_sea", False)
databricks_client_class: Type[DatabricksClient]
if self.use_sea:
logger.debug("Creating SEA backend client")
databricks_client_class = SeaDatabricksClient
else:
logger.debug("Creating Thrift backend client")
databricks_client_class = ThriftDatabricksClient
common_args = {
"server_hostname": server_hostname,
"port": self.port,
"http_path": http_path,
"http_headers": all_headers,
"auth_provider": auth_provider,
"ssl_options": self.ssl_options,
"http_client": self.http_client,
"_use_arrow_native_complex_types": _use_arrow_native_complex_types,
**kwargs,
}
return databricks_client_class(**common_args)
@staticmethod
def _extract_spog_headers(http_path, existing_headers):
"""Extract ?o=<workspaceId> from http_path and return as a header dict for SPOG routing."""
if not http_path or "?" not in http_path:
return {}
from urllib.parse import parse_qs
query_string = http_path.split("?", 1)[1]
params = parse_qs(query_string)
org_id = params.get("o", [None])[0]
if not org_id:
logger.debug(
"SPOG header extraction: http_path has query string but no ?o= param, "
"skipping x-databricks-org-id injection"
)
return {}
# Don't override if explicitly set
if any(k == "x-databricks-org-id" for k, _ in existing_headers):
logger.debug(
"SPOG header extraction: x-databricks-org-id already set by caller, "
"not overriding with ?o=%s from http_path",
org_id,
)
return {}
logger.debug(
"SPOG header extraction: injecting x-databricks-org-id=%s "
"(extracted from ?o= in http_path)",
org_id,
)
return {"x-databricks-org-id": org_id}
def get_spog_headers(self):
"""Returns SPOG routing headers (x-databricks-org-id) if ?o= was in http_path."""
return dict(self._spog_headers)
def open(self):
self._session_id = self.backend.open_session(
session_configuration=self.session_configuration,
catalog=self.catalog,
schema=self.schema,
)
self.protocol_version = self.get_protocol_version(self._session_id)
self.is_open = True
logger.info("Successfully opened session %s", str(self.guid_hex))
@staticmethod
def get_protocol_version(session_id: SessionId):
return session_id.protocol_version
@staticmethod
def server_parameterized_queries_enabled(protocolVersion):
if (
protocolVersion
and protocolVersion >= ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8
):
return True
else:
return False
@property
def session_id(self) -> SessionId:
"""Get the normalized session ID"""
return self._session_id
@property
def guid(self) -> Any:
"""Get the raw session ID (backend-specific)"""
return self._session_id.guid
@property
def guid_hex(self) -> str:
"""Get the session ID in hex format"""
return self._session_id.hex_guid
def get_autocommit(self) -> bool:
"""
Get the cached autocommit state for this session.
Returns:
bool: True if autocommit is enabled, False otherwise
"""
return self._autocommit
def set_autocommit(self, value: bool) -> None:
"""
Update the cached autocommit state for this session.
Args:
value: True to cache autocommit as enabled, False as disabled
"""
self._autocommit = value
def close(self) -> None:
"""Close the underlying session."""
logger.info("Closing session %s", self.guid_hex)
if not self.is_open:
logger.debug("Session appears to have been closed already")
return
try:
self.backend.close_session(self._session_id)
except RequestError as e:
if isinstance(e.args[1], SessionAlreadyClosedError):
logger.info("Session was closed by a prior request")
except DatabaseError as e:
if "Invalid SessionHandle" in str(e):
logger.warning(
"Attempted to close session that was already closed: %s", e
)
else:
logger.warning(
"Attempt to close session raised an exception at the server: %s", e
)
except Exception as e:
logger.error("Attempt to close session raised a local exception: %s", e)
self.is_open = False