Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

make things work with pandas 2.2 #7226

Merged
merged 3 commits into from
Mar 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
39 changes: 23 additions & 16 deletions rdkit/Chem/PandasPatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
orig_get_adjustment = getattr(get_adjustment_module, get_adjustment_name)
break
if orig_get_adjustment is not None:
break
break
if orig_get_adjustment is None:
log.warning("Failed to find the pandas get_adjustment() function to patch")
raise AttributeError
Expand Down Expand Up @@ -88,6 +88,13 @@
log.warning("Failed to import pandas")
raise

dataframe_applymap = pd.DataFrame.applymap
try:
if tuple(map(int, (pd.__version__.split(".")))) >= (2, 1, 0):
dataframe_applymap = pd.DataFrame.map
except:
pass

orig_to_html = getattr(to_html_class, "to_html")
pprint_thing = pandas_formats.printing.pprint_thing

Expand Down Expand Up @@ -134,7 +141,7 @@ def get_formatters(cls, df, orig_formatters):
df_subset = df.select_dtypes("object")
return {
col: cls(orig_formatters.get(col, None))
for col in df_subset.columns[df_subset.applymap(MolFormatter.is_mol).any()]
for col in df_subset.columns[dataframe_applymap(df_subset, MolFormatter.is_mol).any()]
}

def __call__(self, x):
Expand Down Expand Up @@ -195,17 +202,16 @@ def patched_to_html(self, *args, **kwargs):


def patched_get_formatter(self, i, *args, **kwargs):
if (isinstance(self.formatters, dict) and isinstance(i, int)
and i not in self.columns and hasattr(self, "tr_col_num")
and i >= self.tr_col_num):
max_cols = 0
if hasattr(self, "max_cols_fitted"):
max_cols = self.max_cols_fitted
elif hasattr(self, "max_cols_adj"):
max_cols = self.max_cols_adj
n_trunc_cols = len(self.columns) - max_cols
if n_trunc_cols > 0:
i += n_trunc_cols
if (isinstance(self.formatters, dict) and isinstance(i, int) and i not in self.columns
and hasattr(self, "tr_col_num") and i >= self.tr_col_num):
max_cols = 0
if hasattr(self, "max_cols_fitted"):
max_cols = self.max_cols_fitted
elif hasattr(self, "max_cols_adj"):
max_cols = self.max_cols_adj
n_trunc_cols = len(self.columns) - max_cols
if n_trunc_cols > 0:
i += n_trunc_cols
return orig_get_formatter(self, i, *args, **kwargs)


Expand Down Expand Up @@ -260,9 +266,10 @@ def patchPandas():
setattr(to_html_class, "to_html", patched_to_html)
if getattr(html_formatter_class, "_write_cell") != patched_write_cell:
setattr(html_formatter_class, "_write_cell", patched_write_cell)
if getattr(pandas_formats.format, get_adjustment_name) != patched_get_adjustment:
setattr(pandas_formats.format, get_adjustment_name, patched_get_adjustment)
if (orig_get_formatter and getattr(dataframeformatter_class, "_get_formatter") != patched_get_formatter):
if getattr(get_adjustment_module, get_adjustment_name) != patched_get_adjustment:
setattr(get_adjustment_module, get_adjustment_name, patched_get_adjustment)
if (orig_get_formatter
and getattr(dataframeformatter_class, "_get_formatter") != patched_get_formatter):
setattr(dataframeformatter_class, "_get_formatter", patched_get_formatter)


Expand Down