Skip to content

Commit

Permalink
Update impl to use maybe_cast_pointwise_result instead of maybe_cast_…
Browse files Browse the repository at this point in the history
…to_pyarrow_array
  • Loading branch information
Kei committed Apr 29, 2024
1 parent 3b6696b commit 680e238
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 83 deletions.
75 changes: 1 addition & 74 deletions pandas/core/dtypes/cast.py
Expand Up @@ -478,40 +478,6 @@ def maybe_cast_pointwise_result(
return result


def maybe_cast_to_pyarrow_result(result: ArrayLike, obj_dtype: DtypeObj) -> ArrayLike:
"""
Try casting result of a pointwise operation to its pyarrow dtype
and arrow extension array if appropriate. If not possible,
returns np.ndarray.
Parameters
----------
result : array-like
Result to cast.
Returns
-------
result : array-like
result maybe casted to the dtype.
"""
from pandas.core.construction import array as pd_array

# maybe_convert_objects is unable to detect NA as nan
# (detects it as object instead)
stripped_result = result[~isna(result)]
npvalues = lib.maybe_convert_objects(stripped_result, try_float=False)

if stripped_result.size == 0:
return maybe_cast_pointwise_result(npvalues, obj_dtype, numeric_only=True)

try:
dtype = convert_dtypes(npvalues, dtype_backend="pyarrow")
out = pd_array(result, dtype=dtype)
except (TypeError, ValueError, np.ComplexWarning):
out = npvalues
return out


def _maybe_cast_to_extension_array(
cls: type[ExtensionArray], obj: ArrayLike, dtype: ExtensionDtype | None = None
) -> ArrayLike:
Expand Down Expand Up @@ -1061,7 +1027,6 @@ def convert_dtypes(
np.dtype, or ExtensionDtype
"""
inferred_dtype: str | DtypeObj
orig_inferred_dtype = None

if (
convert_string or convert_integer or convert_boolean or convert_floating
Expand All @@ -1070,7 +1035,6 @@ def convert_dtypes(
inferred_dtype = lib.infer_dtype(input_array)
else:
inferred_dtype = input_array.dtype
orig_inferred_dtype = inferred_dtype

if is_string_dtype(inferred_dtype):
if not convert_string or inferred_dtype == "bytes":
Expand Down Expand Up @@ -1168,8 +1132,7 @@ def convert_dtypes(
elif isinstance(inferred_dtype, StringDtype):
base_dtype = np.dtype(str)
else:
base_dtype = _infer_pyarrow_dtype(input_array, orig_inferred_dtype)

base_dtype = inferred_dtype
if (
base_dtype.kind == "O" # type: ignore[union-attr]
and input_array.size > 0
Expand All @@ -1180,10 +1143,8 @@ def convert_dtypes(
pa_type = pa.null()
else:
pa_type = to_pyarrow_type(base_dtype)

if pa_type is not None:
inferred_dtype = ArrowDtype(pa_type)

elif dtype_backend == "numpy_nullable" and isinstance(inferred_dtype, ArrowDtype):
# GH 53648
inferred_dtype = _arrow_dtype_mapping()[inferred_dtype.pyarrow_dtype]
Expand All @@ -1193,40 +1154,6 @@ def convert_dtypes(
return inferred_dtype # type: ignore[return-value]


def _infer_pyarrow_dtype(
input_array: ArrayLike,
inferred_dtype: str,
) -> DtypeObj:
import pyarrow as pa

if inferred_dtype == "date":
return ArrowDtype(pa.date32())
elif inferred_dtype == "time":
return ArrowDtype(pa.time64("us"))
elif inferred_dtype == "bytes":
return ArrowDtype(pa.binary())
elif inferred_dtype == "decimal":
from pyarrow import (
ArrowInvalid,
ArrowMemoryError,
ArrowNotImplementedError,
)

try:
pyarrow_array = pa.array(input_array)
return ArrowDtype(pyarrow_array.type)
except (
TypeError,
ValueError,
ArrowInvalid,
ArrowMemoryError,
ArrowNotImplementedError,
):
return input_array.dtype

return input_array.dtype


def maybe_infer_to_datetimelike(
value: npt.NDArray[np.object_],
) -> np.ndarray | DatetimeArray | TimedeltaArray | PeriodArray | IntervalArray:
Expand Down
21 changes: 12 additions & 9 deletions pandas/core/groupby/ops.py
Expand Up @@ -36,7 +36,6 @@

from pandas.core.dtypes.cast import (
maybe_cast_pointwise_result,
maybe_cast_to_pyarrow_result,
maybe_downcast_to_dtype,
)
from pandas.core.dtypes.common import (
Expand All @@ -46,6 +45,7 @@
ensure_uint64,
is_1d_only_ea_dtype,
)
from pandas.core.dtypes.dtypes import ArrowDtype
from pandas.core.dtypes.missing import (
isna,
maybe_fill,
Expand Down Expand Up @@ -917,21 +917,24 @@ def agg_series(
"""

result = self._aggregate_series_pure_python(obj, func)
npvalues = lib.maybe_convert_objects(result, try_float=False)

if isinstance(obj._values, ArrowExtensionArray):
return maybe_cast_to_pyarrow_result(result, obj.dtype)
out = maybe_cast_pointwise_result(
npvalues, obj.dtype, numeric_only=True, same_dtype=False
)
import pyarrow as pa

if isinstance(out.dtype, ArrowDtype) and pa.types.is_struct(
out.dtype.pyarrow_dtype
):
out = npvalues

if not isinstance(obj._values, np.ndarray) and not isinstance(
obj._values, ArrowExtensionArray
):
elif not isinstance(obj._values, np.ndarray):
# we can preserve a little bit more aggressively with EA dtype
# because maybe_cast_pointwise_result will do a try/except
# with _from_sequence. NB we are assuming here that _from_sequence
# is sufficiently strict that it casts appropriately.
preserve_dtype = True

npvalues = lib.maybe_convert_objects(result, try_float=False)
if preserve_dtype:
out = maybe_cast_pointwise_result(npvalues, obj.dtype, numeric_only=True)
else:
out = npvalues
Expand Down

0 comments on commit 680e238

Please sign in to comment.