#!/usr/bin/env python3 """ Streamlit app for segmenting insects using SAM (Segment Anything Model). uv run \ --with streamlit \ --with segment_anything \ --with opencv-python-headless \ --with torch \ --with matplotlib \ streamlit run sam_segment_st.py """ import json import cv2 import numpy as np import streamlit as st import torch from matplotlib import pyplot as plt from segment_anything import SamAutomaticMaskGenerator, sam_model_registry @st.cache_resource def load_sam_model(): """Load SAM model with caching.""" model_type = "vit_h" # Using the highest quality model checkpoint = "sam_vit_h_4b8939.pth" # Force CPU for now due to MPS float64 issues device = "cpu" st.info(f"Using device: {device}") # Load model sam = sam_model_registry[model_type](checkpoint=checkpoint) sam.to(device=device) return sam, device def process_image(image, mask_generator, min_area=0.0001, max_area=0.1): """ Generate segments using SAM's automatic mask generator. Args: image: RGB image array mask_generator: SAM automatic mask generator min_area: Minimum area as fraction of image area max_area: Maximum area as fraction of image area """ # Ensure image is uint8 if image.dtype != np.uint8: image = (image * 255).astype(np.uint8) # Get image area for filtering image_area = image.shape[0] * image.shape[1] min_area_pixels = image_area * min_area max_area_pixels = image_area * max_area # Generate masks with torch.inference_mode(): masks = mask_generator.generate(image) # Filter masks by area and sort by area filtered_masks = [] for mask in masks: area = mask["area"] if min_area_pixels <= area <= max_area_pixels: filtered_masks.append(mask) # Sort by area, largest first filtered_masks = sorted(filtered_masks, key=lambda x: x["area"], reverse=True) return filtered_masks def plot_results(image, masks): """Plot original image and segmentation results.""" fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7)) # Original image ax1.imshow(image) ax1.set_title("Original") ax1.axis("off") # Segmentation ax2.imshow(image) # Plot masks with random colors and transparency for mask in masks: color = np.random.rand( 3, ).astype( np.float32 ) # Force float32 mask_array = mask["segmentation"] # Create mask overlay mask_overlay = np.zeros_like(image, dtype=np.float32) # Force float32 mask_overlay[mask_array] = color # Blend with original image ax2.imshow(mask_overlay, alpha=0.35) # Draw contour contour = mask["bbox"] # [x, y, w, h] rect = plt.Rectangle( (contour[0], contour[1]), contour[2], contour[3], linewidth=1, edgecolor=color, facecolor="none", ) ax2.add_patch(rect) ax2.set_title(f"Segmentation ({len(masks)} segments)") ax2.axis("off") plt.tight_layout() return fig def main(): st.title("Insect Segmentation with SAM") # Load SAM model try: sam, device = load_sam_model() mask_generator = SamAutomaticMaskGenerator( model=sam, points_per_side=32, pred_iou_thresh=0.86, stability_score_thresh=0.92, crop_n_layers=1, crop_n_points_downscale_factor=2, min_mask_region_area=100, # Minimum area in pixels ) except FileNotFoundError: st.error( """ Please download the SAM checkpoint file: wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth """ ) return # File uploader uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: # Convert uploaded file to numpy array file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8) image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Parameter controls col1, col2 = st.columns(2) with col1: min_area_pct = st.slider( "Minimum Area (%)", min_value=0.01, max_value=1.0, value=0.05, step=0.01, help="Minimum segment area as percentage of image area", ) pred_iou_thresh = st.slider( "Prediction IoU Threshold", min_value=0.0, max_value=1.0, value=0.86, help="Higher values = more selective segmentation", ) with col2: max_area_pct = st.slider( "Maximum Area (%)", min_value=1.0, max_value=20.0, value=5.0, step=0.1, help="Maximum segment area as percentage of image area", ) stability_score_thresh = st.slider( "Stability Score Threshold", min_value=0.0, max_value=1.0, value=0.92, help="Higher values = more stable segments", ) # Process image with st.spinner("Processing image with SAM..."): masks = process_image( image, mask_generator, min_area=min_area_pct / 100, max_area=max_area_pct / 100, ) # Plot results fig = plot_results(image, masks) st.pyplot(fig) # Add download button for masks if st.button("Download Masks as JSON"): # Convert masks to JSON-serializable format masks_json = [ { "segmentation": mask["segmentation"].tolist(), "area": float(mask["area"]), "bbox": [float(x) for x in mask["bbox"]], } for mask in masks ] st.download_button( "Download JSON", data=json.dumps(masks_json), file_name="masks.json", mime="application/json", ) if __name__ == "__main__": main()