Source code for concepts.dm.crow.parsers.crow_parser

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : crow_parser.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 02/10/2024
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

import os
import os.path as osp
import contextlib
import jacinle

from typing import Any, Optional, Union, Sequence, Tuple, Set, List, Dict
from dataclasses import dataclass, field
from lark import Lark, Tree
from lark.visitors import Transformer, Interpreter, v_args
from lark.indenter import PythonIndenter

import concepts.dsl.expression as E
from concepts.dsl.dsl_types import TypeBase, AutoType, VectorValueType, BOOL, Variable, ObjectConstant, UnnamedPlaceholder

from concepts.dm.crow.function_utils import flatten_expression
from concepts.dm.crow.controller import CrowControllerApplicationExpression
from concepts.dm.crow.action import CrowPrecondition, CrowEffect, CrowActionBodyItem, CrowActionBodyPrimitiveBase, CrowActionBodySuiteBase
from concepts.dm.crow.action import CrowAchieveExpression, CrowBindExpression, CrowRuntimeAssignmentExpression, CrowAssertExpression, CrowActionApplicationExpression
from concepts.dm.crow.action import CrowActionConditionSuite, CrowActionWhileLoopSuite, CrowActionOrdering, CrowActionOrderingSuite
from concepts.dm.crow.generator import CrowGeneratorApplicationExpression
from concepts.dm.crow.domain import CrowDomain, CrowProblem, CrowState

logger = jacinle.get_logger(__name__)
inline_args = v_args(inline=True)


__all__ = [
    'PDSketchV3Parser', 'path_resolver',
    'PDSketchV3PathResolver', 'PDSketchV3DomainTransformer', 'PDSketchV3ProblemTransformer', 'PDSketch3LiteralTransformer', 'PDSketch3ExpressionInterpreter',
    'load_domain_file', 'load_domain_string', 'load_domain_string_incremental',
    'load_problem_file', 'load_problem_string'
]


