""" Streamlit web interface for perspective warp correction. Uses core functionality from correct.py. """ import pathlib import tempfile from io import BytesIO import cv2 import kornia import kornia.geometry.transform as kgt import numpy as np import streamlit as st import tifffile import torch from correct import ( DEFAULT_PARAMS, RAW_EXTENSIONS, crop_image, detect_document_corners, warp_perspective_transform, ) try: import rawpy except ImportError: rawpy = None # ----------------------------------------------------------- # Helper: Load image from Streamlit file uploader. # Supports RAW files (using rawpy) and standard formats. # For raw files, we call postprocess() with custom parameters # to get a correctly color-balanced view. # ----------------------------------------------------------- @st.cache_data(show_spinner=False) def load_image_streamlit( uploaded_file, gamma: float = 1.0, curve: float = 1.0, use_auto_wb: bool = False, bright: float = 1.0, ): """Load image from Streamlit's UploadedFile object""" extension = pathlib.Path(uploaded_file.name).suffix.lower() file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8) if extension in RAW_EXTENSIONS: if rawpy is None: st.error( "rawpy module not found. Please install it via 'pip install rawpy'" ) return None, None # Write to temp file since rawpy needs file access with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp_file: tmp_file.write(file_bytes) tmp_filename = tmp_file.name with rawpy.imread(tmp_filename) as raw: image = raw.postprocess( gamma=(gamma, curve), use_auto_wb=use_auto_wb, bright=bright, output_color=rawpy.ColorSpace.sRGB, ) color_order = "RGB" else: image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) if image is None: st.error("Error loading image.") return None, None color_order = "BGR" return image, color_order # ----------------------------------------------------------- # Helper: Order 4 points as [top-left, top-right, bottom-right, bottom-left] # ----------------------------------------------------------- def order_points(pts: np.ndarray) -> np.ndarray: s = pts.sum(axis=1) tl = pts[np.argmin(s)] br = pts[np.argmax(s)] diff = np.diff(pts, axis=1) tr = pts[np.argmin(diff)] bl = pts[np.argmax(diff)] return np.array([tl, tr, br, bl], dtype="float32") # ----------------------------------------------------------- # Detect document/frame corners using either Canny or Adaptive thresholding. # # Returns: # - corners: a 4x2 array if detected (or None) # - debug_info: a dictionary of intermediate images and info. # # This version downsizes the grayscale for processing and then scales # the detected corners back up. # ----------------------------------------------------------- @st.cache_data(show_spinner=False) def detect_document_corners_custom(image, color_order, detection_method, params): debug_info = {} # Convert to grayscale. if len(image.shape) == 3 and image.shape[2] == 3: if color_order.upper() == "BGR": gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) else: gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) else: gray = image.copy() debug_info["Grayscale"] = gray # Downscale for processing if required. downscale_factor = params.get("downscale_factor", 1) if downscale_factor > 1: h, w = gray.shape gray = cv2.resize(gray, (w // downscale_factor, h // downscale_factor)) # Blur the image. blur_kernel = params.get("blur_kernel_size", 5) # Ensure the kernel size is odd. if blur_kernel % 2 == 0: blur_kernel += 1 blurred = cv2.GaussianBlur(gray, (blur_kernel, blur_kernel), 0) debug_info["Blurred"] = blurred corners = None method_used = None # --- Adaptive threshold approach (default and primary) --- if detection_method in ["Adaptive", "Auto"]: block_size = params.get("adaptive_block_size", 11) adaptive_C = params.get("adaptive_C", 2) # Ensure block_size is odd. if block_size % 2 == 0: block_size += 1 thresh = cv2.adaptiveThreshold( blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, block_size, adaptive_C, ) debug_info["Adaptive Threshold"] = thresh cnts, _ = cv2.findContours( thresh.copy(), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE ) cnts = sorted(cnts, key=cv2.contourArea, reverse=True) approx_factor_adaptive = params.get("adaptive_approx", 0.02) for c in cnts: peri = cv2.arcLength(c, True) approx = cv2.approxPolyDP(c, approx_factor_adaptive * peri, True) if len(approx) == 4: pts = approx.reshape(4, 2) corners = order_points(pts) method_used = "Adaptive" break # --- Canny-based approach (if forced) --- if corners is None and detection_method in ["Canny"]: canny_low = params.get("canny_low", 50) canny_high = params.get("canny_high", 150) edges = cv2.Canny(blurred, canny_low, canny_high) debug_info["Canny Edges"] = edges cnts, _ = cv2.findContours(edges.copy(), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) cnts = sorted(cnts, key=cv2.contourArea, reverse=True) approx_factor = params.get("canny_approx", 0.1) for c in cnts: peri = cv2.arcLength(c, True) approx = cv2.approxPolyDP(c, approx_factor * peri, True) if len(approx) == 4: pts = approx.reshape(4, 2) corners = order_points(pts) method_used = "Canny" break # Create an overlay to draw the detected corners. corners_overlay = image.copy() if corners is not None: # If we downscaled before processing, scale corners back. if downscale_factor > 1: corners = corners * downscale_factor for x, y in corners: cv2.circle(corners_overlay, (int(x), int(y)), 10, (0, 255, 0), -1) debug_info["Detected Corners Overlay"] = corners_overlay debug_info["Method Used"] = method_used return corners, debug_info # ----------------------------------------------------------- # Warp perspective using Kornia. # ----------------------------------------------------------- @st.cache_data(show_spinner=False) def warp_perspective_transform(image, src_pts): src = order_points(src_pts) widthA = np.linalg.norm(src[2] - src[3]) widthB = np.linalg.norm(src[1] - src[0]) maxWidth = int(max(widthA, widthB)) heightA = np.linalg.norm(src[1] - src[2]) heightB = np.linalg.norm(src[0] - src[3]) maxHeight = int(max(heightA, heightB)) if maxWidth <= 0 or maxHeight <= 0: return None dst = np.array( [[0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1], [0, maxHeight - 1]], dtype="float32", ) src_tensor = torch.from_numpy(src).unsqueeze(0) dst_tensor = torch.from_numpy(dst).unsqueeze(0) M = kgt.get_perspective_transform(src_tensor, dst_tensor) if len(image.shape) == 2: image = image[:, :, np.newaxis] image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float() warped_tensor = kgt.warp_perspective(image_tensor, M, dsize=(maxHeight, maxWidth)) warped = warped_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() warped = np.clip(warped, 0, 255).astype(np.uint8) return warped def main(): """Main Streamlit app""" st.title("Perspective Warp Debugger") # File upload uploaded_file = st.file_uploader( "Upload an image", type=["jpg", "jpeg", "png", "nef", "cr2", "arw", "raf", "raw", "dng", "rw2"], ) if not uploaded_file: return # Load image with st.sidebar.expander("Image Processing Parameters", expanded=True): gamma = st.slider( "Gamma", 0.1, 5.0, DEFAULT_PARAMS["gamma"], 0.1, help="Gamma correction for the image.", ) curve = st.slider( "Curve", 0.1, 5.0, DEFAULT_PARAMS["curve"], 0.1, help="Curve correction for the image.", ) use_auto_wb = st.checkbox( "Use Auto White Balance", value=DEFAULT_PARAMS["use_auto_wb"], help="Use auto white balance for the image.", ) bright = st.slider("Brightness", 0.1, 5.0, DEFAULT_PARAMS["bright"], 0.1) st.markdown("---") st.markdown("### Crop Settings") # Use columns for a more compact layout col1, col2 = st.columns(2) with col1: crop_top = st.slider( "Crop Top", 0.0, 0.4, DEFAULT_PARAMS["crop_top"], 0.01, help="Percentage to crop from top edge", ) crop_bottom = st.slider( "Crop Bottom", 0.0, 0.4, DEFAULT_PARAMS["crop_bottom"], 0.01, help="Percentage to crop from bottom edge", ) with col2: crop_left = st.slider( "Crop Left", 0.0, 0.4, DEFAULT_PARAMS["crop_left"], 0.01, help="Percentage to crop from left edge", ) crop_right = st.slider( "Crop Right", 0.0, 0.4, DEFAULT_PARAMS["crop_right"], 0.01, help="Percentage to crop from right edge", ) image, color_order = load_image_streamlit( uploaded_file, gamma=gamma, curve=curve, use_auto_wb=use_auto_wb, bright=bright ) if image is None: return # Apply cropping image = crop_image( image, crop_top=crop_top, crop_bottom=crop_bottom, crop_left=crop_left, crop_right=crop_right, ) # Display original disp_image = ( cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if color_order.upper() == "BGR" else image.copy() ) st.image( disp_image, caption="Original Image (After Cropping)", use_container_width=True ) # Parameter controls st.sidebar.header("Processing Parameters") params = {} params["downscale_factor"] = st.sidebar.number_input( "Downscale Factor (for processing)", min_value=1, value=DEFAULT_PARAMS["downscale_factor"], step=1, help="Downscale factor to speed up processing on very high-res images.", ) params["blur_kernel_size"] = st.sidebar.slider( "Gaussian Blur Kernel Size (odd)", 3, 21, DEFAULT_PARAMS["blur_kernel_size"], step=2, ) params["canny_low"] = st.sidebar.slider( "Canny Low Threshold", 0, 255, DEFAULT_PARAMS["canny_low"] ) params["canny_high"] = st.sidebar.slider( "Canny High Threshold", 0, 255, DEFAULT_PARAMS["canny_high"] ) params["canny_approx"] = st.sidebar.slider( "Contour Approx Factor (Canny)", 0.01, 0.2, DEFAULT_PARAMS["canny_approx"], step=0.01, ) params["adaptive_block_size"] = st.sidebar.slider( "Adaptive Threshold Block Size (odd)", 3, 25, DEFAULT_PARAMS["adaptive_block_size"], step=2, ) params["adaptive_C"] = st.sidebar.slider( "Adaptive Threshold Constant", 0, 10, DEFAULT_PARAMS["adaptive_C"] ) params["adaptive_approx"] = st.sidebar.slider( "Contour Approx Factor (Adaptive)", 0.01, 0.2, DEFAULT_PARAMS["adaptive_approx"], step=0.01, ) detection_method = st.sidebar.radio( "Detection Method", ["Adaptive", "Canny", "Auto"], index=0 ).lower() # Process image corners, debug_info = detect_document_corners( image, color_order, method=detection_method, params=params ) # Show results method_used = debug_info.get("Method Used") st.write(f"Detection method used: **{method_used}**") if corners is None: st.error("Could not detect a 4-corner contour. Try adjusting parameters.") return st.success("Detected corners!") # Show debug images for key in ["Grayscale", "Blurred", "Canny Edges", "Adaptive Threshold"]: if key not in debug_info: continue img = debug_info[key] if len(img.shape) == 2: st.image(img, caption=key, use_container_width=True) else: img_disp = ( cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if color_order.upper() == "BGR" else img.copy() ) st.image(img_disp, caption=key, use_container_width=True) # Apply warp warped = warp_perspective_transform(image, corners) if warped is None: st.error("Warping failed due to invalid dimensions.") return # Show warped image warped_disp = ( cv2.cvtColor(warped, cv2.COLOR_BGR2RGB) if color_order.upper() == "BGR" else warped.copy() ) st.image( warped_disp, caption="Warped (Perspective Corrected) Image", use_container_width=True, ) # Save/download functionality base_name = pathlib.Path(uploaded_file.name).stem output_filename = f"{base_name}.tiff" buf = BytesIO() tifffile.imwrite(buf, warped) buf.seek(0) st.sidebar.download_button( label="Save Output Image", data=buf, file_name=output_filename, mime="image/tiff", key="download", ) if __name__ == "__main__": main()