Source code for concepts.vision.fm_match.diff3f.diff3f_renderer

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : diff3f_renderer.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 09/16/2024
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

"""Renderer 3D meshes with the PyTorch3D renderer."""

import math
from typing import Optional, Union

import torch
import torch.nn as nn

from pytorch3d.structures.meshes import Meshes
from pytorch3d.ops.interp_face_attrs import interpolate_face_attributes

from pytorch3d.renderer.blending import BlendParams
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras
from pytorch3d.renderer.lighting import PointLights
from pytorch3d.renderer.materials import Materials
from pytorch3d.renderer.utils import TensorProperties
from pytorch3d.renderer.mesh.rasterizer import Fragments, RasterizationSettings, MeshRasterizer
from pytorch3d.renderer.mesh.shader import HardPhongShader
from pytorch3d.renderer.mesh.renderer import MeshRenderer


from concepts.vision.fm_match.diff3f.diff3f_mesh import MeshContainer


[docs] class HardPhongNormalShader(nn.Module): """ Modifies HardPhongShader to return normals Per pixel lighting - the lighting model is applied using the interpolated coordinates and normals for each pixel. The blending function hard assigns the color of the closest face for each pixel. To use the default values, simply initialize the shader with the desired device e.g. .. code-block:: shader = HardPhongShader(device=torch.device("cuda:0")) """
[docs] def __init__( self, device = "cpu", cameras: Optional[TensorProperties] = None, lights: Optional[TensorProperties] = None, materials: Optional[Materials] = None, blend_params: Optional[BlendParams] = None, ) -> None: super().__init__() self.lights = lights if lights is not None else PointLights(device=device) self.materials = ( materials if materials is not None else Materials(device=device) ) self.cameras = cameras self.blend_params = blend_params if blend_params is not None else BlendParams()
[docs] def to(self, device): # Manually move to device modules which are not subclasses of nn.Module cameras = self.cameras if cameras is not None: self.cameras = cameras.to(device) self.materials = self.materials.to(device) self.lights = self.lights.to(device) return self
[docs] def phong_normal_shading(self, meshes, fragments) -> torch.Tensor: faces = meshes.faces_packed() # (F, 3) vertex_normals = meshes.verts_normals_packed() # (V, 3) faces_normals = vertex_normals[faces] ones = torch.ones_like(fragments.bary_coords) pixel_normals = interpolate_face_attributes( fragments.pix_to_face, ones, faces_normals ) return pixel_normals
[docs] def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) if cameras is None: msg = "Cameras must be specified either at initialization \ or in the forward pass of HardPhongShader" raise ValueError(msg) normals = self.phong_normal_shading( meshes=meshes, fragments=fragments, ) return normals
[docs] @torch.no_grad() def run_rendering(device: str, mesh: Meshes, num_views: int, H: int, W: int, additional_angle_azi: float = 0, additional_angle_ele: float = 0, use_normal_map: bool = False): bbox = mesh.get_bounding_boxes() bbox_min = bbox.min(dim=-1).values[0] bbox_max = bbox.max(dim=-1).values[0] bb_diff = bbox_max - bbox_min bbox_center = (bbox_min + bbox_max) / 2.0 scaling_factor = 0.65 distance = torch.sqrt((bb_diff * bb_diff).sum()) distance *= scaling_factor steps = int(math.sqrt(num_views)) end = 360 - 360/steps elevation = torch.linspace(start = 0 , end = end , steps = steps).repeat(steps) + additional_angle_ele azimuth = torch.linspace(start = 0 , end = end , steps = steps) azimuth = torch.repeat_interleave(azimuth, steps) + additional_angle_azi bbox_center = bbox_center.unsqueeze(0) rotation, translation = look_at_view_transform( dist=distance, azim=azimuth, elev=elevation, device=device, at=bbox_center ) camera = PerspectiveCameras(R=rotation, T=translation, device=device) rasterization_settings = RasterizationSettings( image_size=(H, W), blur_radius=0.0, faces_per_pixel=1, bin_size=0 ) rasterizer = MeshRasterizer(cameras=camera, raster_settings=rasterization_settings) camera_centre = camera.get_camera_center() lights = PointLights( diffuse_color=((0.4, 0.4, 0.5),), ambient_color=((0.6, 0.6, 0.6),), specular_color=((0.01, 0.01, 0.01),), location=camera_centre, device=device, ) shader = HardPhongShader(device=device, cameras=camera, lights=lights) batch_renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) batch_mesh = mesh.extend(num_views) normal_batched_renderings = None batched_renderings = batch_renderer(batch_mesh) if use_normal_map: normal_shader = HardPhongNormalShader(device=device, cameras=camera, lights=lights) normal_batch_renderer = MeshRenderer(rasterizer=rasterizer, shader=normal_shader) normal_batched_renderings = normal_batch_renderer(batch_mesh) fragments = rasterizer(batch_mesh) depth = fragments.zbuf return batched_renderings, normal_batched_renderings, camera, depth
[docs] def batch_render(device: str, mesh: Union[MeshContainer, Meshes], num_views: int, H: int, W: int, use_normal_map: bool = False): if isinstance(mesh, MeshContainer): mesh = mesh.to_pytorch3d_meshes(device) trials = 0 additional_angle_azi = 0. additional_angle_ele = 0. while trials < 5: try: return run_rendering( device, mesh, num_views, H, W, additional_angle_azi=additional_angle_azi, additional_angle_ele=additional_angle_ele, use_normal_map=use_normal_map ) except torch.linalg.LinAlgError as e: trials += 1 print("lin alg exception at rendering, retrying ", trials) additional_angle_azi = torch.randn(1).item() additional_angle_ele = torch.randn(1).item() continue