Skip to content

Commit

Permalink
Bad lasso (#6751)
Browse files Browse the repository at this point in the history
* Better separation of lines.

* Only put atoms in colour list once.

* Test.

* Hash codes.

* Response to review.

* First attempt at fixing stray line.

* Tidier.

* Squared distances.
  • Loading branch information
DavidACosgrove committed Oct 13, 2023
1 parent e3d605d commit 24c11d7
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 33 deletions.
152 changes: 130 additions & 22 deletions Code/GraphMol/MolDraw2D/DrawMolMCHLasso.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,23 @@ void DrawMolMCHLasso::extractAtomColourLists(
std::vector<DrawColour> &colours,
std::vector<std::vector<int>> &colourAtoms,
std::vector<std::vector<int>> &colourLists) const {
std::vector<boost::dynamic_bitset<>> inColourAtoms;
for (const auto &cm : mcHighlightAtomMap_) {
for (const auto &col : cm.second) {
auto cpos = std::find(colours.begin(), colours.end(), col);
if (cpos == colours.end()) {
colours.push_back(col);
colourAtoms.push_back(std::vector<int>(1, cm.first));
inColourAtoms.push_back(
boost::dynamic_bitset<>(drawMol_->getNumAtoms()));
inColourAtoms.back().set(cm.first);
} else {
colourAtoms[std::distance(colours.begin(), cpos)].push_back(cm.first);
auto ln = std::distance(colours.begin(), cpos);
// it's important that each atom is only in the list once - Github6749
if (!inColourAtoms[ln][cm.first]) {
colourAtoms[ln].push_back(cm.first);
inColourAtoms[ln].set(cm.first);
}
}
}
}
Expand Down Expand Up @@ -137,6 +146,7 @@ void DrawMolMCHLasso::drawLasso(size_t lassoNum, const RDKit::DrawColour &col,
fixIntersectingLines(arcs, lines);
fixIntersectingArcsAndLines(arcs, lines);
fixProtrudingLines(lines);
fixOrphanLines(arcs, lines);

for (auto &it : arcs) {
highlights_.push_back(std::move(it));
Expand All @@ -146,23 +156,40 @@ void DrawMolMCHLasso::drawLasso(size_t lassoNum, const RDKit::DrawColour &col,
}
}

namespace {
double getLassoWidth(const DrawMolMCH *dm, int atNum, int lassoNum) {
PRECONDITION(dm, "Needs valid DrawMolMCH")
double xrad, yrad;
dm->getAtomRadius(atNum, xrad, yrad);
// Double the area of the circles for successive lassos.
const static double rats[] = {1.0, 1.414, 2, 2.828, 4};
if (lassoNum > 4) {
// It's going to look horrible, probably, but it's a lot of lassos.
return xrad * (1 + lassoNum) * 0.75;
} else {
return xrad * rats[lassoNum];
}
}
} // namespace

// ****************************************************************************
void DrawMolMCHLasso::extractAtomArcs(
size_t lassoNum, const RDKit::DrawColour &col,
const std::vector<int> &colAtoms,
std::vector<std::unique_ptr<DrawShapeArc>> &arcs) const {
double xradius, yradius;
// an empirically derived lineWidth.
int lineWidth = 3;
bool scaleLineWidth = true;
for (auto ca : colAtoms) {
if (ca < 0 || static_cast<unsigned int>(ca) >= drawMol_->getNumAtoms()) {
// there's an error in the colour map
continue;
}
getAtomRadius(ca, xradius, yradius);
xradius += xradius * lassoNum * 0.5;
Point2D radii(xradius, xradius);
double lassoWidth = getLassoWidth(this, ca, lassoNum);
Point2D radii(lassoWidth, lassoWidth);
std::vector<Point2D> pts{atCds_[ca], radii};
DrawShapeArc *ell = new DrawShapeArc(
pts, 0.0, 360.0, drawOptions_.bondLineWidth, true, col, false, ca);
DrawShapeArc *ell = new DrawShapeArc(pts, 0.0, 360.0, lineWidth,
scaleLineWidth, col, false, ca);
arcs.emplace_back(ell);
}
}
Expand All @@ -172,27 +199,25 @@ void DrawMolMCHLasso::extractBondLines(
size_t lassoNum, const RDKit::DrawColour &col,
const std::vector<int> &colAtoms,
std::vector<std::unique_ptr<DrawShapeSimpleLine>> &lines) const {
int lineWidth = 3;
bool scaleLineWidth = true;
if (colAtoms.size() > 1) {
for (size_t i = 0U; i < colAtoms.size() - 1; ++i) {
if (colAtoms[i] < 0 ||
static_cast<unsigned int>(colAtoms[i]) >= drawMol_->getNumAtoms()) {
// there's an error in the colour map.
continue;
}
double xradiusI, yradiusI;
getAtomRadius(colAtoms[i], xradiusI, yradiusI);
xradiusI += xradiusI * lassoNum * 0.5;
auto dispI = xradiusI * 0.75;
auto lassoWidthI = getLassoWidth(this, colAtoms[i], lassoNum);
auto dispI = lassoWidthI * 0.75;
for (size_t j = i + 1; j < colAtoms.size(); ++j) {
if (colAtoms[j] < 0 ||
static_cast<unsigned int>(colAtoms[j]) >= drawMol_->getNumAtoms()) {
// there's an error in the colour map.
continue;
}
double xradiusJ, yradiusJ;
getAtomRadius(colAtoms[j], xradiusJ, yradiusJ);
xradiusJ += xradiusJ * lassoNum * 0.5;
auto dispJ = xradiusJ * 0.75;
auto lassoWidthJ = getLassoWidth(this, colAtoms[j], lassoNum);
auto dispJ = lassoWidthJ * 0.75;
auto bond = drawMol_->getBondBetweenAtoms(colAtoms[i], colAtoms[j]);
if (bond) {
if (!mcHighlightBondMap_.empty()) {
Expand All @@ -213,14 +238,13 @@ void DrawMolMCHLasso::extractBondLines(
// less than the radii of the circles (just less, so that they still
// intersect rather than hitting on the tangent)
auto mid = (p1 + p2) / 2.0;
if ((atCdsI - mid).lengthSq() < xradiusI * xradiusI) {
p1 = atCdsI + perp * xradiusI * 0.99 * m;
p2 = atCdsJ + perp * xradiusJ * 0.99 * m;
if ((atCdsI - mid).lengthSq() < lassoWidthI * lassoWidthI) {
p1 = atCdsI + perp * lassoWidthI * 0.99 * m;
p2 = atCdsJ + perp * lassoWidthJ * 0.99 * m;
}
DrawShapeSimpleLine *pl = new DrawShapeSimpleLine(
{p1, p2}, drawOptions_.bondLineWidth,
drawOptions_.scaleBondWidth, col, colAtoms[i], colAtoms[j],
bond->getIdx(), noDash);
{p1, p2}, lineWidth, scaleLineWidth, col, colAtoms[i],
colAtoms[j], bond->getIdx(), noDash);
lines.emplace_back(pl);
}
}
Expand Down Expand Up @@ -488,7 +512,8 @@ void DrawMolMCHLasso::fixIntersectingArcsAndLines(
void DrawMolMCHLasso::fixProtrudingLines(
std::vector<std::unique_ptr<DrawShapeSimpleLine>> &lines) const {
// lasso_highlight_7.svg also had the problem where two lines didn't
// intersect, but one protruded beyond the end of the other inside the lasso.
// intersect, but one protruded beyond the end of the other inside the
// lasso.
for (auto &line1 : lines) {
for (auto &line2 : lines) {
auto d1_0 = (line1->points_[0] - line2->points_[0]).length();
Expand All @@ -514,5 +539,88 @@ void DrawMolMCHLasso::fixProtrudingLines(
}
}
}

namespace {
std::pair<Point2D, Point2D> getArcEnds(const DrawShapeArc &arc) {
std::pair<Point2D, Point2D> retVal;
// for these purposes, it's always a circle, so just use the x
// radius
retVal.first.x =
arc.points_[0].x + arc.points_[1].x * cos(arc.ang1_ * M_PI / 180.0);
retVal.first.y =
arc.points_[0].y + arc.points_[1].x * sin(arc.ang1_ * M_PI / 180.0);
retVal.second.x =
arc.points_[0].x + arc.points_[1].x * cos(arc.ang2_ * M_PI / 180.0);
retVal.second.y =
arc.points_[0].y + arc.points_[1].x * sin(arc.ang2_ * M_PI / 180.0);
return retVal;
}
} // namespace
void DrawMolMCHLasso::fixOrphanLines(
std::vector<std::unique_ptr<DrawShapeArc>> &arcs,
std::vector<std::unique_ptr<DrawShapeSimpleLine>> &lines) {
// lasso_highlights_7.svg had a line close to an arc at
// one end, but not at the other. Such lines are clearly
// artifacts that need to be removed. Takes out all lines
// that aren't within a tolerance of something at both ends.
std::vector<std::pair<Point2D, Point2D>> arcEnds;
for (const auto &arc : arcs) {
arcEnds.push_back(getArcEnds(*arc));
}
// This tolerance was arrived at empirically. It's enough
// to fix lasso_highlights_7.svg without breaking it and
// lasso_highlights_6.svg. A slightly tighter tolerance
// (e.g. 0.003) causes green lines associated with bonds 9
// and 12 to be removed incorrectly in both of these.
static const double tol = 0.005;
for (auto &line1 : lines) {
bool attached0 = false, attached1 = false;
for (const auto &arcEnd : arcEnds) {
if ((line1->points_[0] - arcEnd.first).lengthSq() < tol) {
attached0 = true;
}
if ((line1->points_[1] - arcEnd.first).lengthSq() < tol) {
attached1 = true;
}
if ((line1->points_[0] - arcEnd.second).lengthSq() < tol) {
attached0 = true;
}
if ((line1->points_[1] - arcEnd.second).lengthSq() < tol) {
attached1 = true;
}
}
if (!attached0 || !attached1) {
for (auto &line2 : lines) {
if (line1 == line2 || !line1 || !line2) {
continue;
}
if ((line1->points_[0] - line2->points_[0]).lengthSq() < tol) {
attached0 = true;
}
if ((line1->points_[0] - line2->points_[1]).lengthSq() < tol) {
attached0 = true;
}
if ((line1->points_[1] - line2->points_[0]).lengthSq() < tol) {
attached1 = true;
}
if ((line1->points_[1] - line2->points_[1]).lengthSq() < tol) {
attached1 = true;
}
if (attached0 && attached1) {
break;
}
}
}
if (!attached0 || !attached1) {
line1.reset();
}
}
lines.erase(std::remove_if(
lines.begin(), lines.end(),
[](const std::unique_ptr<DrawShapeSimpleLine> &line) -> bool {
return !line;
}),
lines.end());
}
} // namespace MolDraw2D_detail
} // namespace RDKit
4 changes: 4 additions & 0 deletions Code/GraphMol/MolDraw2D/DrawMolMCHLasso.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ class DrawMolMCHLasso : public DrawMolMCH {
std::vector<std::unique_ptr<DrawShapeSimpleLine>> &lines) const;
void fixProtrudingLines(
std::vector<std::unique_ptr<DrawShapeSimpleLine>> &lines) const;
// Orphan lines are ones where at least one end isn't close to the end
// of any line or arc.
void fixOrphanLines(std::vector<std::unique_ptr<DrawShapeArc>> &arcs,
std::vector<std::unique_ptr<DrawShapeSimpleLine>> &lines);
};
} // namespace MolDraw2D_detail
} // namespace RDKit
Expand Down
116 changes: 105 additions & 11 deletions Code/GraphMol/MolDraw2D/catch_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,17 +307,18 @@ static const std::map<std::string, std::hash_result_t> SVG_HASHES = {
{"github6504_2.svg", 2871662880U},
{"github6569_1.svg", 116573839U},
{"github6569_2.svg", 2367779037U},
{"lasso_highlights_1.svg", 1259621911U},
{"lasso_highlights_2.svg", 39607542U},
{"lasso_highlights_3.svg", 2744227900U},
{"lasso_highlights_4.svg", 307369790U},
{"lasso_highlights_5.svg", 1103195299U},
{"lasso_highlights_6.svg", 3501858565U},
{"lasso_highlights_7.svg", 4002684466U},
{"lasso_highlights_1.svg", 689837467U},
{"lasso_highlights_2.svg", 348394942U},
{"lasso_highlights_3.svg", 2174136207U},
{"lasso_highlights_4.svg", 1265047504U},
{"lasso_highlights_5.svg", 35778943U},
{"lasso_highlights_6.svg", 1359376880U},
{"lasso_highlights_7.svg", 773081917U},
{"testGithub6685_1.svg", 1835717197U},
{"testGithub6685_2.svg", 116380465U},
{"testGithub6685_3.svg", 409385402U},
{"testGithub6685_4.svg", 1239628830U}};
{"testGithub6685_4.svg", 1239628830U},
{"bad_lasso_1.svg", 1183031575U}};

// These PNG hashes aren't completely reliable due to floating point cruft,
// but they can still reduce the number of drawings that need visual
Expand Down Expand Up @@ -8692,8 +8693,9 @@ M END)RXN";
MolDraw2DSVG drawer(-1, -1);
auto reactIter = rxn->beginReactantTemplates();
REQUIRE_THROWS_AS(MolDraw2DUtils::setACS1996Options(
drawer.drawOptions(),
MolDraw2DUtils::meanBondLength(*(*reactIter)) * 0.7), ValueErrorException);
drawer.drawOptions(),
MolDraw2DUtils::meanBondLength(*(*reactIter)) * 0.7),
ValueErrorException);
}
}

