import xml.etree.ElementTree as ET
from pathlib import Path
import matplotlib.pyplot as plt
import bisect

XML_CONFIGS = [
    ("overlay/Frameworks/res/values/config.xml", "unicorn", False),
    ("config.xml", "unicorn - old", False),
    # ("../cupid/overlay/Frameworks/res/values/config.xml", "cupid"),
    # ("pantah/panther/overlay/frameworks/base/core/res/res/values/config.xml", "pantah"),
    # ("/home/arian/android/vendor/pixel/blazer/blazer-bp4a.260105.004.e1/framework-res__blazer__auto_generated_rro_vendor/res/values/arrays.xml", "blazer"),
]

LEVELS_NAME = "config_autoBrightnessLevels"
NITS_NAME = "config_autoBrightnessDisplayValuesNits"


MAX_LUX = 30000.0
SCALE_FROM_LUX = 15000.0
BASE_LUX = 15000.0

def adjust_curve(lux_levels, nits_values,
                 max_lux=MAX_LUX, base_lux=BASE_LUX):
    """
    Adjust one auto-brightness config:

    1. Remove all lux levels > max_lux
    2. Let B_base = brightness at base_lux (15000).
       For lux >= base_lux, only scale the *extra* brightness above B_base
       so that the brightness at max_lux (30000) reaches the global max nits.

    Assumes:
        len(nits_values) == len(lux_levels) + 1
        lux_levels sorted ascending.
    """
    assert len(nits_values) == len(lux_levels) + 1

    # Global max brightness before any changes (we want to reach this at 30000)
    max_nits_global = max(nits_values)

    # --- 1) Truncate all lux levels above max_lux ---
    # cut_idx = number of lux levels <= max_lux
    cut_idx = bisect.bisect_right(lux_levels, max_lux)
    lux_trunc = lux_levels[:cut_idx]
    nits_trunc = nits_values[:cut_idx + 1]   # +1 because of the extra segment

    # --- 2) Find the segment index for base_lux (15000) ---
    # segment index i: nits_trunc[i] is used for lux in:
    #   [ 0, lux_trunc[0] )    -> i = 0
    #   [ lux_trunc[0], lux_trunc[1] ) -> i = 1
    #   ...
    base_seg = bisect.bisect_right(lux_trunc, base_lux)
    base_nits = nits_trunc[base_seg]

    # Segment index for ">= last lux" (this contains 30000 after truncation)
    last_idx = len(nits_trunc) - 1
    last_nits_before = nits_trunc[last_idx]

    tail_delta = last_nits_before - base_nits
    if tail_delta <= 0:
        # Nothing to scale above base (already flat or decreasing)
        # -> just return with truncation applied.
        return lux_trunc, nits_trunc

    # Desired brightness at 30000 = global max
    desired_last = max_nits_global
    if desired_last <= base_nits:
        # Would not make sense to scale if target <= base
        return lux_trunc, nits_trunc

    # --- 3) Scale only the *extra* brightness above base_nits ---
    scale_factor = (desired_last - base_nits) / tail_delta

    # Keep everything <= 15000 (incl. its segment) unchanged:
    # start scaling from the *next* segment after base_seg
    for i in range(base_seg + 1, len(nits_trunc)):
        # old extra above base:
        extra = nits_trunc[i] - base_nits
        # scaled extra:
        nits_trunc[i] = base_nits + extra * scale_factor
        # clamp to [base_nits, desired_last]
        if nits_trunc[i] < base_nits:
            nits_trunc[i] = base_nits
        if nits_trunc[i] > desired_last:
            nits_trunc[i] = desired_last

    # Make sure the last segment really hits desired_last
    nits_trunc[-1] = desired_last

    return lux_trunc, nits_trunc

def get_array(root, name, as_type=float):
    arr_elem = root.find(f".//integer-array[@name='{name}']")
    if arr_elem is None:
        arr_elem = root.find(f".//array[@name='{name}']")
    if arr_elem is None:
        raise ValueError(f"Array with name '{name}' not found")

    values = []
    for item in arr_elem.findall("item"):
        text = (item.text or "").strip()
        if text:
            values.append(as_type(text))
    return values


def load_config(path):
    tree = ET.parse(path)
    root = tree.getroot()
    lux_levels = get_array(root, LEVELS_NAME, as_type=float)
    nits_values = get_array(root, NITS_NAME, as_type=float)
    return lux_levels, nits_values


def plot_auto_brightness_curve(lux_levels, nits_values, label, ax):
    """Plot one auto-brightness curve as a step function onto ax."""
    if len(nits_values) != len(lux_levels) + 1:
        print(f"[{label}] WARNING: len(nits)={len(nits_values)} "
              f"!= len(lux)+1={len(lux_levels)+1}; skipping this config.")
        return

    lower_bound = 0.0
    last_lux = lux_levels[-1]

    lux_bounds = [lower_bound] + lux_levels
    y_step = list([n / nits_values[-1] for n in nits_values])

    ax.step(lux_bounds, y_step, where="post", label=label)


def main():
    fig, ax = plt.subplots()

    for path_str, label, scale in XML_CONFIGS:
        path = Path(path_str)
        if not path.is_file():
            print(f"[{label}] File not found: {path}")
            continue

        lux_levels, nits_values = load_config(path)
        print(f"[{label}] lux_levels: {len(lux_levels)} entries")
        print(f"[{label}] nits_values: {len(nits_values)} entries")
        plot_auto_brightness_curve(lux_levels, nits_values, label, ax)
        if scale:
            lux_levels, nits_values = adjust_curve(lux_levels, nits_values)
            plot_auto_brightness_curve(lux_levels, nits_values, label + " (scaled)", ax)

            print('<integer-array name="config_autoBrightnessLevels">')
            for v in lux_levels:
                print(f'    <item>{int(round(v))}</item>')
            print('</integer-array>')

            print('<array name="config_autoBrightnessDisplayValuesNits">')
            for v in nits_values:
                print(f'    <item>{v:.1f}</item>')
            print('</array>')

    ax.set_xlabel("Ambient light (lux)")
    ax.set_ylabel("Display brightness")
    ax.set_title("Auto-brightness curve comparison")
    ax.grid(True)
    ax.legend()
    ax.set_xlim(left=0)
    ax.set_ylim(bottom=0)

    plt.tight_layout()
    plt.show()


if __name__ == "__main__":
    main()
