Source code for concepts.benchmark.vision_language.babel_qa.dataset

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : dataset.py
# Author : Joy Hsu
# Email  : joycj@stanford.edu
# Date   : 03/23/2023
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

import os.path as osp

import jacinle.io as io

from jacinle.logging import get_logger
from jacinle.utils.container import GView
from jactorch.data.dataset import FilterableDatasetUnwrapped, FilterableDatasetView
from jactorch.data.dataloader import JacDataLoader
from jactorch.data.collate import VarLengthCollateV2

from concepts.benchmark.vision_language.babel_qa.utils import nsclseq_to_nscltree, nsclseq_to_nsclqsseq, nscltree_to_nsclqstree, program_to_nsclseq

import numpy as np
import math


logger = get_logger(__file__)

__all__ = ['BabelQADataset', 'BabelMotionClassificationDataset']


[docs] class BabelQADatasetUnwrapped(FilterableDatasetUnwrapped): """BabelQA dataset."""
[docs] def __init__(self, data_dir: str, data_split_file: str, split: str, data_source: str, no_gt_segments: bool, filter_supervision, max_frames: int = 150): """Initialize BabelQA dataset. Args: data_dir: path to data directory. data_split_file: path to the json file that contains the question ids for each split. split: split to use. data_source: data source to use. Either teach_synth or ??? no_gt_segments: whether to use ground truth segments or not. filter_supervision: ??? max_frames: maximum number of frames to use per segment. (is this correct???) """ super().__init__() self.labels_json = osp.join(data_dir, 'motion_concepts.json') self.questions_json = osp.join(data_dir, 'questions.json') self.joints_root = osp.join(data_dir, 'motion_sequences') self.labels = io.load_json(self.labels_json) self.questions = io.load_json(self.questions_json) self.split_question_ids = io.load_json(data_split_file)[split] self.data_source = data_source self.max_frames = max_frames self.no_gt_segments = no_gt_segments self.filter_supervision = filter_supervision
def _get_questions(self): return self.questions def _get_metainfo(self, index): question = self.questions[self.split_question_ids[index]] # program section has_program = False if 'program_nsclseq' in question: question['program_raw'] = question['program_nsclseq'] question['program_seq'] = question['program_nsclseq'] has_program = True elif 'program' in question: question['program_raw'] = question['program'] question['program_seq'] = program_to_nsclseq(question['program']) has_program = True if has_program: question['program_tree'] = nsclseq_to_nscltree(question['program_seq']) question['program_qsseq'] = nsclseq_to_nsclqsseq(question['program_seq']) question['program_qstree'] = nscltree_to_nsclqstree(question['program_tree']) return question
[docs] def __getitem__(self, index): metainfo = GView(self.get_metainfo(index)) feed_dict = GView() question = self.questions[self.split_question_ids[index]] if 'program_raw' in metainfo: feed_dict.program_raw = metainfo.program_raw feed_dict.program_seq = metainfo.program_seq feed_dict.program_tree = metainfo.program_tree feed_dict.program_qsseq = metainfo.program_qsseq feed_dict.program_qstree = metainfo.program_qstree if '/' in question['answer']: question['answer'] = question['answer'].split('/')[0] feed_dict.answer = question['answer'] feed_dict.question_type = question['query_type'] feed_dict.segment_boundaries = [] feed_dict.question_text = question['question'] # process joints if self.data_source == 'teach_synth': id_name = 'teach_synth_id' motion_id = question[id_name] num_segments = len(self.labels[motion_id]['labels']) else: id_name = 'babel_id' motion_id = question[id_name] num_segments = len(self.labels[motion_id]) feed_dict.babel_id = motion_id joints = np.load(osp.join(self.joints_root, motion_id, 'joints.npy')) # T, V, C # change shape of joints to match model joints = joints[:, :, :, np.newaxis] # T, V, C, M joints = joints.transpose(2, 0, 1, 3) # C, T, V, M # label info if self.data_source == 'teach_synth': labels_frame_info = self.labels[motion_id]['labels'] else: labels_frame_info = self.labels[motion_id] if 'filter_answer_0' in question: filter_segment = labels_frame_info[question['filter_answer_0']] if filter_segment['end_f'] > np.shape(joints)[1]: # right now end frame can be slightly off (dataset issue) filter_segment['end_f'] = np.shape(joints)[1] feed_dict.filter_boundaries = [(filter_segment['start_f'], filter_segment['end_f'])] if 'filter_answer_1' in question: filter_segment = labels_frame_info[question['filter_answer_1']] if filter_segment['end_f'] > np.shape(joints)[1]: # right now end frame can be slightly off (dataset issue) filter_segment['end_f'] = np.shape(joints)[1] feed_dict.filter_boundaries.append((filter_segment['start_f'], filter_segment['end_f'])) if not self.no_gt_segments: joints_combined = np.zeros((num_segments, 3, self.max_frames, 22, 1), dtype=np.float32) # num_segs, C, T, V, M for seg_i, seg in enumerate(labels_frame_info): if seg['end_f'] > np.shape(joints)[1]: # right now end frame can be slightly off (dataset issue) seg['end_f'] = np.shape(joints)[1] num_frames = seg['end_f'] - seg['start_f'] if num_frames > self.max_frames: # clip segments to max_frames num_frames = self.max_frames joints_combined[seg_i, :, :num_frames, :, :] = joints[:, seg['start_f']: seg['start_f'] + num_frames, :, :] feed_dict.segment_boundaries.append((seg['start_f'], (seg['start_f'] + num_frames))) feed_dict.joints = joints_combined feed_dict.num_segs = num_segments else: total_num_frames = np.shape(joints)[1] num_frames_per_seg = 45 overlap_frames = 15 num_segments = math.ceil(total_num_frames / num_frames_per_seg) feed_dict['info'] = [] joints_combined = np.zeros((num_segments, 3, num_frames_per_seg + overlap_frames*2, 22, 1), dtype=np.float32) # num_segs, C, T, V, M for i in range(num_segments): start_f = i * num_frames_per_seg end_f = (i + 1) * num_frames_per_seg if end_f > total_num_frames: end_f = total_num_frames missing_before_context = overlap_frames - start_f if start_f < overlap_frames else 0 existing_after_context = total_num_frames - end_f if existing_after_context > overlap_frames: existing_after_context = overlap_frames joints_combined[i, :, missing_before_context:overlap_frames+(end_f - start_f)+existing_after_context, :, :] = joints[:, start_f - (overlap_frames - missing_before_context):end_f + existing_after_context, :, :] feed_dict.segment_boundaries.append((start_f - (overlap_frames - missing_before_context), end_f + existing_after_context)) feed_dict.joints = joints_combined feed_dict.num_segs = num_segments return feed_dict.raw()
[docs] def __len__(self): return len(self.split_question_ids)
[docs] class BabelQADatasetFilterableView(FilterableDatasetView):
[docs] def filter_questions(self, allowed): def filt(question): return question['query_type'] in allowed return self.filter(filt, 'filter-question-type[allowed={{{}}}]'.format(','.join(list(allowed))))
[docs] def make_dataloader(self, batch_size, shuffle, drop_last, nr_workers): collate_guide = { 'joints': 'concat', 'answer': 'skip', 'segment_boundaries': 'skip', 'filter_boundaries': 'skip', 'program_raw': 'skip', 'program_seq': 'skip', 'program_tree': 'skip', 'program_qsseq': 'skip', 'program_qstree': 'skip', } return JacDataLoader( self, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=nr_workers, pin_memory=True, collate_fn=VarLengthCollateV2(collate_guide) )
[docs] def BabelQADataset(*args, **kwargs): return BabelQADatasetFilterableView(BabelQADatasetUnwrapped(*args, **kwargs))
[docs] class BabelMotionClassificationDatasetUnwrapped(FilterableDatasetUnwrapped):
[docs] def __init__(self, labels_json, joints_root, symbolic, max_frames=150): super().__init__() self.labels_json = labels_json self.joints_root = joints_root self.act2idx_json = 'nscl/datasets/babel/action_label_2_idx.json' self.labels = io.load_json(self.labels_json) self.babel_ids = list(self.labels.keys()) self.act2idx = io.load_json(self.act2idx_json) self.symbolic = symbolic self.max_frames = max_frames
[docs] def __getitem__(self, index): feed_dict = GView() babel_id = self.babel_ids[index] actions = [] num_segments = len(self.labels[babel_id]) joints = np.load(osp.join(self.joints_root, f'{babel_id}.npy')) # T, V, C # change shape of joints to match model joints = joints[:, :, :, np.newaxis] # T, V, C, M joints = joints.transpose(2, 0, 1, 3) # C, T, V, M joints_combined = np.zeros((num_segments, 3, self.max_frames, 22, 1), dtype=np.float32) # num_segs, C, T, V, M # TODO: take out segments that only have one frame (to be consistent with unbatched version) for seg_i, seg in enumerate(self.labels[babel_id]): if seg['end_f'] > np.shape(joints)[1]: # right now end frame can be slightly off (dataset issue) seg['end_f'] = np.shape(joints)[1] num_frames = seg['end_f'] - seg['start_f'] if num_frames > self.max_frames: # clip segments to max_frames num_frames = self.max_frames joints_combined[seg_i, :, :num_frames, :, :] = joints[:, seg['start_f']: seg['start_f'] + num_frames, :, :] action = seg['action'] if self.symbolic: # option 1: add all actions for i in range(len(action)): action[i] = self.act2idx[action[i]] actions.append(action) else: # option 2: pick randomly among actions random_action_idx = np.random.randint(len(action)) action_idx = self.act2idx[action[random_action_idx].replace('.', '')] # remove periods from action names actions.append([action_idx]) feed_dict.actions = actions feed_dict.joints = joints_combined feed_dict.id = babel_id feed_dict.num_segs = num_segments return feed_dict.raw()
[docs] def __len__(self): return len(self.babel_ids)
[docs] class BabelMotionClassificationDatasetFilterableView(FilterableDatasetView):
[docs] def make_dataloader(self, batch_size, shuffle, drop_last, nr_workers): from jactorch.data.dataloader import JacDataLoader from jactorch.data.collate import VarLengthCollateV2 collate_guide = { 'joints': 'concat', 'actions': 'skip', } return JacDataLoader( self, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=nr_workers, pin_memory=True, collate_fn=VarLengthCollateV2(collate_guide) )
[docs] def BabelMotionClassificationDataset(*args, **kwargs): return BabelMotionClassificationDatasetFilterableView(BabelMotionClassificationDatasetUnwrapped(*args, **kwargs))