#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : ckyee.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 12/11/2022
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.
"""NeuralCCG with CKY-EE (CKY with Expected Execution)."""
from collections import defaultdict
from dataclasses import dataclass
from typing import Optional, Union, Iterable, Tuple, List, Dict, Callable
import torch
import torch.nn.functional as F
import jacinle.random as jacrandom
import jactorch
from concepts.dsl.function_domain import FunctionDomain
from concepts.dsl.value import Value
from concepts.dsl.executors.function_domain_executor import FunctionDomainExecutor
from concepts.language.ccg.composition import CCGCompositionSystem, CCGCoordinationImmNode
from concepts.language.neural_ccg.grammar import NeuralCCGSemanticsPartialTypeLex, NeuralCCGSemantics, NeuralCCGNode, NeuralCCG
from concepts.language.neural_ccg.search import NeuralCCGLexiconSearchResult
__all__ = ['CKYEEExpectationFunction', 'CKYEEExpectationConfig', 'NeuralCKYEE']
_profile = getattr(__builtins__, 'profile', lambda x: x)
[docs]
class CKYEEExpectationFunction(object):
"""A collection of functions to perform expected execution. This class maintains a set of functions of the following form:
.. code-block:: python
def expectation_<typename>(self, values: List[Value], weights: List[torch.Tensor]) -> Tuple[Optional[Value], Optional[Tuple[torch.Tensor, float]]]:
...
where ``typename`` is the name of a type, and ``values`` is a list of values of the type, and ``weights`` is a list of corresponding weights.
The function should return a tuple of (expectation, (sum_weight, max_weight)), where ``expectation`` is the expectation of the values,
as a :class:`~concepts.dsl.value.Value` instance, and ``sum_weight`` is the sum of the weights (as a PyTorch tensor),
and ``max_weight`` is the maximum of the weights (as a float).
"""
[docs]
def __init__(self, domain: FunctionDomain):
"""Initialize the expectation function.
Args:
domain: the function domain.
"""
self._domain = domain
self._expectation_functions = dict()
self._init_expectation_functions()
_domain: FunctionDomain
_expectation_functions: Dict[str, Callable[[List[Value], List[torch.Tensor]], Tuple[torch.Tensor, Tuple[torch.Tensor, float]]]]
[docs]
def register_function(self, typename: str, function: Callable[[List[Value], List[torch.Tensor]], Tuple[torch.Tensor, Tuple[torch.Tensor, float]]]):
"""Register a function to compute the expectation of a list of values.
Args:
typename: the name of the type.
function: the function to compute the expectation of a list of values.
"""
self._expectation_functions[typename] = function
[docs]
def get_function(self, typename: str) -> Optional[Callable[[List[Value], List[torch.Tensor]], Tuple[torch.Tensor, Tuple[torch.Tensor, float]]]]:
"""Get the registered function for a type.
Args:
typename: the name of the type.
Returns:
The registered function.
"""
return self._expectation_functions.get(typename, None)
def _init_expectation_functions(self):
for typename in self._domain.types:
funcname = 'expectation_' + typename
if hasattr(self, funcname):
self.register_function(typename, getattr(self, funcname))
[docs]
def expectation(self, values: List[Value], weights: List[torch.Tensor]) -> Tuple[Optional[Value], Optional[Tuple[torch.Tensor, float]]]:
"""Compute the expectation of a list of values.
Args:
values: a list of values.
weights: a list of corresponding weights.
Returns:
A tuple of (expectation, (sum_weight, max_weight)).
The ``expectation`` is the expectation of the values, as a :class:`~concepts.dsl.value.Value` instance.
The second element is a tuple of (sum_weight, max_weight), where ``sum_weight`` is the sum of the weights (as a PyTorch tensor),
and ``max_weight`` is the maximum of the weights (as a float).
"""
if not isinstance(values[0], Value):
return None, None
function = self.get_function(values[0].dtype.typename)
if function is None:
return None, None
return function(values, weights)
def _expectation_set_tensors(self, sets: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], weights: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], log: bool = False) -> Tuple[torch.Tensor, Tuple[torch.Tensor, float]]:
"""A helper function to compute the expectation of a list of tensors representing sets.
Typically, the tensor representation of a set is a one-hot vector, where each entry represents the probabbility of the corresponding element in the set.
Args:
sets: a list of tensors representing sets.
weights: a list of corresponding weights.
log: whether the probabilities in sets are in log-space.
Returns:
A tuple of (set, (sum_weight, max_weight)).
The ``set`` is a tensor representing the expectation of the list of sets, as a PyTorch tensor.
The second element is a tuple of (sum_weight, max_weight), where ``sum_weight`` is the sum of the weights (as a PyTorch tensor),
and ``max_weight`` is the maximum of the weights (as a float).
"""
sets_tensor = torch.stack(sets, dim=-1)
weights_tensor = torch.stack(weights, dim=0)
weights_sum = jactorch.logsumexp(weights_tensor)
weights_max = weights_tensor.argmax().item()
if log:
weights_tensor = F.log_softmax(weights_tensor, dim=-1)
weights_tensor = jactorch.add_dim_as_except(weights_tensor, sets_tensor, -1)
output = jactorch.logsumexp(sets_tensor + weights_tensor, dim=-1)
else:
weights_tensor = F.softmax(weights_tensor, dim=-1)
output = torch.matmul(sets_tensor, weights_tensor)
return output, (weights_sum, weights_max)
[docs]
def aggregate_weights(weights: Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor]], log=True) -> Tuple[torch.Tensor, float]:
"""Sum up a list of parsing weights (probabilities).
Args:
weights: a list of weights. It can be a single tensor, or a list of tensors.
log: whether the weights are in log space. When this is True, the output will be computed as ``logsumexp``.
Returns:
A tuple of (sum, max).
"""
if torch.is_tensor(weights):
weights_tensor = weights
elif isinstance(weights, (tuple, list)):
weights_tensor = torch.stack(weights, dim=0)
else:
weights_tensor = torch.tensor(weights)
if log:
return jactorch.logsumexp(weights_tensor), weights_tensor.argmax().item()
else:
return weights_tensor.sum(), weights_tensor.argmax().item()
[docs]
@dataclass
class CKYEEExpectationConfig(object):
"""Configurations for CKYEE expectation computation."""
compress_values: bool = True
compress_0varbinding_functions: bool = True
compress_1varbinding_functions: bool = True
sample: bool = False
def _get_typename(x):
return x.typename if x is not None else None
[docs]
class NeuralCKYEE(NeuralCCG):
"""The neural CCG grammar with CKY-EE for chart parsing."""
[docs]
def __init__(
self,
domain: FunctionDomain,
executor: FunctionDomainExecutor,
candidate_lexicon_entries: Iterable[NeuralCCGLexiconSearchResult],
expectation_function: CKYEEExpectationFunction,
composition_system: Optional[CCGCompositionSystem] = None,
*,
expectation_config: Optional[CKYEEExpectationConfig] = None,
joint_execution: bool = True,
allow_none_lexicon: bool = False,
reweight_meaning_lex: bool = False
):
"""Initialize the neural CCG grammar with CKY-EE.
Args:
domain: the function domain of the grammar.
executor: the executor for the function domain.
candidate_lexicon_entries: a list of candidate lexicon entries.
exepctation_function: the expectation function for CKY-EE. See :class:`CKYEEExpectationFunction`.
composition_system: the composition system. If None, the default composition system will be used.
expectation_config: the configuration for CKY-EE expectation computation.
joint_execution: whether to execute the partial programs during CKY.
allow_none_lexicon: whether to allow None lexicon.
reweight_meaning_lex: whether to reweight the meaning lexicon entries. Specifically, if there are two parsings
share the same set of lexicon entries (i.e., caused by ambiguities in combination), specifying this flag
will reweight both of them to be 1 / (number of parsings that used this set of lexicon entries).
"""
super().__init__(
domain, executor, candidate_lexicon_entries, composition_system,
joint_execution=joint_execution, allow_none_lexicon=allow_none_lexicon, reweight_meaning_lex=reweight_meaning_lex
)
self.expectation_function = expectation_function
self.expectation_config = expectation_config if expectation_config is not None else CKYEEExpectationConfig()
training: bool
domain: FunctionDomain
executor: FunctionDomainExecutor
composition_system: CCGCompositionSystem
candidate_lexicon_entries: Tuple['NeuralCCGLexiconSearchResult', ...]
joint_execution: bool
allow_none_lexicon: bool
reweight_meaning_lex: bool
def _unique_lexicon(self, nodes: List[NeuralCCGNode], syntax_only: bool) -> List[NeuralCCGNode]:
if syntax_only:
return self._unique_syntax_only(nodes)
return self._collect_nodes_by_type(nodes)
def _unique(self, nodes: List[NeuralCCGNode], syntax_only: bool) -> List[NeuralCCGNode]:
nodes = super()._unique(nodes, syntax_only)
if syntax_only:
return nodes
nodes = self._collect_nodes_by_type(nodes)
return nodes
def _reweight_parse_trees(self, parsings: List[NeuralCCGNode]) -> List[NeuralCCGNode]:
raise NotImplementedError('Reweight pr[meaning | lex_i] is not supported for NeuralSoftCCG.')
def _collect_nodes_by_type(self, nodes: List[NeuralCCGNode]) -> List[NeuralCCGNode]:
output_nodes = nodes
if self.expectation_config.compress_values:
output_nodes = self._collect_values(output_nodes)
if self.expectation_config.compress_0varbinding_functions:
output_nodes = self._collect_0varbinding_functions(output_nodes)
if self.expectation_config.compress_1varbinding_functions:
output_nodes = self._collect_1varbinding_functions(output_nodes)
if self.expectation_config.sample:
output_nodes = self._collect_by_sample(output_nodes)
return output_nodes
@_profile
def _collect_values(self, nodes):
output_nodes = list()
value_nodes = defaultdict(lambda: (list(), list(), list()))
for node in nodes:
if not isinstance(node.syntax, CCGCoordinationImmNode) and node.syntax.is_value:
rec = value_nodes[(node.syntax.return_type.typename, _get_typename(node.syntax.lang_syntax_type))]
rec[0].append(node)
if self.joint_execution:
rec[1].append(node.execution_result)
rec[2].append(node.weight)
else:
output_nodes.append(node)
for typename, rec in value_nodes.items():
output_nodes.extend(self._collect_values_once(*rec))
return output_nodes
@_profile
def _collect_0varbinding_functions(self, nodes):
output_nodes = list()
function_nodes = defaultdict(lambda: (list(), list(), list()))
for node in nodes:
sem = node.semantics
if not isinstance(node.syntax, CCGCoordinationImmNode) and sem.partial_type is not None and len(sem.partial_type) == 1 and isinstance(sem.partial_type[0], NeuralCCGSemanticsPartialTypeLex):
rec = function_nodes[(sem.partial_type[0], _get_typename(node.syntax.lang_syntax_type))]
rec[0].append(node)
rec[2].append(node.weight)
else:
output_nodes.append(node)
for typename, rec in function_nodes.items():
output_nodes.extend(self._collect_0varbinding_functions_once(*rec))
return output_nodes
@_profile
def _collect_1varbinding_functions(self, nodes):
output_nodes = list()
function_nodes = defaultdict(lambda: (list(), list(), list()))
for node in nodes:
sem = node.semantics
if not isinstance(node.syntax, CCGCoordinationImmNode) and sem.partial_type is not None and len(sem.partial_type) == 2 and isinstance(sem.partial_type[0], NeuralCCGSemanticsPartialTypeLex):
rec = function_nodes[(sem.partial_type[0], _get_typename(node.syntax.lang_syntax_type))]
rec[0].append(node)
if self.joint_execution:
rec[1].append(node.semantics.execution_buffer[1])
rec[2].append(node.weight)
else:
output_nodes.append(node)
for typename, rec in function_nodes.items():
output_nodes.extend(self._collect_1varbinding_functions_once(*rec))
return output_nodes
@_profile
def _collect_by_sample(self, nodes):
output_nodes = list()
value_nodes = defaultdict(lambda: (list(), list()))
for node in nodes:
rec = value_nodes[str(node.syntax)]
rec[0].append(node)
rec[1].append(node.weight)
for typename, rec in value_nodes.items():
output_nodes.extend(self._collect_by_sample_once(*rec))
return output_nodes
def _collect_values_once(self, nodes: List[NeuralCCGNode], all_results: List[Value], all_weights: List[torch.Tensor]) -> List[NeuralCCGNode]:
if len(nodes) > 1:
if self.joint_execution:
compressed_node, weights = self.expectation_function.expectation(all_results, all_weights)
else:
compressed_node = None
weights = aggregate_weights(all_weights, log=True)
if self.joint_execution and compressed_node is None:
return nodes
else:
sample_node = nodes[weights[1]]
new_node = NeuralCCGNode(
composition_system=sample_node.composition_system,
syntax=sample_node.syntax,
semantics=NeuralCCGSemantics(
sample_node.semantics.value,
execution_buffer=[compressed_node] if compressed_node is not None else None,
partial_type=sample_node.semantics.partial_type,
nr_execution_steps=self.count_nr_execution_steps(nodes)
),
composition_type=sample_node.composition_type,
lexicon=sample_node.lexicon, lhs=sample_node.lhs, rhs=sample_node.rhs,
composition_str=sample_node.composition_str,
weight=weights[0]
)
new_node.set_used_lexicon_entries(self.gen_valid_lexicons(nodes))
return [new_node]
else:
return nodes
def _collect_0varbinding_functions_once(self, nodes: List[NeuralCCGNode], all_results: List[Value], all_weights: List[torch.Tensor]) -> List[NeuralCCGNode]:
if len(nodes) > 1:
weights = aggregate_weights(all_weights, log=True)
sample_node = nodes[0]
new_node = NeuralCCGNode(
composition_system=sample_node.composition_system,
syntax=sample_node.syntax,
semantics=sample_node.semantics,
composition_type=sample_node.composition_type,
lexicon=sample_node.lexicon, lhs=sample_node.lhs, rhs=sample_node.rhs,
composition_str=sample_node.composition_str,
weight=weights[0]
)
new_node.set_used_lexicon_entries(self.gen_valid_lexicons(nodes))
return [new_node]
else:
return nodes
def _collect_1varbinding_functions_once(self, nodes: List[NeuralCCGNode], all_results: List[Value], all_weights: List[torch.Tensor]) -> List[NeuralCCGNode]:
if len(nodes) > 1:
if self.joint_execution:
compressed_node, weights = self.expectation_function.expectation(all_results, all_weights)
else:
compressed_node = None
weights = aggregate_weights(all_weights, log=True)
if self.joint_execution and compressed_node is None:
return nodes
else:
sample_node = nodes[weights[1]]
new_node = NeuralCCGNode(
composition_system=sample_node.composition_system,
syntax=sample_node.syntax,
semantics=NeuralCCGSemantics(
sample_node.semantics.value,
execution_buffer=[
sample_node.semantics.execution_buffer[0],
compressed_node
] if compressed_node is not None else None,
partial_type=sample_node.semantics.partial_type,
nr_execution_steps=self.count_nr_execution_steps(nodes)
),
composition_type=sample_node.composition_type,
lexicon=sample_node.lexicon, lhs=sample_node.lhs, rhs=sample_node.rhs,
composition_str=sample_node.composition_str,
weight=weights[0]
)
new_node.set_used_lexicon_entries(self.gen_valid_lexicons(nodes))
return [new_node]
else:
return nodes
def _collect_by_sample_once(self, nodes: List[NeuralCCGNode], all_weights: List[torch.Tensor]) -> List[NeuralCCGNode]:
if len(nodes) > 1:
weights_tensor = torch.stack(all_weights, dim=0).detach()
index = jacrandom.choice(weights_tensor.shape[0], p=jactorch.as_numpy(F.softmax(weights_tensor)))
weights_tensor.data[index] = -1e9
total_weights = jactorch.logaddexp(all_weights[index], jactorch.logsumexp(weights_tensor))
sample_node = nodes[index]
new_node = NeuralCCGNode(
composition_system=sample_node.composition_system,
syntax=sample_node.syntax,
semantics=NeuralCCGSemantics(
sample_node.semantics.value,
execution_buffer=sample_node.semantics.execution_buffer,
partial_type=sample_node.semantics.partial_type,
nr_execution_steps=self.count_nr_execution_steps(nodes)
),
composition_type=sample_node.composition_type,
lexicon=sample_node.lexicon, lhs=sample_node.lhs, rhs=sample_node.rhs,
composition_str=sample_node.composition_str,
weight=total_weights
)
new_node.set_used_lexicon_entries(self.gen_valid_lexicons(nodes))
return [new_node]
else:
return nodes