Skip to content

Commit 797c7e8

Browse files
committed
Allow recoloring of hsne embeddings
1 parent f945cb5 commit 797c7e8

File tree

1 file changed

+38
-29
lines changed

1 file changed

+38
-29
lines changed

src/ScatterplotPlugin.cpp

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -172,17 +172,17 @@ ScatterplotPlugin::ScatterplotPlugin(const PluginFactory* factory) :
172172
});
173173
}
174174

175-
//const auto numPointsCand = candidateDataset->getNumPoints();
176-
//const auto numPointsPos = _positionDataset->getNumPoints();
177-
//const bool sameNumPoints = numPointsCand == numPointsPos;
178-
//bool sameNumPointsAsFull = false;
179-
180-
//if (_positionDataset->isDerivedData()) {
181-
// const auto numPointsSor = _positionDataset->getSourceDataset<Points>()->getNumPoints();
182-
// sameNumPointsAsFull = numPointsCand == numPointsSor;
183-
//}
184-
185-
//if (sameNumPoints || sameNumPointsAsFull) {
175+
// Accept both data with the same number if points and data which is derived from
176+
// a parent that has the same number of points (e.g. for HSNE embeddings)
177+
const auto numPointsCandidate = candidateDataset->getNumPoints();
178+
const auto numPointsPosition = _positionDataset->getNumPoints();
179+
const bool sameNumPoints = numPointsPosition == numPointsCandidate;
180+
const bool sameNumPointsAsFull =
181+
/*if*/ _positionDataset->isDerivedData() ?
182+
/*then*/ _positionDataset->getSourceDataset<Points>()->getFullDataset<Points>()->getNumPoints() == numPointsCandidate :
183+
/*else*/ false;
184+
185+
if (sameNumPoints || sameNumPointsAsFull) {
186186

187187
// The number of points is equal, so offer the option to use the points dataset as source for points colors
188188
dropRegions << new DropWidget::DropRegion(this, "Point color", QString("Colorize %1 points with %2").arg(_positionDataset->text(), candidateDataset->text()), "palette", true, [this, candidateDataset]() {
@@ -201,7 +201,7 @@ ScatterplotPlugin::ScatterplotPlugin(const PluginFactory* factory) :
201201
_settingsAction.getPlotAction().getPointPlotAction().addPointOpacityDataset(candidateDataset);
202202
_settingsAction.getPlotAction().getPointPlotAction().getOpacityAction().setCurrentDataset(candidateDataset);
203203
});
204-
//}
204+
}
205205
}
206206
}
207207

@@ -580,9 +580,9 @@ void ScatterplotPlugin::samplePoints()
580580
if (getSamplerAction().getRestrictNumberOfElementsAction().isChecked() && numberOfPoints >= getSamplerAction().getMaximumNumberOfElementsAction().getValue())
581581
break;
582582

583-
const auto& distance = sampledPoint.first;
584-
const auto& localPointIndex = sampledPoint.second;
585-
const auto& globalPointIndex = localGlobalIndices[localPointIndex];
583+
const auto& distance = sampledPoint.first;
584+
const auto& localPointIndex = sampledPoint.second;
585+
const auto& globalPointIndex = localGlobalIndices[localPointIndex];
586586

587587
distances << distance;
588588
localPointIndices << localPointIndex;
@@ -659,24 +659,33 @@ void ScatterplotPlugin::loadColors(const Dataset<Points>& points, const std::uin
659659

660660
points->extractDataForDimension(scalars, dimensionIndex);
661661

662-
//if (_positionSourceDataset->getNumPoints() == points->getNumPoints())
663-
//{
664-
std::vector<std::uint32_t> globalIndices;
665-
_positionDataset->getGlobalIndices(globalIndices);
662+
const auto numColorPoints = points->getNumPoints();
666663

667-
std::vector<float> localScalars(_numPoints, 0);
668-
std::int32_t localColorIndex = 0;
669664

670-
for (const auto& globalIndex : globalIndices)
671-
localScalars[localColorIndex++] = scalars[globalIndex];
665+
if (numColorPoints != _numPoints) {
666+
667+
const bool sameNumPointsAsFull =
668+
/*if*/ _positionDataset->isDerivedData() ?
669+
/*then*/ _positionSourceDataset->getFullDataset<Points>()->getNumPoints() == numColorPoints :
670+
/*else*/ false;
672671

673-
std::swap(localScalars, scalars);
672+
if (sameNumPointsAsFull) {
673+
std::vector<std::uint32_t> globalIndices;
674+
_positionDataset->getGlobalIndices(globalIndices);
674675

675-
//}
676-
//else if (points->getNumPoints() != _numPoints) {
677-
// qWarning("Number of points used for coloring does not match number of points in data, aborting attempt to color plot");
678-
// return;
679-
//}
676+
std::vector<float> localScalars(_numPoints, 0);
677+
std::int32_t localColorIndex = 0;
678+
679+
for (const auto& globalIndex : globalIndices)
680+
localScalars[localColorIndex++] = scalars[globalIndex];
681+
682+
std::swap(localScalars, scalars);
683+
}
684+
else {
685+
qWarning("Number of points used for coloring does not match number of points in data, aborting attempt to color plot");
686+
return;
687+
}
688+
}
680689

681690
assert(scalars.size() == _numPoints);
682691

0 commit comments

Comments
 (0)