Source code for concepts.vision.fm_match.dino.get_feature

import gc
from typing import Optional, Union, List

from tqdm import tqdm
from PIL import Image

import numpy as np
import torch
import torch.nn.functional as F

from concepts.vision.fm_match.dino.extractor_dino import ViTExtractor


[docs] def resize(img: Union[np.ndarray, Image.Image], target_res, resize=True, to_pil=True, edge=False): if isinstance(img, np.ndarray): img = Image.fromarray(img) original_width, original_height = img.size original_channels = len(img.getbands()) if not edge: canvas = np.zeros([target_res, target_res, 3], dtype=np.uint8) if original_channels == 1: canvas = np.zeros([target_res, target_res], dtype=np.uint8) if original_height <= original_width: if resize: img = img.resize((target_res, int(np.around(target_res * original_height / original_width))), Image.Resampling.LANCZOS) width, height = img.size img = np.asarray(img) canvas[(width - height) // 2: (width + height) // 2] = img else: if resize: img = img.resize((int(np.around(target_res * original_width / original_height)), target_res), Image.Resampling.LANCZOS) width, height = img.size img = np.asarray(img) canvas[:, (height - width) // 2: (height + width) // 2] = img else: if original_height <= original_width: if resize: img = img.resize((target_res, int(np.around(target_res * original_height / original_width))), Image.Resampling.LANCZOS) width, height = img.size img = np.asarray(img) top_pad = (target_res - height) // 2 bottom_pad = target_res - height - top_pad img = np.pad(img, pad_width=[(top_pad, bottom_pad), (0, 0), (0, 0)], mode='edge') else: if resize: img = img.resize((int(np.around(target_res * original_width / original_height)), target_res), Image.Resampling.LANCZOS) width, height = img.size img = np.asarray(img) left_pad = (target_res - width) // 2 right_pad = target_res - width - left_pad img = np.pad(img, pad_width=[(0, 0), (left_pad, right_pad), (0, 0)], mode='edge') canvas = img if to_pil: canvas = Image.fromarray(canvas) return canvas
[docs] def torch_pca(feature: torch.Tensor, target_dim: int = 256) -> torch.Tensor: """ Perform Principal Component Analysis (PCA) on the input feature tensor. Parameters: - feature (torch.Tensor): The input tensor with shape (N, D), where N is the number of samples and D is the feature dimension. - target_dim (int, optional): The target dimension for the output tensor. Defaults to 256. Returns: - torch.Tensor: The transformed tensor with shape (N, target_dim). """ mean = torch.mean(feature, dim=0, keepdim=True) centered_features = feature - mean U, S, V = torch.pca_lowrank(centered_features, q=target_dim) reduced_features = torch.matmul(centered_features, V[:, :target_dim]) return reduced_features
[docs] def compute_dino_feature( source_img: Union[Image.Image, List[Union[np.ndarray, Image.Image]]], target_imgs: Optional[List[Union[np.ndarray, Image.Image]]] = None, *, model_size: str = 'base', use_dino_v2: bool = True, stride: Optional[int] = None, edge_pad: bool = False, pca: bool = False, pca_dim: int = 256, reusable_extractor: Optional[ViTExtractor] = None ) -> tuple[torch.Tensor, List[Image.Image], List[Image.Image]]: """ return: (result, resized_imgs, downsampled_imgs), where result is a tensor of shape (N, pca_dim, num_patches, num_patches), resized_imgs is a list of PIL image_scene resized to the input size of the dino model, and downsampled_imgs is a list of PIL image_scene resized to the output size of the dino model. """ device = 'cuda' if torch.cuda.is_available() else 'cpu' img_size = 840 if use_dino_v2 else 244 if reusable_extractor is None: model_dict = {'small': 'dinov2_vits14', 'base': 'dinov2_vitb14', 'large': 'dinov2_vitl14', 'giant': 'dinov2_vitg14'} model_type = model_dict[model_size] if use_dino_v2 else 'dino_vits8' layer = 11 if use_dino_v2 else 9 if 'l' in model_type: layer = 23 elif 'g' in model_type: layer = 39 facet = 'token' if use_dino_v2 else 'key' if stride is None: stride = 14 if use_dino_v2 else 4 extractor = ViTExtractor(model_type, stride, device=device) else: extractor = reusable_extractor patch_size = extractor.model.patch_embed.patch_size[0] if use_dino_v2 else extractor.model.patch_embed.patch_size num_patches = int(patch_size / stride * (img_size // patch_size - 1) + 1) original_imgs = list() if isinstance(source_img, (np.ndarray, Image.Image)): original_imgs.append(source_img) else: original_imgs.extend(source_img) if target_imgs is not None: if isinstance(target_imgs, (np.ndarray, Image.Image)): original_imgs.append(target_imgs) else: original_imgs.extend(target_imgs) result = [] resized_imgs = [resize(img, img_size, resize=True, to_pil=True, edge=edge_pad) for img in original_imgs] for img in tqdm(resized_imgs, desc='Extracting dino feature'): with torch.no_grad(): img_batch = extractor.preprocess_pil(img) img_desc = extractor.extract_descriptors( img_batch.to(device), layer, facet) # 1,1,num_patches*num_patches, feature_dim result.append(img_desc) result = torch.concat(result, dim=0) # N, 1, num_patches*num_patches, feature_dim if pca: N, _, _, feature_dim = result.shape result = result.reshape(-1, feature_dim) result = torch_pca(result, pca_dim) result = result.reshape(N, 1, -1, pca_dim) result = result.permute(0, 1, 3, 2).reshape(result.shape[0], result.shape[-1], num_patches, num_patches) result = F.normalize(result, dim=1) gc.collect() torch.cuda.empty_cache() output_size = result.shape[-1] downsampled_imgs = [resize(img, output_size, resize=True, to_pil=True, edge=edge_pad) for img in original_imgs] return result, resized_imgs, downsampled_imgs