import copy import os import json import logging import uuid from urllib.parse import urljoin import httpx from fastapi import FastAPI, Request, HTTPException from fastapi.responses import StreamingResponse, JSONResponse from dotenv import load_dotenv # ======================== # Конфигурация и инициализация # ======================== load_dotenv() LOG_LEVEL = os.getenv("LOG_LEVEL", "info").upper() logging.basicConfig(level=LOG_LEVEL) logger = logging.getLogger(__name__) app = FastAPI() OPENWEBUI_URL = os.getenv("OPENWEBUI_URL", "your_endpoint_here") API_KEY = os.getenv("OPENWEBUI_API_KEY", "your_api_key_here") TIMEOUT = 30.0 ZED_SYSTEM_PROMPT_FILE = os.getenv("ZED_SYSTEM_PROMPT_FILE") ZED_SYSTEM_PROMPT_MODE = os.getenv("ZED_SYSTEM_PROMPT_MODE", "default").lower() EMULATE_TOOLS_CALLING = os.getenv("EMULATE_TOOLS_CALLING", True) START_MARKER = "<|tools▁calls▁start|>" END_MARKER = "<|tools▁calls▁end|>" MARKER_MAX_LEN = max(len(START_MARKER), len(END_MARKER)) # ======================== # Загрузка и обработка системного промта # ======================== def load_system_prompt() -> str | None: if ZED_SYSTEM_PROMPT_FILE and os.path.exists(ZED_SYSTEM_PROMPT_FILE): with open(ZED_SYSTEM_PROMPT_FILE, encoding="utf-8") as f: return f.read().strip() return None def apply_system_prompt_policy(messages: list[dict], mode: str, custom_prompt: str | None) -> list[dict]: if mode == "disable": return [m for m in messages if m.get("role") != "system"] if mode == "replace" and custom_prompt: filtered = [m for m in messages if m.get("role") != "system"] filtered.append({"role": "system", "content": custom_prompt}) return filtered return messages def inject_tools_as_prompt(tools: dict, messages: list[dict]) -> None: if not tools: return tools_text = json.dumps(tools, indent=2, ensure_ascii=False) tools_message = { "role": "system", "content": f"Available tools:\n{tools_text}" } for i, msg in enumerate(messages): if msg.get("role") == "system": messages.insert(i + 1, tools_message) return messages.append(tools_message) # ======================== # Основной endpoint # ======================== @app.post("/v1/chat/completions") async def openai_proxy(request: Request): logger.info(">>> Вызван openai_proxy") body = await request.json() original_body = copy.deepcopy(body) # Системный промт system_prompt = load_system_prompt() body["messages"] = apply_system_prompt_policy(body.get("messages", []), ZED_SYSTEM_PROMPT_MODE, system_prompt) # Интеграция tools в messages if EMULATE_TOOLS_CALLING: tools = body.pop("tools", None) if tools: inject_tools_as_prompt(tools, body.get("messages", [])) logger.info("Инструменты встроены в messages, ключ 'tools' удалён") if body != original_body: logger.info(f"Тело запроса изменено: {json.dumps(body, ensure_ascii=False)}") else: logger.info(f"Тело запроса без изменений: {json.dumps(body, ensure_ascii=False)}") # Извлекаем Authorization из исходного запроса, если есть auth_header = request.headers.get("Authorization", f"Bearer {API_KEY}") headers = { "Authorization": auth_header, "Content-Type": "application/json", "Accept": "text/event-stream" if body.get("stream") else "application/json", } generator = func_calling_event_generator if EMULATE_TOOLS_CALLING else default_event_generator return StreamingResponse(generator(body, headers), media_type="text/event-stream") # ======================== # Прокси для всех /v1/* путей # ======================== @app.api_route("/v1/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def proxy_all(request: Request, path: str): if path == "chat/completions": return await openai_proxy(request) target_url = urljoin(f"{OPENWEBUI_URL}/", path) try: request_body = None if request.method in ["POST", "PUT"]: try: request_body = await request.json() except json.JSONDecodeError: request_body = None async with httpx.AsyncClient(timeout=TIMEOUT) as client: response = await client.request( method=request.method, url=target_url, headers={ "Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json", }, json=request_body, params=dict(request.query_params), ) filtered_headers = { k: v for k, v in response.headers.items() if k.lower() not in ["content-encoding", "content-length", "transfer-encoding", "connection"] } return JSONResponse( content=response.json(), status_code=response.status_code, headers=filtered_headers, ) except httpx.ReadTimeout: logger.error("Таймаут при обращении к Open WebUI") raise HTTPException(status_code=504, detail="Таймаут соединения с Open WebUI") except Exception as e: logger.error(f"Ошибка проксирования: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) # ======================== # Генераторы событий # ======================== async def default_event_generator(body: dict, headers: dict): max_log_chunk = 200 try: async with httpx.AsyncClient(timeout=TIMEOUT) as client: async with client.stream("POST", f"{OPENWEBUI_URL}/api/chat/completions", json=body, headers=headers) as response: if response.status_code != 200: text = await response.aread() logger.error(f"OpenWebUI error: {text.decode()}") yield format_error_event(text.decode()) return async for line in response.aiter_lines(): if not line.strip(): continue if line.startswith("data: "): json_str = line[len("data: "):].strip() try: data = json.loads(json_str) if "sources" in data: snippet = json_str[:max_log_chunk].replace("\n", " ") logger.info(f"Пропущен чанк с 'sources': {snippet}...") continue except json.JSONDecodeError: pass logger.info(line) yield f"{line}\n" except Exception as e: logger.error(f"Ошибка стриминга: {e}") yield format_error_event("Internal server error") async def func_calling_event_generator(body: dict, headers: dict): text_accumulator = [] ignore_rest = False try: async with httpx.AsyncClient(timeout=60) as client: async with client.stream("POST", f"{OPENWEBUI_URL}/api/chat/completions", json=body, headers=headers) as response: if response.status_code != 200: text = await response.aread() logger.error(f"Ошибка от API: {text.decode()}") yield f"data: {{\"error\": \"{text.decode()}\"}}\n\n" return async for line in response.aiter_lines(): if ignore_rest: if line.strip() == "data: [DONE]": yield line + "\n" return continue if not line.startswith("data: "): continue data_part = line[len("data: "):].strip() if not data_part: continue try: data_json = json.loads(data_part) choice = data_json.get("choices", [{}])[0] delta = choice.get("delta", {}) content = delta.get("content", "") if not content: continue except Exception: continue if "sources" in data_json: # Пропускаем системные служебные чанки OpenWEBUI continue # Добавляем текст к аккумулятору text_accumulator.append(content) accumulated = "".join(text_accumulator) start_idx = accumulated.find(START_MARKER) end_idx = accumulated.find(END_MARKER) if start_idx != -1 and end_idx != -1 and end_idx > start_idx: pre_text = accumulated[:start_idx].strip() json_block = accumulated[start_idx + len(START_MARKER):end_idx] if pre_text: yield emit_with_content(data_json, pre_text) first_brace = json_block.find("{") last_brace = json_block.rfind("}") if first_brace != -1 and last_brace != -1: json_str = json_block[first_brace : last_brace + 1] try: parsed = json.loads(json_str) functions = parsed.get("functions", []) if functions: # Отдаем все вызовы функций async for tool_chunk in stream_tool_calls(data_json, functions): logger.info(tool_chunk[:-2]) yield tool_chunk except json.JSONDecodeError: pass logger.info("data: [DONE]\n") yield "data: [DONE]\n\n" return else: # Пока маркеры не найдены — ищем безопасную часть для отдачи accumulated = ''.join(text_accumulator) def find_safe_cutoff(buffer: str, marker: str) -> int: """ Возвращает позицию начала маркера `ABC` (или его префикса `A`, `AB`) в любом месте строки. - Если найден маркер → индекс его начала. - Если найден префикс маркера в конце → его индекс. - Если ничего не найдено → длина буфера. """ # Сначала проверяем полный маркер pos = buffer.find(marker) if pos != -1: return pos # Проверяем префиксы маркера в конце строки for i in range(len(marker) - 1, 0, -1): prefix = marker[:i] if buffer.endswith(prefix): return len(buffer) - len(prefix) return len(buffer) # Маркер не найден safe_idx = find_safe_cutoff(accumulated, START_MARKER) safe_part = accumulated[:safe_idx] remaining_tail = accumulated[safe_idx:] chunk = emit_with_content(data_json, safe_part) logger.info(chunk[:-2]) yield chunk text_accumulator = [remaining_tail] yield "data: [DONE]\n\n" except Exception as e: logger.error(f"Ошибка стриминга: {e}") yield f"data: {{\"error\": \"Internal server error\"}}\n\n" # ======================== # Потоковый вывод функций # ======================== def emit_with_content(base_json: dict, content: str) -> str: base = dict(base_json) base["choices"][0]["delta"] = {"role": "assistant", "content": content} return f"data: {json.dumps(base)}\n\n" async def stream_tool_calls(base_json: dict, functions: list): for i, func in enumerate(functions): fc = func.get("function_call", {}) chunk = { "id": f"{base_json["id"]}", "object": "chat.completion.chunk", "model": f"{base_json["model"]}", "created": f"{base_json["created"]}", "choices": [ { "index": 0, "delta": { "tool_calls": [ { "id": f"call_{i}", "index": i, "type": "function", "function": { "name": fc.get("name", ""), "arguments": fc.get("arguments", "") } } ] }, "finish_reason": "tool_calls" if i == len(functions) - 1 else None, "native_finish_reason": "tool_calls" if i == len(functions) - 1 else None } ] } yield f"data: {json.dumps(chunk)}\n\n" # ======================== # Утилиты # ======================== def format_error_event(message): return f'data: {{"error": "{message}"}}\n\n' def find_safe_cutoff(buffer: str, marker: str) -> int: pos = buffer.find(marker) if pos != -1: return pos for i in range(len(marker) - 1, 0, -1): if buffer.endswith(marker[:i]): return len(buffer) - i return len(buffer) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="127.0.0.1", port=5000, log_level=LOG_LEVEL.lower())