[docs]class PDSketchV3PathResolver(object):
[docs] def __init__(self, search_paths: Sequence[str] = tuple()): self.search_paths = list(search_paths)
[docs] def resolve(self, filename: str, relative_filename: Optional[str] = None) -> str: if osp.exists(filename): return filename # Try the relative filename first. if relative_filename is not None: relative_dir = osp.dirname(relative_filename) full_path = osp.join(relative_dir, filename) if osp.exists(full_path): return full_path # Try the current directory second. if osp.exists(osp.join(os.getcwd(), filename)): return osp.join(os.getcwd(), filename) # Then try the search paths. for path in self.search_paths: full_path = osp.join(path, filename) if osp.exists(full_path): return full_path raise FileNotFoundError(f'File not found: {filename}')
[docs] def add_search_path(self, path: str): self.search_paths.append(path)
[docs] def remove_search_path(self, path: str): self.search_paths.remove(path)
path_resolver = PDSketchV3PathResolver()
[docs]class PDSketchV3Parser(object): """The parser for PDSketch v3.""" grammar_file = osp.join(osp.dirname(osp.abspath(__file__)), 'crow.grammar') """The grammar definition v3 for PDSketch."""
[docs] def __init__(self): """Initialize the parser.""" with open(self.grammar_file, 'r') as f: self.grammar = f.read() self.parser = Lark(self.grammar, start='start', postlex=PythonIndenter(), parser='lalr')
[docs] def parse(self, filename: str) -> Tree: """Parse a PDSketch v3 file. Args: filename: the filename to parse. Returns: the parse tree. It is a :class:`lark.Tree` object. """ filename = path_resolver.resolve(filename) with open(filename, 'r') as f: return self.parse_str(f.read())
[docs] def parse_str(self, s: str) -> Tree: """Parse a PDSketch v3 string. Args: s: the string to parse. Returns: the parse tree. It is a :class:`lark.Tree` object. """ # NB(Jiayuan Mao @ 2024/03/13): for reasons, the pdsketch-v3 grammar file requires that the string ends with a newline. # In particular, the suite definition requires that the file ends with a _DEDENT token, which seems to be only triggered by a newline. s = s.strip() + '\n' return self.parser.parse(s)
[docs] def parse_domain(self, filename: str) -> CrowDomain: """Parse a PDSketch v3 domain file. Args: filename: the filename to parse. Returns: the parsed domain. """ return self.transform_domain(self.parse(filename))
[docs] def parse_domain_str(self, s: str, domain: Optional[CrowDomain] = None) -> Any: """Parse a PDSketch v3 domain string. Args: s: the string to parse. domain: the domain to use. If not provided, a new domain will be created. Returns: the parsed domain. """ return self.transform_domain(self.parse_str(s), domain=domain)
[docs] def parse_problem(self, filename: str, domain: Optional[CrowDomain] = None) -> CrowProblem: """Parse a PDSketch v3 problem file. Args: filename: the filename to parse. domain: the domain to use. If not provided, the domain will be parsed from the problem file. Returns: the parsed problem. """ return self.transform_problem(self.parse(filename), domain=domain)
[docs] def parse_problem_str(self, s: str, domain: Optional[CrowDomain] = None) -> CrowProblem: """Parse a PDSketch v3 problem string. Args: s: the string to parse. domain: the domain to use. If not provided, the domain will be parsed from the problem file. Returns: the parsed problem. """ return self.transform_problem(self.parse_str(s), domain=domain)
[docs] def parse_expression(self, s: str, domain: CrowDomain, state: Optional[CrowState] = None, variables: Optional[Sequence[Variable]] = None, auto_constant_guess: bool = True) -> E.Expression: """Parse a PDSketch v3 expression string. Args: s: the string to parse. domain: the domain to use. state: the current state, containing objects. variables: variables from the outer scope. auto_constant_guess: whether to automatically guess whether a variable is a constant. Returns: the parsed expression. """ return self.transform_expression(self.parse_str(s), domain, state=state, variables=variables, auto_constant_guess=auto_constant_guess)
[docs] @staticmethod def transform_domain(tree: Tree, domain: Optional[CrowDomain] = None) -> CrowDomain: """Transform a parse tree into a domain. Args: tree: the parse tree. domain: the domain to use. If not provided, a new domain will be created. Returns: the parsed domain. """ transformer = PDSketchV3DomainTransformer(domain) transformer.transform(tree) return transformer.domain
[docs] @staticmethod def transform_problem(tree: Tree, domain: Optional[CrowDomain] = None) -> CrowProblem: """Transform a parse tree into a problem. Args: tree: the parse tree. domain: the domain to use. If not provided, the domain will be parsed from the problem file. Returns: the parsed problem. """ transformer = PDSketchV3ProblemTransformer(domain) transformer.transform(tree) return transformer.problem
[docs] @staticmethod def transform_expression(tree: Tree, domain: CrowDomain, state: Optional[CrowState] = None, variables: Optional[Sequence[Variable]] = None, auto_constant_guess: bool = True) -> E.Expression: """Transform a parse tree into an expression. Args: tree: the parse tree. domain: the domain to use. state: the current state, containing objects. variables: variables from the outer scope. auto_constant_guess: whether to automatically guess whether a variable is a constant. Returns: the parsed expression. """ transformer = PDSketchV3ProblemTransformer(domain, state, auto_constant_guess=auto_constant_guess) interpreter = transformer.expression_interpreter expression_def_ctx = transformer.expression_def_ctx # the root of the tree is the `start` rule. tree = transformer.transform(tree).children[0] if variables is None: variables = tuple() with expression_def_ctx.with_variables(*variables) as ctx: return interpreter.visit(tree)
[docs]@dataclass class LiteralValue(object): """A literal value.""" value: Union[bool, int, float, complex, str]
[docs]@dataclass class LiteralList(object): """A list of literals.""" items: Tuple[Union[bool, int, float, complex, str], ...]
[docs]@dataclass class LiteralSet(object): """A set of literals.""" items: Set[Union[bool, int, float, complex, str]]
[docs]@dataclass class InTypedArgument(object): """A typed argument defined as `name in value`. This is used in forall/exists statements.""" name: str value: Any
[docs]class PDSketch3LiteralTransformer(Transformer): """The transformer for literal types. Including: - VARNAME, CONSTNAME, BASIC_TYPENAME - number, DEC_NUMBER, HEX_NUMBER, BIN_NUMBER, OCT_NUMBER, FLOAT_NUMBER, IMAG_NUMBER - boolean, TRUE, FALSE - string - literal_list - literal_set - decorator_k, decorator_kwarg, decorator_kwargs """ domain: CrowDomain
[docs] @inline_args def typename(self, name: Union[str, TypeBase]) -> TypeBase: """Captures typenames including basic types and vector types.""" return name
[docs] @inline_args def sized_vector_typename(self, name: Union[str, TypeBase], size: int) -> VectorValueType: """Captures sized vector typenames defined as `vector[typename, size]`.""" return VectorValueType(self.domain.get_type(name), size)
[docs] @inline_args def unsized_vector_typename(self, name: Union[str, TypeBase]) -> VectorValueType: """Captures unsized vector typenames defined as `vector[typename]`.""" return VectorValueType(self.domain.get_type(name))
[docs] @inline_args def typed_argument(self, name: str, typename: Union[str, TypeBase]) -> Variable: """Captures typed arguments defined as `name: typename`.""" if isinstance(typename, str): typename = self.domain.get_type(typename) return Variable(name, typename)
[docs] @inline_args def is_typed_argument(self, name: str, typename: Union[str, TypeBase]) -> Variable: """Captures typed arguments defined as `name is typename`. This is used in forall/exists statements.""" if isinstance(typename, str): typename = self.domain.get_type(typename) return Variable(name, typename)
[docs] @inline_args def in_typed_argument(self, name: str, value: Any) -> InTypedArgument: """Captures typed arguments defined as `name in value`. This is used in forall/exists statements.""" return InTypedArgument(name, value)
[docs] def arguments_def(self, args): """Captures the arguments definition. This is used in function definitions.""" return ArgumentsDef(tuple(args))
[docs] def VARNAME(self, token): """Captures variable names, such as `var_name`.""" return token.value
[docs] def CONSTNAME(self, token): """Captures constant names, such as `CONST_NAME`.""" return token.value
[docs] def BASIC_TYPENAME(self, token): """Captures basic type names (non-vector types), such as `int`, `float`, `bool`, `object`, etc.""" return token.value
[docs] @inline_args def number(self, value: Union[int, float, complex]) -> Union[int, float, complex]: """Captures number literals, including integers, floats, and complex numbers.""" return value
[docs] @inline_args def BIN_NUMBER(self, value: str) -> int: """Captures binary number literals.""" return int(value, 2)
[docs] @inline_args def OCT_NUMBER(self, value: str) -> int: """Captures octal number literals.""" return int(value, 8)
[docs] @inline_args def DEC_NUMBER(self, value: str) -> int: """Captures decimal number literals.""" return int(value)
[docs] @inline_args def HEX_NUMBER(self, value: str) -> int: """Captures hexadecimal number literals.""" return int(value, 16)
[docs] @inline_args def FLOAT_NUMBER(self, value: str) -> float: """Captures floating point number literals.""" return float(value)
[docs] @inline_args def IMAG_NUMBER(self, value: str) -> complex: """Captures complex number literals.""" return complex(value)
[docs] @inline_args def boolean(self, value: bool) -> bool: """Captures boolean literals.""" return value
[docs] @inline_args def TRUE(self, _) -> bool: """Captures the `True` literal.""" return True
[docs] @inline_args def FALSE(self, _) -> bool: """Captures the `False` literal.""" return False
[docs] @inline_args def ELLIPSIS(self, _) -> str: """Captures the `...` literal.""" return Ellipsis
[docs] @inline_args def string(self, value: str) -> str: """Captures string literals.""" if value[0] == value[-1] and value[0] in ('"', "'"): value = value[1:-1] return str(value)
[docs] @inline_args def literal_list(self, *items: Any) -> LiteralList: """Captures literal lists, such as `[1, 2, 3, 4]`.""" return LiteralList(tuple(items))
[docs] @inline_args def literal_set(self, *items: Any) -> LiteralSet: """Captures literal sets, such as `{1, 2, 3, 4}`.""" return LiteralSet(set(items))
[docs] @inline_args def literal(self, value: Union[bool, int, float, complex, str, LiteralList, LiteralSet]) -> Union[LiteralValue, LiteralList, LiteralSet]: """Captures literal values.""" if isinstance(value, (bool, int, float, complex, str)): return LiteralValue(value) elif isinstance(value, (LiteralList, LiteralSet)): return value else: raise ValueError(f'Invalid literal value: {value}')
[docs] @inline_args def decorator_kwarg(self, k, v: Union[LiteralValue, LiteralList, LiteralSet] = True) -> Tuple[str, Union[bool, int, float, complex, str, LiteralList, LiteralSet]]: """Captures the key-value pair of a decorator. This is used in the decorator syntax, such as [[k=True]].""" return k, v.value if isinstance(v, LiteralValue) else v
[docs] def decorator_kwargs(self, args) -> Dict[str, Union[bool, int, float, complex, str, LiteralList, LiteralSet]]: """Captures the key-value pairs of a decorator. This is used in the decorator syntax, such as [[k=True, k2=123, k3=[1, 2, 3]]].""" return {k: v for k, v in args}
[docs]@dataclass class ArgumentsList(object): """A list of argument values. They can be variables, function calls, or other expressions.""" arguments: Tuple[Union['Suite', E.ValueOutputExpression, E.ListExpansionExpression, E.VariableExpression, bool, int, float, complex, str], ...]
[docs]@dataclass class FunctionCall(object): """A function call. This is used as the intermediate representation of the parsed expressions. Note that this includes not only function calls but also primitive operators and control flow statements. """ name: str args: ArgumentsList annotations: Optional[Dict[str, Any]] = None def __str__(self): annotation_str = '' if self.annotations is not None: annotation_str = f'[[' + ', '.join(f'{k}={v}' for k, v in self.annotations.items()) + ']] ' arg_strings = [str(arg) for arg in self.args.arguments] if sum(len(arg) for arg in arg_strings) > 80: arg_strings = [jacinle.indent_text(arg) for arg in arg_strings] return f'{annotation_str}{self.name}:\n' + '\n'.join(arg_strings) return f'{annotation_str}{self.name}(' + ', '.join(arg_strings) + ')' def __repr__(self): return f'FunctionCall{{{str(self)}}}'
[docs]@dataclass class CSList(object): """A comma-separated list of something.""" items: Tuple[Any, ...] def __str__(self): return f'CSList({", ".join(str(item) for item in self.items)})' def __repr__(self): return self.__str__()
[docs]@dataclass class Suite(object): """A suite of statements. This is used as the intermediate representation of the parsed expressions.""" items: Tuple[Any, ...] local_variables: Dict[str, Any] = field(default_factory=dict) tracker: Optional['FunctionCallTracker'] = None def _init_tracker(self, use_runtime_assign: bool = False): self.tracker = FunctionCallTracker(self, dict(), use_runtime_assign=use_runtime_assign).run()
[docs] def get_all_assign_expressions(self) -> List[Tuple[E.VariableAssignmentExpression, Dict[str, Any]]]: if self.tracker is None: self._init_tracker() return self.tracker.assign_expressions
[docs] def get_all_check_expressions(self) -> List[E.ValueOutputExpression]: if self.tracker is None: self._init_tracker() return self.tracker.check_expressions
[docs] def get_all_action_expressions(self, use_runtime_assign=True) -> List[CrowActionBodyItem]: if self.tracker is None: self._init_tracker(use_runtime_assign=use_runtime_assign) return self.tracker.action_expressions
[docs] def get_all_expr_expression(self, allow_multiple_expressions: bool = False) -> Optional[Union[E.ValueOutputExpression, Tuple[E.ValueOutputExpression, ...]]]: if self.tracker is None: self._init_tracker() if len(self.tracker.expr_expressions) == 1: return self.tracker.expr_expressions[0] if not allow_multiple_expressions: raise ValueError(f'Multiple expressions found in a single suite: {self.tracker.expr_expressions}') if len(self.tracker.expr_expressions) == 0: return tuple() return tuple(self.tracker.expr_expressions)
[docs] def get_combined_return_expression(self, allow_expr_expressions: bool = False, allow_multiple_expressions: bool = False) -> Optional[E.ValueOutputExpression]: if self.tracker is None: self._init_tracker() if self.tracker.return_expression is not None: return self.tracker.return_expression if allow_expr_expressions: return self.get_all_expr_expression(allow_multiple_expressions=allow_multiple_expressions) return None
def __str__(self): if len(self.items) == 0: return 'Suite{}' if len(self.items) == 1: return f'Suite{{{self.items[0]}}}' return 'Suite{\n' + '\n'.join(jacinle.indent_text(str(item)) for item in self.items) + '\n}' def __repr__(self): return self.__str__()
[docs]class FunctionCallTracker(object): """This class is used to track the function calls and other statements in a suite. It supports simulating the execution of the program and generating the post-condition of the program. """
[docs] def __init__(self, suite: Suite, init_local_variables: Optional[Dict[str, Any]] = None, use_runtime_assign: bool = False): self.suite = suite self.local_variables = dict() if init_local_variables is None else init_local_variables self.assign_expressions = list() self.assign_expression_signatures = dict() self.check_expressions = list() self.expr_expressions = list() self.action_expressions = list() self.return_expression = None self.use_runtime_assign = use_runtime_assign
local_variables_stack: List[Dict[str, Any]] """The assignments of local variables.""" assign_expressions: List[Tuple[E.VariableAssignmentExpression, Dict[str, Any]]] """A list of assign expressions.""" assign_expression_signatures: Dict[Tuple[str, ...], E.VariableAssignmentExpression] """A dictionary of assign expressions, indexed by their signatures.""" check_expressions: List[E.ValueOutputExpression] """A list of check expressions.""" expr_expressions: List[E.ValueOutputExpression] """A list of expr expressions (i.e. bare expressions in the body).""" return_expression: Optional[E.ValueOutputExpression] """The return expression. This is either None or a single expression.""" use_runtime_assign: bool """Whether to use runtime assign expressions. If this is set to True, assignment expressions for local variables will be converted into runtime assign expressions.""" action_expressions: List[CrowActionBodyItem] def _g( self, expr: Union[E.Expression, UnnamedPlaceholder, CrowActionBodyItem] ) -> Union[E.Expression, UnnamedPlaceholder, CrowActionBodyItem]: if isinstance(expr, (CrowControllerApplicationExpression, CrowActionApplicationExpression, CrowGeneratorApplicationExpression)): return expr if isinstance(expr, (CrowActionBodyPrimitiveBase, CrowActionBodySuiteBase)): return expr if not isinstance(expr, E.Expression): raise ValueError(f'Invalid expression: {expr}') return flatten_expression(expr, { E.VariableExpression(Variable(k, None)): v for k, v in self.local_variables.items() }) def _get_deictic_signature(self, e, known_deictic_vars=tuple()) -> Optional[Tuple[str, ...]]: if isinstance(e, E.DeicticAssignExpression): known_deictic_vars = known_deictic_vars + (e.variable.name,) return self._get_deictic_signature(e.expression, known_deictic_vars) elif isinstance(e, E.AssignExpression): args = [x.name if x.name not in known_deictic_vars else '?' for x in e.predicate.arguments] return tuple((e.predicate.function.name, *args)) else: return None def _mark_assign(self, *exprs: E.VariableAssignmentExpression, annotations: Optional[Dict[str, Any]] = None): if annotations is None: annotations = dict() for expr in exprs: signature = self._get_deictic_signature(expr) if signature is not None: if signature in self.assign_expression_signatures: raise ValueError(f'Duplicate assign expression: {expr} vs {self.assign_expression_signatures[signature]}') self.assign_expressions.append((expr, annotations)) self.assign_expression_signatures[signature] = expr else: self.assign_expressions.append((expr, annotations))
[docs] @jacinle.log_function(verbose=False) def run(self): """Simulate the execution of the program and generates an equivalent return statement. This function handles if-else conditions. However, loops are not allowed.""" # jacinle.log_function.print('Current suite:', self.suite) current_return_statement = None current_return_statement_condition_neg = None for item in self.suite.items: assert isinstance(item, FunctionCall), f'Invalid item in suite: {item}' if item.name == 'assign': if isinstance(item.args.arguments[0], E.FunctionApplicationExpression) and item.args.arguments[1] is Ellipsis: self._mark_assign(self._g(E.AssignExpression(item.args.arguments[0], E.NullExpression(item.args.arguments[0].return_type))), annotations=item.annotations) elif isinstance(item.args.arguments[0], E.FunctionApplicationExpression) and isinstance(item.args.arguments[1], (E.ValueOutputExpression, E.VariableExpression)): self._mark_assign(self._g(E.AssignExpression(item.args.arguments[0], item.args.arguments[1])), annotations=item.annotations) elif isinstance(item.args.arguments[0], E.VariableExpression) and isinstance(item.args.arguments[1], (E.ValueOutputExpression, E.VariableExpression, E.FindAllExpression)): if self.use_runtime_assign: if item.args.arguments[0].name not in self.local_variables: self.local_variables[item.args.arguments[0].name] = E.VariableExpression(item.args.arguments[0]) self.action_expressions.append(CrowRuntimeAssignmentExpression( item.args.arguments[0].variable, self._g(item.args.arguments[1]) )) else: self.local_variables[item.args.arguments[0].name] = self._g(item.args.arguments[1]) else: raise ValueError(f'Invalid assignment: {item}. Types: {type(item.args.arguments[0])}, {type(item.args.arguments[1])}.') elif item.name == 'check': assert isinstance(item.args.arguments[0], (E.ValueOutputExpression, E.VariableExpression)), f'Invalid check expression: {item.args.arguments[0]}' self.check_expressions.append(self._g(item.args.arguments[0])) elif item.name == 'expr': if isinstance(item.args.arguments[0], (CrowControllerApplicationExpression, CrowActionApplicationExpression, E.ListExpansionExpression)): self.action_expressions.append(self._g(item.args.arguments[0])) else: assert isinstance( item.args.arguments[0], (E.ValueOutputExpression, E.VariableExpression, CrowGeneratorApplicationExpression) ), f'Invalid expr expression: {item.args.arguments[0]}' self.expr_expressions.append(self._g(item.args.arguments[0])) elif item.name == 'bind': arguments = item.args.arguments[0].items body = item.args.arguments[1] self.action_expressions.append(CrowBindExpression(arguments, body, **item.annotations if item.annotations is not None else dict())) elif item.name == 'achieve': term = item.args.arguments[0] self.action_expressions.append(CrowAchieveExpression(term, **item.annotations if item.annotations is not None else dict())) elif item.name == 'assert': term = item.args.arguments[0] self.action_expressions.append(CrowAssertExpression(term, **item.annotations if item.annotations is not None else dict())) elif item.name == 'return': assert isinstance(item.args.arguments[0], (E.ValueOutputExpression, E.VariableExpression)), f'Invalid return expression: {item.args.arguments[0]}' self.return_expression = _make_conditional_return(current_return_statement, current_return_statement_condition_neg, self._g(item.args.arguments[0])) break elif item.name == 'ordering': suite = item.args.arguments[1] tracker = FunctionCallTracker(suite, self.local_variables.copy(), use_runtime_assign=self.use_runtime_assign).run() self.local_variables = tracker.local_variables action_expressions = tracker.action_expressions if item.args.arguments[0] == 'promotable unordered': prog = CrowActionOrderingSuite('promotable', (CrowActionOrderingSuite('unordered', action_expressions), )) elif item.args.arguments[0] == 'promotable sequential': prog = CrowActionOrderingSuite('promotable', action_expressions) else: assert ' ' not in item.args.arguments[0], f'Invalid ordering type: {item.args.arguments[0]}' prog = CrowActionOrderingSuite(item.args.arguments[0], action_expressions) self.action_expressions.append(prog) elif item.name == 'if': condition = self._g(item.args.arguments[0]) neg_condition = E.NotExpression(condition) assert isinstance(condition, E.ValueOutputExpression), f'Invalid condition: {condition}. Type: {type(condition)}.' t_suite = item.args.arguments[1] f_suite = item.args.arguments[2] t_tracker = FunctionCallTracker(t_suite, self.local_variables.copy(), use_runtime_assign=self.use_runtime_assign).run() f_tracker = FunctionCallTracker(f_suite, self.local_variables.copy(), use_runtime_assign=self.use_runtime_assign).run() assert set(t_tracker.local_variables.keys()) == set(f_tracker.local_variables.keys()), f'Local variables in the true and false branches are not consistent: {t_tracker.local_variables.keys()} vs {f_tracker.local_variables.keys()}' new_local_variables = t_tracker.local_variables for k, v in t_tracker.local_variables.items(): if f_tracker.local_variables[k] != v: new_local_variables[k] = E.ConditionExpression(condition, v, f_tracker.local_variables[k]) self.local_variables = new_local_variables # TODO(Jiayuan Mao @ 2024/03/2): optimize the implementation for this by merging the conditions. for expr, annotations in t_tracker.assign_expressions: self._mark_assign(_make_conditional_assign(expr, condition), annotations=annotations) for expr, annotations in f_tracker.assign_expressions: self._mark_assign(_make_conditional_assign(expr, neg_condition), annotations=annotations) for expr in t_tracker.check_expressions: self.check_expressions.append(_make_conditional_implies(condition, expr)) for expr in f_tracker.check_expressions: self.check_expressions.append(_make_conditional_implies(neg_condition, expr)) if len(t_tracker.expr_expressions) != len(f_tracker.expr_expressions): raise ValueError(f'Number of bare expressions in the true and false branches are not consistent: {len(t_tracker.expr_expressions)} vs {len(f_tracker.expr_expressions)}') if len(t_tracker.expr_expressions) == 0: pass elif len(t_tracker.expr_expressions) == 1: self.expr_expressions.append(E.ConditionExpression(condition, t_tracker.expr_expressions[0], f_tracker.expr_expressions[0])) else: raise ValueError(f'Multiple bare expressions in the true and false branches are not supported: {t_tracker.expr_expressions} vs {f_tracker.expr_expressions}') if len(t_tracker.action_expressions) > 0: self.action_expressions.append(CrowActionConditionSuite(condition, t_tracker.action_expressions)) if len(f_tracker.action_expressions) > 0: self.action_expressions.append(CrowActionConditionSuite(neg_condition, f_tracker.action_expressions)) if t_tracker.return_expression is not None and f_tracker.return_expression is not None: # Both branches have return statements. statement = E.ConditionExpression(condition, t_tracker.return_expression, f_tracker.return_expression) self.return_expression = _make_conditional_return(current_return_statement, current_return_statement_condition_neg, statement) break elif t_tracker.return_expression is not None: current_return_statement = _make_conditional_return(current_return_statement, current_return_statement_condition_neg, t_tracker.return_expression) current_return_statement_condition_neg = E.NotExpression(condition) elif f_tracker.return_expression is not None: current_return_statement = _make_conditional_return(current_return_statement, current_return_statement_condition_neg, f_tracker.return_expression) current_return_statement_condition_neg = condition else: pass elif item.name == 'forall': suite = item.args.arguments[1] tracker = FunctionCallTracker(suite, self.local_variables.copy(), use_runtime_assign=self.use_runtime_assign).run() for k in tracker.local_variables: if k in self.local_variables and self.local_variables[k] != tracker.local_variables[k]: raise ValueError(f'Local variable {k} is assigned in the forall statement but has been assigned before: {self.local_variables[k]} vs {tracker.local_variables[k]}') for expr, annotations in tracker.assign_expressions: for var in item.args.arguments[0].items: expr = E.DeicticAssignExpression(var, expr) self._mark_assign(expr, annotations=annotations) for expr in tracker.check_expressions: for var in item.args.arguments[0].items: expr = E.ForallExpression(var, expr) self.check_expressions.append(expr) if len(tracker.expr_expressions) == 0: pass else: if len(tracker.expr_expressions) == 1: merged = tracker.expr_expressions[0] else: merged = E.AndExpression(*tracker.expr_expressions) for var in item.args.arguments[0].items: merged = E.ForallExpression(var, merged) self.expr_expressions.append(merged) if len(tracker.action_expressions) > 0: raise ValueError(f'Actions are not allowed in a forall statement: {tracker.action_expressions}') if tracker.return_expression is not None: raise ValueError(f'Return statement is not allowed in a forall statement: {tracker.return_expression}') elif item.name == 'forall_in': suite = item.args.arguments[0] tracker = FunctionCallTracker(suite, self.local_variables.copy(), use_runtime_assign=self.use_runtime_assign).run() for k in tracker.local_variables: if k in self.local_variables and self.local_variables[k] != tracker.local_variables[k]: raise ValueError(f'Local variable {k} is assigned in the forall statement but has been assigned before: {self.local_variables[k]} vs {tracker.local_variables[k]}') for expr, annotations in tracker.assign_expressions: raise ValueError(f'Assign statements are not allowed in a forall_in statement: {expr}') for expr in tracker.check_expressions: self.check_expressions.append(E.AndExpression(expr)) if len(tracker.expr_expressions) == 0: pass else: if len(tracker.expr_expressions) == 1: merged = tracker.expr_expressions[0] else: merged = E.AndExpression(*tracker.expr_expressions) self.expr_expressions.append(merged) if len(tracker.action_expressions) > 0: # TODO(Jiayuan Mao @ 2024/03/12): implement the rest parts of action statements. for expr in tracker.action_expressions: if isinstance(expr, CrowAchieveExpression): self.action_expressions.append(E.ListExpansionExpression(E.AndExpression(expr.goal))) else: raise ValueError(f'Action items except for achieve statements are not allowed in a forall_in statement: {expr}') if tracker.return_expression is not None: raise ValueError(f'Return statement is not allowed in a forall_in statement: {tracker.return_expression}') elif item.name == 'pass': pass # jacinle.log_function.print('Local variables:', self.local_variables) # jacinle.log_function.print('Assign expressions:', self.assign_expressions) # jacinle.log_function.print('Check expressions:', self.check_expressions) # jacinle.log_function.print('Expr expressions:', self.expr_expressions) # jacinle.log_function.print('Return expression:', self.return_expression) return self
def _make_conditional_implies(condition: E.ValueOutputExpression, test: E.ValueOutputExpression): if isinstance(test, E.BoolExpression) and test.op == E.BoolOpType.IMPLIES: if isinstance(test.arguments[0], E.BoolExpression) and test.arguments[0].op == E.BoolOpType.AND: return E.ImpliesExpression(E.AndExpression(condition, *test.arguments[0].arguments), test.arguments[1]) else: return E.ImpliesExpression(E.AndExpression(condition, test.arguments[0]), test.arguments[1]) else: return E.ImpliesExpression(condition, test) def _make_conditional_return(current_stmt: Optional[E.ValueOutputExpression], current_condition_neg: Optional[E.ValueOutputExpression], new_stmt: E.ValueOutputExpression): if current_stmt is None: return new_stmt return E.ConditionExpression(current_condition_neg, new_stmt, current_stmt) def _make_conditional_assign(assign_stmt: E.VariableAssignmentExpression, condition: E.ValueOutputExpression): if isinstance(assign_stmt, E.AssignExpression): return E.ConditionalAssignExpression(assign_stmt.predicate, assign_stmt.value, condition) elif isinstance(assign_stmt, E.ConditionalAssignExpression): if isinstance(assign_stmt.condition, E.BoolExpression) and assign_stmt.condition.op == E.BoolOpType.AND: return E.ConditionalAssignExpression(assign_stmt.predicate, assign_stmt.value, E.AndExpression(condition, *assign_stmt.condition.arguments)) else: return E.ConditionalAssignExpression(assign_stmt.predicate, assign_stmt.value, E.AndExpression(condition, assign_stmt.condition)) elif isinstance(assign_stmt, E.DeicticAssignExpression): return E.DeicticAssignExpression(assign_stmt.variable, _make_conditional_assign(assign_stmt.expression, condition)) else: raise ValueError(f'Invalid assign statement: {assign_stmt}')
[docs]def gen_term_expr(expr_typename: str): """Generate a term expression function. This function is used to generate the term expression functions for the transformer. It is used to generate the following functions: - mul_expr - arith_expr - shift_expr Args: expr_typename: the name of the expression type. This is only used for printing the debug information. Returns: the generated term expression function. """ @inline_args def term(self, *values: Any): values = [self.visit(value) for value in values] if len(values) == 1: return values[0] raise NotImplementedError(f'{expr_typename} expression is not supported in the current version.') assert len(values) % 2 == 1, f'[{expr_typename}] expressions expected an odd number of values, got {len(values)}. Values: {values}.' result = values[0] for i in range(1, len(values), 2): result = FunctionCall(values[i], ArgumentsList((result, values[i + 1]))) # print(f'[{expr_typename}] result: {result}') return result return term
[docs]def gen_term_expr_noop(expr_typename: str): """Generate a term expression function. This function is used to generate the term expression functions for the transformer. It is named `_noop` because the arguments to the function does not contain the operator being used. Therefore, we have to specify the operator name manually (`expr_typename`). This is used for the following functions: - bitand_expr - bitxor_expr - bitor_expr """ op_mapping = { 'bitand': E.BoolOpType.AND, 'bitxor': E.BoolOpType.XOR, 'bitor': E.BoolOpType.OR, } @inline_args def term(self, *values: Any): values = [self.visit(value) for value in values] if len(values) == 1: return values[0] result = E.BoolExpression(op_mapping[expr_typename], _canonize_arguments(values)) # print(f'[{expr_typename}] result: {result}') return result return term
def _canonize_single_argument(arg: Any, dtype: Optional[TypeBase] = None) -> Union[E.ObjectOrValueOutputExpression, E.VariableExpression, type(Ellipsis)]: if isinstance(arg, E.ObjectOrValueOutputExpression): return arg if isinstance(arg, (CrowControllerApplicationExpression, CrowActionApplicationExpression, CrowGeneratorApplicationExpression)): return arg if isinstance(arg, Variable): return E.VariableExpression(arg) if isinstance(arg, (bool, int, float, complex)): return E.ConstantExpression.from_value(arg, dtype=dtype) if arg is Ellipsis: return Ellipsis raise ValueError(f'Invalid argument: {arg}. Type: {type(arg)}.') def _canonize_arguments(args: Optional[Union[ArgumentsList, tuple, list]] = None, dtypes: Optional[Sequence[TypeBase]] = None) -> Tuple[Union[E.ValueOutputExpression, E.VariableExpression], ...]: if args is None: return tuple() canonized_args = list() # TODO(Jiayuan Mao @ 2024/03/2): Strictly check the allowability of "Ellipsis" in the arguments. arguments = args.arguments if isinstance(args, ArgumentsList) else args if dtypes is not None: if len(arguments) != len(dtypes): raise ValueError(f'Number of arguments does not match the number of types: {len(arguments)} vs {len(dtypes)}. Args: {arguments}, Types: {dtypes}') for i, arg in enumerate(arguments): if arg is Ellipsis: canonized_args.append(Ellipsis) else: canonized_args.append(_canonize_single_argument(arg, dtype=dtypes[i] if dtypes is not None else None)) return tuple(canonized_args) def _has_list_arguments(args: Tuple[E.ObjectOrValueOutputExpression, ...]) -> bool: for arg in args: if arg.return_type.is_list_type: return True return False
[docs]class PDSketch3ExpressionInterpreter(Interpreter): """The transformer for expressions. Including: - typename - sized_vector_typename - unsized_vector_typename - typed_argument - is_typed_argument - in_typed_argument - arguments_def - atom_expr_funccall - atom_varname - atom - power - factor - unary_op_expr - mul_expr - arith_expr - shift_expr - bitand_expr - bitxor_expr - bitor_expr - comparison_expr - not_test - and_test - or_test - cond_test - test - test_nocond - tuple - list - cs_list - suite - expr_stmt - expr_list_expansion_stmt - assign_stmt - annotated_assign_stmt - local_assign_stmt - pass_stmt - check_stmt - return_stmt - achieve_stmt - compound_assign_stmt - compound_check_stmt - compound_return_stmt - compound_achieve_stmt - if_stmt - forall_stmt - forall_test - exists_test - findall_test - forall_in_test - exists_in_test """
[docs] def __init__(self, domain: Optional[CrowDomain], state: Optional[CrowState], expression_def_ctx: E.ExpressionDefinitionContext, auto_constant_guess: bool = False): super().__init__() self.domain = domain self.state = state self.expression_def_ctx = expression_def_ctx self.auto_constant_guess = auto_constant_guess self.generator_impl_outputs = None self.local_variables = dict()
[docs] def set_domain(self, domain: CrowDomain): self.domain = domain
[docs] @contextlib.contextmanager def local_variable_guard(self): backup = self.local_variables.copy() yield self.local_variables = backup
[docs] @contextlib.contextmanager def set_generator_impl_outputs(self, outputs: List[Variable]): backup = self.generator_impl_outputs self.generator_impl_outputs = outputs yield self.generator_impl_outputs = backup
domain: CrowDomain expression_def_ctx: E.ExpressionDefinitionContext
[docs] def visit(self, tree: Any) -> Any: if isinstance(tree, Tree): return super().visit(tree) return tree
[docs] @inline_args def atom_varname(self, name: str) -> Union[E.VariableExpression, E.ObjectConstantExpression, E.ValueOutputExpression]: """Captures variable names such as `var_name`.""" if name in self.local_variables: return self.local_variables[name] if self.state is not None and name in self.state.object_name2defaultindex: return E.ObjectConstantExpression(ObjectConstant(name, self.domain.types[self.state.get_typename(name)])) # TODO(Jiayuan Mao @ 2024/03/12): smartly guess the type of the variable. if not self.expression_def_ctx.has_variable(name): if self.auto_constant_guess: return E.ObjectConstantExpression(ObjectConstant(name, AutoType)) variable = self.expression_def_ctx.wrap_variable(name) return variable
[docs] @inline_args def atom_expr_funccall(self, annotations: dict, name: str, args: Tree) -> Union[E.FunctionApplicationExpression, E.ListFunctionApplicationExpression, CrowActionBodyItem]: """Captures function calls, such as `func_name(arg1, arg2, ...)`.""" annotations: Optional[dict] = self.visit(annotations) args: Optional[ArgumentsList] = self.visit(args) if annotations is None: annotations = dict() if args is None: args = ArgumentsList(tuple()) if self.domain.has_feature(name): function = self.domain.get_feature(name) args_c = _canonize_arguments(args, function.ftype.argument_types) if _has_list_arguments(args_c): return E.ListFunctionApplicationExpression(function, args_c) return E.FunctionApplicationExpression(function, args_c) elif self.domain.has_function(name): function = self.domain.get_function(name) args_c = _canonize_arguments(args, function.ftype.argument_types) if _has_list_arguments(args_c): return E.ListFunctionApplicationExpression(function, args_c) return E.FunctionApplicationExpression(function, args_c) elif self.domain.has_controller(name): controller = self.domain.get_controller(name) args_c = _canonize_arguments(args, controller.argument_types) return CrowControllerApplicationExpression(controller, args_c) elif self.domain.has_action(name): action = self.domain.get_action(name) args_c = _canonize_arguments(args, action.argument_types) if len(args_c) > 0 and args_c[-1] is Ellipsis: args_c = args_c[:-1] + tuple([UnnamedPlaceholder(t) for t in action.argument_types[len(args_c) - 1:]]) return CrowActionApplicationExpression(action, args_c) elif self.domain.has_generator(name): generator = self.domain.get_generator(name) args_c = _canonize_arguments(args, generator.argument_types) return CrowGeneratorApplicationExpression(generator, args_c, list()) else: if 'inplace_action_body' in annotations and annotations['inplace_action_body']: args_c = _canonize_arguments(args) argument_types = [arg.return_type for arg in args_c] # logger.warning(f'Action {name} not found, creating a new one with argument types {argument_types}.') predicate = self.domain.define_crow_function(name, argument_types, self.domain.get_type('__totally_ordered_plan__')) return E.FunctionApplicationExpression(predicate, args_c) elif 'inplace_generator' in annotations and annotations['inplace_generator']: args_c = _canonize_arguments(args) argument_types = [arg.return_type for arg in args_c] # logger.warning(f'Generator placeholder function {name} not found, creating a new one with argument types {argument_types}.') predicate = self.domain.define_crow_function( name, argument_types, BOOL, generator_placeholder=annotations.get('generator_placeholder', True) ) assert 'inplace_generator_targets' in annotations, f'Inplace generator {name} requires inplace generator targets to be set.' inplace_generator_targets = annotations['inplace_generator_targets'] generator_name = 'gen_' + name generator_arguments = predicate.arguments generator_goal = E.FunctionApplicationExpression(predicate, [E.VariableExpression(arg) for arg in generator_arguments]) output_argument_names = [x.value for x in inplace_generator_targets.items] output_indices = list() for i, arg in enumerate(args_c): if isinstance(arg, E.VariableExpression) and arg.name in output_argument_names: output_argument_names.remove(arg.name) output_indices.append(i) assert len(output_argument_names) == 0, f'Mismatched output arguments for inplace generator {name}: {output_argument_names}' inputs = [arg for i, arg in enumerate(generator_arguments) if i not in output_indices] outputs = [arg for i, arg in enumerate(generator_arguments) if i in output_indices] self.domain.define_generator(generator_name, generator_arguments, generator_goal, inputs, outputs) return E.FunctionApplicationExpression(predicate, args_c) else: raise KeyError(f'Function {name} not found. Note that recursive function calls are not supported in the current version.')
[docs] @inline_args def atom_subscript(self, name: str, index: Tree) -> Union[E.FunctionApplicationExpression, E.ListFunctionApplicationExpression]: """Captures subscript expressions such as `name[index1, index2, ...]`.""" feature = self.domain.get_feature(name) index: CSList = self.visit(index) if not feature.is_state_variable: raise ValueError(f'Invalid subscript expression: {name} is not a state variable. Expression: {name}[{index.items}]') items = index.items if len(items) == 1 and items[0] is Ellipsis: return E.FunctionApplicationExpression(feature, tuple()) arguments = _canonize_arguments(index.items, dtypes=feature.ftype.argument_types) if _has_list_arguments(arguments): return E.ListFunctionApplicationExpression(feature, arguments) return E.FunctionApplicationExpression(feature, arguments)
[docs] @inline_args def atom(self, value: Union[FunctionCall, Variable]) -> Union[FunctionCall, Variable]: """Captures the atom. This is used in the base case of the expression, including literal constants, variables, and subscript expressions.""" return value
[docs] def arguments(self, args: Tree) -> ArgumentsList: """Captures the argument list. This is used in function calls.""" args = self.visit_children(args) return ArgumentsList(tuple(args))
[docs] @inline_args def power(self, base: Union[FunctionCall, Variable], exp: Optional[float] = None) -> Union[FunctionCall, Variable]: """The highest-priority expression. This is used to capture the power expression, such as `base ** exp`. If `exp` is None, it is treated as `base ** 1`.""" if exp is None: return base raise NotImplementedError('Power expression is not supported in the current version.')
[docs] @inline_args def factor(self, value: Union[FunctionCall, Variable]) -> Union[FunctionCall, Variable]: return value
[docs] @inline_args def unary_op_expr(self, op: str, value: Union[FunctionCall, Variable]) -> FunctionCall: raise NotImplementedError('Unary operators are not supported in the current version.')
mul_expr = gen_term_expr('mul') arith_expr = gen_term_expr('add') shift_expr = gen_term_expr('shift') bitand_expr = gen_term_expr_noop('bitand') bitxor_expr = gen_term_expr_noop('bitxor') bitor_expr = gen_term_expr_noop('bitor')
[docs] @inline_args def comparison_expr(self, *values: Union[E.ValueOutputExpression, E.VariableExpression]) -> E.ValueOutputExpression: if len(values) == 1: return self.visit(values[0]) assert len(values) % 2 == 1, f'[compare] expressions expected an odd number of values, got {len(values)}. Values: {values}.' values = [self.visit(value) for value in values] results = list() for i in range(1, len(values), 2): if values[i - 1].return_type.is_object_type and values[i + 1].return_type.is_object_type: results.append(E.ObjectCompareExpression(E.CompareOpType.from_string(values[i][0].value), values[i - 1], values[i + 1])) if len(results) == 1: return results[0] result = E.AndExpression(*results) return result
[docs] @inline_args def not_test(self, value: Any) -> E.NotExpression: return E.NotExpression(*_canonize_arguments([self.visit(value)]))
[docs] @inline_args def and_test(self, *values: Any) -> E.AndExpression: values = [self.visit(value) for value in values] if len(values) == 1: return values[0] result = E.AndExpression(*_canonize_arguments(values)) return result
[docs] @inline_args def or_test(self, *values: Any) -> E.OrExpression: values = [self.visit(value) for value in values] if len(values) == 1: return values[0] result = E.OrExpression(*_canonize_arguments(values)) return result
[docs] @inline_args def cond_test(self, value1: Any, cond: Any, value2: Any) -> E.ConditionExpression: return E.ConditionExpression(*_canonize_arguments([self.visit(cond), self.visit(value1), self.visit(value2)]))
[docs] @inline_args def test(self, value: Any): return self.visit(value)
[docs] @inline_args def test_nocond(self, value: Any): return self.visit(value)
[docs] @inline_args def tuple(self, *values: Any): return tuple(self.visit(v) for v in values)
[docs] @inline_args def list(self, *values: Any): return E.ListCreationExpression(_canonize_arguments([self.visit(v) for v in values]))
[docs] @inline_args def cs_list(self, *values: Any): return CSList(tuple(self.visit(v) for v in values))
[docs] @inline_args def suite(self, *values: Tree, activate_variable_guard: bool = True) -> Suite: if activate_variable_guard: with self.local_variable_guard(): values = [self.visit(value) for value in values] local_variables = self.local_variables.copy() else: values = [self.visit(value) for value in values] local_variables = self.local_variables.copy() return Suite(tuple(v for v in values if v is not None), local_variables)
[docs] @inline_args def expr_stmt(self, value: Tree): value = self.visit(value) if value is Ellipsis: return None if isinstance(value, str): return None return FunctionCall('expr', ArgumentsList((_canonize_single_argument(value),)))
[docs] @inline_args def expr_list_expansion_stmt(self, value: Any): value = _canonize_single_argument(self.visit(value)) return FunctionCall('expr', ArgumentsList((E.ListExpansionExpression(value), )))
[docs] @inline_args def assign_stmt(self, target: Any, value: Any): return FunctionCall('assign', ArgumentsList((_canonize_single_argument(self.visit(target)), _canonize_single_argument(self.visit(value)))))
[docs] @inline_args def annotated_assign_stmt(self, annotations: dict, target: Any, value: Any): return FunctionCall('assign', ArgumentsList((_canonize_single_argument(self.visit(target)), _canonize_single_argument(self.visit(value)))), annotations)
[docs] @inline_args def local_assign_stmt(self, target: str, value: Any = None): assert isinstance(target, str), f'Invalid local variable name: {target}' value = _canonize_single_argument(self.visit(value)) self.local_variables[target] = E.VariableExpression(Variable(target, value.return_type)) if value is not None: return FunctionCall('assign', ArgumentsList((self.local_variables[target], value))) return None
[docs] @inline_args def pass_stmt(self): return FunctionCall('pass', ArgumentsList(tuple()))
[docs] @inline_args def check_stmt(self, value: Any): return FunctionCall('check', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs] @inline_args def achieve_stmt(self, value: CSList): return FunctionCall('achieve', ArgumentsList(_canonize_arguments(self.visit(value).items)))
[docs] @inline_args def assert_stmt(self, value: Any): return FunctionCall('assert', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs] @inline_args def return_stmt(self, value: Any): return FunctionCall('return', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs] @inline_args def compound_assign_stmt(self, target: Variable, value: Any): return FunctionCall('assign', ArgumentsList((self.visit(target), _canonize_single_argument(self.visit(value)))))
[docs] @inline_args def compound_check_stmt(self, value: Any): return FunctionCall('check', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs] @inline_args def compound_achieve_stmt(self, value: Any): return FunctionCall('achieve', ArgumentsList(_canonize_arguments([self.visit(value)])))
[docs] @inline_args def compound_assert_stmt(self, value: Any): return FunctionCall('assert', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs] @inline_args def compound_return_stmt(self, value: Any): return FunctionCall('return', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs] @inline_args def ordered_suite(self, ordering_op: Any, body: Any): ordering_op = self.visit(ordering_op) assert body.data.value == 'suite', f'Invalid body type: {body}' if ordering_op in ('promotable', 'unordered', 'promotable unordered', 'promotable sequential'): with self.local_variable_guard(): body = self.visit(body) return FunctionCall('ordering', ArgumentsList((ordering_op, body))) else: body = self.visit_children(body) body = Suite(tuple(body), self.local_variables.copy()) return FunctionCall('ordering', ArgumentsList((ordering_op, body)))
[docs] @inline_args def ordering_op(self, *ordering_op: Any): return ' '.join([x.value for x in ordering_op])
[docs] @inline_args def if_stmt(self, cond: Any, suite: Any, else_suite: Optional[Any] = None): cond = _canonize_single_argument(self.visit(cond)) with self.local_variable_guard(): suite = self.visit(suite) if else_suite is None: else_suite = Suite((FunctionCall('pass', ArgumentsList(tuple())), )) else: with self.local_variable_guard(): else_suite = self.visit(else_suite) return FunctionCall('if', ArgumentsList((cond, suite, else_suite)))
[docs] @inline_args def forall_stmt(self, cs_list: Any, suite: Any): cs_list = self.visit(cs_list) with self.expression_def_ctx.new_variables(*cs_list.items), self.local_variable_guard(): suite = self.visit(suite) return FunctionCall('forall', ArgumentsList((cs_list, suite)))
[docs] @inline_args def forall_in_stmt(self, variables_cs_list: Any, values_cs_list: Any, suite: Any): variables_cs_list = self.visit(variables_cs_list) values_cs_list = self.visit(values_cs_list) with self.local_variable_guard(): for variable_item, value_item in zip(variables_cs_list.items, values_cs_list.items): self.local_variables[variable_item] = value_item suite = self.visit(suite) return FunctionCall('forall_in', ArgumentsList((suite, )))
def _quantification_expression(self, cs_list: Any, suite: Any, quantification_cls): cs_list = self.visit(cs_list) with self.expression_def_ctx.new_variables(*cs_list.items), self.local_variable_guard(): suite = self.visit(suite) body = suite.get_combined_return_expression(allow_expr_expressions=True) for item in reversed(cs_list.items): body = quantification_cls(item, body) return body
[docs] @inline_args def forall_test(self, cs_list: Any, suite: Any): return self._quantification_expression(cs_list, suite, E.ForallExpression)
[docs] @inline_args def exists_test(self, cs_list: Any, suite: Any): return self._quantification_expression(cs_list, suite, E.ExistsExpression)
[docs] @inline_args def findall_test(self, variable: Variable, suite: Any): with self.expression_def_ctx.new_variables(variable), self.local_variable_guard(): suite = self.visit(suite) body = suite.get_combined_return_expression(allow_expr_expressions=True) return E.FindAllExpression(variable, body)
def _quantification_in_expression(self, cs_list: Any, suite: Any, quantification_cls): cs_list = self.visit(cs_list) with self.local_variable_guard(): item: InTypedArgument for item in cs_list.items: self.local_variables[item.name] = self.visit(item.value) suite = self.visit(suite) body = suite.get_combined_return_expression(allow_expr_expressions=True) return quantification_cls(body)
[docs] @inline_args def forall_in_test(self, cs_list: Any, suite: Any): return self._quantification_in_expression(cs_list, suite, E.AndExpression)
[docs] @inline_args def exists_in_test(self, cs_list: Any, suite: Any): return self._quantification_in_expression(cs_list, suite, E.OrExpression)
[docs] @inline_args def bind_stmt(self, cs_list: Any, suite: Any): cs_list = self.visit(cs_list) with self.expression_def_ctx.new_variables(*cs_list.items), self.local_variable_guard(): suite = self.visit(suite) body = suite.get_combined_return_expression(allow_expr_expressions=True, allow_multiple_expressions=True) if isinstance(body, list): body = E.AndExpression(*body) for item in cs_list.items: self.local_variables[item.name] = E.VariableExpression(item) return FunctionCall('bind', ArgumentsList((cs_list, body)))
[docs] @inline_args def bind_stmt_no_where(self, cs_list: Any): """Captures bind statements without a body. For example: .. code-block:: python bind x: int, y: int """ cs_list = self.visit(cs_list) for item in cs_list.items: self.local_variables[item.name] = E.VariableExpression(item) return FunctionCall('bind', ArgumentsList((cs_list, E.NullExpression(BOOL))))
[docs] @inline_args def annotated_compound_stmt(self, annotations: dict, stmt: Any): stmt = self.visit(stmt) stmt.annotations = annotations return stmt
[docs]@dataclass class ArgumentsDef(object): arguments: Tuple[Variable, ...]
[docs]@dataclass class PreconditionPart(object): suite: Tree
[docs]@dataclass class EffectPart(object): suite: Tree
[docs]@dataclass class GoalPart(object): suite: Tree
[docs]@dataclass class BodyPart(object): suite: Tree
[docs]@dataclass class SideEffectPart(object): suite: Tree
[docs]@dataclass class ImplPart(object): suite: Tree
[docs]@dataclass class InPart(object): suite: Tree
[docs]@dataclass class OutPart(object): suite: Tree
[docs]class PDSketchV3DomainTransformer(PDSketch3LiteralTransformer):
[docs] def __init__(self, domain: Optional[CrowDomain] = None): super().__init__() self.domain = CrowDomain() if domain is None else domain self.expression_def_ctx = E.ExpressionDefinitionContext(domain=self.domain) self.expression_interpreter = PDSketch3ExpressionInterpreter(domain=self.domain, state=None, expression_def_ctx=self.expression_def_ctx, auto_constant_guess=False)
domain: CrowDomain expression_def_ctx: E.ExpressionDefinitionContext expression_interpreter: PDSketch3ExpressionInterpreter
[docs] @inline_args def pragma_definition(self, pragma: Dict[str, Any]): print('pragma_definition', pragma)
[docs] @inline_args def type_definition(self, typename, basetype: Optional[Union[str, TypeBase]]): print(f'type_definition:: {typename=} {basetype=}') self.domain.define_type(typename, basetype)
[docs] @inline_args def feature_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], ret: Optional[Union[str, TypeBase]], suite: Optional[Tree]): if annotations is None: annotations = dict() if args is None: args = ArgumentsDef(tuple()) if ret is None: ret = BOOL elif isinstance(ret, str): ret = self.domain.get_type(ret) return_stmt = None if suite is not None: with self.expression_def_ctx.with_variables(*args.arguments): suite = self.expression_interpreter.visit(suite) return_stmt = suite.get_combined_return_expression(allow_expr_expressions=False) self.domain.define_feature(name, args.arguments, ret, derived_expression=return_stmt, **annotations) print(f'feature_definition:: {name=} {args.arguments=} {ret=} {annotations=} {suite=}') if return_stmt is not None: print(jacinle.indent_text(f'Return statement: {return_stmt}'))
[docs] @inline_args def function_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], ret: Optional[Union[str, TypeBase]], suite: Optional[Tree]): if annotations is None: annotations = dict() if args is None: args = ArgumentsDef(tuple()) if ret is None: ret = BOOL elif isinstance(ret, str): ret = self.domain.get_type(ret) return_stmt = None if suite is not None: with self.expression_def_ctx.with_variables(*args.arguments): suite = self.expression_interpreter.visit(suite) return_stmt = suite.get_combined_return_expression(allow_expr_expressions=False) self.domain.define_crow_function(name, args.arguments, ret, derived_expression=return_stmt, **annotations) print(f'function_definition:: {name=} {args.arguments=} {ret=} {annotations=} {suite=}') if return_stmt is not None: print(jacinle.indent_text(f'Return statement: {return_stmt}'))
[docs] @inline_args def controller_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef]): if annotations is None: annotations = dict() if args is None: args = ArgumentsDef(tuple()) self.domain.define_controller(name, args.arguments, **annotations) print(f'controller_definition:: {name=} {args.arguments=}')
[docs] @inline_args def action_precondition_definition(self, suite: Tree) -> PreconditionPart: return PreconditionPart(suite)
[docs] @inline_args def action_effect_definition(self, suite: Tree) -> EffectPart: return EffectPart(suite)
[docs] @inline_args def action_goal_definition(self, suite: Tree) -> GoalPart: return GoalPart(suite)
[docs] @inline_args def action_body_definition(self, suite: Tree) -> BodyPart: return BodyPart(suite)
[docs] @inline_args def action_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], *parts: Union[PreconditionPart, GoalPart, EffectPart, BodyPart]): if annotations is None: annotations = dict() if args is None: args = ArgumentsDef(tuple()) print(f'action_definition:: {name=} {args.arguments=} {annotations=}') precondition = list() goals = list() body = list() effect = list() local_variables = None for part in parts: with self.expression_def_ctx.with_variables(*args.arguments): if isinstance(part, EffectPart): self.expression_interpreter.local_variables = local_variables if local_variables is not None else dict() suite = self.expression_interpreter.visit(part.suite) self.expression_interpreter.local_variables = dict() else: suite = self.expression_interpreter.visit(part.suite) if isinstance(part, BodyPart): print(jacinle.indent_text(f'Body: {suite}')) if isinstance(part, PreconditionPart): suite = suite.get_all_check_expressions() precondition = [CrowPrecondition(x) for x in suite] print(jacinle.indent_text(f'Precondition: {precondition}')) elif isinstance(part, GoalPart): if len(suite.items) != 1: raise NotImplementedError('Multiple goals are not supported in the current version.') goal = suite.items[0] goals.append(goal) print(jacinle.indent_text(f'Goal: {goal}')) elif isinstance(part, BodyPart): local_variables = suite.local_variables body = suite.get_all_action_expressions(use_runtime_assign=True) print(jacinle.indent_text(f'Body:')) if len(body) == 0: body = CrowActionOrderingSuite('sequential', tuple()) elif len(body) == 1: if not isinstance(body[0], CrowActionOrderingSuite) or body[0].ordering_op != 'sequential': body = CrowActionOrderingSuite('sequential', (body[0],)) else: body = CrowActionOrderingSuite('sequential', tuple(body)) print(jacinle.indent_text(str(body), level=2)) elif isinstance(part, EffectPart): suite = suite.get_all_assign_expressions() effect = [CrowEffect(x, **a) for x, a in suite] print(jacinle.indent_text(f'Effect: {effect}')) else: raise ValueError(f'Invalid part: {part}') if len(goals) == 0: goal = E.NullExpression(BOOL) self.domain.define_action(name, args.arguments, goal, body, precondition, effect, **annotations) elif len(goals) == 1: self.domain.define_action(name, args.arguments, goals[0], body, precondition, effect, **annotations) else: goal = E.NullExpression(BOOL) self.domain.define_action(name, args.arguments, goal, body, precondition, effect, **annotations) for i, goal in enumerate(goals): self.domain.define_action(f'{name}_{i}', args.arguments, goal, body, precondition, effect, **annotations)
[docs] @inline_args def generator_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], *parts: Union[PreconditionPart, GoalPart, InPart, OutPart]): if annotations is None: annotations = dict() if args is None: args = ArgumentsDef(tuple()) print(f'generator_definition:: {name=} {args.arguments=} {annotations=}') inputs = None outputs = None goal = None with self.expression_def_ctx.with_variables(*args.arguments): for part in parts: if isinstance(part, ImplPart): continue suite = self.expression_interpreter.visit(part.suite) if isinstance(part, GoalPart): goal = suite.get_all_expr_expression(allow_multiple_expressions=False) print(jacinle.indent_text(f'Goal: {goal}')) elif isinstance(part, InPart): inputs = [E.VariableExpression(x) for x in suite.items] print(jacinle.indent_text(f'Inputs: {inputs}')) elif isinstance(part, OutPart): outputs = [E.VariableExpression(x) for x in suite.items] print(jacinle.indent_text(f'Outputs: {outputs}')) else: raise ValueError(f'Invalid part: {part}') self.domain.define_generator(name, args.arguments, goal, inputs, outputs, **annotations)
[docs] @inline_args def generator_precondition_definition(self, suite: Tree) -> PreconditionPart: return PreconditionPart(suite)
[docs] @inline_args def generator_goal_definition(self, suite: Tree) -> GoalPart: return GoalPart(suite)
[docs] @inline_args def generator_in_definition(self, values: Tree) -> InPart: return InPart(values)
[docs] @inline_args def generator_out_definition(self, values: Tree) -> OutPart: return OutPart(values)
[docs] @inline_args def generator_impl_definition(self, suite: Tree) -> ImplPart: return ImplPart(suite)
[docs]class PDSketchV3ProblemTransformer(PDSketch3LiteralTransformer):
[docs] def __init__(self, domain: Optional[CrowDomain] = None, state: Optional[CrowState] = None, auto_constant_guess: bool = False): super().__init__() self.domain = None self.state = None self.problem = None self.expression_def_ctx = None self.expression_interpreter = None self.auto_constant_guess = auto_constant_guess if domain is not None: self._init_domain(domain, state)
domain: Optional[CrowDomain] state: Optional[CrowState] def _init_domain(self, domain: CrowDomain, state: Optional[CrowState] = None): if self.domain is not None: raise ValueError('Domain is already initialized. Cannot overwrite the domain.') self.domain = domain self.problem = CrowProblem(domain=self.domain) self.expression_def_ctx = E.ExpressionDefinitionContext(domain=self.domain) self.expression_interpreter = PDSketch3ExpressionInterpreter(domain=self.domain, state=state, expression_def_ctx=self.expression_def_ctx, auto_constant_guess=self.auto_constant_guess)
[docs] @inline_args def domain_def(self, filename: str): if self.domain is not None: logger.warning('Domain is already initialized. Skip the in-place domain loading.') return domain = _parser.parse_domain(filename) self._init_domain(domain)
[docs] @inline_args def objects_definition(self, *objects): for o in objects: self.problem.add_object(o.name, o.dtype.typename) self.expression_interpreter.local_variables[o.name] = E.ObjectConstantExpression(ObjectConstant(o.name, o.dtype))
[docs] @inline_args def init_definition(self, suite: Tree): self.problem.init_state() suite = self.expression_interpreter.visit(suite) executor = self.domain.make_executor() for stmt, _ in suite.get_all_assign_expressions(): executor.execute(stmt, state=self.problem.state) for stmt in suite.get_all_expr_expression(allow_multiple_expressions=True): if isinstance(stmt, E.FunctionApplicationExpression): executor.execute(E.AssignExpression(stmt, E.ConstantExpression.TRUE), state=self.problem.state)
[docs] @inline_args def goal_definition(self, suite: Tree): suite = self.expression_interpreter.visit(suite) suite = suite.get_all_expr_expression(allow_multiple_expressions=False) self.problem.set_goal(suite)
_parser = PDSketchV3Parser()
[docs]def load_domain_file(filename:str) -> CrowDomain: """Load a domain file. Args: filename: the filename of the domain file. Returns: the domain object. """ return _parser.parse_domain(filename)
[docs]def load_domain_string(string: str) -> CrowDomain: """Load a domain from a string. Args: string: the string containing the domain definition. Returns: the domain object. """ return _parser.parse_domain_str(string)
[docs]def load_domain_string_incremental(domain: CrowDomain, string: str) -> CrowDomain: """Load a domain from a string incrementally. Args: domain: the domain object to be updated. string: the string containing the domain definition. Returns: the domain object. """ return _parser.parse_domain_str(string, domain=domain)
[docs]def load_problem_file(filename: str, domain: Optional[CrowDomain] = None) -> CrowProblem: """Load a problem file. Args: filename: the filename of the problem file. domain: the domain object. If not provided, the domain will be loaded from the domain file specified in the problem file. Returns: the problem object. """ return _parser.parse_problem(filename, domain=domain)
[docs]def load_problem_string(string: str, domain: Optional[CrowDomain] = None) -> CrowProblem: """Load a problem from a string. Args: string: the string containing the problem definition. domain: the domain object. If not provided, the domain will be loaded from the domain file specified in the problem file. Returns: the problem object. """ return _parser.parse_problem_str(string, domain=domain)
[docs]def parse_expression(domain: CrowDomain, string: str, state: Optional[CrowState] = None, variables: Optional[Sequence[Variable]] = None, auto_constant_guess: bool = True) -> E.Expression: """Parse an expression. Args: domain: the domain object. string: the string containing the expression. state: the current state, containing objects. variables: the variables. auto_constant_guess: whether to guess whether a variable is a constant. Returns: the parsed expression. """ return _parser.parse_expression(string, domain, state=state, variables=variables, auto_constant_guess=auto_constant_guess)