Last active
December 15, 2025 12:49
-
-
Save singularitti/8a9e5dc4a79f9d5ab6ceeb5911db8b7e to your computer and use it in GitHub Desktop.
Vector heatmaps #Python #plotting
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from matplotlib.colors import ListedColormap, BoundaryNorm | |
| from typing import Iterable, Tuple, Optional, Union | |
| Color = Union[str, Tuple[float, float, float], Tuple[float, float, float, float]] | |
| def vector_to_heatmap_array(vec: Iterable[int], vertical: bool = True) -> np.ndarray: | |
| """Converts a 1D binary vector into a 2D array for heatmap plotting. | |
| This utility reshapes a vector such as ``[0, 1, 0, 1, 0]`` into a format | |
| compatible with ``matplotlib.imshow``. The orientation can be vertical | |
| (column vector) or horizontal (row vector). | |
| Args: | |
| vec: A 1D iterable of integers (typically 0/1). | |
| vertical: If ``True``, returns an array of shape ``(N, 1)``. | |
| If ``False``, returns an array of shape ``(1, N)``. | |
| Returns: | |
| A 2D NumPy array suitable for heatmap visualization. | |
| Raises: | |
| ValueError: If ``vec`` is not one-dimensional. | |
| """ | |
| v = np.asarray(list(vec), dtype=int) | |
| if v.ndim != 1: | |
| raise ValueError("vec must be a 1D iterable of integers.") | |
| return v[:, None] if vertical else v[None, :] | |
| def plot_binary_vector_heatmap( | |
| vec: Iterable[int], | |
| *, | |
| vertical: bool = True, | |
| color0: Color = "white", | |
| color1: Color = "tab:blue", | |
| text_color: Color = "black", | |
| fontsize: int = 16, | |
| figsize: Optional[Tuple[float, float]] = None, | |
| out_path: Optional[str] = None, | |
| dpi: int = 200, | |
| show: bool = True, | |
| ): | |
| """Plots a binary vector as a borderless heatmap with centered values. | |
| Each element of the vector is rendered as a square cell. Cells with value | |
| ``1`` are highlighted using ``color1``, and cells with value ``0`` use | |
| ``color0``. The numeric value of each element is written at the center | |
| of its corresponding cell. The plot contains no axes, ticks, titles, | |
| or gridlines. | |
| Args: | |
| vec: A 1D iterable of binary values (0 or 1). Multiple ones are allowed | |
| (i.e., multi-hot vectors). | |
| vertical: If ``True``, plots the vector as an ``N × 1`` column. | |
| If ``False``, plots it as a ``1 × N`` row. | |
| color0: Matplotlib-compatible color for elements equal to 0. | |
| color1: Matplotlib-compatible color for elements equal to 1. | |
| text_color: Color used for the numeric labels inside each cell. | |
| fontsize: Font size for the numeric labels. | |
| figsize: Optional figure size ``(width, height)`` in inches. | |
| If ``None``, a reasonable default is chosen based on orientation | |
| and vector length. | |
| out_path: Optional file path. If provided, the figure is saved to | |
| this location. | |
| dpi: Resolution (dots per inch) used when saving the figure. | |
| show: If ``True``, displays the figure using ``plt.show()``. | |
| Returns: | |
| A tuple ``(fig, ax)`` where ``fig`` is the Matplotlib ``Figure`` and | |
| ``ax`` is the corresponding ``Axes`` object. | |
| Raises: | |
| ValueError: If ``vec`` contains values other than 0 or 1. | |
| """ | |
| arr = vector_to_heatmap_array(vec, vertical=vertical) | |
| uniq = np.unique(arr) | |
| if not set(uniq.tolist()).issubset({0, 1}): | |
| raise ValueError( | |
| f"vec must contain only 0/1 values. Found: {uniq.tolist()}" | |
| ) | |
| if figsize is None: | |
| n = arr.shape[0] if vertical else arr.shape[1] | |
| figsize = (1.6, max(1.6, 1.0 * n)) if vertical else (max(1.6, 1.0 * n), 1.6) | |
| cmap = ListedColormap([color0, color1]) | |
| norm = BoundaryNorm([-0.5, 0.5, 1.5], cmap.N) | |
| fig, ax = plt.subplots(figsize=figsize) | |
| ax.imshow(arr, cmap=cmap, norm=norm, aspect="equal", interpolation="nearest") | |
| if vertical: | |
| for i in range(arr.shape[0]): | |
| ax.text( | |
| 0, | |
| i, | |
| str(int(arr[i, 0])), | |
| ha="center", | |
| va="center", | |
| fontsize=fontsize, | |
| color=text_color, | |
| ) | |
| else: | |
| for j in range(arr.shape[1]): | |
| ax.text( | |
| j, | |
| 0, | |
| str(int(arr[0, j])), | |
| ha="center", | |
| va="center", | |
| fontsize=fontsize, | |
| color=text_color, | |
| ) | |
| ax.axis("off") | |
| plt.tight_layout(pad=0) | |
| if out_path is not None: | |
| fig.savefig(out_path, dpi=dpi, bbox_inches="tight", pad_inches=0) | |
| if show: | |
| plt.show() | |
| return fig, ax | |
| if __name__ == "__main__": | |
| vec = [0, 1, 0, 1, 0] | |
| # Default colors | |
| plot_binary_vector_heatmap(vec, vertical=True) | |
| # Custom colors | |
| plot_binary_vector_heatmap( | |
| vec, | |
| vertical=True, | |
| color0="#FFFFFF", | |
| color1="#F8971F", # orange | |
| text_color="black", | |
| out_path="5elem_twohot_vertical_custom.png", | |
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def plot_continuous_vector_heatmap( | |
| vec: Iterable[float], | |
| *, | |
| vertical: bool = True, | |
| cmap: Union[str, plt.cm.ScalarMappable] = "viridis", | |
| text_color: Color = "black", | |
| fontsize: int = 16, | |
| fmt: str = "{:.2f}", | |
| figsize: Optional[Tuple[float, float]] = None, | |
| out_path: Optional[str] = None, | |
| dpi: int = 200, | |
| show: bool = True, | |
| ): | |
| """Plots a continuous-valued vector as a borderless heatmap with transparency. | |
| Args: | |
| vec: A 1D iterable of floating-point values in [0, 1]. | |
| vertical: If True, plots as an N×1 column; otherwise 1×N row. | |
| cmap: Matplotlib colormap for values in [0, 1]. | |
| text_color: Color of numeric labels. | |
| fontsize: Font size for numeric labels. | |
| fmt: Format string for numeric labels. | |
| figsize: Optional figure size in inches. | |
| out_path: Optional output file path. | |
| dpi: DPI used when saving. | |
| show: If True, displays the figure. | |
| Returns: | |
| (fig, ax): Matplotlib Figure and Axes objects. | |
| Raises: | |
| ValueError: If any value in vec lies outside [0, 1]. | |
| """ | |
| arr = vector_to_heatmap_array(vec, vertical=vertical) | |
| if np.any(arr < 0.0) or np.any(arr > 1.0): | |
| raise ValueError("All elements of vec must be in the interval [0, 1].") | |
| if figsize is None: | |
| n = arr.shape[0] if vertical else arr.shape[1] | |
| figsize = (1.6, max(1.6, n)) if vertical else (max(1.6, n), 1.6) | |
| fig, ax = plt.subplots(figsize=figsize) | |
| # --- key lines for transparency --- | |
| fig.patch.set_alpha(0.0) | |
| ax.set_facecolor((0, 0, 0, 0)) | |
| # --------------------------------- | |
| ax.imshow( | |
| arr, | |
| cmap=cmap, | |
| norm=Normalize(vmin=0.0, vmax=1.0), | |
| aspect="equal", | |
| interpolation="nearest", | |
| ) | |
| if vertical: | |
| for i in range(arr.shape[0]): | |
| ax.text( | |
| 0, | |
| i, | |
| fmt.format(arr[i, 0]), | |
| ha="center", | |
| va="center", | |
| fontsize=fontsize, | |
| color=text_color, | |
| ) | |
| else: | |
| for j in range(arr.shape[1]): | |
| ax.text( | |
| j, | |
| 0, | |
| fmt.format(arr[0, j]), | |
| ha="center", | |
| va="center", | |
| fontsize=fontsize, | |
| color=text_color, | |
| ) | |
| ax.axis("off") | |
| plt.tight_layout(pad=0) | |
| if out_path is not None: | |
| fig.savefig( | |
| out_path, | |
| dpi=dpi, | |
| bbox_inches="tight", | |
| pad_inches=0, | |
| transparent=True, # REQUIRED | |
| ) | |
| if show: | |
| plt.show() | |
| return fig, ax |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment