#!/usr/bin/env uv run

from mlx_vlm import generate, load, utils

utils.MODEL_REMAPPING["llava"] = "pixtral"

model_path = "mistral-community/pixtral-12b"

template = """{%- if messages[0]["role"] == "system" %}
    {%- set system_message = messages[0]["content"] %}
    {%- set loop_messages = messages[1:] %}
{%- else %}
    {%- set loop_messages = messages %}
{%- endif %}

{{- bos_token }}
{%- for message in loop_messages %}
    {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
        {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}
    {%- endif %}
    {%- if message["role"] == "user" %}
        {%- if loop.last and system_message is defined %}
            {{- "[INST]" + system_message + "\n\n" }}
        {%- else %}
            {{- "[INST]" }}
        {%- endif %}
        {%- if message["content"] is not string %}
            {%- for chunk in message["content"] %}
                {%- if chunk["type"] == "text" %}
                    {{- chunk["content"] }}
                {%- elif chunk["type"] == "image" %}
                    {{- "[IMG]" }}
                {%- else %}
                    {{- raise_exception("Unrecognized content type!") }}
                {%- endif %}
            {%- endfor %}
        {%- else %}
            {{- message["content"] }}
        {%- endif %}
        {{- "[/INST]" }}
    {%- elif message["role"] == "assistant" %}
        {{- message["content"] + eos_token}}
    {%- else %}
        {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
    {%- endif %}
{%- endfor %}"""

model, processor = load(model_path, {"chat_template": template})

prompt = processor.tokenizer.apply_chat_template(
    [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "content": "What are these?"},
            ],
        }
    ],
    tokenize=False,
    add_generation_prompt=True,
)

output = generate(
    model,
    processor,
    "http://images.cocodataset.org/val2017/000000039769.jpg",
    prompt,
    verbose=True,
)