Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion KERNEL_REV
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ec2288742cbac0cd9fab50da353e8405972eefe9
b4d88220cdfad8dba1cfa89892269342ae26feeb
69 changes: 69 additions & 0 deletions src/databricks/sql/backend/kernel/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ def __init__(
# OAuth secret is consumed during ``auth_provider`` construction
# and isn't recoverable from the built provider.
self._auth_options = kwargs.get("auth_options") or {}
# Connector retry-tuning kwargs (the ``_retry_*`` family),
# forwarded so the kernel's own retry loop honours them. Mapped
# to the kernel ``Session``'s ``retry_*`` kwargs in
# ``open_session`` via ``_kernel_retry_kwargs``.
self._retry_options = kwargs.get("retry_options") or {}
self._catalog = catalog
self._schema = schema
# ``_use_arrow_native_complex_types`` is the connector-side
Expand Down Expand Up @@ -179,6 +184,9 @@ def open_session(
# Translate the connector's SSLOptions into the kernel's
# ``tls_*`` Session kwargs. Empty when TLS is left at defaults.
tls_kwargs = _kernel_tls_kwargs(self._ssl_options)
# Translate the connector's ``_retry_*`` kwargs into the kernel's
# ``retry_*`` Session kwargs. Empty when retry is left at defaults.
retry_kwargs = _kernel_retry_kwargs(self._retry_options)
try:
self._kernel_session = _kernel.Session(
host=self._server_hostname,
Expand All @@ -199,6 +207,7 @@ def open_session(
intervals_as_string=True,
**auth_kwargs,
**tls_kwargs,
**retry_kwargs,
)
except Exception as exc:
raise _wrap_kernel_exception("open_session", exc) from exc
Expand Down Expand Up @@ -729,6 +738,66 @@ def _kernel_tls_kwargs(ssl_options) -> Dict[str, Any]:
return kwargs


def _kernel_retry_kwargs(retry_options: Dict[str, Any]) -> Dict[str, Any]:
"""Translate the connector's ``_retry_*`` tuning into the kernel
``Session``'s ``retry_*`` kwargs.

Only knobs the caller actually set are emitted, so an untuned
connection produces an empty dict (kernel keeps its default policy:
1s/60s backoff, 6 total attempts, 900s budget).

Mappings (connector → kernel):

- ``retry_delay_min`` (float secs) → ``retry_min_wait_secs``
- ``retry_delay_max`` (float secs) → ``retry_max_wait_secs``
- ``retry_stop_after_attempts_count`` (int, **total** attempts) →
``retry_max_attempts`` (1:1 — the kernel converts to its
retries-after-first internally)
- ``retry_stop_after_attempts_duration`` (float secs) →
``retry_overall_timeout_secs``

The connector expresses delays/durations as **floats in seconds**;
the kernel takes **whole seconds** (``u64``). We round to the
nearest second, with a floor of 1s for any positive sub-second
value so a configured delay never collapses to "no wait".

``_retry_delay_default`` has no kernel counterpart and is ignored:
the kernel's no-``Retry-After`` backoff is exponential from
``retry_min_wait``, which already plays that role.
"""
kwargs: Dict[str, Any] = {}

def _secs(value: Any) -> Optional[int]:
if value is None:
return None
rounded = round(float(value))
# Never round a positive delay down to 0 — that would turn a
# configured backoff into a busy-retry. Floor at 1s.
if rounded <= 0 and float(value) > 0:
return 1
return rounded

min_wait = _secs(retry_options.get("retry_delay_min"))
if min_wait is not None:
kwargs["retry_min_wait_secs"] = min_wait

max_wait = _secs(retry_options.get("retry_delay_max"))
if max_wait is not None:
kwargs["retry_max_wait_secs"] = max_wait

count = retry_options.get("retry_stop_after_attempts_count")
if count is not None:
# Total-attempts count, forwarded 1:1; the kernel converts to
# its retries-after-first representation.
kwargs["retry_max_attempts"] = int(count)

duration = _secs(retry_options.get("retry_stop_after_attempts_duration"))
if duration is not None:
kwargs["retry_overall_timeout_secs"] = duration

return kwargs


def _read_pem_bytes(path: str, label: str) -> bytes:
"""Read a PEM file into bytes, mapping IO errors to a clear
``ProgrammingError`` that names the offending TLS option. An empty
Expand Down
18 changes: 18 additions & 0 deletions src/databricks/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,23 @@ def _create_backend(
"oauth_scopes": kwargs.get("oauth_scopes"),
"credentials_provider": kwargs.get("credentials_provider"),
}
# Forward the connector's retry-tuning kwargs so the kernel's
# own retry policy honours them (the kernel owns the retry
# loop on this path). Only the keys with a kernel counterpart
# are passed; `_retry_delay_default` is intentionally omitted
# (the kernel's no-Retry-After backoff is exponential from
# its min-wait, so a flat default delay has no equivalent).
# Kernel-only; Thrift / SEA are unaffected.
kernel_retry_options = {
"retry_delay_min": kwargs.get("_retry_delay_min"),
"retry_delay_max": kwargs.get("_retry_delay_max"),
"retry_stop_after_attempts_count": kwargs.get(
"_retry_stop_after_attempts_count"
),
"retry_stop_after_attempts_duration": kwargs.get(
"_retry_stop_after_attempts_duration"
),
}
return KernelDatabricksClient(
server_hostname=server_hostname,
http_path=http_path,
Expand All @@ -185,6 +202,7 @@ def _create_backend(
schema=kwargs.get("schema"),
_use_arrow_native_complex_types=_use_arrow_native_complex_types,
auth_options=kernel_auth_options,
retry_options=kernel_retry_options,
)

databricks_client_class: Type[DatabricksClient]
Expand Down
41 changes: 40 additions & 1 deletion tests/e2e/test_kernel_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import databricks.sql as sql
from databricks.sql.exc import DatabaseError


# Skip the whole module unless the kernel wheel is genuinely installed.
# ``pytest.importorskip`` alone isn't enough: the kernel unit tests inject a
# fake ``databricks_sql_kernel`` ModuleType into ``sys.modules`` so the
Expand Down Expand Up @@ -183,6 +182,46 @@ def test_session_configuration_round_trips(kernel_conn_params):
assert kv.get("ANSI_MODE") == "false", f"got {rows!r}"


def test_retry_params_accepted_end_to_end(kernel_conn_params):
"""The connector's `_retry_*` tuning kwargs are translated to the
kernel `Session`'s `retry_*` kwargs and accepted end-to-end. We
can't easily force a retry against a live warehouse, so this is a
smoke test: a connection configured with explicit retry params
opens and runs a query successfully (proving the kwargs reach and
are accepted by the kernel)."""
params = dict(kernel_conn_params)
params.update(
_retry_delay_min=2,
_retry_delay_max=30,
_retry_stop_after_attempts_count=4,
_retry_stop_after_attempts_duration=120,
)
with sql.connect(**params) as c:
with c.cursor() as cur:
cur.execute("SELECT 1 AS n")
assert cur.fetchall()[0][0] == 1


def test_enable_metric_view_metadata_lists_metric_view_table_type(kernel_conn_params):
"""`enable_metric_view_metadata=True` injects the
`spark.sql.thriftserver.metadata.metricview.enabled` session conf,
which the kernel now passes through (verbatim) so the server
surfaces `METRIC_VIEW` in `cursor.tables()`'s table-type column.

We assert the connection opens and `tables()` runs; the kernel
already lists `METRIC_VIEW` among its table types, and the conf
enables the server side. Not asserting a specific metric view
exists in the catalog (workspace-dependent)."""
params = dict(kernel_conn_params)
params["enable_metric_view_metadata"] = True
with sql.connect(**params) as c:
with c.cursor() as cur:
# Smoke: the conf was accepted (no SqlError on open) and a
# metadata call works with it set.
cur.tables()
cur.fetchall()


# ── Error mapping ─────────────────────────────────────────────────


Expand Down
59 changes: 59 additions & 0 deletions tests/unit/test_kernel_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,3 +956,62 @@ def test_empty_ca_file_raises_programming_error(self, tmp_path):
ca.write_bytes(b" \n")
with pytest.raises(ProgrammingError, match="is empty"):
kernel_client._kernel_tls_kwargs(SSLOptions(tls_trusted_ca_file=str(ca)))


# ---------------------------------------------------------------------------
# Retry translation: connector _retry_* -> kernel Session retry kwargs.
# ---------------------------------------------------------------------------


class TestKernelRetryKwargs:
"""``_kernel_retry_kwargs`` maps the connector's ``_retry_*`` tuning
onto the kernel ``Session``'s ``retry_*`` kwargs, rounding float
seconds to whole seconds and forwarding the total-attempts count
1:1 (the kernel does the retries-after-first conversion)."""

def test_empty_options_emit_no_kwargs(self):
assert kernel_client._kernel_retry_kwargs({}) == {}

def test_all_options_mapped(self):
out = kernel_client._kernel_retry_kwargs(
{
"retry_delay_min": 2.0,
"retry_delay_max": 90.0,
"retry_stop_after_attempts_count": 10,
"retry_stop_after_attempts_duration": 600.0,
}
)
assert out == {
"retry_min_wait_secs": 2,
"retry_max_wait_secs": 90,
"retry_max_attempts": 10,
"retry_overall_timeout_secs": 600,
}

def test_count_forwarded_one_to_one(self):
# Total-attempts count is passed verbatim; the kernel converts
# to retries-after-first internally (so 1 means a single attempt).
out = kernel_client._kernel_retry_kwargs({"retry_stop_after_attempts_count": 1})
assert out == {"retry_max_attempts": 1}

def test_float_seconds_rounded(self):
out = kernel_client._kernel_retry_kwargs(
{"retry_delay_min": 2.4, "retry_delay_max": 2.6}
)
assert out == {"retry_min_wait_secs": 2, "retry_max_wait_secs": 3}

def test_subsecond_delay_floored_to_one(self):
# A positive sub-second delay (the connector allows 0.1) must
# not round down to 0 — that would turn backoff into busy-retry.
out = kernel_client._kernel_retry_kwargs({"retry_delay_min": 0.1})
assert out == {"retry_min_wait_secs": 1}

def test_only_set_keys_emitted(self):
out = kernel_client._kernel_retry_kwargs({"retry_delay_max": 30.0})
assert out == {"retry_max_wait_secs": 30}

def test_retry_delay_default_has_no_mapping(self):
# _retry_delay_default isn't forwarded by session.py and isn't a
# recognised key here — it has no kernel equivalent.
out = kernel_client._kernel_retry_kwargs({"retry_delay_default": 5.0})
assert out == {}
58 changes: 58 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,3 +408,61 @@ def test_use_kernel_pat_builds_minimal_access_token_provider(self):
# PAT path: a minimal AccessTokenAuthProvider, not the
# federation-wrapped connector provider.
assert isinstance(sess.auth_provider, AccessTokenAuthProvider)


class TestKernelRetryOptionsThreading:
"""The connector's ``_retry_*`` kwargs must be forwarded into the
kernel client's ``retry_options`` on the use_kernel path (the kernel
owns the retry loop). Stubs ``_create_backend`` so the kernel client
is never imported — wheel-independent — and inspects the
``retry_options`` dict session.py builds by patching the kernel
client and capturing its call args.
"""

PACKAGE = "databricks.sql"

def test_retry_kwargs_threaded_into_kernel_client(self):
import sys
import types

# The lazy ``from databricks.sql.backend.kernel.client import
# KernelDatabricksClient`` triggers ``import databricks_sql_kernel``
# at module load; the unit-test job has no Rust wheel, so inject
# a fake module (scoped via patch.dict) before connect() runs.
fake = types.ModuleType("databricks_sql_kernel")
fake.KernelError = type("KernelError", (Exception,), {})
fake.Session = MagicMock()

# Patch the kernel client class (imported lazily inside
# _create_backend) and the provider builder; capture the kwargs
# session.py passes to the kernel client.
with patch.dict(sys.modules, {"databricks_sql_kernel": fake}), patch(
"databricks.sql.backend.kernel.client.KernelDatabricksClient"
) as mock_kernel_client, patch(
"%s.session.get_python_sql_connector_auth_provider" % self.PACKAGE
):
instance = mock_kernel_client.return_value
instance.open_session.return_value = SessionId(
BackendType.SEA, "sess-id", None
)

conn = databricks.sql.connect(
server_hostname="foo",
http_path="/sql/1.0/warehouses/abc",
use_kernel=True,
access_token="dapi-xyz",
enable_telemetry=False,
_retry_delay_min=2.0,
_retry_delay_max=90.0,
_retry_stop_after_attempts_count=10,
_retry_stop_after_attempts_duration=600.0,
)
try:
_, kwargs = mock_kernel_client.call_args
opts = kwargs["retry_options"]
assert opts["retry_delay_min"] == 2.0
assert opts["retry_delay_max"] == 90.0
assert opts["retry_stop_after_attempts_count"] == 10
assert opts["retry_stop_after_attempts_duration"] == 600.0
finally:
conn.close()
Loading