Source code for concepts.vision.fm_match.diff3f.extractor_dino
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : extractor_dino.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 09/16/2024
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.
from typing import Union
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms as tfs
patch_size = 14
[docs]
def init_dino(device: str) -> torch.nn.Module:
"""Initialize the DINO model.
Args:
device: the device to use.
Returns:
the DINO model.
"""
model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14")
model = model.to(device).eval()
return model
[docs]
@torch.no_grad()
def get_dino_features(device: str, dino_model: torch.nn.Module, img: Union[Image.Image, np.ndarray], grid: torch.Tensor, normalize: bool = True) -> torch.Tensor:
"""Get the DINO features for a given image and grid.
This function will always resize the image to (518, 518) and then compute its DINO features. This will result in a feature map of (37, 37).
Then, we will "project" the feature map to the grid using bilinear interpolation.
Args:
device: the device to use.
dino_model: the DINO model.
img: the image to extract features from.
grid: the grid to project the features to.
normalize: whether to normalize the features.
Returns:
the grid-projected DINO features.
"""
transform = tfs.Compose([
tfs.Resize((518, 518)),
tfs.ToTensor(),
tfs.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
img = transform(img)[:3].unsqueeze(0).to(device)
features = dino_model.get_intermediate_layers(img, n=1)[0].half()
h, w = int(img.shape[2] / patch_size), int(img.shape[3] / patch_size)
dim = features.shape[-1]
features = features.reshape(-1, h, w, dim).permute(0, 3, 1, 2)
if device == 'cpu':
# NB(Jiayuan Mao @ 2024/09/16): When using CPU, some operations such as grid_sample are not supported for half precision.
features = features.to(torch.float32)
grid = grid.to(torch.float32)
features = F.grid_sample(features, grid, align_corners=False).reshape(1, 768, -1)
if normalize:
features = F.normalize(features, dim=1)
return features