Last active
April 3, 2020 15:31
-
-
Save HadrienG2/b21b35319470e2a025d6c3fe8a8792c3 to your computer and use it in GitHub Desktop.
Demo of global <-> local bin coordinate conversions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #include <algorithm> | |
| #include <exception> | |
| #include <iostream> | |
| #include <utility> | |
| #include <vector> | |
| #define ASSERT(x, msg) if (!(x)) throw std::runtime_error(msg) | |
| struct IAxis { | |
| virtual int GetNBinsNoOver() const = 0; | |
| virtual int GetNOverflowBins() const = 0; | |
| int GetNBins() const { | |
| return GetNBinsNoOver() + GetNOverflowBins(); | |
| } | |
| virtual int FindBin(double x) const = 0; | |
| virtual double GetBinFrom(int bin) const = 0; | |
| virtual double GetBinTo(int bin) const = 0; | |
| }; | |
| class EqAxisMock : public IAxis | |
| { | |
| public: | |
| EqAxisMock(double from, double to, int nBinsNoOver) : | |
| m_from(from), m_to(to), m_nBinsNoOver(nBinsNoOver) | |
| {} | |
| int GetNBinsNoOver() const final override { | |
| return m_nBinsNoOver; | |
| } | |
| int GetNOverflowBins() const final override { | |
| return 2; | |
| } | |
| int FindBin(double x) const final override { | |
| if (x < m_from) { | |
| return -1; | |
| } else if (x >= m_to) { | |
| return -2; | |
| } else { | |
| return (x - m_from) / GetBinWidth() + 1; | |
| } | |
| } | |
| double GetBinFrom(int bin) const final override { | |
| switch (bin) { | |
| case -1: | |
| return -std::numeric_limits<double>::infinity(); | |
| case -2: | |
| return m_to; | |
| default: | |
| ASSERT((bin >= 1) || (bin <= GetNBinsNoOver()), | |
| "GetBinFrom got an out-of-bounds bin index"); | |
| return m_from + (bin - 1) * GetBinWidth(); | |
| } | |
| } | |
| double GetBinTo(int bin) const final override { | |
| switch (bin) { | |
| case -1: | |
| return m_from; | |
| case -2: | |
| return std::numeric_limits<double>::infinity(); | |
| default: | |
| ASSERT((bin >= 1) || (bin <= GetNBinsNoOver()), | |
| "GetBinTo got an out-of-bounds bin index"); | |
| return m_from + bin * GetBinWidth(); | |
| } | |
| } | |
| private: | |
| double m_from, m_to; | |
| int m_nBinsNoOver; | |
| // NOTE: Specific to equidistant axis, so general algorithm shouldn't use it | |
| double GetBinWidth() const { | |
| return (m_to - m_from) / m_nBinsNoOver; | |
| } | |
| }; | |
| // TODO: Generalize the algorithm to N-d: make it take a vector of (IAxis, int) | |
| // pairs and implement it using a loop. | |
| int ComputeGlobalBin(const IAxis& axis_0, int bin_0, | |
| const IAxis& axis_1, int bin_1) { | |
| // Get regular bins out of the way | |
| if ((bin_0 >= 1) && (bin_1 >= 1)) { | |
| return bin_0 + (bin_1 - 1) * axis_0.GetNBinsNoOver(); | |
| } | |
| // Convert bin indices to another coordinate system where the underflow bin | |
| // has coordinate 0, regular bins have coordinates [1, N], and the overflow | |
| // bin has coordinate N+1, where N is GetNBinsNoOver(). | |
| auto compute_virtual_bin = [](const IAxis& axis, int bin) { | |
| switch (bin) { | |
| case -1: | |
| return 0; | |
| case -2: | |
| return axis.GetNBins() - 1; | |
| default: | |
| ASSERT((bin >= 1) || (bin <= axis.GetNBinsNoOver()), | |
| "Received an invalid local bin index as input"); | |
| return bin; | |
| } | |
| }; | |
| const int virtual_bin_0 = compute_virtual_bin(axis_0, bin_0); | |
| ASSERT((virtual_bin_0 >= 0) || (virtual_bin_0 < axis_0.GetNBins()), | |
| "Computed local virtual bin index is out of expected range for axis 0"); | |
| const int virtual_bin_1 = compute_virtual_bin(axis_1, bin_1); | |
| ASSERT((virtual_bin_1 >= 0) || (virtual_bin_1 < axis_1.GetNBins()), | |
| "Computed local virtual bin index is out of expected range for axis 1"); | |
| // Deduce what the global bin index would be in this coordinate system | |
| const int global_virtual_bin = virtual_bin_0 + virtual_bin_1 * axis_0.GetNBins(); | |
| ASSERT((global_virtual_bin >= 0) || (global_virtual_bin < axis_0.GetNBins() * axis_1.GetNBins()), | |
| "Computed global virtual bin index is out of expected range"); | |
| // Move to negative and 1-based indexing | |
| const int neg_1based_virtual_bin = -global_virtual_bin - 1; | |
| // Now we have the right index for the first overflow bins, but the indices | |
| // are wrong after crossing "regular" bins, because the count of overflow | |
| // bins takes these into account when it shouldn't. | |
| // | |
| // To fix this, we need to know how many regular bins exist before the | |
| // current bin (in row-major order), and increment the virtual bin index by | |
| // this amount. This will un-do the counting of regular bins as overflow. | |
| // | |
| int num_regular_bins_before = 0; | |
| // Rows of regular bins can be taken into account with just a clamp trick | |
| const int num_rows_before = std::clamp(virtual_bin_1 - 1, 0, axis_1.GetNBinsNoOver()); | |
| num_regular_bins_before += num_rows_before * axis_0.GetNBinsNoOver(); | |
| // For columns, however, we need to know whether we are on a regular row | |
| // or an under/overflow row. | |
| if (bin_1 >= 1) { | |
| const int num_cols_before = std::clamp(virtual_bin_0 - 1, 0, axis_1.GetNBinsNoOver()); | |
| num_regular_bins_before += num_cols_before; | |
| } | |
| // And with this correction, we should be good | |
| return neg_1based_virtual_bin + num_regular_bins_before; | |
| } | |
| // TODO: Generalize the algorithm to N-d: make it take a vector of IAxis as | |
| // input, return a vector of Int as output, and implement it using a loop. | |
| std::pair<int, int> ComputeLocalBins(int global_bin, const IAxis& axis_0, const IAxis& axis_1) { | |
| // Get regular bins out of the way | |
| if (global_bin >= 1) { | |
| const int bin_0 = ((global_bin - 1) % axis_0.GetNBinsNoOver()) + 1; | |
| const int bin_1 = | |
| (((global_bin - 1) - (bin_0 - 1)) / axis_0.GetNBinsNoOver()) + 1; | |
| return std::make_pair(bin_0, bin_1); | |
| } | |
| // Convert our negative index to something positive and 0-based, as that is | |
| // more convenient to work with. Note, however, that this is _not_ | |
| // equivalent to the virtual_bin that we had before, because what we have | |
| // here is a count of overflow bins, not of all bins. | |
| ASSERT(global_bin != 0, "Received an invalid global bin index as input"); | |
| const int corrected_virtual_bin = -global_bin - 1; | |
| // This overflow bin count, accounts for... | |
| // - At most one full row of "top" underflow bins, which are the axis 0 bins | |
| // associated with the axis 1 underflow bin. | |
| // - Any number of axis 0 (underflow, overflow) bin pairs, with regular | |
| // bins in the middle. | |
| // - Possibly one trailing axis 0 underflow bin, if we are on the overflow | |
| // bin of such a pair. | |
| // - At most one full row of "bottom" overflow bins, which are the axis 0 | |
| // bins associated with the axis 1 overflow bin. | |
| // | |
| // We can suppress the contribution of the top and bottom row of | |
| // under/overflow bins like so: | |
| const int middle_overflow_bins = | |
| std::clamp(corrected_virtual_bin - axis_0.GetNBins(), | |
| 0, | |
| 2 * axis_1.GetNBinsNoOver()); | |
| // Then, we can deduce the number of rows of regular bins that we must add | |
| // back into the index, taking into account the fact that we must also add | |
| // a row if there is only one underflow bin before us (i.e. we are the | |
| // overflow bin of a row), like so: | |
| const int middle_regular_bins = | |
| ((middle_overflow_bins + 1) / 2) * axis_0.GetNBinsNoOver(); | |
| // With this regular bin count, we can recover the same kind of virtual bin | |
| // index that we had in ComputeGlobalBin()... | |
| const int global_virtual_bin = corrected_virtual_bin + middle_regular_bins; | |
| ASSERT((global_virtual_bin >= 0) || (global_virtual_bin < axis_0.GetNBins() * axis_1.GetNBins()), | |
| "Computed global virtual bin index is out of expected range"); | |
| // ...then from that we can deduce "virtual" local bin indices on each axis | |
| const int virtual_bin_1 = (global_virtual_bin / axis_0.GetNBins()); | |
| ASSERT((virtual_bin_1 >= 0) || (virtual_bin_1 < axis_1.GetNBins()), | |
| "Computed local virtual bin index is out of expected range for axis 1"); | |
| const int virtual_bin_0 = (global_virtual_bin % axis_0.GetNBins()); | |
| ASSERT((virtual_bin_0 >= 0) || (virtual_bin_0 < axis_0.GetNBins()), | |
| "Computed local virtual bin index is out of expected range for axis 0"); | |
| // And from that point, we can go back to the -1/-2 overflow convention | |
| auto compute_bin = [](const IAxis& axis, int virtual_bin) -> int { | |
| if (virtual_bin == 0) { | |
| return -1; | |
| } else if (virtual_bin == (axis.GetNBins() - 1)) { | |
| return -2; | |
| } else { | |
| return virtual_bin; | |
| } | |
| }; | |
| return std::make_pair(compute_bin(axis_0, virtual_bin_0), | |
| compute_bin(axis_1, virtual_bin_1)); | |
| } | |
| int main() { | |
| // Set up a pair of test axes | |
| EqAxisMock axis_0(3.0, 6.0, 3); | |
| EqAxisMock axis_1(7.0, 9.5, 5); | |
| // Enumerate all the bins | |
| auto enumerate_bins = [](const IAxis& axis) -> std::vector<int> { | |
| std::vector<int> all_bins; | |
| all_bins.push_back(-1); | |
| for (int bin = 1; bin <= axis.GetNBinsNoOver(); ++bin) { | |
| all_bins.push_back(bin); | |
| } | |
| all_bins.push_back(-2); | |
| return all_bins; | |
| }; | |
| std::vector<int> all_bins_0 = enumerate_bins(axis_0); | |
| std::vector<int> all_bins_1 = enumerate_bins(axis_1); | |
| // Display the bins boundaries | |
| for (const int bin_1: all_bins_1) { | |
| for (const int bin_0: all_bins_0) { | |
| std::cout << "On true bin (" << bin_0 << ", " << bin_1 << "):\t"; | |
| int global_bin = ComputeGlobalBin(axis_0, bin_0, axis_1, bin_1); | |
| auto [computed_bin_0, computed_bin_1] = | |
| ComputeLocalBins(global_bin, axis_0, axis_1); | |
| std::cout << "Global bin " << global_bin | |
| << " aka (" << computed_bin_0 << ", " | |
| << computed_bin_1 << ')' | |
| << " goes from (" << axis_0.GetBinFrom(bin_0) << ", " | |
| << axis_1.GetBinFrom(bin_1) << ')' | |
| << " to (" << axis_0.GetBinTo(bin_0) << ", " | |
| << axis_1.GetBinTo(bin_1) << ')' | |
| << std::endl; | |
| } | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment