diff --git a/.gitignore b/.gitignore index 2ae38dbc6..9d02d5775 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,7 @@ +# Profiling credentials +examples/credentials.env + # Created by https://www.toptal.com/developers/gitignore/api/python,macos # Edit at https://www.toptal.com/developers/gitignore?templates=python,macos diff --git a/examples/PROFILING_README.md b/examples/PROFILING_README.md new file mode 100644 index 000000000..15226c9e9 --- /dev/null +++ b/examples/PROFILING_README.md @@ -0,0 +1,167 @@ +# SET TAGS Profiling Scripts + +## Context + +These scripts measure the performance of two approaches for managing tags on Databricks tables and columns: + +**Approach A (direct ALTER)**: Call `ALTER TABLE SET TAGS` or `ALTER TABLE ALTER COLUMN SET TAGS` directly, overwriting any existing tags without reading them first. SET TAGS is idempotent — setting the same key overwrites the value. + +**Approach B (read then write)**: First query `system.information_schema.column_tags` or `system.information_schema.table_tags` to read existing tags, compute a diff, then issue ALTERs only for changes. + +The goal is to determine whether the information_schema read step is worth the cost, or whether direct ALTER is faster even though it may redundantly set unchanged tags. + +## Prerequisites + +- Python 3.x with `databricks-sql-connector` installed (`pip install -e .` from repo root) +- A Databricks SQL warehouse +- 64 tables (`table1` through `table64`) with 128 STRING columns each. Create them by running `profile_column_tags.py` without `--skip-setup`. +- Credentials in `examples/credentials.env` (gitignored). Copy and edit: + +```bash +cp examples/credentials.env.example examples/credentials.env +# Edit credentials.env with your workspace details +``` + +The file format is: +``` +SERVER_HOSTNAME=your-workspace.cloud.databricks.com +HTTP_PATH=/sql/1.0/warehouses/your_warehouse_id +ACCESS_TOKEN=your_token +CATALOG=your_catalog +SCHEMA=your_schema +``` + +All scripts read from this file via `load_credentials.py`. To switch workspaces, just edit `credentials.env`. + +## Scripts + +### profile_column_tags.py — Direct ALTER column tags + +Sets tags on columns directly via ALTER statements. No information_schema reads. + +```bash +# Create tables + validate +python examples/profile_column_tags.py --columns 1 --tags 1 --threads 1 --iterations 1 --validate + +# Full experiment: 100 columns, 9 tags each, 8 threads, 3 iterations +python examples/profile_column_tags.py --columns 100 --tags 9 --threads 8 --iterations 3 --skip-setup +``` + +Arguments: +- `--columns`: Number of columns to tag per table +- `--tags`: Number of tags per ALTER command +- `--threads`: Concurrent connections +- `--iterations`: Times to repeat the full 64-table sweep +- `--skip-setup`: Skip table creation +- `--validate`: Force 1 iteration for quick validation + +Output: `examples/results/column_tags/` + +### profile_table_tags.py — Direct ALTER table tags + +Sets tags on tables directly via ALTER statements. One ALTER per table. + +```bash +python examples/profile_table_tags.py --tags 1 --threads 8 --iterations 3 +``` + +Arguments: +- `--tags`: Number of tags per ALTER command +- `--threads`: Concurrent connections +- `--iterations`: Times to repeat the full 64-table sweep +- `--validate`: Force 1 iteration + +Output: `examples/results/table_tags/` + +### profile_read_then_write_tags.py — information_schema column_tags SELECT + +Queries `system.information_schema.column_tags` for each table. No ALTER — measures the read cost only. + +```bash +python examples/profile_read_then_write_tags.py --threads 1 --iterations 3 +``` + +Arguments: +- `--threads`: Concurrent connections +- `--iterations`: Times to repeat the full 64-table sweep +- `--validate`: Force 1 iteration + +Output: `examples/results/read_then_write/` + +### profile_read_then_write_table_tags.py — information_schema table_tags SELECT + +Queries `system.information_schema.table_tags` for each table. No ALTER — measures the read cost only. + +```bash +python examples/profile_read_then_write_table_tags.py --threads 1 --iterations 3 +``` + +Arguments: +- `--threads`: Concurrent connections +- `--iterations`: Times to repeat the full 64-table sweep +- `--validate`: Force 1 iteration + +Output: `examples/results/read_then_write_table_tags/` + +### cleanup_column_tags.py — Remove all tags + +Removes all column tags and table tags from all 64 tables using 32 threads. Run this to reset state between experiments. + +```bash +python examples/cleanup_column_tags.py +``` + +### plot_comparison.py — Generate charts + +Reads all report files and generates comparison charts as PNGs. + +```bash +pip install matplotlib +python examples/plot_comparison.py +``` + +Output: +- `examples/results/comparison_column_tags.png` — column tags: ALTER vs info_schema +- `examples/results/comparison_table_tags.png` — table tags: ALTER vs info_schema + +Each PNG has 4 charts: wall-clock time, throughput, P50 latency, P99 latency, all plotted against thread count. + +## Running the definitive experiment + +```bash +# Step 1: Create tables (once) +python examples/profile_column_tags.py --columns 1 --tags 1 --threads 1 --iterations 1 --validate + +# Step 2: Run info_schema reads across thread counts +# Stop early if latency is already unacceptable +for n in 1 2 4 8 16 32 64; do + python examples/profile_read_then_write_tags.py --threads $n --iterations 3 +done + +# Step 3: Run direct ALTERs across thread counts +for n in 1 2 4 8 16 32 64; do + python examples/profile_column_tags.py --columns 100 --tags 9 --threads $n --iterations 3 --skip-setup +done + +# Step 4: Generate charts +python examples/plot_comparison.py +``` + +## Connector instrumentation + +The scripts capture retry behavior via `[PROFILE]` log lines added to the connector: +- `src/databricks/sql/backend/thrift_backend.py` — logs per-attempt timing, success, statement IDs, and retry sleeps in `make_request()` +- `src/databricks/sql/auth/retry.py` — logs urllib3-level retry decisions (`should_retry`) and sleep durations with HTTP status codes, Thrift method names, and SQL text + +These are written to `*_retries.log` files alongside each report. Use `grep "[PROFILE]"` to filter. + +## Output structure + +Each script run produces three files: +- `*_report.md` — Markdown report with latency percentiles, throughput, error analysis, retry analysis +- `*_data.jsonl` — Raw per-operation data (one JSON line per ALTER or SELECT) +- `*_retries.log` — Full connector debug logs with `[PROFILE]` instrumentation + +## Key finding + +On Azure workspaces, `system.information_schema.column_tags` queries can take 60-110 seconds under concurrency due to server-side queuing (visible as repeated `GetOperationStatus` polling in logs). Direct ALTER SET TAGS consistently completes in ~500ms regardless of concurrency. The information_schema read alone is slower than performing all the writes it was meant to optimize. \ No newline at end of file diff --git a/examples/cleanup_column_tags.py b/examples/cleanup_column_tags.py new file mode 100644 index 000000000..5fadd97df --- /dev/null +++ b/examples/cleanup_column_tags.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +"""Remove all column tags and table tags from all 64 tables using 32 threads.""" + +import sys +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed + +sys.stdout.reconfigure(line_buffering=True) + +import urllib3 +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + +from databricks import sql +from load_credentials import load_credentials + +_creds = load_credentials() +SERVER_HOSTNAME = _creds["SERVER_HOSTNAME"] +HTTP_PATH = _creds["HTTP_PATH"] +ACCESS_TOKEN = _creds["ACCESS_TOKEN"] +CATALOG = _creds["CATALOG"] +SCHEMA = _creds["SCHEMA"] + +NUM_TABLES = 64 +NUM_THREADS = 32 + + +def cleanup_table(table_name): + table_fqn = f"`{CATALOG}`.`{SCHEMA}`.{table_name}" + total_removed = 0 + + with sql.connect( + server_hostname=SERVER_HOSTNAME, + http_path=HTTP_PATH, + access_token=ACCESS_TOKEN, + _tls_no_verify=True, + enable_telemetry=False, + ) as conn: + with conn.cursor() as cursor: + # --- Clean up column tags --- + cursor.execute( + f"SELECT column_name, tag_name FROM system.information_schema.column_tags " + f"WHERE catalog_name = '{CATALOG}' AND schema_name = '{SCHEMA}' AND table_name = '{table_name}'" + ) + col_rows = cursor.fetchall() + + if col_rows: + col_tags = defaultdict(list) + for row in col_rows: + col_tags[row[0]].append(row[1]) + + for col, tags in col_tags.items(): + tag_list = ", ".join(f"'{tag}'" for tag in tags) + cursor.execute(f"ALTER TABLE {table_fqn} ALTER COLUMN {col} UNSET TAGS ({tag_list})") + + total_removed += len(col_rows) + + # --- Clean up table tags --- + cursor.execute( + f"SELECT tag_name FROM system.information_schema.table_tags " + f"WHERE catalog_name = '{CATALOG}' AND schema_name = '{SCHEMA}' AND table_name = '{table_name}'" + ) + tbl_rows = cursor.fetchall() + + if tbl_rows: + tag_list = ", ".join(f"'{row[0]}'" for row in tbl_rows) + cursor.execute(f"ALTER TABLE {table_fqn} UNSET TAGS ({tag_list})") + total_removed += len(tbl_rows) + + print(f"{table_name}: removed {len(col_rows)} column tags, {len(tbl_rows)} table tags") + return total_removed + + +total_removed = 0 +with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor: + futures = { + executor.submit(cleanup_table, f"table{t}"): t + for t in range(1, NUM_TABLES + 1) + } + for f in as_completed(futures): + total_removed += f.result() + +print(f"\nDone. Removed {total_removed} total tags (column + table).") diff --git a/examples/credentials.env.example b/examples/credentials.env.example new file mode 100644 index 000000000..406417236 --- /dev/null +++ b/examples/credentials.env.example @@ -0,0 +1,9 @@ +# SET TAGS Profiling — Workspace Credentials +# Copy this file to credentials.env and fill in your values. +# credentials.env is gitignored and will not be committed. + +SERVER_HOSTNAME=your-workspace.cloud.databricks.com +HTTP_PATH=/sql/1.0/warehouses/your_warehouse_id +ACCESS_TOKEN=your_access_token +CATALOG=your_catalog +SCHEMA=your_schema diff --git a/examples/load_credentials.py b/examples/load_credentials.py new file mode 100644 index 000000000..ef919cd8a --- /dev/null +++ b/examples/load_credentials.py @@ -0,0 +1,31 @@ +"""Load credentials from examples/credentials.env""" + +import os + + +def load_credentials(env_path=None): + """Read credentials.env and return a dict of key=value pairs.""" + if env_path is None: + env_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "credentials.env") + + if not os.path.exists(env_path): + raise FileNotFoundError( + f"Credentials file not found: {env_path}\n" + f"Copy examples/credentials.env.example to examples/credentials.env and fill in your values." + ) + + creds = {} + with open(env_path) as f: + for line in f: + line = line.strip() + if not line or line.startswith("#"): + continue + key, _, value = line.partition("=") + creds[key.strip()] = value.strip() + + required = ["SERVER_HOSTNAME", "HTTP_PATH", "ACCESS_TOKEN", "CATALOG", "SCHEMA"] + missing = [k for k in required if k not in creds] + if missing: + raise ValueError(f"Missing required credentials: {', '.join(missing)}") + + return creds diff --git a/examples/plot_comparison.py b/examples/plot_comparison.py new file mode 100644 index 000000000..a57853da2 --- /dev/null +++ b/examples/plot_comparison.py @@ -0,0 +1,269 @@ +#!/usr/bin/env python3 +""" +Plot all profiling results: info_schema SELECTs vs direct ALTERs. + +Auto-discovers all report MD files across all result directories. +Generates SEPARATE PNGs for column tags and table tags. +Each has 4 charts: wall-clock, throughput, P50, P99. + +Usage: + python examples/plot_comparison.py +""" + +import re +import os +import sys +from collections import defaultdict + +sys.stdout.reconfigure(line_buffering=True) + +try: + import matplotlib.pyplot as plt +except ImportError: + print("Install matplotlib: pip install matplotlib") + sys.exit(1) + +RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results") + + +def parse_report(filepath): + """Extract key metrics from a report MD file.""" + metrics = {} + with open(filepath) as f: + content = f.read() + + m = re.search(r"\*\*Total wall-clock\*\*:\s*([\d.]+)s", content) + if m: + metrics["wall_clock_s"] = float(m.group(1)) + + m = re.search(r"\*\*(ALTERs/sec|SELECTs/sec|Operations/sec)\*\*:\s*([\d.]+)", content) + if m: + metrics["throughput_ops"] = float(m.group(2)) + + for pct in ["p50", "p90", "p95", "p99"]: + m = re.search(rf"\|\s*{pct}\s*\|\s*([\d.]+)\s*\|", content) + if m: + metrics[pct] = float(m.group(1)) + + m = re.search(r"\|\s*max\s*\|\s*([\d.]+)\s*\|", content) + if m: + metrics["max"] = float(m.group(1)) + + m = re.search(r"\|\s*count\s*\|\s*([\d.]+)\s*\|", content) + if m: + metrics["count"] = int(float(m.group(1))) + + m = re.search(r"\*\*Threads\*\*:\s*(\d+)", content) + if m: + metrics["threads"] = int(m.group(1)) + + m = re.search(r"\*\*Iterations\*\*:\s*(\d+)", content) + if m: + metrics["iterations"] = int(m.group(1)) + + m = re.search(r"\*\*Columns tagged per table\*\*:\s*(\d+)", content) + if m: + metrics["columns"] = int(m.group(1)) + + m = re.search(r"\*\*Tables per iteration\*\*:\s*(\d+)", content) + if m: + metrics["tables_per_iteration"] = int(m.group(1)) + + # Also match older reports that used "Tables": N + if "tables_per_iteration" not in metrics: + m = re.search(r"\*\*Total SELECTs\*\*:\s*(\d+)", content) + iters = metrics.get("iterations", 1) + if m and iters: + metrics["tables_per_iteration"] = int(float(m.group(1))) // iters + + m = re.search(r"\*\*Tags per ALTER\*\*:\s*(\d+)", content) + if m: + metrics["tags"] = int(m.group(1)) + + return metrics + + +def classify_report(dirpath, filename): + """Classify a report: (category, type) where category is 'column' or 'table'.""" + dirpath_lower = dirpath.lower() + filename_lower = filename.lower() + + if "read_then_write_table_tags" in dirpath_lower or filename_lower.startswith("rwtt_"): + return "table", "info_schema" + elif "read_then_write" in dirpath_lower or filename_lower.startswith("rw_"): + return "column", "info_schema" + elif "table_tags" in dirpath_lower or filename_lower.startswith("tt_"): + return "table", "alter" + elif "column_tags" in dirpath_lower or filename_lower.startswith("c"): + return "column", "alter" + else: + return "unknown", "unknown" + + +def discover_reports(): + """Walk all result directories and collect report data, split by category.""" + # {category: {series_label: {threads: metrics}}} + categories = defaultdict(lambda: defaultdict(dict)) + + for dirpath, _, filenames in os.walk(RESULTS_DIR): + for fname in sorted(filenames): + if not fname.endswith("_report.md"): + continue + + filepath = os.path.join(dirpath, fname) + metrics = parse_report(filepath) + + if "threads" not in metrics: + continue + + category, report_type = classify_report(dirpath, fname) + if category == "unknown": + continue + + threads = metrics["threads"] + + tbl = metrics.get("tables_per_iteration", "?") + + if report_type == "alter" and category == "column": + cols = metrics.get("columns", "?") + tags = metrics.get("tags", "?") + label = f"ALTER column tags (columns={cols}, tags_per_column={tags}, tables={tbl})" + elif report_type == "alter" and category == "table": + tags = metrics.get("tags", "?") + label = f"ALTER table tags (tags={tags}, tables={tbl})" + elif report_type == "info_schema" and category == "column": + label = f"info_schema column_tags SELECT (tables={tbl})" + elif report_type == "info_schema" and category == "table": + label = f"info_schema table_tags SELECT (tables={tbl})" + else: + continue + + # Keep the one with more iterations + existing = categories[category][label].get(threads) + if existing and metrics.get("iterations", 0) <= existing.get("iterations", 0): + continue + + # Compute tables/sec from wall-clock and tables_per_iteration + tpi = metrics.get("tables_per_iteration") + wc = metrics.get("wall_clock_s") + if tpi and wc and wc > 0: + metrics["tables_per_sec"] = round(tpi / wc, 2) + + categories[category][label][threads] = metrics + print(f" [{category}] {label} threads={threads}: " + f"wall={metrics.get('wall_clock_s', '?')}s, " + f"p50={metrics.get('p50', '?')}ms, " + f"tables/s={metrics.get('tables_per_sec', '?')} " + f"[{fname}]") + + return categories + + +def build_style_map(series): + """Assign colors and styles to series labels.""" + colors_info = ["#d62728", "#ff7f0e"] + colors_alter = ["#1f77b4", "#2ca02c", "#9467bd", "#17becf", "#8c564b"] + info_idx = 0 + alter_idx = 0 + style_map = {} + + for label in sorted(series.keys()): + if "info_schema" in label: + style_map[label] = {"color": colors_info[info_idx % len(colors_info)], "marker": "o", "linestyle": "--"} + info_idx += 1 + else: + style_map[label] = {"color": colors_alter[alter_idx % len(colors_alter)], "marker": "s", "linestyle": "-"} + alter_idx += 1 + + return style_map + + +def plot_charts(series, style_map, chart_configs, suptitle, output_path): + """Generate a chart PNG with len(chart_configs) subplots.""" + n = len(chart_configs) + cols = 2 + rows = (n + 1) // 2 + fig, axes = plt.subplots(rows, cols, figsize=(16, 6 * rows)) + if rows == 1: + axes = [axes] + + for idx, (metric_key, ylabel, title) in enumerate(chart_configs): + ax = axes[idx // cols][idx % cols] + for label, thread_data in sorted(series.items()): + threads = sorted(thread_data.keys()) + values = [thread_data[t].get(metric_key) for t in threads] + if any(v is not None for v in values): + s = style_map[label] + ax.plot(threads, values, marker=s["marker"], linestyle=s["linestyle"], + color=s["color"], linewidth=2, label=label, markersize=8) + ax.set_xlabel("Thread Count") + ax.set_ylabel(ylabel) + ax.set_title(title) + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # Hide unused subplot if odd number of charts + if n % 2 == 1: + axes[rows - 1][1].set_visible(False) + + plt.suptitle(suptitle, fontsize=14, fontweight="bold") + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" Chart saved to: {output_path}") + + +def plot_category(category_name, series, output_dir): + """Generate two PNGs per category: table-level comparison + individual operation detail.""" + if not series: + print(f" No data for {category_name}, skipping.") + return + + style_map = build_style_map(series) + title_label = "Column Tags" if category_name == "column" else "Table Tags" + + # Chart 1: Table-level comparison (apples-to-apples across approaches) + table_charts = [ + ("wall_clock_s", "Wall-Clock Time (seconds)", "Wall-Clock Time vs Thread Count (Lower is better)"), + ("tables_per_sec", "Tables / second", "Tables Processed per Second vs Thread Count (Higher is better)"), + ] + plot_charts( + series, style_map, table_charts, + f"{title_label}: Table-Level Comparison — info_schema SELECT vs Direct ALTER", + os.path.join(output_dir, f"comparison_{category_name}_tags_tables.png"), + ) + + # Chart 2: Individual operation detail (per-op latency) + op_charts = [ + ("throughput_ops", "Individual Operations / second", "Individual Op Throughput vs Thread Count (Higher is better)"), + ("p50", "P50 Latency per Op (ms)", "P50 Latency vs Thread Count (Lower is better)"), + ("p99", "P99 Latency per Op (ms)", "P99 Latency vs Thread Count (Lower is better)"), + ("max", "Max Latency per Op (ms)", "Max Latency vs Thread Count (Lower is better)"), + ] + plot_charts( + series, style_map, op_charts, + f"{title_label}: Individual Operation Detail", + os.path.join(output_dir, f"comparison_{category_name}_tags_ops.png"), + ) + + +if __name__ == "__main__": + print("Discovering results...\n") + categories = discover_reports() + + total_series = sum(len(v) for v in categories.values()) + total_points = sum(len(td) for cat in categories.values() for td in cat.values()) + print(f"\nFound {total_series} series across {total_points} data points.\n") + + if "column" in categories: + print("Generating column tags charts...") + plot_category("column", categories["column"], RESULTS_DIR) + + if "table" in categories: + print("Generating table tags charts...") + plot_category("table", categories["table"], RESULTS_DIR) + + if not categories: + print("No results found. Run experiments first.") + else: + print("\nDrag the PNGs into Google Docs.") diff --git a/examples/profile_column_tags.py b/examples/profile_column_tags.py new file mode 100644 index 000000000..49a19ebc4 --- /dev/null +++ b/examples/profile_column_tags.py @@ -0,0 +1,687 @@ +#!/usr/bin/env python3 +""" +Profile SET COLUMN TAGS performance on Databricks. + +Usage: + # Quick validation (1 col x 1 tag x 1 thread x 1 iteration = 20 ALTERs) + python examples/profile_column_tags.py --columns 1 --tags 1 --threads 1 --iterations 1 --validate + + # Single experiment + python examples/profile_column_tags.py --columns 2 --tags 4 --threads 8 --iterations 10 + + # Full sweep + for c in 1 2 4; do + for t in 1 2 4; do + for n in 1 2 4 8 16; do + python examples/profile_column_tags.py --columns $c --tags $t --threads $n --iterations 10 + done + done + done +""" + +import argparse +import json +import logging +import os +import random +import re +import statistics +import string +import sys +import threading +import time +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime +from queue import Empty, Queue + +# Force unbuffered stdout so output is visible when piped through grep +sys.stdout.reconfigure(line_buffering=True) + +import urllib3 +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + +from databricks import sql + +# ============================================================ +# CONFIGURATION — loaded from examples/credentials.env +# ============================================================ +from load_credentials import load_credentials +_creds = load_credentials() +SERVER_HOSTNAME = _creds["SERVER_HOSTNAME"] +HTTP_PATH = _creds["HTTP_PATH"] +ACCESS_TOKEN = _creds["ACCESS_TOKEN"] +CATALOG = _creds["CATALOG"] +SCHEMA = _creds["SCHEMA"] +# ============================================================ + +NUM_TABLES = 128 # total tables available (table1..table128) +MAX_COLUMNS = 128 # tables always created with this many columns +RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results", "column_tags") + + +# --------------------------------------------------------------------------- +# Logging setup +# --------------------------------------------------------------------------- + +class ProfileLogHandler(logging.Handler): + """Captures [PROFILE] log lines for retry analysis.""" + + def __init__(self): + super().__init__() + self.records: list = [] + + def emit(self, record): + msg = record.getMessage() + if "[PROFILE]" in msg: + self.records.append( + { + "timestamp": record.created, + "thread": record.threadName, + "message": msg, + } + ) + + +def setup_logging(log_path: str) -> ProfileLogHandler: + """Configure logging: file handler for all connector logs, profile handler for [PROFILE] lines.""" + profile_handler = ProfileLogHandler() + profile_handler.setLevel(logging.INFO) + + file_handler = logging.FileHandler(log_path, mode="w") + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter( + logging.Formatter("%(asctime)s %(threadName)s %(name)s %(levelname)s %(message)s") + ) + + for logger_name in [ + "databricks.sql.backend.thrift_backend", + "databricks.sql.auth.retry", + "databricks.sql.client", + ]: + lgr = logging.getLogger(logger_name) + lgr.setLevel(logging.DEBUG) + lgr.addHandler(profile_handler) + lgr.addHandler(file_handler) + + return profile_handler + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def conn_params() -> dict: + return { + "server_hostname": SERVER_HOSTNAME, + "http_path": HTTP_PATH, + "access_token": ACCESS_TOKEN, + "_tls_no_verify": True, + "enable_telemetry": False, + } + + +def random_tag_value(length: int = 5) -> str: + return "".join(random.choices(string.ascii_lowercase, k=length)) + + +def build_alter_sql(table_fqn: str, column_name: str, num_tags: int) -> str: + tags = ", ".join(f"'key{i}' = '{random_tag_value()}'" for i in range(1, num_tags + 1)) + return f"ALTER TABLE {table_fqn} ALTER COLUMN {column_name} SET TAGS ({tags})" + + +def percentile(data: list, p: float) -> float: + """Return the p-th percentile (0-100) of data.""" + if not data: + return 0.0 + sorted_data = sorted(data) + k = (len(sorted_data) - 1) * (p / 100.0) + f = int(k) + c = f + 1 + if c >= len(sorted_data): + return sorted_data[f] + return sorted_data[f] + (k - f) * (sorted_data[c] - sorted_data[f]) + + +def latency_stats(latencies: list) -> dict: + """Compute full latency statistics for a list of ms values.""" + if not latencies: + return {k: 0.0 for k in ["count", "min", "max", "mean", "stdev", "p50", "p90", "p95", "p99"]} + return { + "count": len(latencies), + "min": min(latencies), + "max": max(latencies), + "mean": statistics.mean(latencies), + "stdev": statistics.stdev(latencies) if len(latencies) > 1 else 0.0, + "p50": percentile(latencies, 50), + "p90": percentile(latencies, 90), + "p95": percentile(latencies, 95), + "p99": percentile(latencies, 99), + } + + +# --------------------------------------------------------------------------- +# Setup +# --------------------------------------------------------------------------- + +def setup_tables(): + """Create NUM_TABLES tables with MAX_COLUMNS STRING columns each.""" + print(f"Setting up {NUM_TABLES} tables with {MAX_COLUMNS} columns each...") + with sql.connect(**conn_params()) as connection: + with connection.cursor() as cursor: + cursor.execute(f"USE CATALOG `{CATALOG}`") + cursor.execute(f"USE SCHEMA `{SCHEMA}`") + for t in range(1, NUM_TABLES + 1): + cols = ", ".join(f"column{c} STRING" for c in range(1, MAX_COLUMNS + 1)) + ddl = f"CREATE TABLE IF NOT EXISTS table{t} ({cols})" + print(f" Creating table{t}...", end=" ", flush=True) + cursor.execute(ddl) + print("done") + print("Setup complete.\n") + + +# --------------------------------------------------------------------------- +# Worker +# --------------------------------------------------------------------------- + +def worker( + thread_id: int, + table_queue: Queue, + num_columns: int, + num_tags: int, + alter_results: list, + table_results: list, + results_lock: threading.Lock, +): + """Worker thread: pulls tables from queue, ALTERs all columns, records metrics.""" + local_alter_results = [] + local_table_results = [] + table_fqn_prefix = f"`{CATALOG}`.`{SCHEMA}`" + + with sql.connect(**conn_params()) as connection: + with connection.cursor() as cursor: + # Warmup + cursor.execute("SELECT 1") + + while True: + try: + table_name = table_queue.get_nowait() + except Empty: + break + + table_fqn = f"{table_fqn_prefix}.{table_name}" + table_start = time.perf_counter() + table_errors = 0 + + for c in range(1, num_columns + 1): + column_name = f"column{c}" + alter_sql = build_alter_sql(table_fqn, column_name, num_tags) + + cmd_start = time.perf_counter() + success = True + error_type = None + error_message = None + error_context = None + + try: + cursor.execute(alter_sql) + except Exception as e: + success = False + error_type = type(e).__name__ + error_message = str(e)[:500] + error_context = getattr(e, "context", None) + table_errors += 1 + + cmd_end = time.perf_counter() + latency_ms = (cmd_end - cmd_start) * 1000 + + local_alter_results.append( + { + "table": table_name, + "column": column_name, + "thread_id": thread_id, + "latency_ms": round(latency_ms, 2), + "success": success, + "error_type": error_type, + "error_message": error_message, + "error_context": str(error_context) if error_context else None, + "timestamp": cmd_start, + } + ) + + table_end = time.perf_counter() + table_latency_ms = (table_end - table_start) * 1000 + + local_table_results.append( + { + "table": table_name, + "thread_id": thread_id, + "latency_ms": round(table_latency_ms, 2), + "num_alters": num_columns, + "num_errors": table_errors, + "alters_per_sec": round(num_columns / (table_latency_ms / 1000), 2) + if table_latency_ms > 0 + else 0, + } + ) + + with results_lock: + alter_results.extend(local_alter_results) + table_results.extend(local_table_results) + + +# --------------------------------------------------------------------------- +# Run one iteration +# --------------------------------------------------------------------------- + +def run_iteration( + iteration: int, + num_columns: int, + num_tags: int, + num_threads: int, + tables_per_iteration: int, +) -> tuple: + """Run a single iteration: tables_per_iteration tables distributed across num_threads threads.""" + table_queue = Queue() + start = ((iteration - 1) * tables_per_iteration) % NUM_TABLES + for i in range(tables_per_iteration): + table_idx = start + i + 1 + table_queue.put(f"table{table_idx}") + + alter_results: list = [] + table_results: list = [] + results_lock = threading.Lock() + + iter_start = time.perf_counter() + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [] + for tid in range(num_threads): + f = executor.submit( + worker, + tid, + table_queue, + num_columns, + num_tags, + alter_results, + table_results, + results_lock, + ) + futures.append(f) + + for f in as_completed(futures): + f.result() # raise any thread exceptions + + iter_end = time.perf_counter() + duration_s = iter_end - iter_start + + # Tag each result with iteration number + for r in alter_results: + r["iteration"] = iteration + for r in table_results: + r["iteration"] = iteration + + return alter_results, table_results, duration_s + + +# --------------------------------------------------------------------------- +# Report generation +# --------------------------------------------------------------------------- + +def generate_report( + args, + all_alter_results: list, + all_table_results: list, + iteration_durations: list, + profile_handler: ProfileLogHandler, + report_path: str, +): + """Generate the markdown report.""" + lines = [] + + def w(text=""): + lines.append(text) + + total_alters = len(all_alter_results) + total_duration = sum(iteration_durations) + successful = [r for r in all_alter_results if r["success"]] + failed = [r for r in all_alter_results if not r["success"]] + success_latencies = [r["latency_ms"] for r in successful] + all_latencies = [r["latency_ms"] for r in all_alter_results] + + # --- Header --- + w(f"# Profile: C={args.columns}, T={args.tags}, N={args.threads}, I={args.iterations}") + w() + w("## Configuration") + w(f"- **Server**: `{SERVER_HOSTNAME}`") + w(f"- **HTTP Path**: `{HTTP_PATH}`") + w(f"- **Catalog.Schema**: `{CATALOG}.{SCHEMA}`") + w(f"- **Tables per iteration**: {args.tables_per_iteration}") + w(f"- **Columns tagged per table**: {args.columns}") + w(f"- **Tags per ALTER**: {args.tags}") + w(f"- **Threads**: {args.threads}") + w(f"- **Iterations**: {args.iterations}") + w(f"- **Total ALTERs**: {total_alters}") + w(f"- **Date**: {datetime.now().isoformat()}") + w() + + # --- Overall ALTER Latency --- + w("## Per-ALTER Latency — All Iterations (ms)") + w() + stats = latency_stats(success_latencies) + w("| Metric | Value |") + w("|--------|-------|") + for k, v in stats.items(): + w(f"| {k} | {v:.2f} |") + w() + + # --- Throughput --- + w("## Throughput") + w() + w(f"- **Total ALTERs**: {total_alters}") + w(f"- **Successful**: {len(successful)}") + w(f"- **Failed**: {len(failed)}") + w(f"- **Total wall-clock**: {total_duration:.2f}s") + if total_duration > 0: + w(f"- **ALTERs/sec**: {total_alters / total_duration:.2f}") + w() + + # --- Cold Start vs Steady State --- + if args.iterations > 1: + w("## Cold Start vs Steady State") + w() + iter1 = [r["latency_ms"] for r in successful if r["iteration"] == 1] + iter_rest = [r["latency_ms"] for r in successful if r["iteration"] > 1] + w("| Phase | ALTERs | Mean (ms) | P50 (ms) | P99 (ms) |") + w("|-------|--------|-----------|----------|----------|") + s1 = latency_stats(iter1) + sr = latency_stats(iter_rest) + w(f"| Iteration 1 | {s1['count']:.0f} | {s1['mean']:.2f} | {s1['p50']:.2f} | {s1['p99']:.2f} |") + w(f"| Iterations 2-{args.iterations} | {sr['count']:.0f} | {sr['mean']:.2f} | {sr['p50']:.2f} | {sr['p99']:.2f} |") + w() + + # --- Per-Iteration Summary --- + w("## Per-Iteration Summary") + w() + w("| Iteration | ALTERs | Mean (ms) | P50 (ms) | P99 (ms) | Errors | Duration (s) | ALTERs/sec |") + w("|-----------|--------|-----------|----------|----------|--------|--------------|------------|") + for i in range(1, args.iterations + 1): + iter_lats = [r["latency_ms"] for r in successful if r["iteration"] == i] + iter_errs = len([r for r in failed if r["iteration"] == i]) + s = latency_stats(iter_lats) + dur = iteration_durations[i - 1] + alters_in_iter = len([r for r in all_alter_results if r["iteration"] == i]) + rps = alters_in_iter / dur if dur > 0 else 0 + w( + f"| {i} | {s['count']:.0f} | {s['mean']:.2f} | {s['p50']:.2f} | {s['p99']:.2f} " + f"| {iter_errs} | {dur:.2f} | {rps:.2f} |" + ) + w() + + # --- Per-ALTER Latency by Table --- + w("## Per-ALTER Latency by Table (ms)") + w() + w("| Table | Count | Min | Max | Mean | P50 | P90 | P95 | P99 |") + w("|-------|-------|-----|-----|------|-----|-----|-----|-----|") + tables_seen = sorted(set(r["table"] for r in successful), key=lambda x: int(x.replace("table", ""))) + for tbl in tables_seen: + tbl_lats = [r["latency_ms"] for r in successful if r["table"] == tbl] + s = latency_stats(tbl_lats) + w( + f"| {tbl} | {s['count']:.0f} | {s['min']:.2f} | {s['max']:.2f} " + f"| {s['mean']:.2f} | {s['p50']:.2f} | {s['p90']:.2f} | {s['p95']:.2f} | {s['p99']:.2f} |" + ) + w() + + # --- Per-ALTER Latency by Thread --- + w("## Per-ALTER Latency by Thread (ms)") + w() + w("| Thread | Count | Min | Max | Mean | P50 | P90 | P95 | P99 |") + w("|--------|-------|-----|-----|------|-----|-----|-----|-----|") + threads_seen = sorted(set(r["thread_id"] for r in successful)) + for tid in threads_seen: + thr_lats = [r["latency_ms"] for r in successful if r["thread_id"] == tid] + s = latency_stats(thr_lats) + w( + f"| {tid} | {s['count']:.0f} | {s['min']:.2f} | {s['max']:.2f} " + f"| {s['mean']:.2f} | {s['p50']:.2f} | {s['p90']:.2f} | {s['p95']:.2f} | {s['p99']:.2f} |" + ) + w() + + # --- Per-Table Latency (all columns in one table) --- + w("## Per-Table Latency — Time to Tag All Columns in One Table (ms)") + w() + w("| Table | Iteration | Thread | Latency (ms) | ALTERs/sec | Errors |") + w("|-------|-----------|--------|--------------|------------|--------|") + for r in sorted(all_table_results, key=lambda x: (x["iteration"], int(x["table"].replace("table", "")))): + w( + f"| {r['table']} | {r['iteration']} | {r['thread_id']} " + f"| {r['latency_ms']:.2f} | {r['alters_per_sec']:.2f} | {r['num_errors']} |" + ) + w() + + # --- Per-Table Aggregate Stats --- + w("## Per-Table Aggregate Stats (ms)") + w() + table_latencies = [r["latency_ms"] for r in all_table_results] + s = latency_stats(table_latencies) + w("| Metric | Value |") + w("|--------|-------|") + for k, v in s.items(): + w(f"| {k} | {v:.2f} |") + w() + + # --- Error Analysis --- + w("## Error Analysis") + w() + if not failed: + w("No errors encountered.") + else: + error_groups = defaultdict(list) + for r in failed: + error_groups[r["error_type"]].append(r) + w("| Error Type | Count | % of Total | Sample Message |") + w("|------------|-------|------------|----------------|") + for etype, records in sorted(error_groups.items(), key=lambda x: -len(x[1])): + pct = len(records) / total_alters * 100 + sample = records[0]["error_message"][:200] if records[0]["error_message"] else "N/A" + w(f"| {etype} | {len(records)} | {pct:.1f}% | {sample} |") + w() + + w("### Error Detail") + w() + for etype, records in sorted(error_groups.items(), key=lambda x: -len(x[1])): + w(f"**{etype}** ({len(records)} occurrences)") + w() + # Show up to 3 samples + for r in records[:3]: + w(f"- Table: {r['table']}, Column: {r['column']}, Iteration: {r['iteration']}") + w(f" Latency: {r['latency_ms']:.2f}ms") + w(f" Message: {r['error_message']}") + if r["error_context"]: + w(f" Context: {r['error_context']}") + if len(records) > 3: + w(f"- ... and {len(records) - 3} more") + w() + w() + + # --- Retry Analysis --- + ATTEMPT_RE = re.compile(r"\[PROFILE\] (?P\w+) attempt (?P\d+)/(?P\d+)") + SUCCESS_RE = re.compile(r"\[PROFILE\] (?P\w+) succeeded on attempt (?P\d+) in (?P[0-9.]+)s") + SHOULD_RETRY_RE = re.compile(r"\[PROFILE\] should_retry: status=(?P\d+), command=(?P[^,]+),") + RETRY_SLEEP_RE = re.compile(r"\[PROFILE\] (?P\w+) retry sleep=(?P[0-9.]+)s, attempt=(?P\d+)/(?P\d+)") + + parsed_events = [] + for r in profile_handler.records: + msg = r["message"] + for etype, regex in [("attempt", ATTEMPT_RE), ("success", SUCCESS_RE), + ("should_retry", SHOULD_RETRY_RE), ("retry_sleep", RETRY_SLEEP_RE)]: + m = regex.search(msg) + if m: + event = {"type": etype, "thread": r["thread"], "timestamp": r["timestamp"], "message": msg} + event.update(m.groupdict()) + if "attempt" in event: + event["attempt"] = int(event["attempt"]) + if "status" in event: + event["status"] = int(event["status"]) + parsed_events.append(event) + break + + # Filter to ExecuteStatement only (excludes OpenSession, CloseSession, GetOperationStatus) + exec_events = [e for e in parsed_events if e.get("cmd") == "ExecuteStatement"] + exec_retry_sleeps = [e for e in exec_events if e["type"] == "retry_sleep"] + exec_should_retry = [e for e in exec_events if e["type"] == "should_retry"] + exec_success_after_retry = [e for e in exec_events if e["type"] == "success" and e["attempt"] > 1] + exec_total_attempts = [e for e in exec_events if e["type"] == "attempt"] + exec_successes = [e for e in exec_events if e["type"] == "success"] + + w("## Statement Retry Analysis (ExecuteStatement only)") + w() + w("*Includes benchmarked ALTERs + one warmup SELECT 1 per worker thread.*") + w() + w(f"- **Total [PROFILE] events (all commands)**: {len(parsed_events)}") + w(f"- **ExecuteStatement attempts**: {len(exec_total_attempts)}") + w(f"- **ExecuteStatement successes**: {len(exec_successes)}") + w(f"- **ExecuteStatement retry sleeps**: {len(exec_retry_sleeps)}") + w(f"- **ExecuteStatement succeeded after retry (attempt > 1)**: {len(exec_success_after_retry)}") + w(f"- **should_retry evaluations**: {len(exec_should_retry)}") + w() + + if exec_retry_sleeps: + w("### Retry Events") + w() + w("| Timestamp | Thread | Attempt | Sleep (s) | Message |") + w("|-----------|--------|---------|-----------|---------|") + for e in exec_retry_sleeps[:50]: + ts = datetime.fromtimestamp(e["timestamp"]).strftime("%H:%M:%S.%f")[:-3] + w(f"| {ts} | {e['thread']} | {e['attempt']} | {e.get('sleep', '?')} | {e['message'][:150]} |") + if len(exec_retry_sleeps) > 50: + w(f"| ... | ... | ... | ... | {len(exec_retry_sleeps) - 50} more |") + w() + + if exec_should_retry: + w("### should_retry Decisions") + w() + status_counts = defaultdict(int) + for e in exec_should_retry: + status_counts[e["status"]] += 1 + w("| HTTP Status | Count |") + w("|-------------|-------|") + for status, count in sorted(status_counts.items()): + w(f"| {status} | {count} |") + w() + + # --- Footer --- + w("---") + w(f"*Generated by profile_column_tags.py on {datetime.now().isoformat()}*") + + report_text = "\n".join(lines) + with open(report_path, "w") as f: + f.write(report_text) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser(description="Profile SET COLUMN TAGS performance") + parser.add_argument("--columns", type=int, required=True, help="Number of columns to tag per table") + parser.add_argument("--tags", type=int, required=True, help="Number of tags per ALTER command") + parser.add_argument("--threads", type=int, required=True, help="Number of concurrent threads") + parser.add_argument("--iterations", type=int, required=True, help="Number of iterations") + parser.add_argument("--tables-per-iteration", type=int, default=None, help="Tables to process per iteration (default = --threads, i.e. 1 table per thread)") + parser.add_argument("--validate", action="store_true", help="Quick validation: override to 1 iteration, print result") + parser.add_argument("--skip-setup", action="store_true", help="Skip table creation (tables already exist)") + args = parser.parse_args() + + if args.tables_per_iteration is None: + args.tables_per_iteration = args.threads + + if args.columns > MAX_COLUMNS: + print(f"Error: --columns {args.columns} exceeds MAX_COLUMNS={MAX_COLUMNS}") + sys.exit(1) + + if args.tables_per_iteration > NUM_TABLES: + print(f"Error: --tables-per-iteration {args.tables_per_iteration} exceeds NUM_TABLES={NUM_TABLES}") + sys.exit(1) + + if args.validate: + args.iterations = 1 + print("=== VALIDATION MODE: 1 iteration only ===\n") + + # File paths + os.makedirs(RESULTS_DIR, exist_ok=True) + prefix = f"c{args.columns}_t{args.tags}_n{args.threads}_i{args.iterations}" + report_path = os.path.join(RESULTS_DIR, f"{prefix}_report.md") + data_path = os.path.join(RESULTS_DIR, f"{prefix}_data.jsonl") + log_path = os.path.join(RESULTS_DIR, f"{prefix}_retries.log") + + # Logging + profile_handler = setup_logging(log_path) + + print(f"Profile: columns={args.columns}, tags={args.tags}, threads={args.threads}, iterations={args.iterations}, tables_per_iteration={args.tables_per_iteration}") + print(f"ALTERs per iteration: {args.tables_per_iteration * args.columns} ({args.tables_per_iteration} tables x {args.columns} columns)") + print(f"Total ALTERs: {args.tables_per_iteration * args.columns * args.iterations}") + print(f"Output: {report_path}") + print() + + # Setup + if not args.skip_setup: + setup_tables() + + # Run iterations + all_alter_results = [] + all_table_results = [] + iteration_durations = [] + + for i in range(1, args.iterations + 1): + print(f"Iteration {i}/{args.iterations}...", end=" ", flush=True) + alter_results, table_results, duration = run_iteration( + iteration=i, + num_columns=args.columns, + num_tags=args.tags, + num_threads=args.threads, + tables_per_iteration=args.tables_per_iteration, + ) + all_alter_results.extend(alter_results) + all_table_results.extend(table_results) + iteration_durations.append(duration) + + alters_count = len(alter_results) + errors = len([r for r in alter_results if not r["success"]]) + rps = alters_count / duration if duration > 0 else 0 + print(f"done in {duration:.2f}s ({alters_count} ALTERs, {errors} errors, {rps:.1f} ALTERs/sec)") + + print() + + # Write raw data + with open(data_path, "w") as f: + for r in all_alter_results: + f.write(json.dumps(r) + "\n") + # separator + f.write("\n") + for r in all_table_results: + f.write(json.dumps(r) + "\n") + + # Generate report + generate_report(args, all_alter_results, all_table_results, iteration_durations, profile_handler, report_path) + + print(f"Report written to: {report_path}") + print(f"Raw data written to: {data_path}") + print(f"Retry log written to: {log_path}") + + # Print summary to stdout + success_lats = [r["latency_ms"] for r in all_alter_results if r["success"]] + if success_lats: + s = latency_stats(success_lats) + total_dur = sum(iteration_durations) + print() + print("=== Summary ===") + print(f" ALTERs: {len(all_alter_results)} ({len(success_lats)} ok, {len(all_alter_results) - len(success_lats)} failed)") + print(f" Latency: p50={s['p50']:.1f}ms p90={s['p90']:.1f}ms p95={s['p95']:.1f}ms p99={s['p99']:.1f}ms max={s['max']:.1f}ms") + print(f" Throughput: {len(all_alter_results) / total_dur:.1f} ALTERs/sec") + + +if __name__ == "__main__": + main() diff --git a/examples/profile_read_then_write_table_tags.py b/examples/profile_read_then_write_table_tags.py new file mode 100644 index 000000000..5aaf729c4 --- /dev/null +++ b/examples/profile_read_then_write_table_tags.py @@ -0,0 +1,556 @@ +#!/usr/bin/env python3 +""" +Profile information_schema.table_tags SELECT performance. + +For each table, this script SELECTs existing table tags from +system.information_schema.table_tags. No ALTER/write operations. + +Usage: + python examples/profile_read_then_write_table_tags.py --threads 1 --iterations 1 --validate + python examples/profile_read_then_write_table_tags.py --threads 8 --iterations 10 +""" + +import argparse +import json +import logging +import os +import random +import re +import statistics +import string +import sys +import threading +import time +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime +from queue import Empty, Queue + +sys.stdout.reconfigure(line_buffering=True) + +import urllib3 +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + +from databricks import sql + +# ============================================================ +# CONFIGURATION — loaded from examples/credentials.env +# ============================================================ +from load_credentials import load_credentials +_creds = load_credentials() +SERVER_HOSTNAME = _creds["SERVER_HOSTNAME"] +HTTP_PATH = _creds["HTTP_PATH"] +ACCESS_TOKEN = _creds["ACCESS_TOKEN"] +CATALOG = _creds["CATALOG"] +SCHEMA = _creds["SCHEMA"] +# ============================================================ + +NUM_TABLES = 128 +RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results", "read_then_write_table_tags") + +SELECT_TEMPLATE = """SELECT tag_name, tag_value +FROM system.information_schema.table_tags +WHERE catalog_name = '{catalog}' + AND schema_name = '{schema}' + AND table_name = '{table}'""" + + +# --------------------------------------------------------------------------- +# Logging setup +# --------------------------------------------------------------------------- + +class ProfileLogHandler(logging.Handler): + def __init__(self): + super().__init__() + self.records: list = [] + + def emit(self, record): + msg = record.getMessage() + if "[PROFILE]" in msg: + self.records.append( + {"timestamp": record.created, "thread": record.threadName, "message": msg} + ) + + +def setup_logging(log_path: str) -> ProfileLogHandler: + profile_handler = ProfileLogHandler() + profile_handler.setLevel(logging.INFO) + + file_handler = logging.FileHandler(log_path, mode="w") + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter( + logging.Formatter("%(asctime)s %(threadName)s %(name)s %(levelname)s %(message)s") + ) + + for logger_name in [ + "databricks.sql.backend.thrift_backend", + "databricks.sql.auth.retry", + "databricks.sql.client", + ]: + lgr = logging.getLogger(logger_name) + lgr.setLevel(logging.DEBUG) + lgr.addHandler(profile_handler) + lgr.addHandler(file_handler) + + return profile_handler + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def conn_params() -> dict: + return { + "server_hostname": SERVER_HOSTNAME, + "http_path": HTTP_PATH, + "access_token": ACCESS_TOKEN, + "_tls_no_verify": True, + "enable_telemetry": False, + } + + +def random_tag_value(length: int = 5) -> str: + return "".join(random.choices(string.ascii_lowercase, k=length)) + + +def percentile(data: list, p: float) -> float: + if not data: + return 0.0 + sorted_data = sorted(data) + k = (len(sorted_data) - 1) * (p / 100.0) + f = int(k) + c = f + 1 + if c >= len(sorted_data): + return sorted_data[f] + return sorted_data[f] + (k - f) * (sorted_data[c] - sorted_data[f]) + + +def latency_stats(latencies: list) -> dict: + if not latencies: + return {k: 0.0 for k in ["count", "min", "max", "mean", "stdev", "p50", "p90", "p95", "p99"]} + return { + "count": len(latencies), + "min": min(latencies), + "max": max(latencies), + "mean": statistics.mean(latencies), + "stdev": statistics.stdev(latencies) if len(latencies) > 1 else 0.0, + "p50": percentile(latencies, 50), + "p90": percentile(latencies, 90), + "p95": percentile(latencies, 95), + "p99": percentile(latencies, 99), + } + + +# --------------------------------------------------------------------------- +# Worker +# --------------------------------------------------------------------------- + +def worker( + thread_id: int, + table_queue: Queue, + results: list, + results_lock: threading.Lock, +): + local_results = [] + table_fqn_prefix = f"`{CATALOG}`.`{SCHEMA}`" + + with sql.connect(**conn_params()) as connection: + with connection.cursor() as cursor: + # Warmup + cursor.execute("SELECT 1") + + while True: + try: + table_name = table_queue.get_nowait() + except Empty: + break + + table_fqn = f"{table_fqn_prefix}.{table_name}" + + # --- Step 1: Read table tags from information_schema --- + select_sql = SELECT_TEMPLATE.format( + catalog=CATALOG, schema=SCHEMA, table=table_name + ) + + op_start = time.perf_counter() + select_start = time.perf_counter() + select_success = True + select_error_type = None + select_error_message = None + select_rows = 0 + select_statement_id = None + + try: + cursor.execute(select_sql) + select_statement_id = str(cursor.active_command_id) if cursor.active_command_id else None + rows = cursor.fetchall() + select_rows = len(rows) + except Exception as e: + select_success = False + select_error_type = type(e).__name__ + select_error_message = str(e)[:500] + + select_end = time.perf_counter() + select_latency_ms = (select_end - select_start) * 1000 + + local_results.append( + { + "table": table_name, + "thread_id": thread_id, + "select_latency_ms": round(select_latency_ms, 2), + "select_success": select_success, + "select_error_type": select_error_type, + "select_error_message": select_error_message, + "select_rows": select_rows, + "select_statement_id": select_statement_id, + "timestamp": op_start, + } + ) + + with results_lock: + results.extend(local_results) + + +# --------------------------------------------------------------------------- +# Run one iteration +# --------------------------------------------------------------------------- + +def run_iteration(iteration: int, num_threads: int, tables_per_iteration: int) -> tuple: + table_queue = Queue() + start = ((iteration - 1) * tables_per_iteration) % NUM_TABLES + for i in range(tables_per_iteration): + table_idx = start + i + 1 + table_queue.put(f"table{table_idx}") + + results: list = [] + results_lock = threading.Lock() + + iter_start = time.perf_counter() + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [ + executor.submit(worker, tid, table_queue, results, results_lock) + for tid in range(num_threads) + ] + for f in as_completed(futures): + f.result() + + iter_end = time.perf_counter() + duration_s = iter_end - iter_start + + for r in results: + r["iteration"] = iteration + + return results, duration_s + + +# --------------------------------------------------------------------------- +# Report generation +# --------------------------------------------------------------------------- + +def generate_report( + args, + all_results: list, + iteration_durations: list, + profile_handler: ProfileLogHandler, + report_path: str, +): + lines = [] + + def w(text=""): + lines.append(text) + + total_ops = len(all_results) + total_duration = sum(iteration_durations) + + select_ok = [r for r in all_results if r["select_success"]] + select_fail = [r for r in all_results if not r["select_success"]] + select_latencies = [r["select_latency_ms"] for r in select_ok] + + # --- Header --- + w(f"# Information Schema Table Tags Profile: N={args.threads}, I={args.iterations}") + w() + w("## Configuration") + w(f"- **Server**: `{SERVER_HOSTNAME}`") + w(f"- **HTTP Path**: `{HTTP_PATH}`") + w(f"- **Catalog.Schema**: `{CATALOG}.{SCHEMA}`") + w(f"- **Tables**: {NUM_TABLES}") + w(f"- **Pattern**: SELECT from system.information_schema.table_tags per table") + w(f"- **Threads**: {args.threads}") + w(f"- **Iterations**: {args.iterations}") + w(f"- **Total SELECTs**: {total_ops}") + w(f"- **Date**: {datetime.now().isoformat()}") + w() + + # --- Latency --- + w("## SELECT Latency (ms)") + w() + ss = latency_stats(select_latencies) + w("| Metric | Value |") + w("|--------|-------|") + for k in ["count", "min", "max", "mean", "stdev", "p50", "p90", "p95", "p99"]: + w(f"| {k} | {ss[k]:.2f} |") + w() + + # --- Throughput --- + w("## Throughput") + w() + w(f"- **Total SELECTs**: {total_ops}") + w(f"- **Successes**: {len(select_ok)} / {total_ops}") + w(f"- **Failures**: {len(select_fail)} / {total_ops}") + w(f"- **Total wall-clock**: {total_duration:.2f}s") + if total_duration > 0: + w(f"- **SELECTs/sec**: {total_ops / total_duration:.2f}") + w() + + # --- Cold Start vs Steady State --- + if args.iterations > 1: + w("## Cold Start vs Steady State") + w() + iter1 = [r["select_latency_ms"] for r in select_ok if r["iteration"] == 1] + iter_rest = [r["select_latency_ms"] for r in select_ok if r["iteration"] > 1] + s1 = latency_stats(iter1) + sr = latency_stats(iter_rest) + w("| Phase | Count | P50 (ms) | P90 (ms) | P99 (ms) |") + w("|-------|-------|----------|----------|----------|") + w(f"| Iteration 1 | {s1['count']:.0f} | {s1['p50']:.2f} | {s1['p90']:.2f} | {s1['p99']:.2f} |") + w(f"| Iterations 2-{args.iterations} | {sr['count']:.0f} | {sr['p50']:.2f} | {sr['p90']:.2f} | {sr['p99']:.2f} |") + w() + + # --- Per-Iteration Summary --- + w("## Per-Iteration Summary") + w() + w("| Iteration | SELECTs | P50 (ms) | P90 (ms) | P99 (ms) | Errors | Duration (s) | SELECTs/sec |") + w("|-----------|---------|----------|----------|----------|--------|--------------|-------------|") + for i in range(1, args.iterations + 1): + i_lats = [r["select_latency_ms"] for r in select_ok if r["iteration"] == i] + i_errs = len([r for r in all_results if r["iteration"] == i and not r["select_success"]]) + dur = iteration_durations[i - 1] + ops_in_iter = len([r for r in all_results if r["iteration"] == i]) + rps = ops_in_iter / dur if dur > 0 else 0 + s = latency_stats(i_lats) + w(f"| {i} | {s['count']:.0f} | {s['p50']:.2f} | {s['p90']:.2f} | {s['p99']:.2f} | {i_errs} | {dur:.2f} | {rps:.2f} |") + w() + + # --- All SELECTs with Statement IDs --- + w("## All SELECTs with Statement IDs") + w() + w("| Table | Iteration | Latency (ms) | Rows | Statement ID |") + w("|-------|-----------|-------------|------|--------------|") + sorted_results = sorted(all_results, key=lambda r: (r.get("iteration", 0), int(r["table"].replace("table", "")))) + for r in sorted_results: + w( + f"| {r['table']} | {r.get('iteration', '?')} " + f"| {r['select_latency_ms']:.2f} | {r['select_rows']} " + f"| {r.get('select_statement_id', 'N/A')} |" + ) + w() + + # --- By Thread --- + w("## Latency by Thread (ms)") + w() + w("| Thread | SELECTs | P50 | P90 | P99 |") + w("|--------|---------|-----|-----|-----|") + threads_seen = sorted(set(r["thread_id"] for r in select_ok)) + for tid in threads_seen: + t_lats = latency_stats([r["select_latency_ms"] for r in select_ok if r["thread_id"] == tid]) + w(f"| {tid} | {t_lats['count']:.0f} | {t_lats['p50']:.2f} | {t_lats['p90']:.2f} | {t_lats['p99']:.2f} |") + w() + + # --- Rows returned by SELECT --- + w("## Information Schema Rows Returned") + w() + row_counts = [r["select_rows"] for r in select_ok] + if row_counts: + w(f"- **Min rows**: {min(row_counts)}") + w(f"- **Max rows**: {max(row_counts)}") + w(f"- **Mean rows**: {statistics.mean(row_counts):.1f}") + w() + + # --- Error Analysis --- + w("## Error Analysis") + w() + all_errors = [] + for r in all_results: + if not r["select_success"]: + all_errors.append({"table": r["table"], "iteration": r.get("iteration", "?"), + "error_type": r["select_error_type"], "error_message": r["select_error_message"]}) + + if not all_errors: + w("No errors encountered.") + else: + error_groups = defaultdict(list) + for e in all_errors: + error_groups[e["error_type"]].append(e) + w("| Error Type | Count | % of Total | Sample Message |") + w("|------------|-------|------------|----------------|") + for etype, records in sorted(error_groups.items(), key=lambda x: -len(x[1])): + pct = len(records) / total_ops * 100 + sample = records[0]["error_message"][:200] if records[0]["error_message"] else "N/A" + w(f"| {etype} | {len(records)} | {pct:.1f}% | {sample} |") + w() + + w("### Error Detail") + w() + for etype, records in sorted(error_groups.items(), key=lambda x: -len(x[1])): + w(f"**{etype}** ({len(records)} occurrences)") + w() + for e in records[:3]: + w(f"- Table: {e['table']}, Iteration: {e['iteration']}") + w(f" Message: {e['error_message']}") + if len(records) > 3: + w(f"- ... and {len(records) - 3} more") + w() + w() + + # --- Retry Analysis --- + ATTEMPT_RE = re.compile(r"\[PROFILE\] (?P\w+) attempt (?P\d+)/(?P\d+)") + SUCCESS_RE = re.compile(r"\[PROFILE\] (?P\w+) succeeded on attempt (?P\d+) in (?P[0-9.]+)s") + SHOULD_RETRY_RE = re.compile(r"\[PROFILE\] should_retry: status=(?P\d+), command=(?P[^,]+),") + RETRY_SLEEP_RE = re.compile(r"\[PROFILE\] (?P\w+) retry sleep=(?P[0-9.]+)s, attempt=(?P\d+)/(?P\d+)") + + parsed_events = [] + for r in profile_handler.records: + msg = r["message"] + for etype, regex in [("attempt", ATTEMPT_RE), ("success", SUCCESS_RE), + ("should_retry", SHOULD_RETRY_RE), ("retry_sleep", RETRY_SLEEP_RE)]: + m = regex.search(msg) + if m: + event = {"type": etype, "thread": r["thread"], "timestamp": r["timestamp"], "message": msg} + event.update(m.groupdict()) + if "attempt" in event: + event["attempt"] = int(event["attempt"]) + if "status" in event: + event["status"] = int(event["status"]) + parsed_events.append(event) + break + + exec_events = [e for e in parsed_events if e.get("cmd") == "ExecuteStatement"] + exec_retry_sleeps = [e for e in exec_events if e["type"] == "retry_sleep"] + exec_should_retry = [e for e in exec_events if e["type"] == "should_retry"] + exec_success_after_retry = [e for e in exec_events if e["type"] == "success" and e["attempt"] > 1] + exec_total_attempts = [e for e in exec_events if e["type"] == "attempt"] + exec_successes = [e for e in exec_events if e["type"] == "success"] + + w("## Statement Retry Analysis (ExecuteStatement only)") + w() + w("*Includes information_schema SELECTs and one warmup SELECT 1 per thread.*") + w() + w(f"- **Total [PROFILE] events (all commands)**: {len(parsed_events)}") + w(f"- **ExecuteStatement attempts**: {len(exec_total_attempts)}") + w(f"- **ExecuteStatement successes**: {len(exec_successes)}") + w(f"- **ExecuteStatement retry sleeps**: {len(exec_retry_sleeps)}") + w(f"- **ExecuteStatement succeeded after retry (attempt > 1)**: {len(exec_success_after_retry)}") + w(f"- **should_retry evaluations**: {len(exec_should_retry)}") + w() + + if exec_retry_sleeps: + w("### Retry Events") + w() + w("| Timestamp | Thread | Attempt | Sleep (s) | Message |") + w("|-----------|--------|---------|-----------|---------|") + for e in exec_retry_sleeps[:50]: + ts = datetime.fromtimestamp(e["timestamp"]).strftime("%H:%M:%S.%f")[:-3] + w(f"| {ts} | {e['thread']} | {e['attempt']} | {e.get('sleep', '?')} | {e['message'][:150]} |") + if len(exec_retry_sleeps) > 50: + w(f"| ... | ... | ... | ... | {len(exec_retry_sleeps) - 50} more |") + w() + + if exec_should_retry: + w("### should_retry Decisions") + w() + status_counts = defaultdict(int) + for e in exec_should_retry: + status_counts[e["status"]] += 1 + w("| HTTP Status | Count |") + w("|-------------|-------|") + for status, count in sorted(status_counts.items()): + w(f"| {status} | {count} |") + w() + + # --- Footer --- + w("---") + w(f"*Generated by profile_read_then_write_table_tags.py on {datetime.now().isoformat()}*") + + report_text = "\n".join(lines) + with open(report_path, "w") as f: + f.write(report_text) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description="Profile read-from-information_schema then write-table-tag pattern" + ) + parser.add_argument("--threads", type=int, required=True, help="Number of concurrent threads") + parser.add_argument("--iterations", type=int, required=True, help="Number of iterations") + parser.add_argument("--tables-per-iteration", type=int, default=None, help="Tables per iteration (default = --threads)") + parser.add_argument("--validate", action="store_true", help="Quick validation: override to 1 iteration") + args = parser.parse_args() + + if args.tables_per_iteration is None: + args.tables_per_iteration = args.threads + + if args.tables_per_iteration > NUM_TABLES: + print(f"Error: --tables-per-iteration {args.tables_per_iteration} exceeds NUM_TABLES={NUM_TABLES}") + sys.exit(1) + + if args.validate: + args.iterations = 1 + print("=== VALIDATION MODE: 1 iteration only ===\n") + + os.makedirs(RESULTS_DIR, exist_ok=True) + prefix = f"rwtt_n{args.threads}_i{args.iterations}" + report_path = os.path.join(RESULTS_DIR, f"{prefix}_report.md") + data_path = os.path.join(RESULTS_DIR, f"{prefix}_data.jsonl") + log_path = os.path.join(RESULTS_DIR, f"{prefix}_retries.log") + + profile_handler = setup_logging(log_path) + + print(f"Profile (information_schema.table_tags): threads={args.threads}, iterations={args.iterations}, tables_per_iteration={args.tables_per_iteration}") + print(f"SELECTs per iteration: {args.tables_per_iteration} (1 per table)") + print(f"Total SELECTs: {args.tables_per_iteration * args.iterations}") + print(f"Output: {report_path}") + print() + + all_results = [] + iteration_durations = [] + + for i in range(1, args.iterations + 1): + print(f"Iteration {i}/{args.iterations}...", end=" ", flush=True) + results, duration = run_iteration(iteration=i, num_threads=args.threads, tables_per_iteration=args.tables_per_iteration) + all_results.extend(results) + iteration_durations.append(duration) + + errs = len([r for r in results if not r["select_success"]]) + rps = len(results) / duration if duration > 0 else 0 + print(f"done in {duration:.2f}s ({len(results)} SELECTs, {errs} errors, {rps:.1f} SELECTs/sec)") + + print() + + with open(data_path, "w") as f: + for r in all_results: + f.write(json.dumps(r) + "\n") + + generate_report(args, all_results, iteration_durations, profile_handler, report_path) + + print(f"Report written to: {report_path}") + print(f"Raw data written to: {data_path}") + print(f"Retry log written to: {log_path}") + + ok = [r for r in all_results if r["select_success"]] + if ok: + s = latency_stats([r["select_latency_ms"] for r in ok]) + total_dur = sum(iteration_durations) + print() + print("=== Summary ===") + print(f" SELECTs: {len(all_results)} ({len(ok)} ok, {len(all_results) - len(ok)} failed)") + print(f" Latency: p50={s['p50']:.1f}ms p90={s['p90']:.1f}ms p99={s['p99']:.1f}ms max={s['max']:.1f}ms") + print(f" Throughput: {len(all_results) / total_dur:.1f} SELECTs/sec") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/profile_read_then_write_tags.py b/examples/profile_read_then_write_tags.py new file mode 100644 index 000000000..0ac138a18 --- /dev/null +++ b/examples/profile_read_then_write_tags.py @@ -0,0 +1,557 @@ +#!/usr/bin/env python3 +""" +Profile information_schema.column_tags SELECT performance. + +For each table, this script SELECTs existing column tags from +system.information_schema.column_tags. No ALTER/write operations. + +Usage: + python examples/profile_read_then_write_tags.py --threads 1 --iterations 1 --validate + python examples/profile_read_then_write_tags.py --threads 8 --iterations 10 + python examples/profile_read_then_write_tags.py --threads 32 --iterations 10 +""" + +import argparse +import json +import logging +import os +import random +import re +import statistics +import string +import sys +import threading +import time +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime +from queue import Empty, Queue + +sys.stdout.reconfigure(line_buffering=True) + +import urllib3 +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + +from databricks import sql + +# ============================================================ +# CONFIGURATION — loaded from examples/credentials.env +# ============================================================ +from load_credentials import load_credentials +_creds = load_credentials() +SERVER_HOSTNAME = _creds["SERVER_HOSTNAME"] +HTTP_PATH = _creds["HTTP_PATH"] +ACCESS_TOKEN = _creds["ACCESS_TOKEN"] +CATALOG = _creds["CATALOG"] +SCHEMA = _creds["SCHEMA"] +# ============================================================ + +NUM_TABLES = 128 +RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results", "read_then_write") + +SELECT_TEMPLATE = """SELECT column_name, tag_name, tag_value +FROM system.information_schema.column_tags +WHERE catalog_name = '{catalog}' + AND schema_name = '{schema}' + AND table_name = '{table}'""" + + +# --------------------------------------------------------------------------- +# Logging setup +# --------------------------------------------------------------------------- + +class ProfileLogHandler(logging.Handler): + def __init__(self): + super().__init__() + self.records: list = [] + + def emit(self, record): + msg = record.getMessage() + if "[PROFILE]" in msg: + self.records.append( + {"timestamp": record.created, "thread": record.threadName, "message": msg} + ) + + +def setup_logging(log_path: str) -> ProfileLogHandler: + profile_handler = ProfileLogHandler() + profile_handler.setLevel(logging.INFO) + + file_handler = logging.FileHandler(log_path, mode="w") + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter( + logging.Formatter("%(asctime)s %(threadName)s %(name)s %(levelname)s %(message)s") + ) + + for logger_name in [ + "databricks.sql.backend.thrift_backend", + "databricks.sql.auth.retry", + "databricks.sql.client", + ]: + lgr = logging.getLogger(logger_name) + lgr.setLevel(logging.DEBUG) + lgr.addHandler(profile_handler) + lgr.addHandler(file_handler) + + return profile_handler + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def conn_params() -> dict: + return { + "server_hostname": SERVER_HOSTNAME, + "http_path": HTTP_PATH, + "access_token": ACCESS_TOKEN, + "_tls_no_verify": True, + "enable_telemetry": False, + } + + +def random_tag_value(length: int = 5) -> str: + return "".join(random.choices(string.ascii_lowercase, k=length)) + + +def percentile(data: list, p: float) -> float: + if not data: + return 0.0 + sorted_data = sorted(data) + k = (len(sorted_data) - 1) * (p / 100.0) + f = int(k) + c = f + 1 + if c >= len(sorted_data): + return sorted_data[f] + return sorted_data[f] + (k - f) * (sorted_data[c] - sorted_data[f]) + + +def latency_stats(latencies: list) -> dict: + if not latencies: + return {k: 0.0 for k in ["count", "min", "max", "mean", "stdev", "p50", "p90", "p95", "p99"]} + return { + "count": len(latencies), + "min": min(latencies), + "max": max(latencies), + "mean": statistics.mean(latencies), + "stdev": statistics.stdev(latencies) if len(latencies) > 1 else 0.0, + "p50": percentile(latencies, 50), + "p90": percentile(latencies, 90), + "p95": percentile(latencies, 95), + "p99": percentile(latencies, 99), + } + + +# --------------------------------------------------------------------------- +# Worker +# --------------------------------------------------------------------------- + +def worker( + thread_id: int, + table_queue: Queue, + results: list, + results_lock: threading.Lock, +): + local_results = [] + table_fqn_prefix = f"`{CATALOG}`.`{SCHEMA}`" + + with sql.connect(**conn_params()) as connection: + with connection.cursor() as cursor: + # Warmup + cursor.execute("SELECT 1") + + while True: + try: + table_name = table_queue.get_nowait() + except Empty: + break + + table_fqn = f"{table_fqn_prefix}.{table_name}" + + # --- Step 1: Read column tags from information_schema --- + select_sql = SELECT_TEMPLATE.format( + catalog=CATALOG, schema=SCHEMA, table=table_name + ) + + op_start = time.perf_counter() + select_start = time.perf_counter() + select_success = True + select_error_type = None + select_error_message = None + select_rows = 0 + select_statement_id = None + + try: + cursor.execute(select_sql) + select_statement_id = str(cursor.active_command_id) if cursor.active_command_id else None + rows = cursor.fetchall() + select_rows = len(rows) + except Exception as e: + select_success = False + select_error_type = type(e).__name__ + select_error_message = str(e)[:500] + + select_end = time.perf_counter() + select_latency_ms = (select_end - select_start) * 1000 + + local_results.append( + { + "table": table_name, + "thread_id": thread_id, + "select_latency_ms": round(select_latency_ms, 2), + "select_success": select_success, + "select_error_type": select_error_type, + "select_error_message": select_error_message, + "select_rows": select_rows, + "select_statement_id": select_statement_id, + "timestamp": op_start, + } + ) + + with results_lock: + results.extend(local_results) + + +# --------------------------------------------------------------------------- +# Run one iteration +# --------------------------------------------------------------------------- + +def run_iteration(iteration: int, num_threads: int, tables_per_iteration: int) -> tuple: + table_queue = Queue() + start = ((iteration - 1) * tables_per_iteration) % NUM_TABLES + for i in range(tables_per_iteration): + table_idx = start + i + 1 + table_queue.put(f"table{table_idx}") + + results: list = [] + results_lock = threading.Lock() + + iter_start = time.perf_counter() + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [ + executor.submit(worker, tid, table_queue, results, results_lock) + for tid in range(num_threads) + ] + for f in as_completed(futures): + f.result() + + iter_end = time.perf_counter() + duration_s = iter_end - iter_start + + for r in results: + r["iteration"] = iteration + + return results, duration_s + + +# --------------------------------------------------------------------------- +# Report generation +# --------------------------------------------------------------------------- + +def generate_report( + args, + all_results: list, + iteration_durations: list, + profile_handler: ProfileLogHandler, + report_path: str, +): + lines = [] + + def w(text=""): + lines.append(text) + + total_ops = len(all_results) + total_duration = sum(iteration_durations) + + select_ok = [r for r in all_results if r["select_success"]] + select_fail = [r for r in all_results if not r["select_success"]] + select_latencies = [r["select_latency_ms"] for r in select_ok] + + # --- Header --- + w(f"# Information Schema Column Tags Profile: N={args.threads}, I={args.iterations}") + w() + w("## Configuration") + w(f"- **Server**: `{SERVER_HOSTNAME}`") + w(f"- **HTTP Path**: `{HTTP_PATH}`") + w(f"- **Catalog.Schema**: `{CATALOG}.{SCHEMA}`") + w(f"- **Tables**: {NUM_TABLES}") + w(f"- **Pattern**: SELECT from system.information_schema.column_tags per table") + w(f"- **Threads**: {args.threads}") + w(f"- **Iterations**: {args.iterations}") + w(f"- **Total SELECTs**: {total_ops}") + w(f"- **Date**: {datetime.now().isoformat()}") + w() + + # --- Latency --- + w("## SELECT Latency (ms)") + w() + ss = latency_stats(select_latencies) + w("| Metric | Value |") + w("|--------|-------|") + for k in ["count", "min", "max", "mean", "stdev", "p50", "p90", "p95", "p99"]: + w(f"| {k} | {ss[k]:.2f} |") + w() + + # --- Throughput --- + w("## Throughput") + w() + w(f"- **Total SELECTs**: {total_ops}") + w(f"- **Successes**: {len(select_ok)} / {total_ops}") + w(f"- **Failures**: {len(select_fail)} / {total_ops}") + w(f"- **Total wall-clock**: {total_duration:.2f}s") + if total_duration > 0: + w(f"- **SELECTs/sec**: {total_ops / total_duration:.2f}") + w() + + # --- Cold Start vs Steady State --- + if args.iterations > 1: + w("## Cold Start vs Steady State") + w() + iter1 = [r["select_latency_ms"] for r in select_ok if r["iteration"] == 1] + iter_rest = [r["select_latency_ms"] for r in select_ok if r["iteration"] > 1] + s1 = latency_stats(iter1) + sr = latency_stats(iter_rest) + w("| Phase | Count | P50 (ms) | P90 (ms) | P99 (ms) |") + w("|-------|-------|----------|----------|----------|") + w(f"| Iteration 1 | {s1['count']:.0f} | {s1['p50']:.2f} | {s1['p90']:.2f} | {s1['p99']:.2f} |") + w(f"| Iterations 2-{args.iterations} | {sr['count']:.0f} | {sr['p50']:.2f} | {sr['p90']:.2f} | {sr['p99']:.2f} |") + w() + + # --- Per-Iteration Summary --- + w("## Per-Iteration Summary") + w() + w("| Iteration | SELECTs | P50 (ms) | P90 (ms) | P99 (ms) | Errors | Duration (s) | SELECTs/sec |") + w("|-----------|---------|----------|----------|----------|--------|--------------|-------------|") + for i in range(1, args.iterations + 1): + i_lats = [r["select_latency_ms"] for r in select_ok if r["iteration"] == i] + i_errs = len([r for r in all_results if r["iteration"] == i and not r["select_success"]]) + dur = iteration_durations[i - 1] + ops_in_iter = len([r for r in all_results if r["iteration"] == i]) + rps = ops_in_iter / dur if dur > 0 else 0 + s = latency_stats(i_lats) + w(f"| {i} | {s['count']:.0f} | {s['p50']:.2f} | {s['p90']:.2f} | {s['p99']:.2f} | {i_errs} | {dur:.2f} | {rps:.2f} |") + w() + + # --- All Operations with Statement IDs --- + w("## All SELECTs with Statement IDs") + w() + w("| Table | Iteration | Latency (ms) | Rows | Statement ID |") + w("|-------|-----------|-------------|------|--------------|") + sorted_results = sorted(all_results, key=lambda r: (r.get("iteration", 0), int(r["table"].replace("table", "")))) + for r in sorted_results: + w( + f"| {r['table']} | {r.get('iteration', '?')} " + f"| {r['select_latency_ms']:.2f} | {r['select_rows']} " + f"| {r.get('select_statement_id', 'N/A')} |" + ) + w() + + # --- By Thread --- + w("## Latency by Thread (ms)") + w() + w("| Thread | SELECTs | P50 | P90 | P99 |") + w("|--------|---------|-----|-----|-----|") + threads_seen = sorted(set(r["thread_id"] for r in select_ok)) + for tid in threads_seen: + t_lats = latency_stats([r["select_latency_ms"] for r in select_ok if r["thread_id"] == tid]) + w(f"| {tid} | {t_lats['count']:.0f} | {t_lats['p50']:.2f} | {t_lats['p90']:.2f} | {t_lats['p99']:.2f} |") + w() + + # --- Rows returned by SELECT --- + w("## Information Schema Rows Returned") + w() + row_counts = [r["select_rows"] for r in select_ok] + if row_counts: + w(f"- **Min rows**: {min(row_counts)}") + w(f"- **Max rows**: {max(row_counts)}") + w(f"- **Mean rows**: {statistics.mean(row_counts):.1f}") + w() + + # --- Error Analysis --- + w("## Error Analysis") + w() + all_errors = [] + for r in all_results: + if not r["select_success"]: + all_errors.append({"table": r["table"], "iteration": r.get("iteration", "?"), + "error_type": r["select_error_type"], "error_message": r["select_error_message"]}) + + if not all_errors: + w("No errors encountered.") + else: + error_groups = defaultdict(list) + for e in all_errors: + error_groups[e["error_type"]].append(e) + w("| Error Type | Count | % of Total | Sample Message |") + w("|------------|-------|------------|----------------|") + for etype, records in sorted(error_groups.items(), key=lambda x: -len(x[1])): + pct = len(records) / total_ops * 100 + sample = records[0]["error_message"][:200] if records[0]["error_message"] else "N/A" + w(f"| {etype} | {len(records)} | {pct:.1f}% | {sample} |") + w() + + w("### Error Detail") + w() + for etype, records in sorted(error_groups.items(), key=lambda x: -len(x[1])): + w(f"**{etype}** ({len(records)} occurrences)") + w() + for e in records[:3]: + w(f"- Table: {e['table']}, Iteration: {e['iteration']}") + w(f" Message: {e['error_message']}") + if len(records) > 3: + w(f"- ... and {len(records) - 3} more") + w() + w() + + # --- Retry Analysis --- + ATTEMPT_RE = re.compile(r"\[PROFILE\] (?P\w+) attempt (?P\d+)/(?P\d+)") + SUCCESS_RE = re.compile(r"\[PROFILE\] (?P\w+) succeeded on attempt (?P\d+) in (?P[0-9.]+)s") + SHOULD_RETRY_RE = re.compile(r"\[PROFILE\] should_retry: status=(?P\d+), command=(?P[^,]+),") + RETRY_SLEEP_RE = re.compile(r"\[PROFILE\] (?P\w+) retry sleep=(?P[0-9.]+)s, attempt=(?P\d+)/(?P\d+)") + + parsed_events = [] + for r in profile_handler.records: + msg = r["message"] + for etype, regex in [("attempt", ATTEMPT_RE), ("success", SUCCESS_RE), + ("should_retry", SHOULD_RETRY_RE), ("retry_sleep", RETRY_SLEEP_RE)]: + m = regex.search(msg) + if m: + event = {"type": etype, "thread": r["thread"], "timestamp": r["timestamp"], "message": msg} + event.update(m.groupdict()) + if "attempt" in event: + event["attempt"] = int(event["attempt"]) + if "status" in event: + event["status"] = int(event["status"]) + parsed_events.append(event) + break + + exec_events = [e for e in parsed_events if e.get("cmd") == "ExecuteStatement"] + exec_retry_sleeps = [e for e in exec_events if e["type"] == "retry_sleep"] + exec_should_retry = [e for e in exec_events if e["type"] == "should_retry"] + exec_success_after_retry = [e for e in exec_events if e["type"] == "success" and e["attempt"] > 1] + exec_total_attempts = [e for e in exec_events if e["type"] == "attempt"] + exec_successes = [e for e in exec_events if e["type"] == "success"] + + w("## Statement Retry Analysis (ExecuteStatement only)") + w() + w("*Includes information_schema SELECTs and one warmup SELECT 1 per thread.*") + w() + w(f"- **Total [PROFILE] events (all commands)**: {len(parsed_events)}") + w(f"- **ExecuteStatement attempts**: {len(exec_total_attempts)}") + w(f"- **ExecuteStatement successes**: {len(exec_successes)}") + w(f"- **ExecuteStatement retry sleeps**: {len(exec_retry_sleeps)}") + w(f"- **ExecuteStatement succeeded after retry (attempt > 1)**: {len(exec_success_after_retry)}") + w(f"- **should_retry evaluations**: {len(exec_should_retry)}") + w() + + if exec_retry_sleeps: + w("### Retry Events") + w() + w("| Timestamp | Thread | Attempt | Sleep (s) | Message |") + w("|-----------|--------|---------|-----------|---------|") + for e in exec_retry_sleeps[:50]: + ts = datetime.fromtimestamp(e["timestamp"]).strftime("%H:%M:%S.%f")[:-3] + w(f"| {ts} | {e['thread']} | {e['attempt']} | {e.get('sleep', '?')} | {e['message'][:150]} |") + if len(exec_retry_sleeps) > 50: + w(f"| ... | ... | ... | ... | {len(exec_retry_sleeps) - 50} more |") + w() + + if exec_should_retry: + w("### should_retry Decisions") + w() + status_counts = defaultdict(int) + for e in exec_should_retry: + status_counts[e["status"]] += 1 + w("| HTTP Status | Count |") + w("|-------------|-------|") + for status, count in sorted(status_counts.items()): + w(f"| {status} | {count} |") + w() + + # --- Footer --- + w("---") + w(f"*Generated by profile_read_then_write_tags.py on {datetime.now().isoformat()}*") + + report_text = "\n".join(lines) + with open(report_path, "w") as f: + f.write(report_text) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description="Profile read-from-information_schema then write-column-tag pattern" + ) + parser.add_argument("--threads", type=int, required=True, help="Number of concurrent threads") + parser.add_argument("--iterations", type=int, required=True, help="Number of iterations") + parser.add_argument("--tables-per-iteration", type=int, default=None, help="Tables per iteration (default = --threads)") + parser.add_argument("--validate", action="store_true", help="Quick validation: override to 1 iteration") + args = parser.parse_args() + + if args.tables_per_iteration is None: + args.tables_per_iteration = args.threads + + if args.tables_per_iteration > NUM_TABLES: + print(f"Error: --tables-per-iteration {args.tables_per_iteration} exceeds NUM_TABLES={NUM_TABLES}") + sys.exit(1) + + if args.validate: + args.iterations = 1 + print("=== VALIDATION MODE: 1 iteration only ===\n") + + os.makedirs(RESULTS_DIR, exist_ok=True) + prefix = f"rw_n{args.threads}_i{args.iterations}" + report_path = os.path.join(RESULTS_DIR, f"{prefix}_report.md") + data_path = os.path.join(RESULTS_DIR, f"{prefix}_data.jsonl") + log_path = os.path.join(RESULTS_DIR, f"{prefix}_retries.log") + + profile_handler = setup_logging(log_path) + + print(f"Profile (information_schema.column_tags): threads={args.threads}, iterations={args.iterations}, tables_per_iteration={args.tables_per_iteration}") + print(f"SELECTs per iteration: {args.tables_per_iteration} (1 per table)") + print(f"Total SELECTs: {args.tables_per_iteration * args.iterations}") + print(f"Output: {report_path}") + print() + + all_results = [] + iteration_durations = [] + + for i in range(1, args.iterations + 1): + print(f"Iteration {i}/{args.iterations}...", end=" ", flush=True) + results, duration = run_iteration(iteration=i, num_threads=args.threads, tables_per_iteration=args.tables_per_iteration) + all_results.extend(results) + iteration_durations.append(duration) + + errs = len([r for r in results if not r["select_success"]]) + rps = len(results) / duration if duration > 0 else 0 + print(f"done in {duration:.2f}s ({len(results)} SELECTs, {errs} errors, {rps:.1f} SELECTs/sec)") + + print() + + with open(data_path, "w") as f: + for r in all_results: + f.write(json.dumps(r) + "\n") + + generate_report(args, all_results, iteration_durations, profile_handler, report_path) + + print(f"Report written to: {report_path}") + print(f"Raw data written to: {data_path}") + print(f"Retry log written to: {log_path}") + + ok = [r for r in all_results if r["select_success"]] + if ok: + s = latency_stats([r["select_latency_ms"] for r in ok]) + total_dur = sum(iteration_durations) + print() + print("=== Summary ===") + print(f" SELECTs: {len(all_results)} ({len(ok)} ok, {len(all_results) - len(ok)} failed)") + print(f" Latency: p50={s['p50']:.1f}ms p90={s['p90']:.1f}ms p99={s['p99']:.1f}ms max={s['max']:.1f}ms") + print(f" Throughput: {len(all_results) / total_dur:.1f} SELECTs/sec") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/profile_table_tags.py b/examples/profile_table_tags.py new file mode 100644 index 000000000..f73d6074e --- /dev/null +++ b/examples/profile_table_tags.py @@ -0,0 +1,586 @@ +#!/usr/bin/env python3 +""" +Profile SET TABLE TAGS performance on Databricks. + +Uses existing tables (table1..table64). No column tags — only table-level tags. + +Usage: + # Quick validation + python examples/profile_table_tags.py --tags 1 --threads 1 --iterations 1 --validate + + # Single experiment + python examples/profile_table_tags.py --tags 4 --threads 8 --iterations 10 + + # Full sweep + for t in 1 2 4; do + for n in 1 2 4 8 16; do + python examples/profile_table_tags.py --tags $t --threads $n --iterations 10 + done + done +""" + +import argparse +import json +import logging +import os +import random +import re +import statistics +import string +import sys +import threading +import time +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime +from queue import Empty, Queue + +# Force unbuffered stdout so output is visible when piped through grep +sys.stdout.reconfigure(line_buffering=True) + +import urllib3 +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + +from databricks import sql + +# ============================================================ +# CONFIGURATION — loaded from examples/credentials.env +# ============================================================ +from load_credentials import load_credentials +_creds = load_credentials() +SERVER_HOSTNAME = _creds["SERVER_HOSTNAME"] +HTTP_PATH = _creds["HTTP_PATH"] +ACCESS_TOKEN = _creds["ACCESS_TOKEN"] +CATALOG = _creds["CATALOG"] +SCHEMA = _creds["SCHEMA"] +# ============================================================ + +NUM_TABLES = 128 +RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results", "table_tags") + + +# --------------------------------------------------------------------------- +# Logging setup +# --------------------------------------------------------------------------- + +class ProfileLogHandler(logging.Handler): + """Captures [PROFILE] log lines for retry analysis.""" + + def __init__(self): + super().__init__() + self.records: list = [] + + def emit(self, record): + msg = record.getMessage() + if "[PROFILE]" in msg: + self.records.append( + { + "timestamp": record.created, + "thread": record.threadName, + "message": msg, + } + ) + + +def setup_logging(log_path: str) -> ProfileLogHandler: + """Configure logging: file handler for all connector logs, profile handler for [PROFILE] lines.""" + profile_handler = ProfileLogHandler() + profile_handler.setLevel(logging.INFO) + + file_handler = logging.FileHandler(log_path, mode="w") + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter( + logging.Formatter("%(asctime)s %(threadName)s %(name)s %(levelname)s %(message)s") + ) + + for logger_name in [ + "databricks.sql.backend.thrift_backend", + "databricks.sql.auth.retry", + "databricks.sql.client", + ]: + lgr = logging.getLogger(logger_name) + lgr.setLevel(logging.DEBUG) + lgr.addHandler(profile_handler) + lgr.addHandler(file_handler) + + return profile_handler + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def conn_params() -> dict: + return { + "server_hostname": SERVER_HOSTNAME, + "http_path": HTTP_PATH, + "access_token": ACCESS_TOKEN, + "_tls_no_verify": True, + "enable_telemetry": False, + } + + +def random_tag_value(length: int = 5) -> str: + return "".join(random.choices(string.ascii_lowercase, k=length)) + + +def build_table_tag_sql(table_fqn: str, num_tags: int) -> str: + tags = ", ".join(f"'key{i}' = '{random_tag_value()}'" for i in range(1, num_tags + 1)) + return f"ALTER TABLE {table_fqn} SET TAGS ({tags})" + + +def percentile(data: list, p: float) -> float: + """Return the p-th percentile (0-100) of data.""" + if not data: + return 0.0 + sorted_data = sorted(data) + k = (len(sorted_data) - 1) * (p / 100.0) + f = int(k) + c = f + 1 + if c >= len(sorted_data): + return sorted_data[f] + return sorted_data[f] + (k - f) * (sorted_data[c] - sorted_data[f]) + + +def latency_stats(latencies: list) -> dict: + """Compute full latency statistics for a list of ms values.""" + if not latencies: + return {k: 0.0 for k in ["count", "min", "max", "mean", "stdev", "p50", "p90", "p95", "p99"]} + return { + "count": len(latencies), + "min": min(latencies), + "max": max(latencies), + "mean": statistics.mean(latencies), + "stdev": statistics.stdev(latencies) if len(latencies) > 1 else 0.0, + "p50": percentile(latencies, 50), + "p90": percentile(latencies, 90), + "p95": percentile(latencies, 95), + "p99": percentile(latencies, 99), + } + + +# --------------------------------------------------------------------------- +# Worker +# --------------------------------------------------------------------------- + +def worker( + thread_id: int, + table_queue: Queue, + num_tags: int, + alter_results: list, + results_lock: threading.Lock, +): + """Worker thread: pulls tables from queue, sets table-level tags, records metrics.""" + local_results = [] + table_fqn_prefix = f"`{CATALOG}`.`{SCHEMA}`" + + with sql.connect(**conn_params()) as connection: + with connection.cursor() as cursor: + # Warmup + cursor.execute("SELECT 1") + + while True: + try: + table_name = table_queue.get_nowait() + except Empty: + break + + table_fqn = f"{table_fqn_prefix}.{table_name}" + alter_sql = build_table_tag_sql(table_fqn, num_tags) + + cmd_start = time.perf_counter() + success = True + error_type = None + error_message = None + error_context = None + + try: + cursor.execute(alter_sql) + except Exception as e: + success = False + error_type = type(e).__name__ + error_message = str(e)[:500] + error_context = getattr(e, "context", None) + + cmd_end = time.perf_counter() + latency_ms = (cmd_end - cmd_start) * 1000 + + local_results.append( + { + "table": table_name, + "thread_id": thread_id, + "latency_ms": round(latency_ms, 2), + "success": success, + "error_type": error_type, + "error_message": error_message, + "error_context": str(error_context) if error_context else None, + "timestamp": cmd_start, + } + ) + + with results_lock: + alter_results.extend(local_results) + + +# --------------------------------------------------------------------------- +# Run one iteration +# --------------------------------------------------------------------------- + +def run_iteration( + iteration: int, + num_tags: int, + num_threads: int, + tables_per_iteration: int, +) -> tuple: + """Run a single iteration: tables_per_iteration tables distributed across threads.""" + table_queue = Queue() + start = ((iteration - 1) * tables_per_iteration) % NUM_TABLES + for i in range(tables_per_iteration): + table_idx = start + i + 1 + table_queue.put(f"table{table_idx}") + + alter_results: list = [] + results_lock = threading.Lock() + + iter_start = time.perf_counter() + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [] + for tid in range(num_threads): + f = executor.submit( + worker, + tid, + table_queue, + num_tags, + alter_results, + results_lock, + ) + futures.append(f) + + for f in as_completed(futures): + f.result() # raise any thread exceptions + + iter_end = time.perf_counter() + duration_s = iter_end - iter_start + + for r in alter_results: + r["iteration"] = iteration + + return alter_results, duration_s + + +# --------------------------------------------------------------------------- +# Report generation +# --------------------------------------------------------------------------- + +def generate_report( + args, + all_results: list, + iteration_durations: list, + profile_handler: ProfileLogHandler, + report_path: str, +): + """Generate the markdown report.""" + lines = [] + + def w(text=""): + lines.append(text) + + total_alters = len(all_results) + total_duration = sum(iteration_durations) + successful = [r for r in all_results if r["success"]] + failed = [r for r in all_results if not r["success"]] + success_latencies = [r["latency_ms"] for r in successful] + + # --- Header --- + w(f"# Table Tags Profile: T={args.tags}, N={args.threads}, I={args.iterations}") + w() + w("## Configuration") + w(f"- **Server**: `{SERVER_HOSTNAME}`") + w(f"- **HTTP Path**: `{HTTP_PATH}`") + w(f"- **Catalog.Schema**: `{CATALOG}.{SCHEMA}`") + w(f"- **Tables**: {NUM_TABLES}") + w(f"- **Tags per ALTER**: {args.tags}") + w(f"- **Threads**: {args.threads}") + w(f"- **Iterations**: {args.iterations}") + w(f"- **Total ALTERs**: {total_alters}") + w(f"- **Date**: {datetime.now().isoformat()}") + w() + + # --- Overall ALTER Latency --- + w("## Per-ALTER Latency — All Iterations (ms)") + w() + stats = latency_stats(success_latencies) + w("| Metric | Value |") + w("|--------|-------|") + for k, v in stats.items(): + w(f"| {k} | {v:.2f} |") + w() + + # --- Throughput --- + w("## Throughput") + w() + w(f"- **Total ALTERs**: {total_alters}") + w(f"- **Successful**: {len(successful)}") + w(f"- **Failed**: {len(failed)}") + w(f"- **Total wall-clock**: {total_duration:.2f}s") + if total_duration > 0: + w(f"- **ALTERs/sec**: {total_alters / total_duration:.2f}") + w() + + # --- Cold Start vs Steady State --- + if args.iterations > 1: + w("## Cold Start vs Steady State") + w() + iter1 = [r["latency_ms"] for r in successful if r["iteration"] == 1] + iter_rest = [r["latency_ms"] for r in successful if r["iteration"] > 1] + w("| Phase | ALTERs | Mean (ms) | P50 (ms) | P99 (ms) |") + w("|-------|--------|-----------|----------|----------|") + s1 = latency_stats(iter1) + sr = latency_stats(iter_rest) + w(f"| Iteration 1 | {s1['count']:.0f} | {s1['mean']:.2f} | {s1['p50']:.2f} | {s1['p99']:.2f} |") + w(f"| Iterations 2-{args.iterations} | {sr['count']:.0f} | {sr['mean']:.2f} | {sr['p50']:.2f} | {sr['p99']:.2f} |") + w() + + # --- Per-Iteration Summary --- + w("## Per-Iteration Summary") + w() + w("| Iteration | ALTERs | Mean (ms) | P50 (ms) | P99 (ms) | Errors | Duration (s) | ALTERs/sec |") + w("|-----------|--------|-----------|----------|----------|--------|--------------|------------|") + for i in range(1, args.iterations + 1): + iter_lats = [r["latency_ms"] for r in successful if r["iteration"] == i] + iter_errs = len([r for r in failed if r["iteration"] == i]) + s = latency_stats(iter_lats) + dur = iteration_durations[i - 1] + alters_in_iter = len([r for r in all_results if r["iteration"] == i]) + rps = alters_in_iter / dur if dur > 0 else 0 + w( + f"| {i} | {s['count']:.0f} | {s['mean']:.2f} | {s['p50']:.2f} | {s['p99']:.2f} " + f"| {iter_errs} | {dur:.2f} | {rps:.2f} |" + ) + w() + + # --- Per-ALTER Latency by Table --- + w("## Per-ALTER Latency by Table (ms)") + w() + w("| Table | Count | Min | Max | Mean | P50 | P90 | P95 | P99 |") + w("|-------|-------|-----|-----|------|-----|-----|-----|-----|") + tables_seen = sorted(set(r["table"] for r in successful), key=lambda x: int(x.replace("table", ""))) + for tbl in tables_seen: + tbl_lats = [r["latency_ms"] for r in successful if r["table"] == tbl] + s = latency_stats(tbl_lats) + w( + f"| {tbl} | {s['count']:.0f} | {s['min']:.2f} | {s['max']:.2f} " + f"| {s['mean']:.2f} | {s['p50']:.2f} | {s['p90']:.2f} | {s['p95']:.2f} | {s['p99']:.2f} |" + ) + w() + + # --- Per-ALTER Latency by Thread --- + w("## Per-ALTER Latency by Thread (ms)") + w() + w("| Thread | Count | Min | Max | Mean | P50 | P90 | P95 | P99 |") + w("|--------|-------|-----|-----|------|-----|-----|-----|-----|") + threads_seen = sorted(set(r["thread_id"] for r in successful)) + for tid in threads_seen: + thr_lats = [r["latency_ms"] for r in successful if r["thread_id"] == tid] + s = latency_stats(thr_lats) + w( + f"| {tid} | {s['count']:.0f} | {s['min']:.2f} | {s['max']:.2f} " + f"| {s['mean']:.2f} | {s['p50']:.2f} | {s['p90']:.2f} | {s['p95']:.2f} | {s['p99']:.2f} |" + ) + w() + + # --- Error Analysis --- + w("## Error Analysis") + w() + if not failed: + w("No errors encountered.") + else: + error_groups = defaultdict(list) + for r in failed: + error_groups[r["error_type"]].append(r) + w("| Error Type | Count | % of Total | Sample Message |") + w("|------------|-------|------------|----------------|") + for etype, records in sorted(error_groups.items(), key=lambda x: -len(x[1])): + pct = len(records) / total_alters * 100 + sample = records[0]["error_message"][:200] if records[0]["error_message"] else "N/A" + w(f"| {etype} | {len(records)} | {pct:.1f}% | {sample} |") + w() + + w("### Error Detail") + w() + for etype, records in sorted(error_groups.items(), key=lambda x: -len(x[1])): + w(f"**{etype}** ({len(records)} occurrences)") + w() + for r in records[:3]: + w(f"- Table: {r['table']}, Iteration: {r['iteration']}") + w(f" Latency: {r['latency_ms']:.2f}ms") + w(f" Message: {r['error_message']}") + if r["error_context"]: + w(f" Context: {r['error_context']}") + if len(records) > 3: + w(f"- ... and {len(records) - 3} more") + w() + w() + + # --- Retry Analysis --- + ATTEMPT_RE = re.compile(r"\[PROFILE\] (?P\w+) attempt (?P\d+)/(?P\d+)") + SUCCESS_RE = re.compile(r"\[PROFILE\] (?P\w+) succeeded on attempt (?P\d+) in (?P[0-9.]+)s") + SHOULD_RETRY_RE = re.compile(r"\[PROFILE\] should_retry: status=(?P\d+), command=(?P[^,]+),") + RETRY_SLEEP_RE = re.compile(r"\[PROFILE\] (?P\w+) retry sleep=(?P[0-9.]+)s, attempt=(?P\d+)/(?P\d+)") + + parsed_events = [] + for r in profile_handler.records: + msg = r["message"] + for etype, regex in [("attempt", ATTEMPT_RE), ("success", SUCCESS_RE), + ("should_retry", SHOULD_RETRY_RE), ("retry_sleep", RETRY_SLEEP_RE)]: + m = regex.search(msg) + if m: + event = {"type": etype, "thread": r["thread"], "timestamp": r["timestamp"], "message": msg} + event.update(m.groupdict()) + if "attempt" in event: + event["attempt"] = int(event["attempt"]) + if "status" in event: + event["status"] = int(event["status"]) + parsed_events.append(event) + break + + exec_events = [e for e in parsed_events if e.get("cmd") == "ExecuteStatement"] + exec_retry_sleeps = [e for e in exec_events if e["type"] == "retry_sleep"] + exec_should_retry = [e for e in exec_events if e["type"] == "should_retry"] + exec_success_after_retry = [e for e in exec_events if e["type"] == "success" and e["attempt"] > 1] + exec_total_attempts = [e for e in exec_events if e["type"] == "attempt"] + exec_successes = [e for e in exec_events if e["type"] == "success"] + + w("## Statement Retry Analysis (ExecuteStatement only)") + w() + w("*Includes benchmarked ALTERs + one warmup SELECT 1 per worker thread.*") + w() + w(f"- **Total [PROFILE] events (all commands)**: {len(parsed_events)}") + w(f"- **ExecuteStatement attempts**: {len(exec_total_attempts)}") + w(f"- **ExecuteStatement successes**: {len(exec_successes)}") + w(f"- **ExecuteStatement retry sleeps**: {len(exec_retry_sleeps)}") + w(f"- **ExecuteStatement succeeded after retry (attempt > 1)**: {len(exec_success_after_retry)}") + w(f"- **should_retry evaluations**: {len(exec_should_retry)}") + w() + + if exec_retry_sleeps: + w("### Retry Events") + w() + w("| Timestamp | Thread | Attempt | Sleep (s) | Message |") + w("|-----------|--------|---------|-----------|---------|") + for e in exec_retry_sleeps[:50]: + ts = datetime.fromtimestamp(e["timestamp"]).strftime("%H:%M:%S.%f")[:-3] + w(f"| {ts} | {e['thread']} | {e['attempt']} | {e.get('sleep', '?')} | {e['message'][:150]} |") + if len(exec_retry_sleeps) > 50: + w(f"| ... | ... | ... | ... | {len(exec_retry_sleeps) - 50} more |") + w() + + if exec_should_retry: + w("### should_retry Decisions") + w() + status_counts = defaultdict(int) + for e in exec_should_retry: + status_counts[e["status"]] += 1 + w("| HTTP Status | Count |") + w("|-------------|-------|") + for status, count in sorted(status_counts.items()): + w(f"| {status} | {count} |") + w() + + # --- Footer --- + w("---") + w(f"*Generated by profile_table_tags.py on {datetime.now().isoformat()}*") + + report_text = "\n".join(lines) + with open(report_path, "w") as f: + f.write(report_text) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser(description="Profile SET TABLE TAGS performance") + parser.add_argument("--tags", type=int, required=True, help="Number of tags per ALTER command") + parser.add_argument("--threads", type=int, required=True, help="Number of concurrent threads") + parser.add_argument("--iterations", type=int, required=True, help="Number of iterations") + parser.add_argument("--tables-per-iteration", type=int, default=None, help="Tables per iteration (default = --threads)") + parser.add_argument("--validate", action="store_true", help="Quick validation: override to 1 iteration, print result") + args = parser.parse_args() + + if args.tables_per_iteration is None: + args.tables_per_iteration = args.threads + + if args.tables_per_iteration > NUM_TABLES: + print(f"Error: --tables-per-iteration {args.tables_per_iteration} exceeds NUM_TABLES={NUM_TABLES}") + sys.exit(1) + + if args.validate: + args.iterations = 1 + print("=== VALIDATION MODE: 1 iteration only ===\n") + + # File paths + os.makedirs(RESULTS_DIR, exist_ok=True) + prefix = f"tt_t{args.tags}_n{args.threads}_i{args.iterations}" + report_path = os.path.join(RESULTS_DIR, f"{prefix}_report.md") + data_path = os.path.join(RESULTS_DIR, f"{prefix}_data.jsonl") + log_path = os.path.join(RESULTS_DIR, f"{prefix}_retries.log") + + # Logging + profile_handler = setup_logging(log_path) + + print(f"Profile (TABLE TAGS): tags={args.tags}, threads={args.threads}, iterations={args.iterations}, tables_per_iteration={args.tables_per_iteration}") + print(f"ALTERs per iteration: {args.tables_per_iteration} (one per table)") + print(f"Total ALTERs: {args.tables_per_iteration * args.iterations}") + print(f"Output: {report_path}") + print() + + # Run iterations + all_results = [] + iteration_durations = [] + + for i in range(1, args.iterations + 1): + print(f"Iteration {i}/{args.iterations}...", end=" ", flush=True) + results, duration = run_iteration( + iteration=i, + num_tags=args.tags, + num_threads=args.threads, + tables_per_iteration=args.tables_per_iteration, + ) + all_results.extend(results) + iteration_durations.append(duration) + + errors = len([r for r in results if not r["success"]]) + rps = len(results) / duration if duration > 0 else 0 + print(f"done in {duration:.2f}s ({len(results)} ALTERs, {errors} errors, {rps:.1f} ALTERs/sec)") + + print() + + # Write raw data + with open(data_path, "w") as f: + for r in all_results: + f.write(json.dumps(r) + "\n") + + # Generate report + generate_report(args, all_results, iteration_durations, profile_handler, report_path) + + print(f"Report written to: {report_path}") + print(f"Raw data written to: {data_path}") + print(f"Retry log written to: {log_path}") + + # Print summary to stdout + success_lats = [r["latency_ms"] for r in all_results if r["success"]] + if success_lats: + s = latency_stats(success_lats) + total_dur = sum(iteration_durations) + print() + print("=== Summary ===") + print(f" ALTERs: {len(all_results)} ({len(success_lats)} ok, {len(all_results) - len(success_lats)} failed)") + print(f" Latency: p50={s['p50']:.1f}ms p90={s['p90']:.1f}ms p95={s['p95']:.1f}ms p99={s['p99']:.1f}ms max={s['max']:.1f}ms") + print(f" Throughput: {len(all_results) / total_dur:.1f} ALTERs/sec") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/test_connection.py b/examples/test_connection.py new file mode 100644 index 000000000..0332a2980 --- /dev/null +++ b/examples/test_connection.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +"""Quick smoke test: connect and run one CREATE TABLE.""" + +import logging +import time +from databricks import sql +from load_credentials import load_credentials + +logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(name)s %(levelname)s %(message)s") + +_creds = load_credentials() +SERVER_HOSTNAME = _creds["SERVER_HOSTNAME"] +HTTP_PATH = _creds["HTTP_PATH"] +ACCESS_TOKEN = _creds["ACCESS_TOKEN"] +CATALOG = _creds["CATALOG"] +SCHEMA = _creds["SCHEMA"] + +print("Connecting...") +t0 = time.time() + +with sql.connect( + server_hostname=SERVER_HOSTNAME, + http_path=HTTP_PATH, + access_token=ACCESS_TOKEN, + _tls_no_verify=True, + enable_telemetry=False, +) as conn: + print(f"Connected in {time.time() - t0:.2f}s") + + with conn.cursor() as cursor: + cursor.execute(f"USE CATALOG {CATALOG}") + print(f"USE CATALOG done in {time.time() - t0:.2f}s") + + cursor.execute(f"USE SCHEMA {SCHEMA}") + print(f"USE SCHEMA done in {time.time() - t0:.2f}s") + + t1 = time.time() + cursor.execute("CREATE TABLE IF NOT EXISTS test_conn_check (col1 STRING, col2 STRING)") + print(f"CREATE TABLE done in {time.time() - t1:.2f}s") + + t1 = time.time() + cursor.execute("SELECT 1") + print(f"SELECT 1 done in {time.time() - t1:.2f}s") + + t1 = time.time() + cursor.execute("DROP TABLE IF EXISTS test_conn_check") + print(f"DROP TABLE done in {time.time() - t1:.2f}s") + +print(f"Total: {time.time() - t0:.2f}s") \ No newline at end of file diff --git a/src/databricks/sql/auth/retry.py b/src/databricks/sql/auth/retry.py index b0c2f497d..a8b921296 100755 --- a/src/databricks/sql/auth/retry.py +++ b/src/databricks/sql/auth/retry.py @@ -231,11 +231,15 @@ def new( # Include urllib3's current state in our __init__ params databricks_init_params["urllib3_kwargs"].update(**urllib3_init_params) # type: ignore[attr-defined] - return type(self).__private_init__( + new_instance = type(self).__private_init__( retry_start_time=self._retry_start_time, command_type=self.command_type, **databricks_init_params, ) + # Carry profiling state across retries + new_instance.thrift_method_name = getattr(self, "thrift_method_name", None) + new_instance.last_sql_statement = getattr(self, "last_sql_statement", None) + return new_instance @property def command_type(self) -> Optional[CommandType]: @@ -294,9 +298,16 @@ def sleep_for_retry(self, response: BaseHTTPResponse) -> bool: else: proposed_wait = self.get_backoff_time() - proposed_wait = max(proposed_wait, self.delay_max) + proposed_wait = min(proposed_wait, self.delay_max) self.check_proposed_wait(proposed_wait) - logger.debug(f"Retrying after {proposed_wait} seconds") + logger.info( + "[PROFILE] urllib3_retry sleep=%.1fs, command=%s, method=%s, sql=%s, retry_after=%s", + proposed_wait, + self.command_type and self.command_type.value, + getattr(self, "thrift_method_name", "unknown"), + getattr(self, "last_sql_statement", None), + retry_after, + ) time.sleep(proposed_wait) return True @@ -358,6 +369,14 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]: if status_code // 100 <= 3: return False, "2xx/3xx codes are not retried" + logger.info( + "[PROFILE] should_retry: status=%d, command=%s, method=%s, sql=%s, evaluating_retry", + status_code, + self.command_type and self.command_type.value, + getattr(self, "thrift_method_name", "unknown"), + getattr(self, "last_sql_statement", None), + ) + if status_code == 400: return ( False, @@ -416,6 +435,13 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]: logger.debug( f"This request should be retried: {self.command_type and self.command_type.value}" ) + logger.info( + "[PROFILE] should_retry: status=%d, command=%s, method=%s, sql=%s, decision=True, reason=default_retry_policy", + status_code, + self.command_type and self.command_type.value, + getattr(self, "thrift_method_name", "unknown"), + getattr(self, "last_sql_statement", None), + ) return ( True, "Failed requests are retried by default per configured DatabricksRetryPolicy", diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index e23f3389b..171ef4f0d 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -354,6 +354,16 @@ def _handle_request_error(self, error_info, attempt, elapsed): error_info.retry_delay, full_error_info_context ) ) + logger.info( + "[PROFILE] %s retry sleep=%.1fs, attempt=%d/%d, elapsed=%.1fs/%ds, http_code=%s", + error_info.method, + error_info.retry_delay, + attempt, + max_attempts, + elapsed, + max_duration_s, + error_info.http_code, + ) time.sleep(error_info.retry_delay) # FUTURE: Consider moving to https://github.com/litl/backoff or @@ -410,6 +420,14 @@ def attempt_request(attempt): logger.debug("Sending request: {}()".format(this_method_name)) unsafe_logger.debug("Sending request: {}".format(request)) + # Always set the method name and SQL text for profiling + if hasattr(self._transport, "retry_policy") and self._transport.retry_policy: + self._transport.retry_policy.thrift_method_name = this_method_name + sql_statement = getattr(request, "statement", None) + self._transport.retry_policy.last_sql_statement = ( + sql_statement[:200] if sql_statement else None + ) + # These three lines are no-ops if the v3 retry policy is not in use if self.enable_v3_retries: this_command_type = CommandType.get(this_method_name) @@ -506,6 +524,13 @@ def attempt_request(attempt): # use index-1 counting for logging/human consistency for attempt in range(1, max_attempts + 1): + logger.info( + "[PROFILE] %s attempt %d/%d (elapsed=%.3fs)", + getattr(method, "__name__", "unknown"), + attempt, + max_attempts, + get_elapsed(), + ) # We have a lock here because .cancel can be called from a separate thread. # We do not want threads to be simultaneously sharing the Thrift Transport # because we use its state to determine retries @@ -515,7 +540,12 @@ def attempt_request(attempt): # conditions: success, non-retry-able, no-attempts-left, no-time-left, delay+retry if not isinstance(response_or_error_info, RequestErrorInfo): - # log nothing here, presume that main request logging covers + logger.info( + "[PROFILE] %s succeeded on attempt %d in %.3fs", + getattr(method, "__name__", "unknown"), + attempt, + elapsed, + ) response = response_or_error_info ThriftDatabricksClient._check_response_for_error(response, self._host) return response @@ -1059,6 +1089,14 @@ def execute_command( ) resp = self.make_request(self._client.ExecuteStatement, req) + if resp.operationHandle: + _cmd_id = CommandId.from_thrift_handle(resp.operationHandle) + logger.info( + "[PROFILE] ExecuteStatement statement_id=%s, sql=%s", + _cmd_id, + operation[:200] if operation else None, + ) + if async_op: self._handle_execute_response_async(resp, cursor) return None