Expand Down Expand Up @@ -9042,4 +9044,96 @@ M END
checkImage(text);
check_file_hash(baseName + "4.svg");
}
}
}

TEST_CASE("Github 6749 : various bad things in the lasso highlighting") {
std::string baseName = "bad_lasso_";
auto mol =
"CCCS(=O)(=O)Nc1ccc(F)c(c1F)C(=O)c2c[nH]c3c2cc(cn3)c4ccc(Cl)cc4"_smiles;
REQUIRE(mol);

auto update_colour_map =
[](const std::vector<int> &ats, DrawColour col,
std::map<int, std::vector<DrawColour>> &ha_map) -> void {
for (auto h : ats) {
auto ex = ha_map.find(h);
if (ex == ha_map.end()) {
std::vector<DrawColour> cvec(1, col);
ha_map.insert(make_pair(h, cvec));
} else {
if (ex->second.end() ==
find(ex->second.begin(), ex->second.end(), col)) {
ex->second.push_back(col);
}
}
}
};
auto update_bond_map =
[](const std::vector<int> &ats, DrawColour col, const ROMol &mol,
std::map<int, std::vector<DrawColour>> &hb_map) -> void {
for (auto at1 : ats) {
for (auto at2 : ats) {
if (at1 > at2) {
auto b = mol.getBondBetweenAtoms(at1, at2);
if (b) {
auto ex = hb_map.find(b->getIdx());
if (ex == hb_map.end()) {
std::vector<DrawColour> cvec(1, col);
hb_map.insert(make_pair(b->getIdx(), cvec));
} else {
ex->second.push_back(col);
}
}
}
}
}
};
std::vector<DrawColour> colours = {
DrawColour(1.0, 0.2, 1.0), DrawColour(0.2, 1.0, 1.0),
DrawColour(0.8, 0.8, 0.2), DrawColour(0.4, 0.4, 0.2)};
std::map<int, std::vector<DrawColour>> ha_map;
std::map<int, std::vector<DrawColour>> hb_map;
update_colour_map({6, 19, 25}, colours[0], ha_map);
update_bond_map({6, 19, 25}, colours[0], *mol, hb_map);

update_colour_map({25, 4, 5, 11, 14, 16}, colours[1], ha_map);
update_bond_map({25, 4, 5, 11, 14, 16}, colours[1], *mol, hb_map);

update_colour_map({19, 25, 17, 18, 20, 21, 7, 8, 9, 10, 12,
13, 22, 23, 24, 26, 27, 28, 29, 31, 32},
colours[2], ha_map);
update_bond_map({19, 25, 17, 18, 20, 21, 7, 8, 9, 10, 12,
13, 22, 23, 24, 26, 27, 28, 29, 31, 32},
colours[2], *mol, hb_map);
// If there are duplicate colours in the list, there was a bug where
// arcs weren't removed correctly (Github 6749)
ha_map[20].push_back(colours[2]);
ha_map[21].push_back(colours[2]);

update_colour_map({7, 8, 9, 10, 12, 13, 26, 27, 28, 29, 31, 32}, colours[3],
ha_map);
update_bond_map({7, 8, 9, 10, 12, 13, 26, 27, 28, 29, 31, 32}, colours[3],
*mol, hb_map);

std::map<int, double> h_rads;
std::map<int, int> h_lw_mult;
MolDraw2DSVG drawer(600, 400);
drawer.drawOptions().multiColourHighlightStyle =
RDKit::MultiColourHighlightStyle::LASSO;
drawer.drawOptions().fillHighlights = false;
drawer.drawMoleculeWithHighlights(*mol, "Bad Lasso", ha_map, hb_map, h_rads,
h_lw_mult);
drawer.finishDrawing();
std::string text = drawer.getDrawingText();
std::ofstream outs(baseName + "1.svg");
outs << text;
outs.flush();
outs.close();
std::regex atom20("<path class='atom-20'");
// there should be 3 matches for "class='atom-20'" - the buggy version gave 5
std::ptrdiff_t const match_count(
std::distance(std::sregex_iterator(text.begin(), text.end(), atom20),
std::sregex_iterator()));
REQUIRE(match_count == 3);
check_file_hash(baseName + "1.svg");
}

0 comments on commit 24c11d7

Please sign in to comment.