Source code for concepts.dm.crow.function_utils

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : function_utils.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 03/17/2024
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

import contextlib
from typing import Optional, Union, Sequence, Tuple, List, Set, Dict

import torch
import concepts.dsl.expression as E
from concepts.dsl.dsl_types import Variable
from concepts.dsl.expression import Expression, ExpressionDefinitionContext, iter_exprs
from concepts.dsl.expression import FunctionApplicationExpression, VariableExpression, ObjectOrValueOutputExpression, VariableAssignmentExpression, ValueOutputExpression
from concepts.dsl.expression import ListCreationExpression, ListExpansionExpression, ListFunctionApplicationExpression
from concepts.dsl.expression import BoolExpression, QuantificationExpression, FindAllExpression, BoolOpType, ObjectCompareExpression, ValueCompareExpression
from concepts.dsl.expression import PredicateEqualExpression, AssignExpression, ConditionalSelectExpression, DeicticSelectExpression, ConditionalAssignExpression, DeicticAssignExpression
from concepts.dsl.expression_visitor import IdentityExpressionVisitor
from concepts.dsl.tensor_value import TensorValue
from concepts.dsl.tensor_state import StateObjectReference

from concepts.dm.crow.function import CrowFeature

__all__ = ['expand_argument_values', 'flatten_expression', 'get_used_state_variables', 'is_simple_bool', 'split_simple_bool', 'get_simple_bool_predicate', 'get_used_state_variables']


