Skip to content

Commit

Permalink
Move calculation of observed grouper to when initialising groupby
Browse files Browse the repository at this point in the history
  • Loading branch information
Kei committed Apr 20, 2024
1 parent baa1b28 commit 30013ee
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 36 deletions.
6 changes: 4 additions & 2 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2040,12 +2040,13 @@ def _gotitem(self, key, ndim: int, subset=None):
group_keys=self.group_keys,
observed=self.observed,
dropna=self.dropna,
observed_grouper=self.observed_grouper,
observed_exclusions=self.observed_exclusions,
)
elif ndim == 1:
if subset is None:
subset = self.obj[key]

orig_obj = self.orig_obj if not self.observed else None
return SeriesGroupBy(
subset,
self.keys,
Expand All @@ -2058,7 +2059,8 @@ def _gotitem(self, key, ndim: int, subset=None):
group_keys=self.group_keys,
observed=self.observed,
dropna=self.dropna,
orig_obj=orig_obj,
observed_grouper=self.observed_grouper,
observed_exclusions=self.observed_exclusions,
)

raise AssertionError("invalid ndim for _gotitem")
Expand Down
56 changes: 22 additions & 34 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,9 +586,10 @@ class BaseGroupBy(PandasObject, SelectionMixin[NDFrameT], GroupByIndexingMixin):
"keys",
"level",
"obj",
"orig_obj",
"observed",
"sort",
"observed_grouper",
"observed_exclusions",
}

_grouper: ops.BaseGrouper
Expand Down Expand Up @@ -1107,7 +1108,8 @@ def __init__(
group_keys: bool = True,
observed: bool = False,
dropna: bool = True,
orig_obj: NDFrameT | None = None,
observed_grouper: ops.BaseGrouper | None = None,
observed_exclusions: frozenset[Hashable] | None = None,
) -> None:
self._selection = selection

Expand All @@ -1119,8 +1121,8 @@ def __init__(
self.sort = sort
self.group_keys = group_keys
self.dropna = dropna
self.orig_obj = obj if orig_obj is None else orig_obj

orig_obj = obj
if grouper is None:
grouper, exclusions, obj = get_grouper(
obj,
Expand All @@ -1136,6 +1138,21 @@ def __init__(
self._grouper = grouper
self.exclusions = frozenset(exclusions) if exclusions else frozenset()

if not observed and observed_grouper is None:
observed_grouper, observed_exclusions, _ = get_grouper(
orig_obj,
self.keys,
level=self.level,
sort=self.sort,
observed=True,
dropna=self.dropna,
)

self.observed_grouper = observed_grouper
self.observed_exclusions = (
frozenset(observed_exclusions) if observed_exclusions else frozenset()
)

def __getattr__(self, attr: str):
if attr in self._internal_names_set:
return object.__getattribute__(self, attr)
Expand Down Expand Up @@ -1887,44 +1904,15 @@ def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
)

grouper, exclusions, obj = get_grouper(
self.orig_obj,
self.keys,
level=self.level,
sort=self.sort,
observed=True,
dropna=self.dropna,
)
exclusions = frozenset(exclusions) if exclusions else frozenset()
obj_has_not_changed = self.orig_obj.equals(self.obj)

with (
com.temp_setattr(self, "observed", True),
com.temp_setattr(self, "_grouper", grouper),
com.temp_setattr(self, "exclusions", exclusions),
com.temp_setattr(self, "obj", obj, condition=obj_has_not_changed),
com.temp_setattr(self, "_grouper", self.observed_grouper),
com.temp_setattr(self, "exclusions", self.observed_exclusions),
):
return self._reduction_kernel_transform(
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
)

# with com.temp_setattr(self, "as_index", True):
# # GH#49834 - result needs groups in the index for
# # _wrap_transform_fast_result
# if func in ["idxmin", "idxmax"]:
# func = cast(Literal["idxmin", "idxmax"], func)
# result = self._idxmax_idxmin(func, True, *args, **kwargs)
# else:
# if engine is not None:
# kwargs["engine"] = engine
# kwargs["engine_kwargs"] = engine_kwargs
# result = getattr(self, func)(*args, **kwargs)

# print("result with observed = False\n", result.to_string())
# r = self._wrap_transform_fast_result(result)
# print("reindexed result", r.to_string())
# return r

@final
def _reduction_kernel_transform(
self, func, *args, engine=None, engine_kwargs=None, **kwargs
Expand Down

0 comments on commit 30013ee

Please sign in to comment.