Skip to content

Instantly share code, notes, and snippets.

@HadrienG2
Last active April 3, 2020 15:31
Show Gist options
  • Select an option

  • Save HadrienG2/b21b35319470e2a025d6c3fe8a8792c3 to your computer and use it in GitHub Desktop.

Select an option

Save HadrienG2/b21b35319470e2a025d6c3fe8a8792c3 to your computer and use it in GitHub Desktop.
Demo of global <-> local bin coordinate conversions
#include <algorithm>
#include <exception>
#include <iostream>
#include <utility>
#include <vector>
#define ASSERT(x, msg) if (!(x)) throw std::runtime_error(msg)
struct IAxis {
virtual int GetNBinsNoOver() const = 0;
virtual int GetNOverflowBins() const = 0;
int GetNBins() const {
return GetNBinsNoOver() + GetNOverflowBins();
}
virtual int FindBin(double x) const = 0;
virtual double GetBinFrom(int bin) const = 0;
virtual double GetBinTo(int bin) const = 0;
};
class EqAxisMock : public IAxis
{
public:
EqAxisMock(double from, double to, int nBinsNoOver) :
m_from(from), m_to(to), m_nBinsNoOver(nBinsNoOver)
{}
int GetNBinsNoOver() const final override {
return m_nBinsNoOver;
}
int GetNOverflowBins() const final override {
return 2;
}
int FindBin(double x) const final override {
if (x < m_from) {
return -1;
} else if (x >= m_to) {
return -2;
} else {
return (x - m_from) / GetBinWidth() + 1;
}
}
double GetBinFrom(int bin) const final override {
switch (bin) {
case -1:
return -std::numeric_limits<double>::infinity();
case -2:
return m_to;
default:
ASSERT((bin >= 1) || (bin <= GetNBinsNoOver()),
"GetBinFrom got an out-of-bounds bin index");
return m_from + (bin - 1) * GetBinWidth();
}
}
double GetBinTo(int bin) const final override {
switch (bin) {
case -1:
return m_from;
case -2:
return std::numeric_limits<double>::infinity();
default:
ASSERT((bin >= 1) || (bin <= GetNBinsNoOver()),
"GetBinTo got an out-of-bounds bin index");
return m_from + bin * GetBinWidth();
}
}
private:
double m_from, m_to;
int m_nBinsNoOver;
// NOTE: Specific to equidistant axis, so general algorithm shouldn't use it
double GetBinWidth() const {
return (m_to - m_from) / m_nBinsNoOver;
}
};
// TODO: Generalize the algorithm to N-d: make it take a vector of (IAxis, int)
// pairs and implement it using a loop.
int ComputeGlobalBin(const IAxis& axis_0, int bin_0,
const IAxis& axis_1, int bin_1) {
// Get regular bins out of the way
if ((bin_0 >= 1) && (bin_1 >= 1)) {
return bin_0 + (bin_1 - 1) * axis_0.GetNBinsNoOver();
}
// Convert bin indices to another coordinate system where the underflow bin
// has coordinate 0, regular bins have coordinates [1, N], and the overflow
// bin has coordinate N+1, where N is GetNBinsNoOver().
auto compute_virtual_bin = [](const IAxis& axis, int bin) {
switch (bin) {
case -1:
return 0;
case -2:
return axis.GetNBins() - 1;
default:
ASSERT((bin >= 1) || (bin <= axis.GetNBinsNoOver()),
"Received an invalid local bin index as input");
return bin;
}
};
const int virtual_bin_0 = compute_virtual_bin(axis_0, bin_0);
ASSERT((virtual_bin_0 >= 0) || (virtual_bin_0 < axis_0.GetNBins()),
"Computed local virtual bin index is out of expected range for axis 0");
const int virtual_bin_1 = compute_virtual_bin(axis_1, bin_1);
ASSERT((virtual_bin_1 >= 0) || (virtual_bin_1 < axis_1.GetNBins()),
"Computed local virtual bin index is out of expected range for axis 1");
// Deduce what the global bin index would be in this coordinate system
const int global_virtual_bin = virtual_bin_0 + virtual_bin_1 * axis_0.GetNBins();
ASSERT((global_virtual_bin >= 0) || (global_virtual_bin < axis_0.GetNBins() * axis_1.GetNBins()),
"Computed global virtual bin index is out of expected range");
// Move to negative and 1-based indexing
const int neg_1based_virtual_bin = -global_virtual_bin - 1;
// Now we have the right index for the first overflow bins, but the indices
// are wrong after crossing "regular" bins, because the count of overflow
// bins takes these into account when it shouldn't.
//
// To fix this, we need to know how many regular bins exist before the
// current bin (in row-major order), and increment the virtual bin index by
// this amount. This will un-do the counting of regular bins as overflow.
//
int num_regular_bins_before = 0;
// Rows of regular bins can be taken into account with just a clamp trick
const int num_rows_before = std::clamp(virtual_bin_1 - 1, 0, axis_1.GetNBinsNoOver());
num_regular_bins_before += num_rows_before * axis_0.GetNBinsNoOver();
// For columns, however, we need to know whether we are on a regular row
// or an under/overflow row.
if (bin_1 >= 1) {
const int num_cols_before = std::clamp(virtual_bin_0 - 1, 0, axis_1.GetNBinsNoOver());
num_regular_bins_before += num_cols_before;
}
// And with this correction, we should be good
return neg_1based_virtual_bin + num_regular_bins_before;
}
// TODO: Generalize the algorithm to N-d: make it take a vector of IAxis as
// input, return a vector of Int as output, and implement it using a loop.
std::pair<int, int> ComputeLocalBins(int global_bin, const IAxis& axis_0, const IAxis& axis_1) {
// Get regular bins out of the way
if (global_bin >= 1) {
const int bin_0 = ((global_bin - 1) % axis_0.GetNBinsNoOver()) + 1;
const int bin_1 =
(((global_bin - 1) - (bin_0 - 1)) / axis_0.GetNBinsNoOver()) + 1;
return std::make_pair(bin_0, bin_1);
}
// Convert our negative index to something positive and 0-based, as that is
// more convenient to work with. Note, however, that this is _not_
// equivalent to the virtual_bin that we had before, because what we have
// here is a count of overflow bins, not of all bins.
ASSERT(global_bin != 0, "Received an invalid global bin index as input");
const int corrected_virtual_bin = -global_bin - 1;
// This overflow bin count, accounts for...
// - At most one full row of "top" underflow bins, which are the axis 0 bins
// associated with the axis 1 underflow bin.
// - Any number of axis 0 (underflow, overflow) bin pairs, with regular
// bins in the middle.
// - Possibly one trailing axis 0 underflow bin, if we are on the overflow
// bin of such a pair.
// - At most one full row of "bottom" overflow bins, which are the axis 0
// bins associated with the axis 1 overflow bin.
//
// We can suppress the contribution of the top and bottom row of
// under/overflow bins like so:
const int middle_overflow_bins =
std::clamp(corrected_virtual_bin - axis_0.GetNBins(),
0,
2 * axis_1.GetNBinsNoOver());
// Then, we can deduce the number of rows of regular bins that we must add
// back into the index, taking into account the fact that we must also add
// a row if there is only one underflow bin before us (i.e. we are the
// overflow bin of a row), like so:
const int middle_regular_bins =
((middle_overflow_bins + 1) / 2) * axis_0.GetNBinsNoOver();
// With this regular bin count, we can recover the same kind of virtual bin
// index that we had in ComputeGlobalBin()...
const int global_virtual_bin = corrected_virtual_bin + middle_regular_bins;
ASSERT((global_virtual_bin >= 0) || (global_virtual_bin < axis_0.GetNBins() * axis_1.GetNBins()),
"Computed global virtual bin index is out of expected range");
// ...then from that we can deduce "virtual" local bin indices on each axis
const int virtual_bin_1 = (global_virtual_bin / axis_0.GetNBins());
ASSERT((virtual_bin_1 >= 0) || (virtual_bin_1 < axis_1.GetNBins()),
"Computed local virtual bin index is out of expected range for axis 1");
const int virtual_bin_0 = (global_virtual_bin % axis_0.GetNBins());
ASSERT((virtual_bin_0 >= 0) || (virtual_bin_0 < axis_0.GetNBins()),
"Computed local virtual bin index is out of expected range for axis 0");
// And from that point, we can go back to the -1/-2 overflow convention
auto compute_bin = [](const IAxis& axis, int virtual_bin) -> int {
if (virtual_bin == 0) {
return -1;
} else if (virtual_bin == (axis.GetNBins() - 1)) {
return -2;
} else {
return virtual_bin;
}
};
return std::make_pair(compute_bin(axis_0, virtual_bin_0),
compute_bin(axis_1, virtual_bin_1));
}
int main() {
// Set up a pair of test axes
EqAxisMock axis_0(3.0, 6.0, 3);
EqAxisMock axis_1(7.0, 9.5, 5);
// Enumerate all the bins
auto enumerate_bins = [](const IAxis& axis) -> std::vector<int> {
std::vector<int> all_bins;
all_bins.push_back(-1);
for (int bin = 1; bin <= axis.GetNBinsNoOver(); ++bin) {
all_bins.push_back(bin);
}
all_bins.push_back(-2);
return all_bins;
};
std::vector<int> all_bins_0 = enumerate_bins(axis_0);
std::vector<int> all_bins_1 = enumerate_bins(axis_1);
// Display the bins boundaries
for (const int bin_1: all_bins_1) {
for (const int bin_0: all_bins_0) {
std::cout << "On true bin (" << bin_0 << ", " << bin_1 << "):\t";
int global_bin = ComputeGlobalBin(axis_0, bin_0, axis_1, bin_1);
auto [computed_bin_0, computed_bin_1] =
ComputeLocalBins(global_bin, axis_0, axis_1);
std::cout << "Global bin " << global_bin
<< " aka (" << computed_bin_0 << ", "
<< computed_bin_1 << ')'
<< " goes from (" << axis_0.GetBinFrom(bin_0) << ", "
<< axis_1.GetBinFrom(bin_1) << ')'
<< " to (" << axis_0.GetBinTo(bin_0) << ", "
<< axis_1.GetBinTo(bin_1) << ')'
<< std::endl;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment