Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert result of group by agg to pyarrow if input is pyarrow #58129

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
9faa460
Set preserve_dtype flag for bool type only when result is also bool
Apr 1, 2024
969d5b1
Update implementation to change type to pyarrow only
Apr 2, 2024
66114f3
Change import order
Apr 2, 2024
b0290ed
Convert numpy array to pandas representation of pyarrow array
Apr 3, 2024
20c8fa0
Add tests
Apr 3, 2024
97b3d54
Merge branch 'main' into fix/group_by_agg_pyarrow_bool_numpy_same_type
Apr 3, 2024
932d737
Change pyarrow to optional import in agg_series() method
Apr 5, 2024
82ddeb5
Seperate tests
Apr 5, 2024
d510052
Merge branch 'main' into fix/group_by_agg_pyarrow_bool_numpy_same_type
Apr 5, 2024
62a31d9
Merge branch 'main' into fix/group_by_agg_pyarrow_bool_numpy_same_type
Apr 8, 2024
a54bf58
Revert to old implementation
Apr 8, 2024
64330f0
Update implementation to use pyarrow array method
Apr 8, 2024
0647711
Update test_aggregate tests
Apr 8, 2024
affde38
Move pyarrow import to top of method
Apr 8, 2024
842f561
Update according to pr comments
Apr 12, 2024
93b5bf3
Merge branch 'main' into fix/group_by_agg_pyarrow_bool_numpy_same_type
Apr 20, 2024
6f35c0e
Fallback convert to input dtype is output is all nan or empty array
Apr 20, 2024
abd0adf
Strip na values when inferring pyarrow dtype
Apr 20, 2024
bebc442
Update tests to check expected inferred dtype instead of inputy dtype
Apr 20, 2024
bb6343b
Override test case for test_arrow.py
Apr 21, 2024
3a3f2a2
Merge branch 'main' into fix/group_by_agg_pyarrow_bool_numpy_same_type
Apr 21, 2024
6dc40f5
Empty commit to trigger build run
Apr 21, 2024
4ef96f7
In agg series, convert to np values, then cast to pyarrow dtype, acco…
Apr 23, 2024
c6a98c0
Update tests
Apr 23, 2024
9181eaf
Update rst docs
Apr 25, 2024
612d7d0
Update impl to fix tests
Apr 25, 2024
3b6696b
Declare variable in outer scope
Apr 25, 2024
680e238
Update impl to use maybe_cast_pointwise_result instead of maybe_cast…
Apr 29, 2024
3a8597e
Fix tests with nested array
Apr 29, 2024
6496b15
Update according to pr comments
May 2, 2024
712c36a
Merge branch 'main' into fix/group_by_agg_pyarrow_bool_numpy_same_type
May 2, 2024
e1ccef6
Preserve_dtype if argument is passed in, else don't preserve
May 7, 2024
0ce083d
Merge branch 'main' into fix/group_by_agg_pyarrow_bool_numpy_same_type
undermyumbrella1 May 7, 2024
a1d73f5
Update tests
May 7, 2024
57845a8
Merge branch 'fix/group_by_agg_pyarrow_bool_numpy_same_type' of githu…
May 7, 2024
fa257b0
Remove redundant tests
undermyumbrella1 May 12, 2024
0a9b83f
Merge branch 'main' into fix/group_by_agg_pyarrow_bool_numpy_same_type
undermyumbrella1 May 12, 2024
139319a
retrigger pipeline
undermyumbrella1 May 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Expand Up @@ -449,6 +449,7 @@ Groupby/resample/rolling
- Bug in :meth:`.DataFrameGroupBy.quantile` when ``interpolation="nearest"`` is inconsistent with :meth:`DataFrame.quantile` (:issue:`47942`)
- Bug in :meth:`.Resampler.interpolate` on a :class:`DataFrame` with non-uniform sampling and/or indices not aligning with the resulting resampled index would result in wrong interpolation (:issue:`21351`)
- Bug in :meth:`DataFrame.ewm` and :meth:`Series.ewm` when passed ``times`` and aggregation functions other than mean (:issue:`51695`)
- Bug in :meth:`DataFrameGroupBy.agg` and :meth:`SeriesGroupBy.agg` that was returning numpy dtype values when input values are pyarrow dtype values, instead of returning pyarrow dtype values. (:issue:`53030`)
- Bug in :meth:`DataFrameGroupBy.apply` that was returning a completely empty DataFrame when all return values of ``func`` were ``None`` instead of returning an empty DataFrame with the original columns and dtypes. (:issue:`57775`)
- Bug in :meth:`DataFrameGroupBy.apply` with ``as_index=False`` that was returning :class:`MultiIndex` instead of returning :class:`Index`. (:issue:`58291`)

