#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : cdl_parser.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 02/10/2024
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.
import os
import os.path as osp
import contextlib
import jacinle
from typing import Any, Optional, Union, Sequence, Tuple, Set, List, Dict
from dataclasses import dataclass, field
from lark import Lark, Tree
from lark.visitors import Transformer, Interpreter, v_args
from lark.indenter import PythonIndenter
import concepts.dsl.expression as E
from concepts.dsl.dsl_types import TypeBase, AutoType, ListType, BOOL, Variable, ObjectConstant, UnnamedPlaceholder, QINDEX
from concepts.dsl.dsl_functions import FunctionType
from concepts.dsl.tensor_value import TensorValue
from concepts.dsl.tensor_state import StateObjectList
from concepts.dm.crow.crow_function import CrowFunction
from concepts.dm.crow.controller import CrowControllerApplicationExpression
from concepts.dm.crow.behavior import CrowBehaviorBodyItem
from concepts.dm.crow.behavior import CrowBehaviorApplicationExpression
from concepts.dm.crow.behavior import CrowBehaviorOrderingSuite
from concepts.dm.crow.crow_generator import CrowGeneratorApplicationExpression
from concepts.dm.crow.crow_domain import CrowDomain, CrowProblem, CrowState
from concepts.dm.crow.behavior_utils import execute_effect_statements
from concepts.dm.crow.parsers.cdl_literal_parser import InTypedArgument, ArgumentsDef, CSList, CDLLiteralTransformer
from concepts.dm.crow.parsers.cdl_symbolic_execution import ArgumentsList, FunctionCall, Suite
logger = jacinle.get_logger(__name__)
inline_args = v_args(inline=True)
__all__ = [
'CDLPathResolver', 'get_default_path_resolver',
'CDLParser', 'get_default_parser',
'load_domain_file', 'load_domain_string', 'load_domain_string_incremental',
'load_problem_file', 'load_problem_string',
'parse_expression',
'CDLDomainTransformer', 'CDLProblemTransformer', 'CDLExpressionInterpreter',
]
[docs]
class CDLPathResolver(object):
[docs]
def __init__(self, search_paths: Sequence[str] = tuple()):
self.search_paths = list(search_paths)
[docs]
def resolve(self, filename: str, relative_filename: Optional[str] = None) -> str:
if osp.exists(filename):
return filename
# Try the relative filename first.
if relative_filename is not None:
relative_dir = osp.dirname(relative_filename)
full_path = osp.join(relative_dir, filename)
if osp.exists(full_path):
return full_path
# Try the current directory second.
if osp.exists(osp.join(os.getcwd(), filename)):
return osp.join(os.getcwd(), filename)
# Then try the search paths.
for path in self.search_paths:
full_path = osp.join(path, filename)
if osp.exists(full_path):
return full_path
raise FileNotFoundError(f'File not found: {filename}')
[docs]
def add_search_path(self, path: str):
if path not in self.search_paths:
self.search_paths.append(path)
[docs]
def remove_search_path(self, path: str):
self.search_paths.remove(path)
[docs]
class CDLParser(object):
"""The parser for PDSketch v3."""
grammar_file = osp.join(osp.dirname(osp.abspath(__file__)), 'cdl.grammar')
"""The grammar definition v3 for PDSketch."""
[docs]
def __init__(self):
"""Initialize the parser."""
with open(self.grammar_file, 'r') as f:
self.grammar = f.read()
self.parser = Lark(self.grammar, start='start', postlex=PythonIndenter(), parser='lalr')
[docs]
def parse(self, filename: str) -> Tree:
"""Parse a PDSketch v3 file.
Args:
filename: the filename to parse.
Returns:
the parse tree. It is a :class:`lark.Tree` object.
"""
filename = _DEFAULT_PATH_RESOLVER.resolve(filename)
with open(filename, 'r') as f:
return self.parse_str(f.read())
[docs]
def parse_str(self, s: str) -> Tree:
"""Parse a PDSketch v3 string.
Args:
s: the string to parse.
Returns:
the parse tree. It is a :class:`lark.Tree` object.
"""
# NB(Jiayuan Mao @ 2024/03/13): for reasons, the pdsketch-v3 grammar file requires that the string ends with a newline.
# In particular, the suite definition requires that the file ends with a _DEDENT token, which seems to be only triggered by a newline.
s = s.strip() + '\n'
if s.startswith('#!') and not s.startswith('#!pragma'):
s = s[s.find('\n') + 1:]
return self.parser.parse(s)
[docs]
def parse_domain(self, filename: str) -> CrowDomain:
"""Parse a PDSketch v3 domain file.
Args:
filename: the filename to parse.
Returns:
the parsed domain.
"""
return self.transform_domain(self.parse(filename))
[docs]
def parse_domain_str(self, s: str, domain: Optional[CrowDomain] = None) -> Any:
"""Parse a PDSketch v3 domain string.
Args:
s: the string to parse.
domain: the domain to use. If not provided, a new domain will be created.
Returns:
the parsed domain.
"""
return self.transform_domain(self.parse_str(s), domain=domain)
[docs]
def parse_problem(self, filename: str, domain: Optional[CrowDomain] = None) -> CrowProblem:
"""Parse a PDSketch v3 problem file.
Args:
filename: the filename to parse.
domain: the domain to use. If not provided, the domain will be parsed from the problem file.
Returns:
the parsed problem.
"""
return self.transform_problem(self.parse(filename), domain=domain)
[docs]
def parse_problem_str(self, s: str, domain: Optional[CrowDomain] = None) -> CrowProblem:
"""Parse a PDSketch v3 problem string.
Args:
s: the string to parse.
domain: the domain to use. If not provided, the domain will be parsed from the problem file.
Returns:
the parsed problem.
"""
return self.transform_problem(self.parse_str(s), domain=domain)
[docs]
def parse_expression(self, s: str, domain: CrowDomain, state: Optional[CrowState] = None, variables: Optional[Sequence[Variable]] = None, auto_constant_guess: bool = True) -> E.Expression:
"""Parse a PDSketch v3 expression string.
Args:
s: the string to parse.
domain: the domain to use.
state: the current state, containing objects.
variables: variables from the outer scope.
auto_constant_guess: whether to automatically guess whether a variable is a constant.
Returns:
the parsed expression.
"""
return self.transform_expression(self.parse_str(s), domain, state=state, variables=variables, auto_constant_guess=auto_constant_guess)
[docs]
@staticmethod
def transform_domain(tree: Tree, domain: Optional[CrowDomain] = None) -> CrowDomain:
"""Transform a parse tree into a domain.
Args:
tree: the parse tree.
domain: the domain to use. If not provided, a new domain will be created.
Returns:
the parsed domain.
"""
transformer = CDLDomainTransformer(domain)
transformer.transform(tree)
return transformer.domain
_DEFAULT_PATH_RESOLVER = CDLPathResolver()
_DEFAULT_PARSER = None
_PARSER_VERBOSE = False
[docs]
def get_default_path_resolver() -> CDLPathResolver:
global _DEFAULT_PATH_RESOLVER
return _DEFAULT_PATH_RESOLVER
[docs]
def get_default_parser() -> CDLParser:
global _DEFAULT_PARSER
if _DEFAULT_PARSER is None:
_DEFAULT_PARSER = CDLParser()
return _DEFAULT_PARSER
[docs]
def set_parser_verbose(verbose: bool = True):
global _PARSER_VERBOSE
_PARSER_VERBOSE = verbose
[docs]
def load_domain_file(filename:str) -> CrowDomain:
"""Load a domain file.
Args:
filename: the filename of the domain file.
Returns:
the domain object.
"""
return get_default_parser().parse_domain(filename)
[docs]
def load_domain_string(string: str) -> CrowDomain:
"""Load a domain from a string.
Args:
string: the string containing the domain definition.
Returns:
the domain object.
"""
return get_default_parser().parse_domain_str(string)
[docs]
def load_domain_string_incremental(domain: CrowDomain, string: str) -> CrowDomain:
"""Load a domain from a string incrementally.
Args:
domain: the domain object to be updated.
string: the string containing the domain definition.
Returns:
the domain object.
"""
return get_default_parser().parse_domain_str(string, domain=domain)
[docs]
def load_problem_file(filename: str, domain: Optional[CrowDomain] = None) -> CrowProblem:
"""Load a problem file.
Args:
filename: the filename of the problem file.
domain: the domain object. If not provided, the domain will be loaded from the domain file specified in the problem file.
Returns:
the problem object.
"""
return get_default_parser().parse_problem(filename, domain=domain)
[docs]
def load_problem_string(string: str, domain: Optional[CrowDomain] = None) -> CrowProblem:
"""Load a problem from a string.
Args:
string: the string containing the problem definition.
domain: the domain object. If not provided, the domain will be loaded from the domain file specified in the problem file.
Returns:
the problem object.
"""
return get_default_parser().parse_problem_str(string, domain=domain)
[docs]
def parse_expression(domain: CrowDomain, string: str, state: Optional[CrowState] = None, variables: Optional[Sequence[Variable]] = None, auto_constant_guess: bool = True) -> E.Expression:
"""Parse an expression.
Args:
domain: the domain object.
string: the string containing the expression.
state: the current state, containing objects.
variables: the variables.
auto_constant_guess: whether to guess whether a variable is a constant.
Returns:
the parsed expression.
"""
return get_default_parser().parse_expression(string, domain, state=state, variables=variables, auto_constant_guess=auto_constant_guess)
g_term_op_mapping = {
'*': 'mul',
'/': 'div',
'+': 'add',
'-': 'sub',
}
def _gen_term_expr_func(expr_typename: str):
"""Generate a term expression function. This function is used to generate the term expression functions for the transformer.
It is used to generate the following functions:
- mul_expr
- arith_expr
- shift_expr
Args:
expr_typename: the name of the expression type. This is only used for printing the debug information.
Returns:
the generated term expression function.
"""
@inline_args
def term(self, *values: Any):
values = [self.visit(value) for value in values]
if len(values) == 1:
return values[0]
assert len(values) % 2 == 1, f'[{expr_typename}] expressions expected an odd number of values, got {len(values)}. Values: {values}.'
result = values[0]
for i in range(1, len(values), 2):
v1, v2 = _canonize_arguments_same_dtype([result, values[i + 1]])
t = v1.return_type
if t.is_uniform_sequence_type:
t = t.element_type
fname = f'type::{t.typename}::{g_term_op_mapping[values[i]]}'
result = E.FunctionApplicationExpression(CrowFunction(fname, FunctionType([t, t], t)), [v1, v2])
# print(f'[{expr_typename}] result: {result}')
return result
return term
def _gen_bitop_expr_func(expr_typename: str):
"""Generate a term expression function. This function is used to generate the term expression functions for the transformer.
It is named `_noop` because the arguments to the function does not contain the operator being used. Therefore, we have to
specify the operator name manually (`expr_typename`). This is used for the following functions:
- bitand_expr
- bitxor_expr
- bitor_expr
"""
op_mapping = {
'bitand': E.BoolOpType.AND,
'bitxor': E.BoolOpType.XOR,
'bitor': E.BoolOpType.OR,
}
@inline_args
def term(self, *values: Any):
values = [self.visit(value) for value in values]
if len(values) == 1:
return values[0]
result = E.BoolExpression(op_mapping[expr_typename], _canonize_arguments(values))
# print(f'[{expr_typename}] result: {result}')
return result
return term
def _canonize_single_argument(arg: Any, dtype: Optional[TypeBase] = None) -> Union[E.ObjectOrValueOutputExpression, E.VariableExpression, type(Ellipsis)]:
if isinstance(arg, E.ObjectOrValueOutputExpression):
return arg
if isinstance(arg, (CrowControllerApplicationExpression, CrowBehaviorApplicationExpression, CrowGeneratorApplicationExpression)):
return arg
if isinstance(arg, Variable):
return E.VariableExpression(arg)
if isinstance(arg, (bool, int, float, complex, str)):
return E.ConstantExpression.from_value(arg, dtype=dtype)
if isinstance(arg, E.ListCreationExpression):
return arg
if arg is QINDEX:
return E.ObjectConstantExpression(ObjectConstant(StateObjectList(ListType(AutoType), QINDEX), ListType(AutoType)))
if arg is Ellipsis:
return Ellipsis
raise ValueError(f'Invalid argument: {arg}. Type: {type(arg)}.')
def _canonize_arguments_same_dtype(args: Optional[Union[ArgumentsList, tuple, list]] = None, dtype: Optional[TypeBase] = None) -> Tuple[Union[E.ValueOutputExpression, E.VariableExpression], ...]:
if args is None:
return tuple()
args = args.arguments if isinstance(args, ArgumentsList) else args
if dtype is None:
# Guess the dtype from the list.
for arg in args:
if isinstance(arg, E.ObjectOrValueOutputExpression):
dtype = arg.return_type
break
canonized_args = list()
for arg in args:
canonized_args.append(_canonize_single_argument(arg, dtype=dtype))
return tuple(canonized_args)
def _canonize_arguments(args: Optional[Union[ArgumentsList, tuple, list]] = None, dtypes: Optional[Union[Sequence[TypeBase], TypeBase]] = None) -> Tuple[Union[E.ValueOutputExpression, E.VariableExpression], ...]:
if args is None:
return tuple()
if isinstance(dtypes, TypeBase):
dtypes = [dtypes] * len(args)
# TODO(Jiayuan Mao @ 2024/03/2): Strictly check the allowability of "Ellipsis" in the arguments.
canonized_args = list()
arguments = args.arguments if isinstance(args, ArgumentsList) else args
if dtypes is not None:
if len(arguments) != len(dtypes):
raise ValueError(f'Number of arguments does not match the number of types: {len(arguments)} vs {len(dtypes)}. Args: {arguments}, Types: {dtypes}')
for i, arg in enumerate(arguments):
if arg is Ellipsis:
canonized_args.append(Ellipsis)
else:
canonized_args.append(_canonize_single_argument(arg, dtype=dtypes[i] if dtypes is not None else None))
return tuple(canonized_args)
def _safe_is_value_type(arg: Any) -> bool:
if isinstance(arg, E.ObjectOrValueOutputExpression):
return arg.return_type.is_value_type
if isinstance(arg, (bool, int, float, complex, str)):
return True
raise ValueError(f'Invalid argument: {arg}. Type: {type(arg)}.')
def _has_list_arguments(args: Tuple[E.ObjectOrValueOutputExpression, ...]) -> bool:
for arg in args:
if arg.return_type.is_list_type:
return True
return False
[docs]
class CDLExpressionInterpreter(Interpreter):
"""The transformer for expressions. Including:
- typename
- sized_vector_typename
- unsized_vector_typename
- typed_argument
- is_typed_argument
- in_typed_argument
- arguments_def
- atom_expr_funccall
- atom_varname
- atom
- power
- factor
- unary_op_expr
- mul_expr
- arith_expr
- shift_expr
- bitand_expr
- bitxor_expr
- bitor_expr
- comparison_expr
- not_test
- and_test
- or_test
- cond_test
- test
- test_nocond
- tuple
- list
- cs_list
- suite
- expr_stmt
- expr_list_expansion_stmt
- assign_stmt
- annotated_assign_stmt
- local_assign_stmt
- pass_stmt
- return_stmt
- achieve_stmt
- if_stmt
- foreach_stmt
- foreach_in_stmt
- while_stmt
- forall_test
- exists_test
- findall_test
- forall_in_test
- exists_in_test
"""
[docs]
def __init__(self, domain: Optional[CrowDomain], state: Optional[CrowState], expression_def_ctx: E.ExpressionDefinitionContext, auto_constant_guess: bool = False):
super().__init__()
self.domain = domain
self.state = state
self.expression_def_ctx = expression_def_ctx
self.auto_constant_guess = auto_constant_guess
self.generator_impl_outputs = None
self.local_variables = dict()
[docs]
def set_domain(self, domain: CrowDomain):
self.domain = domain
[docs]
@contextlib.contextmanager
def local_variable_guard(self):
backup = self.local_variables.copy()
yield
self.local_variables = backup
[docs]
@contextlib.contextmanager
def set_generator_impl_outputs(self, outputs: List[Variable]):
backup = self.generator_impl_outputs
self.generator_impl_outputs = outputs
yield
self.generator_impl_outputs = backup
domain: CrowDomain
expression_def_ctx: E.ExpressionDefinitionContext
[docs]
def visit(self, tree: Any) -> Any:
if isinstance(tree, Tree):
return super().visit(tree)
return tree
[docs]
@inline_args
def atom_colon(self):
return QINDEX
[docs]
@inline_args
def atom_varname(self, name: str) -> Union[E.VariableExpression, E.ObjectConstantExpression, E.ValueOutputExpression]:
"""Captures variable names such as `var_name`."""
if name in self.local_variables:
return self.local_variables[name]
if self.state is not None and name in self.state.object_name2defaultindex:
return E.ObjectConstantExpression(ObjectConstant(name, self.domain.types[self.state.get_default_typename(name)]))
if name in self.domain.constants:
constant = self.domain.constants[name]
if isinstance(constant, ObjectConstant):
return E.ObjectConstantExpression(constant)
return E.ConstantExpression(constant)
if name in self.domain.features and self.domain.features[name].nr_arguments == 0:
return E.FunctionApplicationExpression(self.domain.features[name], tuple())
# TODO(Jiayuan Mao @ 2024/03/12): smartly guess the type of the variable.
if not self.expression_def_ctx.has_variable(name):
if self.auto_constant_guess:
return E.ObjectConstantExpression(ObjectConstant(name, AutoType))
variable = self.expression_def_ctx.wrap_variable(name)
return variable
[docs]
@inline_args
def atom_expr_do_funccall(self, name: str, annotations: dict, args: Tree) -> CrowBehaviorBodyItem:
if self.domain.has_controller(name):
controller = self.domain.get_controller(name)
args: Optional[ArgumentsList] = self.visit(args)
args_c = _canonize_arguments(args, controller.argument_types)
return CrowControllerApplicationExpression(controller, args_c, **annotations if annotations is not None else dict())
else:
raise KeyError(f'Controller {name} not found.')
[docs]
@inline_args
def atom_expr_funccall(self, name: str, annotations: dict, args: Tree) -> Union[E.FunctionApplicationExpression, CrowBehaviorBodyItem, CrowBehaviorApplicationExpression, CrowGeneratorApplicationExpression]:
"""Captures function calls, such as `func_name(arg1, arg2, ...)`."""
annotations: Optional[dict] = self.visit(annotations)
args: Optional[ArgumentsList] = self.visit(args)
if annotations is None:
annotations = dict()
if args is None:
args = ArgumentsList(tuple())
if self.domain.has_type(name):
assert len(args.arguments) == 1, 'Type constructor expects exactly one argument.'
if isinstance(args.arguments[0], E.ConstantExpression):
if isinstance(args.arguments[0].constant, TensorValue) and args.arguments[0].constant.tensor.numel() == 0:
dtype = self.domain.get_type(name)
if dtype.is_uniform_sequence_type and dtype.element_type.is_object_type:
rv = E.ObjectConstantExpression(ObjectConstant(StateObjectList(dtype, []), dtype))
else:
rv = E.ConstantExpression(args.arguments[0].constant.clone(dtype=self.domain.get_type(name)))
else:
rv = E.ConstantExpression(args.arguments[0].constant.clone(dtype=self.domain.get_type(name)))
elif isinstance(args.arguments[0], E.ObjectConstantExpression):
rv = E.ObjectConstantExpression(args.arguments[0].constant.clone(dtype=self.domain.get_type(name)))
else:
raise TypeError('Invalid type definition: {name}({args.arguments}).')
return rv
elif self.domain.has_feature(name):
function = self.domain.get_feature(name)
args_c = _canonize_arguments(args, function.ftype.argument_types)
return E.FunctionApplicationExpression(function, args_c, **annotations)
elif self.domain.has_function(name):
function = self.domain.get_function(name)
args_c = _canonize_arguments(args, function.ftype.argument_types)
return E.FunctionApplicationExpression(function, args_c, **annotations)
elif self.domain.has_controller(name):
controller = self.domain.get_controller(name)
args_c = _canonize_arguments(args, controller.argument_types)
return CrowControllerApplicationExpression(controller, args_c)
elif self.domain.has_behavior(name):
behavior = self.domain.get_behavior(name)
args_c = _canonize_arguments(args, behavior.argument_types)
if len(args_c) > 0 and args_c[-1] is Ellipsis:
args_c = args_c[:-1] + tuple([UnnamedPlaceholder(t) for t in behavior.argument_types[len(args_c) - 1:]])
return CrowBehaviorApplicationExpression(behavior, args_c)
elif self.domain.has_generator(name):
generator = self.domain.get_generator(name)
args_c = _canonize_arguments(args, generator.argument_types)
return CrowGeneratorApplicationExpression(generator, args_c, list())
else:
if 'inplace_behavior_body' in annotations and annotations['inplace_behavior_body']:
"""Inplace definition of an function, used for define a __totally_ordered_plan__ function."""
args_c = _canonize_arguments(args)
argument_types = [arg.return_type for arg in args_c]
# logger.warning(f'Behavior {name} not found, creating a new one with argument types {argument_types}.')
predicate = self.domain.define_crow_function(name, argument_types, self.domain.get_type('__behavior_body__'))
return E.FunctionApplicationExpression(predicate, args_c)
elif 'inplace_generator' in annotations and annotations['inplace_generator']:
"""Inplace definition of an generator function. Typically this function is used together with the generator_placeholder annotation."""
args_c = _canonize_arguments(args)
argument_types = [arg.return_type for arg in args_c]
# logger.warning(f'Generator placeholder function {name} not found, creating a new one with argument types {argument_types}.')
predicate = self.domain.define_crow_function(
name, argument_types, BOOL,
generator_placeholder=annotations.get('generator_placeholder', True)
)
assert 'inplace_generator_targets' in annotations, f'Inplace generator {name} requires inplace generator targets to be set.'
inplace_generator_targets = annotations['inplace_generator_targets']
generator_name = 'gen_' + name
generator_arguments = predicate.arguments
generator_goal = E.FunctionApplicationExpression(predicate, [E.VariableExpression(arg) for arg in generator_arguments])
output_argument_names = [x.value for x in inplace_generator_targets.items]
output_indices = list()
for i, arg in enumerate(args_c):
if isinstance(arg, E.VariableExpression) and arg.name in output_argument_names:
output_argument_names.remove(arg.name)
output_indices.append(i)
assert len(output_argument_names) == 0, f'Mismatched output arguments for inplace generator {name}: {output_argument_names}'
inputs = [arg for i, arg in enumerate(generator_arguments) if i not in output_indices]
outputs = [arg for i, arg in enumerate(generator_arguments) if i in output_indices]
self.domain.define_generator(generator_name, generator_arguments, generator_goal, inputs, outputs)
return E.FunctionApplicationExpression(predicate, args_c)
else:
raise KeyError(f'Function {name} not found. Note that recursive function calls are not supported in the current version.')
[docs]
@inline_args
def atom_subscript(self, name: str, annotations: dict, index: Tree) -> Union[E.FunctionApplicationExpression]:
"""Captures subscript expressions such as `name[index1, index2, ...]`."""
feature = self.domain.get_feature(name)
index: CSList = self.visit(index)
annotations: Optional[dict] = self.visit(annotations)
if not feature.is_state_variable:
raise ValueError(f'Invalid subscript expression: {name} is not a state variable. Expression: {name}[{index.items}]')
if annotations is None:
annotations = dict()
items = index.items
if len(items) == 1 and items[0] is Ellipsis:
return E.FunctionApplicationExpression(feature, tuple())
arguments = _canonize_arguments(index.items, dtypes=feature.ftype.argument_types)
return E.FunctionApplicationExpression(feature, arguments, **annotations)
[docs]
@inline_args
def atom(self, value: Union[FunctionCall, Variable]) -> Union[FunctionCall, Variable]:
"""Captures the atom. This is used in the base case of the expression, including literal constants, variables, and subscript expressions."""
return value
[docs]
def arguments(self, args: Tree) -> ArgumentsList:
"""Captures the argument list. This is used in function calls."""
args = self.visit_children(args)
return ArgumentsList(tuple(args))
[docs]
@inline_args
def power(self, base: Union[FunctionCall, Variable], exp: Optional[Union[FunctionCall, Variable]] = None) -> Union[FunctionCall, Variable]:
"""The highest-priority expression. This is used to capture the power expression, such as `base ** exp`. If `exp` is None, it is treated as `base ** 1`."""
if exp is None:
return base
raise NotImplementedError('Power expression is not supported in the current version.')
[docs]
@inline_args
def factor(self, value: Union[FunctionCall, Variable]) -> Union[FunctionCall, Variable]:
return value
[docs]
@inline_args
def unary_op_expr(self, op: str, value: Union[FunctionCall, Variable]):
value = self.visit(value)
if op == '+':
return value
if op == '-':
if isinstance(value, (int, float)):
return -value
t = value.return_type
if t.is_uniform_sequence_type:
t = t.element_type
fname = f'type::{t.typename}::neg'
return E.FunctionApplicationExpression(CrowFunction(fname, FunctionType([t], t)), [value])
raise NotImplementedError(f'Unary operator {op} is not supported in the current version.')
mul_expr = _gen_term_expr_func('mul')
arith_expr = _gen_term_expr_func('add')
shift_expr = _gen_term_expr_func('shift')
bitand_expr = _gen_bitop_expr_func('bitand')
bitxor_expr = _gen_bitop_expr_func('bitxor')
bitor_expr = _gen_bitop_expr_func('bitor')
[docs]
@inline_args
def comparison_expr(self, *values: Union[E.ValueOutputExpression, E.VariableExpression]) -> E.ValueOutputExpression:
if len(values) == 1:
return self.visit(values[0])
assert len(values) % 2 == 1, f'[compare] expressions expected an odd number of values, got {len(values)}. Values: {values}.'
values = [self.visit(value) for value in values]
results = list()
for i in range(1, len(values), 2):
if not _safe_is_value_type(values[i - 1]) and not _safe_is_value_type(values[i + 1]):
results.append(E.ObjectCompareExpression(E.CompareOpType.from_string(values[i][0].value), values[i - 1], values[i + 1]))
elif _safe_is_value_type(values[i - 1]) and _safe_is_value_type(values[i + 1]):
v1, v2 = _canonize_arguments_same_dtype([values[i - 1], values[i + 1]])
results.append(E.ValueCompareExpression(E.CompareOpType.from_string(values[i][0].value), v1, v2))
else:
raise ValueError(f'Invalid comparison: {values[i - 1]} vs {values[i + 1]}')
if len(results) == 1:
return results[0]
result = E.AndExpression(*results)
return result
[docs]
@inline_args
def not_test(self, value: Any) -> E.NotExpression:
return E.NotExpression(*_canonize_arguments([self.visit(value)], BOOL))
[docs]
@inline_args
def and_test(self, *values: Any) -> E.AndExpression:
values = [self.visit(value) for value in values]
if len(values) == 1:
return values[0]
result = E.AndExpression(*_canonize_arguments(values, BOOL))
return result
[docs]
@inline_args
def or_test(self, *values: Any) -> E.OrExpression:
values = [self.visit(value) for value in values]
if len(values) == 1:
return values[0]
result = E.OrExpression(*_canonize_arguments(values, BOOL))
return result
[docs]
@inline_args
def cond_test(self, value1: Any, cond: Any, value2: Any) -> E.ConditionExpression:
x, y = _canonize_arguments_same_dtype([self.visit(value1), self.visit(value2)])
return E.ConditionExpression(_canonize_single_argument(self.visit(cond), BOOL), x, y)
[docs]
@inline_args
def test(self, value: Any):
return self.visit(value)
[docs]
@inline_args
def test_nocond(self, value: Any):
return self.visit(value)
[docs]
@inline_args
def tuple(self, *values: Any):
return tuple(self.visit(v) for v in values)
[docs]
@inline_args
def list(self, *values: Any):
values = [self.visit(v) for v in values]
# TODO(Jiayuan Mao @ 2024/08/10): make this more general, like nested list etc.
if all(isinstance(v, (int, float)) for v in values):
return E.ConstantExpression(TensorValue.from_values(*values))
else:
return E.ListCreationExpression(_canonize_arguments(values))
[docs]
@inline_args
def cs_list(self, *values: Any):
return CSList(tuple(self.visit(v) for v in values))
[docs]
@inline_args
def suite(self, *values: Tree, activate_variable_guard: bool = True) -> Suite:
if activate_variable_guard:
with self.local_variable_guard():
values = [self.visit(value) for value in values]
local_variables = self.local_variables.copy()
else:
values = [self.visit(value) for value in values]
local_variables = self.local_variables.copy()
return Suite(tuple(v for v in values if v is not None), local_variables)
[docs]
@inline_args
def expr_stmt(self, value: Tree):
value = self.visit(value)
if value is Ellipsis:
return None
# NB(Jiayuan Mao @ 2024/06/21): for handling string literals as docs.
if isinstance(value, str):
return None
return FunctionCall('expr', ArgumentsList((_canonize_single_argument(value),)))
[docs]
@inline_args
def expr_list_expansion_stmt(self, value: Any):
value = _canonize_single_argument(self.visit(value))
return FunctionCall('expr', ArgumentsList((E.ListExpansionExpression(value), )))
[docs]
@inline_args
def compound_expr_stmt(self, value: Tree):
value = self.visit(value)
if value is Ellipsis:
return None
if isinstance(value, str):
return None
return FunctionCall('expr', ArgumentsList((_canonize_single_argument(value),)))
def _make_additive_assign_stmt(self, lv, rv, op: str, annotations: dict):
if op == '=':
return FunctionCall('assign', ArgumentsList((lv, rv)), annotations)
if op in ('+=', '-=', '*=', '/='):
t = lv.return_type
if t.is_uniform_sequence_type:
t = t.element_type
fname = f'type::{t.typename}::{g_term_op_mapping[op[0]]}'
result = E.FunctionApplicationExpression(CrowFunction(fname, FunctionType([t, t], t)), [lv, rv])
return FunctionCall('assign', ArgumentsList((lv, result)), annotations)
if op == '%=':
t = lv.return_type
if t.is_uniform_sequence_type:
t = t.element_type
fname = f'type::{t.typename}::mod'
result = E.FunctionApplicationExpression(CrowFunction(fname, FunctionType([t, t], t)), [lv, rv])
return FunctionCall('assign', ArgumentsList((lv, result)), annotations)
if op in ('&=', '|=', '^='):
mapping = {
'&': E.BoolOpType.AND,
'|': E.BoolOpType.OR,
'^': E.BoolOpType.XOR,
}
result = E.BoolExpression(mapping[op[0]], (lv, rv))
return FunctionCall('assign', ArgumentsList((lv, result)), annotations)
raise ValueError(f'Invalid assignment operator: {op}')
[docs]
def assign_stmt_inner(self, op: str, target: Any, value: Any, annotations: dict):
if target.data == 'atom_varname':
target_lv = target.children[0]
else:
target_lv = _canonize_single_argument(self.visit(target)) # left value
if isinstance(target_lv, str):
if target_lv in self.local_variables:
annotations.setdefault('local', True)
target_lv = self.local_variables[target_lv]
target_rv = _canonize_single_argument(self.visit(value))
# return FunctionCall('assign', ArgumentsList((target_lv, target_rv)), annotations)
return self._make_additive_assign_stmt(target_lv, target_rv, op, annotations)
else:
if target_lv in self.domain.features and self.domain.get_feature(target_lv).nr_arguments == 0:
target_lv = E.FunctionApplicationExpression(self.domain.get_feature(target_lv), tuple())
# return FunctionCall('assign', ArgumentsList((target_lv, _canonize_single_argument(self.visit(value)))), annotations)
return self._make_additive_assign_stmt(target_lv, _canonize_single_argument(self.visit(value)), op, annotations)
else:
raise NameError(f'Invalid assignment target: it is not a local variable and not a feature with 0 arguments: {target_lv}')
# return FunctionCall('assign', ArgumentsList((target_lv, _canonize_single_argument(self.visit(value)))), annotations)
return self._make_additive_assign_stmt(target_lv, _canonize_single_argument(self.visit(value)), op, annotations)
[docs]
@inline_args
def assign_stmt(self, target: Any, op: Any, value: Any):
return self.assign_stmt_inner(op.value, target, value, dict())
[docs]
@inline_args
def annotated_assign_stmt(self, annotations: dict, target: Any, op: Any, value: Any):
return self.assign_stmt_inner(op.value, target, value, annotations)
[docs]
@inline_args
def let_assign_stmt(self, target: str, dtype: Any = None, value: Any = None):
assert isinstance(target, str), f'Invalid local variable name: {target}'
dtype = self.visit(dtype)
value = _canonize_single_argument(self.visit(value))
if dtype is not None:
dtype = self.domain.get_type(dtype)
if value is not None:
assert value.return_type.downcast_compatible(dtype), f'Invalid assignment: variable {target} has dtype {dtype} but the value has dtype {value.return_type}.'
if dtype is None and value is not None:
dtype = value.return_type
if dtype is None:
dtype = AutoType
self.local_variables[target] = E.VariableExpression(Variable(target, dtype))
if value is not None:
return FunctionCall('assign', ArgumentsList((self.local_variables[target], value)), {'local': True})
return None
[docs]
@inline_args
def symbol_assign_stmt(self, target: str, value: Any):
assert isinstance(target, str), f'Invalid symbol variable name: {target}'
value = _canonize_single_argument(self.visit(value))
if target in self.local_variables:
raise RuntimeError(f'Local symbol variable {target} has been assigned before.')
self.local_variables[target] = E.VariableExpression(Variable(target, value.return_type))
if value is not None:
return FunctionCall('assign', ArgumentsList((self.local_variables[target], value)), {'symbol': True})
return None
[docs]
@inline_args
def pass_stmt(self):
return FunctionCall('pass', ArgumentsList(tuple()))
[docs]
@inline_args
def commit_stmt(self, kwargs: dict):
return FunctionCall('commit', ArgumentsList(tuple()), kwargs)
[docs]
@inline_args
def achieve_once_stmt(self, value: CSList):
return FunctionCall('achieve', ArgumentsList(_canonize_arguments(self.visit(value).items, BOOL)), {'once': True})
[docs]
@inline_args
def achieve_hold_stmt(self, value: CSList):
return FunctionCall('achieve', ArgumentsList(_canonize_arguments(self.visit(value).items, BOOL)), {'once': False})
[docs]
@inline_args
def pachieve_once_stmt(self, value: CSList):
return FunctionCall('pachieve', ArgumentsList(_canonize_arguments(self.visit(value).items, BOOL)), {'once': True})
[docs]
@inline_args
def pachieve_hold_stmt(self, value: CSList):
return FunctionCall('pachieve', ArgumentsList(_canonize_arguments(self.visit(value).items, BOOL)), {'once': False})
[docs]
@inline_args
def untrack_stmt(self, value: CSList):
if value is None:
return FunctionCall('untrack', ArgumentsList(tuple()))
return FunctionCall('untrack', ArgumentsList(_canonize_arguments(self.visit(value).items, BOOL)))
[docs]
@inline_args
def assert_once_stmt(self, value: Any):
return FunctionCall('assert', ArgumentsList((_canonize_single_argument(self.visit(value), BOOL), )), {'once': True})
[docs]
@inline_args
def assert_hold_stmt(self, value: Any):
return FunctionCall('assert', ArgumentsList((_canonize_single_argument(self.visit(value), BOOL), )), {'once': False})
[docs]
@inline_args
def return_stmt(self, value: Any):
return FunctionCall('return', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs]
@inline_args
def ordered_suite(self, ordering_op: Any, body: Any):
ordering_op = self.visit(ordering_op)
assert body.data.value == 'suite', f'Invalid body type: {body}'
# We need a special handling for unordered and promotable sections because their execution order is unknown and therefore we can not rely the order of the statements.
if ordering_op in ('promotable', 'unordered', 'promotable unordered', 'promotable sequential', 'alternative'):
with self.local_variable_guard():
body = self.visit(body)
return FunctionCall('ordering', ArgumentsList((ordering_op, body)))
else:
body = self.visit_children(body)
body = Suite(tuple(body), self.local_variables.copy())
return FunctionCall('ordering', ArgumentsList((ordering_op, body)))
[docs]
@inline_args
def ordering_op(self, *ordering_op: Any):
return ' '.join([x.value for x in ordering_op])
[docs]
@inline_args
def if_stmt(self, cond: Any, suite: Any, else_suite: Optional[Any] = None):
cond = _canonize_single_argument(self.visit(cond))
with self.local_variable_guard():
suite = self.visit(suite)
if else_suite is None:
else_suite = Suite((FunctionCall('pass', ArgumentsList(tuple())), ))
else:
with self.local_variable_guard():
else_suite = self.visit(else_suite)
return FunctionCall('if', ArgumentsList((cond, suite, else_suite)))
[docs]
@inline_args
def foreach_stmt(self, cs_list: Any, suite: Any):
cs_list = self.visit(cs_list)
with self.expression_def_ctx.new_variables(*cs_list.items), self.local_variable_guard():
suite = self.visit(suite)
return FunctionCall('foreach', ArgumentsList((cs_list, suite)))
[docs]
@inline_args
def foreach_in_stmt(self, variables_cs_list: Any, values_cs_list: Any, suite: Any) -> object:
values_cs_list = self.visit(values_cs_list)
variables_cs_list = self.visit(variables_cs_list)
if len(variables_cs_list.items) != len(values_cs_list.items):
raise ValueError(f'Number of variables does not match the number of values: {len(variables_cs_list.items)} vs {len(values_cs_list.items)}. Variables: {variables_cs_list.items}, Values: {values_cs_list.items}')
variable_items = list()
for i in range(len(variables_cs_list.items)):
# Variables are just names, not typed Variables. So we need to wrap them.
return_type = values_cs_list.items[i].return_type
if return_type.is_list_type:
variable_items.append(Variable(variables_cs_list.items[i], return_type.element_type))
elif return_type.is_batched_list_type:
variable_items.append(Variable(variables_cs_list.items[i], return_type.iter_element_type()))
else:
raise ValueError(f'Invalid foreach_in statement: {values_cs_list.items[i]} is not a list.')
variables_cs_list = CSList(tuple(variable_items))
with self.expression_def_ctx.new_variables(*variables_cs_list.items), self.local_variable_guard():
suite = self.visit(suite)
return FunctionCall('foreach_in', ArgumentsList((variables_cs_list, values_cs_list, suite, )))
[docs]
@inline_args
def while_stmt(self, cond: Any, suite: Any):
cond = _canonize_single_argument(self.visit(cond))
with self.local_variable_guard():
suite = self.visit(suite)
return FunctionCall('while', ArgumentsList((cond, suite)))
def _quantification_expression(self, cs_list: Any, suite: Any, quantification_cls):
cs_list = self.visit(cs_list)
with self.expression_def_ctx.new_variables(*cs_list.items):
body = self.visit(suite)
for item in reversed(cs_list.items):
body = quantification_cls(item, body)
return body
[docs]
@inline_args
def forall_test(self, cs_list: Any, suite: Any):
return self._quantification_expression(cs_list, suite, E.ForallExpression)
[docs]
@inline_args
def exists_test(self, cs_list: Any, suite: Any):
return self._quantification_expression(cs_list, suite, E.ExistsExpression)
[docs]
@inline_args
def batched_test(self, cs_list: Any, suite: Any):
return self._quantification_expression(cs_list, suite, E.BatchedExpression)
[docs]
@inline_args
def findall_test(self, variable: Variable, suite: Any):
with self.expression_def_ctx.new_variables(variable), self.local_variable_guard():
body = self.visit(suite)
return E.FindAllExpression(variable, body)
[docs]
@inline_args
def findone_test(self, variable: Variable, suite: Any):
with self.expression_def_ctx.new_variables(variable), self.local_variable_guard():
body = self.visit(suite)
return E.FindOneExpression(variable, body)
def _quantification_in_expression(self, cs_list: Any, suite: Any, quantification_cls):
cs_list = self.visit(cs_list)
with self.local_variable_guard():
item: InTypedArgument
for item in cs_list.items:
self.local_variables[item.name] = self.visit(item.value)
body = self.visit(suite)
return quantification_cls(body)
[docs]
@inline_args
def forall_in_test(self, cs_list: Any, suite: Any):
return self._quantification_in_expression(cs_list, suite, E.AndExpression)
[docs]
@inline_args
def exists_in_test(self, cs_list: Any, suite: Any):
return self._quantification_in_expression(cs_list, suite, E.OrExpression)
[docs]
@inline_args
def bind_stmt(self, cs_list: Any, suite: Any):
cs_list = self.visit(cs_list)
with self.expression_def_ctx.new_variables(*cs_list.items), self.local_variable_guard():
suite = self.visit(suite)
body = suite.get_derived_expression()
if isinstance(body, list):
body = E.AndExpression(*body)
for item in cs_list.items:
self.local_variables[item.name] = E.VariableExpression(item)
return FunctionCall('bind', ArgumentsList((cs_list, body)))
[docs]
@inline_args
def bind_stmt_no_where(self, cs_list: Any):
"""Captures bind statements without a body. For example:
.. code-block:: python
bind x: int, y: int
"""
cs_list = self.visit(cs_list)
for item in cs_list.items:
self.local_variables[item.name] = E.VariableExpression(item)
return FunctionCall('bind', ArgumentsList((cs_list, E.NullExpression(BOOL))))
[docs]
@inline_args
def annotated_compound_stmt(self, annotations: dict, stmt: Any):
stmt = self.visit(stmt)
if stmt.annotations is None:
stmt.annotations = annotations
else:
stmt.annotations.update(annotations)
return stmt
[docs]
@dataclass
class GoalPart(object):
suite: Tree
[docs]
@dataclass
class MinimizePart(object):
suite: Tree
[docs]
@dataclass
class BodyPart(object):
suite: Tree
[docs]
@dataclass
class EffectPart(object):
suite: Tree
[docs]
@dataclass
class HeuristicPart(object):
suite: Tree
[docs]
@dataclass
class InPart(object):
suite: Tree
[docs]
@dataclass
class OutPart(object):
suite: Tree
[docs]
class CDLDomainTransformer(CDLLiteralTransformer):
[docs]
def __init__(self, domain: Optional[CrowDomain] = None, auto_init_domain: bool = True):
super().__init__()
if auto_init_domain or domain is not None:
self._domain = CrowDomain() if domain is None else domain
self._expression_def_ctx = E.ExpressionDefinitionContext(domain=self.domain)
self._expression_interpreter = CDLExpressionInterpreter(domain=self.domain, state=None, expression_def_ctx=self.expression_def_ctx, auto_constant_guess=False)
else:
self._domain = None
self._expression_def_ctx = None
self._expression_interpreter = None
@property
def domain(self) -> CrowDomain:
if self._domain is None:
raise ValueError('Domain is not initialized.')
return self._domain
@property
def expression_def_ctx(self) -> E.ExpressionDefinitionContext:
if self._expression_def_ctx is None:
raise ValueError('Expression definition context is not initialized.')
return self._expression_def_ctx
@property
def expression_interpreter(self) -> CDLExpressionInterpreter:
if self._expression_interpreter is None:
raise ValueError('Expression interpreter is not initialized.')
return self._expression_interpreter
[docs]
@inline_args
def include_definition(self, path: str):
path = _DEFAULT_PATH_RESOLVER.resolve(path)
tree = get_default_parser().parse(path)
self.transform(tree)
[docs]
@inline_args
def pragma_definition(self, pragma: Dict[str, Any]):
if _PARSER_VERBOSE:
print('pragma_definition', pragma)
self._handle_pragma(pragma)
[docs]
@inline_args
def pragma_definition_with_args(self, function, arguments):
if _PARSER_VERBOSE:
print('pragma_definition_with_args', function, arguments)
if arguments is None:
arguments = tuple()
else:
arguments = self.expression_interpreter.visit(arguments).arguments
self._handle_pragma({function: arguments})
def _handle_pragma(self, pragma: Dict[str, Any]):
for k, v in pragma.items():
if k == 'load_implementation':
for lib in v:
lib = _DEFAULT_PATH_RESOLVER.resolve(lib)
self.domain.add_external_function_implementation_file(lib)
[docs]
@inline_args
def type_definition(self, typename, basetype: Optional[Union[str, TypeBase]]):
if _PARSER_VERBOSE:
print(f'type_definition:: {typename=} {basetype=}')
self.domain.define_type(typename, basetype)
[docs]
@inline_args
def object_constant_definition(self, name: str, typename: str):
if _PARSER_VERBOSE:
print(f'object_constant_definition:: {name=} {typename=}')
self.domain.define_object_constant(name, typename)
[docs]
@inline_args
def feature_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], ret: Optional[Union[str, TypeBase]], suite: Optional[Tree]):
if annotations is None:
annotations = dict()
if args is None:
args = ArgumentsDef(tuple())
if ret is None:
ret = BOOL
elif isinstance(ret, str):
ret = self.domain.get_type(ret)
return_stmt = None
if suite is not None:
with self.expression_def_ctx.with_variables(*args.arguments):
suite = self.expression_interpreter.visit(suite)
return_stmt = suite.get_derived_expression()
self.domain.define_feature(name, args.arguments, ret, derived_expression=return_stmt, **annotations)
if _PARSER_VERBOSE:
print(f'feature_definition:: {name=} {args.arguments=} {ret=} {annotations=} {suite=}')
if return_stmt is not None:
print(jacinle.indent_text(f'Return statement: {return_stmt}'))
[docs]
@inline_args
def function_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], ret: Optional[Union[str, TypeBase]], suite: Optional[Tree]):
if annotations is None:
annotations = dict()
if args is None:
args = ArgumentsDef(tuple())
if ret is None:
ret = BOOL
elif isinstance(ret, str):
ret = self.domain.get_type(ret)
return_stmt = None
if suite is not None:
with self.expression_def_ctx.with_variables(*args.arguments):
suite = self.expression_interpreter.visit(suite)
return_stmt = suite.get_derived_expression()
self.domain.define_crow_function(name, args.arguments, ret, derived_expression=return_stmt, **annotations)
if _PARSER_VERBOSE:
print(f'function_definition:: {name=} {args.arguments=} {ret=} {annotations=} {suite=}')
if return_stmt is not None:
print(jacinle.indent_text(f'Return statement: {return_stmt}'))
[docs]
@inline_args
def controller_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], effect: Optional[EffectPart]):
if annotations is None:
annotations = dict()
if args is None:
args = ArgumentsDef(tuple())
if effect is not None:
with self.expression_def_ctx.with_variables(*args.arguments):
suite = self.expression_interpreter.visit(effect.suite)
suite = suite.get_effect_statements()
effect = CrowBehaviorOrderingSuite.make_sequential(*suite)
self.domain.define_controller(name, args.arguments, effect, **annotations)
if _PARSER_VERBOSE:
print(f'controller_definition:: {name=} {args.arguments=}')
[docs]
@inline_args
def behavior_goal_definition(self, suite: Tree) -> GoalPart:
return GoalPart(suite)
[docs]
@inline_args
def behavior_minimize_definition(self, suite: Tree) -> MinimizePart:
return MinimizePart(suite)
[docs]
@inline_args
def behavior_body_definition(self, suite: Tree) -> BodyPart:
return BodyPart(suite)
[docs]
@inline_args
def behavior_effect_definition(self, suite: Tree) -> EffectPart:
return EffectPart(suite)
[docs]
@inline_args
def behavior_heuristic_definition(self, suite: Tree) -> HeuristicPart:
return HeuristicPart(suite)
[docs]
@inline_args
def behavior_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], *parts: Union[GoalPart, EffectPart, BodyPart, HeuristicPart]):
if annotations is None:
annotations = dict()
if args is None:
args = ArgumentsDef(tuple())
if _PARSER_VERBOSE:
print(f'behavior_definition:: {name=} {args.arguments=} {annotations=}')
goals = list()
minimize = None
body = list()
effect = None
heuristic = None
local_variables = None
for part in parts:
with self.expression_def_ctx.with_variables(*args.arguments):
if isinstance(part, EffectPart):
self.expression_interpreter.local_variables = local_variables if local_variables is not None else dict()
suite = self.expression_interpreter.visit(part.suite)
self.expression_interpreter.local_variables = dict()
else:
suite = self.expression_interpreter.visit(part.suite)
if isinstance(part, GoalPart):
suite = suite.get_derived_expression()
goals.append(suite)
if _PARSER_VERBOSE:
print(jacinle.indent_text(f'Goal: {suite}'))
elif isinstance(part, MinimizePart):
suite = suite.get_derived_expression()
minimize = suite
if _PARSER_VERBOSE:
print(jacinle.indent_text(f'Minimize: {suite}'))
elif isinstance(part, BodyPart):
local_variables = suite.local_variables
body = suite.get_behavior_body_statements()
if len(body) == 0:
body = CrowBehaviorOrderingSuite('sequential', tuple())
elif len(body) == 1:
if not isinstance(body[0], CrowBehaviorOrderingSuite) or body[0].order.value != 'sequential':
body = CrowBehaviorOrderingSuite('sequential', (body[0],), _skip_simplify=True)
else:
body = CrowBehaviorOrderingSuite('sequential', tuple(body), _skip_simplify=True)
if _PARSER_VERBOSE:
print(jacinle.indent_text(f'Body:'))
print(jacinle.indent_text(str(body), level=2))
elif isinstance(part, EffectPart):
suite = suite.get_effect_statements()
effect = CrowBehaviorOrderingSuite.make_sequential(*suite)
if _PARSER_VERBOSE:
print(jacinle.indent_text(f'Effect: {effect}'))
elif isinstance(part, HeuristicPart):
heuristic = suite.get_heuristic_statements()
if len(heuristic) == 0:
heuristic = None
elif len(heuristic) == 1:
if not isinstance(heuristic[0], CrowBehaviorOrderingSuite) or heuristic[0].order.value != 'sequential':
heuristic = CrowBehaviorOrderingSuite('sequential', (heuristic[0],), _skip_simplify=True)
else:
heuristic = CrowBehaviorOrderingSuite('sequential', tuple(heuristic), _skip_simplify=True)
else:
raise ValueError(f'Invalid part: {part}')
if effect is None:
effect = CrowBehaviorOrderingSuite.make_sequential()
if len(goals) == 0:
goal = E.NullExpression(BOOL)
self.domain.define_behavior(name, args.arguments, goal, body, effect, heuristic=heuristic, minimize=minimize, **annotations)
elif len(goals) == 1:
self.domain.define_behavior(name, args.arguments, goals[0], body, effect, heuristic=heuristic, minimize=minimize, **annotations)
else:
goal = E.NullExpression(BOOL)
self.domain.define_behavior(name, args.arguments, goal, body, effect, minimize=minimize, **annotations)
for i, goal in enumerate(goals):
self.domain.define_behavior(f'{name}_{i}', args.arguments, goal, body, effect, heuristic=heuristic, minimize=minimize, **annotations)
[docs]
@inline_args
def generator_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], *parts: Union[GoalPart, InPart, OutPart]):
if annotations is None:
annotations = dict()
if args is None:
args = ArgumentsDef(tuple())
if _PARSER_VERBOSE:
print(f'generator_definition:: {name=} {args.arguments=} {annotations=}')
inputs = list()
outputs = None
goal = E.NullExpression(BOOL)
with self.expression_def_ctx.with_variables(*args.arguments):
for part in parts:
suite = self.expression_interpreter.visit(part.suite)
if isinstance(part, GoalPart):
goal = suite.get_derived_expression()
if _PARSER_VERBOSE:
print(jacinle.indent_text(f'Goal: {goal}'))
elif isinstance(part, InPart):
inputs = [E.VariableExpression(x) for x in suite.items]
if _PARSER_VERBOSE:
print(jacinle.indent_text(f'Inputs: {inputs}'))
elif isinstance(part, OutPart):
outputs = [E.VariableExpression(x) for x in suite.items]
if _PARSER_VERBOSE:
print(jacinle.indent_text(f'Outputs: {outputs}'))
else:
raise ValueError(f'Invalid part: {part}')
self.domain.define_generator(name, args.arguments, goal, inputs, outputs, **annotations)
[docs]
@inline_args
def generator_goal_definition(self, suite: Tree) -> GoalPart:
return GoalPart(suite)
[docs]
@inline_args
def generator_in_definition(self, values: Tree) -> InPart:
return InPart(values)
[docs]
@inline_args
def generator_out_definition(self, values: Tree) -> OutPart:
return OutPart(values)