Source code for concepts.benchmark.vision_language.babel_qa.utils
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : utils.py
# Author : Joy Hsu
# Email : joycj@stanford.edu
# Date : 03/24/2023
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.
from collections import defaultdict
from copy import deepcopy
__all__ = ['nsclseq_to_nscltree', 'nsclseq_to_nsclqsseq', 'nscltree_to_nsclqstree', 'program_to_nsclseq']
[docs]
def nsclseq_to_nscltree(seq_program):
def dfs(sblock):
tblock = deepcopy(sblock)
input_ids = tblock.pop('inputs')
tblock['inputs'] = [dfs(seq_program[i]) for i in input_ids]
return tblock
try:
return dfs(seq_program[-1])
finally:
del dfs
[docs]
def nsclseq_to_nsclqsseq(seq_program):
qs_seq = deepcopy(seq_program)
cached = defaultdict(list)
'''for sblock in qs_seq:
for param_type in gdef.parameter_types:
if param_type in sblock:
sblock[param_type + '_idx'] = len(cached[param_type])
sblock[param_type + '_values'] = cached[param_type]
cached[param_type].append(sblock[param_type])'''
return qs_seq
[docs]
def nscltree_to_nsclqstree(tree_program):
qs_tree = deepcopy(tree_program)
cached = defaultdict(list)
'''for tblock in iter_nscltree(qs_tree):
for param_type in gdef.parameter_types:
if param_type in tblock:
tblock[param_type + '_idx'] = len(cached[param_type])
tblock[param_type + '_values'] = cached[param_type]
cached[param_type].append(tblock[param_type])'''
return qs_tree
[docs]
def iter_nscltree(tree_program):
yield tree_program
for i in tree_program['inputs']:
yield from iter_nscltree(i)
[docs]
def get_clevr_pblock_op(block):
if 'type' in block:
return block['type']
assert 'function' in block
return block['function']
[docs]
def get_clevr_op_attribute(op):
return "_".join(op.split('_')[1:])
[docs]
def program_to_nsclseq(program):
nscl_program = list()
mapping = dict()
for block_id, block in enumerate(program):
op = get_clevr_pblock_op(block)
current = None
if op == 'scene':
current = dict(op='scene')
elif op.startswith('filter'):
concept = block['value_inputs'][0]
last = nscl_program[mapping[block['inputs'][0]]]
if last['op'] == 'filter':
last['concept'].append(concept)
else:
current = dict(op='filter', concept=[concept])
elif op.startswith('relate'):
concept = block['value_inputs'][0]
current = dict(op='relate', relational_concept=[concept])
elif op.startswith('same'):
attribute = get_clevr_op_attribute(op)
current = dict(op='relate_attribute_equal', attribute=attribute)
elif op in ('intersect', 'union'):
current = dict(op=op)
elif op == 'unique':
pass # We will ignore the unique operations.
else:
if op.startswith('query'):
if block_id == len(program) - 1:
attribute = get_clevr_op_attribute(op)
current = dict(op='query', attribute=attribute)
elif op.startswith('equal') and op != 'equal_integer':
attribute = get_clevr_op_attribute(op)
current = dict(op='query_attribute_equal', attribute=attribute)
elif op == 'exist':
current = dict(op='exist')
elif op == 'count':
if block_id == len(program) - 1:
current = dict(op='count')
elif op == 'equal_integer':
current = dict(op='count_equal')
elif op == 'less_than':
current = dict(op='count_less')
elif op == 'greater_than':
current = dict(op='count_greater')
else:
raise ValueError('Unknown CLEVR operation: {}.'.format(op))
if current is None:
assert len(block['inputs']) == 1
mapping[block_id] = mapping[block['inputs'][0]]
else:
current['inputs'] = list(map(mapping.get, block['inputs']))
if '_output' in block:
current['output'] = deepcopy(block['_output'])
nscl_program.append(current)
mapping[block_id] = len(nscl_program) - 1
return nscl_program