Expand Down
25 changes: 18 additions & 7 deletions pandas/core/groupby/ops.py
Expand Up @@ -45,12 +45,14 @@
ensure_uint64,
is_1d_only_ea_dtype,
)
from pandas.core.dtypes.dtypes import ArrowDtype
from pandas.core.dtypes.missing import (
isna,
maybe_fill,
)

from pandas.core.arrays import Categorical
from pandas.core.arrays.arrow.array import ArrowExtensionArray
from pandas.core.frame import DataFrame
from pandas.core.groupby import grouper
from pandas.core.indexes.api import (
Expand Down Expand Up @@ -914,20 +916,29 @@ def agg_series(
np.ndarray or ExtensionArray
"""

if not isinstance(obj._values, np.ndarray):
result = self._aggregate_series_pure_python(obj, func)
npvalues = lib.maybe_convert_objects(result, try_float=False)

if isinstance(obj._values, ArrowExtensionArray):
out = maybe_cast_pointwise_result(
npvalues, obj.dtype, numeric_only=True, same_dtype=preserve_dtype
)
import pyarrow as pa

if isinstance(out.dtype, ArrowDtype) and pa.types.is_struct(
out.dtype.pyarrow_dtype
):
out = npvalues
Comment on lines +941 to +944
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a test that hits this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved, the test_agg_lambda_pyarrow_struct_to_object_dtype_conversion test hits this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jbrockmendel - I was surprised maybe_cast_pointwise_result was giving us back a Arrow dtypes we don't have EAs for. I'm thinking the logic here to prevent this should maybe go in dtypes.cast._maybe_cast_to_extension_array in a followup. Any thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

giving us back a Arrow dtypes we don't have EAs for

Can you give an example? this confuses me.

should maybe go in dtypes.cast._maybe_cast_to_extension_array

_maybe_cast_to_extension_array is only used in maybe_cast_pointwise_result, so not a huge deal either way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from pandas.core.dtypes.cast import maybe_cast_pointwise_result

arr = np.array([{"number": 1}])
result = maybe_cast_pointwise_result(
    arr, 
    dtype=pd.ArrowDtype(pa.int64()), 
    numeric_only=True, 
    same_dtype=False,
)
print(result)
# Length: 1, dtype: struct<number: int64>[pyarrow]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jbrockmendel - sorry for the noise, I was not aware we could support struct dtypes. I think everything is okay here.

@undermyumbrella1 - why go with NumPy object dtype instead of struct dtypes here?


elif not isinstance(obj._values, np.ndarray):
# we can preserve a little bit more aggressively with EA dtype
# because maybe_cast_pointwise_result will do a try/except
# with _from_sequence. NB we are assuming here that _from_sequence
# is sufficiently strict that it casts appropriately.
preserve_dtype = True

result = self._aggregate_series_pure_python(obj, func)

npvalues = lib.maybe_convert_objects(result, try_float=False)
if preserve_dtype:
out = maybe_cast_pointwise_result(npvalues, obj.dtype, numeric_only=True)
else:
out = npvalues

return out

@final
Expand Down
21 changes: 21 additions & 0 deletions pandas/tests/extension/test_arrow.py
Expand Up @@ -1125,6 +1125,27 @@ def test_comp_masked_numpy(self, masked_dtype, comparison_op):
expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_()))
tm.assert_series_equal(result, expected)

def test_groupby_agg_extension(self, data_for_grouping):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this test should behave the same as the one in the base class. If that's the case, this can be removed. Can you confirm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

# GH#38980 groupby agg on extension type fails for non-numeric types
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4], "B": data_for_grouping})

expected_df = pd.DataFrame(
{"A": [1, 1, 2, 2, 3, 3, 1, 4], "B": data_for_grouping}
)
expected = expected_df.iloc[[0, 2, 4, 7]]
expected = expected.set_index("A")
expected_dtype = expected["B"].dtype
expected["B"] = expected["B"].astype(expected_dtype)

result = df.groupby("A").agg({"B": "first"})
tm.assert_frame_equal(result, expected)

result = df.groupby("A").agg("first")
tm.assert_frame_equal(result, expected)

result = df.groupby("A").first()
tm.assert_frame_equal(result, expected)


class TestLogicalOps:
"""Various Series and DataFrame logical ops methods."""
Expand Down
86 changes: 86 additions & 0 deletions pandas/tests/groupby/aggregate/test_aggregate.py
Expand Up @@ -1663,3 +1663,89 @@ def func(x):
msg = "length must not be 0"
with pytest.raises(ValueError, match=msg):
df.groupby("A", observed=False).agg(func)


@pytest.mark.parametrize(
"input_dtype, output_dtype",
[
("float[pyarrow]", "double[pyarrow]"),
("int64[pyarrow]", "int64[pyarrow]"),
("uint64[pyarrow]", "int64[pyarrow]"),
("bool[pyarrow]", "bool[pyarrow]"),
],
)
def test_agg_lambda_pyarrow_dtype_conversion(input_dtype, output_dtype):
# GH#53030
# test numpy dtype conversion back to pyarrow dtype
# complexes, floats, ints, uints, object
df = DataFrame(
{
"A": ["c1", "c2", "c3", "c1", "c2", "c3"],
"B": pd.array([100, 200, 255, 0, 199, 40392], dtype=input_dtype),
}
)
gb = df.groupby("A")
result = gb.agg(lambda x: x.min())

expected = DataFrame(
{"B": pd.array([0, 199, 255], dtype=output_dtype)},
index=Index(["c1", "c2", "c3"], name="A"),
)
tm.assert_frame_equal(result, expected)


def test_agg_lambda_complex128_dtype_conversion():
# GH#53030
df = DataFrame(
{"A": ["c1", "c2", "c3"], "B": pd.array([100, 200, 255], "int64[pyarrow]")}
)
gb = df.groupby("A")
result = gb.agg(lambda x: complex(x.sum(), x.count()))

expected = DataFrame(
{
"B": pd.array(
[complex(100, 1), complex(200, 1), complex(255, 1)], dtype="complex128"
),
},
index=Index(["c1", "c2", "c3"], name="A"),
)
tm.assert_frame_equal(result, expected)


def test_agg_lambda_numpy_uint64_to_pyarrow_dtype_conversion():
# GH#53030
df = DataFrame(
{
"A": ["c1", "c2", "c3"],
"B": pd.array([100, 200, 255], dtype="uint64[pyarrow]"),
}
)
gb = df.groupby("A")
result = gb.agg(lambda x: np.uint64(x.sum()))

expected = DataFrame(
{
"B": pd.array([100, 200, 255], dtype="uint64[pyarrow]"),
},
index=Index(["c1", "c2", "c3"], name="A"),
)
tm.assert_frame_equal(result, expected)


def test_agg_lambda_pyarrow_struct_to_object_dtype_conversion():
# GH#53030
df = DataFrame(
{
"A": ["c1", "c2", "c3"],
"B": pd.array([100, 200, 255], dtype="int64[pyarrow]"),
}
)
gb = df.groupby("A")
result = gb.agg(lambda x: {"number": 1})

expected = DataFrame(
{"B": pd.array([{"number": 1}, {"number": 1}, {"number": 1}], dtype="object")},
index=Index(["c1", "c2", "c3"], name="A"),
)
tm.assert_frame_equal(result, expected)
8 changes: 7 additions & 1 deletion pandas/tests/groupby/test_groupby.py
Expand Up @@ -28,6 +28,7 @@
)
import pandas._testing as tm
from pandas.core.arrays import BooleanArray
from pandas.core.arrays.string_arrow import ArrowStringArrayNumpySemantics
import pandas.core.common as com

pytestmark = pytest.mark.filterwarnings("ignore:Mean of empty slice:RuntimeWarning")
Expand Down Expand Up @@ -2475,9 +2476,14 @@ def test_by_column_values_with_same_starting_value(dtype):
"Mood": [["happy", "sad"], "happy"],
"Credit": [2500, 900],
"Name": ["Thomas", "Thomas John"],
}
},
).set_index("Name")
if dtype == "string[pyarrow_numpy]":
import pyarrow as pa

mood_values = ArrowStringArrayNumpySemantics(pa.array(["happy", "sad"]))
expected_result["Mood"] = [mood_values, "happy"]
expected_result["Mood"] = expected_result["Mood"].astype(dtype)
tm.assert_frame_equal(result, expected_result)


Expand Down