#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : strips_expression.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 04/26/2022
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.
"""
This file defines a set of classes for representing STRIPS expressions, including:
- Boolean constant expressions.
- Boolean predicate expressions.
- SAS predicate expressions.
- And/Or expressions.
- Not expressions.
- Forall/Exists expressions.
- SAS expressions.
- Single predicate assignment expressions.
- Conditional assignment expressions.
- Deictic assignment expressions.
At the highest level, the STRIPS expressions are categorized into two types:
- StripsValueOutputExpression, which outputs a value. There is a special instantiation of this type, StripsBooleanOutputExpression, which outputs a boolean value.
- StripsVariableAssignmentExpression, which assigns a value to a state variable.
"""
import jacinle
from copy import deepcopy
from abc import abstractmethod, ABC
from typing import Optional, Union, Iterable, Sequence, Tuple, Set, FrozenSet, Dict
from concepts.dsl.dsl_types import Variable, ObjectConstant, BOOL
from concepts.dsl.expression import FunctionApplicationExpression, VariableExpression, ObjectConstantExpression, BoolExpression, BoolOpType
__all__ = [
'SPredicateName', 'SProposition', 'make_sproposition',
'SState', 'SStateCompatible', 'SStateDict',
'SExpression', 'SValueOutputExpression', 'SVariableAssignmentExpression', 'SBoolOutputExpression',
'SBoolConstant', 'SBoolPredicateApplicationExpression', 'SSASPredicateApplicationExpression',
'SSimpleBoolExpression', 'SBoolNot', 'SQuantificationExpression', 'SSASExpression',
'SAssignExpression', 'SConditionalAssignExpression', 'SDeicticAssignExpression'
]
"""The name of predicates, represented as strings."""
SPredicateName = str
"""The name of propositions. A proposition is a predicate grounded on a set of arguments."""
SProposition = str
# """The representation of a SAS proposition. It is a tuple of (predicate name, SAS value)."""
# StripsSASProposition = Tuple[str, int]
def _variable_or_constant_to_str(x: Union[str, Variable, ObjectConstant]) -> str:
if isinstance(x, str):
return x
elif isinstance(x, int):
return str(x)
elif isinstance(x, Variable):
return x.name
elif isinstance(x, ObjectConstant):
return x.name
else:
raise TypeError(f'Unknown type: {type(x)}.')
[docs]
def make_sproposition(name: SPredicateName, *args: Union[str, Variable, ObjectConstantExpression]) -> SProposition:
"""
Compose a proposition from a predicate name and a list of arguments.
"""
if len(args) == 0:
return name
return '{} {}'.format(name, ' '.join(_variable_or_constant_to_str(x) for x in args))
[docs]
def make_sproposition_from_function_application(expr: FunctionApplicationExpression, objects: Optional[Dict[str, str]] = None) -> SProposition:
"""
Compose a proposition from a function application expression.
"""
name = expr.function.name
arguments = list()
for arg in expr.arguments:
if not isinstance(arg, ObjectConstantExpression):
raise ValueError(f'Expected object constant, got {arg}.')
if objects is not None and arg.name not in objects:
raise ValueError(f'Unknown object {arg.name}.')
assert arg.name in objects, f'Unknown object {arg.name}.'
arguments.append(arg.name)
return make_sproposition(name, *arguments)
[docs]
class SState(frozenset, FrozenSet[SProposition]):
"""The representation of a STRIPS state, which is a set of propositions."""
pass
[docs]
class SStateDict(dict, Dict[str, Set[Tuple[Union[int, str], ...]]]):
[docs]
def add(self, predicate_name: SPredicateName, arguments: Sequence[Union[int, str]]):
if predicate_name not in self:
self[predicate_name] = set()
self[predicate_name].add(tuple(arguments))
[docs]
def remove(self, predicate_name: SPredicateName, arguments: Sequence[Union[int, str]]):
if predicate_name in self:
self[predicate_name].discard(tuple(arguments))
[docs]
def contains(self, predicate_name: SPredicateName, arguments: Sequence[Union[int, str]], negated: bool = False, check_negation: bool = False) -> bool:
"""Check whether the state contains the given proposition.
Args:
predicate_name: the name of the predicate.
arguments: the arguments of the predicate, as a tuple of integers or strings.
negated: whether the proposition is negated. If True, the function will check whether the state does not contain the proposition.
check_negation: whether the function should also check "{predicate_name}_not" in the state. This will only be used when `negated` is True.
This is useful for delete-relaxed planning.
Returns:
True if the state contains the proposition, False otherwise.
"""
if not check_negation:
if predicate_name in self:
return (arguments in self[predicate_name]) ^ negated
return negated # if the predicate is not in the state and it is not negated, we return False.
else:
if not negated:
if predicate_name in self:
return arguments in self[predicate_name]
return False
else:
true_set = self.get(predicate_name, None)
false_set = self.get(f'{predicate_name}_not', None)
return (false_set is not None and arguments in false_set) or (true_set is None) or (true_set is not None and arguments not in true_set)
[docs]
def clone(self):
return deepcopy(self)
[docs]
def as_state(self) -> SState:
return SState([f'{predicate_name} {" ".join(map(str, arguments))}' for predicate_name, list_of_arguments in self.items() for arguments in list_of_arguments])
SStateCompatible = Union[SState, Set[SProposition]]
# class StripsSASState(dict, Dict[Tuple[StripsPredicateName, str], int]):
# """The representation of a SAS state, which is a mapping from (predicate name, SAS value) to the number of occurrences."""
# """StripsSASState is a mapping from (predicate name, arguments_str) to value."""
# pass
[docs]
class SExpression(ABC):
"""The base class for STRIPS expressions."""
__repr__ = jacinle.repr_from_str
[docs]
def ground(self, variable_dict: Dict[str, str], state: Optional[SStateCompatible] = None):
"""Return a new expression with all variables grounded according to the given variable dictionary."""
raise NotImplementedError()
[docs]
@abstractmethod
def iter_precondition_predicates(self) -> Iterable[SPredicateName]:
"""Iterate over the precondition predicate names in the expression."""
raise NotImplementedError()
[docs]
@abstractmethod
def iter_effect_predicates(self) -> Iterable[SPredicateName]:
"""Iterate over the effect predicate names in the expression."""
raise NotImplementedError()
[docs]
class SValueOutputExpression(SExpression, ABC):
"""The base class for STRIPS expressions that output a value."""
[docs]
@abstractmethod
def iter_precondition_predicates(self) -> Iterable[SPredicateName]:
"""Iterate over the precondition predicate names in the expression."""
raise NotImplementedError()
[docs]
def iter_effect_predicates(self) -> Iterable[SPredicateName]:
"""Iterate over the effect predicate names in the expression."""
return set()
[docs]
class SVariableAssignmentExpression(SExpression, ABC):
"""The base class for STRIPS expressions that assign a value to a state variable."""
[docs]
def iter_precondition_predicates(self) -> Iterable[SPredicateName]:
"""Iterate over the precondition predicate names in the expression."""
return set()
[docs]
@abstractmethod
def iter_effect_predicates(self) -> Iterable[SPredicateName]:
"""Iterate over the effect predicate names in the expression."""
raise NotImplementedError()
[docs]
class SBoolOutputExpression(SValueOutputExpression, ABC):
"""The base class for STRIPS expressions that output a boolean value."""
pass
[docs]
class SBoolConstant(SBoolOutputExpression):
"""The representation of a boolean constant."""
[docs]
def __init__(self, constant: bool):
"""Initialize a boolean constant.
Args:
constant: the value of the constant.
"""
self.constant = constant
[docs]
def ground(self, variable_dict: Dict[str, str], state: Optional[SStateCompatible] = None):
from concepts.dm.pdsketch.strips.strips_grounded_expression import GSBoolConstantExpression
return GSBoolConstantExpression(self.constant)
[docs]
def iter_precondition_predicates(self) -> Iterable[SPredicateName]:
return set()
def __str__(self) -> str:
return 'true' if self.constant else 'false'
[docs]
class SBoolPredicateApplicationExpression(SBoolOutputExpression):
"""The base class for STRIPS expressions that output a boolean value based on a predicate."""
[docs]
def __init__(self, name: SPredicateName, negated: bool, arguments: Sequence[Union[Variable, str]]):
"""Initialize a boolean predicate expression.
Args:
name: the name of the predicate.
negated: whether the predicate is negated.
arguments: the arguments of the predicate. Either variables or str (constants).
"""
self.name = name
self.arguments = tuple(arguments)
self.negated = negated
name: SPredicateName
"""The name of the predicate."""
arguments: Tuple[Union[Variable, str]]
"""The arguments of the predicate."""
negated: bool
"""Whether the predicate is negated."""
[docs]
def ground(self, variable_dict: Dict[str, str], state: Optional[SStateCompatible] = None, negated: bool = False, return_proposition: bool = False):
"""Ground the expression according to the given variable dictionary.
Args:
variable_dict: the variable dictionary.
state: the state to ground the expression on. If None, the expression will be grounded without considering the state.
negated: whether the predicate is negated.
return_proposition: whether to return a SProposition instead of a GSSimpleBoolExpression
Returns:
the grounded expression. Will be a GSSimpleBoolExpression if `return_proposition` is False, otherwise a SProposition.
"""
from concepts.dm.pdsketch.strips.strips_grounded_expression import GSSimpleBoolExpression
identifier = self.name + '_not' if (self.negated ^ negated) else self.name
proposition = make_sproposition(identifier, *tuple(variable_dict[argument.name] if isinstance(argument, Variable) else argument for argument in self.arguments))
if return_proposition:
return proposition
return GSSimpleBoolExpression({proposition})
[docs]
def iter_precondition_predicates(self) -> Iterable[SPredicateName]:
return {self.name}
def __str__(self) -> str:
if len(self.arguments) == 0:
fmt = f'({self.name})'
else:
argument_str = ' '.join(x.name if isinstance(x, Variable) else x for x in self.arguments)
fmt = f'({self.name} {argument_str})'
if self.negated:
return f'(not {fmt})'
return fmt
[docs]
@classmethod
def from_function_application_expression(cls, expression: Union[FunctionApplicationExpression, BoolExpression], negated: bool = False):
assert isinstance(expression, (FunctionApplicationExpression, BoolExpression)), f'Invalid expression type: {type(expression)}.'
if isinstance(expression, BoolExpression):
assert expression.bool_op is BoolOpType.NOT
assert len(expression.arguments) == 1
expression = expression.arguments[0]
assert isinstance(expression, FunctionApplicationExpression)
negated = not negated
return cls.from_function_application_expression(expression, negated)
assert expression.function.return_type == BOOL
new_arguments = list()
for arg in expression.arguments:
if isinstance(arg, VariableExpression):
new_arguments.append(arg.variable)
elif isinstance(arg, ObjectConstantExpression):
new_arguments.append(arg.constant.name)
else:
raise TypeError(f'Invalid argument type: {type(arg)}.')
return cls(expression.function.name, negated, new_arguments)
[docs]
class SSASPredicateApplicationExpression(SBoolPredicateApplicationExpression):
"""The representation for an SAS predicate expression. It is composed of a predicate name and an SAS index."""
[docs]
def __init__(self, sas_name: SPredicateName, sas_index: Optional[int], negated: bool, arguments: Sequence[Variable]):
"""Initialize an SAS predicate expression.
Args:
sas_name: the name of the SAS predicate.
sas_index: the index of the SAS predicate.
negated: whether the predicate is negated.
arguments: the arguments of the predicate.
"""
if sas_index is None:
super().__init__(sas_name, negated, arguments)
else:
super().__init__(sas_name + '@' + str(sas_index), negated, arguments)
self.sas_name = sas_name
self.sas_index = sas_index
name: SPredicateName
arguments: Tuple[Variable]
negated: bool
sas_name: SPredicateName
"""The name of the SAS predicate."""
sas_index: Optional[int]
"""The index of the SAS predicate. If None, it is a normal predicate."""
[docs]
class SSimpleBoolExpression(SBoolOutputExpression):
"""The representation of a boolean expression. Note that since the negation is recorded in the raw :class:`StripsBoolPredicateApplicationExpression`,
we do not need to record it here. Therefore, in this class, we only need to record whether the expression is an AND or an OR."""
[docs]
def __init__(self, arguments: Sequence[SBoolOutputExpression], is_disjunction: bool):
"""Initialize a boolean expression.
Args:
arguments: the arguments of the expression.
is_disjunction: whether the expression is a disjunction.
"""
self.arguments = arguments
self.is_disjunction = is_disjunction
arguments: Sequence[SBoolOutputExpression]
"""The arguments of the expression."""
is_disjunction: bool
"""Whether the expression is a disjunction."""
[docs]
def ground(self, variable_dict: Dict[str, str], state: Optional[SStateCompatible] = None):
from concepts.dm.pdsketch.strips.strips_grounded_expression import gs_compose_bool_expressions
return gs_compose_bool_expressions(
[argument.ground(variable_dict) for argument in self.arguments],
is_disjunction=self.is_disjunction,
)
@property
def is_conjunction(self) -> bool:
"""Whether the expression is a conjunction."""
return not self.is_disjunction
[docs]
def iter_precondition_predicates(self) -> Iterable[SPredicateName]:
return set.union(*(x.iter_precondition_predicates() for x in self.arguments))
def __str__(self) -> str:
arguments_str = [str(arg) for arg in self.arguments]
if sum(len(s) for s in arguments_str) > 120:
arguments_str = [jacinle.indent_text(s) for s in arguments_str]
fmt = '\n'.join(arguments_str)
return f'(or\n{fmt}\n)' if self.is_disjunction else f'(and\n{fmt}\n)'
return '({} {})'.format('or' if self.is_disjunction else 'and', ' '.join(arguments_str))
[docs]
class SBoolNot(SBoolOutputExpression):
"""The representation of a boolean NOT expression. Note that this class is usually only used as a temporary expression during parsing.
At the end, the negation is recorded in the raw :class:`StripsBoolPredicateApplicationExpression`."""
[docs]
def __init__(self, expr: SBoolOutputExpression):
"""Initialize a boolean NOT expression.
Args:
expr: the expression to be negated.
"""
self.expr = expr
expr: SBoolOutputExpression
"""The expression to be negated."""
[docs]
def ground(self, variable_dict: Dict[str, str], state: Optional[SStateCompatible] = None):
if isinstance(self.expr, SBoolPredicateApplicationExpression):
return self.expr.ground(variable_dict, negated=True)
raise NotImplementedError()
[docs]
def iter_precondition_predicates(self) -> Iterable[SPredicateName]:
return self.expr.iter_precondition_predicates()
def __str__(self) -> str:
return '(not {})'.format(str(self.expr))
[docs]
class SQuantificationExpression(SBoolOutputExpression):
"""The representation of a quantification expression."""
[docs]
def __init__(self, variable: Variable, expr: SBoolOutputExpression, is_disjunction: bool):
"""Initialize a quantification expression.
Args:
variable: the variable to be quantified.
expr: the expression to be quantified.
is_disjunction: whether the expression is a disjunction (EXISTS quantification).
"""
self.variable = variable
self.expr = expr
self.is_disjunction = is_disjunction
variable: Variable
"""The variable to be quantified."""
expr: SBoolOutputExpression
"""The expression to be quantified."""
is_disjunction: bool
"""Whether the expression is a disjunction (EXISTS quantification)."""
@property
def is_conjunction(self) -> bool:
"""Whether the expression is a conjunction (FORALL quantification)."""
return not self.is_disjunction
@property
def is_forall(self) -> bool:
"""Whether the expression is a conjunction (FORALL quantification)."""
return not self.is_disjunction
@property
def is_exists(self) -> bool:
"""Whether the expression is a disjunction (EXISTS quantification)."""
return self.is_disjunction
[docs]
def ground(self, variable_dict: Dict[str, str], state: Optional[SStateCompatible] = None):
raise NotImplementedError()
[docs]
def iter_precondition_predicates(self) -> Iterable[SPredicateName]:
return self.expr.iter_precondition_predicates()
def __str__(self) -> str:
return '({} ({}) {})'.format('exists' if self.is_disjunction else 'forall', str(self.variable), str(self.expr))
[docs]
class SSASExpression(SValueOutputExpression): # For all external functions.
"""The representation of an SAS expression. The return value of the expression is an SAS index, therefore it can be represented as a dictionary,
mapping from Boolean expressions to SAS indices. The execution procedure is to first evaluate all Boolean expressions, and then set the SAS index.
Suggested implementation is:
.. code-block:: python
for sas_index, expr in self.mappings.items():
if evaluate(expr, state):
return sas_index
"""
[docs]
def __init__(self, mappings: Dict[int, SBoolOutputExpression]):
"""Initialize an SAS expression.
Args:
mappings: the mappings from SAS indices to Boolean expressions.
"""
self.mappings: Dict[int, SBoolOutputExpression] = mappings
mappings: Dict[int, SBoolOutputExpression]
"""The mappings from SAS indices to Boolean expressions."""
[docs]
def ground(self, variable_dict: Dict[str, str], state: Optional[SStateCompatible] = None):
raise NotImplementedError()
def __str__(self) -> str:
return '(SAS\n{}\n)'.format('\n'.join(' ' + str(i) + ' <- ' + str(self.mappings[i]) for i in self.mappings))
[docs]
class SAssignExpression(SVariableAssignmentExpression):
"""The representation of an assignment expression."""
[docs]
def __init__(self, predicate: Union[SBoolPredicateApplicationExpression, SSASPredicateApplicationExpression], value: Union[SBoolOutputExpression, SSASExpression]):
"""Initialize an assignment expression.
Args:
predicate: the predicate in the state representation to be assigned.
value: the value to be assigned.
"""
self.predicate = predicate
self.value = value
predicate: Union[SBoolPredicateApplicationExpression, SSASPredicateApplicationExpression]
"""The predicate in the state representation to be assigned."""
value: Union[SBoolOutputExpression, SSASExpression]
"""The value to be assigned."""
[docs]
def iter_effect_predicates(self) -> Iterable[SPredicateName]:
return self.predicate.iter_precondition_predicates()
def __str__(self) -> str:
return '({} <- {})'.format(str(self.predicate), str(self.value))
[docs]
class SConditionalAssignExpression(SVariableAssignmentExpression):
"""The representation of a conditional assignment expression. Note that the inner assignment expression is always a :class:`StripsAssignment`."""
[docs]
def __init__(self, assign_op: SAssignExpression, condition: SBoolOutputExpression):
"""Initialize a conditional assignment expression.
Args:
assign_op: the assignment expression.
condition: the condition expression.
"""
self.assign_op = assign_op
self.condition = condition
assign_op: SAssignExpression
"""The assignment expression."""
condition: SBoolOutputExpression
"""The condition expression."""
@property
def predicate(self):
"""The predicate in the state representation to be assigned."""
return self.assign_op.predicate
@property
def value(self):
"""The value to be assigned, if the condition is satisfied."""
return self.assign_op.value
[docs]
def iter_precondition_predicates(self) -> Iterable[SPredicateName]:
return self.condition.iter_precondition_predicates()
[docs]
def iter_effect_predicates(self) -> Iterable[SPredicateName]:
return self.assign_op.iter_effect_predicates()
def __str__(self) -> str:
return '({} if {})'.format(str(self.assign_op), str(self.condition))
[docs]
class SDeicticAssignExpression(SVariableAssignmentExpression):
"""The representation of a deictic assignment expression."""
[docs]
def __init__(self, variable: Variable, expression: SVariableAssignmentExpression):
"""Initialize a deictic assignment expression.
Args:
variable: the deictic variable.
expression: the inner assignment expression.
"""
self.variable = variable
self.expression = expression
variable: Variable
"""The deictic expression."""
expression: SVariableAssignmentExpression
"""The inner assignment expression."""
[docs]
def iter_precondition_predicates(self) -> Iterable[SPredicateName]:
return self.expression.iter_precondition_predicates()
[docs]
def iter_effect_predicates(self) -> Iterable[SPredicateName]:
return self.expression.iter_effect_predicates()
def __str__(self) -> str:
return '(foreach ({}) {})'.format(self.variable, str(self.expression))