[docs]def expand_argument_values(argument_values: Sequence[Union[TensorValue, int, str, slice, StateObjectReference]]) -> List[TensorValue]: """Expand a list of argument values to the same batch size. Args: argument_values: a list of argument values. Returns: the result list of argument values. All return values will have the same batch size. """ has_slot_var = False for arg in argument_values: if isinstance(arg, TensorValue): for var in arg.batch_variables: if var == '??': has_slot_var = True break if has_slot_var: return list(argument_values) if len(argument_values) < 2: return list(argument_values) argument_values = list(argument_values) batch_variables = list() batch_sizes = list() for arg in argument_values: if isinstance(arg, TensorValue): for var in arg.batch_variables: if var not in batch_variables: batch_variables.append(var) batch_sizes.append(arg.get_variable_size(var)) else: assert isinstance(arg, (int, str, slice, StateObjectReference)), arg masks = list() for i, arg in enumerate(argument_values): if isinstance(arg, TensorValue): argument_values[i] = arg.expand(batch_variables, batch_sizes) if argument_values[i].tensor_mask is not None: masks.append(argument_values[i].tensor_mask) if len(masks) > 0: final_mask = torch.stack(masks, dim=-1).amin(dim=-1) for arg in argument_values: if isinstance(arg, TensorValue): arg.tensor_mask = final_mask arg._mask_certified_flag = True # now we have corrected the mask. return argument_values
[docs]def flatten_expression( expr: Expression, mappings: Optional[Dict[Union[FunctionApplicationExpression, VariableExpression], Union[Variable, ValueOutputExpression]]] = None, ctx: Optional[ExpressionDefinitionContext] = None, flatten_cacheable_expression: bool = True, ) -> Union[ObjectOrValueOutputExpression, VariableAssignmentExpression]: """Flatten an expression by replacing certain variables or function applications with sub-expressions. The input mapping is a dictionary of {expression: sub-expression}. There are two cases: - The expression is a :class:`~concepts.dsl.expression.VariableExpression`, and the sub-expression is a :class:`~concepts.dsl.dsl_types.Variable` or a :class:`~concepts.dsl.expression.ValueOutputExpression`. In this case, the variable expression will is the sub-expression used for replacing the variable. - The expression is a :class:`~concepts.dsl.expression.FunctionApplicationExpression`, and the sub-expression is a :class:`~concepts.dsl.dsl_types.Variable`. Here, the function application expression must be a "simple" function application expression, i.e., it contains only variables as arguments. The Variable will replace the entire function application expression. Args: expr: the expression to flatten. mappings: a dictionary of {expression: sub-expression} to replace the expression with the sub-expression. ctx: a :class:`~concepts.dsl.expression.ExpressionDefinitionContext`. flatten_cacheable_expression: whether to flatten cacheable expressions. If False, cacheable function applications will be kept as-is. Returns: the flattened expression. """ if mappings is None: mappings = dict() if ctx is None: ctx = ExpressionDefinitionContext() with ctx.as_default(): return _FlattenExpressionVisitor(ctx, mappings, flatten_cacheable_expression).visit(expr)
class _FlattenExpressionVisitor(IdentityExpressionVisitor): def __init__( self, ctx: ExpressionDefinitionContext, mappings: Dict[Union[FunctionApplicationExpression, VariableExpression], Union[Variable, ValueOutputExpression]], flatten_cacheable_expression: bool = True, ): self.ctx = ctx self.mappings = mappings self.flatten_cacheable_expression = flatten_cacheable_expression def visit_variable_expression(self, expr: VariableExpression) -> Union[VariableExpression, ValueOutputExpression]: rv = expr for k, v in self.mappings.items(): if isinstance(k, VariableExpression): if k.name == expr.name: if isinstance(v, Variable): rv = VariableExpression(v) else: rv = v break return rv def visit_function_application_expression(self, expr: Union[FunctionApplicationExpression, ListFunctionApplicationExpression]) -> Union[VariableExpression, ValueOutputExpression]: # Case 1: the function application will be replaced by something in the mappings. for k, v in self.mappings.items(): if isinstance(k, FunctionApplicationExpression): if expr.function.name == k.function.name and all( isinstance(a1, VariableExpression) and isinstance(a2, VariableExpression) and a1.name == a2.name for a1, a2 in zip(expr.arguments, k.arguments) ): assert isinstance(v, Variable) return VariableExpression(v) # Case 2 contains three sub-cases: # (1) the function is not a derived function # (2) the function corresponds to a state variable # (3) the function is a cacheable function and we do not want to flatten it. if not expr.function.is_derived or (isinstance(expr.function, CrowFeature) and expr.function.is_state_variable) or (not self.flatten_cacheable_expression and expr.function.ftype.is_cacheable): return type(expr)(expr.function, [self.visit(e) for e in expr.arguments]) # Case 3: the function is a derived function and we want to flatten it. for arg in expr.function.arguments: if not isinstance(arg, Variable): raise TypeError(f'Cannot flatten function application {expr} because it contains non-variable arguments.') # (1) First resolve the arguments. argvs = [self.visit(e) for e in expr.arguments] # (2) Make a backup of the current context, and then create a new context using the arguments. old_mappings = self.mappings self.mappings = dict() for arg, argv in zip(expr.function.arguments, argvs): if isinstance(arg, Variable): self.mappings[VariableExpression(arg)] = argv # (3) Flatten the derived expression. with self.ctx.with_variables(*expr.function.arguments): rv = self.visit(expr.function.derived_expression) # (4) Restore the old context. self.mappings = old_mappings # (5) Flatten the result again, using the old context + mappings. return self.visit(rv) # return type(rv)(rv.function, [self.visit(e) for e in rv.arguments]) def visit_list_creation_expression(self, expr: ListCreationExpression) -> ListCreationExpression: return type(expr)([self.visit(e) for e in expr.arguments]) def visit_list_expansion_expression(self, expr: E.ListExpansionExpression) -> ListExpansionExpression: return type(expr)(self.visit(expr.expression)) def visit_list_function_application_expression(self, expr: ListFunctionApplicationExpression) -> Union[VariableExpression, ValueOutputExpression]: return self.visit_function_application_expression(expr) def visit_bool_expression(self, expr: BoolExpression) -> BoolExpression: return BoolExpression(expr.bool_op, [self.visit(child) for child in expr.arguments]) def visit_object_compare_expression(self, expr: ObjectCompareExpression) -> ObjectCompareExpression: return ObjectCompareExpression(expr.compare_op, self.visit(expr.lhs), self.visit(expr.rhs)) def visit_value_compare_expression(self, expr: ValueCompareExpression) -> ValueCompareExpression: return ValueCompareExpression(expr.compare_op, self.visit(expr.lhs), self.visit(expr.rhs)) def visit_quantification_expression(self, expr: QuantificationExpression) -> QuantificationExpression: with self.make_dummy_variable(expr.variable) as dummy_variable: return QuantificationExpression(expr.quantification_op, dummy_variable, self.visit(expr.expression)) def visit_find_all_expression(self, expr: FindAllExpression) -> FindAllExpression: with self.make_dummy_variable(expr.variable) as dummy_variable: return E.FindAllExpression(dummy_variable, self.visit(expr.expression)) def visit_predicate_equal_expression(self, expr: PredicateEqualExpression) -> PredicateEqualExpression: return type(expr)(self.visit(expr.predicate), self.visit(expr.value)) def visit_assign_expression(self, expr: AssignExpression) -> AssignExpression: return type(expr)(self.visit(expr.predicate), self.visit(expr.value)) def visit_conditional_select_expression(self, expr: ConditionalSelectExpression) -> ConditionalSelectExpression: return type(expr)(self.visit(expr.predicate), self.visit(expr.condition)) def visit_deictic_select_expression(self, expr: DeicticSelectExpression) -> DeicticSelectExpression: with self.make_dummy_variable(expr.variable) as dummy_variable: return type(expr)(dummy_variable, self.visit(expr.expression)) def visit_conditional_assign_expression(self, expr: ConditionalAssignExpression) -> ConditionalAssignExpression: return type(expr)(self.visit(expr.predicate), self.visit(expr.value), self.visit(expr.condition)) def visit_deictic_assign_expression(self, expr: DeicticAssignExpression) -> DeicticAssignExpression: with self.make_dummy_variable(expr.variable) as dummy_variable: return type(expr)(dummy_variable, self.visit(expr.expression)) def visit_constant_expression(self, expr: Expression) -> Expression: return expr def visit_object_constant_expression(self, expr: Expression) -> Expression: return expr @contextlib.contextmanager def make_dummy_variable(self, variable: Variable): dummy_variable = self.ctx.gen_random_named_variable(variable.dtype) dummy_variable_expr = VariableExpression(variable) old_mapping = self.mappings.get(dummy_variable_expr, None) self.mappings[dummy_variable_expr] = dummy_variable yield dummy_variable self.mappings[dummy_variable_expr] = old_mapping
[docs]def get_used_state_variables(expr: ValueOutputExpression) -> Set[CrowFeature]: """Return the set of state variables used in the given expression. Args: expr: the expression to be analyzed. Returns: the set of state variables (the :class:`~concepts.dm.crow.function.Feature` objects) used in the given expression. """ assert isinstance(expr, ValueOutputExpression), ( 'Only value output expression has well-defined used-state-variables.\n' 'For value assignment expressions, please separately process the targets, conditions, and values.' ) used_svs = set() def dfs(this): nonlocal used_svs for e in iter_exprs(this): if isinstance(e, FunctionApplicationExpression): if isinstance(e.function, CrowFeature) and e.function.is_state_variable: used_svs.add(e.function) elif e.function.derived_expression is not None: dfs(e.function.derived_expression) dfs(expr) return used_svs
[docs]def is_simple_bool(expr: Expression) -> bool: """Check if the expression is a simple Boolean expression. That is, it is either a Boolean state variable, or the negation of a Boolean state variable. Args: expr: the expression to check. Returns: True if the expression is a simple boolean expression, False otherwise. """ if isinstance(expr, FunctionApplicationExpression) and isinstance(expr.function, CrowFeature) and expr.function.is_state_variable: return True if isinstance(expr, BoolExpression) and expr.bool_op is BoolOpType.NOT: return is_simple_bool(expr.arguments[0]) return False
[docs]def split_simple_bool(expr: Expression, initial_negated: bool = False) -> Tuple[Optional[FunctionApplicationExpression], bool]: """ If the expression is a simple Boolean expression (see :func:`is_simple_bool`), it returns the feature definition and a boolean indicating whether the expression is negated. Args: expr (Expression): the expression to be checked. initial_negated (bool, optional): whether outer context of the feature expression is a negated function. Returns: a tuple of the feature application and a boolean indicating whether the expression is negated. The first element is None if the feature is not a simple Boolean feature application. """ if isinstance(expr, FunctionApplicationExpression) and isinstance(expr.function, CrowFeature) and expr.function.is_state_variable: return expr, initial_negated if isinstance(expr, BoolExpression) and expr.bool_op is BoolOpType.NOT: return split_simple_bool(expr.arguments[0], not initial_negated) return None, initial_negated
[docs]def get_simple_bool_predicate(expr: Expression) -> CrowFeature: """If the expression is a simple bool (see :func:`is_simple_bool`), it returns the underlying predicate. Args: expr: the expression, assumed to be a simple Boolean expression. Returns: the underlying predicate. """ if isinstance(expr, FunctionApplicationExpression) and isinstance(expr.function, CrowFeature) and expr.function.is_state_variable: return expr.function assert isinstance(expr, BoolExpression) and expr.bool_op is BoolOpType.NOT return get_simple_bool_predicate(expr.arguments[0])