Skip to content

Instantly share code, notes, and snippets.

@maulikmadhavi
Created August 26, 2024 18:30
Show Gist options
  • Select an option

  • Save maulikmadhavi/41a50155fe6d18a56e01dc05acbba9d4 to your computer and use it in GitHub Desktop.

Select an option

Save maulikmadhavi/41a50155fe6d18a56e01dc05acbba9d4 to your computer and use it in GitHub Desktop.
siglip_matching
from transformers import SiglipModel
class SigLipSimilarity:
def __init__(self):
device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = SiglipModel.from_pretrained(
"google/siglip-so400m-patch14-384",
device_map=device,
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
)
self.model.eval()
self.processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384",
device="cuda",
torch_dtype=torch.float16)
@torch.no_grad()
def run_image(self, image):
inputs = self.processor(images=image, text = None, padding="max_length", return_tensors="pt",
)
image_embeds = self.model.get_image_features(pixel_values = inputs['pixel_values'].cuda().half())
return image_embeds
@torch.no_grad()
def run_text(self, text):
inputs = self.processor(images=None, text=text, padding="max_length", return_tensors="pt")
text_embeds = self.model.get_text_features(input_ids = inputs['input_ids'].cuda())
return text_embeds
@torch.no_grad()
def get_similarity(self, image_embeds: torch.Tensor, text_embeds: torch.Tensor):
# normalized features
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
logits_per_image = (
torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) * self.model.logit_scale.exp()
+ self.model.logit_bias
)
logits_per_text = logits_per_image.t()
probs = torch.sigmoid(logits_per_image) # these are the probabilities
return probs
#url = "http://images.cocodataset.org/val2017/000000039769.jpg"
# image = Image.open(requests.get(url, stream=True).raw)
#texts = ["a photo of 2 cats", "a photo of 2 dogs"]
#
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment