#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : crow_expression_utils.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 03/17/2024
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.
import numpy as np
import torch
from typing import Optional, Any, Union, Sequence, Tuple, Set, Dict
from jacinle.utils.meta import stmap
from concepts.dsl.dsl_types import Variable
from concepts.dsl.expression import Expression, ExpressionDefinitionContext
from concepts.dsl.expression import FunctionApplicationExpression, VariableExpression, ObjectOrValueOutputExpression, VariableAssignmentExpression, ValueOutputExpression
from concepts.dsl.expression import ListFunctionApplicationExpression, BoolExpression, BoolOpType
from concepts.dsl.expression_utils import iter_exprs, FlattenExpressionVisitor
from concepts.dsl.tensor_value import TensorValue
from concepts.dsl.tensor_state import StateObjectReference
from concepts.dm.crow.crow_function import CrowFeature
from concepts.dm.crow.controller import CrowControllerApplier
__all__ = [
'make_plan_serializable',
'crow_flatten_expression', 'crow_replace_expression_variables',
'crow_get_used_state_variables', 'crow_is_simple_bool', 'crow_split_simple_bool',
'crow_get_simple_bool_predicate',
]
[docs]
def make_plan_serializable(plan: Sequence[CrowControllerApplier], json_compatible: bool = False) -> Tuple[Union[Dict[str, Any], str, list]]:
"""Make a serializable version of the plan.
Args:
plan: the plan to be serialized.
json_compatible: whether to make the plan JSON-compatible. If True, the plan will be converted to a JSON-compatible format, which will flatten all numpy arrays and torch tensors to lists.
Returns:
the serialized plan.
"""
def prim(v, json_compatible=json_compatible):
if isinstance(v, CrowControllerApplier):
return {'name': v.name, 'arguments': stmap(prim, v.arguments)}
elif isinstance(v, StateObjectReference):
return v.name
elif isinstance(v, TensorValue):
return prim(v.tensor)
else:
if json_compatible:
if isinstance(v, np.ndarray):
return v.tolist()
elif isinstance(v, torch.Tensor):
return v.cpu().numpy().tolist()
elif hasattr(v, '__dict__'):
return {'class': v.__class__.__name__, 'data': stmap(prim, v.__dict__)}
else:
return v
return v
return tuple(prim(v) for v in plan)
[docs]
def crow_replace_expression_variables(
expr: Expression,
mappings: Optional[Dict[Union[FunctionApplicationExpression, VariableExpression], Union[Variable, ObjectOrValueOutputExpression]]] = None,
ctx: Optional[ExpressionDefinitionContext] = None,
) -> Union[ObjectOrValueOutputExpression, VariableAssignmentExpression]:
"""Replace variables in an expression with other expressions. Allowed replacements are:
- Replace a :class:`~concepts.dsl.expression.VariableExpression` with a :class:`~concepts.dsl.dsl_types.Variable` or a :class:`~concepts.dsl.expression.ValueOutputExpression`.
- Replace a :class:`~concepts.dsl.expression.FunctionApplicationExpression` with a :class:`~concepts.dsl.dsl_types.Variable` or a :class:`~concepts.dsl.expression.ValueOutputExpression`.
Args:
expr: the expression to replace variables.
mappings: a dictionary of {expression: sub-expression} to replace the expression with the sub-expression.
ctx: a :class:`~concepts.dsl.expression.ExpressionDefinitionContext`.
"""
return crow_flatten_expression(expr, mappings, ctx, deep=False, flatten_cacheable_expression=False)
[docs]
def crow_flatten_expression(
expr: Expression,
mappings: Optional[Dict[Union[FunctionApplicationExpression, VariableExpression], Union[Variable, ObjectOrValueOutputExpression]]] = None,
ctx: Optional[ExpressionDefinitionContext] = None,
deep: bool = True,
flatten_cacheable_expression: bool = True,
) -> Union[ObjectOrValueOutputExpression, VariableAssignmentExpression]:
"""Flatten an expression by replacing certain variables or function applications with sub-expressions.
The input mapping is a dictionary of {expression: sub-expression}. There are two cases:
- The expression is a :class:`~concepts.dsl.expression.VariableExpression`, and the sub-expression is a
:class:`~concepts.dsl.dsl_types.Variable` or a :class:`~concepts.dsl.expression.ValueOutputExpression`. In this case,
the variable expression will is the sub-expression used for replacing the variable.
- The expression is a :class:`~concepts.dsl.expression.FunctionApplicationExpression`, and the sub-expression is a
:class:`~concepts.dsl.dsl_types.Variable`. Here, the function
application expression must be a "simple" function application expression, i.e., it contains only variables
as arguments. The Variable will replace the entire function application expression.
Args:
expr: the expression to flatten.
mappings: a dictionary of {expression: sub-expression} to replace the expression with the sub-expression.
ctx: a :class:`~concepts.dsl.expression.ExpressionDefinitionContext`.
deep: whether to recursively flatten the expression (expand derived functions). Default is True.
flatten_cacheable_expression: whether to flatten cacheable expressions. If False, cacheable function applications will be kept as-is.
Returns:
the flattened expression.
"""
if mappings is None:
mappings = dict()
if ctx is None:
ctx = ExpressionDefinitionContext()
with ctx.as_default():
return CrowFlattenExpressionVisitor(ctx, mappings, deep=deep, flatten_cacheable_expression=flatten_cacheable_expression).visit(expr)
[docs]
class CrowFlattenExpressionVisitor(FlattenExpressionVisitor):
[docs]
def __init__(
self,
ctx: ExpressionDefinitionContext,
mappings: Dict[Union[FunctionApplicationExpression, VariableExpression], Union[Variable, ValueOutputExpression]],
deep: bool = True,
flatten_cacheable_expression: bool = True,
):
super().__init__(ctx, mappings, deep)
self.flatten_cacheable_expression = flatten_cacheable_expression
[docs]
def visit_function_application_expression(self, expr: Union[FunctionApplicationExpression, ListFunctionApplicationExpression]) -> Union[VariableExpression, ValueOutputExpression]:
# Case 1: the function application will be replaced by something in the mappings.
for k, v in self.mappings.items():
if isinstance(k, FunctionApplicationExpression):
if expr.function.name == k.function.name and all(
isinstance(a1, VariableExpression) and isinstance(a2, VariableExpression) and a1.name == a2.name for a1, a2 in zip(expr.arguments, k.arguments)
):
assert isinstance(v, Variable)
return VariableExpression(v)
if not self.deep:
return type(expr)(expr.function, [self.visit(e) for e in expr.arguments])
# Case 2 contains three sub-cases:
# (1) the function is not a derived function
# (2) the function corresponds to a state variable
# (3) the function is a cacheable function and we do not want to flatten it.
if not expr.function.is_derived or (isinstance(expr.function, CrowFeature) and expr.function.is_state_variable) or (not self.flatten_cacheable_expression and expr.function.ftype.is_cacheable):
return type(expr)(expr.function, [self.visit(e) for e in expr.arguments])
# Case 3: the function is a derived function and we want to flatten it.
for arg in expr.function.arguments:
if not isinstance(arg, Variable):
raise TypeError(f'Cannot flatten function application {expr} because it contains non-variable arguments.')
# (1) First resolve the arguments.
argvs = [self.visit(e) for e in expr.arguments]
# (2) Make a backup of the current context, and then create a new context using the arguments.
old_mappings = self.mappings
self.mappings = dict()
for arg, argv in zip(expr.function.arguments, argvs):
if isinstance(arg, Variable):
self.mappings[VariableExpression(arg)] = argv
# (3) Flatten the derived expression.
with self.ctx.with_variables(*expr.function.arguments):
rv = self.visit(expr.function.derived_expression)
# (4) Restore the old context.
self.mappings = old_mappings
# (5) Flatten the result again, using the old context + mappings.
return self.visit(rv)
# return type(rv)(rv.function, [self.visit(e) for e in rv.arguments])
[docs]
def crow_get_used_state_variables(expr: ValueOutputExpression) -> Set[CrowFeature]:
"""Return the set of state variables used in the given expression.
Args:
expr: the expression to be analyzed.
Returns:
the set of state variables (the :class:`~concepts.dm.crow.function.Feature` objects) used in the given expression.
"""
assert isinstance(expr, ValueOutputExpression), (
'Only value output expression has well-defined used-state-variables.\n'
'For value assignment expressions, please separately process the targets, conditions, and values.'
)
used_svs = set()
def dfs(this):
nonlocal used_svs
for e in iter_exprs(this):
if isinstance(e, FunctionApplicationExpression):
if isinstance(e.function, CrowFeature) and e.function.is_state_variable:
used_svs.add(e.function)
elif e.function.derived_expression is not None:
dfs(e.function.derived_expression)
dfs(expr)
return used_svs
[docs]
def crow_is_simple_bool(expr: Expression) -> bool:
"""Check if the expression is a simple Boolean expression. That is, it is either a Boolean state variable,
or the negation of a Boolean state variable.
Args:
expr: the expression to check.
Returns:
True if the expression is a simple boolean expression, False otherwise.
"""
if isinstance(expr, FunctionApplicationExpression) and isinstance(expr.function, CrowFeature) and expr.function.is_state_variable:
return True
if isinstance(expr, BoolExpression) and expr.bool_op is BoolOpType.NOT:
return crow_is_simple_bool(expr.arguments[0])
return False
[docs]
def crow_split_simple_bool(expr: Expression, initial_negated: bool = False) -> Tuple[Optional[FunctionApplicationExpression], bool]:
"""
If the expression is a simple Boolean expression (see :func:`is_simple_bool`),
it returns the feature definition and a boolean indicating whether the expression is negated.
Args:
expr (Expression): the expression to be checked.
initial_negated (bool, optional): whether outer context of the feature expression is a negated function.
Returns:
a tuple of the feature application and a boolean indicating whether the expression is negated.
The first element is None if the feature is not a simple Boolean feature application.
"""
if isinstance(expr, FunctionApplicationExpression) and isinstance(expr.function, CrowFeature) and expr.function.is_state_variable:
return expr, initial_negated
if isinstance(expr, BoolExpression) and expr.bool_op is BoolOpType.NOT:
return crow_split_simple_bool(expr.arguments[0], not initial_negated)
return None, initial_negated
[docs]
def crow_get_simple_bool_predicate(expr: Expression) -> CrowFeature:
"""If the expression is a simple bool (see :func:`is_simple_bool`), it returns the underlying predicate.
Args:
expr: the expression, assumed to be a simple Boolean expression.
Returns:
the underlying predicate.
"""
if isinstance(expr, FunctionApplicationExpression) and isinstance(expr.function, CrowFeature) and expr.function.is_state_variable:
return expr.function
assert isinstance(expr, BoolExpression) and expr.bool_op is BoolOpType.NOT
return crow_get_simple_bool_predicate(expr.arguments[0])