#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : crow_parser.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 02/10/2024
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.
import os
import os.path as osp
import contextlib
import jacinle
from typing import Any, Optional, Union, Sequence, Tuple, Set, List, Dict
from dataclasses import dataclass, field
from lark import Lark, Tree
from lark.visitors import Transformer, Interpreter, v_args
from lark.indenter import PythonIndenter
import concepts.dsl.expression as E
from concepts.dsl.dsl_types import TypeBase, AutoType, VectorValueType, BOOL, Variable, ObjectConstant, UnnamedPlaceholder
from concepts.dm.crow.function_utils import flatten_expression
from concepts.dm.crow.controller import CrowControllerApplicationExpression
from concepts.dm.crow.action import CrowPrecondition, CrowEffect, CrowActionBodyItem, CrowActionBodyPrimitiveBase, CrowActionBodySuiteBase
from concepts.dm.crow.action import CrowAchieveExpression, CrowBindExpression, CrowRuntimeAssignmentExpression, CrowAssertExpression, CrowActionApplicationExpression
from concepts.dm.crow.action import CrowActionConditionSuite, CrowActionWhileLoopSuite, CrowActionOrdering, CrowActionOrderingSuite
from concepts.dm.crow.generator import CrowGeneratorApplicationExpression
from concepts.dm.crow.domain import CrowDomain, CrowProblem, CrowState
logger = jacinle.get_logger(__name__)
inline_args = v_args(inline=True)
__all__ = [
'PDSketchV3Parser', 'path_resolver',
'PDSketchV3PathResolver', 'PDSketchV3DomainTransformer', 'PDSketchV3ProblemTransformer', 'PDSketch3LiteralTransformer', 'PDSketch3ExpressionInterpreter',
'load_domain_file', 'load_domain_string', 'load_domain_string_incremental',
'load_problem_file', 'load_problem_string'
]
[docs]class PDSketchV3PathResolver(object):
[docs] def __init__(self, search_paths: Sequence[str] = tuple()):
self.search_paths = list(search_paths)
[docs] def resolve(self, filename: str, relative_filename: Optional[str] = None) -> str:
if osp.exists(filename):
return filename
# Try the relative filename first.
if relative_filename is not None:
relative_dir = osp.dirname(relative_filename)
full_path = osp.join(relative_dir, filename)
if osp.exists(full_path):
return full_path
# Try the current directory second.
if osp.exists(osp.join(os.getcwd(), filename)):
return osp.join(os.getcwd(), filename)
# Then try the search paths.
for path in self.search_paths:
full_path = osp.join(path, filename)
if osp.exists(full_path):
return full_path
raise FileNotFoundError(f'File not found: {filename}')
[docs] def add_search_path(self, path: str):
self.search_paths.append(path)
[docs] def remove_search_path(self, path: str):
self.search_paths.remove(path)
path_resolver = PDSketchV3PathResolver()
[docs]class PDSketchV3Parser(object):
"""The parser for PDSketch v3."""
grammar_file = osp.join(osp.dirname(osp.abspath(__file__)), 'crow.grammar')
"""The grammar definition v3 for PDSketch."""
[docs] def __init__(self):
"""Initialize the parser."""
with open(self.grammar_file, 'r') as f:
self.grammar = f.read()
self.parser = Lark(self.grammar, start='start', postlex=PythonIndenter(), parser='lalr')
[docs] def parse(self, filename: str) -> Tree:
"""Parse a PDSketch v3 file.
Args:
filename: the filename to parse.
Returns:
the parse tree. It is a :class:`lark.Tree` object.
"""
filename = path_resolver.resolve(filename)
with open(filename, 'r') as f:
return self.parse_str(f.read())
[docs] def parse_str(self, s: str) -> Tree:
"""Parse a PDSketch v3 string.
Args:
s: the string to parse.
Returns:
the parse tree. It is a :class:`lark.Tree` object.
"""
# NB(Jiayuan Mao @ 2024/03/13): for reasons, the pdsketch-v3 grammar file requires that the string ends with a newline.
# In particular, the suite definition requires that the file ends with a _DEDENT token, which seems to be only triggered by a newline.
s = s.strip() + '\n'
return self.parser.parse(s)
[docs] def parse_domain(self, filename: str) -> CrowDomain:
"""Parse a PDSketch v3 domain file.
Args:
filename: the filename to parse.
Returns:
the parsed domain.
"""
return self.transform_domain(self.parse(filename))
[docs] def parse_domain_str(self, s: str, domain: Optional[CrowDomain] = None) -> Any:
"""Parse a PDSketch v3 domain string.
Args:
s: the string to parse.
domain: the domain to use. If not provided, a new domain will be created.
Returns:
the parsed domain.
"""
return self.transform_domain(self.parse_str(s), domain=domain)
[docs] def parse_problem(self, filename: str, domain: Optional[CrowDomain] = None) -> CrowProblem:
"""Parse a PDSketch v3 problem file.
Args:
filename: the filename to parse.
domain: the domain to use. If not provided, the domain will be parsed from the problem file.
Returns:
the parsed problem.
"""
return self.transform_problem(self.parse(filename), domain=domain)
[docs] def parse_problem_str(self, s: str, domain: Optional[CrowDomain] = None) -> CrowProblem:
"""Parse a PDSketch v3 problem string.
Args:
s: the string to parse.
domain: the domain to use. If not provided, the domain will be parsed from the problem file.
Returns:
the parsed problem.
"""
return self.transform_problem(self.parse_str(s), domain=domain)
[docs] def parse_expression(self, s: str, domain: CrowDomain, state: Optional[CrowState] = None, variables: Optional[Sequence[Variable]] = None, auto_constant_guess: bool = True) -> E.Expression:
"""Parse a PDSketch v3 expression string.
Args:
s: the string to parse.
domain: the domain to use.
state: the current state, containing objects.
variables: variables from the outer scope.
auto_constant_guess: whether to automatically guess whether a variable is a constant.
Returns:
the parsed expression.
"""
return self.transform_expression(self.parse_str(s), domain, state=state, variables=variables, auto_constant_guess=auto_constant_guess)
[docs] @staticmethod
def transform_domain(tree: Tree, domain: Optional[CrowDomain] = None) -> CrowDomain:
"""Transform a parse tree into a domain.
Args:
tree: the parse tree.
domain: the domain to use. If not provided, a new domain will be created.
Returns:
the parsed domain.
"""
transformer = PDSketchV3DomainTransformer(domain)
transformer.transform(tree)
return transformer.domain
[docs]@dataclass
class LiteralValue(object):
"""A literal value."""
value: Union[bool, int, float, complex, str]
[docs]@dataclass
class LiteralList(object):
"""A list of literals."""
items: Tuple[Union[bool, int, float, complex, str], ...]
[docs]@dataclass
class LiteralSet(object):
"""A set of literals."""
items: Set[Union[bool, int, float, complex, str]]
[docs]@dataclass
class InTypedArgument(object):
"""A typed argument defined as `name in value`. This is used in forall/exists statements."""
name: str
value: Any
[docs]@dataclass
class ArgumentsList(object):
"""A list of argument values. They can be variables, function calls, or other expressions."""
arguments: Tuple[Union['Suite', E.ValueOutputExpression, E.ListExpansionExpression, E.VariableExpression, bool, int, float, complex, str], ...]
[docs]@dataclass
class FunctionCall(object):
"""A function call. This is used as the intermediate representation of the parsed expressions.
Note that this includes not only function calls but also primitive operators and control flow statements.
"""
name: str
args: ArgumentsList
annotations: Optional[Dict[str, Any]] = None
def __str__(self):
annotation_str = ''
if self.annotations is not None:
annotation_str = f'[[' + ', '.join(f'{k}={v}' for k, v in self.annotations.items()) + ']] '
arg_strings = [str(arg) for arg in self.args.arguments]
if sum(len(arg) for arg in arg_strings) > 80:
arg_strings = [jacinle.indent_text(arg) for arg in arg_strings]
return f'{annotation_str}{self.name}:\n' + '\n'.join(arg_strings)
return f'{annotation_str}{self.name}(' + ', '.join(arg_strings) + ')'
def __repr__(self):
return f'FunctionCall{{{str(self)}}}'
[docs]@dataclass
class CSList(object):
"""A comma-separated list of something."""
items: Tuple[Any, ...]
def __str__(self):
return f'CSList({", ".join(str(item) for item in self.items)})'
def __repr__(self):
return self.__str__()
[docs]@dataclass
class Suite(object):
"""A suite of statements. This is used as the intermediate representation of the parsed expressions."""
items: Tuple[Any, ...]
local_variables: Dict[str, Any] = field(default_factory=dict)
tracker: Optional['FunctionCallTracker'] = None
def _init_tracker(self, use_runtime_assign: bool = False):
self.tracker = FunctionCallTracker(self, dict(), use_runtime_assign=use_runtime_assign).run()
[docs] def get_all_assign_expressions(self) -> List[Tuple[E.VariableAssignmentExpression, Dict[str, Any]]]:
if self.tracker is None:
self._init_tracker()
return self.tracker.assign_expressions
[docs] def get_all_check_expressions(self) -> List[E.ValueOutputExpression]:
if self.tracker is None:
self._init_tracker()
return self.tracker.check_expressions
[docs] def get_all_action_expressions(self, use_runtime_assign=True) -> List[CrowActionBodyItem]:
if self.tracker is None:
self._init_tracker(use_runtime_assign=use_runtime_assign)
return self.tracker.action_expressions
[docs] def get_all_expr_expression(self, allow_multiple_expressions: bool = False) -> Optional[Union[E.ValueOutputExpression, Tuple[E.ValueOutputExpression, ...]]]:
if self.tracker is None:
self._init_tracker()
if len(self.tracker.expr_expressions) == 1:
return self.tracker.expr_expressions[0]
if not allow_multiple_expressions:
raise ValueError(f'Multiple expressions found in a single suite: {self.tracker.expr_expressions}')
if len(self.tracker.expr_expressions) == 0:
return tuple()
return tuple(self.tracker.expr_expressions)
[docs] def get_combined_return_expression(self, allow_expr_expressions: bool = False, allow_multiple_expressions: bool = False) -> Optional[E.ValueOutputExpression]:
if self.tracker is None:
self._init_tracker()
if self.tracker.return_expression is not None:
return self.tracker.return_expression
if allow_expr_expressions:
return self.get_all_expr_expression(allow_multiple_expressions=allow_multiple_expressions)
return None
def __str__(self):
if len(self.items) == 0:
return 'Suite{}'
if len(self.items) == 1:
return f'Suite{{{self.items[0]}}}'
return 'Suite{\n' + '\n'.join(jacinle.indent_text(str(item)) for item in self.items) + '\n}'
def __repr__(self):
return self.__str__()
[docs]class FunctionCallTracker(object):
"""This class is used to track the function calls and other statements in a suite. It supports simulating the execution of the program
and generating the post-condition of the program.
"""
[docs] def __init__(self, suite: Suite, init_local_variables: Optional[Dict[str, Any]] = None, use_runtime_assign: bool = False):
self.suite = suite
self.local_variables = dict() if init_local_variables is None else init_local_variables
self.assign_expressions = list()
self.assign_expression_signatures = dict()
self.check_expressions = list()
self.expr_expressions = list()
self.action_expressions = list()
self.return_expression = None
self.use_runtime_assign = use_runtime_assign
local_variables_stack: List[Dict[str, Any]]
"""The assignments of local variables."""
assign_expressions: List[Tuple[E.VariableAssignmentExpression, Dict[str, Any]]]
"""A list of assign expressions."""
assign_expression_signatures: Dict[Tuple[str, ...], E.VariableAssignmentExpression]
"""A dictionary of assign expressions, indexed by their signatures."""
check_expressions: List[E.ValueOutputExpression]
"""A list of check expressions."""
expr_expressions: List[E.ValueOutputExpression]
"""A list of expr expressions (i.e. bare expressions in the body)."""
return_expression: Optional[E.ValueOutputExpression]
"""The return expression. This is either None or a single expression."""
use_runtime_assign: bool
"""Whether to use runtime assign expressions. If this is set to True, assignment expressions for local variables will be converted into runtime assign expressions."""
action_expressions: List[CrowActionBodyItem]
def _g(
self, expr: Union[E.Expression, UnnamedPlaceholder, CrowActionBodyItem]
) -> Union[E.Expression, UnnamedPlaceholder, CrowActionBodyItem]:
if isinstance(expr, (CrowControllerApplicationExpression, CrowActionApplicationExpression, CrowGeneratorApplicationExpression)):
return expr
if isinstance(expr, (CrowActionBodyPrimitiveBase, CrowActionBodySuiteBase)):
return expr
if not isinstance(expr, E.Expression):
raise ValueError(f'Invalid expression: {expr}')
return flatten_expression(expr, {
E.VariableExpression(Variable(k, None)): v
for k, v in self.local_variables.items()
})
def _get_deictic_signature(self, e, known_deictic_vars=tuple()) -> Optional[Tuple[str, ...]]:
if isinstance(e, E.DeicticAssignExpression):
known_deictic_vars = known_deictic_vars + (e.variable.name,)
return self._get_deictic_signature(e.expression, known_deictic_vars)
elif isinstance(e, E.AssignExpression):
args = [x.name if x.name not in known_deictic_vars else '?' for x in e.predicate.arguments]
return tuple((e.predicate.function.name, *args))
else:
return None
def _mark_assign(self, *exprs: E.VariableAssignmentExpression, annotations: Optional[Dict[str, Any]] = None):
if annotations is None:
annotations = dict()
for expr in exprs:
signature = self._get_deictic_signature(expr)
if signature is not None:
if signature in self.assign_expression_signatures:
raise ValueError(f'Duplicate assign expression: {expr} vs {self.assign_expression_signatures[signature]}')
self.assign_expressions.append((expr, annotations))
self.assign_expression_signatures[signature] = expr
else:
self.assign_expressions.append((expr, annotations))
[docs] @jacinle.log_function(verbose=False)
def run(self):
"""Simulate the execution of the program and generates an equivalent return statement.
This function handles if-else conditions. However, loops are not allowed."""
# jacinle.log_function.print('Current suite:', self.suite)
current_return_statement = None
current_return_statement_condition_neg = None
for item in self.suite.items:
assert isinstance(item, FunctionCall), f'Invalid item in suite: {item}'
if item.name == 'assign':
if isinstance(item.args.arguments[0], E.FunctionApplicationExpression) and item.args.arguments[1] is Ellipsis:
self._mark_assign(self._g(E.AssignExpression(item.args.arguments[0], E.NullExpression(item.args.arguments[0].return_type))), annotations=item.annotations)
elif isinstance(item.args.arguments[0], E.FunctionApplicationExpression) and isinstance(item.args.arguments[1], (E.ValueOutputExpression, E.VariableExpression)):
self._mark_assign(self._g(E.AssignExpression(item.args.arguments[0], item.args.arguments[1])), annotations=item.annotations)
elif isinstance(item.args.arguments[0], E.VariableExpression) and isinstance(item.args.arguments[1], (E.ValueOutputExpression, E.VariableExpression, E.FindAllExpression)):
if self.use_runtime_assign:
if item.args.arguments[0].name not in self.local_variables:
self.local_variables[item.args.arguments[0].name] = E.VariableExpression(item.args.arguments[0])
self.action_expressions.append(CrowRuntimeAssignmentExpression(
item.args.arguments[0].variable,
self._g(item.args.arguments[1])
))
else:
self.local_variables[item.args.arguments[0].name] = self._g(item.args.arguments[1])
else:
raise ValueError(f'Invalid assignment: {item}. Types: {type(item.args.arguments[0])}, {type(item.args.arguments[1])}.')
elif item.name == 'check':
assert isinstance(item.args.arguments[0], (E.ValueOutputExpression, E.VariableExpression)), f'Invalid check expression: {item.args.arguments[0]}'
self.check_expressions.append(self._g(item.args.arguments[0]))
elif item.name == 'expr':
if isinstance(item.args.arguments[0], (CrowControllerApplicationExpression, CrowActionApplicationExpression, E.ListExpansionExpression)):
self.action_expressions.append(self._g(item.args.arguments[0]))
else:
assert isinstance(
item.args.arguments[0],
(E.ValueOutputExpression, E.VariableExpression, CrowGeneratorApplicationExpression)
), f'Invalid expr expression: {item.args.arguments[0]}'
self.expr_expressions.append(self._g(item.args.arguments[0]))
elif item.name == 'bind':
arguments = item.args.arguments[0].items
body = item.args.arguments[1]
self.action_expressions.append(CrowBindExpression(arguments, body, **item.annotations if item.annotations is not None else dict()))
elif item.name == 'achieve':
term = item.args.arguments[0]
self.action_expressions.append(CrowAchieveExpression(term, **item.annotations if item.annotations is not None else dict()))
elif item.name == 'assert':
term = item.args.arguments[0]
self.action_expressions.append(CrowAssertExpression(term, **item.annotations if item.annotations is not None else dict()))
elif item.name == 'return':
assert isinstance(item.args.arguments[0], (E.ValueOutputExpression, E.VariableExpression)), f'Invalid return expression: {item.args.arguments[0]}'
self.return_expression = _make_conditional_return(current_return_statement, current_return_statement_condition_neg, self._g(item.args.arguments[0]))
break
elif item.name == 'ordering':
suite = item.args.arguments[1]
tracker = FunctionCallTracker(suite, self.local_variables.copy(), use_runtime_assign=self.use_runtime_assign).run()
self.local_variables = tracker.local_variables
action_expressions = tracker.action_expressions
if item.args.arguments[0] == 'promotable unordered':
prog = CrowActionOrderingSuite('promotable', (CrowActionOrderingSuite('unordered', action_expressions), ))
elif item.args.arguments[0] == 'promotable sequential':
prog = CrowActionOrderingSuite('promotable', action_expressions)
else:
assert ' ' not in item.args.arguments[0], f'Invalid ordering type: {item.args.arguments[0]}'
prog = CrowActionOrderingSuite(item.args.arguments[0], action_expressions)
self.action_expressions.append(prog)
elif item.name == 'if':
condition = self._g(item.args.arguments[0])
neg_condition = E.NotExpression(condition)
assert isinstance(condition, E.ValueOutputExpression), f'Invalid condition: {condition}. Type: {type(condition)}.'
t_suite = item.args.arguments[1]
f_suite = item.args.arguments[2]
t_tracker = FunctionCallTracker(t_suite, self.local_variables.copy(), use_runtime_assign=self.use_runtime_assign).run()
f_tracker = FunctionCallTracker(f_suite, self.local_variables.copy(), use_runtime_assign=self.use_runtime_assign).run()
assert set(t_tracker.local_variables.keys()) == set(f_tracker.local_variables.keys()), f'Local variables in the true and false branches are not consistent: {t_tracker.local_variables.keys()} vs {f_tracker.local_variables.keys()}'
new_local_variables = t_tracker.local_variables
for k, v in t_tracker.local_variables.items():
if f_tracker.local_variables[k] != v:
new_local_variables[k] = E.ConditionExpression(condition, v, f_tracker.local_variables[k])
self.local_variables = new_local_variables
# TODO(Jiayuan Mao @ 2024/03/2): optimize the implementation for this by merging the conditions.
for expr, annotations in t_tracker.assign_expressions:
self._mark_assign(_make_conditional_assign(expr, condition), annotations=annotations)
for expr, annotations in f_tracker.assign_expressions:
self._mark_assign(_make_conditional_assign(expr, neg_condition), annotations=annotations)
for expr in t_tracker.check_expressions:
self.check_expressions.append(_make_conditional_implies(condition, expr))
for expr in f_tracker.check_expressions:
self.check_expressions.append(_make_conditional_implies(neg_condition, expr))
if len(t_tracker.expr_expressions) != len(f_tracker.expr_expressions):
raise ValueError(f'Number of bare expressions in the true and false branches are not consistent: {len(t_tracker.expr_expressions)} vs {len(f_tracker.expr_expressions)}')
if len(t_tracker.expr_expressions) == 0:
pass
elif len(t_tracker.expr_expressions) == 1:
self.expr_expressions.append(E.ConditionExpression(condition, t_tracker.expr_expressions[0], f_tracker.expr_expressions[0]))
else:
raise ValueError(f'Multiple bare expressions in the true and false branches are not supported: {t_tracker.expr_expressions} vs {f_tracker.expr_expressions}')
if len(t_tracker.action_expressions) > 0:
self.action_expressions.append(CrowActionConditionSuite(condition, t_tracker.action_expressions))
if len(f_tracker.action_expressions) > 0:
self.action_expressions.append(CrowActionConditionSuite(neg_condition, f_tracker.action_expressions))
if t_tracker.return_expression is not None and f_tracker.return_expression is not None:
# Both branches have return statements.
statement = E.ConditionExpression(condition, t_tracker.return_expression, f_tracker.return_expression)
self.return_expression = _make_conditional_return(current_return_statement, current_return_statement_condition_neg, statement)
break
elif t_tracker.return_expression is not None:
current_return_statement = _make_conditional_return(current_return_statement, current_return_statement_condition_neg, t_tracker.return_expression)
current_return_statement_condition_neg = E.NotExpression(condition)
elif f_tracker.return_expression is not None:
current_return_statement = _make_conditional_return(current_return_statement, current_return_statement_condition_neg, f_tracker.return_expression)
current_return_statement_condition_neg = condition
else:
pass
elif item.name == 'forall':
suite = item.args.arguments[1]
tracker = FunctionCallTracker(suite, self.local_variables.copy(), use_runtime_assign=self.use_runtime_assign).run()
for k in tracker.local_variables:
if k in self.local_variables and self.local_variables[k] != tracker.local_variables[k]:
raise ValueError(f'Local variable {k} is assigned in the forall statement but has been assigned before: {self.local_variables[k]} vs {tracker.local_variables[k]}')
for expr, annotations in tracker.assign_expressions:
for var in item.args.arguments[0].items:
expr = E.DeicticAssignExpression(var, expr)
self._mark_assign(expr, annotations=annotations)
for expr in tracker.check_expressions:
for var in item.args.arguments[0].items:
expr = E.ForallExpression(var, expr)
self.check_expressions.append(expr)
if len(tracker.expr_expressions) == 0:
pass
else:
if len(tracker.expr_expressions) == 1:
merged = tracker.expr_expressions[0]
else:
merged = E.AndExpression(*tracker.expr_expressions)
for var in item.args.arguments[0].items:
merged = E.ForallExpression(var, merged)
self.expr_expressions.append(merged)
if len(tracker.action_expressions) > 0:
raise ValueError(f'Actions are not allowed in a forall statement: {tracker.action_expressions}')
if tracker.return_expression is not None:
raise ValueError(f'Return statement is not allowed in a forall statement: {tracker.return_expression}')
elif item.name == 'forall_in':
suite = item.args.arguments[0]
tracker = FunctionCallTracker(suite, self.local_variables.copy(), use_runtime_assign=self.use_runtime_assign).run()
for k in tracker.local_variables:
if k in self.local_variables and self.local_variables[k] != tracker.local_variables[k]:
raise ValueError(f'Local variable {k} is assigned in the forall statement but has been assigned before: {self.local_variables[k]} vs {tracker.local_variables[k]}')
for expr, annotations in tracker.assign_expressions:
raise ValueError(f'Assign statements are not allowed in a forall_in statement: {expr}')
for expr in tracker.check_expressions:
self.check_expressions.append(E.AndExpression(expr))
if len(tracker.expr_expressions) == 0:
pass
else:
if len(tracker.expr_expressions) == 1:
merged = tracker.expr_expressions[0]
else:
merged = E.AndExpression(*tracker.expr_expressions)
self.expr_expressions.append(merged)
if len(tracker.action_expressions) > 0:
# TODO(Jiayuan Mao @ 2024/03/12): implement the rest parts of action statements.
for expr in tracker.action_expressions:
if isinstance(expr, CrowAchieveExpression):
self.action_expressions.append(E.ListExpansionExpression(E.AndExpression(expr.goal)))
else:
raise ValueError(f'Action items except for achieve statements are not allowed in a forall_in statement: {expr}')
if tracker.return_expression is not None:
raise ValueError(f'Return statement is not allowed in a forall_in statement: {tracker.return_expression}')
elif item.name == 'pass':
pass
# jacinle.log_function.print('Local variables:', self.local_variables)
# jacinle.log_function.print('Assign expressions:', self.assign_expressions)
# jacinle.log_function.print('Check expressions:', self.check_expressions)
# jacinle.log_function.print('Expr expressions:', self.expr_expressions)
# jacinle.log_function.print('Return expression:', self.return_expression)
return self
def _make_conditional_implies(condition: E.ValueOutputExpression, test: E.ValueOutputExpression):
if isinstance(test, E.BoolExpression) and test.op == E.BoolOpType.IMPLIES:
if isinstance(test.arguments[0], E.BoolExpression) and test.arguments[0].op == E.BoolOpType.AND:
return E.ImpliesExpression(E.AndExpression(condition, *test.arguments[0].arguments), test.arguments[1])
else:
return E.ImpliesExpression(E.AndExpression(condition, test.arguments[0]), test.arguments[1])
else:
return E.ImpliesExpression(condition, test)
def _make_conditional_return(current_stmt: Optional[E.ValueOutputExpression], current_condition_neg: Optional[E.ValueOutputExpression], new_stmt: E.ValueOutputExpression):
if current_stmt is None:
return new_stmt
return E.ConditionExpression(current_condition_neg, new_stmt, current_stmt)
def _make_conditional_assign(assign_stmt: E.VariableAssignmentExpression, condition: E.ValueOutputExpression):
if isinstance(assign_stmt, E.AssignExpression):
return E.ConditionalAssignExpression(assign_stmt.predicate, assign_stmt.value, condition)
elif isinstance(assign_stmt, E.ConditionalAssignExpression):
if isinstance(assign_stmt.condition, E.BoolExpression) and assign_stmt.condition.op == E.BoolOpType.AND:
return E.ConditionalAssignExpression(assign_stmt.predicate, assign_stmt.value, E.AndExpression(condition, *assign_stmt.condition.arguments))
else:
return E.ConditionalAssignExpression(assign_stmt.predicate, assign_stmt.value, E.AndExpression(condition, assign_stmt.condition))
elif isinstance(assign_stmt, E.DeicticAssignExpression):
return E.DeicticAssignExpression(assign_stmt.variable, _make_conditional_assign(assign_stmt.expression, condition))
else:
raise ValueError(f'Invalid assign statement: {assign_stmt}')
[docs]def gen_term_expr(expr_typename: str):
"""Generate a term expression function. This function is used to generate the term expression functions for the transformer.
It is used to generate the following functions:
- mul_expr
- arith_expr
- shift_expr
Args:
expr_typename: the name of the expression type. This is only used for printing the debug information.
Returns:
the generated term expression function.
"""
@inline_args
def term(self, *values: Any):
values = [self.visit(value) for value in values]
if len(values) == 1:
return values[0]
raise NotImplementedError(f'{expr_typename} expression is not supported in the current version.')
assert len(values) % 2 == 1, f'[{expr_typename}] expressions expected an odd number of values, got {len(values)}. Values: {values}.'
result = values[0]
for i in range(1, len(values), 2):
result = FunctionCall(values[i], ArgumentsList((result, values[i + 1])))
# print(f'[{expr_typename}] result: {result}')
return result
return term
[docs]def gen_term_expr_noop(expr_typename: str):
"""Generate a term expression function. This function is used to generate the term expression functions for the transformer.
It is named `_noop` because the arguments to the function does not contain the operator being used. Therefore, we have to
specify the operator name manually (`expr_typename`). This is used for the following functions:
- bitand_expr
- bitxor_expr
- bitor_expr
"""
op_mapping = {
'bitand': E.BoolOpType.AND,
'bitxor': E.BoolOpType.XOR,
'bitor': E.BoolOpType.OR,
}
@inline_args
def term(self, *values: Any):
values = [self.visit(value) for value in values]
if len(values) == 1:
return values[0]
result = E.BoolExpression(op_mapping[expr_typename], _canonize_arguments(values))
# print(f'[{expr_typename}] result: {result}')
return result
return term
def _canonize_single_argument(arg: Any, dtype: Optional[TypeBase] = None) -> Union[E.ObjectOrValueOutputExpression, E.VariableExpression, type(Ellipsis)]:
if isinstance(arg, E.ObjectOrValueOutputExpression):
return arg
if isinstance(arg, (CrowControllerApplicationExpression, CrowActionApplicationExpression, CrowGeneratorApplicationExpression)):
return arg
if isinstance(arg, Variable):
return E.VariableExpression(arg)
if isinstance(arg, (bool, int, float, complex)):
return E.ConstantExpression.from_value(arg, dtype=dtype)
if arg is Ellipsis:
return Ellipsis
raise ValueError(f'Invalid argument: {arg}. Type: {type(arg)}.')
def _canonize_arguments(args: Optional[Union[ArgumentsList, tuple, list]] = None, dtypes: Optional[Sequence[TypeBase]] = None) -> Tuple[Union[E.ValueOutputExpression, E.VariableExpression], ...]:
if args is None:
return tuple()
canonized_args = list()
# TODO(Jiayuan Mao @ 2024/03/2): Strictly check the allowability of "Ellipsis" in the arguments.
arguments = args.arguments if isinstance(args, ArgumentsList) else args
if dtypes is not None:
if len(arguments) != len(dtypes):
raise ValueError(f'Number of arguments does not match the number of types: {len(arguments)} vs {len(dtypes)}. Args: {arguments}, Types: {dtypes}')
for i, arg in enumerate(arguments):
if arg is Ellipsis:
canonized_args.append(Ellipsis)
else:
canonized_args.append(_canonize_single_argument(arg, dtype=dtypes[i] if dtypes is not None else None))
return tuple(canonized_args)
def _has_list_arguments(args: Tuple[E.ObjectOrValueOutputExpression, ...]) -> bool:
for arg in args:
if arg.return_type.is_list_type:
return True
return False
[docs]class PDSketch3ExpressionInterpreter(Interpreter):
"""The transformer for expressions. Including:
- typename
- sized_vector_typename
- unsized_vector_typename
- typed_argument
- is_typed_argument
- in_typed_argument
- arguments_def
- atom_expr_funccall
- atom_varname
- atom
- power
- factor
- unary_op_expr
- mul_expr
- arith_expr
- shift_expr
- bitand_expr
- bitxor_expr
- bitor_expr
- comparison_expr
- not_test
- and_test
- or_test
- cond_test
- test
- test_nocond
- tuple
- list
- cs_list
- suite
- expr_stmt
- expr_list_expansion_stmt
- assign_stmt
- annotated_assign_stmt
- local_assign_stmt
- pass_stmt
- check_stmt
- return_stmt
- achieve_stmt
- compound_assign_stmt
- compound_check_stmt
- compound_return_stmt
- compound_achieve_stmt
- if_stmt
- forall_stmt
- forall_test
- exists_test
- findall_test
- forall_in_test
- exists_in_test
"""
[docs] def __init__(self, domain: Optional[CrowDomain], state: Optional[CrowState], expression_def_ctx: E.ExpressionDefinitionContext, auto_constant_guess: bool = False):
super().__init__()
self.domain = domain
self.state = state
self.expression_def_ctx = expression_def_ctx
self.auto_constant_guess = auto_constant_guess
self.generator_impl_outputs = None
self.local_variables = dict()
[docs] def set_domain(self, domain: CrowDomain):
self.domain = domain
[docs] @contextlib.contextmanager
def local_variable_guard(self):
backup = self.local_variables.copy()
yield
self.local_variables = backup
[docs] @contextlib.contextmanager
def set_generator_impl_outputs(self, outputs: List[Variable]):
backup = self.generator_impl_outputs
self.generator_impl_outputs = outputs
yield
self.generator_impl_outputs = backup
domain: CrowDomain
expression_def_ctx: E.ExpressionDefinitionContext
[docs] def visit(self, tree: Any) -> Any:
if isinstance(tree, Tree):
return super().visit(tree)
return tree
[docs] @inline_args
def atom_varname(self, name: str) -> Union[E.VariableExpression, E.ObjectConstantExpression, E.ValueOutputExpression]:
"""Captures variable names such as `var_name`."""
if name in self.local_variables:
return self.local_variables[name]
if self.state is not None and name in self.state.object_name2defaultindex:
return E.ObjectConstantExpression(ObjectConstant(name, self.domain.types[self.state.get_typename(name)]))
# TODO(Jiayuan Mao @ 2024/03/12): smartly guess the type of the variable.
if not self.expression_def_ctx.has_variable(name):
if self.auto_constant_guess:
return E.ObjectConstantExpression(ObjectConstant(name, AutoType))
variable = self.expression_def_ctx.wrap_variable(name)
return variable
[docs] @inline_args
def atom_expr_funccall(self, annotations: dict, name: str, args: Tree) -> Union[E.FunctionApplicationExpression, E.ListFunctionApplicationExpression, CrowActionBodyItem]:
"""Captures function calls, such as `func_name(arg1, arg2, ...)`."""
annotations: Optional[dict] = self.visit(annotations)
args: Optional[ArgumentsList] = self.visit(args)
if annotations is None:
annotations = dict()
if args is None:
args = ArgumentsList(tuple())
if self.domain.has_feature(name):
function = self.domain.get_feature(name)
args_c = _canonize_arguments(args, function.ftype.argument_types)
if _has_list_arguments(args_c):
return E.ListFunctionApplicationExpression(function, args_c)
return E.FunctionApplicationExpression(function, args_c)
elif self.domain.has_function(name):
function = self.domain.get_function(name)
args_c = _canonize_arguments(args, function.ftype.argument_types)
if _has_list_arguments(args_c):
return E.ListFunctionApplicationExpression(function, args_c)
return E.FunctionApplicationExpression(function, args_c)
elif self.domain.has_controller(name):
controller = self.domain.get_controller(name)
args_c = _canonize_arguments(args, controller.argument_types)
return CrowControllerApplicationExpression(controller, args_c)
elif self.domain.has_action(name):
action = self.domain.get_action(name)
args_c = _canonize_arguments(args, action.argument_types)
if len(args_c) > 0 and args_c[-1] is Ellipsis:
args_c = args_c[:-1] + tuple([UnnamedPlaceholder(t) for t in action.argument_types[len(args_c) - 1:]])
return CrowActionApplicationExpression(action, args_c)
elif self.domain.has_generator(name):
generator = self.domain.get_generator(name)
args_c = _canonize_arguments(args, generator.argument_types)
return CrowGeneratorApplicationExpression(generator, args_c, list())
else:
if 'inplace_action_body' in annotations and annotations['inplace_action_body']:
args_c = _canonize_arguments(args)
argument_types = [arg.return_type for arg in args_c]
# logger.warning(f'Action {name} not found, creating a new one with argument types {argument_types}.')
predicate = self.domain.define_crow_function(name, argument_types, self.domain.get_type('__totally_ordered_plan__'))
return E.FunctionApplicationExpression(predicate, args_c)
elif 'inplace_generator' in annotations and annotations['inplace_generator']:
args_c = _canonize_arguments(args)
argument_types = [arg.return_type for arg in args_c]
# logger.warning(f'Generator placeholder function {name} not found, creating a new one with argument types {argument_types}.')
predicate = self.domain.define_crow_function(
name, argument_types, BOOL,
generator_placeholder=annotations.get('generator_placeholder', True)
)
assert 'inplace_generator_targets' in annotations, f'Inplace generator {name} requires inplace generator targets to be set.'
inplace_generator_targets = annotations['inplace_generator_targets']
generator_name = 'gen_' + name
generator_arguments = predicate.arguments
generator_goal = E.FunctionApplicationExpression(predicate, [E.VariableExpression(arg) for arg in generator_arguments])
output_argument_names = [x.value for x in inplace_generator_targets.items]
output_indices = list()
for i, arg in enumerate(args_c):
if isinstance(arg, E.VariableExpression) and arg.name in output_argument_names:
output_argument_names.remove(arg.name)
output_indices.append(i)
assert len(output_argument_names) == 0, f'Mismatched output arguments for inplace generator {name}: {output_argument_names}'
inputs = [arg for i, arg in enumerate(generator_arguments) if i not in output_indices]
outputs = [arg for i, arg in enumerate(generator_arguments) if i in output_indices]
self.domain.define_generator(generator_name, generator_arguments, generator_goal, inputs, outputs)
return E.FunctionApplicationExpression(predicate, args_c)
else:
raise KeyError(f'Function {name} not found. Note that recursive function calls are not supported in the current version.')
[docs] @inline_args
def atom_subscript(self, name: str, index: Tree) -> Union[E.FunctionApplicationExpression, E.ListFunctionApplicationExpression]:
"""Captures subscript expressions such as `name[index1, index2, ...]`."""
feature = self.domain.get_feature(name)
index: CSList = self.visit(index)
if not feature.is_state_variable:
raise ValueError(f'Invalid subscript expression: {name} is not a state variable. Expression: {name}[{index.items}]')
items = index.items
if len(items) == 1 and items[0] is Ellipsis:
return E.FunctionApplicationExpression(feature, tuple())
arguments = _canonize_arguments(index.items, dtypes=feature.ftype.argument_types)
if _has_list_arguments(arguments):
return E.ListFunctionApplicationExpression(feature, arguments)
return E.FunctionApplicationExpression(feature, arguments)
[docs] @inline_args
def atom(self, value: Union[FunctionCall, Variable]) -> Union[FunctionCall, Variable]:
"""Captures the atom. This is used in the base case of the expression, including literal constants, variables, and subscript expressions."""
return value
[docs] def arguments(self, args: Tree) -> ArgumentsList:
"""Captures the argument list. This is used in function calls."""
args = self.visit_children(args)
return ArgumentsList(tuple(args))
[docs] @inline_args
def power(self, base: Union[FunctionCall, Variable], exp: Optional[float] = None) -> Union[FunctionCall, Variable]:
"""The highest-priority expression. This is used to capture the power expression, such as `base ** exp`. If `exp` is None, it is treated as `base ** 1`."""
if exp is None:
return base
raise NotImplementedError('Power expression is not supported in the current version.')
[docs] @inline_args
def factor(self, value: Union[FunctionCall, Variable]) -> Union[FunctionCall, Variable]:
return value
[docs] @inline_args
def unary_op_expr(self, op: str, value: Union[FunctionCall, Variable]) -> FunctionCall:
raise NotImplementedError('Unary operators are not supported in the current version.')
mul_expr = gen_term_expr('mul')
arith_expr = gen_term_expr('add')
shift_expr = gen_term_expr('shift')
bitand_expr = gen_term_expr_noop('bitand')
bitxor_expr = gen_term_expr_noop('bitxor')
bitor_expr = gen_term_expr_noop('bitor')
[docs] @inline_args
def comparison_expr(self, *values: Union[E.ValueOutputExpression, E.VariableExpression]) -> E.ValueOutputExpression:
if len(values) == 1:
return self.visit(values[0])
assert len(values) % 2 == 1, f'[compare] expressions expected an odd number of values, got {len(values)}. Values: {values}.'
values = [self.visit(value) for value in values]
results = list()
for i in range(1, len(values), 2):
if values[i - 1].return_type.is_object_type and values[i + 1].return_type.is_object_type:
results.append(E.ObjectCompareExpression(E.CompareOpType.from_string(values[i][0].value), values[i - 1], values[i + 1]))
if len(results) == 1:
return results[0]
result = E.AndExpression(*results)
return result
[docs] @inline_args
def not_test(self, value: Any) -> E.NotExpression:
return E.NotExpression(*_canonize_arguments([self.visit(value)]))
[docs] @inline_args
def and_test(self, *values: Any) -> E.AndExpression:
values = [self.visit(value) for value in values]
if len(values) == 1:
return values[0]
result = E.AndExpression(*_canonize_arguments(values))
return result
[docs] @inline_args
def or_test(self, *values: Any) -> E.OrExpression:
values = [self.visit(value) for value in values]
if len(values) == 1:
return values[0]
result = E.OrExpression(*_canonize_arguments(values))
return result
[docs] @inline_args
def cond_test(self, value1: Any, cond: Any, value2: Any) -> E.ConditionExpression:
return E.ConditionExpression(*_canonize_arguments([self.visit(cond), self.visit(value1), self.visit(value2)]))
[docs] @inline_args
def test(self, value: Any):
return self.visit(value)
[docs] @inline_args
def test_nocond(self, value: Any):
return self.visit(value)
[docs] @inline_args
def tuple(self, *values: Any):
return tuple(self.visit(v) for v in values)
[docs] @inline_args
def list(self, *values: Any):
return E.ListCreationExpression(_canonize_arguments([self.visit(v) for v in values]))
[docs] @inline_args
def cs_list(self, *values: Any):
return CSList(tuple(self.visit(v) for v in values))
[docs] @inline_args
def suite(self, *values: Tree, activate_variable_guard: bool = True) -> Suite:
if activate_variable_guard:
with self.local_variable_guard():
values = [self.visit(value) for value in values]
local_variables = self.local_variables.copy()
else:
values = [self.visit(value) for value in values]
local_variables = self.local_variables.copy()
return Suite(tuple(v for v in values if v is not None), local_variables)
[docs] @inline_args
def expr_stmt(self, value: Tree):
value = self.visit(value)
if value is Ellipsis:
return None
if isinstance(value, str):
return None
return FunctionCall('expr', ArgumentsList((_canonize_single_argument(value),)))
[docs] @inline_args
def expr_list_expansion_stmt(self, value: Any):
value = _canonize_single_argument(self.visit(value))
return FunctionCall('expr', ArgumentsList((E.ListExpansionExpression(value), )))
[docs] @inline_args
def assign_stmt(self, target: Any, value: Any):
return FunctionCall('assign', ArgumentsList((_canonize_single_argument(self.visit(target)), _canonize_single_argument(self.visit(value)))))
[docs] @inline_args
def annotated_assign_stmt(self, annotations: dict, target: Any, value: Any):
return FunctionCall('assign', ArgumentsList((_canonize_single_argument(self.visit(target)), _canonize_single_argument(self.visit(value)))), annotations)
[docs] @inline_args
def local_assign_stmt(self, target: str, value: Any = None):
assert isinstance(target, str), f'Invalid local variable name: {target}'
value = _canonize_single_argument(self.visit(value))
self.local_variables[target] = E.VariableExpression(Variable(target, value.return_type))
if value is not None:
return FunctionCall('assign', ArgumentsList((self.local_variables[target], value)))
return None
[docs] @inline_args
def pass_stmt(self):
return FunctionCall('pass', ArgumentsList(tuple()))
[docs] @inline_args
def check_stmt(self, value: Any):
return FunctionCall('check', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs] @inline_args
def achieve_stmt(self, value: CSList):
return FunctionCall('achieve', ArgumentsList(_canonize_arguments(self.visit(value).items)))
[docs] @inline_args
def assert_stmt(self, value: Any):
return FunctionCall('assert', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs] @inline_args
def return_stmt(self, value: Any):
return FunctionCall('return', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs] @inline_args
def compound_assign_stmt(self, target: Variable, value: Any):
return FunctionCall('assign', ArgumentsList((self.visit(target), _canonize_single_argument(self.visit(value)))))
[docs] @inline_args
def compound_check_stmt(self, value: Any):
return FunctionCall('check', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs] @inline_args
def compound_achieve_stmt(self, value: Any):
return FunctionCall('achieve', ArgumentsList(_canonize_arguments([self.visit(value)])))
[docs] @inline_args
def compound_assert_stmt(self, value: Any):
return FunctionCall('assert', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs] @inline_args
def compound_return_stmt(self, value: Any):
return FunctionCall('return', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs] @inline_args
def ordered_suite(self, ordering_op: Any, body: Any):
ordering_op = self.visit(ordering_op)
assert body.data.value == 'suite', f'Invalid body type: {body}'
if ordering_op in ('promotable', 'unordered', 'promotable unordered', 'promotable sequential'):
with self.local_variable_guard():
body = self.visit(body)
return FunctionCall('ordering', ArgumentsList((ordering_op, body)))
else:
body = self.visit_children(body)
body = Suite(tuple(body), self.local_variables.copy())
return FunctionCall('ordering', ArgumentsList((ordering_op, body)))
[docs] @inline_args
def ordering_op(self, *ordering_op: Any):
return ' '.join([x.value for x in ordering_op])
[docs] @inline_args
def if_stmt(self, cond: Any, suite: Any, else_suite: Optional[Any] = None):
cond = _canonize_single_argument(self.visit(cond))
with self.local_variable_guard():
suite = self.visit(suite)
if else_suite is None:
else_suite = Suite((FunctionCall('pass', ArgumentsList(tuple())), ))
else:
with self.local_variable_guard():
else_suite = self.visit(else_suite)
return FunctionCall('if', ArgumentsList((cond, suite, else_suite)))
[docs] @inline_args
def forall_stmt(self, cs_list: Any, suite: Any):
cs_list = self.visit(cs_list)
with self.expression_def_ctx.new_variables(*cs_list.items), self.local_variable_guard():
suite = self.visit(suite)
return FunctionCall('forall', ArgumentsList((cs_list, suite)))
[docs] @inline_args
def forall_in_stmt(self, variables_cs_list: Any, values_cs_list: Any, suite: Any):
variables_cs_list = self.visit(variables_cs_list)
values_cs_list = self.visit(values_cs_list)
with self.local_variable_guard():
for variable_item, value_item in zip(variables_cs_list.items, values_cs_list.items):
self.local_variables[variable_item] = value_item
suite = self.visit(suite)
return FunctionCall('forall_in', ArgumentsList((suite, )))
def _quantification_expression(self, cs_list: Any, suite: Any, quantification_cls):
cs_list = self.visit(cs_list)
with self.expression_def_ctx.new_variables(*cs_list.items), self.local_variable_guard():
suite = self.visit(suite)
body = suite.get_combined_return_expression(allow_expr_expressions=True)
for item in reversed(cs_list.items):
body = quantification_cls(item, body)
return body
[docs] @inline_args
def forall_test(self, cs_list: Any, suite: Any):
return self._quantification_expression(cs_list, suite, E.ForallExpression)
[docs] @inline_args
def exists_test(self, cs_list: Any, suite: Any):
return self._quantification_expression(cs_list, suite, E.ExistsExpression)
[docs] @inline_args
def findall_test(self, variable: Variable, suite: Any):
with self.expression_def_ctx.new_variables(variable), self.local_variable_guard():
suite = self.visit(suite)
body = suite.get_combined_return_expression(allow_expr_expressions=True)
return E.FindAllExpression(variable, body)
def _quantification_in_expression(self, cs_list: Any, suite: Any, quantification_cls):
cs_list = self.visit(cs_list)
with self.local_variable_guard():
item: InTypedArgument
for item in cs_list.items:
self.local_variables[item.name] = self.visit(item.value)
suite = self.visit(suite)
body = suite.get_combined_return_expression(allow_expr_expressions=True)
return quantification_cls(body)
[docs] @inline_args
def forall_in_test(self, cs_list: Any, suite: Any):
return self._quantification_in_expression(cs_list, suite, E.AndExpression)
[docs] @inline_args
def exists_in_test(self, cs_list: Any, suite: Any):
return self._quantification_in_expression(cs_list, suite, E.OrExpression)
[docs] @inline_args
def bind_stmt(self, cs_list: Any, suite: Any):
cs_list = self.visit(cs_list)
with self.expression_def_ctx.new_variables(*cs_list.items), self.local_variable_guard():
suite = self.visit(suite)
body = suite.get_combined_return_expression(allow_expr_expressions=True, allow_multiple_expressions=True)
if isinstance(body, list):
body = E.AndExpression(*body)
for item in cs_list.items:
self.local_variables[item.name] = E.VariableExpression(item)
return FunctionCall('bind', ArgumentsList((cs_list, body)))
[docs] @inline_args
def bind_stmt_no_where(self, cs_list: Any):
"""Captures bind statements without a body. For example:
.. code-block:: python
bind x: int, y: int
"""
cs_list = self.visit(cs_list)
for item in cs_list.items:
self.local_variables[item.name] = E.VariableExpression(item)
return FunctionCall('bind', ArgumentsList((cs_list, E.NullExpression(BOOL))))
[docs] @inline_args
def annotated_compound_stmt(self, annotations: dict, stmt: Any):
stmt = self.visit(stmt)
stmt.annotations = annotations
return stmt
[docs]@dataclass
class ArgumentsDef(object):
arguments: Tuple[Variable, ...]
[docs]@dataclass
class PreconditionPart(object):
suite: Tree
[docs]@dataclass
class EffectPart(object):
suite: Tree
[docs]@dataclass
class GoalPart(object):
suite: Tree
[docs]@dataclass
class BodyPart(object):
suite: Tree
[docs]@dataclass
class SideEffectPart(object):
suite: Tree
[docs]@dataclass
class ImplPart(object):
suite: Tree
[docs]@dataclass
class InPart(object):
suite: Tree
[docs]@dataclass
class OutPart(object):
suite: Tree
[docs]class PDSketchV3DomainTransformer(PDSketch3LiteralTransformer):
[docs] def __init__(self, domain: Optional[CrowDomain] = None):
super().__init__()
self.domain = CrowDomain() if domain is None else domain
self.expression_def_ctx = E.ExpressionDefinitionContext(domain=self.domain)
self.expression_interpreter = PDSketch3ExpressionInterpreter(domain=self.domain, state=None, expression_def_ctx=self.expression_def_ctx, auto_constant_guess=False)
domain: CrowDomain
expression_def_ctx: E.ExpressionDefinitionContext
expression_interpreter: PDSketch3ExpressionInterpreter
[docs] @inline_args
def pragma_definition(self, pragma: Dict[str, Any]):
print('pragma_definition', pragma)
[docs] @inline_args
def type_definition(self, typename, basetype: Optional[Union[str, TypeBase]]):
print(f'type_definition:: {typename=} {basetype=}')
self.domain.define_type(typename, basetype)
[docs] @inline_args
def feature_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], ret: Optional[Union[str, TypeBase]], suite: Optional[Tree]):
if annotations is None:
annotations = dict()
if args is None:
args = ArgumentsDef(tuple())
if ret is None:
ret = BOOL
elif isinstance(ret, str):
ret = self.domain.get_type(ret)
return_stmt = None
if suite is not None:
with self.expression_def_ctx.with_variables(*args.arguments):
suite = self.expression_interpreter.visit(suite)
return_stmt = suite.get_combined_return_expression(allow_expr_expressions=False)
self.domain.define_feature(name, args.arguments, ret, derived_expression=return_stmt, **annotations)
print(f'feature_definition:: {name=} {args.arguments=} {ret=} {annotations=} {suite=}')
if return_stmt is not None:
print(jacinle.indent_text(f'Return statement: {return_stmt}'))
[docs] @inline_args
def function_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], ret: Optional[Union[str, TypeBase]], suite: Optional[Tree]):
if annotations is None:
annotations = dict()
if args is None:
args = ArgumentsDef(tuple())
if ret is None:
ret = BOOL
elif isinstance(ret, str):
ret = self.domain.get_type(ret)
return_stmt = None
if suite is not None:
with self.expression_def_ctx.with_variables(*args.arguments):
suite = self.expression_interpreter.visit(suite)
return_stmt = suite.get_combined_return_expression(allow_expr_expressions=False)
self.domain.define_crow_function(name, args.arguments, ret, derived_expression=return_stmt, **annotations)
print(f'function_definition:: {name=} {args.arguments=} {ret=} {annotations=} {suite=}')
if return_stmt is not None:
print(jacinle.indent_text(f'Return statement: {return_stmt}'))
[docs] @inline_args
def controller_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef]):
if annotations is None:
annotations = dict()
if args is None:
args = ArgumentsDef(tuple())
self.domain.define_controller(name, args.arguments, **annotations)
print(f'controller_definition:: {name=} {args.arguments=}')
[docs] @inline_args
def action_precondition_definition(self, suite: Tree) -> PreconditionPart:
return PreconditionPart(suite)
[docs] @inline_args
def action_effect_definition(self, suite: Tree) -> EffectPart:
return EffectPart(suite)
[docs] @inline_args
def action_goal_definition(self, suite: Tree) -> GoalPart:
return GoalPart(suite)
[docs] @inline_args
def action_body_definition(self, suite: Tree) -> BodyPart:
return BodyPart(suite)
[docs] @inline_args
def action_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], *parts: Union[PreconditionPart, GoalPart, EffectPart, BodyPart]):
if annotations is None:
annotations = dict()
if args is None:
args = ArgumentsDef(tuple())
print(f'action_definition:: {name=} {args.arguments=} {annotations=}')
precondition = list()
goals = list()
body = list()
effect = list()
local_variables = None
for part in parts:
with self.expression_def_ctx.with_variables(*args.arguments):
if isinstance(part, EffectPart):
self.expression_interpreter.local_variables = local_variables if local_variables is not None else dict()
suite = self.expression_interpreter.visit(part.suite)
self.expression_interpreter.local_variables = dict()
else:
suite = self.expression_interpreter.visit(part.suite)
if isinstance(part, BodyPart):
print(jacinle.indent_text(f'Body: {suite}'))
if isinstance(part, PreconditionPart):
suite = suite.get_all_check_expressions()
precondition = [CrowPrecondition(x) for x in suite]
print(jacinle.indent_text(f'Precondition: {precondition}'))
elif isinstance(part, GoalPart):
if len(suite.items) != 1:
raise NotImplementedError('Multiple goals are not supported in the current version.')
goal = suite.items[0]
goals.append(goal)
print(jacinle.indent_text(f'Goal: {goal}'))
elif isinstance(part, BodyPart):
local_variables = suite.local_variables
body = suite.get_all_action_expressions(use_runtime_assign=True)
print(jacinle.indent_text(f'Body:'))
if len(body) == 0:
body = CrowActionOrderingSuite('sequential', tuple())
elif len(body) == 1:
if not isinstance(body[0], CrowActionOrderingSuite) or body[0].ordering_op != 'sequential':
body = CrowActionOrderingSuite('sequential', (body[0],))
else:
body = CrowActionOrderingSuite('sequential', tuple(body))
print(jacinle.indent_text(str(body), level=2))
elif isinstance(part, EffectPart):
suite = suite.get_all_assign_expressions()
effect = [CrowEffect(x, **a) for x, a in suite]
print(jacinle.indent_text(f'Effect: {effect}'))
else:
raise ValueError(f'Invalid part: {part}')
if len(goals) == 0:
goal = E.NullExpression(BOOL)
self.domain.define_action(name, args.arguments, goal, body, precondition, effect, **annotations)
elif len(goals) == 1:
self.domain.define_action(name, args.arguments, goals[0], body, precondition, effect, **annotations)
else:
goal = E.NullExpression(BOOL)
self.domain.define_action(name, args.arguments, goal, body, precondition, effect, **annotations)
for i, goal in enumerate(goals):
self.domain.define_action(f'{name}_{i}', args.arguments, goal, body, precondition, effect, **annotations)
[docs] @inline_args
def generator_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], *parts: Union[PreconditionPart, GoalPart, InPart, OutPart]):
if annotations is None:
annotations = dict()
if args is None:
args = ArgumentsDef(tuple())
print(f'generator_definition:: {name=} {args.arguments=} {annotations=}')
inputs = None
outputs = None
goal = None
with self.expression_def_ctx.with_variables(*args.arguments):
for part in parts:
if isinstance(part, ImplPart):
continue
suite = self.expression_interpreter.visit(part.suite)
if isinstance(part, GoalPart):
goal = suite.get_all_expr_expression(allow_multiple_expressions=False)
print(jacinle.indent_text(f'Goal: {goal}'))
elif isinstance(part, InPart):
inputs = [E.VariableExpression(x) for x in suite.items]
print(jacinle.indent_text(f'Inputs: {inputs}'))
elif isinstance(part, OutPart):
outputs = [E.VariableExpression(x) for x in suite.items]
print(jacinle.indent_text(f'Outputs: {outputs}'))
else:
raise ValueError(f'Invalid part: {part}')
self.domain.define_generator(name, args.arguments, goal, inputs, outputs, **annotations)
[docs] @inline_args
def generator_precondition_definition(self, suite: Tree) -> PreconditionPart:
return PreconditionPart(suite)
[docs] @inline_args
def generator_goal_definition(self, suite: Tree) -> GoalPart:
return GoalPart(suite)
[docs] @inline_args
def generator_in_definition(self, values: Tree) -> InPart:
return InPart(values)
[docs] @inline_args
def generator_out_definition(self, values: Tree) -> OutPart:
return OutPart(values)
[docs] @inline_args
def generator_impl_definition(self, suite: Tree) -> ImplPart:
return ImplPart(suite)
_parser = PDSketchV3Parser()
[docs]def load_domain_file(filename:str) -> CrowDomain:
"""Load a domain file.
Args:
filename: the filename of the domain file.
Returns:
the domain object.
"""
return _parser.parse_domain(filename)
[docs]def load_domain_string(string: str) -> CrowDomain:
"""Load a domain from a string.
Args:
string: the string containing the domain definition.
Returns:
the domain object.
"""
return _parser.parse_domain_str(string)
[docs]def load_domain_string_incremental(domain: CrowDomain, string: str) -> CrowDomain:
"""Load a domain from a string incrementally.
Args:
domain: the domain object to be updated.
string: the string containing the domain definition.
Returns:
the domain object.
"""
return _parser.parse_domain_str(string, domain=domain)
[docs]def load_problem_file(filename: str, domain: Optional[CrowDomain] = None) -> CrowProblem:
"""Load a problem file.
Args:
filename: the filename of the problem file.
domain: the domain object. If not provided, the domain will be loaded from the domain file specified in the problem file.
Returns:
the problem object.
"""
return _parser.parse_problem(filename, domain=domain)
[docs]def load_problem_string(string: str, domain: Optional[CrowDomain] = None) -> CrowProblem:
"""Load a problem from a string.
Args:
string: the string containing the problem definition.
domain: the domain object. If not provided, the domain will be loaded from the domain file specified in the problem file.
Returns:
the problem object.
"""
return _parser.parse_problem_str(string, domain=domain)
[docs]def parse_expression(domain: CrowDomain, string: str, state: Optional[CrowState] = None, variables: Optional[Sequence[Variable]] = None, auto_constant_guess: bool = True) -> E.Expression:
"""Parse an expression.
Args:
domain: the domain object.
string: the string containing the expression.
state: the current state, containing objects.
variables: the variables.
auto_constant_guess: whether to guess whether a variable is a constant.
Returns:
the parsed expression.
"""
return _parser.parse_expression(string, domain, state=state, variables=variables, auto_constant_guess=auto_constant_guess)