Skip to content
5 changes: 5 additions & 0 deletions crates/core/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,11 @@ impl PyExpr {
expr.into()
}

pub fn try_cast(&self, to: PyArrowType<DataType>) -> PyExpr {
let expr = Expr::TryCast(TryCast::new(Box::new(self.expr.clone()), to.0));
expr.into()
}

#[pyo3(signature = (low, high, negated=false))]
pub fn between(&self, low: PyExpr, high: PyExpr, negated: bool) -> PyExpr {
let expr = Expr::Between(Between::new(
Expand Down
16 changes: 16 additions & 0 deletions crates/core/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,19 @@ expr_fn_vec!(named_struct);
expr_fn!(from_unixtime, unixtime);
expr_fn!(arrow_typeof, arg_1);
expr_fn!(arrow_cast, arg_1 datatype);
expr_fn!(arrow_try_cast, arg_1 datatype);
expr_fn!(arrow_field, arg_1);
#[pyfunction]
#[pyo3(signature = (arg_1, reference, *, try_cast = false))]
fn cast_to_type(arg_1: PyExpr, reference: PyExpr, try_cast: bool) -> PyExpr {
if try_cast {
functions::expr_fn::try_cast_to_type(arg_1.into(), reference.into()).into()
} else {
functions::expr_fn::cast_to_type(arg_1.into(), reference.into()).into()
}
}
expr_fn_vec!(arrow_metadata);
expr_fn_vec!(with_metadata);
expr_fn!(union_tag, arg1);
expr_fn!(random);

Expand Down Expand Up @@ -962,7 +974,11 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(array_agg))?;
m.add_wrapped(wrap_pyfunction!(arrow_typeof))?;
m.add_wrapped(wrap_pyfunction!(arrow_cast))?;
m.add_wrapped(wrap_pyfunction!(arrow_try_cast))?;
m.add_wrapped(wrap_pyfunction!(arrow_field))?;
m.add_wrapped(wrap_pyfunction!(cast_to_type))?;
m.add_wrapped(wrap_pyfunction!(arrow_metadata))?;
m.add_wrapped(wrap_pyfunction!(with_metadata))?;
m.add_wrapped(wrap_pyfunction!(ascii))?;
m.add_wrapped(wrap_pyfunction!(asin))?;
m.add_wrapped(wrap_pyfunction!(asinh))?;
Expand Down
22 changes: 22 additions & 0 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,28 @@ def cast(self, to: pa.DataType[Any] | type) -> Expr:

return Expr(self.expr.cast(to))

def try_cast(self, to: pa.DataType[Any] | type) -> Expr:
"""Cast to a new data type, returning NULL on failure.

Like :py:meth:`cast` but produces NULL instead of erroring when the
cast cannot be performed for a given row.

Examples:
>>> ctx = dfn.SessionContext()
>>> df = ctx.from_pydict({"a": ["oops"]})
>>> result = df.select(col("a").try_cast(pa.float64()).alias("c"))
>>> result.collect_column("c")[0].as_py() is None
True
"""
if not isinstance(to, pa.DataType):
try:
to = self._to_pyarrow_types[to]
except KeyError as err:
error_msg = "Expected instance of pyarrow.DataType or builtins.type"
raise TypeError(error_msg) from err

return Expr(self.expr.try_cast(to))

def between(self, low: Any, high: Any, negated: bool = False) -> Expr:
"""Returns ``True`` if this expression is between a given range.

Expand Down
128 changes: 128 additions & 0 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@
"arrays_overlap",
"arrays_zip",
"arrow_cast",
"arrow_field",
"arrow_metadata",
"arrow_try_cast",
"arrow_typeof",
"ascii",
"asin",
Expand All @@ -138,6 +140,7 @@
"btrim",
"cardinality",
"case",
"cast_to_type",
"cbrt",
"ceil",
"char_length",
Expand Down Expand Up @@ -368,6 +371,7 @@
"var_sample",
"version",
"when",
"with_metadata",
]


Expand Down Expand Up @@ -2930,6 +2934,95 @@ def arrow_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr:
return Expr(f.arrow_cast(expr.expr, data_type.expr))


def arrow_try_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr:
"""Casts an expression to a specified data type, returning NULL on failure.

Like :py:func:`arrow_cast` but produces NULL instead of erroring when the
cast cannot be performed. The ``data_type`` may be a string in DataFusion
type syntax (for example ``"Float64"``), a ``pyarrow.DataType``, or an
``Expr`` of string type.

Examples:
>>> ctx = dfn.SessionContext()
>>> df = ctx.from_pydict({"a": ["oops"]})
>>> result = df.select(
... dfn.functions.arrow_try_cast(dfn.col("a"), "Float64").alias("c")
... )
>>> result.collect_column("c")[0].as_py() is None
True

>>> result = df.select(
... dfn.functions.arrow_try_cast(
... dfn.col("a"), data_type=pa.float64()
... ).alias("c")
... )
>>> result.collect_column("c")[0].as_py() is None
True
"""
if isinstance(data_type, pa.DataType):
return expr.try_cast(data_type)
if isinstance(data_type, str):
data_type = Expr.string_literal(data_type)
return Expr(f.arrow_try_cast(expr.expr, data_type.expr))


def arrow_field(expr: Expr) -> Expr:
"""Returns the Arrow field information of an expression as a struct.

The returned struct contains the field's name, data type, nullability,
and metadata.

Examples:
>>> field = pa.field("val", pa.int64(), metadata={"k": "v"})
>>> schema = pa.schema([field])
>>> batch = pa.RecordBatch.from_arrays([pa.array([1])], schema=schema)
>>> ctx = dfn.SessionContext()
>>> df = ctx.create_dataframe([[batch]])
>>> result = df.select(
... dfn.functions.arrow_field(dfn.col("val")).alias("f")
... )
>>> out = result.collect_column("f")[0].as_py()
>>> out["name"], out["data_type"], out["nullable"], out["metadata"]
('val', 'Int64', True, [('k', 'v')])
"""
return Expr(f.arrow_field(expr.expr))


def cast_to_type(value: Expr, type_ref: Expr, *, try_cast: bool = False) -> Expr:
"""Casts ``value`` to the data type of ``type_ref``.

Only the *type* of ``type_ref`` is used; its value is ignored. This is
useful when the target type comes from another column or expression
rather than being known up-front. When ``try_cast=True``, casts that
fail produce NULL instead of erroring.

If the target type is known statically, prefer :py:func:`arrow_cast`
(or :py:func:`arrow_try_cast` for the NULL-on-failure variant) and
pass a type string or ``pyarrow.DataType`` directly.

Examples:
>>> ctx = dfn.SessionContext()
>>> df = ctx.from_pydict({"a": [1], "b": [1.0]})
>>> result = df.select(
... dfn.functions.cast_to_type(
... dfn.col("a"), dfn.col("b")
... ).alias("c")
... )
>>> result.collect_column("c")[0].as_py()
1.0

>>> df = ctx.from_pydict({"a": ["oops"], "b": [1.0]})
>>> result = df.select(
... dfn.functions.cast_to_type(
... dfn.col("a"), dfn.col("b"), try_cast=True
... ).alias("c")
... )
>>> result.collect_column("c")[0].as_py() is None
True
"""
return Expr(f.cast_to_type(value.expr, type_ref.expr, try_cast=try_cast))


def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr:
"""Returns the metadata of the input expression.

Expand Down Expand Up @@ -2963,6 +3056,41 @@ def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr:
return Expr(f.arrow_metadata(expr.expr, key.expr))


def with_metadata(expr: Expr, metadata: dict[str, str]) -> Expr:
"""Attaches Arrow field metadata (key/value pairs) to the input expression.

This is the inverse of :py:func:`arrow_metadata`. Existing metadata on the
input field is preserved; new keys overwrite on collision. Keys must be
non-empty strings; empty values are allowed.

An empty ``metadata`` dict is a no-op and returns the input expression
unchanged. Empty keys raise :py:class:`ValueError`.

Examples:
>>> ctx = dfn.SessionContext()
>>> df = ctx.from_pydict({"a": [1]})
>>> result = df.select(
... dfn.functions.with_metadata(
... dfn.col("a"), {"unit": "ms"}
... ).alias("a")
... )
>>> result.select(
... dfn.functions.arrow_metadata(dfn.col("a"), "unit").alias("u")
... ).collect_column("u")[0].as_py()
'ms'
"""
if not metadata:
return expr
args = [expr.expr]
for k, v in metadata.items():
if not k:
msg = "with_metadata keys must be non-empty strings"
raise ValueError(msg)
args.append(Expr.string_literal(k).expr)
args.append(Expr.string_literal(v).expr)
return Expr(f.with_metadata(*args))


def get_field(expr: Expr, *names: Expr | str) -> Expr:
"""Extracts a (possibly nested) field from a struct or map by name.

Expand Down
99 changes: 81 additions & 18 deletions python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,30 +1299,93 @@ def test_make_time(df):
assert result.column(0)[0].as_py() == time(12, 30)


def test_arrow_cast(df):
df = df.select(
f.arrow_cast(column("b"), "Float64").alias("b_as_float"),
f.arrow_cast(column("b"), "Int32").alias("b_as_int"),
@pytest.mark.parametrize("cast_fn", [f.arrow_cast, f.arrow_try_cast])
@pytest.mark.parametrize(
("data_type", "expected"),
[
("Float64", pa.array([4.0, 5.0, 6.0], type=pa.float64())),
("Int32", pa.array([4, 5, 6], type=pa.int32())),
(pa.float64(), pa.array([4.0, 5.0, 6.0], type=pa.float64())),
(pa.int32(), pa.array([4, 5, 6], type=pa.int32())),
(pa.string(), pa.array(["4", "5", "6"], type=pa.string())),
],
)
def test_arrow_cast_variants(df, cast_fn, data_type, expected):
"""arrow_cast / arrow_try_cast accept str and pyarrow target types."""
result = df.select(cast_fn(column("b"), data_type).alias("c")).collect()[0]
assert result.column(0) == expected


@pytest.mark.parametrize("data_type", ["Float64", pa.float64()])
def test_arrow_try_cast_null_on_failure(data_type):
ctx = SessionContext()
batch = pa.RecordBatch.from_arrays([pa.array(["1.5", "oops", "3"])], names=["s"])
df = ctx.create_dataframe([[batch]])

result = df.select(f.arrow_try_cast(column("s"), data_type).alias("c")).collect()[0]

assert result.column(0).to_pylist() == [1.5, None, 3.0]


def test_arrow_field():
ctx = SessionContext()
field = pa.field("val", pa.int64(), metadata={"k": "v"})
schema = pa.schema([field])
batch = pa.RecordBatch.from_arrays([pa.array([1])], schema=schema)
df = ctx.create_dataframe([[batch]])

out = (
df.select(f.arrow_field(column("val")).alias("f"))
.collect_column("f")[0]
.as_py()
)
result = df.collect()
assert len(result) == 1
result = result[0]
assert out == {
"name": "val",
"data_type": "Int64",
"nullable": True,
"metadata": [("k", "v")],
}

assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64())
assert result.column(1) == pa.array([4, 5, 6], type=pa.int32())

@pytest.mark.parametrize(
("values", "try_cast", "expected"),
[
(pa.array([4, 5, 6]), False, [4.0, 5.0, 6.0]),
(pa.array(["oops", "2", "3"]), True, [None, 2.0, 3.0]),
],
)
def test_cast_to_type(values, try_cast, expected):
"""cast_to_type takes target type from ``type_ref``; try_cast nullifies failures."""
ctx = SessionContext()
batch = pa.RecordBatch.from_arrays(
[values, pa.array([1.0, 2.0, 3.0])], names=["v", "fl"]
)
df = ctx.create_dataframe([[batch]])

def test_arrow_cast_with_pyarrow_type(df):
df = df.select(
f.arrow_cast(column("b"), pa.float64()).alias("b_as_float"),
f.arrow_cast(column("b"), pa.int32()).alias("b_as_int"),
f.arrow_cast(column("b"), pa.string()).alias("b_as_str"),
result = df.select(
f.cast_to_type(column("v"), column("fl"), try_cast=try_cast).alias("c")
).collect()[0]

assert result.column(0).to_pylist() == expected
assert result.column(0).type == pa.float64()


def test_with_metadata_round_trip(df):
df = df.select(f.with_metadata(column("b"), {"unit": "ms"}).alias("b"))
result = df.select(f.arrow_metadata(column("b"), "unit").alias("u")).collect_column(
"u"
)
result = df.collect()[0]
assert result[0].as_py() == "ms"


def test_with_metadata_empty_dict_noop(df):
out = df.select(f.with_metadata(column("b"), {}).alias("b")).collect()[0]
assert out.column(0) == pa.array([4, 5, 6])


assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64())
assert result.column(1) == pa.array([4, 5, 6], type=pa.int32())
assert result.column(2) == pa.array(["4", "5", "6"], type=pa.string())
def test_with_metadata_empty_key_raises():
with pytest.raises(ValueError, match="non-empty"):
f.with_metadata(column("b"), {"": "v"})


def test_case(df):
Expand Down
7 changes: 6 additions & 1 deletion skills/datafusion_python/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,12 @@ F.left(col("c_phone"), lit(2)) # prefix shortcut

**Hash**: `md5`, `sha224`, `sha256`, `sha384`, `sha512`, `digest`

**Type**: `arrow_typeof`, `arrow_cast`, `arrow_metadata`
**Type**: `arrow_typeof`, `arrow_cast`, `arrow_try_cast`, `arrow_field`,
`arrow_metadata`, `cast_to_type`, `with_metadata`

Note: ``cast_to_type(value, type_ref, *, try_cast=False)`` is the single
Python entry point for both upstream ``cast_to_type`` and ``try_cast_to_type``;
pass ``try_cast=True`` for the variant that returns NULL on failure.

**Other**: `in_list`, `order_by`, `alias`, `col`, `encode`, `decode`,
`to_hex`, `to_char`, `uuid`, `version`, `bit_length`, `octet_length`