#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : python_function.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 03/17/2024
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.
import itertools
from dataclasses import dataclass
from typing import Any, Optional, Union, Iterator, Sequence, Tuple, Dict, Callable, TYPE_CHECKING
import numpy as np
import torch
from concepts.dsl.dsl_functions import Function
from concepts.dsl.dsl_types import BOOL, PyObjValueType, ObjectType, UniformSequenceType, TupleType, QINDEX, TensorValueTypeBase
from concepts.dsl.executors.tensor_value_executor import TensorValueExecutorReturnType
from concepts.dsl.expression import Expression
from concepts.dsl.tensor_value import TensorValue, TensorizedPyObjValues
from concepts.dsl.tensor_state import StateObjectReference
from concepts.dsl.tensor_value_utils import expand_argument_values
from concepts.dsl.value import ListValue
from concepts.dm.crow.crow_domain import CrowState
if TYPE_CHECKING:
from concepts.dm.crow.executors.crow_executor import CrowExecutor
__all__ = ['CrowPythonFunctionRef', 'CrowPythonFunctionCrossRef', 'config_function_implementation', 'CrowSGC']
[docs]
class CrowPythonFunctionRef(object):
"""A reference to a Python function.
This class is used to wrap external function implementations in domains.
"""
[docs]
def __init__(
self, function: Callable, function_quantized: Optional[Callable] = None,
*,
return_type: Optional[Union[TensorValueTypeBase, PyObjValueType, Tuple[Union[TensorValueTypeBase, PyObjValueType], ...]]] = None,
support_batch: bool = False, auto_broadcast: bool = False,
use_object_names: bool = True, unwrap_values: Optional[bool] = None,
include_executor_args: bool = False, is_iterator: bool = False, is_sgc_function: bool = False,
executor: Optional['CrowExecutor'] = None
):
"""Initialize a Python function reference.
Args:
function: the function to be wrapped.
function_quantized: the quantized version of the function (can be None).
support_batch: whether the function supports batched inputs.
auto_broadcast: whether the executor should automatically broadcast the arguments before calling the function.
use_object_names: whether the executor should use the object names in the state (instead of the index).
unwrap_values: whether the executor should unwrap the tensor values before calling the function.
include_executor_args: whether the caller should include the executor as the first argument.
is_iterator: whether the function is an iterator.
is_sgc_function: whether the function is an SGC function (state-goal-constraints).
executor: the executor that is using this function reference.
"""
self.function = function
self.function_quantized = function_quantized
self.return_type = return_type
self.support_batch = support_batch
self.auto_broadcast = auto_broadcast
self.use_object_names = use_object_names
if unwrap_values is None:
unwrap_values = not support_batch or auto_broadcast
self.unwrap_values = unwrap_values
self.include_executor_args = include_executor_args
self.is_iterator = is_iterator
self.is_sgc_function = is_sgc_function
self._executor = executor
function: Callable
"""The internal implementation of the function."""
function_quantized: Optional[Callable]
"""The quantized version of the function (can be None)."""
return_type: Optional[Union[TensorValueTypeBase, PyObjValueType, Tuple[Union[TensorValueTypeBase, PyObjValueType], ...]]]
"""The return type of the function."""
support_batch: bool
"""Whether the function supports batched inputs."""
auto_broadcast: bool
"""Whether the executor should automatically broadcast the arguments before calling the function."""
use_object_names: bool
"""Whether the executor should use the object names in the state (instead of the index)."""
unwrap_values: bool
"""Whether the executor should unwrap the tensor values before calling the function."""
include_executor_args: bool
"""Whether the caller should include the executor as the first argument."""
is_iterator: bool
"""Whether the function is an iterator."""
is_sgc_function: bool
"""Whether the function is an SGC function (state-goal-constraints)."""
[docs]
def set_executor(self, executor: 'CrowExecutor') -> 'CrowPythonFunctionRef':
"""Set the executor that is using this function reference.
Args:
executor: the executor that is using this function reference.
Returns:
the function reference itself.
"""
self._executor = executor
return self
def __str__(self) -> str:
return (
'PythonFunctionRef('
f'{self.function}, support_batch={self.support_batch}, auto_broadcast={self.auto_broadcast}, '
f'use_object_names={self.use_object_names}, unwrap_values={self.unwrap_values}, include_executor_args={self.include_executor_args}, '
f'is_iterator={self.is_iterator}), is_sgc_function={self.is_sgc_function}'
)
def __repr__(self) -> str:
return self.__str__()
@property
def flags(self) -> Dict[str, bool]:
return {
'support_batch': self.support_batch,
'auto_broadcast': self.auto_broadcast,
'use_object_names': self.use_object_names,
'unwrap_values': self.unwrap_values,
'include_executor_args': self.include_executor_args,
'is_iterator': self.is_iterator,
'is_sgc_function': self.is_sgc_function
}
[docs]
def forward(
self, argument_values: Sequence[TensorValueExecutorReturnType], return_type: Optional[Union[TensorValueTypeBase, PyObjValueType]] = None,
additional_parameters: Optional[Sequence[Any]] = None, auto_broadcast: bool = True, wrap_rv: bool = True,
function_def: Optional[Function] = None
) -> Union[TensorValue, Tuple[TensorValue, ...]]:
"""Call the function.
Args:
argument_values: the arguments to the function.
return_type: the type of the return value.
additional_parameters: the additional parameters to the function.
auto_broadcast: whether the executor should automatically broadcast the arguments before calling the function.
wrap_rv: whether the executor should wrap the return value.
function_def: the function definition, used to wrap return values and to handle QINDEX.
Returns:
the result of the function.
"""
function = self.function
if self.unwrap_values:
if self.use_object_names:
argument_values = [v.name if isinstance(v, StateObjectReference) else v for v in argument_values]
else:
argument_values = [v.index if isinstance(v, StateObjectReference) else v for v in argument_values]
if self.support_batch:
if self.auto_broadcast and auto_broadcast:
argument_values = expand_argument_values(argument_values)
argument_values_flat = list(argument_values)
if self.unwrap_values:
argument_values_flat = [v.tensor if isinstance(v, TensorValue) else v for v in argument_values_flat]
if not self.support_batch:
argument_values_flat = [v.item() if isinstance(v, TensorizedPyObjValues) else v for v in argument_values_flat]
if additional_parameters is not None:
additional_parameters = list(additional_parameters)
else:
additional_parameters = []
if self.is_sgc_function:
assert self._executor is not None, 'Executor is None.'
additional_parameters.insert(0, self._executor.sgc)
if self.include_executor_args:
assert self._executor is not None, 'Executor is None.'
additional_parameters.insert(0, self._executor)
argument_values_flat = additional_parameters + argument_values_flat
if QINDEX in argument_values_flat:
if not self.support_batch:
if function_def is None:
raise RuntimeError('For functions with QINDEX, the function_def argument must be provided.')
rv = self.forward_internal_autobatch(function_def, function, argument_values_flat)
else:
rv = function(*argument_values_flat)
else:
rv = function(*argument_values_flat)
if not wrap_rv:
return rv
return self._wrap_rv(rv, return_type, argument_values, auto_broadcast)
[docs]
def forward_internal_autobatch(self, function_def, function, argument_values_flat):
options_per_argument = list()
output_dims = list()
batch_variables = list()
for i, arg in enumerate(argument_values_flat):
if arg is QINDEX:
objects = self._executor.state.object_type2name[function_def.ftype.argument_types[i].typename]
if self.use_object_names:
options_per_argument.append(objects)
else:
options_per_argument.append(list(range(len(objects))))
output_dims.append(len(objects))
batch_variables.append(function_def.ftype.argument_names[i])
else:
options_per_argument.append([arg])
rtype = function_def.ftype.return_type
if rtype != BOOL:
raise TypeError('Only BOOL is supported for auto-batch functions with QINDEX.')
output = torch.zeros(np.prod(output_dims), dtype=torch.bool)
options = list(itertools.product(*options_per_argument))
for i, option in enumerate(options):
rv = function(*option)
if isinstance(rv, TensorValue):
rv = rv.item()
if isinstance(rv, torch.Tensor):
if output.device != rv.device:
output.tensor = output.to(rv.device)
output[i] = rv
output = TensorValue.from_tensor(output.reshape(output_dims), rtype, batch_variables=batch_variables)
return output
[docs]
def forward_generator(
self, argument_values: Sequence[TensorValueExecutorReturnType], return_type: Optional[Union[TensorValueTypeBase, PyObjValueType]] = None,
auto_broadcast: bool = True, wrap_rv: bool = True
) -> Union[Iterator[TensorValue], Iterator[Tuple[TensorValue, ...]]]:
"""Call the function and return a generator.
Args:
argument_values: the arguments to the function.
return_type: the type of the return value.
auto_broadcast: whether the executor should automatically broadcast the arguments before calling the function.
wrap_rv: whether the executor should wrap the
"""
generator = self.forward(argument_values, return_type=return_type, auto_broadcast=auto_broadcast, wrap_rv=False)
if not wrap_rv:
yield from generator
else:
for v in generator:
yield self._wrap_rv(v, return_type, argument_values, auto_broadcast)
[docs]
def forward_sgc_function(
self, state: CrowState, goal: Expression, constraints: Sequence[Expression], additional_arguments: Sequence[TensorValueExecutorReturnType],
return_type: Optional[Union[TensorValueTypeBase, PyObjValueType]] = None,
auto_broadcast: bool = True, wrap_rv: bool = True
):
"""Call an SGC function (state-goal-constraints) function.
Args:
state: the current state.
goal: the goal expression.
constraints: the constraints, as a list of expressions.
additional_arguments: the additional arguments.
return_type: the type of the return value.
auto_broadcast: whether the executor should automatically broadcast the arguments before calling the function.
wrap_rv: whether the executor should wrap the return value.
"""
return self.forward(additional_arguments, return_type=return_type, additional_parameters=(state, goal, constraints), auto_broadcast=auto_broadcast, wrap_rv=wrap_rv)
[docs]
def __call__(self, *args, return_type: Optional[Union[TensorValueTypeBase, PyObjValueType]] = None, auto_broadcast: bool = True, wrap_rv: bool = True) -> Union[TensorValue, Tuple[TensorValue, ...]]:
assert not self.is_iterator, 'Use iter_from to call an iterator function.'
return self.forward(args, return_type=return_type, auto_broadcast=auto_broadcast, wrap_rv=wrap_rv)
[docs]
def iter_from(self, *args, return_type: Optional[Union[TensorValueTypeBase, PyObjValueType]] = None, auto_broadcast: bool = True, wrap_rv: bool = True) -> Union[Iterator[TensorValue], Iterator[Tuple[TensorValue, ...]]]:
assert self.is_iterator, 'Use __call__ to call a non-iterator function.'
return self.forward_generator(args, return_type=return_type, auto_broadcast=auto_broadcast, wrap_rv=wrap_rv)
def _wrap_rv(self, rv, return_type, argument_values, auto_broadcast):
if isinstance(rv, (TensorValue, ListValue)):
return rv
elif isinstance(rv, tuple) and all(isinstance(v, (TensorValue, ListValue)) for v in rv):
return rv
if return_type is None:
return_type = self.return_type
if return_type is None:
raise RuntimeError('Return type can not be None if the function return is not a TensorValue.')
if isinstance(return_type, TupleType):
if not isinstance(rv, tuple) and len(return_type) == 1:
rv = (rv, )
return tuple(self._wrap_single_rv(v, t, argument_values, auto_broadcast) for v, t in zip(rv, return_type.element_types))
else:
return self._wrap_single_rv(rv, return_type, argument_values, auto_broadcast)
def _wrap_single_rv(self, rv, return_type, argument_values, auto_broadcast):
if isinstance(rv, (TensorValue, ListValue)):
return rv
# TODO(Jiayuan Mao @ 2023/11/18): have an actual type check.
if return_type.alias is not None and return_type.alias.startswith('__') and return_type.alias.endswith('__'):
return rv
if not self.support_batch:
if isinstance(return_type, PyObjValueType):
if isinstance(rv, TensorizedPyObjValues):
return TensorValue.from_tensorized_pyobj(rv, return_type)
return TensorValue.from_scalar(rv, return_type)
elif isinstance(return_type, ObjectType):
if isinstance(rv, str):
return self._executor.state.get_state_object_reference(return_type, name=rv)
elif isinstance(rv, int):
return self._executor.state.get_state_object_reference(return_type, index=rv)
else:
return rv
elif isinstance(return_type, UniformSequenceType) and isinstance(return_type.element_type, ObjectType):
if isinstance(rv, (list, tuple)):
if len(rv) == 0:
return self._executor.state.get_state_object_list(return_type.element_type, [])
else:
if isinstance(rv[0], str):
return self._executor.state.get_state_object_list(return_type.element_type, names=rv)
elif isinstance(rv[0], int):
return self._executor.state.get_state_object_list(return_type.element_type, indices=rv)
else:
raise ValueError(f'Unsupported return type: {rv}')
else:
return rv
else:
if isinstance(rv, torch.Tensor):
return TensorValue.from_tensor(rv, return_type)
elif isinstance(rv, (bool, int, float)):
return TensorValue.from_scalar(rv, return_type)
else:
raise ValueError(f'Unsupported return type: {type(rv)}')
else:
if isinstance(return_type, PyObjValueType):
raise TypeError('Cannot return a PyObjValueType for a batched function.')
else:
if isinstance(rv, torch.Tensor):
first_tensor_arg = None
for arg in argument_values:
if isinstance(arg, TensorValue):
first_tensor_arg = arg
break
if not self.auto_broadcast or not auto_broadcast or first_tensor_arg is None:
raise ValueError('Cannot return a raw PyTorch tensor for a batched function without auto_broadcast.')
return TensorValue.from_tensor(rv, return_type, batch_variables=first_tensor_arg.batch_variables, batch_dims=first_tensor_arg.batch_dims)
else:
raise ValueError(f'Unsupported return type: {type(rv)}')
[docs]
class CrowPythonFunctionCrossRef(object):
[docs]
def __init__(self, cross_ref_name: str):
self.cross_ref_name = cross_ref_name
def config_function_implementation(
function: Optional[Callable] = None, *, function_quantized: Optional[Callable] = None,
support_batch: bool = False, auto_broadcast: bool = True, use_object_names: bool = True, unwrap_values: Optional[bool] = None, include_executor_args: bool = False,
is_iterator: bool = False, is_sgc_function: bool = False
) -> Callable:
"""Configure the implementation of a function in a domain.
Args:
function: the function to be wrapped.
function_quantized: the quantized version of the function (can be None).
support_batch: whether the function supports batched inputs.
auto_broadcast: whether the executor should automatically broadcast the arguments before calling the function.
use_object_names: whether the executor should use object names instead of indices.
unwrap_values: whether the executor should unwrap the values before calling the function.
include_executor_args: whether the executor should include itself as the first argument.
is_iterator: whether the function is an iterator.
is_sgc_function: whether the function is an SGC function.
Returns:
the decorator.
"""
function_implementation_configs = {
'function_quantized': function_quantized,
'support_batch': support_batch,
'auto_broadcast': auto_broadcast,
'use_object_names': use_object_names,
'unwrap_values': unwrap_values,
'include_executor_args': include_executor_args,
'is_iterator': is_iterator,
'is_sgc_function': is_sgc_function
}
def wrapper(function: Callable, configs=function_implementation_configs):
return CrowPythonFunctionRef(function, **configs)
if function is None:
return wrapper
return wrapper(function)
def _check_no_quantized_arguments(arguments):
"""A helper function to check that there are no quantized arguments. This function handles the migration from quantized tensor CSP computation to non-quantized tensor CSP computation."""
# TODO(Jiayuan Mao @ 2023/08/15): remove this after the migration.
for arg in arguments:
if isinstance(arg, TensorValue):
if isinstance(arg.dtype, TensorValueTypeBase):
if arg.quantized and not arg.dtype.is_intrinsically_quantized():
raise RuntimeError('Quantized arguments are not supported.')
@dataclass
class CrowSGC(object):
state: CrowState
goal: Expression
constraints: Sequence[Expression]