Skip to content

Instantly share code, notes, and snippets.

@singularitti
Last active December 15, 2025 12:49
Show Gist options
  • Select an option

  • Save singularitti/8a9e5dc4a79f9d5ab6ceeb5911db8b7e to your computer and use it in GitHub Desktop.

Select an option

Save singularitti/8a9e5dc4a79f9d5ab6ceeb5911db8b7e to your computer and use it in GitHub Desktop.
Vector heatmaps #Python #plotting
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
from typing import Iterable, Tuple, Optional, Union
Color = Union[str, Tuple[float, float, float], Tuple[float, float, float, float]]
def vector_to_heatmap_array(vec: Iterable[int], vertical: bool = True) -> np.ndarray:
"""Converts a 1D binary vector into a 2D array for heatmap plotting.
This utility reshapes a vector such as ``[0, 1, 0, 1, 0]`` into a format
compatible with ``matplotlib.imshow``. The orientation can be vertical
(column vector) or horizontal (row vector).
Args:
vec: A 1D iterable of integers (typically 0/1).
vertical: If ``True``, returns an array of shape ``(N, 1)``.
If ``False``, returns an array of shape ``(1, N)``.
Returns:
A 2D NumPy array suitable for heatmap visualization.
Raises:
ValueError: If ``vec`` is not one-dimensional.
"""
v = np.asarray(list(vec), dtype=int)
if v.ndim != 1:
raise ValueError("vec must be a 1D iterable of integers.")
return v[:, None] if vertical else v[None, :]
def plot_binary_vector_heatmap(
vec: Iterable[int],
*,
vertical: bool = True,
color0: Color = "white",
color1: Color = "tab:blue",
text_color: Color = "black",
fontsize: int = 16,
figsize: Optional[Tuple[float, float]] = None,
out_path: Optional[str] = None,
dpi: int = 200,
show: bool = True,
):
"""Plots a binary vector as a borderless heatmap with centered values.
Each element of the vector is rendered as a square cell. Cells with value
``1`` are highlighted using ``color1``, and cells with value ``0`` use
``color0``. The numeric value of each element is written at the center
of its corresponding cell. The plot contains no axes, ticks, titles,
or gridlines.
Args:
vec: A 1D iterable of binary values (0 or 1). Multiple ones are allowed
(i.e., multi-hot vectors).
vertical: If ``True``, plots the vector as an ``N × 1`` column.
If ``False``, plots it as a ``1 × N`` row.
color0: Matplotlib-compatible color for elements equal to 0.
color1: Matplotlib-compatible color for elements equal to 1.
text_color: Color used for the numeric labels inside each cell.
fontsize: Font size for the numeric labels.
figsize: Optional figure size ``(width, height)`` in inches.
If ``None``, a reasonable default is chosen based on orientation
and vector length.
out_path: Optional file path. If provided, the figure is saved to
this location.
dpi: Resolution (dots per inch) used when saving the figure.
show: If ``True``, displays the figure using ``plt.show()``.
Returns:
A tuple ``(fig, ax)`` where ``fig`` is the Matplotlib ``Figure`` and
``ax`` is the corresponding ``Axes`` object.
Raises:
ValueError: If ``vec`` contains values other than 0 or 1.
"""
arr = vector_to_heatmap_array(vec, vertical=vertical)
uniq = np.unique(arr)
if not set(uniq.tolist()).issubset({0, 1}):
raise ValueError(
f"vec must contain only 0/1 values. Found: {uniq.tolist()}"
)
if figsize is None:
n = arr.shape[0] if vertical else arr.shape[1]
figsize = (1.6, max(1.6, 1.0 * n)) if vertical else (max(1.6, 1.0 * n), 1.6)
cmap = ListedColormap([color0, color1])
norm = BoundaryNorm([-0.5, 0.5, 1.5], cmap.N)
fig, ax = plt.subplots(figsize=figsize)
ax.imshow(arr, cmap=cmap, norm=norm, aspect="equal", interpolation="nearest")
if vertical:
for i in range(arr.shape[0]):
ax.text(
0,
i,
str(int(arr[i, 0])),
ha="center",
va="center",
fontsize=fontsize,
color=text_color,
)
else:
for j in range(arr.shape[1]):
ax.text(
j,
0,
str(int(arr[0, j])),
ha="center",
va="center",
fontsize=fontsize,
color=text_color,
)
ax.axis("off")
plt.tight_layout(pad=0)
if out_path is not None:
fig.savefig(out_path, dpi=dpi, bbox_inches="tight", pad_inches=0)
if show:
plt.show()
return fig, ax
if __name__ == "__main__":
vec = [0, 1, 0, 1, 0]
# Default colors
plot_binary_vector_heatmap(vec, vertical=True)
# Custom colors
plot_binary_vector_heatmap(
vec,
vertical=True,
color0="#FFFFFF",
color1="#F8971F", # orange
text_color="black",
out_path="5elem_twohot_vertical_custom.png",
)
def plot_continuous_vector_heatmap(
vec: Iterable[float],
*,
vertical: bool = True,
cmap: Union[str, plt.cm.ScalarMappable] = "viridis",
text_color: Color = "black",
fontsize: int = 16,
fmt: str = "{:.2f}",
figsize: Optional[Tuple[float, float]] = None,
out_path: Optional[str] = None,
dpi: int = 200,
show: bool = True,
):
"""Plots a continuous-valued vector as a borderless heatmap with transparency.
Args:
vec: A 1D iterable of floating-point values in [0, 1].
vertical: If True, plots as an N×1 column; otherwise 1×N row.
cmap: Matplotlib colormap for values in [0, 1].
text_color: Color of numeric labels.
fontsize: Font size for numeric labels.
fmt: Format string for numeric labels.
figsize: Optional figure size in inches.
out_path: Optional output file path.
dpi: DPI used when saving.
show: If True, displays the figure.
Returns:
(fig, ax): Matplotlib Figure and Axes objects.
Raises:
ValueError: If any value in vec lies outside [0, 1].
"""
arr = vector_to_heatmap_array(vec, vertical=vertical)
if np.any(arr < 0.0) or np.any(arr > 1.0):
raise ValueError("All elements of vec must be in the interval [0, 1].")
if figsize is None:
n = arr.shape[0] if vertical else arr.shape[1]
figsize = (1.6, max(1.6, n)) if vertical else (max(1.6, n), 1.6)
fig, ax = plt.subplots(figsize=figsize)
# --- key lines for transparency ---
fig.patch.set_alpha(0.0)
ax.set_facecolor((0, 0, 0, 0))
# ---------------------------------
ax.imshow(
arr,
cmap=cmap,
norm=Normalize(vmin=0.0, vmax=1.0),
aspect="equal",
interpolation="nearest",
)
if vertical:
for i in range(arr.shape[0]):
ax.text(
0,
i,
fmt.format(arr[i, 0]),
ha="center",
va="center",
fontsize=fontsize,
color=text_color,
)
else:
for j in range(arr.shape[1]):
ax.text(
j,
0,
fmt.format(arr[0, j]),
ha="center",
va="center",
fontsize=fontsize,
color=text_color,
)
ax.axis("off")
plt.tight_layout(pad=0)
if out_path is not None:
fig.savefig(
out_path,
dpi=dpi,
bbox_inches="tight",
pad_inches=0,
transparent=True, # REQUIRED
)
if show:
plt.show()
return fig, ax
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment