Skip to content

Commit

Permalink
make things work with pandas 2.2 (#7226)
Browse files Browse the repository at this point in the history
* make things work with pandas 2.2

* suppress a deprecation warning in pandas >= 2.1.0

---------

Co-authored-by: ptosco <paolo.tosco@novartis.com>
  • Loading branch information
greglandrum and ptosco committed Mar 10, 2024
1 parent 029b1e3 commit 462ed8f
Showing 1 changed file with 23 additions and 16 deletions.
39 changes: 23 additions & 16 deletions rdkit/Chem/PandasPatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
orig_get_adjustment = getattr(get_adjustment_module, get_adjustment_name)
break
if orig_get_adjustment is not None:
break
break
if orig_get_adjustment is None:
log.warning("Failed to find the pandas get_adjustment() function to patch")
raise AttributeError
Expand Down Expand Up @@ -88,6 +88,13 @@
log.warning("Failed to import pandas")
raise

dataframe_applymap = pd.DataFrame.applymap
try:
if tuple(map(int, (pd.__version__.split(".")))) >= (2, 1, 0):
dataframe_applymap = pd.DataFrame.map
except:
pass

orig_to_html = getattr(to_html_class, "to_html")
pprint_thing = pandas_formats.printing.pprint_thing

Expand Down Expand Up @@ -134,7 +141,7 @@ def get_formatters(cls, df, orig_formatters):
df_subset = df.select_dtypes("object")
return {
col: cls(orig_formatters.get(col, None))
for col in df_subset.columns[df_subset.applymap(MolFormatter.is_mol).any()]
for col in df_subset.columns[dataframe_applymap(df_subset, MolFormatter.is_mol).any()]
}

def __call__(self, x):
Expand Down Expand Up @@ -195,17 +202,16 @@ def patched_to_html(self, *args, **kwargs):


def patched_get_formatter(self, i, *args, **kwargs):
if (isinstance(self.formatters, dict) and isinstance(i, int)
and i not in self.columns and hasattr(self, "tr_col_num")
and i >= self.tr_col_num):
max_cols = 0
if hasattr(self, "max_cols_fitted"):
max_cols = self.max_cols_fitted
elif hasattr(self, "max_cols_adj"):
max_cols = self.max_cols_adj
n_trunc_cols = len(self.columns) - max_cols
if n_trunc_cols > 0:
i += n_trunc_cols
if (isinstance(self.formatters, dict) and isinstance(i, int) and i not in self.columns
and hasattr(self, "tr_col_num") and i >= self.tr_col_num):
max_cols = 0
if hasattr(self, "max_cols_fitted"):
max_cols = self.max_cols_fitted
elif hasattr(self, "max_cols_adj"):
max_cols = self.max_cols_adj
n_trunc_cols = len(self.columns) - max_cols
if n_trunc_cols > 0:
i += n_trunc_cols
return orig_get_formatter(self, i, *args, **kwargs)


Expand Down Expand Up @@ -260,9 +266,10 @@ def patchPandas():
setattr(to_html_class, "to_html", patched_to_html)
if getattr(html_formatter_class, "_write_cell") != patched_write_cell:
setattr(html_formatter_class, "_write_cell", patched_write_cell)
if getattr(pandas_formats.format, get_adjustment_name) != patched_get_adjustment:
setattr(pandas_formats.format, get_adjustment_name, patched_get_adjustment)
if (orig_get_formatter and getattr(dataframeformatter_class, "_get_formatter") != patched_get_formatter):
if getattr(get_adjustment_module, get_adjustment_name) != patched_get_adjustment:
setattr(get_adjustment_module, get_adjustment_name, patched_get_adjustment)
if (orig_get_formatter
and getattr(dataframeformatter_class, "_get_formatter") != patched_get_formatter):
setattr(dataframeformatter_class, "_get_formatter", patched_get_formatter)


Expand Down

0 comments on commit 462ed8f

Please sign in to comment.