#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : pdsketch_v3_parser.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 02/10/2024
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.
import os
import os.path as osp
import itertools
import contextlib
import collections
import jacinle
from typing import Any, Optional, Union, Sequence, Tuple, Set, List, Dict
from dataclasses import dataclass, field
from lark import Lark, Tree
from lark.visitors import Transformer, Interpreter, v_args
from lark.indenter import PythonIndenter
import concepts.dsl.expression as E
from concepts.dsl.dsl_types import TypeBase, AutoType, VectorValueType, TupleType, Variable, ObjectConstant, UnnamedPlaceholder
from concepts.pdsketch.domain import Domain, Problem3, State
from concepts.pdsketch.operator import Precondition, Effect, Implementation, OperatorApplicationExpression
from concepts.pdsketch.generator import FancyGenerator, GeneratorApplicationExpression
from concepts.pdsketch.regression_rule import RegressionRuleBodyItemType, FindExpression, AchieveExpression, RuntimeAssignExpression, RegressionCommitFlag
from concepts.pdsketch.regression_rule import ConditionalRegressionRuleBodyExpression, LoopRegressionRuleBodyExpression
from concepts.pdsketch.regression_rule import RegressionRuleApplicationExpression
from concepts.pdsketch.executor import PDSketchExecutor
logger = jacinle.get_logger(__name__)
inline_args = v_args(inline=True)
__all__ = [
'PDSketchV3Parser', 'path_resolver',
'PDSketchV3PathResolver', 'PDSketchV3DomainTransformer', 'PDSketchV3ProblemTransformer', 'PDSketch3LiteralTransformer', 'PDSketch3ExpressionInterpreter',
'load_domain_file3', 'load_domain_string3', 'load_domain_string3_incremental',
'load_problem_file3', 'load_problem_string3'
]
[docs]class PDSketchV3PathResolver(object):
[docs] def __init__(self, search_paths: Sequence[str] = tuple()):
self.search_paths = list(search_paths)
[docs] def resolve(self, filename: str, relative_filename: Optional[str] = None) -> str:
if osp.exists(filename):
return filename
# Try the relative filename first.
if relative_filename is not None:
relative_dir = osp.dirname(relative_filename)
full_path = osp.join(relative_dir, filename)
if osp.exists(full_path):
return full_path
# Try the current directory second.
if osp.exists(osp.join(os.getcwd(), filename)):
return osp.join(os.getcwd(), filename)
# Then try the search paths.
for path in self.search_paths:
full_path = osp.join(path, filename)
if osp.exists(full_path):
return full_path
raise FileNotFoundError(f'File not found: {filename}')
[docs] def add_search_path(self, path: str):
self.search_paths.append(path)
[docs] def remove_search_path(self, path: str):
self.search_paths.remove(path)
path_resolver = PDSketchV3PathResolver()
[docs]class PDSketchV3Parser(object):
"""The parser for PDSketch v3."""
grammar_file = osp.join(osp.dirname(osp.abspath(__file__)), 'pdsketch-v3.grammar')
"""The grammar definition v3 for PDSketch."""
[docs] def __init__(self):
"""Initialize the parser."""
with open(self.grammar_file, 'r') as f:
self.grammar = f.read()
self.parser = Lark(self.grammar, start='start', postlex=PythonIndenter(), parser='lalr')
[docs] def parse(self, filename: str) -> Tree:
"""Parse a PDSketch v3 file.
Args:
filename: the filename to parse.
Returns:
the parse tree. It is a :class:`lark.Tree` object.
"""
filename = path_resolver.resolve(filename)
with open(filename, 'r') as f:
return self.parse_str(f.read())
[docs] def parse_str(self, s: str) -> Tree:
"""Parse a PDSketch v3 string.
Args:
s: the string to parse.
Returns:
the parse tree. It is a :class:`lark.Tree` object.
"""
# NB(Jiayuan Mao @ 2024/03/13): for reasons, the pdsketch-v3 grammar file requires that the string ends with a newline.
# In particular, the suite definition requires that the file ends with a _DEDENT token, which seems to be only triggered by a newline.
s = s.strip() + '\n'
return self.parser.parse(s)
[docs] def parse_domain(self, filename: str) -> Domain:
"""Parse a PDSketch v3 domain file.
Args:
filename: the filename to parse.
Returns:
the parsed domain.
"""
return self.transform_domain(self.parse(filename))
[docs] def parse_domain_str(self, s: str, domain: Optional[Domain] = None) -> Any:
"""Parse a PDSketch v3 domain string.
Args:
s: the string to parse.
domain: the domain to use. If not provided, a new domain will be created.
Returns:
the parsed domain.
"""
return self.transform_domain(self.parse_str(s), domain=domain)
[docs] def parse_problem(self, filename: str, domain: Optional[Domain] = None) -> Problem3:
"""Parse a PDSketch v3 problem file.
Args:
filename: the filename to parse.
domain: the domain to use. If not provided, the domain will be parsed from the problem file.
Returns:
the parsed problem.
"""
return self.transform_problem(self.parse(filename), domain=domain)
[docs] def parse_problem_str(self, s: str, domain: Optional[Domain] = None) -> Problem3:
"""Parse a PDSketch v3 problem string.
Args:
s: the string to parse.
domain: the domain to use. If not provided, the domain will be parsed from the problem file.
Returns:
the parsed problem.
"""
return self.transform_problem(self.parse_str(s), domain=domain)
[docs] def parse_expression(self, s: str, domain: Domain, state: Optional[State] = None, variables: Optional[Sequence[Variable]] = None, auto_constant_guess: bool = True) -> E.Expression:
"""Parse a PDSketch v3 expression string.
Args:
s: the string to parse.
domain: the domain to use.
state: the current state, containing objects.
variables: variables from the outer scope.
auto_constant_guess: whether to automatically guess whether a variable is a constant.
Returns:
the parsed expression.
"""
return self.transform_expression(self.parse_str(s), domain, state=state, variables=variables, auto_constant_guess=auto_constant_guess)
[docs] @staticmethod
def transform_domain(tree: Tree, domain: Optional[Domain] = None) -> Domain:
"""Transform a parse tree into a domain.
Args:
tree: the parse tree.
domain: the domain to use. If not provided, a new domain will be created.
Returns:
the parsed domain.
"""
transformer = PDSketchV3DomainTransformer(domain)
transformer.transform(tree)
return transformer.domain
[docs]@dataclass
class LiteralValue(object):
"""A literal value."""
value: Union[bool, int, float, complex, str]
[docs]@dataclass
class LiteralList(object):
"""A list of literals."""
items: Tuple[Union[bool, int, float, complex, str], ...]
[docs]@dataclass
class LiteralSet(object):
"""A set of literals."""
items: Set[Union[bool, int, float, complex, str]]
[docs]@dataclass
class InTypedArgument(object):
"""A typed argument defined as `name in value`. This is used in forall/exists statements."""
name: str
value: Any
[docs]@dataclass
class ArgumentsList(object):
"""A list of argument values. They can be variables, function calls, or other expressions."""
arguments: Tuple[Union[E.ValueOutputExpression, E.ListExpansionExpression, E.VariableExpression, bool, int, float, complex, str], ...]
[docs]@dataclass
class FunctionCall(object):
"""A function call. This is used as the intermediate representation of the parsed expressions.
Note that this includes not only function calls but also primitive operators and control flow statements.
"""
name: str
args: ArgumentsList
annotations: Optional[Dict[str, Any]] = None
def __str__(self):
annotation_str = ''
if self.annotations is not None:
annotation_str = f'[[' + ', '.join(f'{k}={v}' for k, v in self.annotations.items()) + ']] '
arg_strings = [str(arg) for arg in self.args.arguments]
if sum(len(arg) for arg in arg_strings) > 80:
arg_strings = [jacinle.indent_text(arg) for arg in arg_strings]
return f'{annotation_str}{self.name}:\n' + '\n'.join(arg_strings)
return f'{annotation_str}{self.name}(' + ', '.join(arg_strings) + ')'
def __repr__(self):
return f'FunctionCall{{{str(self)}}}'
[docs]@dataclass
class CSList(object):
"""A comma-separated list of something."""
items: Tuple[Any, ...]
def __str__(self):
return f'CSList({", ".join(str(item) for item in self.items)})'
def __repr__(self):
return self.__str__()
[docs]@dataclass
class Suite(object):
"""A suite of statements. This is used as the intermediate representation of the parsed expressions."""
items: Tuple[Any, ...]
tracker: Optional['FunctionCallTracker'] = None
def _init_tracker(self, use_runtime_assign: bool = False):
self.tracker = FunctionCallTracker(self, dict(), use_runtime_assign=use_runtime_assign).run()
[docs] def get_all_assign_expressions(self) -> List[Tuple[E.VariableAssignmentExpression, Dict[str, Any]]]:
if self.tracker is None:
self._init_tracker()
return self.tracker.assign_expressions
[docs] def get_all_check_expressions(self) -> List[E.ValueOutputExpression]:
if self.tracker is None:
self._init_tracker()
return self.tracker.check_expressions
[docs] def get_all_regression_expressions(self, use_runtime_assign=True) -> List[RegressionRuleBodyItemType]:
if self.tracker is None:
self._init_tracker(use_runtime_assign=use_runtime_assign)
return self.tracker.regression_expressions
[docs] def get_all_expr_expression(self, allow_multiple_expressions: bool = False) -> Optional[Union[E.ValueOutputExpression, Tuple[E.ValueOutputExpression, ...]]]:
if self.tracker is None:
self._init_tracker()
if len(self.tracker.expr_expressions) == 1:
return self.tracker.expr_expressions[0]
if not allow_multiple_expressions:
raise ValueError(f'Multiple expressions found in a single suite: {self.tracker.expr_expressions}')
if len(self.tracker.expr_expressions) == 0:
return tuple()
return tuple(self.tracker.expr_expressions)
[docs] def get_combined_return_expression(self, allow_expr_expressions: bool = False) -> Optional[E.ValueOutputExpression]:
if self.tracker is None:
self._init_tracker()
if self.tracker.return_expression is not None:
return self.tracker.return_expression
if allow_expr_expressions:
return self.get_all_expr_expression(allow_multiple_expressions=False)
return None
def __str__(self):
if len(self.items) == 0:
return 'Suite{}'
if len(self.items) == 1:
return f'Suite{{{self.items[0]}}}'
return 'Suite{\n' + '\n'.join(jacinle.indent_text(str(item)) for item in self.items) + '\n}'
def __repr__(self):
return self.__str__()
[docs]class FunctionCallTracker(object):
"""This class is used to track the function calls and other statements in a suite. It supports simulating the execution of the program
and generating the post-condition of the program.
"""
[docs] def __init__(self, suite: Suite, init_local_variables: Optional[Dict[str, Any]] = None, use_runtime_assign: bool = False):
self.suite = suite
self.local_variables = dict() if init_local_variables is None else init_local_variables
self.assign_expressions = list()
self.assign_expression_signatures = dict()
self.check_expressions = list()
self.expr_expressions = list()
self.regression_expressions = list()
self.return_expression = None
self.use_runtime_assign = use_runtime_assign
local_variables_stack: List[Dict[str, Any]]
"""The assignments of local variables."""
assign_expressions: List[Tuple[E.VariableAssignmentExpression, Dict[str, Any]]]
"""A list of assign expressions."""
assign_expression_signatures: Dict[Tuple[str, ...], E.VariableAssignmentExpression]
"""A dictionary of assign expressions, indexed by their signatures."""
check_expressions: List[E.ValueOutputExpression]
"""A list of check expressions."""
expr_expressions: List[E.ValueOutputExpression]
"""A list of expr expressions (i.e. bare expressions in the body)."""
return_expression: Optional[E.ValueOutputExpression]
"""The return expression. This is either None or a single expression."""
use_runtime_assign: bool
"""Whether to use runtime assign expressions. If this is set to True, assignment expressions for local variables will be converted into runtime assign expressions."""
regression_expressions: List[Union[
OperatorApplicationExpression, FindExpression, AchieveExpression, RuntimeAssignExpression, E.ListExpansionExpression,
RegressionCommitFlag, RegressionRuleApplicationExpression,
ConditionalRegressionRuleBodyExpression, LoopRegressionRuleBodyExpression
]]
def _g(
self, expr: Union[E.Expression, UnnamedPlaceholder, OperatorApplicationExpression, GeneratorApplicationExpression, RegressionRuleApplicationExpression, RuntimeAssignExpression, Implementation]
) -> Union[E.Expression, UnnamedPlaceholder, OperatorApplicationExpression, GeneratorApplicationExpression, RegressionRuleApplicationExpression, RuntimeAssignExpression, Implementation]:
from concepts.pdsketch.predicate import flatten_expression
if isinstance(expr, OperatorApplicationExpression):
return OperatorApplicationExpression(expr.operator, [self._g(arg) for arg in expr.arguments])
if isinstance(expr, GeneratorApplicationExpression):
return GeneratorApplicationExpression(expr.generator, [self._g(arg) for arg in expr.arguments])
if isinstance(expr, RegressionRuleApplicationExpression):
return RegressionRuleApplicationExpression(expr.rule, [self._g(arg) for arg in expr.arguments])
if isinstance(expr, RuntimeAssignExpression):
return RuntimeAssignExpression(expr.variable, self._g(expr.expression))
if isinstance(expr, Implementation):
return Implementation(expr.name, [self._g(arg) for arg in expr.arguments])
if isinstance(expr, UnnamedPlaceholder):
return expr
if not isinstance(expr, E.Expression):
raise ValueError(f'Invalid expression: {expr}')
return flatten_expression(expr, {
E.VariableExpression(Variable(k, None)): v
for k, v in self.local_variables.items()
})
def _get_deictic_signature(self, e, known_deictic_vars=tuple()) -> Optional[Tuple[str, ...]]:
if isinstance(e, E.DeicticAssignExpression):
known_deictic_vars = known_deictic_vars + (e.variable.name,)
return self._get_deictic_signature(e.expression, known_deictic_vars)
elif isinstance(e, E.AssignExpression):
args = [x.name if x.name not in known_deictic_vars else '?' for x in e.predicate.arguments]
return tuple((e.predicate.function.name, *args))
else:
return None
def _mark_assign(self, *exprs: E.VariableAssignmentExpression, annotations: Optional[Dict[str, Any]] = None):
if annotations is None:
annotations = dict()
for expr in exprs:
signature = self._get_deictic_signature(expr)
if signature is not None:
if signature in self.assign_expression_signatures:
raise ValueError(f'Duplicate assign expression: {expr} vs {self.assign_expression_signatures[signature]}')
self.assign_expressions.append((expr, annotations))
self.assign_expression_signatures[signature] = expr
else:
self.assign_expressions.append((expr, annotations))
[docs] @jacinle.log_function(verbose=False)
def run(self):
"""Simulate the execution of the program and generates an equivalent return statement.
This function handles if-else conditions. However, loops are not allowed."""
# jacinle.log_function.print('Current suite:', self.suite)
current_return_statement = None
current_return_statement_condition_neg = None
for item in self.suite.items:
assert isinstance(item, FunctionCall), f'Invalid item in suite: {item}'
if item.name == 'assign':
if isinstance(item.args.arguments[0], E.FunctionApplicationExpression) and item.args.arguments[1] is Ellipsis:
self._mark_assign(self._g(E.AssignExpression(item.args.arguments[0], E.NullExpression(item.args.arguments[0].return_type))), annotations=item.annotations)
elif isinstance(item.args.arguments[0], E.FunctionApplicationExpression) and isinstance(item.args.arguments[1], (E.ValueOutputExpression, E.VariableExpression)):
self._mark_assign(self._g(E.AssignExpression(item.args.arguments[0], item.args.arguments[1])), annotations=item.annotations)
elif isinstance(item.args.arguments[0], E.VariableExpression) and isinstance(item.args.arguments[1], (E.ValueOutputExpression, E.VariableExpression, E.FindAllExpression)):
if self.use_runtime_assign:
if item.args.arguments[0].name not in self.local_variables:
self.local_variables[item.args.arguments[0].name] = E.VariableExpression(item.args.arguments[0])
self.regression_expressions.append(RuntimeAssignExpression(
item.args.arguments[0].variable,
self._g(item.args.arguments[1])
))
else:
self.local_variables[item.args.arguments[0].name] = self._g(item.args.arguments[1])
else:
raise ValueError(f'Invalid assignment: {item}. Types: {type(item.args.arguments[0])}, {type(item.args.arguments[1])}.')
elif item.name == 'check':
assert isinstance(item.args.arguments[0], (E.ValueOutputExpression, E.VariableExpression)), f'Invalid check expression: {item.args.arguments[0]}'
self.check_expressions.append(self._g(item.args.arguments[0]))
elif item.name == 'expr':
if isinstance(item.args.arguments[0], (OperatorApplicationExpression, RegressionRuleApplicationExpression, E.ListExpansionExpression)):
self.regression_expressions.append(self._g(item.args.arguments[0]))
else:
assert isinstance(
item.args.arguments[0],
(E.ValueOutputExpression, E.VariableExpression, GeneratorApplicationExpression, Implementation)
), f'Invalid expr expression: {item.args.arguments[0]}'
self.expr_expressions.append(self._g(item.args.arguments[0]))
elif item.name == 'find':
arguments = item.args.arguments[0].items
body: Suite = item.args.arguments[1]
self.regression_expressions.append(FindExpression(arguments, body.get_combined_return_expression(allow_expr_expressions=True), **item.annotations if item.annotations is not None else dict()))
elif item.name == 'achieve':
term = item.args.arguments[0]
self.regression_expressions.append(AchieveExpression(term, tuple(), **item.annotations if item.annotations is not None else dict()))
# TODO(Jiayuan Mao @ 2024/03/2): implement achieve-maintain
elif item.name == 'return':
assert isinstance(item.args.arguments[0], (E.ValueOutputExpression, E.VariableExpression)), f'Invalid return expression: {item.args.arguments[0]}'
self.return_expression = _make_conditional_return(current_return_statement, current_return_statement_condition_neg, self._g(item.args.arguments[0]))
break
elif item.name == 'if':
condition = self._g(item.args.arguments[0])
neg_condition = E.NotExpression(condition)
assert isinstance(condition, E.ValueOutputExpression), f'Invalid condition: {condition}. Type: {type(condition)}.'
t_suite = item.args.arguments[1]
f_suite = item.args.arguments[2]
t_tracker = FunctionCallTracker(t_suite, self.local_variables.copy()).run()
f_tracker = FunctionCallTracker(f_suite, self.local_variables.copy()).run()
assert set(t_tracker.local_variables.keys()) == set(f_tracker.local_variables.keys()), f'Local variables in the true and false branches are not consistent: {t_tracker.local_variables.keys()} vs {f_tracker.local_variables.keys()}'
new_local_variables = t_tracker.local_variables
for k, v in t_tracker.local_variables.items():
if f_tracker.local_variables[k] != v:
new_local_variables[k] = E.ConditionExpression(condition, v, f_tracker.local_variables[k])
self.local_variables = new_local_variables
# TODO(Jiayuan Mao @ 2024/03/2): optimize the implementation for this by merging the conditions.
for expr, annotations in t_tracker.assign_expressions:
self._mark_assign(_make_conditional_assign(expr, condition), annotations=annotations)
for expr, annotations in f_tracker.assign_expressions:
self._mark_assign(_make_conditional_assign(expr, neg_condition), annotations=annotations)
for expr in t_tracker.check_expressions:
self.check_expressions.append(_make_conditional_implies(condition, expr))
for expr in f_tracker.check_expressions:
self.check_expressions.append(_make_conditional_implies(neg_condition, expr))
if len(t_tracker.expr_expressions) != len(f_tracker.expr_expressions):
raise ValueError(f'Number of bare expressions in the true and false branches are not consistent: {len(t_tracker.expr_expressions)} vs {len(f_tracker.expr_expressions)}')
if len(t_tracker.expr_expressions) == 0:
pass
elif len(t_tracker.expr_expressions) == 1:
self.expr_expressions.append(E.ConditionExpression(condition, t_tracker.expr_expressions[0], f_tracker.expr_expressions[0]))
else:
raise ValueError(f'Multiple bare expressions in the true and false branches are not supported: {t_tracker.expr_expressions} vs {f_tracker.expr_expressions}')
if len(t_tracker.regression_expressions) > 0:
self.regression_expressions.append(ConditionalRegressionRuleBodyExpression(condition, t_tracker.regression_expressions))
if len(f_tracker.regression_expressions) > 0:
self.regression_expressions.append(ConditionalRegressionRuleBodyExpression(neg_condition, f_tracker.regression_expressions))
if t_tracker.return_expression is not None and f_tracker.return_expression is not None:
# Both branches have return statements.
statement = E.ConditionExpression(condition, t_tracker.return_expression, f_tracker.return_expression)
self.return_expression = _make_conditional_return(current_return_statement, current_return_statement_condition_neg, statement)
break
elif t_tracker.return_expression is not None:
current_return_statement = _make_conditional_return(current_return_statement, current_return_statement_condition_neg, t_tracker.return_expression)
current_return_statement_condition_neg = E.NotExpression(condition)
elif f_tracker.return_expression is not None:
current_return_statement = _make_conditional_return(current_return_statement, current_return_statement_condition_neg, f_tracker.return_expression)
current_return_statement_condition_neg = condition
else:
pass
elif item.name == 'forall':
suite = item.args.arguments[1]
tracker = FunctionCallTracker(suite, self.local_variables.copy()).run()
for k in tracker.local_variables:
if k in self.local_variables and self.local_variables[k] != tracker.local_variables[k]:
raise ValueError(f'Local variable {k} is assigned in the forall statement but has been assigned before: {self.local_variables[k]} vs {tracker.local_variables[k]}')
for expr, annotations in tracker.assign_expressions:
for var in item.args.arguments[0].items:
expr = E.DeicticAssignExpression(var, expr)
self._mark_assign(expr, annotations=annotations)
for expr in tracker.check_expressions:
for var in item.args.arguments[0].items:
expr = E.ForallExpression(var, expr)
self.check_expressions.append(expr)
if len(tracker.expr_expressions) == 0:
pass
else:
if len(tracker.expr_expressions) == 1:
merged = tracker.expr_expressions[0]
else:
merged = E.AndExpression(*tracker.expr_expressions)
for var in item.args.arguments[0].items:
merged = E.ForallExpression(var, merged)
self.expr_expressions.append(merged)
if len(tracker.regression_expressions) > 0:
raise ValueError(f'Regression rules are not allowed in a forall statement: {tracker.regression_expressions}')
if tracker.return_expression is not None:
raise ValueError(f'Return statement is not allowed in a forall statement: {tracker.return_expression}')
elif item.name == 'forall_in':
suite = item.args.arguments[0]
tracker = FunctionCallTracker(suite, self.local_variables.copy()).run()
for k in tracker.local_variables:
if k in self.local_variables and self.local_variables[k] != tracker.local_variables[k]:
raise ValueError(f'Local variable {k} is assigned in the forall statement but has been assigned before: {self.local_variables[k]} vs {tracker.local_variables[k]}')
for expr, annotations in tracker.assign_expressions:
raise ValueError(f'Assign statements are not allowed in a forall_in statement: {expr}')
for expr in tracker.check_expressions:
self.check_expressions.append(E.AndExpression(expr))
if len(tracker.expr_expressions) == 0:
pass
else:
if len(tracker.expr_expressions) == 1:
merged = tracker.expr_expressions[0]
else:
merged = E.AndExpression(*tracker.expr_expressions)
self.expr_expressions.append(merged)
if len(tracker.regression_expressions) > 0:
# TODO(Jiayuan Mao @ 2024/03/12): implement the rest parts of regression statements.
for expr in tracker.regression_expressions:
if isinstance(expr, AchieveExpression):
self.regression_expressions.append(E.ListExpansionExpression(E.AndExpression(expr.goal)))
else:
raise ValueError(f'Regression rule items except for achieve statements are not allowed in a forall_in statement: {expr}')
if tracker.return_expression is not None:
raise ValueError(f'Return statement is not allowed in a forall_in statement: {tracker.return_expression}')
elif item.name == 'pass':
pass
# jacinle.log_function.print('Local variables:', self.local_variables)
# jacinle.log_function.print('Assign expressions:', self.assign_expressions)
# jacinle.log_function.print('Check expressions:', self.check_expressions)
# jacinle.log_function.print('Expr expressions:', self.expr_expressions)
# jacinle.log_function.print('Return expression:', self.return_expression)
return self
def _make_conditional_implies(condition: E.ValueOutputExpression, test: E.ValueOutputExpression):
if isinstance(test, E.BoolExpression) and test.op == E.BoolOpType.IMPLIES:
if isinstance(test.arguments[0], E.BoolExpression) and test.arguments[0].op == E.BoolOpType.AND:
return E.ImpliesExpression(E.AndExpression(condition, *test.arguments[0].arguments), test.arguments[1])
else:
return E.ImpliesExpression(E.AndExpression(condition, test.arguments[0]), test.arguments[1])
else:
return E.ImpliesExpression(condition, test)
def _make_conditional_return(current_stmt: Optional[E.ValueOutputExpression], current_condition_neg: Optional[E.ValueOutputExpression], new_stmt: E.ValueOutputExpression):
if current_stmt is None:
return new_stmt
return E.ConditionExpression(current_condition_neg, new_stmt, current_stmt)
def _make_conditional_assign(assign_stmt: E.VariableAssignmentExpression, condition: E.ValueOutputExpression):
if isinstance(assign_stmt, E.AssignExpression):
return E.ConditionalAssignExpression(assign_stmt.predicate, assign_stmt.value, condition)
elif isinstance(assign_stmt, E.ConditionalAssignExpression):
if isinstance(assign_stmt.condition, E.BoolExpression) and assign_stmt.condition.op == E.BoolOpType.AND:
return E.ConditionalAssignExpression(assign_stmt.predicate, assign_stmt.value, E.AndExpression(condition, *assign_stmt.condition.arguments))
else:
return E.ConditionalAssignExpression(assign_stmt.predicate, assign_stmt.value, E.AndExpression(condition, assign_stmt.condition))
elif isinstance(assign_stmt, E.DeicticAssignExpression):
return E.DeicticAssignExpression(assign_stmt.variable, _make_conditional_assign(assign_stmt.expression, condition))
else:
raise ValueError(f'Invalid assign statement: {assign_stmt}')
def _make_conditional_regression(regression_stmt: Union[OperatorApplicationExpression, FindExpression, AchieveExpression, RegressionCommitFlag, RegressionRuleApplicationExpression, ConditionalRegressionRuleBodyExpression], condition: E.ValueOutputExpression):
# NB(Jiayuan Mao @ 2024/03/2): Not used. Because the original ConditionalRegressionRuleBodyExpression already supports recursive conditions.
if isinstance(regression_stmt, ConditionalRegressionRuleBodyExpression):
if isinstance(regression_stmt.condition, E.BoolExpression) and regression_stmt.condition.op == E.BoolOpType.AND:
return ConditionalRegressionRuleBodyExpression(E.AndExpression(condition, *regression_stmt.condition.arguments), regression_stmt.body)
else:
return ConditionalRegressionRuleBodyExpression(E.AndExpression(condition, regression_stmt.condition), regression_stmt.body)
else:
return ConditionalRegressionRuleBodyExpression(condition, (regression_stmt, ))
[docs]def gen_term_expr(expr_typename: str):
"""Generate a term expression function. This function is used to generate the term expression functions for the transformer.
It is used to generate the following functions:
- mul_expr
- arith_expr
- shift_expr
Args:
expr_typename: the name of the expression type. This is only used for printing the debug information.
Returns:
the generated term expression function.
"""
@inline_args
def term(self, *values: Any):
values = [self.visit(value) for value in values]
if len(values) == 1:
return values[0]
raise NotImplementedError(f'{expr_typename} expression is not supported in the current version.')
assert len(values) % 2 == 1, f'[{expr_typename}] expressions expected an odd number of values, got {len(values)}. Values: {values}.'
result = values[0]
for i in range(1, len(values), 2):
result = FunctionCall(values[i], ArgumentsList((result, values[i + 1])))
# print(f'[{expr_typename}] result: {result}')
return result
return term
[docs]def gen_term_expr_noop(expr_typename: str):
"""Generate a term expression function. This function is used to generate the term expression functions for the transformer.
It is named `_noop` because the arguments to the function does not contain the operator being used. Therefore, we have to
specify the operator name manually (`expr_typename`). This is used for the following functions:
- bitand_expr
- bitxor_expr
- bitor_expr
"""
op_mapping = {
'bitand': E.BoolOpType.AND,
'bitxor': E.BoolOpType.XOR,
'bitor': E.BoolOpType.OR,
}
@inline_args
def term(self, *values: Any):
values = [self.visit(value) for value in values]
if len(values) == 1:
return values[0]
result = E.BoolExpression(op_mapping[expr_typename], _canonize_arguments(values))
# print(f'[{expr_typename}] result: {result}')
return result
return term
def _canonize_single_argument(arg: Any, dtype: Optional[TypeBase] = None) -> Union[E.ValueOutputExpression, E.VariableExpression, type(Ellipsis)]:
if isinstance(arg, (E.ObjectOrValueOutputExpression, OperatorApplicationExpression, GeneratorApplicationExpression, RegressionRuleApplicationExpression, Implementation)):
return arg
if isinstance(arg, Variable):
return E.VariableExpression(arg)
if isinstance(arg, (bool, int, float, complex)):
return E.ConstantExpression.from_value(arg, dtype=dtype)
if arg is Ellipsis:
return Ellipsis
raise ValueError(f'Invalid argument: {arg}')
def _canonize_arguments(args: Optional[Union[ArgumentsList, tuple, list]] = None, dtypes: Optional[Sequence[TypeBase]] = None) -> Tuple[Union[E.ValueOutputExpression, E.VariableExpression], ...]:
if args is None:
return tuple()
canonized_args = list()
# TODO(Jiayuan Mao @ 2024/03/2): Strictly check the allowability of "Ellipsis" in the arguments.
for i, arg in enumerate(args.arguments if isinstance(args, ArgumentsList) else args):
if arg is Ellipsis:
canonized_args.append(Ellipsis)
else:
canonized_args.append(_canonize_single_argument(arg, dtype=dtypes[i] if dtypes is not None else None))
return tuple(canonized_args)
def _has_list_arguments(args: Tuple[E.ObjectOrValueOutputExpression, ...]) -> bool:
for arg in args:
if arg.return_type.is_list_type:
return True
return False
[docs]class PDSketch3ExpressionInterpreter(Interpreter):
"""The transformer for expressions. Including:
- typename
- sized_vector_typename
- unsized_vector_typename
- typed_argument
- is_typed_argument
- in_typed_argument
- arguments_def
- atom_expr_funccall
- atom_varname
- atom
- power
- factor
- unary_op_expr
- mul_expr
- arith_expr
- shift_expr
- bitand_expr
- bitxor_expr
- bitor_expr
- comparison_expr
- not_test
- and_test
- or_test
- cond_test
- test
- test_nocond
- tuple
- list
- cs_list
- suite
- expr_stmt
- expr_list_expansion_stmt
- assign_stmt
- annotated_assign_stmt
- local_assign_stmt
- pass_stmt
- check_stmt
- return_stmt
- achieve_stmt
- compound_assign_stmt
- compound_check_stmt
- compound_return_stmt
- compound_achieve_stmt
- if_stmt
- forall_stmt
- forall_test
- exists_test
- findall_test
- forall_in_test
- exists_in_test
"""
[docs] def __init__(self, domain: Optional[Domain], state: Optional[State], expression_def_ctx: E.ExpressionDefinitionContext, auto_constant_guess: bool = False):
super().__init__()
self.domain = domain
self.state = state
self.expression_def_ctx = expression_def_ctx
self.auto_constant_guess = auto_constant_guess
self.generator_impl_outputs = None
self.local_variables = dict()
[docs] def set_domain(self, domain: Domain):
self.domain = domain
[docs] @contextlib.contextmanager
def local_variable_guard(self):
backup = self.local_variables.copy()
yield
self.local_variables = backup
[docs] @contextlib.contextmanager
def set_generator_impl_outputs(self, outputs: List[Variable]):
backup = self.generator_impl_outputs
self.generator_impl_outputs = outputs
yield
self.generator_impl_outputs = backup
domain: Domain
expression_def_ctx: E.ExpressionDefinitionContext
[docs] def visit(self, tree: Any) -> Any:
if isinstance(tree, Tree):
return super().visit(tree)
return tree
[docs] @inline_args
def atom_varname(self, name: str) -> Union[E.VariableExpression, E.ObjectConstantExpression, E.ValueOutputExpression]:
"""Captures variable names such as `var_name`."""
if name in self.local_variables:
return self.local_variables[name]
if self.state is not None and name in self.state.object_name2defaultindex:
return E.ObjectConstantExpression(ObjectConstant(name, self.domain.types[self.state.get_typename(name)]))
# TODO(Jiayuan Mao @ 2024/03/12): smartly guess the type of the variable.
if not self.expression_def_ctx.has_variable(name):
if self.auto_constant_guess:
return E.ObjectConstantExpression(ObjectConstant(name, AutoType))
variable = self.expression_def_ctx.wrap_variable(name)
return variable
[docs] @inline_args
def atom_expr_funccall(self, annotations: dict, name: str, args: Tree) -> Union[E.FunctionApplicationExpression, E.ListFunctionApplicationExpression, Implementation, OperatorApplicationExpression, GeneratorApplicationExpression, RegressionRuleApplicationExpression]:
"""Captures function calls, such as `func_name(arg1, arg2, ...)`."""
annotations: Optional[dict] = self.visit(annotations)
args: Optional[ArgumentsList] = self.visit(args)
if annotations is None:
annotations = dict()
if args is None:
args = ArgumentsList(tuple())
if self.domain.has_function(name):
function = self.domain.get_function(name)
args_c = _canonize_arguments(args, function.ftype.argument_types)
if _has_list_arguments(args_c):
return E.ListFunctionApplicationExpression(function, args_c)
return E.FunctionApplicationExpression(function, args_c)
elif self.domain.has_operator(name):
operator = self.domain.get_operator(name)
args_c = _canonize_arguments(args, operator.argument_types)
if len(args_c) > 0 and args_c[-1] is Ellipsis:
args_c = args_c[:-1] + tuple([UnnamedPlaceholder(t) for t in operator.argument_types[len(args_c) - 1:]])
return OperatorApplicationExpression(operator, args_c)
elif self.domain.has_generator(name):
generator = self.domain.get_generator(name)
args_c = _canonize_arguments(args, generator.argument_types)
if isinstance(generator, FancyGenerator):
raise NotImplementedError('Fancy generators are not supported in the current version.')
return GeneratorApplicationExpression(generator, args_c)
elif self.domain.has_regression_rule(name):
rule = self.domain.get_regression_rule(name)
args_c = _canonize_arguments(args, rule.argument_types)
return RegressionRuleApplicationExpression(rule, args_c)
else:
if 'action_impl' in annotations and annotations['action_impl']:
args_c = _canonize_arguments(args)
argument_types = [arg.return_type for arg in args_c]
# logger.warning(f'Controller function {name} not found, creating a new one with argument types {argument_types}.')
predicate = self.domain.define_predicate(name, argument_types, self.domain.get_type('__control__'), observation=False, state=False)
return Implementation(predicate.name, args_c)
elif 'generator_impl' in annotations and annotations['generator_impl']:
args_c = _canonize_arguments(args)
argument_types = [arg.return_type for arg in args_c]
# logger.warning(f'Generator function {name} not found, creating a new one with argument types {argument_types}.')
assert self.generator_impl_outputs is not None, f'Generator implementation {name} requires generator outputs to be set.'
generator_outputs = TupleType([value.dtype for value in self.generator_impl_outputs])
predicate = self.domain.define_predicate(name, argument_types, generator_outputs, observation=False, state=False, is_generator_function=True)
return Implementation(predicate.name, args_c)
elif 'regression_impl' in annotations and annotations['regression_impl']:
args_c = _canonize_arguments(args)
argument_types = [arg.return_type for arg in args_c]
# logger.warning(f'Regression function {name} not found, creating a new one with argument types {argument_types}.')
predicate = self.domain.define_predicate(name, argument_types, self.domain.get_type('__totally_ordered_plan__'), observation=False, state=False)
return E.FunctionApplicationExpression(predicate, args_c)
elif 'inplace_generator' in annotations and annotations['inplace_generator']:
args_c = _canonize_arguments(args)
argument_types = [arg.return_type for arg in args_c]
# logger.warning(f'Generator placeholder function {name} not found, creating a new one with argument types {argument_types}.')
predicate = self.domain.define_predicate(
name, argument_types, self.domain.get_type('bool'), observation=False, state=False,
generator_placeholder=annotations.get('generator_placeholder', True)
)
assert 'inplace_generator_targets' in annotations, f'Inplace generator {name} requires inplace generator targets to be set.'
inplace_generator_targets = annotations['inplace_generator_targets']
generator_name = 'gen_' + name
generator_arguments = predicate.arguments
generator_goal = E.FunctionApplicationExpression(predicate, [E.VariableExpression(arg) for arg in generator_arguments])
output_argument_names = [x.value for x in inplace_generator_targets.items]
output_indices = list()
for i, arg in enumerate(args_c):
if isinstance(arg, E.VariableExpression) and arg.name in output_argument_names:
output_argument_names.remove(arg.name)
output_indices.append(i)
assert len(output_argument_names) == 0, f'Mismatched output arguments for inplace generator {name}: {output_argument_names}'
inputs = [arg for i, arg in enumerate(generator_arguments) if i not in output_indices]
outputs = [arg for i, arg in enumerate(generator_arguments) if i in output_indices]
impl = Implementation(generator_name + '_impl', inputs)
self.domain.define_generator(generator_name, generator_arguments, generator_goal, inputs, outputs, impl)
return E.FunctionApplicationExpression(predicate, args_c)
else:
raise KeyError(f'Function {name} not found. Note that recursive function calls are not supported in the current version.')
[docs] @inline_args
def atom_subscript(self, name: str, index: Tree) -> Union[E.FunctionApplicationExpression, E.ListFunctionApplicationExpression]:
"""Captures subscript expressions such as `name[index1, index2, ...]`."""
predicate = self.domain.get_predicate(name)
index: CSList = self.visit(index)
if not predicate.is_state_variable:
raise ValueError(f'Invalid subscript expression: {name} is not a state variable. Expression: {name}[{index.items}]')
items = index.items
if len(items) == 1 and items[0] is Ellipsis:
return E.FunctionApplicationExpression(predicate, tuple())
arguments = _canonize_arguments(index.items, dtypes=predicate.ftype.argument_types)
if _has_list_arguments(arguments):
return E.ListFunctionApplicationExpression(predicate, arguments)
return E.FunctionApplicationExpression(predicate, arguments)
[docs] @inline_args
def atom(self, value: Union[FunctionCall, Variable]) -> Union[FunctionCall, Variable]:
"""Captures atoms. This is used to Captures the base case of the expression, including literal constants, variables, and subscript expressions."""
return value
[docs] def arguments(self, args: Tree) -> ArgumentsList:
"""Captures the argument list. This is used in function calls."""
args = self.visit_children(args)
return ArgumentsList(tuple(args))
[docs] @inline_args
def power(self, base: Union[FunctionCall, Variable], exp: Optional[float] = None) -> Union[FunctionCall, Variable]:
"""The highest-priority expression. This is used to capture the power expression, such as `base ** exp`. If `exp` is None, it is treated as `base ** 1`."""
if exp is None:
return base
raise NotImplementedError('Power expression is not supported in the current version.')
[docs] @inline_args
def factor(self, value: Union[FunctionCall, Variable]) -> Union[FunctionCall, Variable]:
return value
[docs] @inline_args
def unary_op_expr(self, op: str, value: Union[FunctionCall, Variable]) -> FunctionCall:
raise NotImplementedError('Unary operators are not supported in the current version.')
mul_expr = gen_term_expr('mul')
arith_expr = gen_term_expr('add')
shift_expr = gen_term_expr('shift')
bitand_expr = gen_term_expr_noop('bitand')
bitxor_expr = gen_term_expr_noop('bitxor')
bitor_expr = gen_term_expr_noop('bitor')
[docs] @inline_args
def comparison_expr(self, *values: Union[E.ValueOutputExpression, E.VariableExpression]) -> E.ValueOutputExpression:
if len(values) == 1:
return self.visit(values[0])
assert len(values) % 2 == 1, f'[compare] expressions expected an odd number of values, got {len(values)}. Values: {values}.'
values = [self.visit(value) for value in values]
results = list()
for i in range(1, len(values), 2):
if values[i - 1].return_type.is_object_type and values[i + 1].return_type.is_object_type:
results.append(E.ObjectCompareExpression(E.CompareOpType.from_string(values[i][0].value), values[i - 1], values[i + 1]))
if len(results) == 1:
return results[0]
result = E.AndExpression(*results)
return result
[docs] @inline_args
def not_test(self, value: Any) -> E.NotExpression:
return E.NotExpression(*_canonize_arguments([self.visit(value)]))
[docs] @inline_args
def and_test(self, *values: Any) -> E.AndExpression:
values = [self.visit(value) for value in values]
if len(values) == 1:
return values[0]
result = E.AndExpression(*_canonize_arguments(values))
return result
[docs] @inline_args
def or_test(self, *values: Any) -> E.OrExpression:
values = [self.visit(value) for value in values]
if len(values) == 1:
return values[0]
result = E.OrExpression(*_canonize_arguments(values))
return result
[docs] @inline_args
def cond_test(self, value1: Any, cond: Any, value2: Any) -> E.ConditionExpression:
return E.ConditionExpression(*_canonize_arguments([self.visit(cond), self.visit(value1), self.visit(value2)]))
[docs] @inline_args
def test(self, value: Any):
return self.visit(value)
[docs] @inline_args
def test_nocond(self, value: Any):
return self.visit(value)
[docs] @inline_args
def tuple(self, *values: Any):
return tuple(self.visit(v) for v in values)
[docs] @inline_args
def list(self, *values: Any):
return E.ListCreationExpression(_canonize_arguments([self.visit(v) for v in values]))
[docs] @inline_args
def cs_list(self, *values: Any):
return CSList(tuple(self.visit(v) for v in values))
[docs] @inline_args
def suite(self, *values: Tree) -> Suite:
with self.local_variable_guard():
values = [self.visit(value) for value in values]
return Suite(tuple(v for v in values if v is not None))
[docs] @inline_args
def expr_stmt(self, value: Tree):
value = self.visit(value)
if value is Ellipsis:
return None
return FunctionCall('expr', ArgumentsList((_canonize_single_argument(value),)))
[docs] @inline_args
def expr_list_expansion_stmt(self, value: Any):
value = _canonize_single_argument(self.visit(value))
return FunctionCall('expr', ArgumentsList((E.ListExpansionExpression(value), )))
[docs] @inline_args
def assign_stmt(self, target: Any, value: Any):
return FunctionCall('assign', ArgumentsList((_canonize_single_argument(self.visit(target)), _canonize_single_argument(self.visit(value)))))
[docs] @inline_args
def annotated_assign_stmt(self, annotations: dict, target: Any, value: Any):
return FunctionCall('assign', ArgumentsList((_canonize_single_argument(self.visit(target)), _canonize_single_argument(self.visit(value)))), annotations)
[docs] @inline_args
def local_assign_stmt(self, target: str, value: Any = None):
assert isinstance(target, str), f'Invalid local variable name: {target}'
value = _canonize_single_argument(self.visit(value))
self.local_variables[target] = E.VariableExpression(Variable(target, value.return_type))
if value is not None:
return FunctionCall('assign', ArgumentsList((self.local_variables[target], value)))
return None
[docs] @inline_args
def pass_stmt(self):
return FunctionCall('pass', ArgumentsList(tuple()))
[docs] @inline_args
def check_stmt(self, value: Any):
return FunctionCall('check', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs] @inline_args
def return_stmt(self, value: Any):
return FunctionCall('return', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs] @inline_args
def achieve_stmt(self, value: CSList):
return FunctionCall('achieve', ArgumentsList(_canonize_arguments(self.visit(value).items)))
[docs] @inline_args
def compound_assign_stmt(self, target: Variable, value: Any):
return FunctionCall('assign', ArgumentsList((self.visit(target), _canonize_single_argument(self.visit(value)))))
[docs] @inline_args
def compound_check_stmt(self, value: Any):
return FunctionCall('check', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs] @inline_args
def compound_return_stmt(self, value: Any):
return FunctionCall('return', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs] @inline_args
def compound_achieve_stmt(self, value: Any):
return FunctionCall('achieve', ArgumentsList(_canonize_arguments([self.visit(value)])))
[docs] @inline_args
def if_stmt(self, cond: Any, suite: Any, else_suite: Optional[Any] = None):
cond = _canonize_single_argument(self.visit(cond))
with self.local_variable_guard():
suite = self.visit(suite)
if else_suite is None:
else_suite = Suite((FunctionCall('pass', ArgumentsList(tuple())), ))
else:
with self.local_variable_guard():
else_suite = self.visit(else_suite)
return FunctionCall('if', ArgumentsList((cond, suite, else_suite)))
[docs] @inline_args
def forall_stmt(self, cs_list: Any, suite: Any):
cs_list = self.visit(cs_list)
with self.expression_def_ctx.new_variables(*cs_list.items), self.local_variable_guard():
suite = self.visit(suite)
return FunctionCall('forall', ArgumentsList((cs_list, suite)))
[docs] @inline_args
def forall_in_stmt(self, variables_cs_list: Any, values_cs_list: Any, suite: Any):
variables_cs_list = self.visit(variables_cs_list)
values_cs_list = self.visit(values_cs_list)
with self.local_variable_guard():
for variable_item, value_item in zip(variables_cs_list.items, values_cs_list.items):
self.local_variables[variable_item] = value_item
suite = self.visit(suite)
return FunctionCall('forall_in', ArgumentsList((suite, )))
def _quantification_expression(self, cs_list: Any, suite: Any, quantification_cls):
cs_list = self.visit(cs_list)
with self.expression_def_ctx.new_variables(*cs_list.items), self.local_variable_guard():
suite = self.visit(suite)
body = suite.get_combined_return_expression(allow_expr_expressions=True)
for item in reversed(cs_list.items):
body = quantification_cls(item, body)
return body
[docs] @inline_args
def forall_test(self, cs_list: Any, suite: Any):
return self._quantification_expression(cs_list, suite, E.ForallExpression)
[docs] @inline_args
def exists_test(self, cs_list: Any, suite: Any):
return self._quantification_expression(cs_list, suite, E.ExistsExpression)
[docs] @inline_args
def findall_test(self, variable: Variable, suite: Any):
with self.expression_def_ctx.new_variables(variable), self.local_variable_guard():
suite = self.visit(suite)
body = suite.get_combined_return_expression(allow_expr_expressions=True)
return E.FindAllExpression(variable, body)
def _quantification_in_expression(self, cs_list: Any, suite: Any, quantification_cls):
cs_list = self.visit(cs_list)
with self.local_variable_guard():
item: InTypedArgument
for item in cs_list.items:
self.local_variables[item.name] = self.visit(item.value)
suite = self.visit(suite)
body = suite.get_combined_return_expression(allow_expr_expressions=True)
return quantification_cls(body)
[docs] @inline_args
def forall_in_test(self, cs_list: Any, suite: Any):
return self._quantification_in_expression(cs_list, suite, E.AndExpression)
[docs] @inline_args
def exists_in_test(self, cs_list: Any, suite: Any):
return self._quantification_in_expression(cs_list, suite, E.OrExpression)
[docs] @inline_args
def find_stmt(self, cs_list: Any, suite: Any):
cs_list = self.visit(cs_list)
with self.expression_def_ctx.new_variables(*cs_list.items), self.local_variable_guard():
suite = self.visit(suite)
body = suite.get_combined_return_expression()
for item in cs_list.items:
self.local_variables[item.name] = E.VariableExpression(item)
return FunctionCall('find', ArgumentsList((cs_list, suite)))
[docs] @inline_args
def annotated_compound_stmt(self, annotations: dict, stmt: Any):
stmt = self.visit(stmt)
stmt.annotations = annotations
return stmt
[docs]@dataclass
class ArgumentsDef(object):
arguments: Tuple[Variable, ...]
[docs]@dataclass
class PreconditionPart(object):
suite: Tree
[docs]@dataclass
class EffectPart(object):
suite: Tree
[docs]@dataclass
class GoalPart(object):
suite: Tree
[docs]@dataclass
class BodyPart(object):
suite: Tree
[docs]@dataclass
class SideEffectPart(object):
suite: Tree
[docs]@dataclass
class ImplPart(object):
suite: Tree
[docs]@dataclass
class InPart(object):
suite: Tree
[docs]@dataclass
class OutPart(object):
suite: Tree
[docs]class PDSketchV3DomainTransformer(PDSketch3LiteralTransformer):
[docs] def __init__(self, domain: Optional[Domain] = None):
super().__init__()
self.domain = Domain(pdsketch_version=3) if domain is None else domain
self.expression_def_ctx = E.ExpressionDefinitionContext(domain=self.domain)
self.expression_interpreter = PDSketch3ExpressionInterpreter(domain=self.domain, state=None, expression_def_ctx=self.expression_def_ctx, auto_constant_guess=False)
domain: Domain
expression_def_ctx: E.ExpressionDefinitionContext
expression_interpreter: PDSketch3ExpressionInterpreter
[docs] @inline_args
def pragma_definition(self, pragma: Dict[str, Any]):
print('pragma_definition', pragma)
[docs] @inline_args
def type_definition(self, typename, basetype: Optional[Union[str, TypeBase]]):
print(f'type_definition:: {typename=} {basetype=}')
self.domain.define_type(typename, basetype)
[docs] @inline_args
def derived_feature_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], ret: Optional[Union[str, TypeBase]], suite: Optional[Tree]):
if annotations is None:
annotations = dict()
if args is None:
args = ArgumentsDef(tuple())
if ret is None:
ret = self.domain.get_type('bool')
elif isinstance(ret, str):
ret = self.domain.get_type(ret)
return_stmt = None
if suite is not None:
with self.expression_def_ctx.with_variables(*args.arguments):
suite = self.expression_interpreter.visit(suite)
return_stmt = suite.get_combined_return_expression(allow_expr_expressions=False)
if return_stmt is None:
self.domain.define_predicate(name, args.arguments, ret, **annotations)
else:
self.domain.define_derived(name, args.arguments, ret, return_stmt, **annotations)
print(f'derived_feature_definition:: {name=} {args.arguments=} {ret=} {annotations=} {suite=}')
if return_stmt is not None:
print(jacinle.indent_text(f'Return statement: {return_stmt}'))
[docs] @inline_args
def derived_function_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], ret: Optional[Union[str, TypeBase]], suite: Tree):
if annotations is None:
annotations = dict()
if args is None:
args = ArgumentsDef(tuple())
if ret is None:
ret = self.domain.get_type('bool')
elif isinstance(ret, str):
ret = self.domain.get_type(ret)
with self.expression_def_ctx.with_variables(*args.arguments):
suite = self.expression_interpreter.visit(suite)
return_stmt = suite.get_combined_return_expression(allow_expr_expressions=False)
if return_stmt is None:
self.domain.define_predicate(name, args.arguments, ret, state=False, observation=False, **annotations)
else:
self.domain.define_derived(name, args.arguments, ret, return_stmt, state=False, **annotations)
print(f'derived_function_definition:: {name=} {args.arguments=} {ret=} {annotations=} {suite=}')
if return_stmt is not None:
print(jacinle.indent_text(f'Return statement: {return_stmt}'))
[docs] @inline_args
def action_precondition_definition(self, suite: Tree) -> PreconditionPart:
return PreconditionPart(suite)
[docs] @inline_args
def action_effect_definition(self, suite: Tree) -> EffectPart:
return EffectPart(suite)
[docs] @inline_args
def action_impl_definition(self, suite: Tree) -> ImplPart:
return ImplPart(suite)
[docs] @inline_args
def action_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], *parts: Union[PreconditionPart, EffectPart, ImplPart]):
if annotations is None:
annotations = dict()
if args is None:
args = ArgumentsDef(tuple())
print(f'action_definition:: {name=} {args.arguments=} {annotations=}')
precondition = list()
effect = list()
implementation = None
for part in parts:
with self.expression_def_ctx.with_variables(*args.arguments):
suite = self.expression_interpreter.visit(part.suite)
if isinstance(part, PreconditionPart):
suite = suite.get_all_check_expressions()
precondition = [Precondition(x) for x in suite]
print(jacinle.indent_text(f'Precondition: {precondition}'))
elif isinstance(part, EffectPart):
suite = suite.get_all_assign_expressions()
effect = [Effect(x, **a) for x, a in suite]
print(jacinle.indent_text(f'Effect: {effect}'))
elif isinstance(part, ImplPart):
# TODO(Jiayuan Mao @ 2024/03/2): For now we just allow a single expression.
suite = suite.get_all_expr_expression(allow_multiple_expressions=False)
implementation = suite
print(jacinle.indent_text(f'Implementation: {implementation}'))
self.domain.define_operator(name, args.arguments, precondition, effect, controller=implementation, **annotations)
[docs] @inline_args
def generator_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], *parts: Union[PreconditionPart, GoalPart, InPart, OutPart]):
if annotations is None:
annotations = dict()
if args is None:
args = ArgumentsDef(tuple())
print(f'generator_definition:: {name=} {args.arguments=} {annotations=}')
inputs = list()
outputs = list()
goal = None
implementation = None
with self.expression_def_ctx.with_variables(*args.arguments):
for part in parts:
if isinstance(part, ImplPart):
continue
suite = self.expression_interpreter.visit(part.suite)
if isinstance(part, GoalPart):
goal = suite.get_all_expr_expression(allow_multiple_expressions=False)
print(jacinle.indent_text(f'Goal: {goal}'))
elif isinstance(part, InPart):
inputs = [E.VariableExpression(x) for x in suite.items]
print(jacinle.indent_text(f'Inputs: {inputs}'))
elif isinstance(part, OutPart):
outputs = [E.VariableExpression(x) for x in suite.items]
print(jacinle.indent_text(f'Outputs: {outputs}'))
else:
raise ValueError(f'Invalid part: {part}')
for part in parts:
if isinstance(part, ImplPart):
with self.expression_interpreter.set_generator_impl_outputs(outputs):
suite = self.expression_interpreter.visit(part.suite)
implementation = suite.get_all_expr_expression(allow_multiple_expressions=False)
print(jacinle.indent_text(f'Implementation: {implementation}'))
self.domain.define_generator(name, args.arguments, goal, inputs, outputs, implementation=implementation, **annotations)
[docs] @inline_args
def undirected_generator_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], *parts: Union[PreconditionPart, GoalPart]):
if annotations is None:
annotations = dict()
if args is None:
args = ArgumentsDef(tuple())
print(f'undirected_generator_definition:: {name=} {args.arguments=} {annotations=}')
# TODO(Jiayuan Mao @ 2024/03/10): Implement the undirected generator definition.
for part in parts:
with self.expression_def_ctx.with_variables(*args.arguments), self.expression_interpreter.set_generator_impl_outputs([]):
suite = self.expression_interpreter.visit(part.suite)
print(jacinle.indent_text(f'{part.__class__.__name__}: ' + str(suite)))
[docs] @inline_args
def generator_precondition_definition(self, suite: Tree) -> PreconditionPart:
return PreconditionPart(suite)
[docs] @inline_args
def generator_goal_definition(self, suite: Tree) -> GoalPart:
return GoalPart(suite)
[docs] @inline_args
def generator_in_definition(self, values: Tree) -> InPart:
return InPart(values)
[docs] @inline_args
def generator_out_definition(self, values: Tree) -> OutPart:
return OutPart(values)
[docs] @inline_args
def generator_impl_definition(self, suite: Tree) -> ImplPart:
return ImplPart(suite)
[docs] @inline_args
def regression_rule_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], *parts: Union[PreconditionPart, GoalPart, BodyPart]):
if annotations is None:
annotations = dict()
if args is None:
args = ArgumentsDef(tuple())
print(f'regression_rule_definition:: {name=} {args.arguments=} {annotations=}')
precondition = list()
goal = None
body = list()
side_effect = list()
for part in parts:
with self.expression_def_ctx.with_variables(*args.arguments):
suite = self.expression_interpreter.visit(part.suite)
if isinstance(part, BodyPart):
print(jacinle.indent_text(f'Body: {suite}'))
if isinstance(part, PreconditionPart):
suite = suite.get_all_check_expressions()
precondition = [Precondition(x) for x in suite]
print(jacinle.indent_text(f'Precondition: {precondition}'))
elif isinstance(part, GoalPart):
if len(suite.items) != 1:
raise NotImplementedError('Multiple goals are not supported in the current version.')
goal = suite.items[0]
print(jacinle.indent_text(f'Goal: {goal}'))
elif isinstance(part, BodyPart):
body = suite.get_all_regression_expressions(use_runtime_assign=True)
print(jacinle.indent_text(f'Body:'))
for x in body:
print(jacinle.indent_text(str(x), level=2))
elif isinstance(part, SideEffectPart):
suite = suite.get_all_assign_expressions()
side_effect = [Effect(x, **a) for x, a in suite]
print(jacinle.indent_text(f'SideEffect: {side_effect}'))
else:
raise ValueError(f'Invalid part: {part}')
self.domain.define_regression_rule(name, args.arguments, precondition, goal, side_effect, body, **annotations)
[docs] @inline_args
def regression_rule_precondition_definition(self, suite: Tree) -> PreconditionPart:
return PreconditionPart(suite)
[docs] @inline_args
def regression_rule_goal_definition(self, suite: Tree) -> GoalPart:
return GoalPart(suite)
[docs] @inline_args
def regression_rule_body_definition(self, suite: Tree) -> BodyPart:
return BodyPart(suite)
[docs] @inline_args
def regression_rule_side_effect_definition(self, suite: Tree) -> SideEffectPart:
return SideEffectPart(suite)
_parser = PDSketchV3Parser()
[docs]def load_domain_file3(filename:str) -> Domain:
"""Load a domain file.
Args:
filename: the filename of the domain file.
Returns:
the domain object.
"""
return _parser.parse_domain(filename)
[docs]def load_domain_string3(string: str) -> Domain:
"""Load a domain from a string.
Args:
string: the string containing the domain definition.
Returns:
the domain object.
"""
return _parser.parse_domain_str(string)
[docs]def load_domain_string3_incremental(domain: Domain, string: str) -> Domain:
"""Load a domain from a string incrementally.
Args:
domain: the domain object to be updated.
string: the string containing the domain definition.
Returns:
the domain object.
"""
return _parser.parse_domain_str(string, domain=domain)
[docs]def load_problem_file3(filename: str, domain: Optional[Domain] = None) -> Problem3:
"""Load a problem file.
Args:
filename: the filename of the problem file.
domain: the domain object. If not provided, the domain will be loaded from the domain file specified in the problem file.
Returns:
the problem object.
"""
return _parser.parse_problem(filename, domain=domain)
[docs]def load_problem_string3(string: str, domain: Optional[Domain] = None) -> Problem3:
"""Load a problem from a string.
Args:
string: the string containing the problem definition.
domain: the domain object. If not provided, the domain will be loaded from the domain file specified in the problem file.
Returns:
the problem object.
"""
return _parser.parse_problem_str(string, domain=domain)
[docs]def parse_expression3(domain: Domain, string: str, state: Optional[State] = None, variables: Optional[Sequence[Variable]] = None, auto_constant_guess: bool = True) -> E.Expression:
"""Parse an expression.
Args:
domain: the domain object.
string: the string containing the expression.
state: the current state, containing objects.
variables: the variables.
auto_constant_guess: whether to guess whether a variable is a constant.
Returns:
the parsed expression.
"""
return _parser.parse_expression(string, domain, state=state, variables=variables, auto_constant_guess=auto_constant_guess)